pw_rpc: Make service impls derived classes

This updates the generated RPC service class to act as a (non-virtual)
base which calls into a user-implemented derived class. This allows
users to extend generated service classes with custom members and
dependencies, while keeping the core static method dispatch structure.

Change-Id: I9ea6327790071a26fc940e3ac89b5bea2a2c4495
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/14123
Reviewed-by: Wyatt Hepler <hepler@google.com>
Commit-Queue: Alexei Frolov <frolv@google.com>
diff --git a/pw_rpc/nanopb/codegen_test.cc b/pw_rpc/nanopb/codegen_test.cc
index 1cd40f0..5a717fa 100644
--- a/pw_rpc/nanopb/codegen_test.cc
+++ b/pw_rpc/nanopb/codegen_test.cc
@@ -20,23 +20,25 @@
 namespace pw::rpc {
 namespace test {
 
-Status TestService::TestRpc(ServerContext&,
-                            const pw_rpc_test_TestRequest& request,
-                            pw_rpc_test_TestResponse& response) {
-  response.value = request.integer + 1;
-  return static_cast<Status::Code>(request.status_code);
-}
-
-void TestService::TestStreamRpc(
-    ServerContext&,
-    const pw_rpc_test_TestRequest& request,
-    ServerWriter<pw_rpc_test_TestStreamResponse>& writer) {
-  for (int i = 0; i < request.integer; ++i) {
-    writer.Write({.number = static_cast<uint32_t>(i)});
+class TestServiceImpl : public TestService<TestServiceImpl> {
+ public:
+  Status TestRpc(ServerContext&,
+                 const pw_rpc_test_TestRequest& request,
+                 pw_rpc_test_TestResponse& response) {
+    response.value = request.integer + 1;
+    return static_cast<Status::Code>(request.status_code);
   }
 
-  writer.Finish(static_cast<Status::Code>(request.status_code));
-}
+  void TestStreamRpc(ServerContext&,
+                     const pw_rpc_test_TestRequest& request,
+                     ServerWriter<pw_rpc_test_TestStreamResponse>& writer) {
+    for (int i = 0; i < request.integer; ++i) {
+      writer.Write({.number = static_cast<uint32_t>(i)});
+    }
+
+    writer.Finish(static_cast<Status::Code>(request.status_code));
+  }
+};
 
 }  // namespace test
 
@@ -44,13 +46,13 @@
 namespace {
 
 TEST(NanopbCodegen, CompilesProperly) {
-  test::TestService service;
+  test::TestServiceImpl service;
   EXPECT_EQ(service.id(), Hash("pw.rpc.test.TestService"));
   EXPECT_STREQ(service.name(), "TestService");
 }
 
 TEST(NanopbCodegen, InvokeUnaryRpc) {
-  PW_RPC_TEST_METHOD_CONTEXT(test::TestService, TestRpc) context;
+  PW_RPC_TEST_METHOD_CONTEXT(test::TestServiceImpl, TestRpc) context;
 
   EXPECT_EQ(Status::OK,
             context.call({.integer = 123, .status_code = Status::OK}));
@@ -64,7 +66,7 @@
 }
 
 TEST(NanopbCodegen, InvokeStreamingRpc) {
-  PW_RPC_TEST_METHOD_CONTEXT(test::TestService, TestStreamRpc) context;
+  PW_RPC_TEST_METHOD_CONTEXT(test::TestServiceImpl, TestStreamRpc) context;
 
   context.call({.integer = 0, .status_code = Status::ABORTED});
 
@@ -86,7 +88,7 @@
 }
 
 TEST(NanopbCodegen, InvokeStreamingRpc_ContextKeepsFixedNumberOfResponses) {
-  PW_RPC_TEST_METHOD_CONTEXT(test::TestService, TestStreamRpc, 3) context;
+  PW_RPC_TEST_METHOD_CONTEXT(test::TestServiceImpl, TestStreamRpc, 3) context;
 
   ASSERT_EQ(3u, context.responses().max_size());
 
diff --git a/pw_rpc/nanopb/method.cc b/pw_rpc/nanopb/method.cc
index 87997b5..a096ca8 100644
--- a/pw_rpc/nanopb/method.cc
+++ b/pw_rpc/nanopb/method.cc
@@ -64,8 +64,7 @@
     return;
   }
 
-  const Status status =
-      function_.unary(call.context(), request_struct, response_struct);
+  const Status status = function_.unary(call, request_struct, response_struct);
   SendResponse(call.channel(), request, response_struct, status);
 }
 
@@ -77,7 +76,7 @@
   }
 
   internal::BaseServerWriter server_writer(call);
-  function_.server_streaming(call.context(), request_struct, server_writer);
+  function_.server_streaming(call, request_struct, server_writer);
 }
 
 bool Method::DecodeRequest(Channel& channel,
diff --git a/pw_rpc/nanopb/method_test.cc b/pw_rpc/nanopb/method_test.cc
index 350dec2..bbdf753 100644
--- a/pw_rpc/nanopb/method_test.cc
+++ b/pw_rpc/nanopb/method_test.cc
@@ -47,21 +47,31 @@
   return std::as_bytes(buffer.first(output.bytes_written));
 }
 
