blob: 851a968379898b414a3c32bed85f6bb5350a2c1b [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <netdb.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include <algorithm>
#include <cstring>
#include <queue>
#include "third_party/absl/container/flat_hash_map.h"
#include "third_party/absl/memory/memory.h"
#include "third_party/absl/strings/ascii.h"
#include "third_party/absl/strings/numbers.h"
#include "third_party/absl/strings/str_split.h"
#include "third_party/absl/types/span.h"
#include "third_party/flatbuffers/include/flatbuffers/base.h"
#include "third_party/flatbuffers/include/flatbuffers/flatbuffers.h"
#include "third_party/mlir_edge/iree/base/flatbuffer_util.h"
#include "third_party/mlir_edge/iree/base/status.h"
#include "third_party/mlir_edge/iree/schemas/debug_service_generated.h"
#include "third_party/mlir_edge/iree/schemas/module_def_generated.h"
#include "third_party/mlir_edge/iree/vm/debug/debug_client.h"
#include "third_party/mlir_edge/iree/vm/debug/debug_tcp_util.h"
#include "third_party/mlir_edge/iree/vm/module.h"
namespace iree {
namespace vm {
namespace debug {
namespace {
using ::flatbuffers::FlatBufferBuilder;
using ::iree::vm::ModuleFile;
// Parses a host:port address, with support for the RFC 3986 IPv6 [host]:port
// format. Returns a pair of (hostname, port), with port being 0 if none was
// specified.
//
// Parses:
// foo (port 0) / foo:123
// 1.2.3.4 (port 0) / 1.2.3.4:123
// [foo] (port 0) / [foo]:123
// [::1] (port 0) / [::1]:123
StatusOr<std::pair<std::string, int>> ParseAddress(absl::string_view address) {
address = absl::StripAsciiWhitespace(address);
absl::string_view hostname;
absl::string_view port_str;
size_t bracket_loc = address.find_last_of(']');
if (bracket_loc != std::string::npos) {
// Has at least a ]. Let's assume it's mostly right.
if (address.find('[') != 0) {
return InvalidArgumentErrorBuilder(ABSL_LOC)
<< "Mismatched brackets in address: " << address;
}
hostname = address.substr(1, bracket_loc - 1);
port_str = address.substr(bracket_loc + 1);
if (port_str.find(':') == 0) {
port_str.remove_prefix(1);
}
} else {
size_t colon_loc = address.find_last_of(':');
if (colon_loc != std::string::npos) {
hostname = address.substr(0, colon_loc);
port_str = address.substr(colon_loc + 1);
} else {
hostname = address;
port_str = "";
}
}
int port = 0;
if (!port_str.empty() && !absl::SimpleAtoi(port_str, &port)) {
return InvalidArgumentErrorBuilder(ABSL_LOC)
<< "Unable to parse port '" << port_str << "' from " << address;
}
return std::make_pair(std::string(hostname), port);
}
class TcpDebugClient final : public DebugClient {
public:
class TcpRemoteBreakpoint : public RemoteBreakpoint {
public:
TcpRemoteBreakpoint(int id, Type type, TcpDebugClient* client)
: RemoteBreakpoint(id, type) {}
const std::string& module_name() const override { return def_.module_name; }
const std::string& function_name() const override {
return def_.function_name;
}
int function_ordinal() const override { return def_.function_ordinal; }
int bytecode_offset() const override { return def_.bytecode_offset; }
Status MergeFrom(const rpc::BreakpointDef& breakpoint_def) {
breakpoint_def.UnPackTo(&def_);
return OkStatus();
}
private:
rpc::BreakpointDefT def_;
};
class TcpRemoteFunction final : public RemoteFunction {
public:
TcpRemoteFunction(RemoteModule* module, int function_ordinal,
const FunctionDef* function_def, TcpDebugClient* client)
: RemoteFunction(module, function_ordinal),
def_(function_def),
client_(client) {
name_ = def_->name() ? std::string(WrapString(def_->name())) : "";
}
const std::string& name() const override { return name_; }
const FunctionDef& def() override { return *def_; }
bool is_loaded() const override {
return contents_.flatbuffers_buffer.size() > 0;
}
bool CheckLoadedOrRequest() override {
if (!is_loaded()) {
DemandContents();
}
return is_loaded();
}
void WhenLoaded(LoadCallback callback) override {
if (is_loaded()) {
callback(this);
return;
}
load_callbacks_.push_back(std::move(callback));
}
const BytecodeDef* bytecode() override {
CHECK(is_loaded());
return contents_.bytecode_def;
}
private:
void DemandContents() {
if (!has_requested_contents_) {
VLOG(2) << "Client " << client_->fd() << ": GetFunction("
<< module()->context_id() << ", " << module()->name() << ", "
<< ordinal() << ")";
FlatBufferBuilder fbb;
rpc::GetFunctionRequestT request;
request.session_id = client_->session_id();
request.context_id = module()->context_id();
request.module_name = module()->name();
request.function_ordinal = ordinal();
auto status =
client_->IssueRequest<rpc::GetFunctionRequest,
rpc::ResponseUnion::GetFunctionResponse>(
rpc::GetFunctionRequest::Pack(fbb, &request), std::move(fbb),
[this](Status status,
const rpc::Response& response_union) -> Status {
if (!status.ok()) return status;
const auto& response =
*response_union.message_as_GetFunctionResponse();
VLOG(2) << "Client " << client_->fd() << ": GetFunction("
<< module()->context_id() << ", " << module()->name()
<< ", " << ordinal() << ") = ...";
RETURN_IF_ERROR(MergeFrom(response));
for (auto& callback : load_callbacks_) {
callback(this);
}
load_callbacks_.clear();
return OkStatus();
});
if (!status.ok()) {
LOG(ERROR) << "Failed to request module: " << status;
return;
}
has_requested_contents_ = true;
}
}
Status MergeFrom(const rpc::GetFunctionResponse& response) {
// Clone and retain the contents.
// TODO(benvanik): find a way to steal to avoid the reserialization.
BytecodeDefT bytecode_def_storage;
response.bytecode()->UnPackTo(&bytecode_def_storage);
::flatbuffers::FlatBufferBuilder fbb;
fbb.Finish(response.bytecode()->Pack(fbb, &bytecode_def_storage));
contents_.flatbuffers_buffer = fbb.Release();
contents_.bytecode_def = ::flatbuffers::GetRoot<BytecodeDef>(
contents_.flatbuffers_buffer.data());
return OkStatus();
}
const FunctionDef* def_;
TcpDebugClient* client_;
std::string name_;
bool has_requested_contents_ = false;
std::vector<LoadCallback> load_callbacks_;
struct {
::flatbuffers::DetachedBuffer flatbuffers_buffer;
const BytecodeDef* bytecode_def = nullptr;
} contents_;
};
class TcpRemoteModule final : public RemoteModule {
public:
TcpRemoteModule(int context_id, std::string module_name,
TcpDebugClient* client)
: RemoteModule(context_id, std::move(module_name)), client_(client) {}
const ModuleDef& def() override {
CHECK(is_loaded());
return *module_file_->root();
}
bool is_loaded() const override { return module_file_ != nullptr; }
bool CheckLoadedOrRequest() override {
if (!is_loaded()) {
DemandModuleDef();
}
return is_loaded();
}
void WhenLoaded(LoadCallback callback) override {
if (is_loaded()) {
callback(this);
return;
}
load_callbacks_.push_back(std::move(callback));
}
absl::Span<RemoteFunction*> functions() override {
auto* module_def = DemandModuleDef();
if (!module_def) return {};
return {reinterpret_cast<RemoteFunction**>(functions_.data()),
functions_.size()};
}
private:
const ModuleDef* DemandModuleDef() {
if (module_file_) {
return module_file_->root();
}
if (!has_requested_module_def_) {
VLOG(2) << "Client " << client_->fd() << ": GetModule(" << context_id()
<< ", " << name() << ")";
FlatBufferBuilder fbb;
rpc::GetModuleRequestT request;
request.session_id = client_->session_id();
request.context_id = context_id();
request.module_name = name();
auto status =
client_->IssueRequest<rpc::GetModuleRequest,
rpc::ResponseUnion::GetModuleResponse>(
rpc::GetModuleRequest::Pack(fbb, &request), std::move(fbb),
[this](Status status,
const rpc::Response& response_union) -> Status {
if (!status.ok()) return status;
const auto& response =
*response_union.message_as_GetModuleResponse();
VLOG(2) << "Client " << client_->fd() << ": GetModule("
<< context_id() << ", " << name() << ") = ...";
RETURN_IF_ERROR(MergeFrom(response));
for (auto& callback : load_callbacks_) {
callback(this);
}
load_callbacks_.clear();
return OkStatus();
});
if (!status.ok()) {
LOG(ERROR) << "Failed to request module: " << status;
return nullptr;
}
has_requested_module_def_ = true;
}
return nullptr;
}
Status MergeFrom(const rpc::GetModuleResponse& response) {
// Clone and retain the module.
// TODO(benvanik): find a way to steal to avoid the reserialization.
ModuleDefT module_def_storage;
response.module_()->UnPackTo(&module_def_storage);
FlatBufferBuilder fbb;
auto module_offs = response.module_()->Pack(fbb, &module_def_storage);
FinishModuleDefBuffer(fbb, module_offs);
ASSIGN_OR_RETURN(auto module_file,
ModuleFile::CreateWithBackingBuffer(fbb.Release()));
const auto& module_def = module_file->root();
const auto& function_table = *module_def->function_table();
functions_.reserve(function_table.functions()->size());
for (int i = 0; i < function_table.functions()->size(); ++i) {
const auto* function_def = function_table.functions()->Get(i);
functions_.push_back(absl::make_unique<TcpRemoteFunction>(
this, i, function_def, client_));
}
module_file_ = std::move(module_file);
return OkStatus();
}
TcpDebugClient* client_;
bool has_requested_module_def_ = false;
std::vector<LoadCallback> load_callbacks_;
std::unique_ptr<ModuleFile> module_file_;
std::vector<std::unique_ptr<RemoteFunction>> functions_;
};
class TcpRemoteContext final : public RemoteContext {
public:
TcpRemoteContext(int context_id, TcpDebugClient* client)
: RemoteContext(context_id), client_(client) {}
absl::Span<RemoteModule* const> modules() const override {
return absl::MakeConstSpan(modules_);
}
Status AddModule(std::unique_ptr<TcpRemoteModule> module) {
modules_.push_back(module.get());
module_map_.insert({module->name(), std::move(module)});
return OkStatus();
}
Status MergeFrom(const rpc::ContextDef& context_def) { return OkStatus(); }
private:
TcpDebugClient* client_;
std::vector<RemoteModule*> modules_;
absl::flat_hash_map<std::string, std::unique_ptr<TcpRemoteModule>>
module_map_;
};
class TcpRemoteFiberState final : public RemoteFiberState {
public:
TcpRemoteFiberState(int fiber_id, TcpDebugClient* client)
: RemoteFiberState(fiber_id), client_(client) {}
const rpc::FiberStateDefT& def() const override { return def_; }
Status MergeFrom(const rpc::FiberStateDef& fiber_state_def) {
fiber_state_def.UnPackTo(&def_);
return OkStatus();
}
private:
TcpDebugClient* client_;
rpc::FiberStateDefT def_;
};
static StatusOr<std::unique_ptr<TcpDebugClient>> Create(int fd,
Listener* listener) {
VLOG(2) << "Client " << fd << ": Setting up socket options...";
// Disable Nagel's algorithm to ensure we have low latency.
RETURN_IF_ERROR(tcp::ToggleSocketNagelsAlgorithm(fd, false));
// Enable keepalive assuming the client is local and this high freq is ok.
RETURN_IF_ERROR(tcp::ToggleSocketLocalKeepalive(fd, true));
// Linger around for a bit to flush all data.
RETURN_IF_ERROR(tcp::ToggleSocketLinger(fd, true));
// Disable blocking as we are poll based.
RETURN_IF_ERROR(tcp::ToggleSocketBlocking(fd, false));
auto client = absl::make_unique<TcpDebugClient>(fd, listener);
RETURN_IF_ERROR(client->Refresh());
return client;
}
TcpDebugClient(int fd, Listener* listener) : fd_(fd), listener_(listener) {}
~TcpDebugClient() override {
VLOG(2) << "Client " << fd_ << ": Shutting down session socket...";
::shutdown(fd_, SHUT_WR);
VLOG(2) << "Client " << fd_ << ": Closing session socket...";
::close(fd_);
VLOG(2) << "Client " << fd_ << ": Closed session socket!";
fd_ = -1;
}
int fd() const { return fd_; }
int session_id() const { return session_id_; }
absl::Span<RemoteContext* const> contexts() const override {
return absl::MakeConstSpan(contexts_);
}
absl::Span<RemoteFiberState* const> fiber_states() const override {
return absl::MakeConstSpan(fiber_states_);
}
absl::Span<RemoteBreakpoint* const> breakpoints() const override {
return absl::MakeConstSpan(breakpoints_);
}
// Writes the given typed request message to the given fd by wrapping it in
// a size-prefixed rpc::Request union.
//
// Example:
// FlatBufferBuilder fbb;
// rpc::SuspendFiberRequestBuilder request(fbb);
// RETURN_IF_ERROR(WriteRequest(fd_, request.Finish(), std::move(fbb)));
template <typename T>
Status WriteRequest(int fd, ::flatbuffers::Offset<T> request_offs,
FlatBufferBuilder fbb) {
rpc::RequestBuilder request_builder(fbb);
request_builder.add_message_type(rpc::RequestUnionTraits<T>::enum_value);
request_builder.add_message(request_offs.Union());
fbb.FinishSizePrefixed(request_builder.Finish());
auto write_status = tcp::WriteBuffer(fd, fbb.Release());
if (shutdown_pending_ && IsUnavailable(write_status)) {
return OkStatus();
}
return write_status;
}
Status ResolveFunction(
std::string module_name, std::string function_name,
std::function<void(StatusOr<int> function_ordinal)> callback) override {
VLOG(2) << "Client " << fd_ << ": ResolveFunction(" << module_name << ", "
<< function_name << ")";
FlatBufferBuilder fbb;
rpc::ResolveFunctionRequestT request;
request.session_id = session_id_;
request.module_name = module_name;
request.function_name = function_name;
return IssueRequest<rpc::ResolveFunctionRequest,
rpc::ResponseUnion::ResolveFunctionResponse>(
rpc::ResolveFunctionRequest::Pack(fbb, &request), std::move(fbb),
[this, module_name, function_name, callback](
Status status, const rpc::Response& response_union) -> Status {
if (status.ok()) {
const auto& response =
*response_union.message_as_ResolveFunctionResponse();
VLOG(2) << "Client " << fd_ << ": ResolveFunction(" << module_name
<< ", " << function_name
<< ") = " << response.function_ordinal();
callback(response.function_ordinal());
} else {
callback(std::move(status));
}
return OkStatus();
});
}
Status GetFunction(std::string module_name, int function_ordinal,
std::function<void(StatusOr<RemoteFunction*> function)>
callback) override {
// See if we have the module already. If not, we'll fetch it first.
RemoteModule* target_module = nullptr;
for (auto* context : contexts_) {
for (auto* module : context->modules()) {
if (module->name() == module_name) {
target_module = module;
break;
}
}
if (target_module) break;
}
if (!target_module) {
// TODO(benvanik): fetch contexts first.
return UnimplementedErrorBuilder(ABSL_LOC)
<< "Demand fetch contexts not yet implemented";
}
// Found at least one module with the right name.
if (target_module->is_loaded()) {
callback(target_module->functions()[function_ordinal]);
return OkStatus();
} else {
// Wait until the module completes loading.
target_module->WhenLoaded(
[callback, function_ordinal](StatusOr<RemoteModule*> module_or) {
if (!module_or.ok()) {
callback(module_or.status());
return;
}
callback(module_or.ValueOrDie()->functions()[function_ordinal]);
});
return OkStatus();
}
}
Status AddFunctionBreakpoint(
std::string module_name, std::string function_name, int offset,
std::function<void(const RemoteBreakpoint& breakpoint)> callback)
override {
VLOG(2) << "Client " << fd_ << ": AddFunctionBreakpoint(" << module_name
<< ", " << function_name << ", " << offset << ")";
FlatBufferBuilder fbb;
auto breakpoint = absl::make_unique<rpc::BreakpointDefT>();
breakpoint->module_name = module_name;
breakpoint->function_name = function_name;
breakpoint->function_ordinal = -1;
breakpoint->bytecode_offset = offset;
rpc::AddBreakpointRequestT request;
request.session_id = session_id_;
request.breakpoint = std::move(breakpoint);
return IssueRequest<rpc::AddBreakpointRequest,
rpc::ResponseUnion::AddBreakpointResponse>(
rpc::AddBreakpointRequest::Pack(fbb, &request), std::move(fbb),
[this, callback](Status status,
const rpc::Response& response_union) -> Status {
if (!status.ok()) return status;
const auto& response =
*response_union.message_as_AddBreakpointResponse();
RETURN_IF_ERROR(RegisterBreakpoint(*response.breakpoint()));
if (callback) {
ASSIGN_OR_RETURN(
auto breakpoint,
GetBreakpoint(response.breakpoint()->breakpoint_id()));
callback(*breakpoint);
}
return OkStatus();
});
}
Status RemoveBreakpoint(const RemoteBreakpoint& breakpoint) override {
VLOG(2) << "Client " << fd_ << ": RemoveBreakpoint(" << breakpoint.id()
<< ")";
int breakpoint_id = breakpoint.id();
ASSIGN_OR_RETURN(auto* breakpoint_ptr, GetBreakpoint(breakpoint_id));
RETURN_IF_ERROR(UnregisterBreakpoint(breakpoint_ptr));
FlatBufferBuilder fbb;
rpc::RemoveBreakpointRequestBuilder request(fbb);
request.add_session_id(session_id_);
request.add_breakpoint_id(breakpoint_id);
return IssueRequest<rpc::RemoveBreakpointRequest,
rpc::ResponseUnion::RemoveBreakpointResponse>(
request.Finish(), std::move(fbb),
[](Status status, const rpc::Response& response_union) -> Status {
if (!status.ok()) return status;
// No non-error status.
return OkStatus();
});
}
Status MakeReady() override {
FlatBufferBuilder fbb;
rpc::MakeReadyRequestBuilder request(fbb);
request.add_session_id(session_id_);
return IssueRequest<rpc::MakeReadyRequest,
rpc::ResponseUnion::MakeReadyResponse>(
request.Finish(), std::move(fbb),
[](Status status, const rpc::Response& response_union) {
return status;
});
}
Status SuspendAllFibers() override {
VLOG(2) << "Client " << fd_ << ": SuspendAllFibers()";
FlatBufferBuilder fbb;
rpc::SuspendFibersRequestBuilder request(fbb);
request.add_session_id(session_id_);
return IssueRequest<rpc::SuspendFibersRequest,
rpc::ResponseUnion::SuspendFibersResponse>(
request.Finish(), std::move(fbb),
[this](Status status, const rpc::Response& response_union) -> Status {
if (!status.ok()) return status;
return RefreshFiberStates();
});
}
Status ResumeAllFibers() override {
VLOG(2) << "Client " << fd_ << ": ResumeAllFibers()";
FlatBufferBuilder fbb;
rpc::ResumeFibersRequestBuilder request(fbb);
request.add_session_id(session_id_);
return IssueRequest<rpc::ResumeFibersRequest,
rpc::ResponseUnion::ResumeFibersResponse>(
request.Finish(), std::move(fbb),
[this](Status status, const rpc::Response& response_union) -> Status {
if (!status.ok()) return status;
return RefreshFiberStates();
});
}
Status SuspendFibers(absl::Span<RemoteFiberState*> fibers) override {
VLOG(2) << "Client " << fd_ << ": SuspendFibers(...)";
FlatBufferBuilder fbb;
auto fiber_ids_offs = fbb.CreateVector<int32_t>(
fibers.size(), [&fibers](size_t i) { return fibers[i]->id(); });
rpc::SuspendFibersRequestBuilder request(fbb);
request.add_session_id(session_id_);
request.add_fiber_ids(fiber_ids_offs);
return IssueRequest<rpc::SuspendFibersRequest,
rpc::ResponseUnion::SuspendFibersResponse>(
request.Finish(), std::move(fbb),
[this](Status status, const rpc::Response& response_union) -> Status {
if (!status.ok()) return status;
return RefreshFiberStates();
});
}
Status ResumeFibers(absl::Span<RemoteFiberState*> fibers) override {
VLOG(2) << "Client " << fd_ << ": ResumeFibers(...)";
FlatBufferBuilder fbb;
auto fiber_ids_offs = fbb.CreateVector<int32_t>(
fibers.size(), [&fibers](size_t i) { return fibers[i]->id(); });
rpc::ResumeFibersRequestBuilder request(fbb);
request.add_session_id(session_id_);
request.add_fiber_ids(fiber_ids_offs);
return IssueRequest<rpc::ResumeFibersRequest,
rpc::ResponseUnion::ResumeFibersResponse>(
request.Finish(), std::move(fbb),
[this](Status status, const rpc::Response& response_union) -> Status {
if (!status.ok()) return status;
return RefreshFiberStates();
});
}
Status StepFiber(const RemoteFiberState& fiber_state,
std::function<void()> callback) override {
int step_id = next_step_id_++;
VLOG(2) << "Client " << fd_ << ": StepFiber(" << fiber_state.id()
<< ") as step_id=" << step_id;
rpc::StepFiberRequestT step_request;
step_request.step_id = step_id;
step_request.fiber_id = fiber_state.id();
step_request.step_mode = rpc::StepMode::STEP_ONCE;
return StepFiber(&step_request, std::move(callback));
}
Status StepFiberToOffset(const RemoteFiberState& fiber_state,
int bytecode_offset,
std::function<void()> callback) override {
int step_id = next_step_id_++;
VLOG(2) << "Client " << fd_ << ": StepFiberToOffset(" << fiber_state.id()
<< ", " << bytecode_offset << ") as step_id=" << step_id;
rpc::StepFiberRequestT step_request;
step_request.step_id = step_id;
step_request.fiber_id = fiber_state.id();
step_request.step_mode = rpc::StepMode::STEP_TO_OFFSET;
step_request.bytecode_offset = bytecode_offset;
return StepFiber(&step_request, std::move(callback));
}
Status Poll() override {
while (true) {
// If nothing awaiting then return immediately.
if (!tcp::CanReadBuffer(fd_)) {
break;
}
// Read the pending response and dispatch.
auto packet_buffer_or = tcp::ReadBuffer<rpc::ServicePacket>(fd_);
if (!packet_buffer_or.ok()) {
if (shutdown_pending_ && IsUnavailable(packet_buffer_or.status())) {
// This is a graceful close.
return CancelledErrorBuilder(ABSL_LOC) << "Service shutdown";
}
return packet_buffer_or.status();
}
const auto& packet = packet_buffer_or.ValueOrDie().GetRoot();
if (packet.response()) {
RETURN_IF_ERROR(DispatchResponse(*packet.response()));
}
if (packet.event()) {
RETURN_IF_ERROR(DispatchEvent(packet));
}
}
return OkStatus();
}
using ResponseCallback =
std::function<Status(Status status, const rpc::Response& response)>;
template <typename T, rpc::ResponseUnion response_type>
Status IssueRequest(::flatbuffers::Offset<T> request_offs,
FlatBufferBuilder fbb, ResponseCallback callback) {
RETURN_IF_ERROR(WriteRequest(fd_, request_offs, std::move(fbb)));
pending_responses_.push({response_type, std::move(callback)});
return OkStatus();
}
private:
Status Refresh() {
RETURN_IF_ERROR(RefreshContexts());
RETURN_IF_ERROR(RefreshFiberStates());
RETURN_IF_ERROR(RefreshBreakpoints());
return OkStatus();
}
Status RefreshContexts() {
VLOG(2) << "Request contexts refresh...";
FlatBufferBuilder fbb;
rpc::ListContextsRequestBuilder request(fbb);
request.add_session_id(session_id_);
return IssueRequest<rpc::ListContextsRequest,
rpc::ResponseUnion::ListContextsResponse>(
request.Finish(), std::move(fbb),
[this](Status status, const rpc::Response& response_union) -> Status {
if (!status.ok()) return status;
VLOG(2) << "Refreshing contexts...";
const auto& response =
*response_union.message_as_ListContextsResponse();
for (auto* context_def : *response.contexts()) {
auto context_or = GetContext(context_def->context_id());
if (!context_or.ok()) {
// Not found; add new.
RETURN_IF_ERROR(RegisterContext(context_def->context_id()));
context_or = GetContext(context_def->context_id());
}
RETURN_IF_ERROR(context_or.status());
RETURN_IF_ERROR(context_or.ValueOrDie()->MergeFrom(*context_def));
}
VLOG(2) << "Refreshed contexts!";
return OkStatus();
});
}
Status RefreshFiberStates() {
VLOG(2) << "Request fiber states refresh...";
FlatBufferBuilder fbb;
rpc::ListFibersRequestBuilder request(fbb);
request.add_session_id(session_id_);
return IssueRequest<rpc::ListFibersRequest,
rpc::ResponseUnion::ListFibersResponse>(
request.Finish(), std::move(fbb),
[this](Status status, const rpc::Response& response_union) -> Status {
if (!status.ok()) return status;
VLOG(2) << "Refreshing fiber states...";
const auto& response =
*response_union.message_as_ListFibersResponse();
for (auto* fiber_state_def : *response.fiber_states()) {
auto fiber_state_or = GetFiberState(fiber_state_def->fiber_id());
if (!fiber_state_or.ok()) {
// Not found; add new.
RETURN_IF_ERROR(RegisterFiberState(fiber_state_def->fiber_id()));
fiber_state_or = GetFiberState(fiber_state_def->fiber_id());
}
RETURN_IF_ERROR(fiber_state_or.status());
RETURN_IF_ERROR(
fiber_state_or.ValueOrDie()->MergeFrom(*fiber_state_def));
}
// TODO(benvanik): handle removals/deaths.
VLOG(2) << "Refreshed fiber states!";
return OkStatus();
});
}
Status RefreshBreakpoints() {
VLOG(2) << "Requesting breakpoint refresh...";
FlatBufferBuilder fbb;
rpc::ListBreakpointsRequestBuilder request(fbb);
request.add_session_id(session_id_);
return IssueRequest<rpc::ListBreakpointsRequest,
rpc::ResponseUnion::ListBreakpointsResponse>(
request.Finish(), std::move(fbb),
[this](Status status, const rpc::Response& response_union) -> Status {
if (!status.ok()) return status;
VLOG(2) << "Refreshing breakpoints...";
const auto& response =
*response_union.message_as_ListBreakpointsResponse();
for (auto* breakpoint_def : *response.breakpoints()) {
auto breakpoint_or = GetBreakpoint(breakpoint_def->breakpoint_id());
if (!breakpoint_or.ok()) {
// Not found; add new.
RETURN_IF_ERROR(RegisterBreakpoint(*breakpoint_def));
breakpoint_or = GetBreakpoint(breakpoint_def->breakpoint_id());
}
RETURN_IF_ERROR(breakpoint_or.status());
RETURN_IF_ERROR(
breakpoint_or.ValueOrDie()->MergeFrom(*breakpoint_def));
}
// TODO(benvanik): handle removals/deaths.
VLOG(2) << "Refreshed breakpoints!";
return OkStatus();
});
}
Status DispatchResponse(const rpc::Response& response) {
if (pending_responses_.empty()) {
return FailedPreconditionErrorBuilder(ABSL_LOC)
<< "Response received but no request is pending";
}
auto type_callback = std::move(pending_responses_.front());
pending_responses_.pop();
if (response.status()) {
const auto& status = *response.status();
Status client_status =
StatusBuilder(static_cast<StatusCode>(status.code()), ABSL_LOC)
<< "Server request failed: " << WrapString(status.message());
return type_callback.second(std::move(client_status), response);
}
if (!response.message()) {
return InvalidArgumentErrorBuilder(ABSL_LOC)
<< "Response contains no message body";
}
if (response.message_type() != type_callback.first) {
return DataLossErrorBuilder(ABSL_LOC)
<< "Out of order response (mismatch pending)";
}
return type_callback.second(OkStatus(), response);
}
Status DispatchEvent(const rpc::ServicePacket& packet) {
switch (packet.event_type()) {
#define DISPATCH_EVENT(event_name) \
case rpc::EventUnion::event_name##Event: { \
VLOG(2) << "EVENT: " << #event_name; \
return On##event_name(*packet.event_as_##event_name##Event()); \
}
DISPATCH_EVENT(ServiceShutdown);
DISPATCH_EVENT(ContextRegistered);
DISPATCH_EVENT(ContextUnregistered);
DISPATCH_EVENT(ModuleLoaded);
DISPATCH_EVENT(FiberRegistered);
DISPATCH_EVENT(FiberUnregistered);
DISPATCH_EVENT(BreakpointResolved);
DISPATCH_EVENT(BreakpointHit);
DISPATCH_EVENT(StepCompleted);
default:
return UnimplementedErrorBuilder(ABSL_LOC)
<< "Unimplemented debug service event: "
<< static_cast<int>(packet.event_type());
}
}
StatusOr<TcpRemoteContext*> GetContext(int context_id) {
auto it = context_map_.find(context_id);
if (it == context_map_.end()) {
return NotFoundErrorBuilder(ABSL_LOC) << "Context was never registered";
}
return it->second.get();
}
Status OnServiceShutdown(const rpc::ServiceShutdownEvent& event) {
LOG(INFO) << "Service is shutting down; setting pending shutdown flag";
shutdown_pending_ = true;
return OkStatus();
}
Status RegisterContext(int context_id) {
auto context = absl::make_unique<TcpRemoteContext>(context_id, this);
VLOG(2) << "RegisterContext(" << context_id << ")";
auto context_ptr = context.get();
context_map_.insert({context_id, std::move(context)});
contexts_.push_back(context_ptr);
return listener_->OnContextRegistered(*context_ptr);
}
Status OnContextRegistered(const rpc::ContextRegisteredEvent& event) {
VLOG(2) << "OnContextRegistered(" << event.context_id() << ")";
auto it = context_map_.find(event.context_id());
if (it != context_map_.end()) {
return FailedPreconditionErrorBuilder(ABSL_LOC)
<< "Context already registered";
}
return RegisterContext(event.context_id());
}
Status OnContextUnregistered(const rpc::ContextUnregisteredEvent& event) {
VLOG(2) << "OnContextUnregistered(" << event.context_id() << ")";
auto it = context_map_.find(event.context_id());
if (it == context_map_.end()) {
return FailedPreconditionErrorBuilder(ABSL_LOC)
<< "Context was never registered";
}
auto context = std::move(it->second);
context_map_.erase(it);
auto list_it = std::find(contexts_.begin(), contexts_.end(), context.get());
contexts_.erase(list_it);
return listener_->OnContextUnregistered(*context);
}
Status OnModuleLoaded(const rpc::ModuleLoadedEvent& event) {
VLOG(2) << "OnModuleLoaded(" << event.context_id() << ", "
<< WrapString(event.module_name()) << ")";
ASSIGN_OR_RETURN(auto* context, GetContext(event.context_id()));
auto module_name = WrapString(event.module_name());
auto module = absl::make_unique<TcpRemoteModule>(
event.context_id(), std::string(module_name), this);
auto* module_ptr = module.get();
RETURN_IF_ERROR(context->AddModule(std::move(module)));
return listener_->OnModuleLoaded(*context, *module_ptr);
}
StatusOr<TcpRemoteFiberState*> GetFiberState(int fiber_id) {
auto it = fiber_state_map_.find(fiber_id);
if (it == fiber_state_map_.end()) {
return NotFoundErrorBuilder(ABSL_LOC) << "Fiber was never registered";
}
return it->second.get();
}
Status RegisterFiberState(int fiber_id) {
VLOG(2) << "RegisterFiberState(" << fiber_id << ")";
auto fiber_state = absl::make_unique<TcpRemoteFiberState>(fiber_id, this);
auto fiber_state_ptr = fiber_state.get();
fiber_state_map_.insert({fiber_id, std::move(fiber_state)});
fiber_states_.push_back(fiber_state_ptr);
RETURN_IF_ERROR(RefreshFiberStates());
return listener_->OnFiberRegistered(*fiber_state_ptr);
}
Status OnFiberRegistered(const rpc::FiberRegisteredEvent& event) {
VLOG(2) << "OnFiberRegistered(" << event.fiber_id() << ")";
auto it = fiber_state_map_.find(event.fiber_id());
if (it != fiber_state_map_.end()) {
return FailedPreconditionErrorBuilder(ABSL_LOC)
<< "Fiber already registered";
}
return RegisterFiberState(event.fiber_id());
}
Status OnFiberUnregistered(const rpc::FiberUnregisteredEvent& event) {
VLOG(2) << "OnFiberUnregistered(" << event.fiber_id() << ")";
auto it = fiber_state_map_.find(event.fiber_id());
if (it == fiber_state_map_.end()) {
return FailedPreconditionErrorBuilder(ABSL_LOC)
<< "Fiber was never registered";
}
auto fiber_state = std::move(it->second);
fiber_state_map_.erase(it);
auto list_it = std::find(fiber_states_.begin(), fiber_states_.end(),
fiber_state.get());
fiber_states_.erase(list_it);
return listener_->OnFiberUnregistered(*fiber_state);
}
StatusOr<TcpRemoteBreakpoint*> GetBreakpoint(int breakpoint_id) {
auto it = breakpoint_map_.find(breakpoint_id);
if (it == breakpoint_map_.end()) {
return NotFoundErrorBuilder(ABSL_LOC)
<< "Breakpoint " << breakpoint_id << " was never registered";
}
return it->second.get();
}
Status RegisterBreakpoint(const rpc::BreakpointDef& breakpoint_def) {
auto it = breakpoint_map_.find(breakpoint_def.breakpoint_id());
if (it != breakpoint_map_.end()) {
VLOG(2) << "RegisterBreakpoint(" << breakpoint_def.breakpoint_id()
<< ") (update)";
return it->second->MergeFrom(breakpoint_def);
}
VLOG(2) << "RegisterBreakpoint(" << breakpoint_def.breakpoint_id() << ")";
auto breakpoint = absl::make_unique<TcpRemoteBreakpoint>(
breakpoint_def.breakpoint_id(),
static_cast<RemoteBreakpoint::Type>(breakpoint_def.breakpoint_type()),
this);
RETURN_IF_ERROR(breakpoint->MergeFrom(breakpoint_def));
breakpoints_.push_back(breakpoint.get());
breakpoint_map_.insert({breakpoint->id(), std::move(breakpoint)});
return OkStatus();
}
Status UnregisterBreakpoint(RemoteBreakpoint* breakpoint) {
VLOG(2) << "UnregisterBreakpoint(" << breakpoint->id() << ")";
auto it = breakpoint_map_.find(breakpoint->id());
if (it == breakpoint_map_.end()) {
return NotFoundErrorBuilder(ABSL_LOC)
<< "Breakpoint was never registered";
}
breakpoint_map_.erase(it);
auto list_it =
std::find(breakpoints_.begin(), breakpoints_.end(), breakpoint);
breakpoints_.erase(list_it);
return OkStatus();
}
Status OnBreakpointResolved(const rpc::BreakpointResolvedEvent& event) {
VLOG(2) << "OnBreakpointResolved(" << event.breakpoint()->breakpoint_id()
<< ")";
auto it = breakpoint_map_.find(event.breakpoint()->breakpoint_id());
if (it == breakpoint_map_.end()) {
RETURN_IF_ERROR(RegisterBreakpoint(*event.breakpoint()));
} else {
RETURN_IF_ERROR(it->second->MergeFrom(*event.breakpoint()));
}
return OkStatus();
}
Status OnBreakpointHit(const rpc::BreakpointHitEvent& event) {
VLOG(2) << "OnBreakpointHit(" << event.breakpoint_id() << ")";
ASSIGN_OR_RETURN(auto* breakpoint, GetBreakpoint(event.breakpoint_id()));
auto* fiber_state_def = event.fiber_state();
auto fiber_state_or = GetFiberState(fiber_state_def->fiber_id());
if (!fiber_state_or.ok()) {
// Not found; add new.
RETURN_IF_ERROR(RegisterFiberState(fiber_state_def->fiber_id()));
fiber_state_or = GetFiberState(fiber_state_def->fiber_id());
}
RETURN_IF_ERROR(fiber_state_or.status());
RETURN_IF_ERROR(fiber_state_or.ValueOrDie()->MergeFrom(*fiber_state_def));
return listener_->OnBreakpointHit(*breakpoint,
*fiber_state_or.ValueOrDie());
}
Status StepFiber(rpc::StepFiberRequestT* step_request,
std::function<void()> callback) {
FlatBufferBuilder fbb;
auto status = IssueRequest<rpc::StepFiberRequest,
rpc::ResponseUnion::StepFiberResponse>(
rpc::StepFiberRequest::Pack(fbb, step_request), std::move(fbb),
[](Status status, const rpc::Response& response_union) -> Status {
return status;
});
RETURN_IF_ERROR(status);
pending_step_callbacks_[step_request->step_id] = std::move(callback);
return OkStatus();
}
Status OnStepCompleted(const rpc::StepCompletedEvent& event) {
VLOG(2) << "OnStepCompleted(" << event.step_id() << ")";
// Update all fiber states that are contained.
// This may only be a subset of relevant states.
for (auto* fiber_state_def : *event.fiber_states()) {
ASSIGN_OR_RETURN(auto fiber_state,
GetFiberState(fiber_state_def->fiber_id()));
RETURN_IF_ERROR(fiber_state->MergeFrom(*fiber_state_def));
}
// Dispatch step callback. Note that it may have been cancelled and that's
// ok. We'll just make ready to resume execution.
auto it = pending_step_callbacks_.find(event.step_id());
if (it != pending_step_callbacks_.end()) {
it->second();
pending_step_callbacks_.erase(it);
} else {
LOG(WARNING) << "Step " << event.step_id()
<< " not found; was cancelled?";
RETURN_IF_ERROR(MakeReady());
}
return OkStatus();
}
int session_id_ = 123;
int fd_ = -1;
Listener* listener_;
bool shutdown_pending_ = false;
std::queue<std::pair<rpc::ResponseUnion, ResponseCallback>>
pending_responses_;
std::vector<RemoteContext*> contexts_;
absl::flat_hash_map<int, std::unique_ptr<TcpRemoteContext>> context_map_;
std::vector<RemoteFiberState*> fiber_states_;
absl::flat_hash_map<int, std::unique_ptr<TcpRemoteFiberState>>
fiber_state_map_;
std::vector<RemoteBreakpoint*> breakpoints_;
absl::flat_hash_map<int, std::unique_ptr<TcpRemoteBreakpoint>>
breakpoint_map_;
int next_step_id_ = 1;
absl::flat_hash_map<int, std::function<void()>> pending_step_callbacks_;
};
} // namespace
// static
StatusOr<std::unique_ptr<DebugClient>> DebugClient::Connect(
absl::string_view service_address, Listener* listener) {
// Parse address into hostname and port.
ASSIGN_OR_RETURN(auto hostname_port, ParseAddress(service_address));
if (hostname_port.second == 0) {
return InvalidArgumentErrorBuilder(ABSL_LOC)
<< "No port specified in service address; port must match the "
"server: "
<< service_address;
}
// Attempt to resolve the address.
// Note that if we only wanted local debugging we could remove the dep on
// getaddrinfo/having a valid DNS setup.
addrinfo hints = {0};
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
addrinfo* resolved_address = nullptr;
auto port_str = std::to_string(hostname_port.second);
int getaddrinfo_ret = ::getaddrinfo(
hostname_port.first.c_str(), port_str.c_str(), &hints, &resolved_address);
if (getaddrinfo_ret != 0) {
return UnavailableErrorBuilder(ABSL_LOC)
<< "Unable to resolve debug service address for " << service_address
<< ": (" << getaddrinfo_ret << ") "
<< ::gai_strerror(getaddrinfo_ret);
}
// Attempt to connect with each address returned from the query.
int fd = -1;
for (addrinfo* rp = resolved_address; rp != nullptr; rp = rp->ai_next) {
fd = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
if (fd == -1) continue;
if (::connect(fd, rp->ai_addr, rp->ai_addrlen) == 0) {
break; // Success!
}
::close(fd);
fd = -1;
}
::freeaddrinfo(resolved_address);
if (fd == -1) {
return UnavailableErrorBuilder(ABSL_LOC)
<< "Unable to connect to " << service_address << " on any address: ("
<< errno << ") " << ::strerror(errno);
}
LOG(INFO) << "Connected to debug service at " << service_address;
return TcpDebugClient::Create(fd, listener);
}
} // namespace debug
} // namespace vm
} // namespace iree