pw_rpc: Split ServerContext public/internal
- internal::ServerCall contains context needed to manage an ongoing RPC
call.
- ServerContext is the user-facing interface to the ServerCall.
- Add a Server pointer to the ServerCall, which will be used for
streaming RPCs.
Change-Id: I800ad139942e4f8d1d189c8cbba75ba0582912e5
diff --git a/pw_rpc/BUILD b/pw_rpc/BUILD
index 296b223..ac70d4f 100644
--- a/pw_rpc/BUILD
+++ b/pw_rpc/BUILD
@@ -61,6 +61,7 @@
hdrs = [
"public/pw_rpc/channel.h",
"public/pw_rpc/internal/base_method.h",
+ "public/pw_rpc/internal/call.h",
"public/pw_rpc/internal/packet.h",
"public/pw_rpc/server.h",
],
diff --git a/pw_rpc/BUILD.gn b/pw_rpc/BUILD.gn
index c8512e1..2a6b637 100644
--- a/pw_rpc/BUILD.gn
+++ b/pw_rpc/BUILD.gn
@@ -32,8 +32,9 @@
template("_pw_rpc_server_library") {
assert(defined(invoker.implementation),
"_pw_rpc_server_library requires an implementation to be set")
+ _target_name = target_name
- source_set(target_name) {
+ source_set(_target_name) {
forward_variables_from(invoker, "*")
public_configs = [ ":default_config" ]
@@ -63,6 +64,17 @@
allow_circular_includes_from = [ implementation ]
friend = [ "./*" ]
}
+
+ source_set("test_utils_$_target_name") {
+ public = [ "pw_rpc_private/test_utils.h" ]
+ public_configs = [ ":private_includes" ]
+ public_deps = [
+ ":$_target_name",
+ ":common",
+ dir_pw_span,
+ ]
+ visibility = [ "./*" ]
+ }
}
# Classes with no dependencies on the protobuf library for method invocations.
@@ -79,6 +91,7 @@
"channel.cc",
"packet.cc",
"public/pw_rpc/internal/base_method.h",
+ "public/pw_rpc/internal/call.h",
"public/pw_rpc/internal/packet.h",
]
friend = [ "./*" ]
@@ -93,17 +106,7 @@
config("private_includes") {
include_dirs = [ "." ]
- visibility = [ ":test_utils" ]
-}
-
-source_set("test_utils") {
- public = [ "pw_rpc_private/test_utils.h" ]
- public_configs = [ ":private_includes" ]
- public_deps = [
- ":common",
- dir_pw_span,
- ]
- visibility = [ "./*" ]
+ visibility = [ ":*" ]
}
pw_proto_library("protos") {
@@ -137,7 +140,7 @@
pw_test("base_server_writer_test") {
deps = [
":test_server",
- ":test_utils",
+ ":test_utils_test_server",
]
sources = [ "base_server_writer_test.cc" ]
}
@@ -154,7 +157,7 @@
deps = [
":protos_pwpb",
":test_server",
- ":test_utils",
+ ":test_utils_test_server",
dir_pw_assert,
]
sources = [ "server_test.cc" ]
diff --git a/pw_rpc/base_server_writer.cc b/pw_rpc/base_server_writer.cc
index 49b9a37..5156f4d 100644
--- a/pw_rpc/base_server_writer.cc
+++ b/pw_rpc/base_server_writer.cc
@@ -44,7 +44,7 @@
}
PW_DCHECK(response_.empty());
- response_ = context_.channel_->AcquireBuffer();
+ response_ = context_.channel().AcquireBuffer();
// Reserve space for the RPC packet header.
return packet().PayloadUsableSpace(response_);
@@ -61,19 +61,19 @@
response_ = {};
if (!encoded.ok()) {
- context_.channel_->SendAndReleaseBuffer(0);
+ context_.channel().SendAndReleaseBuffer(0);
return Status::INTERNAL;
}
// TODO(hepler): Should Channel::SendAndReleaseBuffer return Status?
- context_.channel_->SendAndReleaseBuffer(encoded.size());
+ context_.channel().SendAndReleaseBuffer(encoded.size());
return Status::OK;
}
Packet BaseServerWriter::packet() const {
return Packet(PacketType::RPC,
context_.channel_id(),
- context_.service_->id(),
+ context_.service().id(),
method().id());
}
diff --git a/pw_rpc/base_server_writer_test.cc b/pw_rpc/base_server_writer_test.cc
index a376aa0..7bf8d7f 100644
--- a/pw_rpc/base_server_writer_test.cc
+++ b/pw_rpc/base_server_writer_test.cc
@@ -58,8 +58,7 @@
class FakeServerWriter : public BaseServerWriter {
public:
- constexpr FakeServerWriter(ServerContext& context)
- : BaseServerWriter(context) {}
+ constexpr FakeServerWriter(ServerCall& context) : BaseServerWriter(context) {}
constexpr FakeServerWriter() = default;
diff --git a/pw_rpc/nanopb/BUILD.gn b/pw_rpc/nanopb/BUILD.gn
index 2aae9d9..a28d3a4 100644
--- a/pw_rpc/nanopb/BUILD.gn
+++ b/pw_rpc/nanopb/BUILD.gn
@@ -41,7 +41,7 @@
deps = [
"..:nanopb_server",
"..:test_protos_nanopb",
- "..:test_utils",
+ "..:test_utils_nanopb_server",
]
sources = [ "method_test.cc" ]
enable_if = dir_third_party_nanopb != ""
diff --git a/pw_rpc/nanopb/method.cc b/pw_rpc/nanopb/method.cc
index 366e53a..7ed0522 100644
--- a/pw_rpc/nanopb/method.cc
+++ b/pw_rpc/nanopb/method.cc
@@ -56,7 +56,7 @@
return StatusWithSize::INTERNAL;
}
-StatusWithSize Method::CallUnary(ServerContext& context,
+StatusWithSize Method::CallUnary(ServerCall& call,
span<const byte> request_buffer,
span<byte> response_buffer,
void* request_struct,
@@ -66,7 +66,7 @@
return StatusWithSize(status, 0);
}
- status = function_.unary(context, request_struct, response_struct);
+ status = function_.unary(call.context(), request_struct, response_struct);
StatusWithSize encoded = EncodeResponse(response_struct, response_buffer);
if (encoded.ok()) {
@@ -75,7 +75,7 @@
return encoded;
}
-StatusWithSize Method::CallServerStreaming(ServerContext& context,
+StatusWithSize Method::CallServerStreaming(ServerCall& call,
span<const byte> request_buffer,
void* request_struct) const {
Status status = DecodeRequest(request_buffer, request_struct);
@@ -83,9 +83,10 @@
return StatusWithSize(status, 0);
}
- internal::BaseServerWriter server_writer(context);
+ internal::BaseServerWriter server_writer(call);
return StatusWithSize(
- function_.server_streaming(context, request_struct, server_writer), 0);
+ function_.server_streaming(call.context(), request_struct, server_writer),
+ 0);
}
} // namespace pw::rpc::internal
diff --git a/pw_rpc/nanopb/public_overrides/pw_rpc/internal/method.h b/pw_rpc/nanopb/public_overrides/pw_rpc/internal/method.h
index 8dd7764..3f74472 100644
--- a/pw_rpc/nanopb/public_overrides/pw_rpc/internal/method.h
+++ b/pw_rpc/nanopb/public_overrides/pw_rpc/internal/method.h
@@ -136,10 +136,10 @@
// The pw::rpc::Server calls method.Invoke to call a user-defined RPC. Invoke
// calls the invoker function, which encodes and decodes the request and
// response (if any) and calls the user-defined RPC function.
- StatusWithSize Invoke(ServerContext& context,
+ StatusWithSize Invoke(ServerCall& call,
span<const std::byte> request,
span<std::byte> payload_buffer) const {
- return invoker_(*this, context, request, payload_buffer);
+ return invoker_(*this, call, request, payload_buffer);
}
// Decodes a request protobuf with Nanopb to the provided buffer.
@@ -186,7 +186,7 @@
// RPC according to its type (unary, server streaming, etc.). The Invoker
// returns the number of bytes written to the response buffer, if any.
using Invoker = StatusWithSize (&)(const Method&,
- ServerContext&,
+ ServerCall&,
span<const std::byte>,
span<std::byte>);
@@ -201,13 +201,13 @@
request_fields_(request),
response_fields_(response) {}
- StatusWithSize CallUnary(ServerContext& context,
+ StatusWithSize CallUnary(ServerCall& call,
span<const std::byte> request_buffer,
span<std::byte> response_buffer,
void* request_struct,
void* response_struct) const;
- StatusWithSize CallServerStreaming(ServerContext& context,
+ StatusWithSize CallServerStreaming(ServerCall& call,
span<const std::byte> request_buffer,
void* request_struct) const;
@@ -218,7 +218,7 @@
// this function for each request/response type.
template <size_t request_size, size_t response_size>
static StatusWithSize UnaryInvoker(const Method& method,
- ServerContext& context,
+ ServerCall& call,
span<const std::byte> request_buffer,
span<std::byte> response_buffer) {
std::aligned_storage_t<request_size, alignof(std::max_align_t)>
@@ -226,7 +226,7 @@
std::aligned_storage_t<response_size, alignof(std::max_align_t)>
response_struct{};
- return method.CallUnary(context,
+ return method.CallUnary(call,
request_buffer,
response_buffer,
&request_struct,
@@ -239,13 +239,13 @@
template <size_t request_size>
static StatusWithSize ServerStreamingInvoker(
const Method& method,
- ServerContext& context,
+ ServerCall& call,
span<const std::byte> request_buffer,
span<std::byte> /* payload not used */) {
std::aligned_storage_t<request_size, alignof(std::max_align_t)>
request_struct{};
- return method.CallServerStreaming(context, request_buffer, &request_struct);
+ return method.CallServerStreaming(call, request_buffer, &request_struct);
}
// Allocates memory for the request/response structs and invokes the
diff --git a/pw_rpc/public/pw_rpc/channel.h b/pw_rpc/public/pw_rpc/channel.h
index ec5d5bd..08ae219 100644
--- a/pw_rpc/public/pw_rpc/channel.h
+++ b/pw_rpc/public/pw_rpc/channel.h
@@ -57,7 +57,7 @@
// static, or it can set to null to allow dynamically opening connections
// through the channel.
template <uint32_t id>
- static Channel Create(ChannelOutput* output) {
+ constexpr static Channel Create(ChannelOutput* output) {
static_assert(id != kUnassignedChannelId, "Channel ID cannot be 0");
return Channel(id, output);
}
diff --git a/pw_rpc/public/pw_rpc/internal/base_method.h b/pw_rpc/public/pw_rpc/internal/base_method.h
index bf6c95f..07db616 100644
--- a/pw_rpc/public/pw_rpc/internal/base_method.h
+++ b/pw_rpc/public/pw_rpc/internal/base_method.h
@@ -26,7 +26,7 @@
// Implementations must provide the Invoke method, which the Server calls:
//
- // StatusWithSize Invoke(ServerContext& context,
+ // StatusWithSize Invoke(ServerCall& call,
// span<const std::byte> request,
// span<std::byte> payload_buffer) const;
diff --git a/pw_rpc/public/pw_rpc/internal/base_server_writer.h b/pw_rpc/public/pw_rpc/internal/base_server_writer.h
index f696ed9..8c95816 100644
--- a/pw_rpc/public/pw_rpc/internal/base_server_writer.h
+++ b/pw_rpc/public/pw_rpc/internal/base_server_writer.h
@@ -33,7 +33,7 @@
// cancelling / terminating ongoing streaming RPCs.
class BaseServerWriter {
public:
- constexpr BaseServerWriter(ServerContext& context)
+ constexpr BaseServerWriter(ServerCall& context)
: context_(context), state_{kOpen} {}
BaseServerWriter(BaseServerWriter&& other) { *this = std::move(other); }
@@ -51,7 +51,7 @@
protected:
constexpr BaseServerWriter() : state_{kClosed} {}
- const Method& method() const { return *context_.method_; }
+ const Method& method() const { return context_.method(); }
span<std::byte> AcquireBuffer();
@@ -60,7 +60,7 @@
private:
Packet packet() const;
- ServerContext context_;
+ ServerCall context_;
span<std::byte> response_;
enum { kClosed, kOpen } state_;
};
diff --git a/pw_rpc/public/pw_rpc/internal/call.h b/pw_rpc/public/pw_rpc/internal/call.h
new file mode 100644
index 0000000..580f7fa
--- /dev/null
+++ b/pw_rpc/public/pw_rpc/internal/call.h
@@ -0,0 +1,91 @@
+// Copyright 2020 The Pigweed Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy of
+// the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+#pragma once
+
+#include <cstddef>
+#include <cstdint>
+
+#include "pw_assert/assert.h"
+#include "pw_rpc/channel.h"
+
+namespace pw::rpc {
+
+class Server;
+class ServerContext;
+
+namespace internal {
+
+class Method;
+class Service;
+
+// Collects information for an ongoing RPC being processed by the server.
+// The Server creates a ServerCall object to represent a method invocation. The
+// ServerCall is copied into a ServerWriter or ServerReader for streaming RPCs.
+//
+// ServerCall is a strictly internal class. ServerContext is the public
+// interface to the internal::ServerCall.
+class ServerCall {
+ public:
+ uint32_t channel_id() const { return channel().id(); }
+
+ constexpr ServerCall()
+ : server_(nullptr),
+ channel_(nullptr),
+ service_(nullptr),
+ method_(nullptr) {}
+
+ constexpr ServerCall(Server& server,
+ Channel& channel,
+ internal::Service& service,
+ const internal::Method& method)
+ : server_(&server),
+ channel_(&channel),
+ service_(&service),
+ method_(&method) {}
+
+ constexpr ServerCall(const ServerCall&) = default;
+ constexpr ServerCall& operator=(const ServerCall&) = default;
+
+ // Access the ServerContext for this call. Defined in pw_rpc/server_context.h.
+ ServerContext& context();
+
+ Server& server() const {
+ PW_DCHECK_NOTNULL(server_);
+ return *server_;
+ }
+
+ Channel& channel() const {
+ PW_DCHECK_NOTNULL(channel_);
+ return *channel_;
+ }
+
+ internal::Service& service() const {
+ PW_DCHECK_NOTNULL(service_);
+ return *service_;
+ }
+
+ const internal::Method& method() const {
+ PW_DCHECK_NOTNULL(method_);
+ return *method_;
+ }
+
+ private:
+ Server* server_;
+ Channel* channel_;
+ internal::Service* service_;
+ const internal::Method* method_;
+};
+
+} // namespace internal
+} // namespace pw::rpc
diff --git a/pw_rpc/public/pw_rpc/server.h b/pw_rpc/public/pw_rpc/server.h
index dd55c93..6a9809e 100644
--- a/pw_rpc/public/pw_rpc/server.h
+++ b/pw_rpc/public/pw_rpc/server.h
@@ -45,13 +45,10 @@
constexpr size_t channel_count() const { return channels_.size(); }
private:
- using Service = internal::Service;
- using ServiceRegistry = internal::ServiceRegistry;
-
void InvokeMethod(const internal::Packet& request,
Channel& channel,
internal::Packet& response,
- span<std::byte> buffer) const;
+ span<std::byte> buffer);
void SendResponse(const Channel& output,
const internal::Packet& response,
@@ -61,7 +58,7 @@
Channel* AssignChannel(uint32_t id, ChannelOutput& interface);
span<Channel> channels_;
- ServiceRegistry services_;
+ internal::ServiceRegistry services_;
};
} // namespace pw::rpc
diff --git a/pw_rpc/public/pw_rpc/server_context.h b/pw_rpc/public/pw_rpc/server_context.h
index 2699039..91e47d1 100644
--- a/pw_rpc/public/pw_rpc/server_context.h
+++ b/pw_rpc/public/pw_rpc/server_context.h
@@ -16,48 +16,38 @@
#include <cstddef>
#include <cstdint>
-#include "pw_assert/assert.h"
-#include "pw_rpc/channel.h"
+#include "pw_rpc/internal/call.h"
namespace pw::rpc {
-namespace internal {
-class Method;
-class Service;
-class BaseServerWriter;
-
-} // namespace internal
-
-// The ServerContext collects context for an RPC being invoked on a server.
-class ServerContext {
+// The ServerContext collects context for an RPC being invoked on a server. The
+// ServerContext is passed into RPC functions and is user-facing.
+//
+// The ServerContext is a public-facing view of the internal::ServerCall class.
+// It uses inheritance to avoid copying or creating an extra reference to the
+// underlying ServerCall. Private inheritance prevents exposing the
+// internal-facing ServerCall interface.
+class ServerContext : private internal::ServerCall {
public:
- uint32_t channel_id() const {
- PW_DCHECK_NOTNULL(channel_);
- return channel_->id();
- }
+ // Returns the ID for the channel this RPC is using.
+ uint32_t channel_id() const { return channel().id(); }
- private:
- friend class Server;
- friend class internal::BaseServerWriter;
+ constexpr ServerContext() = delete;
- // Allow ServerContexts to be created in tests.
- template <typename, size_t, uint32_t, uint32_t>
- friend class ServerContextForTest;
+ constexpr ServerContext(const ServerContext&) = delete;
+ constexpr ServerContext& operator=(const ServerContext&) = delete;
- constexpr ServerContext()
- : channel_(nullptr), service_(nullptr), method_(nullptr) {}
+ constexpr ServerContext(ServerContext&&) = delete;
+ constexpr ServerContext& operator=(ServerContext&&) = delete;
- constexpr ServerContext(Channel& channel,
- internal::Service& service,
- const internal::Method& method)
- : channel_(&channel), service_(&service), method_(&method) {}
-
- constexpr ServerContext(const ServerContext&) = default;
- constexpr ServerContext& operator=(const ServerContext&) = default;
-
- Channel* channel_;
- internal::Service* service_;
- const internal::Method* method_;
+ friend class internal::ServerCall; // Allow down-casting from ServerCall.
};
+namespace internal {
+
+inline ServerContext& ServerCall::context() {
+ return static_cast<ServerContext&>(*this);
+}
+
+} // namespace internal
} // namespace pw::rpc
diff --git a/pw_rpc/pw_rpc_private/test_utils.h b/pw_rpc/pw_rpc_private/test_utils.h
index 321eb3f..262a079 100644
--- a/pw_rpc/pw_rpc_private/test_utils.h
+++ b/pw_rpc/pw_rpc_private/test_utils.h
@@ -17,7 +17,9 @@
#include <cstdint>
#include "pw_rpc/channel.h"
+#include "pw_rpc/internal/method.h"
#include "pw_rpc/internal/packet.h"
+#include "pw_rpc/server.h"
#include "pw_span/span.h"
namespace pw::rpc {
@@ -41,12 +43,6 @@
span<const std::byte> sent_packet_;
};
-namespace internal {
-
-class Method;
-
-} // namespace internal
-
template <typename Service,
size_t output_buffer_size = 128,
uint32_t channel_id = 99,
@@ -58,8 +54,11 @@
ServerContextForTest(const internal::Method& method)
: channel_(Channel::Create<kChannelId>(&output_)),
+ server_(span(&channel_, 1)),
service_(kServiceId),
- context_(channel_, service_, method) {}
+ context_(server_, channel_, service_, method) {
+ server_.RegisterService(service_);
+ }
ServerContextForTest() : ServerContextForTest(service_.method) {}
@@ -68,20 +67,21 @@
return internal::Packet(internal::PacketType::RPC,
kChannelId,
kServiceId,
- context_.method_->id(),
+ context_.method().id(),
payload,
Status::OK);
}
- ServerContext& get() { return context_; }
+ internal::ServerCall& get() { return context_; }
const auto& output() const { return output_; }
private:
TestOutput<output_buffer_size> output_;
Channel channel_;
+ Server server_;
Service service_;
- ServerContext context_;
+ internal::ServerCall context_;
};
} // namespace pw::rpc
diff --git a/pw_rpc/server.cc b/pw_rpc/server.cc
index 7ad7cf0..b239f3d 100644
--- a/pw_rpc/server.cc
+++ b/pw_rpc/server.cc
@@ -70,8 +70,8 @@
void Server::InvokeMethod(const Packet& request,
Channel& channel,
internal::Packet& response,
- span<std::byte> buffer) const {
- Service* service = services_.Find(request.service_id());
+ span<std::byte> buffer) {
+ internal::Service* service = services_.Find(request.service_id());
if (service == nullptr) {
// Couldn't find the requested service. Reply with a NOT_FOUND response
// without the service_id field set.
@@ -92,11 +92,12 @@
response.set_method_id(method->id());
- ServerContext context(channel, *service, *method);
-
span<byte> response_buffer = request.PayloadUsableSpace(buffer);
+
+ internal::ServerCall call(*this, channel, *service, *method);
StatusWithSize result =
- method->Invoke(context, request.payload(), response_buffer);
+ method->Invoke(call, request.payload(), response_buffer);
+
response.set_status(result.status());
response.set_payload(response_buffer.first(result.size()));
}
diff --git a/pw_rpc/test_impl/public_overrides/pw_rpc/internal/method.h b/pw_rpc/test_impl/public_overrides/pw_rpc/internal/method.h
index ecd8f02..4c91255 100644
--- a/pw_rpc/test_impl/public_overrides/pw_rpc/internal/method.h
+++ b/pw_rpc/test_impl/public_overrides/pw_rpc/internal/method.h
@@ -29,10 +29,10 @@
public:
constexpr Method(uint32_t id) : BaseMethod(id), last_channel_id_(0) {}
- StatusWithSize Invoke(ServerContext& context,
+ StatusWithSize Invoke(ServerCall& call,
span<const std::byte> request,
span<std::byte> payload_buffer) const {
- last_channel_id_ = context.channel_id();
+ last_channel_id_ = call.channel_id();
last_request_ = request;
last_payload_buffer_ = payload_buffer;