blob: ea804ac22d6f4dbe880116d7f50f4c948eac6b08 [file] [log] [blame]
"""Tracing support."""
# Copyright 2021 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from genericpath import exists
from typing import Dict, List, Optional, Sequence
import logging
import os
import sys
from . import binding as _binding
try:
import yaml
except ModuleNotFoundError:
_has_yaml = False
else:
_has_yaml = True
__all__ = [
"get_default_tracer",
"Tracer",
"TRACE_PATH_ENV_KEY",
]
TRACE_PATH_ENV_KEY = "IREE_SAVE_CALLS"
class Tracer:
"""Object for tracing calls made into the runtime."""
def __init__(self, trace_path: str):
if not _has_yaml:
self.enabled = False
logging.warning("PyYAML not installed: tracing will be disabled")
return
self.enabled = True
self.trace_path = trace_path
os.makedirs(trace_path, exist_ok=True)
self._name_count = dict() # type: Dict[str, int]
def persist_vm_module(self, vm_module: _binding.VmModule) -> "TracedModule":
# Depending on how the module was created, there are different bits
# of information available to reconstruct.
name = vm_module.name
flatbuffer_blob = vm_module.stashed_flatbuffer_blob
if flatbuffer_blob:
save_path = os.path.join(self.trace_path,
self.get_unique_name(f"{name}.vmfb"))
logging.info("Saving traced vmfb to %s", save_path)
with open(save_path, "wb") as f:
f.write(flatbuffer_blob)
return TracedModule(self, vm_module, save_path)
# No persistent form, but likely they are built-in modules.
return TracedModule(self, vm_module)
def get_unique_name(self, local_name: str) -> str:
if local_name not in self._name_count:
self._name_count[local_name] = 1
return local_name
stem, ext = os.path.splitext(local_name)
index = self._name_count[local_name]
self._name_count[local_name] += 1
unique_name = f"{stem}__{index}{ext}"
return unique_name
class TracedModule:
"""Wraps a VmModule with additional information for tracing."""
def __init__(self,
parent: Tracer,
vm_module: _binding.VmModule,
vmfb_path: Optional[str] = None):
self._parent = parent
self._vm_module = vm_module
self._vmfb_path = vmfb_path
def serialize(self):
module_record = {"name": self._vm_module.name}
if self._vmfb_path:
module_record["type"] = "bytecode"
module_record["path"] = os.path.relpath(self._vmfb_path,
self._parent.trace_path)
else:
module_record["type"] = "builtin"
return module_record
class ContextTracer:
"""Traces invocations against a context."""
def __init__(self, parent: Tracer, is_dynamic: bool,
modules: Sequence[TracedModule]):
self._parent = parent
self._modules = list(modules) # type: List[TracedModule]
self._frame_count = 0
self._file_path = os.path.join(parent.trace_path,
parent.get_unique_name("calls.yaml"))
if os.path.exists(self._file_path):
# Truncate the file.
with open(self._file_path, "wt"):
pass
else:
os.makedirs(os.path.dirname(parent.trace_path), exist_ok=True)
logging.info("Tracing context events to: %s", self._file_path)
self.emit_frame({
"type": "context_load",
})
for module in self._modules:
self.emit_frame({
"type": "module_load",
"module": module.serialize(),
})
def add_module(self, module: TracedModule):
self._modules.append(module)
self.emit_frame({
"type": "module_load",
"module": module.serialize(),
})
def start_call(self, function: _binding.VmFunction):
logging.info("Tracing call to %s.%s", function.module_name, function.name)
# Start assembling the call record.
record = {
"type": "call",
"function": "%s.%s" % (function.module_name, function.name),
}
return CallTrace(self, record)
def emit_frame(self, frame: dict):
self._frame_count += 1
with open(self._file_path, "at") as f:
if self._frame_count != 1:
f.write("---\n")
contents = yaml.dump(frame, sort_keys=False)
f.write(contents)
class CallTrace:
def __init__(self, parent: ContextTracer, record: dict):
self._parent = parent
self._record = record
def add_vm_list(self, vm_list: _binding.VmVariantList, key: str):
mapped = []
for i in range(len(vm_list)):
mapped.append(vm_list.get_serialized_trace_value(i))
self._record[key] = mapped
def end_call(self):
self._parent.emit_frame(self._record)
def get_default_tracer() -> Optional[Tracer]:
"""Gets a default call tracer based on environment variables."""
default_path = os.getenv(TRACE_PATH_ENV_KEY)
if not default_path:
return None
return Tracer(default_path)