blob: d69a57fc39ef649a3c987ee0f69186c59b17ecac [file] [log] [blame]
# Copyright 2020 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Creates an RPC client."""
import abc
from collections import defaultdict
from dataclasses import dataclass
import logging
from queue import SimpleQueue
from typing import Any, Collection, Dict, Iterable
from pw_rpc import descriptors, packets
from pw_rpc.descriptors import Channel, Service, Method, PendingRpc
from pw_status import Status
_LOG = logging.getLogger(__name__)
class ClientImpl(abc.ABC):
"""The internal interface of the RPC client.
This interface defines the semantics for invoking an RPC on a particular
client. The return values can objects that provide for synchronous or
asynchronous behavior.
"""
@abc.abstractmethod
def invoke_unary(self, rpc: PendingRpc, request) -> Any:
"""Invokes a unary RPC."""
@abc.abstractmethod
def invoke_server_streaming(self, rpc: PendingRpc, request) -> Any:
"""Invokes a server streaming RPC."""
@abc.abstractmethod
def invoke_client_streaming(self, rpc: PendingRpc) -> Any:
"""Invokes a client streaming streaming RPC."""
@abc.abstractmethod
def invoke_bidirectional_streaming(self, rpc: PendingRpc) -> Any:
"""Invokes a bidirectional streaming streaming RPC."""
@abc.abstractmethod
def process_response(self, rpc: PendingRpc, payload,
status: Status) -> None:
"""Processes a response from the RPC server."""
class SimpleSynchronousClient(ClientImpl):
"""A client that blocks until a response is received for unary RPCs."""
def __init__(self):
self._responses: Dict[PendingRpc,
SimpleQueue] = defaultdict(SimpleQueue)
self._pending: Dict[PendingRpc, bool] = defaultdict(bool)
def invoke_unary(self, rpc: PendingRpc, request: packets.Message):
queue = self._responses[rpc]
assert not self._pending[rpc], f'{rpc} is already pending!'
self._pending[rpc] = True
try:
rpc.channel.output(packets.encode(rpc, request))
result = queue.get()
finally:
self._pending[rpc] = False
return result
def invoke_server_streaming(self, rpc: PendingRpc, request):
raise NotImplementedError
def invoke_client_streaming(self, rpc: PendingRpc):
raise NotImplementedError
def invoke_bidirectional_streaming(self, rpc: PendingRpc):
raise NotImplementedError
def process_response(self, rpc: PendingRpc, payload,
status: Status) -> None:
if not self._pending[rpc]:
_LOG.warning('Discarding packet for %s', rpc)
return
self._responses[rpc].put((status, payload))
class _MethodClient:
"""A method that can be invoked for a particular channel."""
@classmethod
def create(cls, client_impl: ClientImpl, channel: Channel, method: Method):
"""Instantiates a _MethodClient according to the RPC type."""
if method.type is Method.Type.UNARY:
return UnaryMethodClient(client_impl, channel, method)
raise NotImplementedError('Streaming methods are not yet supported')
def __init__(self, client_impl: ClientImpl, channel: Channel,
method: Method):
self._client_impl = client_impl
self.channel = channel
self.method = method
def _get_request(self, proto: packets.Message,
kwargs: dict) -> packets.Message:
if proto and kwargs:
raise TypeError(
'Requests must be provided either as a message object or a '
'series of keyword args, but both were provided')
if proto is None:
return self.method.request_type(**kwargs)
if not isinstance(proto, self.method.request_type):
try:
bad_type = proto.DESCRIPTOR.full_name
except AttributeError:
bad_type = type(proto).__name__
raise TypeError(
f'Expected a message of type '
f'{self.method.request_type.DESCRIPTOR.full_name}, '
f'got {bad_type}')
return proto
class UnaryMethodClient(_MethodClient):
# TODO(hepler): This function should make _request a positional-only
# argument, to avoid confusion with keyword-specified protobuf fields.
# However, yapf does not yet support Python 3.8's grammar, and
# positional-only arguments crash it.
def __call__(self, _request=None, **request_fields):
"""Invokes this unary method using its associated channel.
The request can be provided as either a message object or as keyword
arguments for the message's fields (but not both).
"""
return self._client_impl.invoke_unary(
PendingRpc(self.channel, self.method.service, self.method),
self._get_request(_request, request_fields))
class _MethodClients(descriptors.ServiceAccessor[_MethodClient]):
"""Navigates the methods in a service provided by a ChannelClient."""
def __init__(self, client_impl: ClientImpl, channel: Channel,
methods: Collection[Method]):
super().__init__(
{
method.name: _MethodClient.create(client_impl, channel, method)
for method in methods
},
as_attrs=True)
class _ServiceClients(descriptors.ServiceAccessor[_MethodClients]):
"""Navigates the services provided by a ChannelClient."""
def __init__(self, client_impl, channel: Channel,
services: Collection[Service]):
super().__init__(
{
s.name: _MethodClients(client_impl, channel, s.methods)
for s in services
},
as_attrs=True)
@dataclass(frozen=True, eq=False)
class ChannelClient:
"""RPC services and methods bound to a particular channel.
RPCs are invoked from a ChannelClient using its call member. The service and
method may be selected as attributes or by indexing call with service and
method name or ID:
response = client.channel(1).call.FooService.SomeMethod(foo=bar)
response = client.channel(1).call[foo_service_id]['SomeMethod'](foo=bar)
The type and semantics of the return value, if there is one, are determined
by the ClientImpl instance used by the Client.
"""
channel: Channel
call: _ServiceClients
class Client:
"""Sends requests and handles responses for a set of channels.
RPC invocations occur through a ChannelClient.
"""
@classmethod
def from_modules(cls, impl: ClientImpl, channels: Iterable[Channel],
modules: Iterable):
return cls(
impl, channels,
(Service.from_descriptor(module, service) for module in modules
for service in module.DESCRIPTOR.services_by_name.values()))
def __init__(self, impl: ClientImpl, channels: Iterable[Channel],
services: Iterable[Service]):
self.services = descriptors.Services(services)
self._impl = impl
self._channels_by_id = {
channel.id:
ChannelClient(channel,
_ServiceClients(self._impl, channel, self.services))
for channel in channels
}
def channel(self, channel_id: int) -> ChannelClient:
"""Returns a ChannelClient, which is used to call RPCs on a channel."""
return self._channels_by_id[channel_id]
def process_packet(self, data: bytes) -> bool:
"""Processes an incoming packet.
Args:
data: raw binary data for exactly one RPC packet
Returns:
True if the packet was decoded and handled by this client
"""
try:
packet = packets.decode(data)
except packets.DecodeError as err:
_LOG.warning('Failed to decode packet: %s', err)
_LOG.debug('Raw packet: %r', data)
return False
try:
rpc = self._lookup_packet(packet)
except ValueError as err:
_LOG.warning('Unable to process packet: %s', err)
return False
try:
response = packets.decode_payload(packet, rpc.method.response_type)
except packets.DecodeError as err:
response = None
_LOG.warning('Failed to decode %s response for %s: %s',
rpc.method.response_type.DESCRIPTOR.full_name,
rpc.method.full_name, err)
self._impl.process_response(rpc, response, Status(packet.status))
return True
def _lookup_packet(self, packet: packets.RpcPacket) -> PendingRpc:
try:
channel_client = self._channels_by_id[packet.channel_id]
except KeyError:
raise ValueError(f'Unrecognized channel ID {packet.channel_id}')
try:
service = self.services[packet.service_id]
except KeyError:
raise ValueError(f'Unrecognized service ID {packet.service_id}')
try:
method = service.methods[packet.method_id]
except KeyError:
raise ValueError(
f'No method ID {packet.method_id} in service {service.name}')
return PendingRpc(channel_client.channel, service, method)