| """Tracing support.""" |
| |
| # Copyright 2021 The IREE Authors |
| # |
| # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| from genericpath import exists |
| from typing import Dict, List, Optional, Sequence |
| |
| import logging |
| import os |
| import sys |
| |
| from . import binding as _binding |
| |
| try: |
| import yaml |
| except ModuleNotFoundError: |
| _has_yaml = False |
| else: |
| _has_yaml = True |
| |
| __all__ = [ |
| "get_default_tracer", |
| "Tracer", |
| "TRACE_PATH_ENV_KEY", |
| ] |
| |
| TRACE_PATH_ENV_KEY = "IREE_SAVE_CALLS" |
| |
| |
| class Tracer: |
| """Object for tracing calls made into the runtime.""" |
| |
| def __init__(self, trace_path: str): |
| if not _has_yaml: |
| self.enabled = False |
| logging.warning("PyYAML not installed: tracing will be disabled") |
| return |
| self.enabled = True |
| self.trace_path = trace_path |
| os.makedirs(trace_path, exist_ok=True) |
| self._name_count = dict() # type: Dict[str, int] |
| |
| def persist_vm_module(self, vm_module: _binding.VmModule) -> "TracedModule": |
| # Depending on how the module was created, there are different bits |
| # of information available to reconstruct. |
| name = vm_module.name |
| flatbuffer_blob = vm_module.stashed_flatbuffer_blob |
| if flatbuffer_blob: |
| save_path = os.path.join(self.trace_path, |
| self.get_unique_name(f"{name}.vmfb")) |
| logging.info("Saving traced vmfb to %s", save_path) |
| with open(save_path, "wb") as f: |
| f.write(flatbuffer_blob) |
| return TracedModule(self, vm_module, save_path) |
| |
| # No persistent form, but likely they are built-in modules. |
| return TracedModule(self, vm_module) |
| |
| def get_unique_name(self, local_name: str) -> str: |
| if local_name not in self._name_count: |
| self._name_count[local_name] = 1 |
| return local_name |
| stem, ext = os.path.splitext(local_name) |
| index = self._name_count[local_name] |
| self._name_count[local_name] += 1 |
| unique_name = f"{stem}__{index}{ext}" |
| return unique_name |
| |
| |
| class TracedModule: |
| """Wraps a VmModule with additional information for tracing.""" |
| |
| def __init__(self, |
| parent: Tracer, |
| vm_module: _binding.VmModule, |
| vmfb_path: Optional[str] = None): |
| self._parent = parent |
| self._vm_module = vm_module |
| self._vmfb_path = vmfb_path |
| |
| def serialize(self): |
| module_record = {"name": self._vm_module.name} |
| if self._vmfb_path: |
| module_record["type"] = "bytecode" |
| module_record["path"] = os.path.relpath(self._vmfb_path, |
| self._parent.trace_path) |
| else: |
| module_record["type"] = "builtin" |
| |
| return module_record |
| |
| |
| class ContextTracer: |
| """Traces invocations against a context.""" |
| |
| def __init__(self, parent: Tracer, is_dynamic: bool, |
| modules: Sequence[TracedModule]): |
| self._parent = parent |
| self._modules = list(modules) # type: List[TracedModule] |
| self._frame_count = 0 |
| self._file_path = os.path.join(parent.trace_path, |
| parent.get_unique_name("calls.yaml")) |
| if os.path.exists(self._file_path): |
| # Truncate the file. |
| with open(self._file_path, "wt"): |
| pass |
| else: |
| os.makedirs(os.path.dirname(parent.trace_path), exist_ok=True) |
| logging.info("Tracing context events to: %s", self._file_path) |
| self.emit_frame({ |
| "type": "context_load", |
| }) |
| for module in self._modules: |
| self.emit_frame({ |
| "type": "module_load", |
| "module": module.serialize(), |
| }) |
| |
| def add_module(self, module: TracedModule): |
| self._modules.append(module) |
| self.emit_frame({ |
| "type": "module_load", |
| "module": module.serialize(), |
| }) |
| |
| def start_call(self, function: _binding.VmFunction): |
| logging.info("Tracing call to %s.%s", function.module_name, function.name) |
| |
| # Start assembling the call record. |
| record = { |
| "type": "call", |
| "function": "%s.%s" % (function.module_name, function.name), |
| } |
| return CallTrace(self, record) |
| |
| def emit_frame(self, frame: dict): |
| self._frame_count += 1 |
| with open(self._file_path, "at") as f: |
| if self._frame_count != 1: |
| f.write("---\n") |
| contents = yaml.dump(frame, sort_keys=False) |
| f.write(contents) |
| |
| |
| class CallTrace: |
| |
| def __init__(self, parent: ContextTracer, record: dict): |
| self._parent = parent |
| self._record = record |
| |
| def add_vm_list(self, vm_list: _binding.VmVariantList, key: str): |
| mapped = [] |
| for i in range(len(vm_list)): |
| mapped.append(vm_list.get_serialized_trace_value(i)) |
| self._record[key] = mapped |
| |
| def end_call(self): |
| self._parent.emit_frame(self._record) |
| |
| |
| def get_default_tracer() -> Optional[Tracer]: |
| """Gets a default call tracer based on environment variables.""" |
| default_path = os.getenv(TRACE_PATH_ENV_KEY) |
| if not default_path: |
| return None |
| return Tracer(default_path) |