blob: 14240e7ddacbe11726e3e04b9fe23c818d4e06aa [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include <algorithm>
#include <cerrno>
#include <exception>
#include <thread> // NOLINT
#include "absl/base/thread_annotations.h"
#include "absl/memory/memory.h"
#include "absl/synchronization/mutex.h"
#include "flatbuffers/flatbuffers.h"
#include "iree/base/status.h"
#include "iree/rt/debug/debug_server.h"
#include "iree/rt/debug/debug_service.h"
#include "iree/rt/debug/debug_tcp_util.h"
#include "iree/schemas/debug_service_generated.h"
namespace iree {
namespace rt {
namespace debug {
namespace {
// Writes the given typed response message to the given fd by wrapping it in
// a size-prefixed rpc::Request union.
//
// Example:
// ::flatbuffers::FlatBufferBuilder fbb;
// rpc::SuspendInvocationResponseBuilder response(fbb);
// RETURN_IF_ERROR(WriteResponse(fd_, response.Finish(), std::move(fbb)));
template <typename T>
Status WriteResponse(int fd, ::flatbuffers::Offset<T> message_offs,
::flatbuffers::FlatBufferBuilder fbb) {
rpc::ResponseBuilder response_builder(fbb);
response_builder.add_message_type(rpc::ResponseUnionTraits<T>::enum_value);
response_builder.add_message(message_offs.Union());
auto response_offs = response_builder.Finish();
rpc::ServicePacketBuilder packet_builder(fbb);
packet_builder.add_response(response_offs);
fbb.FinishSizePrefixed(packet_builder.Finish());
return tcp::WriteBuffer(fd, fbb.Release());
}
class TcpDebugSession : public DebugSession {
public:
using ClosedCallback =
std::function<void(TcpDebugSession* session, Status status)>;
static StatusOr<std::unique_ptr<TcpDebugSession>> Accept(
DebugService* debug_service, int client_fd,
ClosedCallback closed_callback) {
VLOG(2) << "Client " << client_fd << ": Setting up socket options...";
// Disable Nagel's algorithm to ensure we have low latency.
RETURN_IF_ERROR(tcp::ToggleSocketNagelsAlgorithm(client_fd, false));
// Enable keepalive assuming the client is local and this high freq is ok.
RETURN_IF_ERROR(tcp::ToggleSocketLocalKeepalive(client_fd, true));
// Linger around for a bit to flush all data.
RETURN_IF_ERROR(tcp::ToggleSocketLinger(client_fd, true));
return absl::make_unique<TcpDebugSession>(debug_service, client_fd,
std::move(closed_callback));
}
TcpDebugSession(DebugService* debug_service, int client_fd,
ClosedCallback closed_callback)
: debug_service_(debug_service),
client_fd_(client_fd),
closed_callback_(std::move(closed_callback)) {
CHECK_OK(debug_service_->RegisterDebugSession(this));
session_thread_ = std::thread([this]() { SessionThread(); });
}
~TcpDebugSession() override {
CHECK_OK(debug_service_->UnregisterDebugSession(this));
VLOG(2) << "Client " << client_fd_ << ": Shutting down session socket...";
::shutdown(client_fd_, SHUT_RD);
if (session_thread_.joinable() &&
session_thread_.get_id() != std::this_thread::get_id()) {
VLOG(2) << "Client " << client_fd_ << ": Joining socket thread...";
session_thread_.join();
VLOG(2) << "Client " << client_fd_ << ": Joined socket thread!";
} else {
VLOG(2) << "Client " << client_fd_ << ": Detaching socket thread...";
session_thread_.detach();
}
VLOG(2) << "Client " << client_fd_ << ": Closing session socket...";
::close(client_fd_);
VLOG(2) << "Client " << client_fd_ << ": Closed session socket!";
client_fd_ = -1;
}
Status OnServiceShutdown() {
VLOG(2) << "Client " << client_fd_ << ": Post OnServiceShutdown()";
::flatbuffers::FlatBufferBuilder fbb;
rpc::ServiceShutdownEventBuilder event(fbb);
return PostEvent(event.Finish(), std::move(fbb));
}
Status OnContextRegistered(Context* context) override {
VLOG(2) << "Client " << client_fd_ << ": Post OnContextRegistered("
<< context->id() << ")";
::flatbuffers::FlatBufferBuilder fbb;
rpc::ContextRegisteredEventBuilder event(fbb);
event.add_context_id(context->id());
return PostEvent(event.Finish(), std::move(fbb));
}
Status OnContextUnregistered(Context* context) override {
VLOG(2) << "Client " << client_fd_ << ": Post OnContextUnregistered("
<< context->id() << ")";
::flatbuffers::FlatBufferBuilder fbb;
rpc::ContextUnregisteredEventBuilder event(fbb);
event.add_context_id(context->id());
return PostEvent(event.Finish(), std::move(fbb));
}
Status OnModuleLoaded(Context* context, Module* module) override {
VLOG(2) << "Client " << client_fd_ << ": Post OnModuleLoaded("
<< context->id() << ", " << module->name() << ")";
::flatbuffers::FlatBufferBuilder fbb;
auto module_name_offs =
fbb.CreateString(module->name().data(), module->name().size());
rpc::ModuleLoadedEventBuilder event(fbb);
event.add_context_id(context->id());
event.add_module_name(module_name_offs);
return PostEvent(event.Finish(), std::move(fbb));
}
Status OnInvocationRegistered(Invocation* invocation) override {
VLOG(2) << "Client " << client_fd_ << ": Post OnInvocationRegistered("
<< invocation->id() << ")";
::flatbuffers::FlatBufferBuilder fbb;
rpc::InvocationRegisteredEventBuilder event(fbb);
event.add_invocation_id(invocation->id());
return PostEvent(event.Finish(), std::move(fbb));
}
Status OnInvocationUnregistered(Invocation* invocation) override {
VLOG(2) << "Client " << client_fd_ << ": Post OnInvocationUnregistered("
<< invocation->id() << ")";
::flatbuffers::FlatBufferBuilder fbb;
rpc::InvocationUnregisteredEventBuilder event(fbb);
event.add_invocation_id(invocation->id());
return PostEvent(event.Finish(), std::move(fbb));
}
Status OnBreakpointResolved(const rpc::BreakpointDefT& breakpoint,
Context* context) override {
VLOG(2) << "Client " << client_fd_ << ": Post OnBreakpointResolved("
<< breakpoint.breakpoint_id << ", " << context->id() << ", "
<< breakpoint.function_ordinal << ")";
rpc::BreakpointResolvedEventT event;
event.breakpoint = absl::make_unique<rpc::BreakpointDefT>();
*event.breakpoint = breakpoint;
event.context_id = context->id();
::flatbuffers::FlatBufferBuilder fbb;
return PostEvent(rpc::BreakpointResolvedEvent::Pack(fbb, &event),
std::move(fbb));
}
Status OnBreakpointHit(int breakpoint_id,
const Invocation& invocation) override {
VLOG(2) << "Client " << client_fd_ << ": Post OnBreakpointHit("
<< breakpoint_id << ", " << invocation.id() << ")";
::flatbuffers::FlatBufferBuilder fbb;
ASSIGN_OR_RETURN(auto invocation_offs,
debug_service_->SerializeInvocation(invocation, &fbb));
rpc::BreakpointHitEventBuilder event(fbb);
event.add_breakpoint_id(breakpoint_id);
event.add_invocation(invocation_offs);
return PostEvent(event.Finish(), std::move(fbb));
}
private:
void SessionThread() {
VLOG(2) << "Client " << client_fd_ << ": Thread entry";
Status session_status = OkStatus();
while (session_status.ok()) {
auto buffer_or = tcp::ReadBuffer<rpc::Request>(client_fd_);
if (!buffer_or.ok()) {
if (IsCancelled(buffer_or.status())) {
// Graceful shutdown.
VLOG(2) << "Client " << client_fd_ << ": Graceful shutdown requested";
break;
}
// Error reading.
session_status = std::move(buffer_or).status();
LOG(ERROR) << "Client " << client_fd_
<< ": Error reading request buffer: " << session_status;
break;
}
auto request_buffer = std::move(buffer_or).ValueOrDie();
session_status = DispatchRequest(request_buffer.GetRoot());
if (!session_status.ok()) {
LOG(ERROR) << "Client " << client_fd_
<< ": Error dispatching request: " << session_status;
break;
}
}
VLOG(2) << "Client " << client_fd_ << ": Thread exit";
AbortSession(session_status);
}
void AbortSession(Status status) {
if (status.ok()) {
VLOG(2) << "Debug client disconnected";
} else {
LOG(ERROR) << "Debug session aborted; " << status;
::flatbuffers::FlatBufferBuilder fbb;
auto message_offs =
fbb.CreateString(status.message().data(), status.message().size());
rpc::StatusBuilder status_builder(fbb);
status_builder.add_code(static_cast<int>(status.code()));
status_builder.add_message(message_offs);
auto status_offs = status_builder.Finish();
rpc::ResponseBuilder response(fbb);
response.add_status(status_offs);
fbb.FinishSizePrefixed(response.Finish());
tcp::WriteBuffer(client_fd_, fbb.Release()).IgnoreError();
}
closed_callback_(this, std::move(status));
}
template <typename T>
Status PostEvent(::flatbuffers::Offset<T> event_offs,
::flatbuffers::FlatBufferBuilder fbb) {
rpc::ServicePacketBuilder packet_builder(fbb);
packet_builder.add_event_type(rpc::EventUnionTraits<T>::enum_value);
packet_builder.add_event(event_offs.Union());
fbb.FinishSizePrefixed(packet_builder.Finish());
return tcp::WriteBuffer(client_fd_, fbb.Release());
}
Status DispatchRequest(const rpc::Request& request) {
::flatbuffers::FlatBufferBuilder fbb;
switch (request.message_type()) {
#define DISPATCH_REQUEST(method_name) \
case rpc::RequestUnion::method_name##Request: { \
VLOG(2) << "Client " << client_fd_ \
<< ": DispatchRequest(" #method_name ")..."; \
ASSIGN_OR_RETURN(auto response_offs, \
debug_service_->method_name( \
*request.message_as_##method_name##Request(), &fbb)); \
return WriteResponse(client_fd_, response_offs, std::move(fbb)); \
}
DISPATCH_REQUEST(MakeReady);
DISPATCH_REQUEST(GetStatus);
DISPATCH_REQUEST(ListContexts);
DISPATCH_REQUEST(GetModule);
DISPATCH_REQUEST(GetFunction);
DISPATCH_REQUEST(ListInvocations);
DISPATCH_REQUEST(SuspendInvocations);
DISPATCH_REQUEST(ResumeInvocations);
DISPATCH_REQUEST(StepInvocation);
DISPATCH_REQUEST(GetInvocationLocal);
DISPATCH_REQUEST(SetInvocationLocal);
DISPATCH_REQUEST(ListBreakpoints);
DISPATCH_REQUEST(AddBreakpoint);
DISPATCH_REQUEST(RemoveBreakpoint);
DISPATCH_REQUEST(StartProfiling);
DISPATCH_REQUEST(StopProfiling);
default:
return UnimplementedErrorBuilder(IREE_LOC)
<< "Unimplemented debug service request: "
<< static_cast<int>(request.message_type());
}
}
DebugService* debug_service_;
int client_fd_;
ClosedCallback closed_callback_;
std::thread session_thread_;
};
class TcpDebugServer final : public DebugServer {
public:
static StatusOr<std::unique_ptr<TcpDebugServer>> Listen(int port) {
// We support both IPv4 and IPv6 by using the IN6ADDR_ANY. This requires
// that we setup the socket as INET6 and enable reuse (so the same port can
// be bound for both IPv4 and IPv6).
int listen_fd = ::socket(AF_INET6, SOCK_STREAM, 0);
RETURN_IF_ERROR(tcp::ToggleSocketAddressReuse(listen_fd, true));
struct sockaddr_in6 socket_addr = {0};
socket_addr.sin6_family = AF_INET6;
socket_addr.sin6_port = htons(port);
socket_addr.sin6_addr = in6addr_any;
if (::bind(listen_fd, reinterpret_cast<struct sockaddr*>(&socket_addr),
sizeof(socket_addr)) < 0) {
return AlreadyExistsErrorBuilder(IREE_LOC)
<< "Unable to bind socket to port " << port << ": (" << errno
<< ") " << ::strerror(errno);
}
if (::listen(listen_fd, 1)) {
::close(listen_fd);
return AlreadyExistsErrorBuilder(IREE_LOC)
<< "Unable to listen on port " << port << ": (" << errno << ") "
<< ::strerror(errno);
}
return absl::make_unique<TcpDebugServer>(listen_fd);
}
TcpDebugServer(int listen_fd) : listen_fd_(listen_fd) {
server_thread_ = std::thread([this]() { ListenThread(); });
}
~TcpDebugServer() ABSL_LOCKS_EXCLUDED(mutex_) override {
absl::ReleasableMutexLock lock(&mutex_);
LOG(INFO) << "Shutting down debug server...";
// Notify all sessions.
for (auto& session : sessions_) {
session->OnServiceShutdown().IgnoreError();
}
// Shut down listen socket first so that we can't accept new connections.
VLOG(2) << "Shutting down listen socket...";
::shutdown(listen_fd_, SHUT_RDWR);
if (server_thread_.joinable()) {
VLOG(2) << "Joining listen thread...";
server_thread_.join();
VLOG(2) << "Joined listen thread!";
}
VLOG(2) << "Closing listen socket...";
::close(listen_fd_);
listen_fd_ = -1;
VLOG(2) << "Closed listen socket!";
// Kill all active sessions. Note that we must do this outside of our lock.
std::vector<std::unique_ptr<TcpDebugSession>> sessions =
std::move(sessions_);
std::vector<std::function<void()>> at_exit_callbacks =
std::move(at_exit_callbacks_);
lock.Release();
VLOG(2) << "Clearing live sessions...";
sessions.clear();
VLOG(2) << "Calling AtExit callbacks...";
for (auto& callback : at_exit_callbacks) {
callback();
}
LOG(INFO) << "Debug server shutdown!";
}
DebugService* debug_service() { return &debug_service_; }
Status AcceptNewSession(int client_fd) {
LOG(INFO) << "Accepting new client session as " << client_fd;
ASSIGN_OR_RETURN(auto session,
TcpDebugSession::Accept(
&debug_service_, client_fd,
[this](TcpDebugSession* session, Status status) {
absl::MutexLock lock(&mutex_);
for (auto it = sessions_.begin();
it != sessions_.end(); ++it) {
if (it->get() == session) {
sessions_.erase(it);
break;
}
}
return OkStatus();
}));
absl::MutexLock lock(&mutex_);
sessions_.push_back(std::move(session));
return OkStatus();
}
void AtExit(std::function<void()> callback) override {
absl::MutexLock lock(&mutex_);
at_exit_callbacks_.push_back(std::move(callback));
}
Status WaitUntilSessionReady() override {
return debug_service_.WaitUntilAllSessionsReady();
}
protected:
Status RegisterContext(Context* context) override {
return debug_service_.RegisterContext(context);
}
Status UnregisterContext(Context* context) override {
return debug_service_.UnregisterContext(context);
}
Status RegisterContextModule(Context* context, Module* module) override {
return debug_service_.RegisterContextModule(context, module);
}
Status RegisterInvocation(Invocation* invocation) override {
return debug_service_.RegisterInvocation(invocation);
}
Status UnregisterInvocation(Invocation* invocation) override {
return debug_service_.UnregisterInvocation(invocation);
}
private:
void ListenThread() {
VLOG(2) << "Listen thread entry";
while (true) {
struct sockaddr_in accept_socket_addr;
socklen_t accept_socket_addr_length = sizeof(accept_socket_addr);
int accepted_fd = ::accept(
listen_fd_, reinterpret_cast<struct sockaddr*>(&accept_socket_addr),
&accept_socket_addr_length);
if (accepted_fd < 0) {
if (errno == EINVAL) {
// Shutting down gracefully.
break;
}
// We may be able to recover from some of these cases, but... shrug.
LOG(FATAL) << "Failed to accept client socket: (" << errno << ") "
<< ::strerror(errno);
break;
}
auto accept_status = AcceptNewSession(accepted_fd);
if (!accept_status.ok()) {
LOG(ERROR) << "Failed to accept incoming debug client: "
<< accept_status;
}
}
VLOG(2) << "Listen thread exit";
}
int listen_fd_;
std::thread server_thread_;
absl::Mutex mutex_;
std::vector<std::unique_ptr<TcpDebugSession>> sessions_
ABSL_GUARDED_BY(mutex_);
std::vector<std::function<void()>> at_exit_callbacks_ ABSL_GUARDED_BY(mutex_);
DebugService debug_service_;
};
} // namespace
// static
StatusOr<std::unique_ptr<DebugServer>> DebugServer::Create(int listen_port) {
ASSIGN_OR_RETURN(auto debug_server, TcpDebugServer::Listen(listen_port));
LOG(INFO) << "Debug server listening on localhost:" << listen_port;
return debug_server;
}
} // namespace debug
} // namespace rt
} // namespace iree