+template <typename Implementation>
 class FakeGeneratedService : public Service {
  public:
   constexpr FakeGeneratedService(uint32_t id) : Service(id, kMethods) {}
 
-  static Status DoNothing(ServerContext&,
-                          const pw_rpc_test_Empty&,
-                          pw_rpc_test_Empty&);
+  static Status DoNothing(ServerCall& call,
+                          const pw_rpc_test_Empty& request,
+                          pw_rpc_test_Empty& response) {
+    return static_cast<Implementation&>(call.service())
+        .DoNothing(call.context(), request, response);
+  }
 
-  static Status AddFive(ServerContext&,
-                        const pw_rpc_test_TestRequest&,
-                        pw_rpc_test_TestResponse&);
+  static Status AddFive(ServerCall& call,
+                        const pw_rpc_test_TestRequest& request,
+                        pw_rpc_test_TestResponse& response) {
+    return static_cast<Implementation&>(call.service())
+        .AddFive(call.context(), request, response);
+  }
 
-  static void StartStream(ServerContext&,
-                          const pw_rpc_test_TestRequest&,
-                          ServerWriter<pw_rpc_test_TestResponse>&);
+  static void StartStream(ServerCall& call,
+                          const pw_rpc_test_TestRequest& request,
+                          ServerWriter<pw_rpc_test_TestResponse>& writer) {
+    static_cast<Implementation&>(call.service())
+        .StartStream(call.context(), request, writer);
+  }
 
   static constexpr std::array<Method, 3> kMethods = {
       Method::Unary<DoNothing>(
@@ -76,34 +86,38 @@
 pw_rpc_test_TestRequest last_request;
 ServerWriter<pw_rpc_test_TestResponse> last_writer;
 
-Status FakeGeneratedService::AddFive(ServerContext&,
-                                     const pw_rpc_test_TestRequest& request,
-                                     pw_rpc_test_TestResponse& response) {
-  last_request = request;
-  response.value = request.integer + 5;
-  return Status::UNAUTHENTICATED;
-}
+class FakeGeneratedServiceImpl
+    : public FakeGeneratedService<FakeGeneratedServiceImpl> {
+ public:
+  FakeGeneratedServiceImpl(uint32_t id) : FakeGeneratedService(id) {}
 
-Status FakeGeneratedService::DoNothing(ServerContext&,
-                                       const pw_rpc_test_Empty&,
-                                       pw_rpc_test_Empty&) {
-  return Status::UNKNOWN;
-}
+  Status AddFive(ServerContext&,
+                 const pw_rpc_test_TestRequest& request,
+                 pw_rpc_test_TestResponse& response) {
+    last_request = request;
+    response.value = request.integer + 5;
+    return Status::UNAUTHENTICATED;
+  }
 
-void FakeGeneratedService::StartStream(
-    ServerContext&,
-    const pw_rpc_test_TestRequest& request,
-    ServerWriter<pw_rpc_test_TestResponse>& writer) {
-  last_request = request;
+  Status DoNothing(ServerContext&,
+                   const pw_rpc_test_Empty&,
+                   pw_rpc_test_Empty&) {
+    return Status::UNKNOWN;
+  }
 
-  last_writer = std::move(writer);
-}
+  void StartStream(ServerContext&,
+                   const pw_rpc_test_TestRequest& request,
+                   ServerWriter<pw_rpc_test_TestResponse>& writer) {
+    last_request = request;
+    last_writer = std::move(writer);
+  }
+};
 
 TEST(Method, UnaryRpc_SendsResponse) {
   ENCODE_PB(pw_rpc_test_TestRequest, {.integer = 123}, request);
 
-  const Method& method = std::get<1>(FakeGeneratedService::kMethods);
-  ServerContextForTest<FakeGeneratedService> context(method);
+  const Method& method = std::get<1>(FakeGeneratedServiceImpl::kMethods);
+  ServerContextForTest<FakeGeneratedServiceImpl> context(method);
   method.Invoke(context.get(), context.packet(request));
 
   const Packet& response = context.output().sent_packet();
@@ -123,8 +137,8 @@
 TEST(Method, UnaryRpc_InvalidPayload_SendsError) {
   std::array<byte, 8> bad_payload{byte{0xFF}, byte{0xAA}, byte{0xDD}};
 
-  const Method& method = std::get<0>(FakeGeneratedService::kMethods);
-  ServerContextForTest<FakeGeneratedService> context(method);
+  const Method& method = std::get<0>(FakeGeneratedServiceImpl::kMethods);
+  ServerContextForTest<FakeGeneratedServiceImpl> context(method);
   method.Invoke(context.get(), context.packet(bad_payload));
 
   const Packet& packet = context.output().sent_packet();
@@ -138,9 +152,9 @@
   constexpr int64_t value = 0x7FFFFFFF'FFFFFF00ll;
   ENCODE_PB(pw_rpc_test_TestRequest, {.integer = value}, request);
 
-  const Method& method = std::get<1>(FakeGeneratedService::kMethods);
+  const Method& method = std::get<1>(FakeGeneratedServiceImpl::kMethods);
   // Output buffer is too small for the response, but can fit an error packet.
-  ServerContextForTest<FakeGeneratedService, 22> context(method);
+  ServerContextForTest<FakeGeneratedServiceImpl, 22> context(method);
   ASSERT_LT(context.output().buffer_size(),
             context.packet(request).MinEncodedSizeBytes() + request.size() + 1);
 
@@ -158,8 +172,8 @@
 TEST(Method, ServerStreamingRpc_SendsNothingWhenInitiallyCalled) {
   ENCODE_PB(pw_rpc_test_TestRequest, {.integer = 555}, request);
 
-  const Method& method = std::get<2>(FakeGeneratedService::kMethods);
-  ServerContextForTest<FakeGeneratedService> context(method);
+  const Method& method = std::get<2>(FakeGeneratedServiceImpl::kMethods);
+  ServerContextForTest<FakeGeneratedServiceImpl> context(method);
 
   method.Invoke(context.get(), context.packet(request));
 
@@ -168,8 +182,8 @@
 }
 
 TEST(Method, ServerWriter_SendsResponse) {
-  const Method& method = std::get<2>(FakeGeneratedService::kMethods);
-  ServerContextForTest<FakeGeneratedService> context(method);
+  const Method& method = std::get<2>(FakeGeneratedServiceImpl::kMethods);
+  ServerContextForTest<FakeGeneratedServiceImpl> context(method);
 
   method.Invoke(context.get(), context.packet({}));
 
@@ -188,14 +202,14 @@
 }
 
 TEST(Method, ServerStreamingRpc_ServerWriterBufferTooSmall_InternalError) {
-  const Method& method = std::get<2>(FakeGeneratedService::kMethods);
+  const Method& method = std::get<2>(FakeGeneratedServiceImpl::kMethods);
 
   constexpr size_t kNoPayloadPacketSize = 2 /* type */ + 2 /* channel */ +
                                           5 /* service */ + 5 /* method */ +
                                           2 /* payload */ + 2 /* status */;
 
   // Make the buffer barely fit a packet with no payload.
-  ServerContextForTest<FakeGeneratedService, kNoPayloadPacketSize> context(
+  ServerContextForTest<FakeGeneratedServiceImpl, kNoPayloadPacketSize> context(
       method);
 
   // Verify that the encoded size of a packet with an empty payload is correct.
diff --git a/pw_rpc/nanopb/public_overrides/pw_rpc/internal/method.h b/pw_rpc/nanopb/public_overrides/pw_rpc/internal/method.h
index 956c4ef..4cd8303 100644
--- a/pw_rpc/nanopb/public_overrides/pw_rpc/internal/method.h
+++ b/pw_rpc/nanopb/public_overrides/pw_rpc/internal/method.h
@@ -57,8 +57,7 @@
 
 // Specialization for unary RPCs.
 template <typename RequestType, typename ResponseType>
-struct RpcTraits<Status (*)(
-    ServerContext&, const RequestType&, ResponseType&)> {
+struct RpcTraits<Status (*)(ServerCall&, const RequestType&, ResponseType&)> {
   using Request = RequestType;
   using Response = ResponseType;
 
@@ -70,7 +69,7 @@
 // Specialization for server streaming RPCs.
 template <typename RequestType, typename ResponseType>
 struct RpcTraits<void (*)(
-    ServerContext&, const RequestType&, ServerWriter<ResponseType>&)> {
+    ServerCall&, const RequestType&, ServerWriter<ResponseType>&)> {
   using Request = RequestType;
   using Response = ResponseType;
 
@@ -79,6 +78,20 @@
   static constexpr bool kClientStreaming = false;
 };
 
+// Member function specialization for unary RPCs.
+template <typename T, typename RequestType, typename ResponseType>
+struct RpcTraits<Status (T::*)(
+    ServerContext&, const RequestType&, ResponseType&)>
+    : public RpcTraits<Status (*)(
+          ServerCall&, const RequestType&, ResponseType&)> {};
+
+// Member function specialization for server streaming RPCs.
+template <typename T, typename RequestType, typename ResponseType>
+struct RpcTraits<void (T::*)(
+    ServerContext&, const RequestType&, ServerWriter<ResponseType>&)>
+    : public RpcTraits<void (*)(
+          ServerCall&, const RequestType&, ServerWriter<ResponseType>&)> {};
+
 template <auto method>
 using Request = typename RpcTraits<decltype(method)>::Request;
 
@@ -109,9 +122,9 @@
     // In optimized builds, the compiler inlines the user-defined function into
     // this wrapper, elminating any overhead.
     return Method({.unary =
-                       [](ServerContext& ctx, const void* req, void* resp) {
+                       [](ServerCall& call, const void* req, void* resp) {
                          return method(
-                             ctx,
+                             call,
                              *static_cast<const Request<method>*>(req),
                              *static_cast<Response<method>*>(resp));
                        }},
@@ -133,8 +146,8 @@
     // union, defined below.
     return Method(
         {.server_streaming =
-             [](ServerContext& ctx, const void* req, BaseServerWriter& resp) {
-               method(ctx,
+             [](ServerCall& call, const void* req, BaseServerWriter& resp) {
+               method(call,
                       *static_cast<const Request<method>*>(req),
                       static_cast<ServerWriter<Response<method>>&>(resp));
              }},
@@ -163,17 +176,17 @@
  private:
   // Generic version of the unary RPC function signature:
   //
-  //   Status(ServerContext&, const Request&, Response&)
+  //   Status(ServerCall&, const Request&, Response&)
   //
-  using UnaryFunction = Status (*)(ServerContext&,
+  using UnaryFunction = Status (*)(ServerCall&,
                                    const void* request,
                                    void* response);
 
   // Generic version of the server streaming RPC function signature:
   //
-  //   Status(ServerContext&, const Request&, ServerWriter<Response>&)
+  //   Status(ServerCall&, const Request&, ServerWriter<Response>&)
   //
-  using ServerStreamingFunction = void (*)(ServerContext&,
+  using ServerStreamingFunction = void (*)(ServerCall&,
                                            const void* request,
                                            BaseServerWriter& writer);
 
diff --git a/pw_rpc/public/pw_rpc/test_method_context.h b/pw_rpc/public/pw_rpc/test_method_context.h
index ea5dc17..b0fdfc8 100644
--- a/pw_rpc/public/pw_rpc/test_method_context.h
+++ b/pw_rpc/public/pw_rpc/test_method_context.h
@@ -68,23 +68,30 @@
       ::pw::rpc::test_internal::ServiceTestUtilities<         \
           service,                                            \
           ::pw::rpc::internal::Hash(#method_name)>,           \
-      service::method_name PW_COMMA_ARGS(__VA_ARGS__)>
+      &service::method_name PW_COMMA_ARGS(__VA_ARGS__)>
 
 // Internal classes that implement PW_RPC_TEST_METHOD_CONTEXT.
 namespace pw::rpc::test_internal {
 
+// Identifies a base class from a member function it defines. This should be
+// used with decltype to retrieve the base class.
+template <typename T, typename U>
+T BaseFromMember(U T::*);
+
 // Finds the method object in a service at compile time. This class friended by
 // the generated service classes to give it access to the internal method list.
 template <typename ServiceType, uint32_t method_hash>
 class ServiceTestUtilities {
  public:
+  using BaseService =
+      decltype(BaseFromMember(&ServiceType::_PwRpcInternalGeneratedBase));
   using Service = ServiceType;
 
   static constexpr const internal::Method& method() { return *FindMethod(); }
 
  private:
   static constexpr const internal::Method* FindMethod() {
-    for (const internal::Method& method : Service::kMethods) {
+    for (const internal::Method& method : BaseService::kMethods) {
       if (method.id() == method_hash) {
         return &method;
       }
@@ -177,7 +184,8 @@
     ctx_.output.clear();
     ctx_.responses.emplace_back();
     ctx_.responses.back() = {};
-    return function(ctx_.call.context(), request, ctx_.responses.back());
+    return (ctx_.service.*function)(
+        ctx_.call.context(), request, ctx_.responses.back());
   }
 
   // Gives access to the RPC's response.
@@ -204,9 +212,10 @@
   void call(const Request& request) {
     ctx_.output.clear();
     internal::BaseServerWriter server_writer(ctx_.call);
-    function(ctx_.call.context(),
-             request,
-             static_cast<ServerWriter<Response>&>(server_writer));
+    return (ctx_.service.*function)(
+        ctx_.call.context(),
+        request,
+        static_cast<ServerWriter<Response>&>(server_writer));
   }
 
   // Returns the responses that have been recorded. The maximum number of
diff --git a/pw_rpc/py/pw_rpc/codegen_nanopb.py b/pw_rpc/py/pw_rpc/codegen_nanopb.py
index 7512cce..6e8b997 100644
--- a/pw_rpc/py/pw_rpc/codegen_nanopb.py
+++ b/pw_rpc/py/pw_rpc/codegen_nanopb.py
@@ -75,21 +75,32 @@
 
     req_type = method.request_type().nanopb_name()
     res_type = method.response_type().nanopb_name()
+    implementation_cast = 'static_cast<Implementation&>(call.service())'
 
     output.write_line()
 
     if method.type() == ProtoServiceMethod.Type.UNARY:
-        output.write_line(f'static ::pw::Status {method.name()} (')
+        output.write_line(f'static ::pw::Status {method.name()}(')
         with output.indent(4):
-            output.write_line('ServerContext& ctx,')
+            output.write_line('::pw::rpc::internal::ServerCall& call,')
             output.write_line(f'const {req_type}& request,')
-            output.write_line(f'{res_type}& response);')
+            output.write_line(f'{res_type}& response) {{')
+        with output.indent():
+            output.write_line(f'return {implementation_cast}')
+            output.write_line(
+                f'    .{method.name()}(call.context(), request, response);')
+        output.write_line('}')
     elif method.type() == ProtoServiceMethod.Type.SERVER_STREAMING:
-        output.write_line(f'static void {method.name()} (')
+        output.write_line(f'static void {method.name()}(')
         with output.indent(4):
-            output.write_line('ServerContext& ctx,')
+            output.write_line('::pw::rpc::internal::ServerCall& call,')
             output.write_line(f'const {req_type}& request,')
-            output.write_line(f'ServerWriter<{res_type}>& writer);')
+            output.write_line(f'ServerWriter<{res_type}>& writer) {{')
+        with output.indent():
+            output.write_line(implementation_cast)
+            output.write_line(
+                f'    .{method.name()}(call.context(), request, writer);')
+        output.write_line('}')
     else:
         raise NotImplementedError(
             'Only unary and server streaming RPCs are currently supported')
@@ -100,8 +111,9 @@
     """Generates a C++ derived class for a nanopb RPC service."""
 
     base_class = f'{RPC_NAMESPACE}::internal::Service'
+    output.write_line('\ntemplate <typename Implementation>')
     output.write_line(
-        f'\nclass {service.cpp_namespace(root)} : public {base_class} {{')
+        f'class {service.cpp_namespace(root)} : public {base_class} {{')
     output.write_line(' public:')
 
     with output.indent():
@@ -124,8 +136,10 @@
         output.write_line(f'static constexpr const char* name() '
                           f'{{ return "{service.name()}"; }}')
 
-        for method in service.methods():
-            _generate_code_for_method(method, output)
+        output.write_line()
+        output.write_line('// Used in test code to identify a base service.')
+        output.write_line(
+            'constexpr void _PwRpcInternalGeneratedBase() const {}')
 
     service_name_hash = pw_rpc.ids.calculate(service.proto_path())
     output.write_line('\n private:')
@@ -135,8 +149,11 @@
         output.write_line(
             f'static constexpr uint32_t kServiceId = {hex(service_name_hash)};'
         )
-        output.write_line()
 
+        for method in service.methods():
+            _generate_code_for_method(method, output)
+
+        output.write_line()
         output.write_line(
             f'static constexpr std::array<{RPC_NAMESPACE}::internal::Method,'
             f' {len(service.methods())}> kMethods = {{')