pw_rpc: Make service impls derived classes
This updates the generated RPC service class to act as a (non-virtual)
base which calls into a user-implemented derived class. This allows
users to extend generated service classes with custom members and
dependencies, while keeping the core static method dispatch structure.
Change-Id: I9ea6327790071a26fc940e3ac89b5bea2a2c4495
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/14123
Reviewed-by: Wyatt Hepler <hepler@google.com>
Commit-Queue: Alexei Frolov <frolv@google.com>
diff --git a/pw_rpc/nanopb/codegen_test.cc b/pw_rpc/nanopb/codegen_test.cc
index 1cd40f0..5a717fa 100644
--- a/pw_rpc/nanopb/codegen_test.cc
+++ b/pw_rpc/nanopb/codegen_test.cc
@@ -20,23 +20,25 @@
namespace pw::rpc {
namespace test {
-Status TestService::TestRpc(ServerContext&,
- const pw_rpc_test_TestRequest& request,
- pw_rpc_test_TestResponse& response) {
- response.value = request.integer + 1;
- return static_cast<Status::Code>(request.status_code);
-}
-
-void TestService::TestStreamRpc(
- ServerContext&,
- const pw_rpc_test_TestRequest& request,
- ServerWriter<pw_rpc_test_TestStreamResponse>& writer) {
- for (int i = 0; i < request.integer; ++i) {
- writer.Write({.number = static_cast<uint32_t>(i)});
+class TestServiceImpl : public TestService<TestServiceImpl> {
+ public:
+ Status TestRpc(ServerContext&,
+ const pw_rpc_test_TestRequest& request,
+ pw_rpc_test_TestResponse& response) {
+ response.value = request.integer + 1;
+ return static_cast<Status::Code>(request.status_code);
}
- writer.Finish(static_cast<Status::Code>(request.status_code));
-}
+ void TestStreamRpc(ServerContext&,
+ const pw_rpc_test_TestRequest& request,
+ ServerWriter<pw_rpc_test_TestStreamResponse>& writer) {
+ for (int i = 0; i < request.integer; ++i) {
+ writer.Write({.number = static_cast<uint32_t>(i)});
+ }
+
+ writer.Finish(static_cast<Status::Code>(request.status_code));
+ }
+};
} // namespace test
@@ -44,13 +46,13 @@
namespace {
TEST(NanopbCodegen, CompilesProperly) {
- test::TestService service;
+ test::TestServiceImpl service;
EXPECT_EQ(service.id(), Hash("pw.rpc.test.TestService"));
EXPECT_STREQ(service.name(), "TestService");
}
TEST(NanopbCodegen, InvokeUnaryRpc) {
- PW_RPC_TEST_METHOD_CONTEXT(test::TestService, TestRpc) context;
+ PW_RPC_TEST_METHOD_CONTEXT(test::TestServiceImpl, TestRpc) context;
EXPECT_EQ(Status::OK,
context.call({.integer = 123, .status_code = Status::OK}));
@@ -64,7 +66,7 @@
}
TEST(NanopbCodegen, InvokeStreamingRpc) {
- PW_RPC_TEST_METHOD_CONTEXT(test::TestService, TestStreamRpc) context;
+ PW_RPC_TEST_METHOD_CONTEXT(test::TestServiceImpl, TestStreamRpc) context;
context.call({.integer = 0, .status_code = Status::ABORTED});
@@ -86,7 +88,7 @@
}
TEST(NanopbCodegen, InvokeStreamingRpc_ContextKeepsFixedNumberOfResponses) {
- PW_RPC_TEST_METHOD_CONTEXT(test::TestService, TestStreamRpc, 3) context;
+ PW_RPC_TEST_METHOD_CONTEXT(test::TestServiceImpl, TestStreamRpc, 3) context;
ASSERT_EQ(3u, context.responses().max_size());
diff --git a/pw_rpc/nanopb/method.cc b/pw_rpc/nanopb/method.cc
index 87997b5..a096ca8 100644
--- a/pw_rpc/nanopb/method.cc
+++ b/pw_rpc/nanopb/method.cc
@@ -64,8 +64,7 @@
return;
}
- const Status status =
- function_.unary(call.context(), request_struct, response_struct);
+ const Status status = function_.unary(call, request_struct, response_struct);
SendResponse(call.channel(), request, response_struct, status);
}
@@ -77,7 +76,7 @@
}
internal::BaseServerWriter server_writer(call);
- function_.server_streaming(call.context(), request_struct, server_writer);
+ function_.server_streaming(call, request_struct, server_writer);
}
bool Method::DecodeRequest(Channel& channel,
diff --git a/pw_rpc/nanopb/method_test.cc b/pw_rpc/nanopb/method_test.cc
index 350dec2..bbdf753 100644
--- a/pw_rpc/nanopb/method_test.cc
+++ b/pw_rpc/nanopb/method_test.cc
@@ -47,21 +47,31 @@
return std::as_bytes(buffer.first(output.bytes_written));
}
+template <typename Implementation>
class FakeGeneratedService : public Service {
public:
constexpr FakeGeneratedService(uint32_t id) : Service(id, kMethods) {}
- static Status DoNothing(ServerContext&,
- const pw_rpc_test_Empty&,
- pw_rpc_test_Empty&);
+ static Status DoNothing(ServerCall& call,
+ const pw_rpc_test_Empty& request,
+ pw_rpc_test_Empty& response) {
+ return static_cast<Implementation&>(call.service())
+ .DoNothing(call.context(), request, response);
+ }
- static Status AddFive(ServerContext&,
- const pw_rpc_test_TestRequest&,
- pw_rpc_test_TestResponse&);
+ static Status AddFive(ServerCall& call,
+ const pw_rpc_test_TestRequest& request,
+ pw_rpc_test_TestResponse& response) {
+ return static_cast<Implementation&>(call.service())
+ .AddFive(call.context(), request, response);
+ }
- static void StartStream(ServerContext&,
- const pw_rpc_test_TestRequest&,
- ServerWriter<pw_rpc_test_TestResponse>&);
+ static void StartStream(ServerCall& call,
+ const pw_rpc_test_TestRequest& request,
+ ServerWriter<pw_rpc_test_TestResponse>& writer) {
+ static_cast<Implementation&>(call.service())
+ .StartStream(call.context(), request, writer);
+ }
static constexpr std::array<Method, 3> kMethods = {
Method::Unary<DoNothing>(
@@ -76,34 +86,38 @@
pw_rpc_test_TestRequest last_request;
ServerWriter<pw_rpc_test_TestResponse> last_writer;
-Status FakeGeneratedService::AddFive(ServerContext&,
- const pw_rpc_test_TestRequest& request,
- pw_rpc_test_TestResponse& response) {
- last_request = request;
- response.value = request.integer + 5;
- return Status::UNAUTHENTICATED;
-}
+class FakeGeneratedServiceImpl
+ : public FakeGeneratedService<FakeGeneratedServiceImpl> {
+ public:
+ FakeGeneratedServiceImpl(uint32_t id) : FakeGeneratedService(id) {}
-Status FakeGeneratedService::DoNothing(ServerContext&,
- const pw_rpc_test_Empty&,
- pw_rpc_test_Empty&) {
- return Status::UNKNOWN;
-}
+ Status AddFive(ServerContext&,
+ const pw_rpc_test_TestRequest& request,
+ pw_rpc_test_TestResponse& response) {
+ last_request = request;
+ response.value = request.integer + 5;
+ return Status::UNAUTHENTICATED;
+ }
-void FakeGeneratedService::StartStream(
- ServerContext&,
- const pw_rpc_test_TestRequest& request,
- ServerWriter<pw_rpc_test_TestResponse>& writer) {
- last_request = request;
+ Status DoNothing(ServerContext&,
+ const pw_rpc_test_Empty&,
+ pw_rpc_test_Empty&) {
+ return Status::UNKNOWN;
+ }
- last_writer = std::move(writer);
-}
+ void StartStream(ServerContext&,
+ const pw_rpc_test_TestRequest& request,
+ ServerWriter<pw_rpc_test_TestResponse>& writer) {
+ last_request = request;
+ last_writer = std::move(writer);
+ }
+};
TEST(Method, UnaryRpc_SendsResponse) {
ENCODE_PB(pw_rpc_test_TestRequest, {.integer = 123}, request);
- const Method& method = std::get<1>(FakeGeneratedService::kMethods);
- ServerContextForTest<FakeGeneratedService> context(method);
+ const Method& method = std::get<1>(FakeGeneratedServiceImpl::kMethods);
+ ServerContextForTest<FakeGeneratedServiceImpl> context(method);
method.Invoke(context.get(), context.packet(request));
const Packet& response = context.output().sent_packet();
@@ -123,8 +137,8 @@
TEST(Method, UnaryRpc_InvalidPayload_SendsError) {
std::array<byte, 8> bad_payload{byte{0xFF}, byte{0xAA}, byte{0xDD}};
- const Method& method = std::get<0>(FakeGeneratedService::kMethods);
- ServerContextForTest<FakeGeneratedService> context(method);
+ const Method& method = std::get<0>(FakeGeneratedServiceImpl::kMethods);
+ ServerContextForTest<FakeGeneratedServiceImpl> context(method);
method.Invoke(context.get(), context.packet(bad_payload));
const Packet& packet = context.output().sent_packet();
@@ -138,9 +152,9 @@
constexpr int64_t value = 0x7FFFFFFF'FFFFFF00ll;
ENCODE_PB(pw_rpc_test_TestRequest, {.integer = value}, request);
- const Method& method = std::get<1>(FakeGeneratedService::kMethods);
+ const Method& method = std::get<1>(FakeGeneratedServiceImpl::kMethods);
// Output buffer is too small for the response, but can fit an error packet.
- ServerContextForTest<FakeGeneratedService, 22> context(method);
+ ServerContextForTest<FakeGeneratedServiceImpl, 22> context(method);
ASSERT_LT(context.output().buffer_size(),
context.packet(request).MinEncodedSizeBytes() + request.size() + 1);
@@ -158,8 +172,8 @@
TEST(Method, ServerStreamingRpc_SendsNothingWhenInitiallyCalled) {
ENCODE_PB(pw_rpc_test_TestRequest, {.integer = 555}, request);
- const Method& method = std::get<2>(FakeGeneratedService::kMethods);
- ServerContextForTest<FakeGeneratedService> context(method);
+ const Method& method = std::get<2>(FakeGeneratedServiceImpl::kMethods);
+ ServerContextForTest<FakeGeneratedServiceImpl> context(method);
method.Invoke(context.get(), context.packet(request));
@@ -168,8 +182,8 @@
}
TEST(Method, ServerWriter_SendsResponse) {
- const Method& method = std::get<2>(FakeGeneratedService::kMethods);
- ServerContextForTest<FakeGeneratedService> context(method);
+ const Method& method = std::get<2>(FakeGeneratedServiceImpl::kMethods);
+ ServerContextForTest<FakeGeneratedServiceImpl> context(method);
method.Invoke(context.get(), context.packet({}));
@@ -188,14 +202,14 @@
}
TEST(Method, ServerStreamingRpc_ServerWriterBufferTooSmall_InternalError) {
- const Method& method = std::get<2>(FakeGeneratedService::kMethods);
+ const Method& method = std::get<2>(FakeGeneratedServiceImpl::kMethods);
constexpr size_t kNoPayloadPacketSize = 2 /* type */ + 2 /* channel */ +
5 /* service */ + 5 /* method */ +
2 /* payload */ + 2 /* status */;
// Make the buffer barely fit a packet with no payload.
- ServerContextForTest<FakeGeneratedService, kNoPayloadPacketSize> context(
+ ServerContextForTest<FakeGeneratedServiceImpl, kNoPayloadPacketSize> context(
method);
// Verify that the encoded size of a packet with an empty payload is correct.
diff --git a/pw_rpc/nanopb/public_overrides/pw_rpc/internal/method.h b/pw_rpc/nanopb/public_overrides/pw_rpc/internal/method.h
index 956c4ef..4cd8303 100644
--- a/pw_rpc/nanopb/public_overrides/pw_rpc/internal/method.h
+++ b/pw_rpc/nanopb/public_overrides/pw_rpc/internal/method.h
@@ -57,8 +57,7 @@
// Specialization for unary RPCs.
template <typename RequestType, typename ResponseType>
-struct RpcTraits<Status (*)(
- ServerContext&, const RequestType&, ResponseType&)> {
+struct RpcTraits<Status (*)(ServerCall&, const RequestType&, ResponseType&)> {
using Request = RequestType;
using Response = ResponseType;
@@ -70,7 +69,7 @@
// Specialization for server streaming RPCs.
template <typename RequestType, typename ResponseType>
struct RpcTraits<void (*)(
- ServerContext&, const RequestType&, ServerWriter<ResponseType>&)> {
+ ServerCall&, const RequestType&, ServerWriter<ResponseType>&)> {
using Request = RequestType;
using Response = ResponseType;
@@ -79,6 +78,20 @@
static constexpr bool kClientStreaming = false;
};
+// Member function specialization for unary RPCs.
+template <typename T, typename RequestType, typename ResponseType>
+struct RpcTraits<Status (T::*)(
+ ServerContext&, const RequestType&, ResponseType&)>
+ : public RpcTraits<Status (*)(
+ ServerCall&, const RequestType&, ResponseType&)> {};
+
+// Member function specialization for server streaming RPCs.
+template <typename T, typename RequestType, typename ResponseType>
+struct RpcTraits<void (T::*)(
+ ServerContext&, const RequestType&, ServerWriter<ResponseType>&)>
+ : public RpcTraits<void (*)(
+ ServerCall&, const RequestType&, ServerWriter<ResponseType>&)> {};
+
template <auto method>
using Request = typename RpcTraits<decltype(method)>::Request;
@@ -109,9 +122,9 @@
// In optimized builds, the compiler inlines the user-defined function into
// this wrapper, elminating any overhead.
return Method({.unary =
- [](ServerContext& ctx, const void* req, void* resp) {
+ [](ServerCall& call, const void* req, void* resp) {
return method(
- ctx,
+ call,
*static_cast<const Request<method>*>(req),
*static_cast<Response<method>*>(resp));
}},
@@ -133,8 +146,8 @@
// union, defined below.
return Method(
{.server_streaming =
- [](ServerContext& ctx, const void* req, BaseServerWriter& resp) {
- method(ctx,
+ [](ServerCall& call, const void* req, BaseServerWriter& resp) {
+ method(call,
*static_cast<const Request<method>*>(req),
static_cast<ServerWriter<Response<method>>&>(resp));
}},
@@ -163,17 +176,17 @@
private:
// Generic version of the unary RPC function signature:
//
- // Status(ServerContext&, const Request&, Response&)
+ // Status(ServerCall&, const Request&, Response&)
//
- using UnaryFunction = Status (*)(ServerContext&,
+ using UnaryFunction = Status (*)(ServerCall&,
const void* request,
void* response);
// Generic version of the server streaming RPC function signature:
//
- // Status(ServerContext&, const Request&, ServerWriter<Response>&)
+ // Status(ServerCall&, const Request&, ServerWriter<Response>&)
//
- using ServerStreamingFunction = void (*)(ServerContext&,
+ using ServerStreamingFunction = void (*)(ServerCall&,
const void* request,
BaseServerWriter& writer);
diff --git a/pw_rpc/public/pw_rpc/test_method_context.h b/pw_rpc/public/pw_rpc/test_method_context.h
index ea5dc17..b0fdfc8 100644
--- a/pw_rpc/public/pw_rpc/test_method_context.h
+++ b/pw_rpc/public/pw_rpc/test_method_context.h
@@ -68,23 +68,30 @@
::pw::rpc::test_internal::ServiceTestUtilities< \
service, \
::pw::rpc::internal::Hash(#method_name)>, \
- service::method_name PW_COMMA_ARGS(__VA_ARGS__)>
+ &service::method_name PW_COMMA_ARGS(__VA_ARGS__)>
// Internal classes that implement PW_RPC_TEST_METHOD_CONTEXT.
namespace pw::rpc::test_internal {
+// Identifies a base class from a member function it defines. This should be
+// used with decltype to retrieve the base class.
+template <typename T, typename U>
+T BaseFromMember(U T::*);
+
// Finds the method object in a service at compile time. This class friended by
// the generated service classes to give it access to the internal method list.
template <typename ServiceType, uint32_t method_hash>
class ServiceTestUtilities {
public:
+ using BaseService =
+ decltype(BaseFromMember(&ServiceType::_PwRpcInternalGeneratedBase));
using Service = ServiceType;
static constexpr const internal::Method& method() { return *FindMethod(); }
private:
static constexpr const internal::Method* FindMethod() {
- for (const internal::Method& method : Service::kMethods) {
+ for (const internal::Method& method : BaseService::kMethods) {
if (method.id() == method_hash) {
return &method;
}
@@ -177,7 +184,8 @@
ctx_.output.clear();
ctx_.responses.emplace_back();
ctx_.responses.back() = {};
- return function(ctx_.call.context(), request, ctx_.responses.back());
+ return (ctx_.service.*function)(
+ ctx_.call.context(), request, ctx_.responses.back());
}
// Gives access to the RPC's response.
@@ -204,9 +212,10 @@
void call(const Request& request) {
ctx_.output.clear();
internal::BaseServerWriter server_writer(ctx_.call);
- function(ctx_.call.context(),
- request,
- static_cast<ServerWriter<Response>&>(server_writer));
+ return (ctx_.service.*function)(
+ ctx_.call.context(),
+ request,
+ static_cast<ServerWriter<Response>&>(server_writer));
}
// Returns the responses that have been recorded. The maximum number of
diff --git a/pw_rpc/py/pw_rpc/codegen_nanopb.py b/pw_rpc/py/pw_rpc/codegen_nanopb.py
index 7512cce..6e8b997 100644
--- a/pw_rpc/py/pw_rpc/codegen_nanopb.py
+++ b/pw_rpc/py/pw_rpc/codegen_nanopb.py
@@ -75,21 +75,32 @@
req_type = method.request_type().nanopb_name()
res_type = method.response_type().nanopb_name()
+ implementation_cast = 'static_cast<Implementation&>(call.service())'
output.write_line()
if method.type() == ProtoServiceMethod.Type.UNARY:
- output.write_line(f'static ::pw::Status {method.name()} (')
+ output.write_line(f'static ::pw::Status {method.name()}(')
with output.indent(4):
- output.write_line('ServerContext& ctx,')
+ output.write_line('::pw::rpc::internal::ServerCall& call,')
output.write_line(f'const {req_type}& request,')
- output.write_line(f'{res_type}& response);')
+ output.write_line(f'{res_type}& response) {{')
+ with output.indent():
+ output.write_line(f'return {implementation_cast}')
+ output.write_line(
+ f' .{method.name()}(call.context(), request, response);')
+ output.write_line('}')
elif method.type() == ProtoServiceMethod.Type.SERVER_STREAMING:
- output.write_line(f'static void {method.name()} (')
+ output.write_line(f'static void {method.name()}(')
with output.indent(4):
- output.write_line('ServerContext& ctx,')
+ output.write_line('::pw::rpc::internal::ServerCall& call,')
output.write_line(f'const {req_type}& request,')
- output.write_line(f'ServerWriter<{res_type}>& writer);')
+ output.write_line(f'ServerWriter<{res_type}>& writer) {{')
+ with output.indent():
+ output.write_line(implementation_cast)
+ output.write_line(
+ f' .{method.name()}(call.context(), request, writer);')
+ output.write_line('}')
else:
raise NotImplementedError(
'Only unary and server streaming RPCs are currently supported')
@@ -100,8 +111,9 @@
"""Generates a C++ derived class for a nanopb RPC service."""
base_class = f'{RPC_NAMESPACE}::internal::Service'
+ output.write_line('\ntemplate <typename Implementation>')
output.write_line(
- f'\nclass {service.cpp_namespace(root)} : public {base_class} {{')
+ f'class {service.cpp_namespace(root)} : public {base_class} {{')
output.write_line(' public:')
with output.indent():
@@ -124,8 +136,10 @@
output.write_line(f'static constexpr const char* name() '
f'{{ return "{service.name()}"; }}')
- for method in service.methods():
- _generate_code_for_method(method, output)
+ output.write_line()
+ output.write_line('// Used in test code to identify a base service.')
+ output.write_line(
+ 'constexpr void _PwRpcInternalGeneratedBase() const {}')
service_name_hash = pw_rpc.ids.calculate(service.proto_path())
output.write_line('\n private:')
@@ -135,8 +149,11 @@
output.write_line(
f'static constexpr uint32_t kServiceId = {hex(service_name_hash)};'
)
- output.write_line()
+ for method in service.methods():
+ _generate_code_for_method(method, output)
+
+ output.write_line()
output.write_line(
f'static constexpr std::array<{RPC_NAMESPACE}::internal::Method,'
f' {len(service.methods())}> kMethods = {{')