pw_rpc: Expand server-side packet processing - Handle various error cases with incoming packets. - Dynamically assign a channel when a packet's ID doesn't exist. - Reserve space in the response buffer for packet "header" fields. Change-Id: Ibdce99c8ff1d37aa46bb4e400a4d8f8e646a8ac7
diff --git a/pw_protobuf/encoder.cc b/pw_protobuf/encoder.cc index 05ce88a..c20b093 100644 --- a/pw_protobuf/encoder.cc +++ b/pw_protobuf/encoder.cc
@@ -56,7 +56,10 @@ return encode_status_; } - memcpy(cursor_, ptr, size); + // Memmove the value into place as it's possible that it shares the encode + // buffer on a memory-constrained system. + std::memmove(cursor_, ptr, size); + cursor_ += size; return Status::OK; } @@ -167,7 +170,7 @@ to_copy = end - read_cursor; } - memmove(write_cursor, read_cursor, to_copy); + std::memmove(write_cursor, read_cursor, to_copy); write_cursor += to_copy; read_cursor += to_copy;
diff --git a/pw_rpc/BUILD.gn b/pw_rpc/BUILD.gn index 192a9bc..5226243 100644 --- a/pw_rpc/BUILD.gn +++ b/pw_rpc/BUILD.gn
@@ -62,7 +62,10 @@ } pw_test("server_test") { - deps = [ ":pw_rpc" ] + deps = [ + ":protos_pwpb", + ":pw_rpc", + ] sources = [ "server_test.cc" ] }
diff --git a/pw_rpc/packet.cc b/pw_rpc/packet.cc index 57e6a47..38a9174 100644 --- a/pw_rpc/packet.cc +++ b/pw_rpc/packet.cc
@@ -70,11 +70,13 @@ pw::protobuf::NestedEncoder encoder(buffer); RpcPacket::Encoder rpc_packet(&encoder); + // The payload is encoded first, as it may share the encode buffer. + rpc_packet.WritePayload(payload_); + rpc_packet.WriteType(type_); rpc_packet.WriteChannelId(channel_id_); rpc_packet.WriteServiceId(service_id_); rpc_packet.WriteMethodId(method_id_); - rpc_packet.WritePayload(payload_); rpc_packet.WriteStatus(status_); span<const std::byte> proto;
diff --git a/pw_rpc/packet_test.cc b/pw_rpc/packet_test.cc index 83d4ee8..dd413dd 100644 --- a/pw_rpc/packet_test.cc +++ b/pw_rpc/packet_test.cc
@@ -24,8 +24,7 @@ TEST(Packet, EncodeDecode) { constexpr byte payload[]{byte(0x00), byte(0x01), byte(0x02), byte(0x03)}; - Packet packet = Packet::Empty(); - packet.set_type(PacketType::RPC); + Packet packet = Packet::Empty(PacketType::RPC); packet.set_channel_id(12); packet.set_service_id(0xdeadbeef); packet.set_method_id(0x03a82921);
diff --git a/pw_rpc/public/pw_rpc/channel.h b/pw_rpc/public/pw_rpc/channel.h index 1ef6880..68fceca 100644 --- a/pw_rpc/public/pw_rpc/channel.h +++ b/pw_rpc/public/pw_rpc/channel.h
@@ -15,6 +15,7 @@ #include <cstdint> +#include "pw_assert/assert.h" #include "pw_span/span.h" #include "pw_status/status.h" @@ -40,16 +41,22 @@ class Channel { public: + static constexpr uint32_t kUnassignedChannelId = 0; + // Creates a dynamically assignable channel without a set ID or output. constexpr Channel() : id_(kUnassignedChannelId), output_(nullptr) {} // Creates a channel with a static ID. The channel's output can also be // static, or it can set to null to allow dynamically opening connections // through the channel. - constexpr Channel(uint32_t id, ChannelOutput* output) - : id_(id), output_(output) {} + template <uint32_t id> + static Channel Create(ChannelOutput* output) { + static_assert(id != kUnassignedChannelId, "Channel ID cannot be 0"); + return Channel(id, output); + } constexpr uint32_t id() const { return id_; } + constexpr bool assigned() const { return id_ != kUnassignedChannelId; } span<std::byte> AcquireBuffer() const { return output_->AcquireBuffer(); } void SendAndReleaseBuffer(size_t size) const { @@ -57,7 +64,13 @@ } private: - static constexpr uint32_t kUnassignedChannelId = 0; + friend class Server; + + constexpr Channel(uint32_t id, ChannelOutput* output) + : id_(id), output_(output) { + PW_CHECK_UINT_NE(id, kUnassignedChannelId); + } + uint32_t id_; ChannelOutput* output_; };
diff --git a/pw_rpc/public/pw_rpc/internal/packet.h b/pw_rpc/public/pw_rpc/internal/packet.h index 0ed74b9..2f3a164 100644 --- a/pw_rpc/public/pw_rpc/internal/packet.h +++ b/pw_rpc/public/pw_rpc/internal/packet.h
@@ -29,8 +29,8 @@ static Packet FromBuffer(span<const std::byte> data); // Returns an empty packet with default values set. - static constexpr Packet Empty() { - return Packet(PacketType::RPC, 0, 0, 0, {}, Status::OK); + static constexpr Packet Empty(PacketType type) { + return Packet(type, 0, 0, 0, {}, Status::OK); } // Encodes the packet into its wire format. Returns the encoded size.
diff --git a/pw_rpc/public/pw_rpc/internal/service.h b/pw_rpc/public/pw_rpc/internal/service.h index e457285..377b6c1 100644 --- a/pw_rpc/public/pw_rpc/internal/service.h +++ b/pw_rpc/public/pw_rpc/internal/service.h
@@ -42,7 +42,8 @@ // Handles an incoming packet and populates a response. Errors that occur // should be set within the response packet. void ProcessPacket(const internal::Packet& request, - internal::Packet& response); + internal::Packet& response, + span<std::byte> payload_buffer); private: friend class internal::ServiceRegistry;
diff --git a/pw_rpc/public/pw_rpc/server.h b/pw_rpc/public/pw_rpc/server.h index 2faf1b8..b2727cd 100644 --- a/pw_rpc/public/pw_rpc/server.h +++ b/pw_rpc/public/pw_rpc/server.h
@@ -42,7 +42,18 @@ using Service = internal::Service; using ServiceRegistry = internal::ServiceRegistry; - Channel* FindChannel(uint32_t id); + void SendResponse(const Channel& channel, + const internal::Packet& response, + span<std::byte> response_buffer) const; + + // Determines the space required to encode the packet proto fields for a + // response, and splits the buffer into reserved space and available space for + // the payload. Returns a subspan of the payload space. + span<std::byte> ResponsePayloadUsableSpace(const internal::Packet& request, + span<std::byte> buffer) const; + + Channel* FindChannel(uint32_t id) const; + Channel* AssignChannel(uint32_t id, ChannelOutput& interface); span<Channel> channels_; ServiceRegistry services_;
diff --git a/pw_rpc/server.cc b/pw_rpc/server.cc index fb44de7..badf8a9 100644 --- a/pw_rpc/server.cc +++ b/pw_rpc/server.cc
@@ -20,6 +20,7 @@ namespace pw::rpc { using internal::Packet; +using internal::PacketType; void Server::ProcessPacket(span<const std::byte> data, ChannelOutput& interface) { @@ -29,44 +30,51 @@ return; } - if (packet.service_id() == 0 || packet.method_id() == 0) { + if (packet.channel_id() == Channel::kUnassignedChannelId || + packet.service_id() == 0 || packet.method_id() == 0) { // Malformed packet; don't even try to process it. PW_LOG_ERROR("Received incomplete RPC packet on interface %u", unsigned(interface.id())); return; } + Packet response = Packet::Empty(PacketType::RPC); + Channel* channel = FindChannel(packet.channel_id()); if (channel == nullptr) { - // TODO(frolv): Dynamically assign channel. - return; + // If the requested channel doesn't exist, try to dynamically assign one. + channel = AssignChannel(packet.channel_id(), interface); + if (channel == nullptr) { + // If a channel can't be assigned, send back a response indicating that + // the server cannot process the request. The channel_id in the response + // is not set, to allow clients to detect this error case. + Channel temp_channel(packet.channel_id(), &interface); + response.set_status(Status::RESOURCE_EXHAUSTED); + SendResponse(temp_channel, response, temp_channel.AcquireBuffer()); + return; + } } span<std::byte> response_buffer = channel->AcquireBuffer(); + span<std::byte> payload_buffer = + ResponsePayloadUsableSpace(packet, response_buffer); + + response.set_channel_id(channel->id()); Service* service = services_.Find(packet.service_id()); if (service == nullptr) { - // TODO(frolv): Send back a NOT_FOUND response. - channel->SendAndReleaseBuffer(0); + // Couldn't find the requested service. Reply with a NOT_FOUND response + // without the server_id field set. + response.set_status(Status::NOT_FOUND); + SendResponse(*channel, response, response_buffer); return; } - Packet response = Packet::Empty(); - response.set_channel_id(channel->id()); - - service->ProcessPacket(packet, response); - - StatusWithSize sws = response.Encode(response_buffer); - if (!sws.ok()) { - // TODO(frolv): What should be done here? - channel->SendAndReleaseBuffer(0); - return; - } - - channel->SendAndReleaseBuffer(sws.size()); + service->ProcessPacket(packet, response, payload_buffer); + SendResponse(*channel, response, response_buffer); } -Channel* Server::FindChannel(uint32_t id) { +Channel* Server::FindChannel(uint32_t id) const { for (Channel& c : channels_) { if (c.id() == id) { return &c; @@ -75,4 +83,48 @@ return nullptr; } +Channel* Server::AssignChannel(uint32_t id, ChannelOutput& interface) { + Channel* channel = FindChannel(Channel::kUnassignedChannelId); + if (channel == nullptr) { + return nullptr; + } + + *channel = Channel(id, &interface); + return channel; +} + +void Server::SendResponse(const Channel& channel, + const Packet& response, + span<std::byte> response_buffer) const { + StatusWithSize sws = response.Encode(response_buffer); + if (!sws.ok()) { + // TODO(frolv): What should be done here? + channel.SendAndReleaseBuffer(0); + PW_LOG_ERROR("Failed to encode response packet to channel buffer"); + return; + } + + channel.SendAndReleaseBuffer(sws.size()); +} + +span<std::byte> Server::ResponsePayloadUsableSpace( + const Packet& request, span<std::byte> buffer) const { + size_t reserved_size = 0; + + reserved_size += 1; // channel_id key + reserved_size += varint::EncodedSize(request.channel_id()); + reserved_size += 1; // service_id key + reserved_size += varint::EncodedSize(request.service_id()); + reserved_size += 1; // method_id key + reserved_size += varint::EncodedSize(request.method_id()); + + // Packet type always takes two bytes to encode (varint key + varint enum). + reserved_size += 2; + + // Status field always takes two bytes to encode (varint key + varint status). + reserved_size += 2; + + return buffer.subspan(reserved_size); +} + } // namespace pw::rpc
diff --git a/pw_rpc/server_test.cc b/pw_rpc/server_test.cc index aae410e..85eecc0 100644 --- a/pw_rpc/server_test.cc +++ b/pw_rpc/server_test.cc
@@ -15,10 +15,13 @@ #include "pw_rpc/server.h" #include "gtest/gtest.h" +#include "pw_rpc/internal/packet.h" namespace pw::rpc { namespace { +using internal::Packet; +using internal::PacketType; using std::byte; template <size_t buffer_size> @@ -39,35 +42,110 @@ span<const byte> sent_packet_; }; -TestOutput<512> output(1); +Packet MakePacket(uint32_t channel_id, + uint32_t service_id, + uint32_t method_id, + span<const byte> payload) { + Packet packet = Packet::Empty(PacketType::RPC); + packet.set_channel_id(channel_id); + packet.set_service_id(service_id); + packet.set_method_id(method_id); + packet.set_payload(payload); + return packet; +} -// clang-format off -constexpr uint8_t encoded_packet[] = { - // type = PacketType::kRpc - 0x08, 0x00, - // channel_id = 1 - 0x10, 0x01, - // service_id = 42 - 0x18, 0x2a, - // method_id = 27 - 0x20, 0x1b, - // payload - 0x82, 0x02, 0xff, 0xff, -}; -// clang-format on - -TEST(Server, DoesStuff) { +TEST(Server, ProcessPacket_SendsResponse) { + TestOutput<128> output(1); Channel channels[] = { - Channel(1, &output), - Channel(2, &output), + Channel::Create<1>(&output), + Channel::Create<2>(&output), }; Server server(channels); internal::Service service(42, {}); server.RegisterService(service); - server.ProcessPacket(as_bytes(span(encoded_packet)), output); - auto packet = output.sent_packet(); - EXPECT_GT(packet.size(), 0u); + byte encoded_packet[64]; + constexpr byte payload[] = {byte(0x82), byte(0x02), byte(0xff), byte(0xff)}; + Packet request = MakePacket(1, 42, 27, payload); + auto sws = request.Encode(encoded_packet); + + server.ProcessPacket(span(encoded_packet, sws.size()), output); + Packet packet = Packet::FromBuffer(output.sent_packet()); + EXPECT_EQ(packet.status(), Status::OK); + EXPECT_EQ(packet.channel_id(), 1u); + EXPECT_EQ(packet.service_id(), 42u); +} + +TEST(Server, ProcessPacket_SendsNotFoundOnInvalidService) { + TestOutput<128> output(1); + Channel channels[] = { + Channel::Create<1>(&output), + Channel::Create<2>(&output), + }; + Server server(channels); + internal::Service service(42, {}); + server.RegisterService(service); + + byte encoded_packet[64]; + constexpr byte payload[] = {byte(0x82), byte(0x02), byte(0xff), byte(0xff)}; + Packet request = MakePacket(1, 43, 27, payload); + auto sws = request.Encode(encoded_packet); + + server.ProcessPacket(span(encoded_packet, sws.size()), output); + Packet packet = Packet::FromBuffer(output.sent_packet()); + EXPECT_EQ(packet.status(), Status::NOT_FOUND); + EXPECT_EQ(packet.channel_id(), 1u); + EXPECT_EQ(packet.service_id(), 0u); +} + +TEST(Server, ProcessPacket_AssignsAnUnassignedChannel) { + TestOutput<128> output(1); + Channel channels[] = { + Channel::Create<1>(&output), + Channel::Create<2>(&output), + Channel(), + }; + Server server(channels); + internal::Service service(42, {}); + server.RegisterService(service); + + byte encoded_packet[64]; + constexpr byte payload[] = {byte(0x82), byte(0x02), byte(0xff), byte(0xff)}; + Packet request = MakePacket(/*channel_id=*/99, 42, 27, payload); + auto sws = request.Encode(encoded_packet); + + TestOutput<128> unassigned_output(2); + server.ProcessPacket(span(encoded_packet, sws.size()), unassigned_output); + ASSERT_EQ(channels[2].id(), 99u); + + Packet packet = Packet::FromBuffer(unassigned_output.sent_packet()); + EXPECT_EQ(packet.status(), Status::OK); + EXPECT_EQ(packet.channel_id(), 99u); + EXPECT_EQ(packet.service_id(), 42u); +} + +TEST(Server, ProcessPacket_SendsResourceExhaustedWhenChannelCantBeAssigned) { + TestOutput<128> output(1); + Channel channels[] = { + Channel::Create<1>(&output), + Channel::Create<2>(&output), + }; + Server server(channels); + internal::Service service(42, {}); + server.RegisterService(service); + + byte encoded_packet[64]; + constexpr byte payload[] = {byte(0x82), byte(0x02), byte(0xff), byte(0xff)}; + Packet request = MakePacket(/*channel_id=*/99, 42, 27, payload); + auto sws = request.Encode(encoded_packet); + + server.ProcessPacket(span(encoded_packet, sws.size()), output); + + Packet packet = Packet::FromBuffer(output.sent_packet()); + EXPECT_EQ(packet.status(), Status::RESOURCE_EXHAUSTED); + EXPECT_EQ(packet.channel_id(), 0u); + EXPECT_EQ(packet.service_id(), 0u); + EXPECT_EQ(packet.method_id(), 0u); } } // namespace
diff --git a/pw_rpc/service.cc b/pw_rpc/service.cc index 996043c..819e6cd 100644 --- a/pw_rpc/service.cc +++ b/pw_rpc/service.cc
@@ -18,8 +18,9 @@ namespace pw::rpc::internal { -void Service::ProcessPacket(const Packet& request, Packet& response) { - response.set_type(PacketType::RPC); +void Service::ProcessPacket(const Packet& request, + Packet& response, + span<std::byte> payload_buffer) { response.set_service_id(id_); for (const Method& method : methods_) { @@ -28,6 +29,8 @@ response.set_method_id(method.id); } } + + (void)payload_buffer; } } // namespace pw::rpc::internal