pw_rpc: Use full service names in Python client
- Use fully qualified names for service IDs (package.Service).
- Use gRPC-style package.Service/Name as the full name for methods.
- Refactor package navigation code from pw_protobuf_compiler to make it
reusable.
- Access services in an RPC client by package name. Without this, there
would be conflicts if there were services with the same name in
different packages (e.g. imu.Calibration, touchpad.Calibration).
Change-Id: Ib3c0c11768e1082bec23b3f8421c271ae86bb3ee
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/14280
Reviewed-by: Alexei Frolov <frolv@google.com>
Commit-Queue: Wyatt Hepler <hepler@google.com>
diff --git a/pw_protobuf_compiler/py/pw_protobuf_compiler/python_protos.py b/pw_protobuf_compiler/py/pw_protobuf_compiler/python_protos.py
index 6dc72d3..5f22802 100644
--- a/pw_protobuf_compiler/py/pw_protobuf_compiler/python_protos.py
+++ b/pw_protobuf_compiler/py/pw_protobuf_compiler/python_protos.py
@@ -21,7 +21,8 @@
import shlex
import tempfile
from types import ModuleType
-from typing import Dict, Iterable, Iterator, List, Set, Tuple, Union
+from typing import Dict, Generic, Iterable, Iterator, List, NamedTuple, Set
+from typing import Tuple, TypeVar, Union
_LOG = logging.getLogger(__name__)
@@ -137,11 +138,14 @@
yield from compile_and_import(protos, includes, output_dir)
-class _ProtoPackage:
- """Used by the Library class for accessing protocol buffer modules."""
+T = TypeVar('T')
+
+
+class _NestedPackage(Generic[T]):
+ """Facilitates navigating protobuf packages as attributes."""
def __init__(self, package: str):
- self._packages: Dict[str, _ProtoPackage] = {}
- self._modules: List[ModuleType] = []
+ self._packages: Dict[str, _NestedPackage[T]] = {}
+ self._items: List[T] = []
self._package = package
def __getattr__(self, attr: str):
@@ -149,13 +153,59 @@
if attr in self._packages:
return self._packages[attr]
- for module in self._modules:
+ for module in self._items:
if hasattr(module, attr):
return getattr(module, attr)
raise AttributeError(
f'Proto package "{self._package}" does not contain "{attr}"')
+ def __iter__(self) -> Iterator['_NestedPackage[T]']:
+ return iter(self._packages.values())
+
+ def __repr__(self) -> str:
+ return f'_NestedPackage({self._package!r})'
+
+ def __str__(self) -> str:
+ return self._package
+
+
+class Packages(NamedTuple):
+ """Items in a protobuf package structure; returned from as_package."""
+ items_by_package: Dict[str, List]
+ packages: _NestedPackage
+
+
+def as_packages(items: Iterable[Tuple[str, T]],
+ packages: Packages = None) -> Packages:
+ """Places items in a proto-style package structure navigable by attributes.
+
+ Args:
+ items: (package, item) tuples to insert into the package structure
+ packages: if provided, update this Packages instead of creating a new one
+ """
+ if packages is None:
+ packages = Packages({}, _NestedPackage(''))
+
+ for package, item in items:
+ packages.items_by_package.setdefault(package, []).append(item)
+
+ entry = packages.packages
+ subpackages = package.split('.')
+
+ # pylint: disable=protected-access
+ for i, subpackage in enumerate(subpackages, 1):
+ if subpackage not in entry._packages:
+ entry._packages[subpackage] = _NestedPackage('.'.join(
+ subpackages[:i]))
+
+ entry = entry._packages[subpackage]
+
+ entry._items.append(item)
+ # pylint: enable=protected-access
+
+ return packages
+
class Library:
"""A collection of protocol buffer modules sorted by package.
@@ -190,24 +240,9 @@
protos = Library(compile_and_import(list_of_proto_files))
"""
- self.modules_by_package: Dict[str, List[ModuleType]] = {}
- self.packages = _ProtoPackage('')
-
- for module in modules:
- package = module.DESCRIPTOR.package # type: ignore[attr-defined]
- self.modules_by_package.setdefault(package, []).append(module)
-
- entry = self.packages
- subpackages = package.split('.')
-
- for i, subpackage in enumerate(subpackages, 1):
- if subpackage not in entry._packages:
- entry._packages[subpackage] = _ProtoPackage('.'.join(
- subpackages[:i]))
-
- entry = entry._packages[subpackage]
-
- entry._modules.append(module)
+ self.modules_by_package, self.packages = as_packages(
+ (m.DESCRIPTOR.package, m) # type: ignore[attr-defined]
+ for m in modules)
def modules(self) -> Iterable[ModuleType]:
"""Allows iterating over all protobuf modules in this library."""
diff --git a/pw_rpc/py/callback_client_test.py b/pw_rpc/py/callback_client_test.py
index bb46031..d9cbe80 100755
--- a/pw_rpc/py/callback_client_test.py
+++ b/pw_rpc/py/callback_client_test.py
@@ -25,7 +25,7 @@
TEST_PROTO_1 = """\
syntax = "proto3";
-package pw.call.test1;
+package pw.test1;
message SomeMessage {
uint32 magic_number = 1;
@@ -118,7 +118,7 @@
return message
def test_invoke_unary_rpc(self):
- stub = self._client.channel(1).call.PublicService
+ stub = self._client.channel(1).rpcs.pw.test1.PublicService
method = stub.SomeUnary.method
for _ in range(3):
@@ -135,7 +135,7 @@
self.assertEqual('0_o', response.payload)
def test_invoke_unary_rpc_with_callback(self):
- stub = self._client.channel(1).call.PublicService
+ stub = self._client.channel(1).rpcs.pw.test1.PublicService
method = stub.SomeUnary.method
for _ in range(3):
@@ -153,7 +153,7 @@
self._sent_payload(method.request_type).magic_number)
def test_invoke_unary_rpc_callback_errors_suppressed(self):
- stub = self._client.channel(1).call.PublicService.SomeUnary
+ stub = self._client.channel(1).rpcs.pw.test1.PublicService.SomeUnary
self._enqueue_response(1, stub.method)
exception_msg = 'YOU BROKE IT O-]-<'
@@ -169,7 +169,7 @@
self.assertIs(status, Status.UNKNOWN)
def test_invoke_unary_rpc_with_callback_cancel(self):
- stub = self._client.channel(1).call.PublicService
+ stub = self._client.channel(1).rpcs.pw.test1.PublicService
callback = mock.Mock()
for _ in range(3):
@@ -191,7 +191,7 @@
callback.assert_not_called()
def test_invoke_server_streaming(self):
- stub = self._client.channel(1).call.PublicService
+ stub = self._client.channel(1).rpcs.pw.test1.PublicService
method = stub.SomeServerStreaming.method
rep1 = method.response_type(payload='!!!')
@@ -210,7 +210,7 @@
self._sent_payload(method.request_type).magic_number)
def test_invoke_server_streaming_with_callback(self):
- stub = self._client.channel(1).call.PublicService
+ stub = self._client.channel(1).rpcs.pw.test1.PublicService
method = stub.SomeServerStreaming.method
rep1 = method.response_type(payload='!!!')
@@ -235,7 +235,8 @@
self._sent_payload(method.request_type).magic_number)
def test_invoke_server_streaming_with_callback_cancel(self):
- stub = self._client.channel(1).call.PublicService.SomeServerStreaming
+ stub = self._client.channel(
+ 1).rpcs.pw.test1.PublicService.SomeServerStreaming
resp = stub.method.response_type(payload='!!!')
self._enqueue_response(1, stub.method, response=resp)
@@ -263,8 +264,8 @@
])
def test_ignore_bad_packets_with_pending_rpc(self):
- rpcs = self._client.channel(1).call
- method = rpcs.PublicService.SomeUnary.method
+ rpcs = self._client.channel(1).rpcs
+ method = rpcs.pw.test1.PublicService.SomeUnary.method
service_id = method.service.id
# Unknown channel
@@ -276,18 +277,20 @@
# For RPC not pending (valid=True because the packet is processed)
self._enqueue_response(
1,
- ids=(service_id, rpcs.PublicService.SomeBidiStreaming.method.id),
+ ids=(service_id,
+ rpcs.pw.test1.PublicService.SomeBidiStreaming.method.id),
valid=True)
self._enqueue_response(1, method, valid=True)
- status, response = rpcs.PublicService.SomeUnary(magic_number=6)
+ status, response = rpcs.pw.test1.PublicService.SomeUnary(
+ magic_number=6)
self.assertIs(Status.OK, status)
self.assertEqual('', response.payload)
def test_pass_none_if_payload_fails_to_decode(self):
- rpcs = self._client.channel(1).call
- method = rpcs.PublicService.SomeUnary.method
+ rpcs = self._client.channel(1).rpcs
+ method = rpcs.pw.test1.PublicService.SomeUnary.method
self._enqueue_response(1,
method,
@@ -295,7 +298,8 @@
b'INVALID DATA!!!',
valid=True)
- status, response = rpcs.PublicService.SomeUnary(magic_number=6)
+ status, response = rpcs.pw.test1.PublicService.SomeUnary(
+ magic_number=6)
self.assertIs(status, Status.OK)
self.assertIsNone(response)
diff --git a/pw_rpc/py/client_test.py b/pw_rpc/py/client_test.py
index a2486b6..b0e9435 100755
--- a/pw_rpc/py/client_test.py
+++ b/pw_rpc/py/client_test.py
@@ -23,7 +23,7 @@
TEST_PROTO_1 = """\
syntax = "proto3";
-package pw.call.test1;
+package pw.test1;
message SomeMessage {
uint32 magic_number = 1;
@@ -51,7 +51,7 @@
TEST_PROTO_2 = """\
syntax = "proto2";
-package pw.call.test2;
+package pw.test2;
message Request {
optional float magic_number = 1;
@@ -82,54 +82,75 @@
def test_access_service_client_as_attribute_or_index(self):
self.assertIs(
- self._client.channel(1).call.PublicService,
- self._client.channel(1).call['PublicService'])
+ self._client.channel(1).rpcs.pw.test1.PublicService,
+ self._client.channel(1).rpcs['pw.test1.PublicService'])
self.assertIs(
- self._client.channel(1).call.PublicService,
- self._client.channel(1).call[pw_rpc.ids.calculate(
- 'PublicService')])
+ self._client.channel(1).rpcs.pw.test1.PublicService,
+ self._client.channel(1).rpcs[pw_rpc.ids.calculate(
+ 'pw.test1.PublicService')])
def test_access_method_client_as_attribute_or_index(self):
self.assertIs(
- self._client.channel(1).call.Alpha.Unary,
- self._client.channel(1).call['Alpha']['Unary'])
+ self._client.channel(1).rpcs.pw.test2.Alpha.Unary,
+ self._client.channel(1).rpcs['pw.test2.Alpha']['Unary'])
self.assertIs(
- self._client.channel(1).call.Alpha.Unary,
- self._client.channel(1).call['Alpha'][pw_rpc.ids.calculate(
- 'Unary')])
+ self._client.channel(1).rpcs.pw.test2.Alpha.Unary,
+ self._client.channel(1).rpcs['pw.test2.Alpha'][
+ pw_rpc.ids.calculate('Unary')])
+
+ def test_service_name(self):
+ self.assertEqual(
+ self._client.channel(1).rpcs.pw.test2.Alpha.Unary.service.name,
+ 'Alpha')
+ self.assertEqual(
+ self._client.channel(
+ 1).rpcs.pw.test2.Alpha.Unary.service.full_name,
+ 'pw.test2.Alpha')
+
+ def test_method_name(self):
+ self.assertEqual(
+ self._client.channel(1).rpcs.pw.test2.Alpha.Unary.method.name,
+ 'Unary')
+ self.assertEqual(
+ self._client.channel(1).rpcs.pw.test2.Alpha.Unary.method.full_name,
+ 'pw.test2.Alpha/Unary')
def test_check_for_presence_of_services(self):
- self.assertIn('PublicService', self._client.channel(1).call)
- self.assertIn(pw_rpc.ids.calculate('PublicService'),
- self._client.channel(1).call)
- self.assertNotIn('NotAService', self._client.channel(1).call)
- self.assertNotIn(-1213, self._client.channel(1).call)
+ self.assertIn('pw.test1.PublicService', self._client.channel(1).rpcs)
+ self.assertIn(pw_rpc.ids.calculate('pw.test1.PublicService'),
+ self._client.channel(1).rpcs)
+
+ def test_check_for_presence_of_missing_services(self):
+ self.assertNotIn('PublicService', self._client.channel(1).rpcs)
+ self.assertNotIn('NotAService', self._client.channel(1).rpcs)
+ self.assertNotIn(-1213, self._client.channel(1).rpcs)
def test_check_for_presence_of_methods(self):
- self.assertIn('SomeUnary', self._client.channel(1).call.PublicService)
- self.assertIn(pw_rpc.ids.calculate('SomeUnary'),
- self._client.channel(1).call.PublicService)
+ service = self._client.channel(1).rpcs.pw.test1.PublicService
+ self.assertIn('SomeUnary', service)
+ self.assertIn(pw_rpc.ids.calculate('SomeUnary'), service)
- self.assertNotIn('Unary', self._client.channel(1).call.PublicService)
- self.assertNotIn(12345, self._client.channel(1).call.PublicService)
+ def test_check_for_presence_of_missing_methods(self):
+ service = self._client.channel(1).rpcs.pw.test1.PublicService
+ self.assertNotIn('Some', service)
+ self.assertNotIn('Unary', service)
+ self.assertNotIn(12345, service)
def test_method_get_request_with_both_message_and_kwargs(self):
- req = self._client.services['Alpha'].methods['Unary'].request_type()
+ method = self._client.services['pw.test2.Alpha'].methods['Unary']
with self.assertRaisesRegex(TypeError, r'either'):
- self._client.services['Alpha'].methods['Unary'].get_request(
- req, {'magic_number': 1.0})
+ method.get_request(method.request_type(), {'magic_number': 1.0})
def test_method_get_request_with_wrong_type(self):
- with self.assertRaisesRegex(TypeError, r'pw\.call\.test2\.Request'):
- self._client.services['Alpha'].methods['Unary'].get_request(
- 'str!', {})
+ method = self._client.services['pw.test2.Alpha'].methods['Unary']
+ with self.assertRaisesRegex(TypeError, r'pw\.test2\.Request'):
+ method.get_request('a str!', {})
def test_method_get_with_incorrect_message_type(self):
- msg = self._protos.packages.pw.call.test1.AnotherMessage()
- with self.assertRaisesRegex(TypeError,
- r'pw\.call\.test1\.SomeMessage'):
- self._client.services['PublicService'].methods[
+ msg = self._protos.packages.pw.test1.AnotherMessage()
+ with self.assertRaisesRegex(TypeError, r'pw\.test1\.SomeMessage'):
+ self._client.services['pw.test1.PublicService'].methods[
'SomeUnary'].get_request(msg, {})
def test_process_packet_invalid_proto_data(self):
@@ -140,21 +161,20 @@
self._client.process_packet(
packets.encode_request(
(123, 456, 789),
- self._protos.packages.pw.call.test2.Request())))
+ self._protos.packages.pw.test2.Request())))
def test_process_packet_unrecognized_service(self):
self.assertFalse(
self._client.process_packet(
packets.encode_request(
- (1, 456, 789),
- self._protos.packages.pw.call.test2.Request())))
+ (1, 456, 789), self._protos.packages.pw.test2.Request())))
def test_process_packet_unrecognized_method(self):
self.assertFalse(
self._client.process_packet(
packets.encode_request(
(1, next(iter(self._client.services)).id, 789),
- self._protos.packages.pw.call.test2.Request())))
+ self._protos.packages.pw.test2.Request())))
if __name__ == '__main__':
diff --git a/pw_rpc/py/pw_rpc/client.py b/pw_rpc/py/pw_rpc/client.py
index cb4b360..ba775a8 100644
--- a/pw_rpc/py/pw_rpc/client.py
+++ b/pw_rpc/py/pw_rpc/client.py
@@ -112,10 +112,10 @@
channel: Channel, methods: Collection[Method]):
super().__init__(
{
- method.name: client_impl.method_client(rpcs, channel, method)
+ method: client_impl.method_client(rpcs, channel, method)
for method in methods
},
- as_attrs=True)
+ as_attrs='members')
class _ServiceClients(descriptors.ServiceAccessor[_MethodClients]):
@@ -124,10 +124,10 @@
services: Collection[Service]):
super().__init__(
{
- s.name: _MethodClients(rpcs, client_impl, channel, s.methods)
+ s: _MethodClients(rpcs, client_impl, channel, s.methods)
for s in services
},
- as_attrs=True)
+ as_attrs='packages')
def _decode_status(rpc: PendingRpc, packet) -> Optional[Status]:
@@ -159,19 +159,19 @@
class ChannelClient:
"""RPC services and methods bound to a particular channel.
- RPCs are invoked from a ChannelClient using its call member. The service and
- method may be selected as attributes or by indexing call with service and
+ RPCs are invoked from a ChannelClient using its rpcs member. The service and
+ method may be selected as attributes or by indexing rpcs with service and
method name or ID:
- response = client.channel(1).call.FooService.SomeMethod(foo=bar)
+ response = client.channel(1).rpcs.FooService.SomeMethod(foo=bar)
- response = client.channel(1).call[foo_service_id]['SomeMethod'](foo=bar)
+ response = client.channel(1).rpcs[foo_service_id]['SomeMethod'](foo=bar)
The type and semantics of the return value, if there is one, are determined
by the ClientImpl instance used by the Client.
"""
channel: Channel
- call: _ServiceClients
+ rpcs: _ServiceClients
class Client:
@@ -190,6 +190,7 @@
def __init__(self, impl: ClientImpl, channels: Iterable[Channel],
services: Iterable[Service]):
self._impl = impl
+
self.services = descriptors.Services(services)
self._rpcs = PendingRpcs()
diff --git a/pw_rpc/py/pw_rpc/descriptors.py b/pw_rpc/py/pw_rpc/descriptors.py
index 95b2d36..2313738 100644
--- a/pw_rpc/py/pw_rpc/descriptors.py
+++ b/pw_rpc/py/pw_rpc/descriptors.py
@@ -15,11 +15,12 @@
from dataclasses import dataclass
import enum
-from typing import Any, Callable, Collection, Dict, Iterable, Iterator, Tuple
-from typing import TypeVar, Union
+from typing import Any, Callable, Collection, Dict, Generic, Iterable, Iterator
+from typing import Tuple, TypeVar, Union
from google.protobuf import descriptor_pb2
from pw_rpc import ids
+from pw_protobuf_compiler import python_protos
@dataclass(frozen=True)
@@ -36,11 +37,17 @@
"""Describes an RPC service."""
name: str
id: int
+ package: str
methods: 'ServiceAccessor'
+ @property
+ def full_name(self):
+ return f'{self.package}.{self.name}'
+
@classmethod
def from_descriptor(cls, module, descriptor):
- service = cls(descriptor.name, ids.calculate(descriptor.name), None)
+ service = cls(descriptor.name, ids.calculate(descriptor.full_name),
+ descriptor.file.package, None)
object.__setattr__(
service, 'methods',
Methods(
@@ -50,7 +57,10 @@
return service
def __repr__(self) -> str:
- return f'Service({self.name!r})'
+ return f'Service({self.full_name!r})'
+
+ def __str__(self) -> str:
+ return self.full_name
def _streaming_attributes(method) -> Tuple[bool, bool]:
@@ -97,7 +107,7 @@
@property
def full_name(self) -> str:
- return f'{self.service.name}.{self.name}'
+ return f'{self.service.full_name}/{self.name}'
@property
def type(self) -> 'Method.Type':
@@ -143,34 +153,51 @@
def __repr__(self) -> str:
return f'Method({self.name!r})'
+ def __str__(self) -> str:
+ return self.full_name
+
T = TypeVar('T')
+def _name(item: Union[Service, Method]) -> str:
+ return item.full_name if isinstance(item, Service) else item.name
+
+
+class _AccessByName(Generic[T]):
+ """Wrapper for accessing types by name within a proto package structure."""
+ def __init__(self, name: str, item: T):
+ setattr(self, name, item)
+
+
class ServiceAccessor(Collection[T]):
"""Navigates RPC services by name or ID."""
- def __init__(self, members, as_attrs: bool):
- """Creates accessor from a {name: value} dict or [values] iterable."""
+ def __init__(self, members, as_attrs: str = ''):
+ """Creates accessor from an {item: value} dict or [values] iterable."""
if isinstance(members, dict):
- by_name = members
- self._by_id = {
- ids.calculate(name): m
- for name, m in by_name.items()
- }
+ by_name = {_name(k): v for k, v in members.items()}
+ self._by_id = {k.id: v for k, v in members.items()}
else:
- by_name = {m.name: m for m in members}
+ by_name = {_name(m): m for m in members}
self._by_id = {m.id: m for m in by_name.values()}
- if as_attrs:
+ if as_attrs == 'members':
for name, member in by_name.items():
setattr(self, name, member)
+ elif as_attrs == 'packages':
+ for package in python_protos.as_packages(
+ (m.package, _AccessByName(m.name, members[m]))
+ for m in members).packages:
+ setattr(self, str(package), package)
+ elif as_attrs:
+ raise ValueError(f'Unexpected value {as_attrs!r} for as_attrs')
def __getitem__(self, name_or_id: Union[str, int]):
"""Accesses a service/method by the string name or ID."""
try:
return self._by_id[_id(name_or_id)]
except KeyError:
- name = ' (name_or_id)' if isinstance(name_or_id, str) else ''
+ name = f' ("{name_or_id}")' if isinstance(name_or_id, str) else ''
raise KeyError(f'Unknown ID {_id(name_or_id)}{name}')
def __iter__(self) -> Iterator[T]:
@@ -194,10 +221,10 @@
class Services(ServiceAccessor[Service]):
"""A collection of Service descriptors."""
def __init__(self, services: Iterable[Service]):
- super().__init__(services, as_attrs=False)
+ super().__init__(services)
class Methods(ServiceAccessor[Method]):
"""A collection of Method descriptors in a Service."""
def __init__(self, method: Iterable[Method]):
- super().__init__(method, as_attrs=False)
+ super().__init__(method)