pw_rpc: Split ServerContext public/internal - internal::ServerCall contains context needed to manage an ongoing RPC call. - ServerContext is the user-facing interface to the ServerCall. - Add a Server pointer to the ServerCall, which will be used for streaming RPCs. Change-Id: I800ad139942e4f8d1d189c8cbba75ba0582912e5
diff --git a/pw_rpc/BUILD b/pw_rpc/BUILD index 296b223..ac70d4f 100644 --- a/pw_rpc/BUILD +++ b/pw_rpc/BUILD
@@ -61,6 +61,7 @@ hdrs = [ "public/pw_rpc/channel.h", "public/pw_rpc/internal/base_method.h", + "public/pw_rpc/internal/call.h", "public/pw_rpc/internal/packet.h", "public/pw_rpc/server.h", ],
diff --git a/pw_rpc/BUILD.gn b/pw_rpc/BUILD.gn index c8512e1..2a6b637 100644 --- a/pw_rpc/BUILD.gn +++ b/pw_rpc/BUILD.gn
@@ -32,8 +32,9 @@ template("_pw_rpc_server_library") { assert(defined(invoker.implementation), "_pw_rpc_server_library requires an implementation to be set") + _target_name = target_name - source_set(target_name) { + source_set(_target_name) { forward_variables_from(invoker, "*") public_configs = [ ":default_config" ] @@ -63,6 +64,17 @@ allow_circular_includes_from = [ implementation ] friend = [ "./*" ] } + + source_set("test_utils_$_target_name") { + public = [ "pw_rpc_private/test_utils.h" ] + public_configs = [ ":private_includes" ] + public_deps = [ + ":$_target_name", + ":common", + dir_pw_span, + ] + visibility = [ "./*" ] + } } # Classes with no dependencies on the protobuf library for method invocations. @@ -79,6 +91,7 @@ "channel.cc", "packet.cc", "public/pw_rpc/internal/base_method.h", + "public/pw_rpc/internal/call.h", "public/pw_rpc/internal/packet.h", ] friend = [ "./*" ] @@ -93,17 +106,7 @@ config("private_includes") { include_dirs = [ "." ] - visibility = [ ":test_utils" ] -} - -source_set("test_utils") { - public = [ "pw_rpc_private/test_utils.h" ] - public_configs = [ ":private_includes" ] - public_deps = [ - ":common", - dir_pw_span, - ] - visibility = [ "./*" ] + visibility = [ ":*" ] } pw_proto_library("protos") { @@ -137,7 +140,7 @@ pw_test("base_server_writer_test") { deps = [ ":test_server", - ":test_utils", + ":test_utils_test_server", ] sources = [ "base_server_writer_test.cc" ] } @@ -154,7 +157,7 @@ deps = [ ":protos_pwpb", ":test_server", - ":test_utils", + ":test_utils_test_server", dir_pw_assert, ] sources = [ "server_test.cc" ]
diff --git a/pw_rpc/base_server_writer.cc b/pw_rpc/base_server_writer.cc index 49b9a37..5156f4d 100644 --- a/pw_rpc/base_server_writer.cc +++ b/pw_rpc/base_server_writer.cc
@@ -44,7 +44,7 @@ } PW_DCHECK(response_.empty()); - response_ = context_.channel_->AcquireBuffer(); + response_ = context_.channel().AcquireBuffer(); // Reserve space for the RPC packet header. return packet().PayloadUsableSpace(response_); @@ -61,19 +61,19 @@ response_ = {}; if (!encoded.ok()) { - context_.channel_->SendAndReleaseBuffer(0); + context_.channel().SendAndReleaseBuffer(0); return Status::INTERNAL; } // TODO(hepler): Should Channel::SendAndReleaseBuffer return Status? - context_.channel_->SendAndReleaseBuffer(encoded.size()); + context_.channel().SendAndReleaseBuffer(encoded.size()); return Status::OK; } Packet BaseServerWriter::packet() const { return Packet(PacketType::RPC, context_.channel_id(), - context_.service_->id(), + context_.service().id(), method().id()); }
diff --git a/pw_rpc/base_server_writer_test.cc b/pw_rpc/base_server_writer_test.cc index a376aa0..7bf8d7f 100644 --- a/pw_rpc/base_server_writer_test.cc +++ b/pw_rpc/base_server_writer_test.cc
@@ -58,8 +58,7 @@ class FakeServerWriter : public BaseServerWriter { public: - constexpr FakeServerWriter(ServerContext& context) - : BaseServerWriter(context) {} + constexpr FakeServerWriter(ServerCall& context) : BaseServerWriter(context) {} constexpr FakeServerWriter() = default;
diff --git a/pw_rpc/nanopb/BUILD.gn b/pw_rpc/nanopb/BUILD.gn index 2aae9d9..a28d3a4 100644 --- a/pw_rpc/nanopb/BUILD.gn +++ b/pw_rpc/nanopb/BUILD.gn
@@ -41,7 +41,7 @@ deps = [ "..:nanopb_server", "..:test_protos_nanopb", - "..:test_utils", + "..:test_utils_nanopb_server", ] sources = [ "method_test.cc" ] enable_if = dir_third_party_nanopb != ""
diff --git a/pw_rpc/nanopb/method.cc b/pw_rpc/nanopb/method.cc index 366e53a..7ed0522 100644 --- a/pw_rpc/nanopb/method.cc +++ b/pw_rpc/nanopb/method.cc
@@ -56,7 +56,7 @@ return StatusWithSize::INTERNAL; } -StatusWithSize Method::CallUnary(ServerContext& context, +StatusWithSize Method::CallUnary(ServerCall& call, span<const byte> request_buffer, span<byte> response_buffer, void* request_struct, @@ -66,7 +66,7 @@ return StatusWithSize(status, 0); } - status = function_.unary(context, request_struct, response_struct); + status = function_.unary(call.context(), request_struct, response_struct); StatusWithSize encoded = EncodeResponse(response_struct, response_buffer); if (encoded.ok()) { @@ -75,7 +75,7 @@ return encoded; } -StatusWithSize Method::CallServerStreaming(ServerContext& context, +StatusWithSize Method::CallServerStreaming(ServerCall& call, span<const byte> request_buffer, void* request_struct) const { Status status = DecodeRequest(request_buffer, request_struct); @@ -83,9 +83,10 @@ return StatusWithSize(status, 0); } - internal::BaseServerWriter server_writer(context); + internal::BaseServerWriter server_writer(call); return StatusWithSize( - function_.server_streaming(context, request_struct, server_writer), 0); + function_.server_streaming(call.context(), request_struct, server_writer), + 0); } } // namespace pw::rpc::internal
diff --git a/pw_rpc/nanopb/public_overrides/pw_rpc/internal/method.h b/pw_rpc/nanopb/public_overrides/pw_rpc/internal/method.h index 8dd7764..3f74472 100644 --- a/pw_rpc/nanopb/public_overrides/pw_rpc/internal/method.h +++ b/pw_rpc/nanopb/public_overrides/pw_rpc/internal/method.h
@@ -136,10 +136,10 @@ // The pw::rpc::Server calls method.Invoke to call a user-defined RPC. Invoke // calls the invoker function, which encodes and decodes the request and // response (if any) and calls the user-defined RPC function. - StatusWithSize Invoke(ServerContext& context, + StatusWithSize Invoke(ServerCall& call, span<const std::byte> request, span<std::byte> payload_buffer) const { - return invoker_(*this, context, request, payload_buffer); + return invoker_(*this, call, request, payload_buffer); } // Decodes a request protobuf with Nanopb to the provided buffer. @@ -186,7 +186,7 @@ // RPC according to its type (unary, server streaming, etc.). The Invoker // returns the number of bytes written to the response buffer, if any. using Invoker = StatusWithSize (&)(const Method&, - ServerContext&, + ServerCall&, span<const std::byte>, span<std::byte>); @@ -201,13 +201,13 @@ request_fields_(request), response_fields_(response) {} - StatusWithSize CallUnary(ServerContext& context, + StatusWithSize CallUnary(ServerCall& call, span<const std::byte> request_buffer, span<std::byte> response_buffer, void* request_struct, void* response_struct) const; - StatusWithSize CallServerStreaming(ServerContext& context, + StatusWithSize CallServerStreaming(ServerCall& call, span<const std::byte> request_buffer, void* request_struct) const; @@ -218,7 +218,7 @@ // this function for each request/response type. template <size_t request_size, size_t response_size> static StatusWithSize UnaryInvoker(const Method& method, - ServerContext& context, + ServerCall& call, span<const std::byte> request_buffer, span<std::byte> response_buffer) { std::aligned_storage_t<request_size, alignof(std::max_align_t)> @@ -226,7 +226,7 @@ std::aligned_storage_t<response_size, alignof(std::max_align_t)> response_struct{}; - return method.CallUnary(context, + return method.CallUnary(call, request_buffer, response_buffer, &request_struct, @@ -239,13 +239,13 @@ template <size_t request_size> static StatusWithSize ServerStreamingInvoker( const Method& method, - ServerContext& context, + ServerCall& call, span<const std::byte> request_buffer, span<std::byte> /* payload not used */) { std::aligned_storage_t<request_size, alignof(std::max_align_t)> request_struct{}; - return method.CallServerStreaming(context, request_buffer, &request_struct); + return method.CallServerStreaming(call, request_buffer, &request_struct); } // Allocates memory for the request/response structs and invokes the
diff --git a/pw_rpc/public/pw_rpc/channel.h b/pw_rpc/public/pw_rpc/channel.h index ec5d5bd..08ae219 100644 --- a/pw_rpc/public/pw_rpc/channel.h +++ b/pw_rpc/public/pw_rpc/channel.h
@@ -57,7 +57,7 @@ // static, or it can set to null to allow dynamically opening connections // through the channel. template <uint32_t id> - static Channel Create(ChannelOutput* output) { + constexpr static Channel Create(ChannelOutput* output) { static_assert(id != kUnassignedChannelId, "Channel ID cannot be 0"); return Channel(id, output); }
diff --git a/pw_rpc/public/pw_rpc/internal/base_method.h b/pw_rpc/public/pw_rpc/internal/base_method.h index bf6c95f..07db616 100644 --- a/pw_rpc/public/pw_rpc/internal/base_method.h +++ b/pw_rpc/public/pw_rpc/internal/base_method.h
@@ -26,7 +26,7 @@ // Implementations must provide the Invoke method, which the Server calls: // - // StatusWithSize Invoke(ServerContext& context, + // StatusWithSize Invoke(ServerCall& call, // span<const std::byte> request, // span<std::byte> payload_buffer) const;
diff --git a/pw_rpc/public/pw_rpc/internal/base_server_writer.h b/pw_rpc/public/pw_rpc/internal/base_server_writer.h index f696ed9..8c95816 100644 --- a/pw_rpc/public/pw_rpc/internal/base_server_writer.h +++ b/pw_rpc/public/pw_rpc/internal/base_server_writer.h
@@ -33,7 +33,7 @@ // cancelling / terminating ongoing streaming RPCs. class BaseServerWriter { public: - constexpr BaseServerWriter(ServerContext& context) + constexpr BaseServerWriter(ServerCall& context) : context_(context), state_{kOpen} {} BaseServerWriter(BaseServerWriter&& other) { *this = std::move(other); } @@ -51,7 +51,7 @@ protected: constexpr BaseServerWriter() : state_{kClosed} {} - const Method& method() const { return *context_.method_; } + const Method& method() const { return context_.method(); } span<std::byte> AcquireBuffer(); @@ -60,7 +60,7 @@ private: Packet packet() const; - ServerContext context_; + ServerCall context_; span<std::byte> response_; enum { kClosed, kOpen } state_; };
diff --git a/pw_rpc/public/pw_rpc/internal/call.h b/pw_rpc/public/pw_rpc/internal/call.h new file mode 100644 index 0000000..580f7fa --- /dev/null +++ b/pw_rpc/public/pw_rpc/internal/call.h
@@ -0,0 +1,91 @@ +// Copyright 2020 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +#pragma once + +#include <cstddef> +#include <cstdint> + +#include "pw_assert/assert.h" +#include "pw_rpc/channel.h" + +namespace pw::rpc { + +class Server; +class ServerContext; + +namespace internal { + +class Method; +class Service; + +// Collects information for an ongoing RPC being processed by the server. +// The Server creates a ServerCall object to represent a method invocation. The +// ServerCall is copied into a ServerWriter or ServerReader for streaming RPCs. +// +// ServerCall is a strictly internal class. ServerContext is the public +// interface to the internal::ServerCall. +class ServerCall { + public: + uint32_t channel_id() const { return channel().id(); } + + constexpr ServerCall() + : server_(nullptr), + channel_(nullptr), + service_(nullptr), + method_(nullptr) {} + + constexpr ServerCall(Server& server, + Channel& channel, + internal::Service& service, + const internal::Method& method) + : server_(&server), + channel_(&channel), + service_(&service), + method_(&method) {} + + constexpr ServerCall(const ServerCall&) = default; + constexpr ServerCall& operator=(const ServerCall&) = default; + + // Access the ServerContext for this call. Defined in pw_rpc/server_context.h. + ServerContext& context(); + + Server& server() const { + PW_DCHECK_NOTNULL(server_); + return *server_; + } + + Channel& channel() const { + PW_DCHECK_NOTNULL(channel_); + return *channel_; + } + + internal::Service& service() const { + PW_DCHECK_NOTNULL(service_); + return *service_; + } + + const internal::Method& method() const { + PW_DCHECK_NOTNULL(method_); + return *method_; + } + + private: + Server* server_; + Channel* channel_; + internal::Service* service_; + const internal::Method* method_; +}; + +} // namespace internal +} // namespace pw::rpc
diff --git a/pw_rpc/public/pw_rpc/server.h b/pw_rpc/public/pw_rpc/server.h index dd55c93..6a9809e 100644 --- a/pw_rpc/public/pw_rpc/server.h +++ b/pw_rpc/public/pw_rpc/server.h
@@ -45,13 +45,10 @@ constexpr size_t channel_count() const { return channels_.size(); } private: - using Service = internal::Service; - using ServiceRegistry = internal::ServiceRegistry; - void InvokeMethod(const internal::Packet& request, Channel& channel, internal::Packet& response, - span<std::byte> buffer) const; + span<std::byte> buffer); void SendResponse(const Channel& output, const internal::Packet& response, @@ -61,7 +58,7 @@ Channel* AssignChannel(uint32_t id, ChannelOutput& interface); span<Channel> channels_; - ServiceRegistry services_; + internal::ServiceRegistry services_; }; } // namespace pw::rpc
diff --git a/pw_rpc/public/pw_rpc/server_context.h b/pw_rpc/public/pw_rpc/server_context.h index 2699039..91e47d1 100644 --- a/pw_rpc/public/pw_rpc/server_context.h +++ b/pw_rpc/public/pw_rpc/server_context.h
@@ -16,48 +16,38 @@ #include <cstddef> #include <cstdint> -#include "pw_assert/assert.h" -#include "pw_rpc/channel.h" +#include "pw_rpc/internal/call.h" namespace pw::rpc { -namespace internal { -class Method; -class Service; -class BaseServerWriter; - -} // namespace internal - -// The ServerContext collects context for an RPC being invoked on a server. -class ServerContext { +// The ServerContext collects context for an RPC being invoked on a server. The +// ServerContext is passed into RPC functions and is user-facing. +// +// The ServerContext is a public-facing view of the internal::ServerCall class. +// It uses inheritance to avoid copying or creating an extra reference to the +// underlying ServerCall. Private inheritance prevents exposing the +// internal-facing ServerCall interface. +class ServerContext : private internal::ServerCall { public: - uint32_t channel_id() const { - PW_DCHECK_NOTNULL(channel_); - return channel_->id(); - } + // Returns the ID for the channel this RPC is using. + uint32_t channel_id() const { return channel().id(); } - private: - friend class Server; - friend class internal::BaseServerWriter; + constexpr ServerContext() = delete; - // Allow ServerContexts to be created in tests. - template <typename, size_t, uint32_t, uint32_t> - friend class ServerContextForTest; + constexpr ServerContext(const ServerContext&) = delete; + constexpr ServerContext& operator=(const ServerContext&) = delete; - constexpr ServerContext() - : channel_(nullptr), service_(nullptr), method_(nullptr) {} + constexpr ServerContext(ServerContext&&) = delete; + constexpr ServerContext& operator=(ServerContext&&) = delete; - constexpr ServerContext(Channel& channel, - internal::Service& service, - const internal::Method& method) - : channel_(&channel), service_(&service), method_(&method) {} - - constexpr ServerContext(const ServerContext&) = default; - constexpr ServerContext& operator=(const ServerContext&) = default; - - Channel* channel_; - internal::Service* service_; - const internal::Method* method_; + friend class internal::ServerCall; // Allow down-casting from ServerCall. }; +namespace internal { + +inline ServerContext& ServerCall::context() { + return static_cast<ServerContext&>(*this); +} + +} // namespace internal } // namespace pw::rpc
diff --git a/pw_rpc/pw_rpc_private/test_utils.h b/pw_rpc/pw_rpc_private/test_utils.h index 321eb3f..262a079 100644 --- a/pw_rpc/pw_rpc_private/test_utils.h +++ b/pw_rpc/pw_rpc_private/test_utils.h
@@ -17,7 +17,9 @@ #include <cstdint> #include "pw_rpc/channel.h" +#include "pw_rpc/internal/method.h" #include "pw_rpc/internal/packet.h" +#include "pw_rpc/server.h" #include "pw_span/span.h" namespace pw::rpc { @@ -41,12 +43,6 @@ span<const std::byte> sent_packet_; }; -namespace internal { - -class Method; - -} // namespace internal - template <typename Service, size_t output_buffer_size = 128, uint32_t channel_id = 99, @@ -58,8 +54,11 @@ ServerContextForTest(const internal::Method& method) : channel_(Channel::Create<kChannelId>(&output_)), + server_(span(&channel_, 1)), service_(kServiceId), - context_(channel_, service_, method) {} + context_(server_, channel_, service_, method) { + server_.RegisterService(service_); + } ServerContextForTest() : ServerContextForTest(service_.method) {} @@ -68,20 +67,21 @@ return internal::Packet(internal::PacketType::RPC, kChannelId, kServiceId, - context_.method_->id(), + context_.method().id(), payload, Status::OK); } - ServerContext& get() { return context_; } + internal::ServerCall& get() { return context_; } const auto& output() const { return output_; } private: TestOutput<output_buffer_size> output_; Channel channel_; + Server server_; Service service_; - ServerContext context_; + internal::ServerCall context_; }; } // namespace pw::rpc
diff --git a/pw_rpc/server.cc b/pw_rpc/server.cc index 7ad7cf0..b239f3d 100644 --- a/pw_rpc/server.cc +++ b/pw_rpc/server.cc
@@ -70,8 +70,8 @@ void Server::InvokeMethod(const Packet& request, Channel& channel, internal::Packet& response, - span<std::byte> buffer) const { - Service* service = services_.Find(request.service_id()); + span<std::byte> buffer) { + internal::Service* service = services_.Find(request.service_id()); if (service == nullptr) { // Couldn't find the requested service. Reply with a NOT_FOUND response // without the service_id field set. @@ -92,11 +92,12 @@ response.set_method_id(method->id()); - ServerContext context(channel, *service, *method); - span<byte> response_buffer = request.PayloadUsableSpace(buffer); + + internal::ServerCall call(*this, channel, *service, *method); StatusWithSize result = - method->Invoke(context, request.payload(), response_buffer); + method->Invoke(call, request.payload(), response_buffer); + response.set_status(result.status()); response.set_payload(response_buffer.first(result.size())); }
diff --git a/pw_rpc/test_impl/public_overrides/pw_rpc/internal/method.h b/pw_rpc/test_impl/public_overrides/pw_rpc/internal/method.h index ecd8f02..4c91255 100644 --- a/pw_rpc/test_impl/public_overrides/pw_rpc/internal/method.h +++ b/pw_rpc/test_impl/public_overrides/pw_rpc/internal/method.h
@@ -29,10 +29,10 @@ public: constexpr Method(uint32_t id) : BaseMethod(id), last_channel_id_(0) {} - StatusWithSize Invoke(ServerContext& context, + StatusWithSize Invoke(ServerCall& call, span<const std::byte> request, span<std::byte> payload_buffer) const { - last_channel_id_ = context.channel_id(); + last_channel_id_ = call.channel_id(); last_request_ = request; last_payload_buffer_ = payload_buffer;