pw_protobuf: Reimplement decoder as field iterator This changes the pw::protobuf::Decoder API to expose functions to iterate over fields instead of using a virtual callback interface, making the core decoder simpler and more flexible. The virtual callback interface is kept but renamed to a CallbackDecoder, and reimplemented in terms of basic Decoder. Change-Id: Idff321cd5e37184aa730251475c9e336136596d2
diff --git a/pw_protobuf/decoder.cc b/pw_protobuf/decoder.cc index 54d4675..3183556 100644 --- a/pw_protobuf/decoder.cc +++ b/pw_protobuf/decoder.cc
@@ -20,75 +20,178 @@ namespace pw::protobuf { -Status Decoder::Decode(span<const std::byte> proto) { - if (handler_ == nullptr || state_ != kReady) { - return Status::FAILED_PRECONDITION; - } - - state_ = kDecodeInProgress; - proto_ = proto; - - // Iterate over each field in the proto, calling the handler with the field - // key. - while (state_ == kDecodeInProgress && !proto_.empty()) { - const std::byte* original_cursor = proto_.data(); - - uint64_t key; - size_t bytes_read = varint::Decode(proto_, &key); - if (bytes_read == 0) { - state_ = kDecodeFailed; - return Status::DATA_LOSS; - } - - uint32_t field_number = key >> kFieldNumberShift; - Status status = handler_->ProcessField(this, field_number); - if (!status.ok()) { - state_ = status == Status::CANCELLED ? kDecodeCancelled : kDecodeFailed; +Status Decoder::Next() { + if (!previous_field_consumed_) { + if (Status status = SkipField(); !status.ok()) { return status; } + } + if (proto_.empty()) { + return Status::OUT_OF_RANGE; + } + previous_field_consumed_ = false; + return FieldSize() == 0 ? Status::DATA_LOSS : Status::OK; +} - // The callback function can modify the decoder's state; check that - // everything is still okay. - if (state_ == kDecodeFailed) { - break; - } - - // If the cursor has not moved, the user has not consumed the field in their - // callback. Skip ahead to the next field. - if (original_cursor == proto_.data()) { - SkipField(); - } +Status Decoder::SkipField() { + if (proto_.empty()) { + return Status::OUT_OF_RANGE; } - if (state_ != kDecodeInProgress) { + size_t bytes_to_skip = FieldSize(); + if (bytes_to_skip == 0) { return Status::DATA_LOSS; } - state_ = kReady; + proto_ = proto_.subspan(bytes_to_skip); + return proto_.empty() ? Status::OUT_OF_RANGE : Status::OK; +} + +uint32_t Decoder::FieldNumber() const { + uint64_t key; + varint::Decode(proto_, &key); + return key >> kFieldNumberShift; +} + +Status Decoder::ReadUint32(uint32_t* out) { + uint64_t value = 0; + Status status = ReadUint64(&value); + if (!status.ok()) { + return status; + } + if (value > std::numeric_limits<uint32_t>::max()) { + return Status::OUT_OF_RANGE; + } + *out = value; return Status::OK; } -Status Decoder::ReadVarint(uint32_t field_number, uint64_t* out) { - Status status = ConsumeKey(field_number, WireType::kVarint); +Status Decoder::ReadSint32(int32_t* out) { + int64_t value = 0; + Status status = ReadSint64(&value); if (!status.ok()) { return status; } + if (value > std::numeric_limits<int32_t>::max()) { + return Status::OUT_OF_RANGE; + } + *out = value; + return Status::OK; +} + +Status Decoder::ReadSint64(int64_t* out) { + uint64_t value = 0; + Status status = ReadUint64(&value); + if (!status.ok()) { + return status; + } + *out = varint::ZigZagDecode(value); + return Status::OK; +} + +Status Decoder::ReadBool(bool* out) { + uint64_t value = 0; + Status status = ReadUint64(&value); + if (!status.ok()) { + return status; + } + *out = value; + return Status::OK; +} + +Status Decoder::ReadString(std::string_view* out) { + span<const std::byte> bytes; + Status status = ReadDelimited(&bytes); + if (!status.ok()) { + return status; + } + *out = std::string_view(reinterpret_cast<const char*>(bytes.data()), + bytes.size()); + return Status::OK; +} + +size_t Decoder::FieldSize() const { + uint64_t key; + size_t key_size = varint::Decode(proto_, &key); + if (key_size == 0) { + return 0; + } + + span<const std::byte> remainder = proto_.subspan(key_size); + WireType wire_type = static_cast<WireType>(key & kWireTypeMask); + uint64_t value = 0; + size_t expected_size = 0; + + switch (wire_type) { + case WireType::kVarint: + expected_size = varint::Decode(remainder, &value); + if (expected_size == 0) { + return 0; + } + break; + + case WireType::kDelimited: + // Varint at cursor indicates size of the field. + expected_size = varint::Decode(remainder, &value); + if (expected_size == 0) { + return 0; + } + expected_size += value; + break; + + case WireType::kFixed32: + expected_size = sizeof(uint32_t); + break; + + case WireType::kFixed64: + expected_size = sizeof(uint64_t); + break; + } + + if (remainder.size() < expected_size) { + return 0; + } + + return key_size + expected_size; +} + +Status Decoder::ConsumeKey(WireType expected_type) { + uint64_t key; + size_t bytes_read = varint::Decode(proto_, &key); + if (bytes_read == 0) { + return Status::FAILED_PRECONDITION; + } + + WireType wire_type = static_cast<WireType>(key & kWireTypeMask); + if (wire_type != expected_type) { + return Status::FAILED_PRECONDITION; + } + + // Advance past the key. + proto_ = proto_.subspan(bytes_read); + return Status::OK; +} + +Status Decoder::ReadVarint(uint64_t* out) { + if (Status status = ConsumeKey(WireType::kVarint); !status.ok()) { + return status; + } size_t bytes_read = varint::Decode(proto_, out); if (bytes_read == 0) { - state_ = kDecodeFailed; return Status::DATA_LOSS; } // Advance to the next field. proto_ = proto_.subspan(bytes_read); + previous_field_consumed_ = true; return Status::OK; } -Status Decoder::ReadFixed(uint32_t field_number, std::byte* out, size_t size) { +Status Decoder::ReadFixed(std::byte* out, size_t size) { WireType expected_wire_type = size == sizeof(uint32_t) ? WireType::kFixed32 : WireType::kFixed64; - Status status = ConsumeKey(field_number, expected_wire_type); + Status status = ConsumeKey(expected_wire_type); if (!status.ok()) { return status; } @@ -99,13 +202,13 @@ std::memcpy(out, proto_.data(), size); proto_ = proto_.subspan(size); + previous_field_consumed_ = true; return Status::OK; } -Status Decoder::ReadDelimited(uint32_t field_number, - span<const std::byte>* out) { - Status status = ConsumeKey(field_number, WireType::kDelimited); +Status Decoder::ReadDelimited(span<const std::byte>* out) { + Status status = ConsumeKey(WireType::kDelimited); if (!status.ok()) { return status; } @@ -113,80 +216,60 @@ uint64_t length; size_t bytes_read = varint::Decode(proto_, &length); if (bytes_read == 0) { - state_ = kDecodeFailed; return Status::DATA_LOSS; } proto_ = proto_.subspan(bytes_read); if (proto_.size() < length) { - state_ = kDecodeFailed; return Status::DATA_LOSS; } *out = proto_.first(length); proto_ = proto_.subspan(length); + previous_field_consumed_ = true; return Status::OK; } -Status Decoder::ConsumeKey(uint32_t field_number, WireType expected_type) { +Status CallbackDecoder::Decode(span<const std::byte> proto) { + if (handler_ == nullptr || state_ != kReady) { + return Status::FAILED_PRECONDITION; + } + + state_ = kDecodeInProgress; + decoder_.Reset(proto); + + // Iterate the proto, calling the handler with each field number. + while (state_ == kDecodeInProgress) { + if (Status status = decoder_.Next(); !status.ok()) { + if (status == Status::OUT_OF_RANGE) { + // Reached the end of the proto. + break; + } + + // Proto data is malformed. + return status; + } + + Status status = handler_->ProcessField(*this, decoder_.FieldNumber()); + if (!status.ok()) { + state_ = status == Status::CANCELLED ? kDecodeCancelled : kDecodeFailed; + return status; + } + + // The callback function can modify the decoder's state; check that + // everything is still okay. + if (state_ == kDecodeFailed) { + break; + } + } + if (state_ != kDecodeInProgress) { - return Status::FAILED_PRECONDITION; + return Status::DATA_LOSS; } - uint64_t key; - size_t bytes_read = varint::Decode(proto_, &key); - if (bytes_read == 0) { - state_ = kDecodeFailed; - return Status::FAILED_PRECONDITION; - } - - uint32_t field = key >> kFieldNumberShift; - WireType wire_type = static_cast<WireType>(key & kWireTypeMask); - - if (field != field_number || wire_type != expected_type) { - state_ = kDecodeFailed; - return Status::FAILED_PRECONDITION; - } - - // Advance past the key. - proto_ = proto_.subspan(bytes_read); + state_ = kReady; return Status::OK; } -void Decoder::SkipField() { - uint64_t key; - proto_ = proto_.subspan(varint::Decode(proto_, &key)); - - WireType wire_type = static_cast<WireType>(key & kWireTypeMask); - size_t bytes_to_skip = 0; - uint64_t value = 0; - - switch (wire_type) { - case WireType::kVarint: - bytes_to_skip = varint::Decode(proto_, &value); - break; - - case WireType::kDelimited: - // Varint at cursor indicates size of the field. - bytes_to_skip += varint::Decode(proto_, &value); - bytes_to_skip += value; - break; - - case WireType::kFixed32: - bytes_to_skip = sizeof(uint32_t); - break; - - case WireType::kFixed64: - bytes_to_skip = sizeof(uint64_t); - break; - } - - if (bytes_to_skip == 0) { - state_ = kDecodeFailed; - } else { - proto_ = proto_.subspan(bytes_to_skip); - } -} - } // namespace pw::protobuf
diff --git a/pw_protobuf/decoder_test.cc b/pw_protobuf/decoder_test.cc index c655b09..abc9d54 100644 --- a/pw_protobuf/decoder_test.cc +++ b/pw_protobuf/decoder_test.cc
@@ -22,27 +22,28 @@ class TestDecodeHandler : public DecodeHandler { public: - Status ProcessField(Decoder* decoder, uint32_t field_number) override { + Status ProcessField(CallbackDecoder& decoder, + uint32_t field_number) override { std::string_view str; switch (field_number) { case 1: - decoder->ReadInt32(field_number, &test_int32); + decoder.ReadInt32(&test_int32); break; case 2: - decoder->ReadSint32(field_number, &test_sint32); + decoder.ReadSint32(&test_sint32); break; case 3: - decoder->ReadBool(field_number, &test_bool); + decoder.ReadBool(&test_bool); break; case 4: - decoder->ReadDouble(field_number, &test_double); + decoder.ReadDouble(&test_double); break; case 5: - decoder->ReadFixed32(field_number, &test_fixed32); + decoder.ReadFixed32(&test_fixed32); break; case 6: - decoder->ReadString(field_number, &str); + decoder.ReadString(&str); std::memcpy(test_string, str.data(), str.size()); test_string[str.size()] = '\0'; break; @@ -62,7 +63,101 @@ }; TEST(Decoder, Decode) { - Decoder decoder; + // clang-format off + uint8_t encoded_proto[] = { + // type=int32, k=1, v=42 + 0x08, 0x2a, + // type=sint32, k=2, v=-13 + 0x10, 0x19, + // type=bool, k=3, v=false + 0x18, 0x00, + // type=double, k=4, v=3.14159 + 0x21, 0x6e, 0x86, 0x1b, 0xf0, 0xf9, 0x21, 0x09, 0x40, + // type=fixed32, k=5, v=0xdeadbeef + 0x2d, 0xef, 0xbe, 0xad, 0xde, + // type=string, k=6, v="Hello world" + 0x32, 0x0b, 'H', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', + }; + // clang-format on + + Decoder decoder(as_bytes(span(encoded_proto))); + + int32_t v1 = 0; + EXPECT_EQ(decoder.Next(), Status::OK); + ASSERT_EQ(decoder.FieldNumber(), 1u); + EXPECT_EQ(decoder.ReadInt32(&v1), Status::OK); + EXPECT_EQ(v1, 42); + + int32_t v2 = 0; + EXPECT_EQ(decoder.Next(), Status::OK); + ASSERT_EQ(decoder.FieldNumber(), 2u); + EXPECT_EQ(decoder.ReadSint32(&v2), Status::OK); + EXPECT_EQ(v2, -13); + + bool v3 = true; + EXPECT_EQ(decoder.Next(), Status::OK); + ASSERT_EQ(decoder.FieldNumber(), 3u); + EXPECT_EQ(decoder.ReadBool(&v3), Status::OK); + EXPECT_FALSE(v3); + + double v4 = 0; + EXPECT_EQ(decoder.Next(), Status::OK); + ASSERT_EQ(decoder.FieldNumber(), 4u); + EXPECT_EQ(decoder.ReadDouble(&v4), Status::OK); + EXPECT_EQ(v4, 3.14159); + + uint32_t v5 = 0; + EXPECT_EQ(decoder.Next(), Status::OK); + ASSERT_EQ(decoder.FieldNumber(), 5u); + EXPECT_EQ(decoder.ReadFixed32(&v5), Status::OK); + EXPECT_EQ(v5, 0xdeadbeef); + + std::string_view v6; + char buffer[16]; + EXPECT_EQ(decoder.Next(), Status::OK); + ASSERT_EQ(decoder.FieldNumber(), 6u); + EXPECT_EQ(decoder.ReadString(&v6), Status::OK); + std::memcpy(buffer, v6.data(), v6.size()); + buffer[v6.size()] = '\0'; + EXPECT_STREQ(buffer, "Hello world"); + + EXPECT_EQ(decoder.Next(), Status::OUT_OF_RANGE); +} + +TEST(Decoder, Decode_SkipsUnusedFields) { + // clang-format off + uint8_t encoded_proto[] = { + // type=int32, k=1, v=42 + 0x08, 0x2a, + // type=sint32, k=2, v=-13 + 0x10, 0x19, + // type=bool, k=3, v=false + 0x18, 0x00, + // type=double, k=4, v=3.14159 + 0x21, 0x6e, 0x86, 0x1b, 0xf0, 0xf9, 0x21, 0x09, 0x40, + // type=fixed32, k=5, v=0xdeadbeef + 0x2d, 0xef, 0xbe, 0xad, 0xde, + // type=string, k=6, v="Hello world" + 0x32, 0x0b, 'H', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', + }; + // clang-format on + + Decoder decoder(as_bytes(span(encoded_proto))); + + // Don't process any fields except for the fourth. Next should still iterate + // correctly despite field values not being consumed. + EXPECT_EQ(decoder.Next(), Status::OK); + EXPECT_EQ(decoder.Next(), Status::OK); + EXPECT_EQ(decoder.Next(), Status::OK); + EXPECT_EQ(decoder.Next(), Status::OK); + ASSERT_EQ(decoder.FieldNumber(), 4u); + EXPECT_EQ(decoder.Next(), Status::OK); + EXPECT_EQ(decoder.Next(), Status::OK); + EXPECT_EQ(decoder.Next(), Status::OUT_OF_RANGE); +} + +TEST(CallbackDecoder, Decode) { + CallbackDecoder decoder; TestDecodeHandler handler; // clang-format off @@ -93,8 +188,8 @@ EXPECT_STREQ(handler.test_string, "Hello world"); } -TEST(Decoder, Decode_OverridesDuplicateFields) { - Decoder decoder; +TEST(CallbackDecoder, Decode_OverridesDuplicateFields) { + CallbackDecoder decoder; TestDecodeHandler handler; // clang-format off @@ -114,8 +209,8 @@ EXPECT_EQ(handler.test_int32, 44); } -TEST(Decoder, Decode_Empty) { - Decoder decoder; +TEST(CallbackDecoder, Decode_Empty) { + CallbackDecoder decoder; TestDecodeHandler handler; decoder.set_handler(&handler); @@ -125,8 +220,8 @@ EXPECT_EQ(handler.test_sint32, 0); } -TEST(Decoder, Decode_BadData) { - Decoder decoder; +TEST(CallbackDecoder, Decode_BadData) { + CallbackDecoder decoder; TestDecodeHandler handler; // Field key without a value. @@ -139,13 +234,14 @@ // Only processes fields numbered 1 or 3. class OneThreeDecodeHandler : public DecodeHandler { public: - Status ProcessField(Decoder* decoder, uint32_t field_number) override { + Status ProcessField(CallbackDecoder& decoder, + uint32_t field_number) override { switch (field_number) { case 1: - EXPECT_EQ(decoder->ReadInt32(field_number, &field_one), Status::OK); + EXPECT_EQ(decoder.ReadInt32(&field_one), Status::OK); break; case 3: - EXPECT_EQ(decoder->ReadInt32(field_number, &field_three), Status::OK); + EXPECT_EQ(decoder.ReadInt32(&field_three), Status::OK); break; default: // Do nothing. @@ -161,8 +257,8 @@ int32_t field_three = 0; }; -TEST(Decoder, Decode_SkipsUnprocessedFields) { - Decoder decoder; +TEST(CallbackDecoder, Decode_SkipsUnprocessedFields) { + CallbackDecoder decoder; OneThreeDecodeHandler handler; // clang-format off @@ -192,16 +288,17 @@ EXPECT_EQ(handler.field_three, 99); } -// Only processes fields numbered 1 or 3. +// Only processes fields numbered 1 or 3, and stops the decode after hitting 1. class ExitOnOneDecoder : public DecodeHandler { public: - Status ProcessField(Decoder* decoder, uint32_t field_number) override { + Status ProcessField(CallbackDecoder& decoder, + uint32_t field_number) override { switch (field_number) { case 1: - EXPECT_EQ(decoder->ReadInt32(field_number, &field_one), Status::OK); + EXPECT_EQ(decoder.ReadInt32(&field_one), Status::OK); return Status::CANCELLED; case 3: - EXPECT_EQ(decoder->ReadInt32(field_number, &field_three), Status::OK); + EXPECT_EQ(decoder.ReadInt32(&field_three), Status::OK); break; default: // Do nothing. @@ -215,8 +312,8 @@ int32_t field_three = 1111; }; -TEST(Decoder, Decode_StopsOnNonOkStatus) { - Decoder decoder; +TEST(CallbackDecoder, Decode_StopsOnNonOkStatus) { + CallbackDecoder decoder; ExitOnOneDecoder handler; // clang-format off
diff --git a/pw_protobuf/find.cc b/pw_protobuf/find.cc index b98c4b9..70daabb 100644 --- a/pw_protobuf/find.cc +++ b/pw_protobuf/find.cc
@@ -16,7 +16,7 @@ namespace pw::protobuf { -Status FindDecodeHandler::ProcessField(Decoder* decoder, +Status FindDecodeHandler::ProcessField(CallbackDecoder& decoder, uint32_t field_number) { if (field_number != field_number_) { // Continue to the next field. @@ -29,12 +29,11 @@ } span<const std::byte> submessage; - if (Status status = decoder->ReadBytes(field_number, &submessage); - !status.ok()) { + if (Status status = decoder.ReadBytes(&submessage); !status.ok()) { return status; } - Decoder subdecoder; + CallbackDecoder subdecoder; subdecoder.set_handler(nested_handler_); return subdecoder.Decode(submessage); }
diff --git a/pw_protobuf/find_test.cc b/pw_protobuf/find_test.cc index 860e338..2a7edd6 100644 --- a/pw_protobuf/find_test.cc +++ b/pw_protobuf/find_test.cc
@@ -41,7 +41,7 @@ }; TEST(FindDecodeHandler, SingleLevel_FindsExistingField) { - Decoder decoder; + CallbackDecoder decoder; FindDecodeHandler finder(3); decoder.set_handler(&finder); @@ -52,7 +52,7 @@ } TEST(FindDecodeHandler, SingleLevel_DoesntFindNonExistingField) { - Decoder decoder; + CallbackDecoder decoder; FindDecodeHandler finder(8); decoder.set_handler(&finder); @@ -63,7 +63,7 @@ } TEST(FindDecodeHandler, MultiLevel_FindsExistingNestedField) { - Decoder decoder; + CallbackDecoder decoder; FindDecodeHandler nested_finder(1); FindDecodeHandler finder(7, &nested_finder); @@ -76,7 +76,7 @@ } TEST(FindDecodeHandler, MultiLevel_DoesntFindNonExistingNestedField) { - Decoder decoder; + CallbackDecoder decoder; FindDecodeHandler nested_finder(3); FindDecodeHandler finder(7, &nested_finder);
diff --git a/pw_protobuf/public/pw_protobuf/decoder.h b/pw_protobuf/public/pw_protobuf/decoder.h index a3aca75..184c38f 100644 --- a/pw_protobuf/public/pw_protobuf/decoder.h +++ b/pw_protobuf/public/pw_protobuf/decoder.h
@@ -21,10 +21,8 @@ #include "pw_varint/varint.h" // This file defines a low-level event-based protobuf wire format decoder. -// The decoder processes an encoded message by iterating over its fields and -// notifying a handler for each field it encounters. The handler receives a -// reference to the decoder object and can extract the field's value from the -// message. +// The decoder processes an encoded message by iterating over its fields. The +// caller can extract the values of any fields it cares about. // // The decoder does not provide any in-memory data structures to represent a // protobuf message's data. More sophisticated APIs can be built on top of the @@ -32,17 +30,161 @@ // // Example usage: // +// Decoder decoder(proto); +// while (decoder.Next().ok()) { +// switch (decoder.FieldNumber()) { +// case 1: +// decoder.ReadUint32(&my_uint32); +// break; +// // ... and other fields. +// } +// } +// +namespace pw::protobuf { + +class Decoder { + public: + constexpr Decoder(span<const std::byte> proto) + : proto_(proto), previous_field_consumed_(true) {} + + Decoder(const Decoder& other) = delete; + Decoder& operator=(const Decoder& other) = delete; + + // Advances to the next field in the proto. + // + // If Next() returns OK, there is guaranteed to be a valid protobuf field at + // the current cursor position. + // + // Return values: + // + // OK: Advanced to a valid proto field. + // OUT_OF_RANGE: Reached the end of the proto message. + // DATA_LOSS: Invalid protobuf data. + // + Status Next(); + + // Returns the field number of the field at the current cursor position. + uint32_t FieldNumber() const; + + // Reads a proto int32 value from the current cursor. + Status ReadInt32(int32_t* out) { + return ReadUint32(reinterpret_cast<uint32_t*>(out)); + } + + // Reads a proto uint32 value from the current cursor. + Status ReadUint32(uint32_t* out); + + // Reads a proto int64 value from the current cursor. + Status ReadInt64(int64_t* out) { + return ReadVarint(reinterpret_cast<uint64_t*>(out)); + } + + // Reads a proto uint64 value from the current cursor. + Status ReadUint64(uint64_t* out) { return ReadVarint(out); } + + // Reads a proto sint32 value from the current cursor. + Status ReadSint32(int32_t* out); + + // Reads a proto sint64 value from the current cursor. + Status ReadSint64(int64_t* out); + + // Reads a proto bool value from the current cursor. + Status ReadBool(bool* out); + + // Reads a proto fixed32 value from the current cursor. + Status ReadFixed32(uint32_t* out) { return ReadFixed(out); } + + // Reads a proto fixed64 value from the current cursor. + Status ReadFixed64(uint64_t* out) { return ReadFixed(out); } + + // Reads a proto sfixed32 value from the current cursor. + Status ReadSfixed32(int32_t* out) { + return ReadFixed32(reinterpret_cast<uint32_t*>(out)); + } + + // Reads a proto sfixed64 value from the current cursor. + Status ReadSfixed64(int64_t* out) { + return ReadFixed64(reinterpret_cast<uint64_t*>(out)); + } + + // Reads a proto float value from the current cursor. + Status ReadFloat(float* out) { + static_assert(sizeof(float) == sizeof(uint32_t), + "Float and uint32_t must be the same size for protobufs"); + return ReadFixed(out); + } + + // Reads a proto double value from the current cursor. + Status ReadDouble(double* out) { + static_assert(sizeof(double) == sizeof(uint64_t), + "Double and uint64_t must be the same size for protobufs"); + return ReadFixed(out); + } + + // Reads a proto string value from the current cursor and returns a view of it + // in `out`. The raw protobuf data must outlive `out`. If the string field is + // invalid, `out` is not modified. + Status ReadString(std::string_view* out); + + // Reads a proto bytes value from the current cursor and returns a view of it + // in `out`. The raw protobuf data must outlive the `out` span. If the bytes + // field is invalid, `out` is not modified. + Status ReadBytes(span<const std::byte>* out) { return ReadDelimited(out); } + + // Resets the decoder to start reading a new proto message. + void Reset(span<const std::byte> proto) { + proto_ = proto; + previous_field_consumed_ = true; + } + + private: + // Advances the cursor to the next field in the proto. + Status SkipField(); + + // Returns the size of the current field, or 0 if the field is invalid. + size_t FieldSize() const; + + Status ConsumeKey(WireType expected_type); + + // Reads a varint key-value pair from the current cursor position. + Status ReadVarint(uint64_t* out); + + // Reads a fixed-size key-value pair from the current cursor position. + Status ReadFixed(std::byte* out, size_t size); + + template <typename T> + Status ReadFixed(T* out) { + static_assert( + sizeof(T) == sizeof(uint32_t) || sizeof(T) == sizeof(uint64_t), + "Protobuf fixed-size fields must be 32- or 64-bit"); + return ReadFixed(reinterpret_cast<std::byte*>(out), sizeof(T)); + } + + Status ReadDelimited(span<const std::byte>* out); + + span<const std::byte> proto_; + bool previous_field_consumed_; +}; + +class DecodeHandler; + +// A protobuf decoder that iterates over an encoded protobuf, calling a handler +// for each field it encounters. +// +// Example usage: +// // class FooProtoHandler : public DecodeHandler { // public: -// Status ProcessField(Decoder* decoder, uint32_t field_number) override { +// Status ProcessField(CallbackDecoder& decoder, +// uint32_t field_number) override { // switch (field_number) { // case FooFields::kBar: -// if (!decoder->ReadSint32(field_number, &bar).ok()) { +// if (!decoder.ReadSint32(&bar).ok()) { // bar = 0; // } // break; // case FooFields::kBaz: -// if (!decoder->ReadUint32(field_number, &baz).ok()) { +// if (!decoder.ReadUint32(&baz).ok()) { // baz = 0; // } // break; @@ -68,19 +210,13 @@ // handler.bar, handler.baz); // } // - -namespace pw::protobuf { - -class DecodeHandler; - -// A protobuf decoder that iterates over an encoded protobuf, calling a handler -// for each field it encounters. -class Decoder { +class CallbackDecoder { public: - constexpr Decoder() : handler_(nullptr), state_(kReady) {} + constexpr CallbackDecoder() + : decoder_({}), handler_(nullptr), state_(kReady) {} - Decoder(const Decoder& other) = delete; - Decoder& operator=(const Decoder& other) = delete; + CallbackDecoder(const CallbackDecoder& other) = delete; + CallbackDecoder& operator=(const CallbackDecoder& other) = delete; void set_handler(DecodeHandler* handler) { handler_ = handler; } @@ -89,123 +225,54 @@ Status Decode(span<const std::byte> proto); // Reads a proto int32 value from the current cursor. - Status ReadInt32(uint32_t field_number, int32_t* out) { - return ReadUint32(field_number, reinterpret_cast<uint32_t*>(out)); - } + Status ReadInt32(int32_t* out) { return decoder_.ReadInt32(out); } // Reads a proto uint32 value from the current cursor. - Status ReadUint32(uint32_t field_number, uint32_t* out) { - uint64_t value = 0; - Status status = ReadUint64(field_number, &value); - if (!status.ok()) { - return status; - } - if (value > std::numeric_limits<uint32_t>::max()) { - return Status::OUT_OF_RANGE; - } - *out = value; - return Status::OK; - } + Status ReadUint32(uint32_t* out) { return decoder_.ReadUint32(out); } // Reads a proto int64 value from the current cursor. - Status ReadInt64(uint32_t field_number, int64_t* out) { - return ReadVarint(field_number, reinterpret_cast<uint64_t*>(out)); - } + Status ReadInt64(int64_t* out) { return decoder_.ReadInt64(out); } // Reads a proto uint64 value from the current cursor. - Status ReadUint64(uint32_t field_number, uint64_t* out) { - return ReadVarint(field_number, out); - } - - // Reads a proto sint32 value from the current cursor. - Status ReadSint32(uint32_t field_number, int32_t* out) { - int64_t value = 0; - Status status = ReadSint64(field_number, &value); - if (!status.ok()) { - return status; - } - if (value > std::numeric_limits<int32_t>::max()) { - return Status::OUT_OF_RANGE; - } - *out = value; - return Status::OK; - } + Status ReadUint64(uint64_t* out) { return decoder_.ReadUint64(out); } // Reads a proto sint64 value from the current cursor. - Status ReadSint64(uint32_t field_number, int64_t* out) { - uint64_t value = 0; - Status status = ReadUint64(field_number, &value); - if (!status.ok()) { - return status; - } - *out = varint::ZigZagDecode(value); - return Status::OK; - } + Status ReadSint32(int32_t* out) { return decoder_.ReadSint32(out); } + + // Reads a proto sint64 value from the current cursor. + Status ReadSint64(int64_t* out) { return decoder_.ReadSint64(out); } // Reads a proto bool value from the current cursor. - Status ReadBool(uint32_t field_number, bool* out) { - uint64_t value = 0; - Status status = ReadUint64(field_number, &value); - if (!status.ok()) { - return status; - } - *out = value; - return Status::OK; - } + Status ReadBool(bool* out) { return decoder_.ReadBool(out); } // Reads a proto fixed32 value from the current cursor. - Status ReadFixed32(uint32_t field_number, uint32_t* out) { - return ReadFixed(field_number, out); - } + Status ReadFixed32(uint32_t* out) { return decoder_.ReadFixed32(out); } // Reads a proto fixed64 value from the current cursor. - Status ReadFixed64(uint32_t field_number, uint64_t* out) { - return ReadFixed(field_number, out); - } + Status ReadFixed64(uint64_t* out) { return decoder_.ReadFixed64(out); } // Reads a proto sfixed32 value from the current cursor. - Status ReadSfixed32(uint32_t field_number, int32_t* out) { - return ReadFixed32(field_number, reinterpret_cast<uint32_t*>(out)); - } + Status ReadSfixed32(int32_t* out) { return decoder_.ReadSfixed32(out); } // Reads a proto sfixed64 value from the current cursor. - Status ReadSfixed64(uint32_t field_number, int64_t* out) { - return ReadFixed64(field_number, reinterpret_cast<uint64_t*>(out)); - } + Status ReadSfixed64(int64_t* out) { return decoder_.ReadSfixed64(out); } // Reads a proto float value from the current cursor. - Status ReadFloat(uint32_t field_number, float* out) { - static_assert(sizeof(float) == sizeof(uint32_t), - "Float and uint32_t must be the same size for protobufs"); - return ReadFixed(field_number, out); - } + Status ReadFloat(float* out) { return decoder_.ReadFloat(out); } // Reads a proto double value from the current cursor. - Status ReadDouble(uint32_t field_number, double* out) { - static_assert(sizeof(double) == sizeof(uint64_t), - "Double and uint64_t must be the same size for protobufs"); - return ReadFixed(field_number, out); - } + Status ReadDouble(double* out) { return decoder_.ReadDouble(out); } // Reads a proto string value from the current cursor and returns a view of it // in `out`. The raw protobuf data must outlive `out`. If the string field is // invalid, `out` is not modified. - Status ReadString(uint32_t field_number, std::string_view* out) { - span<const std::byte> bytes; - Status status = ReadDelimited(field_number, &bytes); - if (!status.ok()) { - return status; - } - *out = std::string_view(reinterpret_cast<const char*>(bytes.data()), - bytes.size()); - return Status::OK; - } + Status ReadString(std::string_view* out) { return decoder_.ReadString(out); } // Reads a proto bytes value from the current cursor and returns a view of it // in `out`. The raw protobuf data must outlive the `out` span. If the bytes // field is invalid, `out` is not modified. - Status ReadBytes(uint32_t field_number, span<const std::byte>* out) { - return ReadDelimited(field_number, out); + Status ReadBytes(span<const std::byte>* out) { + return decoder_.ReadBytes(out); } bool cancelled() const { return state_ == kDecodeCancelled; }; @@ -218,43 +285,14 @@ kDecodeFailed, }; - // Reads a varint key-value pair from the current cursor position. - Status ReadVarint(uint32_t field_number, uint64_t* out); - - // Reads a fixed-size key-value pair from the current cursor position. - Status ReadFixed(uint32_t field_number, std::byte* out, size_t size); - - template <typename T> - Status ReadFixed(uint32_t field_number, T* out) { - static_assert( - sizeof(T) == sizeof(uint32_t) || sizeof(T) == sizeof(uint64_t), - "Protobuf fixed-size fields must be 32- or 64-bit"); - union { - T value; - std::byte bytes[sizeof(T)]; - }; - Status status = ReadFixed(field_number, bytes, sizeof(bytes)); - if (!status.ok()) { - return status; - } - *out = value; - return Status::OK; - } - - Status ReadDelimited(uint32_t field_number, span<const std::byte>* out); - - Status ConsumeKey(uint32_t field_number, WireType expected_type); - - // Advances the cursor to the next field in the proto. - void SkipField(); - + Decoder decoder_; DecodeHandler* handler_; State state_; - span<const std::byte> proto_; }; -// The event-handling interface implemented for a proto decoding operation. +// The event-handling interface implemented for a proto callback decoding +// operation. class DecodeHandler { public: virtual ~DecodeHandler() = default; @@ -266,7 +304,8 @@ // If the status returned is not Status::OK, the decode operation is exited // with the provided status. Returning Status::CANCELLED allows a convenient // way of stopping a decode early (for example, if a desired field is found). - virtual Status ProcessField(Decoder* decoder, uint32_t field_number) = 0; + virtual Status ProcessField(CallbackDecoder& decoder, + uint32_t field_number) = 0; }; } // namespace pw::protobuf
diff --git a/pw_protobuf/public/pw_protobuf/find.h b/pw_protobuf/public/pw_protobuf/find.h index 1e2f3e8..64ad94d 100644 --- a/pw_protobuf/public/pw_protobuf/find.h +++ b/pw_protobuf/public/pw_protobuf/find.h
@@ -29,7 +29,7 @@ constexpr FindDecodeHandler(uint32_t field_number, FindDecodeHandler* nested) : field_number_(field_number), found_(false), nested_handler_(nested) {} - Status ProcessField(Decoder* decoder, uint32_t field_number) override; + Status ProcessField(CallbackDecoder& decoder, uint32_t field_number) override; bool found() const { return found_; }
diff --git a/pw_protobuf/size_report/decoder_full.cc b/pw_protobuf/size_report/decoder_full.cc index d879012..561335e 100644 --- a/pw_protobuf/size_report/decoder_full.cc +++ b/pw_protobuf/size_report/decoder_full.cc
@@ -28,71 +28,38 @@ // clang-format on } // namespace -class TestDecodeHandler : public pw::protobuf::DecodeHandler { - public: - pw::Status ProcessField(pw::protobuf::Decoder* decoder, - uint32_t field_number) override { - std::string_view str; - - switch (field_number) { - case 1: - if (!decoder->ReadInt32(field_number, &test_int32).ok()) { - test_int32 = 0; - } - break; - case 2: - if (!decoder->ReadSint32(field_number, &test_sint32).ok()) { - test_sint32 = 0; - } - break; - case 3: - if (!decoder->ReadBool(field_number, &test_bool).ok()) { - test_bool = false; - } - break; - case 4: - if (!decoder->ReadDouble(field_number, &test_double).ok()) { - test_double = 0; - } - break; - case 5: - if (!decoder->ReadFixed32(field_number, &test_fixed32).ok()) { - test_fixed32 = 0; - } - break; - case 6: - if (decoder->ReadString(field_number, &str).ok()) { - // In real code: - // assert(str.size() < sizeof(test_string)); - std::memcpy(test_string, str.data(), str.size()); - test_string[str.size()] = '\0'; - } - break; - } - - return pw::Status::OK; - } - - int32_t test_int32 = 0; - int32_t test_sint32 = 0; - bool test_bool = false; - double test_double = 0; - uint32_t test_fixed32 = 0; - char test_string[16]; -}; - int* volatile non_optimizable_pointer; int main() { pw::bloat::BloatThisBinary(); - pw::protobuf::Decoder decoder; - TestDecodeHandler handler; + int32_t test_int32, test_sint32; + std::string_view str; + float f; + double d; - decoder.set_handler(&handler); - decoder.Decode(pw::as_bytes(pw::span(encoded_proto))); + pw::protobuf::Decoder decoder(pw::as_bytes(pw::span(encoded_proto))); + while (decoder.Next().ok()) { + switch (decoder.FieldNumber()) { + case 1: + decoder.ReadInt32(&test_int32); + break; + case 2: + decoder.ReadSint32(&test_sint32); + break; + case 3: + decoder.ReadString(&str); + break; + case 4: + decoder.ReadFloat(&f); + break; + case 5: + decoder.ReadDouble(&d); + break; + } + } - *non_optimizable_pointer = handler.test_int32 + handler.test_sint32; + *non_optimizable_pointer = test_int32 + test_sint32; return 0; }
diff --git a/pw_protobuf/size_report/decoder_incremental.cc b/pw_protobuf/size_report/decoder_incremental.cc index ef6f2db..064f579 100644 --- a/pw_protobuf/size_report/decoder_incremental.cc +++ b/pw_protobuf/size_report/decoder_incremental.cc
@@ -28,98 +28,50 @@ // clang-format on } // namespace -class TestDecodeHandler : public pw::protobuf::DecodeHandler { - public: - pw::Status ProcessField(pw::protobuf::Decoder* decoder, - uint32_t field_number) override { - std::string_view str; - - switch (field_number) { - case 1: - if (!decoder->ReadInt32(field_number, &test_int32).ok()) { - test_int32 = 0; - } - break; - case 2: - if (!decoder->ReadSint32(field_number, &test_sint32).ok()) { - test_sint32 = 0; - } - break; - case 3: - if (!decoder->ReadBool(field_number, &test_bool).ok()) { - test_bool = false; - } - break; - case 4: - if (!decoder->ReadDouble(field_number, &test_double).ok()) { - test_double = 0; - } - break; - case 5: - if (!decoder->ReadFixed32(field_number, &test_fixed32).ok()) { - test_fixed32 = 0; - } - break; - case 6: - if (decoder->ReadString(field_number, &str).ok()) { - // In real code: - // assert(str.size() < sizeof(test_string)); - std::memcpy(test_string, str.data(), str.size()); - test_string[str.size()] = '\0'; - } - break; - - // Extra fields. - case 21: - if (!decoder->ReadInt32(field_number, &test_int32).ok()) { - test_int32 = 0; - } - break; - case 22: - if (!decoder->ReadInt32(field_number, &test_int32).ok()) { - test_int32 = 0; - } - break; - case 23: - if (!decoder->ReadInt32(field_number, &test_int32).ok()) { - test_int32 = 0; - } - break; - case 24: - if (!decoder->ReadSint32(field_number, &test_sint32).ok()) { - test_sint32 = 0; - } - break; - case 25: - if (!decoder->ReadSint32(field_number, &test_sint32).ok()) { - test_sint32 = 0; - } - break; - } - - return pw::Status::OK; - } - - int32_t test_int32 = 0; - int32_t test_sint32 = 0; - bool test_bool = false; - double test_double = 0; - uint32_t test_fixed32 = 0; - char test_string[16]; -}; - int* volatile non_optimizable_pointer; int main() { pw::bloat::BloatThisBinary(); - pw::protobuf::Decoder decoder; - TestDecodeHandler handler; + int32_t test_int32, test_sint32; + std::string_view str; + float f; + double d; + uint32_t uint; - decoder.set_handler(&handler); - decoder.Decode(pw::as_bytes(pw::span(encoded_proto))); + pw::protobuf::Decoder decoder(pw::as_bytes(pw::span(encoded_proto))); + while (decoder.Next().ok()) { + switch (decoder.FieldNumber()) { + case 1: + decoder.ReadInt32(&test_int32); + break; + case 2: + decoder.ReadSint32(&test_sint32); + break; + case 3: + decoder.ReadString(&str); + break; + case 4: + decoder.ReadFloat(&f); + break; + case 5: + decoder.ReadDouble(&d); + break; - *non_optimizable_pointer = handler.test_int32 + handler.test_sint32; + // Extra fields over decoder_full. + case 21: + decoder.ReadInt32(&test_int32); + break; + case 22: + decoder.ReadUint32(&uint); + break; + case 23: + decoder.ReadSint32(&test_sint32); + break; + } + } + + *non_optimizable_pointer = test_int32 + test_sint32; return 0; }