Adding initial dispatch instrumention support. (#12357)
This adds a few new `hal.instrument.*` ops, a pass that instruments
dispatches to a basic level, LLVM CPU support, a runtime tooling flag,
and a prototype tool to dump the instrument data.
At the core of this is support for a new compiler-generated function
`__query_instruments` that allows modules to pass back a list of
buffers. The `--instrument_file=` tooling flag will gather all buffers
from all modules and concatenate them together into a binary file. The
format of the resulting file is defined by a chunked transport stream
containing today just the dispatch instrumentation chunk types.
Dispatch instrumentation will be enabled when the
`--iree-hal-instrument-dispatches=` flag is set to a power-of-two buffer
size. Most programs can usually get by with 16mib while memory access
instrumentation may require 256mib or 2gib.
On the CPU side the `--iree-llvmcpu-instrument-memory-accesses=true`
flag will enable tracking every load/store from/to a memref (scalars and
vectors) by address and length. This can be used to observe memory
access patterns and the addresses being accessed by particular
workgroups. We should be able to support this on other backends
(definitely CUDA, but possibly with SPIR-V using relative buffer offsets
or something).
In addition to tracking workgroup launches and optionally memory
accesses there are also placeholders for printf-style string formatting
and value probes. `hal.instrument.print` still needs conversion work in
each backend and though `hal.instrument.value` works there's no nice way
of inserting them today.
Example commands for getting a memory access dump (16MB is pretty small
for this, 2gib is better when access tracking is enabled):
```sh
iree-compile \
--iree-hal-target-backends=llvm-cpu \
--iree-hal-instrument-dispatches=16mib \
--iree-llvmcpu-instrument-memory-accesses=true \
runtime/src/iree/runtime/testdata/simple_mul.mlir \
-o=simple_mul_instr.vmfb
iree-run-module \
--device=local-sync \
--module=simple_mul_instr.vmfb \
--function=simple_mul \
--input=4xf32=2 \
--input=4xf32=4 \
--instrument_file=instrument.bin
iree-dump-instruments instrument.bin
```
Expected output for a simple_mul which has 4x4096 (multiple workgroups),
note that export sources are listed as well as all sites to that export
(in this case just one):
```
$ ../iree-build/tools/iree-dump-instruments instrument_mem.bin
//===----------------------------------------------------------------------===//
// export[0]: simple_mul_dispatch_0_generic_16384
//===----------------------------------------------------------------------===//
func.func @simple_mul_dispatch_0_generic_16384(%arg0: !stream.binding {stream.alignment = 64 : index}, %arg1: !stream.binding {stream.alignment = 64 : index}, %arg2: !stream.binding {stream.alignment = 64 : index}) {
%c0 = arith.constant 0 : index
%0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16384xf32>>
%1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16384xf32>>
%2 = stream.binding.subspan %arg2[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<16384xf32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [16384], strides = [1] : !flow.dispatch.tensor<readonly:tensor<16384xf32>> -> tensor<16384xf32>
%4 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [16384], strides = [1] : !flow.dispatch.tensor<readonly:tensor<16384xf32>> -> tensor<16384xf32>
%5 = tensor.empty() : tensor<16384xf32>
%6 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%3, %4 : tensor<16384xf32>, tensor<16384xf32>) outs(%5 : tensor<16384xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%7 = arith.mulf %in, %in_0 : f32
linalg.yield %7 : f32
} -> tensor<16384xf32>
flow.dispatch.tensor.store %6, %2, offsets = [0], sizes = [16384], strides = [1] : tensor<16384xf32> -> !flow.dispatch.tensor<writeonly:tensor<16384xf32>>
return
}
//===----------------------------------------------------------------------===//
// dispatch site 0: simple_mul_dispatch_0_generic_16384
//===----------------------------------------------------------------------===//
0000000000000000 | WORKGROUP dispatch(0 simple_mul_dispatch_0_generic_16384 4x1x1) 0,0,0 pid:52
0000000000000000 | LOAD 000002705bc3ad80 16
0000000000000000 | LOAD 000002705bc4ae80 16
0000000000000000 | STORE 000002705bc7a100 16
0000000000000000 | LOAD 000002705bc3ad90 16
0000000000000000 | LOAD 000002705bc4ae90 16
0000000000000000 | STORE 000002705bc7a110 16
0000000000000000 | LOAD 000002705bc3ada0 16
0000000000000000 | LOAD 000002705bc4aea0 16
0000000000000000 | STORE 000002705bc7a120 16
0000000000000000 | LOAD 000002705bc3adb0 16
0000000000000000 | LOAD 000002705bc4aeb0 16
0000000000000000 | STORE 000002705bc7a130 16
0000000000000000 | LOAD 000002705bc3adc0 16
0000000000000000 | LOAD 000002705bc4aec0 16
0000000000000000 | STORE 000002705bc7a140 16
...
000000000000c040 | WORKGROUP dispatch(0 simple_mul_dispatch_0_generic_16384 4x1x1) 1,0,0 pid:52
000000000000c040 | LOAD 000002705bc3ed80 16
000000000000c040 | LOAD 000002705bc4ee80 16
000000000000c040 | STORE 000002705bc7e100 16
...
```
The printed data is currently just a proof of concept - the first column
is the workgroup key - more advanced visualizations can do better things
(still WIP, showing a dispatch listing and memory accesses by a
particular dispatch):



There's still some iteration needed on printing values. For now a simple
test with this:
```diff
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
index 55bab04a8..24abf2255 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
@@ -442,6 +442,22 @@ struct ConvertHALInstrumentWorkgroupOp
rewriter.create<LLVM::ConstantOp>(loc, i64Type, 0xFFFFFFFFFFll)),
rewriter.create<LLVM::ConstantOp>(loc, i64Type, 24));
+ // HACK: test writing out a value.
+ {
+ Value valueOperand = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getUI32IntegerAttr(0xFFFFFFFFu));
+ rewriter.create<IREE::HAL::InstrumentValueOp>(
+ loc, valueOperand.getType(), instrumentOp.getBuffer(), workgroupKey,
+ rewriter.getI8IntegerAttr(0), valueOperand);
+ }
+ {
+ Value valueOperand = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getF32FloatAttr(1.234f));
+ rewriter.create<IREE::HAL::InstrumentValueOp>(
+ loc, valueOperand.getType(), instrumentOp.getBuffer(), workgroupKey,
+ rewriter.getI8IntegerAttr(1), valueOperand);
+ }
+
rewriter.replaceOp(instrumentOp, workgroupKey);
return success();
}
```
will produce this output:
```
0000000000000000 | WORKGROUP dispatch(0 simple_mul_dispatch_0_generic_16384 4x1x1) 0,0,0 pid:50
0000000000000000 | VALUE 0000 = 4294967295
0000000000000000 | VALUE 0001 = 1.234000e+00 1.234000
0000000000000040 | WORKGROUP dispatch(0 simple_mul_dispatch_0_generic_16384 4x1x1) 1,0,0 pid:50
0000000000000040 | VALUE 0000 = 4294967295
0000000000000040 | VALUE 0001 = 1.234000e+00 1.234000
0000000000000080 | WORKGROUP dispatch(0 simple_mul_dispatch_0_generic_16384 4x1x1) 2,0,0 pid:50
0000000000000080 | VALUE 0000 = 4294967295
0000000000000080 | VALUE 0001 = 1.234000e+00 1.234000
00000000000000C0 | WORKGROUP dispatch(0 simple_mul_dispatch_0_generic_16384 4x1x1) 3,0,0 pid:50
00000000000000C0 | VALUE 0000 = 4294967295
00000000000000C0 | VALUE 0001 = 1.234000e+00 1.234000
```
Ergonomics-wise for things like the value instrumentation we could try
adding dialect attributes or something earlier on that tracks values or
flagged patterns/etc for particular experiments.
In the future we can add both additional dispatch instrumentation and
new chunk types from various modules. Examples include
compiler-generated profile-guided optimization markers, VM code coverage
markers, HAL counters for allocations or submissions, or device
timestamp streams.
Only the CPU backend is supported right now as that's what I'm familiar
with. I'd hoped to do the `hal.instrument.*` op lowering to memrefs so
it could be shared but memref descriptors are unfortunately still a
thing and with something tracking every single memory access it created
hundreds of thousands of instructions. I eagerly await the day when
someone works to kill memref descriptors.
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD b/compiler/src/iree/compiler/Codegen/Common/BUILD
index fc321ee..cb4597a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD
@@ -122,6 +122,7 @@
"GPUPipelining.cpp",
"HoistStaticallyBoundAllocations.cpp",
"IREEComprehensiveBufferizePass.cpp",
+ "InstrumentMemoryAccesses.cpp",
"LowerUKernelsToCalls.cpp",
"MaterializeEncodingIntoPackUnPack.cpp",
"OptimizeVectorTransferPass.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index 304afa9..e887349 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -97,6 +97,7 @@
"GPUPipelining.cpp"
"HoistStaticallyBoundAllocations.cpp"
"IREEComprehensiveBufferizePass.cpp"
+ "InstrumentMemoryAccesses.cpp"
"LowerUKernelsToCalls.cpp"
"MaterializeEncodingIntoPackUnPack.cpp"
"OptimizeVectorTransferPass.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Common/InstrumentMemoryAccesses.cpp b/compiler/src/iree/compiler/Codegen/Common/InstrumentMemoryAccesses.cpp
new file mode 100644
index 0000000..dabac41
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/InstrumentMemoryAccesses.cpp
@@ -0,0 +1,94 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+struct InstrumentMemoryAccessesPass
+ : InstrumentMemoryAccessesBase<InstrumentMemoryAccessesPass> {
+ void runOnOperation() override {
+ // Lookup the root instrumentation op. If not present it means the dispatch
+ // is not instrumented and we can skip it.
+ IREE::HAL::InstrumentWorkgroupOp instrumentOp;
+ getOperation().walk([&](IREE::HAL::InstrumentWorkgroupOp op) {
+ instrumentOp = op;
+ return WalkResult::interrupt();
+ });
+ if (!instrumentOp) {
+ // Not instrumented.
+ return;
+ }
+
+ auto buffer = instrumentOp.getBuffer();
+ auto workgroupKey = instrumentOp.getWorkgroupKey();
+ getOperation()->walk([&](Operation *op) {
+ TypeSwitch<Operation *>(op)
+ .Case<memref::LoadOp>([&](auto loadOp) {
+ OpBuilder builder(loadOp);
+ builder.setInsertionPointAfter(loadOp);
+ auto instrumentOp =
+ builder.create<IREE::HAL::InstrumentMemoryLoadOp>(
+ loadOp.getLoc(), loadOp.getResult().getType(), buffer,
+ workgroupKey, loadOp.getResult(), loadOp.getMemRef(),
+ loadOp.getIndices());
+ loadOp.getResult().replaceAllUsesExcept(instrumentOp.getResult(),
+ instrumentOp);
+ })
+ .Case<memref::StoreOp>([&](auto storeOp) {
+ OpBuilder builder(storeOp);
+ auto instrumentOp =
+ builder.create<IREE::HAL::InstrumentMemoryStoreOp>(
+ storeOp.getLoc(), storeOp.getValueToStore().getType(),
+ buffer, workgroupKey, storeOp.getValueToStore(),
+ storeOp.getMemRef(), storeOp.getIndices());
+ storeOp.getValueMutable().assign(instrumentOp.getResult());
+ })
+ .Case<vector::LoadOp>([&](auto loadOp) {
+ OpBuilder builder(loadOp);
+ builder.setInsertionPointAfter(loadOp);
+ auto instrumentOp =
+ builder.create<IREE::HAL::InstrumentMemoryLoadOp>(
+ loadOp.getLoc(), loadOp.getVectorType(), buffer,
+ workgroupKey, loadOp.getResult(), loadOp.getBase(),
+ loadOp.getIndices());
+ loadOp.getResult().replaceAllUsesExcept(instrumentOp.getResult(),
+ instrumentOp);
+ })
+ .Case<vector::StoreOp>([&](auto storeOp) {
+ OpBuilder builder(storeOp);
+ auto instrumentOp =
+ builder.create<IREE::HAL::InstrumentMemoryStoreOp>(
+ storeOp.getLoc(), storeOp.getVectorType(), buffer,
+ workgroupKey, storeOp.getValueToStore(), storeOp.getBase(),
+ storeOp.getIndices());
+ storeOp.getValueToStoreMutable().assign(instrumentOp.getResult());
+ })
+ .Default([&](Operation *) {});
+ });
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+createInstrumentMemoryAccessesPass() {
+ return std::make_unique<InstrumentMemoryAccessesPass>();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
index 5ebf191..5efb99a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
@@ -61,6 +61,7 @@
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialectPasses",
"//runtime/src/iree/builtins/ukernel:exported_bits",
+ "//runtime/src/iree/schemas/instruments",
"@llvm-project//llvm:BinaryFormat",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TargetParser",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
index 693c113..91a49c1 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
@@ -105,6 +105,7 @@
iree::compiler::Dialect::HAL::Utils
iree::compiler::Dialect::Util::IR
iree::compiler::Utils
+ iree::schemas::instruments
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
index 70715fe..307432e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
@@ -12,6 +12,7 @@
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/schemas/instruments/dispatch.h"
#include "llvm/Support/Mutex.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TargetParser/Triple.h"
@@ -304,6 +305,362 @@
}
};
+struct InstrumentationEntry {
+ // !llvm.ptr<i8> pointing at the base of the ringbuffer.
+ Value basePtr;
+ // !llvm.ptr<i8> pointing at the start of the entry (basePtr + offset).
+ Value entryPtr;
+ // i64 offset within the ringbuffer of the entry.
+ Value offset;
+};
+
+// entrySize must be 16-byte aligned
+static InstrumentationEntry acquireInstrumentationEntry(Location loc,
+ Value buffer,
+ Value bufferPtr,
+ Value entrySize,
+ OpBuilder &builder) {
+ auto i64Type = builder.getI64Type();
+ auto bufferType = buffer.getType().cast<MemRefType>();
+ int64_t totalBufferSize =
+ (bufferType.getNumElements() * bufferType.getElementTypeBitWidth()) / 8;
+ int64_t headOffset = totalBufferSize - 8;
+ int64_t ringSize = totalBufferSize - IREE_INSTRUMENT_DISPATCH_PADDING;
+ assert(llvm::isPowerOf2_64(ringSize) &&
+ "ringbuffer storage size must be a power-of-two");
+
+ Value basePtr = MemRefDescriptor(bufferPtr).alignedPtr(builder, loc);
+
+ Value offsetIndex =
+ builder.create<LLVM::ConstantOp>(loc, i64Type, headOffset);
+ Value offsetPtr =
+ builder.create<LLVM::GEPOp>(loc, basePtr.getType(), basePtr, offsetIndex,
+ /*inbounds=*/true);
+ Value offsetPtrI64 = builder.create<LLVM::BitcastOp>(
+ loc, LLVM::LLVMPointerType::get(i64Type), offsetPtr);
+ Value rawOffset = builder.create<LLVM::AtomicRMWOp>(
+ loc, LLVM::AtomicBinOp::add, offsetPtrI64, entrySize,
+ LLVM::AtomicOrdering::monotonic);
+ Value offsetMask =
+ builder.create<LLVM::ConstantOp>(loc, i64Type, ringSize - 1);
+ Value wrappedOffset = builder.create<LLVM::AndOp>(loc, rawOffset, offsetMask);
+
+ Value entryPtr = builder.create<LLVM::GEPOp>(loc, basePtr.getType(), basePtr,
+ wrappedOffset);
+
+ return {basePtr, entryPtr, wrappedOffset};
+}
+
+static InstrumentationEntry appendInstrumentationEntry(
+ Location loc, Value buffer, Value bufferPtr, LLVM::LLVMStructType entryType,
+ ArrayRef<Value> entryValues, DataLayout &dataLayout, OpBuilder &builder) {
+ auto i64Type = builder.getI64Type();
+
+ Value entrySize = builder.create<LLVM::ConstantOp>(
+ loc, i64Type, dataLayout.getTypeSize(entryType));
+ auto entry =
+ acquireInstrumentationEntry(loc, buffer, bufferPtr, entrySize, builder);
+
+ Value entryStruct = builder.create<LLVM::UndefOp>(loc, entryType);
+ for (auto entryValue : llvm::enumerate(entryValues)) {
+ entryStruct = builder.create<LLVM::InsertValueOp>(
+ loc, entryStruct, entryValue.value(), entryValue.index());
+ }
+
+ builder.create<LLVM::StoreOp>(
+ loc, entryStruct,
+ builder.create<LLVM::BitcastOp>(
+ loc, LLVM::LLVMPointerType::get(entryType), entry.entryPtr),
+ /*alignment=*/16);
+
+ return entry;
+}
+
+static int64_t getMemoryAccessByteSize(Type type) {
+ if (auto vectorType = type.dyn_cast<VectorType>()) {
+ return (vectorType.getNumElements() * vectorType.getElementTypeBitWidth()) /
+ 8;
+ } else {
+ return type.getIntOrFloatBitWidth() / 8;
+ }
+}
+
+struct ConvertHALInstrumentWorkgroupOp
+ : public ConvertOpToLLVMWithABIPattern<IREE::HAL::InstrumentWorkgroupOp> {
+ using ConvertOpToLLVMWithABIPattern::ConvertOpToLLVMWithABIPattern;
+ LogicalResult matchAndRewrite(
+ IREE::HAL::InstrumentWorkgroupOp instrumentOp,
+ IREE::HAL::InstrumentWorkgroupOpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = instrumentOp.getLoc();
+ auto dataLayout =
+ getTypeConverter()->getDataLayoutAnalysis()->getAbove(instrumentOp);
+ auto i32Type = rewriter.getI32Type();
+ auto i64Type = rewriter.getI64Type();
+
+ auto entryType = LLVM::LLVMStructType::getLiteral(
+ getContext(), {
+ i32Type, // header
+ i32Type, // workgroup_id_x
+ i32Type, // workgroup_id_y
+ i32Type, // workgroup_id_z
+ i32Type, // workgroup_count_x
+ i32Type, // workgroup_count_y
+ i32Type, // workgroup_count_z
+ i32Type, // processor_id
+ });
+
+ // 8 bit tag = 00 | 24 bit dispatch id
+ // NOTE: we could pre-shift this to avoid needing to do it in each group.
+ // We just need to do the shift - the bottom two bits will be the 00 tag.
+ Value rawDispatchId = instrumentOp.getDispatchId();
+ Value header = rewriter.create<LLVM::ShlOp>(
+ loc, i32Type, rawDispatchId,
+ rewriter.create<LLVM::ConstantOp>(loc, i32Type, 8)); // | 8bit tag
+
+ auto entry = appendInstrumentationEntry(
+ loc, instrumentOp.getBuffer(), operands.getBuffer(), entryType,
+ {
+ header,
+ abi.loadWorkgroupID(instrumentOp, 0, i32Type, rewriter),
+ abi.loadWorkgroupID(instrumentOp, 1, i32Type, rewriter),
+ abi.loadWorkgroupID(instrumentOp, 2, i32Type, rewriter),
+ abi.loadWorkgroupCount(instrumentOp, 0, i32Type, rewriter),
+ abi.loadWorkgroupCount(instrumentOp, 1, i32Type, rewriter),
+ abi.loadWorkgroupCount(instrumentOp, 2, i32Type, rewriter),
+ abi.loadProcessorID(instrumentOp, rewriter),
+ },
+ dataLayout, rewriter);
+
+ // Prepare the 40-bit key used by all accesses - we do this once so that we
+ // can ensure it's hoisted.
+ // Consumers expect 40 bits of offset << 24 bits.
+ Value workgroupKey = rewriter.create<LLVM::ShlOp>(
+ loc,
+ rewriter.create<LLVM::AndOp>(
+ loc, entry.offset,
+ rewriter.create<LLVM::ConstantOp>(loc, i64Type, 0xFFFFFFFFFFll)),
+ rewriter.create<LLVM::ConstantOp>(loc, i64Type, 24));
+
+ rewriter.replaceOp(instrumentOp, workgroupKey);
+ return success();
+ }
+};
+
+static Optional<uint64_t> mapValueType(Type type) {
+ return TypeSwitch<Type, Optional<uint64_t>>(type)
+ .Case<IntegerType>([&](Type type) -> Optional<uint64_t> {
+ if (type.isUnsignedInteger()) {
+ switch (type.getIntOrFloatBitWidth()) {
+ case 8:
+ return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_UINT_8;
+ case 16:
+ return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_UINT_16;
+ case 32:
+ return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_UINT_32;
+ case 64:
+ return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_UINT_64;
+ default:
+ return std::nullopt;
+ }
+ }
+ switch (type.getIntOrFloatBitWidth()) {
+ case 8:
+ return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_SINT_8;
+ case 16:
+ return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_SINT_16;
+ case 32:
+ return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_SINT_32;
+ case 64:
+ return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_SINT_64;
+ default:
+ return std::nullopt;
+ }
+ })
+ .Case<FloatType>([&](Type type) -> Optional<uint64_t> {
+ if (type.isBF16()) {
+ return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_BFLOAT_16;
+ }
+ switch (type.getIntOrFloatBitWidth()) {
+ case 16:
+ return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_FLOAT_16;
+ case 32:
+ return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_FLOAT_32;
+ case 64:
+ return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_FLOAT_64;
+ default:
+ return std::nullopt;
+ }
+ })
+ .Case<IndexType>([&](Type type) -> Optional<uint64_t> {
+ return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_SINT_64;
+ })
+ .Default([&](Type) -> Optional<uint64_t> { return std::nullopt; });
+}
+
+struct ConvertHALInstrumentValueOp
+ : public ConvertOpToLLVMWithABIPattern<IREE::HAL::InstrumentValueOp> {
+ using ConvertOpToLLVMWithABIPattern::ConvertOpToLLVMWithABIPattern;
+ LogicalResult matchAndRewrite(
+ IREE::HAL::InstrumentValueOp instrumentOp,
+ IREE::HAL::InstrumentValueOpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = instrumentOp.getLoc();
+
+ // Only convert ops we can handle, otherwise warn and discard.
+ Optional<uint64_t> valueType;
+ if (operands.getOperand().getType().isa<LLVM::LLVMPointerType>()) {
+ valueType = IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_POINTER;
+ } else {
+ valueType = mapValueType(instrumentOp.getType());
+ }
+ if (!valueType) {
+ mlir::emitWarning(loc,
+ "skipping hal.instrument.value on unsupported type: ")
+ << instrumentOp.getType();
+ rewriter.replaceOp(instrumentOp, {operands.getOperand()});
+ return success();
+ }
+
+ auto dataLayout =
+ getTypeConverter()->getDataLayoutAnalysis()->getAbove(instrumentOp);
+ auto i64Type = rewriter.getI64Type();
+
+ auto entryType =
+ LLVM::LLVMStructType::getLiteral(getContext(), {
+ i64Type, // header
+ i64Type, // value
+ });
+
+ // 8 bit tag
+ // 8 bit type
+ // 8 bit ordinal
+ // 40 bit workgroup offset
+ Value header = rewriter.create<LLVM::OrOp>(
+ loc, operands.getWorkgroupKey(),
+ rewriter.create<LLVM::ConstantOp>(
+ loc, i64Type,
+ (instrumentOp.getOrdinal().getZExtValue() << 16) |
+ (valueType.value() << 8) |
+ IREE_INSTRUMENT_DISPATCH_TYPE_VALUE));
+
+ // Bitcast to an integer and widen to 64 bits.
+ Value bits = rewriter.create<LLVM::ZExtOp>(
+ loc, i64Type,
+ rewriter.create<LLVM::BitcastOp>(
+ loc,
+ rewriter.getIntegerType(
+ instrumentOp.getType().getIntOrFloatBitWidth()),
+ operands.getOperand()));
+
+ appendInstrumentationEntry(loc, instrumentOp.getBuffer(),
+ operands.getBuffer(), entryType,
+ {
+ header,
+ bits,
+ },
+ dataLayout, rewriter);
+
+ rewriter.replaceOp(instrumentOp, operands.getOperand());
+ return success();
+ }
+};
+
+struct ConvertHALInstrumentMemoryLoadOp
+ : public ConvertOpToLLVMWithABIPattern<IREE::HAL::InstrumentMemoryLoadOp> {
+ using ConvertOpToLLVMWithABIPattern::ConvertOpToLLVMWithABIPattern;
+ LogicalResult matchAndRewrite(
+ IREE::HAL::InstrumentMemoryLoadOp instrumentOp,
+ IREE::HAL::InstrumentMemoryLoadOpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = instrumentOp.getLoc();
+ auto dataLayout =
+ getTypeConverter()->getDataLayoutAnalysis()->getAbove(instrumentOp);
+ auto i64Type = rewriter.getI64Type();
+
+ auto entryType =
+ LLVM::LLVMStructType::getLiteral(getContext(), {
+ i64Type, // header
+ i64Type, // address
+ });
+
+ // 8 bit tag = 100 (read), 101 (write)
+ // 16 bit length
+ // 40 bit workgroup offset
+ int64_t loadSize = getMemoryAccessByteSize(instrumentOp.getType());
+ assert(loadSize <= UINT16_MAX && "16-bit length maximum");
+ Value header = rewriter.create<LLVM::OrOp>(
+ loc, operands.getWorkgroupKey(),
+ rewriter.create<LLVM::ConstantOp>(
+ loc, i64Type,
+ (loadSize << 8) | IREE_INSTRUMENT_DISPATCH_TYPE_MEMORY_LOAD));
+
+ Value loadPtr = getStridedElementPtr(
+ loc, instrumentOp.getBase().getType().cast<MemRefType>(),
+ operands.getBase(), operands.getIndices(), rewriter);
+ Value addressI64 = rewriter.create<LLVM::PtrToIntOp>(loc, i64Type, loadPtr);
+
+ appendInstrumentationEntry(loc, instrumentOp.getBuffer(),
+ operands.getBuffer(), entryType,
+ {
+ header,
+ addressI64,
+ },
+ dataLayout, rewriter);
+
+ rewriter.replaceOp(instrumentOp, operands.getLoadValue());
+ return success();
+ }
+};
+
+struct ConvertHALInstrumentMemoryStoreOp
+ : public ConvertOpToLLVMWithABIPattern<IREE::HAL::InstrumentMemoryStoreOp> {
+ using ConvertOpToLLVMWithABIPattern::ConvertOpToLLVMWithABIPattern;
+ LogicalResult matchAndRewrite(
+ IREE::HAL::InstrumentMemoryStoreOp instrumentOp,
+ IREE::HAL::InstrumentMemoryStoreOpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = instrumentOp.getLoc();
+ auto dataLayout =
+ getTypeConverter()->getDataLayoutAnalysis()->getAbove(instrumentOp);
+ auto i64Type = rewriter.getI64Type();
+
+ auto entryType =
+ LLVM::LLVMStructType::getLiteral(getContext(), {
+ i64Type, // header
+ i64Type, // address
+ });
+
+ // 8 bit tag = 10 (read), 11 (write)
+ // 16 bit length
+ // 40 bit workgroup offset
+ int64_t storeSize = getMemoryAccessByteSize(instrumentOp.getType());
+ assert(storeSize <= UINT16_MAX && "16-bit length maximum");
+ Value header = rewriter.create<LLVM::OrOp>(
+ loc, operands.getWorkgroupKey(),
+ rewriter.create<LLVM::ConstantOp>(
+ loc, i64Type,
+ (storeSize << 8) | IREE_INSTRUMENT_DISPATCH_TYPE_MEMORY_STORE));
+
+ Value storePtr = getStridedElementPtr(
+ loc, instrumentOp.getBase().getType().cast<MemRefType>(),
+ operands.getBase(), operands.getIndices(), rewriter);
+ Value addressI64 =
+ rewriter.create<LLVM::PtrToIntOp>(loc, i64Type, storePtr);
+
+ appendInstrumentationEntry(loc, instrumentOp.getBuffer(),
+ operands.getBuffer(), entryType,
+ {
+ header,
+ addressI64,
+ },
+ dataLayout, rewriter);
+
+ rewriter.replaceOp(instrumentOp, operands.getStoreValue());
+ return success();
+ }
+};
+
/// Rewrites calls to extern functions to dynamic library import calls.
/// The parent LLVMFuncOp must be compatible with HALDispatchABI.
///
@@ -545,7 +902,11 @@
ConvertHALInterfaceWorkgroupSizeOp,
ConvertHALInterfaceWorkgroupCountOp,
ConvertHALInterfaceConstantLoadOp,
- ConvertHALInterfaceBindingSubspanOp
+ ConvertHALInterfaceBindingSubspanOp,
+ ConvertHALInstrumentWorkgroupOp,
+ ConvertHALInstrumentValueOp,
+ ConvertHALInstrumentMemoryLoadOp,
+ ConvertHALInstrumentMemoryStoreOp
>(abi, typeConverter);
// clang-format on
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 18dd53f..b5c41f5 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -69,6 +69,12 @@
llvm::cl::desc("Enables reassociation for FP reductions"),
llvm::cl::init(false));
+static llvm::cl::opt<bool> clInstrumentMemoryAccesses{
+ "iree-llvmcpu-instrument-memory-accesses",
+ llvm::cl::desc("Instruments memory accesses in dispatches when dispatch "
+ "instrumentation is enabled."),
+ llvm::cl::init(false)};
+
// MLIR file containing a top-level module that specifies the transformations to
// apply to form dispatch regions.
// Defined externally in KernelDispatch.cpp to control the codegen pass
@@ -742,6 +748,10 @@
// (HAL, IREE, Linalg, CF) -> LLVM
passManager.addNestedPass<func::FuncOp>(arith::createArithExpandOpsPass());
passManager.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
+ if (clInstrumentMemoryAccesses) {
+ passManager.addNestedPass<func::FuncOp>(
+ createInstrumentMemoryAccessesPass());
+ }
passManager.addPass(createConvertToLLVMPass(clEnableReassociateFpReductions));
passManager.addPass(createReconcileUnrealizedCastsPass());
diff --git a/compiler/src/iree/compiler/Codegen/Passes.h b/compiler/src/iree/compiler/Codegen/Passes.h
index 01382d3..9ffa60d 100644
--- a/compiler/src/iree/compiler/Codegen/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/Passes.h
@@ -221,6 +221,10 @@
std::unique_ptr<OperationPass<func::FuncOp>>
createRematerializeParallelOpsPass();
+/// Instruments memory reads and writes for address tracking.
+std::unique_ptr<OperationPass<func::FuncOp>>
+createInstrumentMemoryAccessesPass();
+
//----------------------------------------------------------------------------//
// Common codegen patterns.
//----------------------------------------------------------------------------//
diff --git a/compiler/src/iree/compiler/Codegen/Passes.td b/compiler/src/iree/compiler/Codegen/Passes.td
index cf705b5..fc7a3d7 100644
--- a/compiler/src/iree/compiler/Codegen/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Passes.td
@@ -336,6 +336,12 @@
let constructor = "mlir::iree_compiler::createRematerializeParallelOpsPass()";
}
+def InstrumentMemoryAccesses :
+ Pass<"iree-codegen-instrument-memory-accesses", "func::FuncOp"> {
+ let summary = "Instruments memory reads and writes for address tracking when dispatch instrumentation is enabled.";
+ let constructor = "mlir::iree_compiler::createInstrumentMemoryAccessesPass()";
+}
+
//------------------------------------------------------------------------------
// LLVMCPU
//------------------------------------------------------------------------------
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index 038baa8..45247ff 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -2032,6 +2032,150 @@
}
//===----------------------------------------------------------------------===//
+// hal.instrument.*
+//===----------------------------------------------------------------------===//
+
+def HAL_InstrumentWorkgroupOp : HAL_Op<"instrument.workgroup", []> {
+ let summary = [{emits a dispatch workgroup instrumentation event}];
+ let description = [{
+ Emits an `iree_instrument_dispatch_workgroup_t` event into the
+ instrumentation stream. The workgroup event identifies the unique dispatch,
+ its workgroup count, and the ID of the emitting workgroup within the
+ dispatch. Optionally targets that support querying the processor ID
+ executing the workgroup can attach that information for tracking purposes.
+
+ On targets such as CPUs where entire workgroups execute as atomic units
+ only one workgroup event should be emitted. On targets such as GPUs where
+ there may be multiple invocations executing as part of a single workgroup
+ only the first invocation within the workgroup should emit the workgroup
+ event (by checking if the LocalInvocationIndex or threadIdx == 0, etc).
+
+ The resulting workgroup key is used by subsequent workgroup-specific
+ instrumentation events.
+ }];
+
+ let arguments = (ins
+ AnyMemRef:$buffer,
+ I32:$dispatchId
+ );
+ let results = (outs
+ Index:$workgroupKey
+ );
+
+ let assemblyFormat = [{
+ `` `[` $buffer `:` type($buffer) `]`
+ `dispatch` `(` $dispatchId `)`
+ attr-dict `:` type($workgroupKey)
+ }];
+}
+
+def HAL_InstrumentPrintOp : HAL_Op<"instrument.print", []> {
+ let summary = [{emits a human-readable printf-style string event}];
+ let description = [{
+ Formats a string using a limited subset of printf format specifiers and the
+ provided values and then emits an `iree_instrument_dispatch_print_t` event. Final
+ formatted string lengths may be limited to as much as 1024 characters and
+ should be kept as small as possible to avoid easily exceeding the
+ instrumentation storage buffers with redundant strings.
+ }];
+
+ let arguments = (ins
+ AnyMemRef:$buffer,
+ Index:$workgroupKey,
+ StrAttr:$format,
+ Variadic<AnyType>:$values
+ );
+
+ let assemblyFormat = [{
+ `` `[` $buffer `:` type($buffer) `for` $workgroupKey `]`
+ $format (`*` `(` $values^ `:` type($values) `)`)?
+ attr-dict
+ }];
+}
+
+def HAL_InstrumentValueOp : HAL_Op<"instrument.value", [
+ AllTypesMatch<["operand", "result"]>,
+ ]> {
+ let summary = [{emits a scalar value instrumentation event}];
+ let description = [{
+ Emits a workgroup-specific typed value with the given workgroup-relative
+ ordinal.
+
+ This op will be preserved even if the output is not used as it is only for
+ debugging purposes.
+ }];
+
+ let arguments = (ins
+ AnyMemRef:$buffer,
+ Index:$workgroupKey,
+ AnyI8Attr:$ordinal,
+ AnyType:$operand
+ );
+ let results = (outs
+ AnyType:$result
+ );
+
+ let assemblyFormat = [{
+ `` `[` $buffer `:` type($buffer) `for` $workgroupKey `]`
+ $ordinal `=` $operand attr-dict `:` type($operand)
+ }];
+}
+
+def HAL_InstrumentMemoryLoadOp : HAL_PureOp<"instrument.memory.load", [
+ AllTypesMatch<["loadValue", "result"]>,
+ ]> {
+ let summary = [{emits a memory load instrumentation event}];
+ let description = [{
+ Emits a workgroup-specific memory load event indicating that a number of
+ bytes from the given resolved pointer have been loaded by the workgroup.
+ }];
+
+ let arguments = (ins
+ AnyMemRef:$buffer,
+ Index:$workgroupKey,
+ AnyType:$loadValue,
+ AnyMemRef:$base,
+ Variadic<Index>:$indices
+ );
+ let results = (outs
+ AnyType:$result
+ );
+
+ let assemblyFormat = [{
+ `` `[` $buffer `:` type($buffer) `for` $workgroupKey `]`
+ $base `[` $indices `]` `,` $loadValue
+ attr-dict `:` type($base) `,` type($result)
+ }];
+}
+
+def HAL_InstrumentMemoryStoreOp : HAL_PureOp<"instrument.memory.store", [
+ AllTypesMatch<["storeValue", "result"]>,
+ ]> {
+ let summary = [{emits a memory store instrumentation event}];
+ let description = [{
+ Emits a workgroup-specific memory store event indicating that a number of
+ bytes have been stored to the given resolved pointer by the workgroup.
+ }];
+
+ let arguments = (ins
+ AnyMemRef:$buffer,
+ Index:$workgroupKey,
+ AnyType:$storeValue,
+ AnyMemRef:$base,
+ Variadic<Index>:$indices
+ );
+ let results = (outs
+ AnyType:$result
+ );
+
+ let assemblyFormat = [{
+ `` `[` $buffer `:` type($buffer) `for` $workgroupKey `]`
+ $base `[` $indices `]` `,` $storeValue
+ attr-dict `:` type($base) `,` type($result)
+ }];
+}
+
+//===----------------------------------------------------------------------===//
// hal.interface
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD
index c83c27f..ba624eb 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD
@@ -24,6 +24,7 @@
"FixupLegacySync.cpp",
"InlineDeviceSwitches.cpp",
"LinkExecutables.cpp",
+ "MaterializeDispatchInstrumentation.cpp",
"MaterializeInterfaces.cpp",
"MaterializeResourceCaches.cpp",
"MemoizeDeviceQueries.cpp",
@@ -56,6 +57,8 @@
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Dialect/Util/Transforms",
"//compiler/src/iree/compiler/Utils",
+ "//runtime/src/iree/schemas/instruments",
+ "//runtime/src/iree/schemas/instruments:dispatch_def_c_fbs",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineToStandard",
"@llvm-project//mlir:ArithDialect",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
index 26164d5..10ad946 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
@@ -25,6 +25,7 @@
"FixupLegacySync.cpp"
"InlineDeviceSwitches.cpp"
"LinkExecutables.cpp"
+ "MaterializeDispatchInstrumentation.cpp"
"MaterializeInterfaces.cpp"
"MaterializeResourceCaches.cpp"
"MemoizeDeviceQueries.cpp"
@@ -67,6 +68,8 @@
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::Util::Transforms
iree::compiler::Utils
+ iree::schemas::instruments
+ iree::schemas::instruments::dispatch_def_c_fbs
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeDispatchInstrumentation.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeDispatchInstrumentation.cpp
new file mode 100644
index 0000000..5426cc9
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeDispatchInstrumentation.cpp
@@ -0,0 +1,392 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <memory>
+#include <utility>
+
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Utils/FlatbufferUtils.h"
+#include "iree/schemas/instruments/dispatch.h"
+#include "iree/schemas/instruments/dispatch_def_builder.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+static std::string getAttrStr(Attribute attr) {
+ if (!attr) return "";
+ std::string result;
+ llvm::raw_string_ostream os(result);
+ attr.print(os, /*elideType=*/true);
+ return result;
+}
+
+static std::string getOpStr(Operation *op) {
+ std::string result;
+ llvm::raw_string_ostream os(result);
+ OpPrintingFlags flags;
+ flags.useLocalScope();
+ flags.assumeVerified();
+ op->print(os, flags);
+ return result;
+}
+
+// Returns a data vector containing a iree_idbts_chunk_header_t with |type|.
+// The declared |contentLength| excludes padding.
+static Value createChunkHeader(Location loc, iree_idbts_chunk_type_t type,
+ uint64_t contentLength, OpBuilder &builder) {
+ iree_idbts_chunk_header_t header;
+ header.magic = IREE_IDBTS_CHUNK_MAGIC;
+ header.type = type;
+ header.version = 0;
+ header.content_length = contentLength;
+
+ auto dataAttr = DenseElementsAttr::getFromRawBuffer(
+ VectorType::get({sizeof(header)}, builder.getI8Type()),
+ ArrayRef<char>(reinterpret_cast<const char *>(&header), sizeof(header)));
+
+ return builder.create<IREE::Util::BufferConstantOp>(
+ loc, /*name=*/nullptr, dataAttr, builder.getIndexAttr(16),
+ /*mimeType=*/nullptr);
+}
+
+// Returns a zero padding vector if |unalignedLength| needs alignment or null.
+static Value createPadding(Location loc, uint64_t unalignedLength,
+ OpBuilder &builder) {
+ uint64_t padding = llvm::alignTo(unalignedLength, 16) - unalignedLength;
+ if (!padding) return nullptr;
+ auto i8Type = builder.getI8Type();
+ auto zeroAttr = IntegerAttr::get(i8Type, 0);
+ auto dataAttr = DenseElementsAttr::get(
+ VectorType::get({(int64_t)padding}, i8Type), zeroAttr);
+ return builder.create<IREE::Util::BufferConstantOp>(
+ loc, /*name=*/nullptr, dataAttr, builder.getIndexAttr(16),
+ /*mimeType=*/nullptr);
+}
+
+static void appendListItems(Location loc, Value list, ArrayRef<Value> items,
+ OpBuilder &builder) {
+ Value oldLength = builder.create<IREE::Util::ListSizeOp>(loc, list);
+ Value newLength = builder.create<arith::AddIOp>(
+ loc, oldLength,
+ builder.create<arith::ConstantIndexOp>(loc, items.size()));
+ builder.create<IREE::Util::ListResizeOp>(loc, list, newLength);
+ for (size_t i = 0; i < items.size(); ++i) {
+ Value idx = builder.create<arith::AddIOp>(
+ loc, oldLength, builder.create<arith::ConstantIndexOp>(loc, i));
+ builder.create<IREE::Util::ListSetOp>(loc, list, idx, items[i]);
+ }
+}
+
+class MaterializeDispatchInstrumentationPass
+ : public PassWrapper<MaterializeDispatchInstrumentationPass,
+ OperationPass<mlir::ModuleOp>> {
+ public:
+ MaterializeDispatchInstrumentationPass() = default;
+ MaterializeDispatchInstrumentationPass(
+ const MaterializeDispatchInstrumentationPass &pass) {}
+ explicit MaterializeDispatchInstrumentationPass(int64_t bufferSize) {
+ this->bufferSize = bufferSize;
+ }
+
+ StringRef getArgument() const override {
+ return "iree-hal-materialize-dispatch-instrumentation";
+ }
+
+ StringRef getDescription() const override {
+ return "Materializes dispatch instrumentation resources.";
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<mlir::arith::ArithDialect>();
+ registry.insert<IREE::HAL::HALDialect>();
+ registry.insert<IREE::Stream::StreamDialect>();
+ registry.insert<IREE::Util::UtilDialect>();
+ }
+
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+ if (moduleOp.getBody()->empty()) return;
+
+ auto moduleBuilder = OpBuilder(&moduleOp.getBody()->front());
+ auto i8Type = moduleBuilder.getI8Type();
+ auto i32Type = moduleBuilder.getI32Type();
+ auto indexType = moduleBuilder.getIndexType();
+
+ // Used for all instrumentation.
+ auto loc = moduleBuilder.getUnknownLoc();
+
+ // Currently statically sized to avoid disturbing too much by adding
+ // additional arguments to dispatches that need to be marshaled.
+ // We need to use the base power of two size for storage then add some
+ // padding for overflows and the write head location.
+ //
+ // [power of two storage buffer]
+ // [56 bytes of padding, may get overflow data]
+ // [8 bytes of write head]
+ int64_t totalBufferSize =
+ bufferSize.value + IREE_INSTRUMENT_DISPATCH_PADDING;
+ auto bufferSizeAttr = moduleBuilder.getIndexAttr(totalBufferSize);
+ auto bufferType = MemRefType::get({totalBufferSize}, i8Type);
+
+ // Create global device-side instrumentation resource.
+ auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
+ loc, "__dispatch_instrumentation",
+ /*isMutable=*/false,
+ moduleBuilder.getType<IREE::Stream::ResourceType>(
+ IREE::Stream::Lifetime::External));
+ {
+ auto initializerOp = moduleBuilder.create<IREE::Util::InitializerOp>(loc);
+ auto initializerBuilder =
+ OpBuilder::atBlockBegin(initializerOp.addEntryBlock());
+ Value bufferSize =
+ initializerBuilder.create<arith::ConstantOp>(loc, bufferSizeAttr);
+ Value buffer = initializerBuilder
+ .create<IREE::Stream::ResourceAllocOp>(
+ loc, globalOp.getType(), bufferSize,
+ /*uninitialized=*/true, /*affinity=*/nullptr)
+ .getResult(0);
+ initializerBuilder.create<IREE::Util::GlobalStoreOp>(loc, buffer,
+ globalOp);
+ initializerBuilder.create<IREE::Util::InitializerReturnOp>(loc);
+ }
+
+ FlatbufferBuilder metadataBuilder;
+
+ // Update all executable export signatures to include the instrumentation
+ // binding. We don't actually use it yet but ensure that it's available
+ // during translation. We keep track of which exports we instrument as some
+ // may be external declarations we can't modify.
+ SmallVector<iree_instruments_DispatchFunctionDef_ref_t>
+ dispatchFunctionRefs;
+ DenseMap<SymbolRefAttr, uint32_t> instrumentedExports;
+ auto bindingType = moduleBuilder.getType<IREE::Stream::BindingType>();
+ auto alignmentKey = moduleBuilder.getStringAttr("stream.alignment");
+ auto alignment64 = moduleBuilder.getIndexAttr(64);
+ for (auto executableOp : moduleOp.getOps<IREE::Stream::ExecutableOp>()) {
+ for (auto exportOp :
+ executableOp.getOps<IREE::Stream::ExecutableExportOp>()) {
+ auto funcOp = exportOp.lookupFunctionRef();
+ if (!funcOp) continue;
+
+ // Capture the source before we mess with it.
+ auto originalSource = getOpStr(funcOp);
+
+ // Mark as instrumented.
+ instrumentedExports[SymbolRefAttr::get(
+ executableOp.getNameAttr(), {SymbolRefAttr::get(exportOp)})] =
+ instrumentedExports.size();
+
+ // Update function signature to add the ringbuffer and dispatch ID.
+ auto funcType = funcOp.getFunctionType();
+ SmallVector<Type> argTypes(funcType.getInputs());
+ argTypes.push_back(bindingType);
+ argTypes.push_back(i32Type);
+ funcOp.setFunctionType(
+ FunctionType::get(&getContext(), argTypes, funcType.getResults()));
+ auto bindingArg = funcOp.front().addArgument(bindingType, loc);
+ auto dispatchIdArg = funcOp.front().addArgument(i32Type, loc);
+
+ // Fix up arg attrs (yuck).
+ SmallVector<DictionaryAttr> argAttrs;
+ funcOp.getAllArgAttrs(argAttrs);
+ argAttrs.push_back(DictionaryAttr::get(
+ &getContext(), {NamedAttribute(alignmentKey, alignment64)}));
+ argAttrs.push_back(DictionaryAttr::get(&getContext(), {}));
+ funcOp.setAllArgAttrs(argAttrs);
+
+ // Insert the workgroup instrumentation. Note that this happens before
+ // codegen would do tile-and-distribute, but that's ok as it should just
+ // ignore these ops when doing that. This only works because we aren't
+ // capturing workgroup ID yet with this instrumentation op and instead
+ // leave that until late in the codegen pipeline.
+ auto funcBuilder = OpBuilder::atBlockBegin(&funcOp.front());
+ Value zero = funcBuilder.create<arith::ConstantIndexOp>(loc, 0);
+ auto subspanOp = funcBuilder.create<IREE::Stream::BindingSubspanOp>(
+ loc, bufferType, bindingArg, /*byteOffset=*/zero, ValueRange{});
+ funcBuilder.create<IREE::HAL::InstrumentWorkgroupOp>(
+ loc, indexType, subspanOp.getResult(), dispatchIdArg);
+
+ // Build function metadata.
+ auto nameRef = metadataBuilder.createString(exportOp.getName());
+ auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(exportOp);
+ auto targetRef = metadataBuilder.createString(getAttrStr(targetAttr));
+ auto sourceRef = metadataBuilder.createString(originalSource);
+ iree_instruments_DispatchFunctionDef_start(metadataBuilder);
+ iree_instruments_DispatchFunctionDef_name_add(metadataBuilder, nameRef);
+ iree_instruments_DispatchFunctionDef_target_add(metadataBuilder,
+ targetRef);
+ iree_instruments_DispatchFunctionDef_source_add(metadataBuilder,
+ sourceRef);
+ dispatchFunctionRefs.push_back(
+ iree_instruments_DispatchFunctionDef_end(metadataBuilder));
+ }
+ }
+
+ // Find all dispatches to exports that we've instrumented and pass along the
+ // instrumentation buffer.
+ SmallVector<iree_instruments_DispatchSiteDef_ref_t> dispatchSiteRefs;
+ uint32_t dispatchSiteCount = 0;
+ for (auto funcLikeOp : moduleOp.getOps<FunctionOpInterface>()) {
+ funcLikeOp.walk([&](IREE::Stream::CmdExecuteOp executeOp) {
+ auto parentBuilder = OpBuilder(executeOp);
+
+ // Load the ringbuffer and capture it for use within the execute region.
+ auto loadOp =
+ parentBuilder.create<IREE::Util::GlobalLoadOp>(loc, globalOp);
+ Value zero = parentBuilder.create<arith::ConstantIndexOp>(loc, 0);
+ Value bufferSize =
+ parentBuilder.create<arith::ConstantOp>(loc, bufferSizeAttr);
+ executeOp.getResourceOperandsMutable().append(loadOp.getResult());
+ executeOp.getResourceOperandSizesMutable().append(bufferSize);
+ auto bufferArg = executeOp.getBody().addArgument(loadOp.getType(), loc);
+
+ // Walk dispatches and pass them the ringbuffer and their unique ID.
+ executeOp.walk([&](IREE::Stream::CmdDispatchOp dispatchOp) {
+ auto it = instrumentedExports.find(dispatchOp.getEntryPoint());
+ if (it == instrumentedExports.end()) return; // not instrumented
+
+ // Append dispatch site ID to correlate this op with where it lives in
+ // the program and what is being dispatched. Note that multiple
+ // dispatch ops may reference the same dispatch function after
+ // deduplication.
+ uint32_t dispatchSiteId = dispatchSiteCount++;
+ dispatchOp.getUniformOperandsMutable().append(
+ parentBuilder
+ .create<arith::ConstantIntOp>(loc, dispatchSiteId, 32)
+ .getResult());
+
+ // Record dispatch site to the host-side metadata.
+ iree_instruments_DispatchSiteDef_start(metadataBuilder);
+ // TODO(benvanik): source loc to identify the site.
+ iree_instruments_DispatchSiteDef_function_add(metadataBuilder,
+ it->second);
+ dispatchSiteRefs.push_back(
+ iree_instruments_DispatchSiteDef_end(metadataBuilder));
+
+ // Append ringbuffer for storing the instrumentation data.
+ dispatchOp.getResourcesMutable().append(bufferArg);
+ dispatchOp.getResourceOffsetsMutable().append(zero);
+ dispatchOp.getResourceLengthsMutable().append(bufferSize);
+ dispatchOp.getResourceSizesMutable().append(bufferSize);
+ SmallVector<Attribute> accesses(
+ dispatchOp.getResourceAccesses().getValue());
+ accesses.push_back(IREE::Stream::ResourceAccessBitfieldAttr::get(
+ &getContext(), IREE::Stream::ResourceAccessBitfield::Read |
+ IREE::Stream::ResourceAccessBitfield::Write));
+ dispatchOp.setResourceAccessesAttr(
+ parentBuilder.getArrayAttr(accesses));
+ });
+ });
+ }
+
+ auto dispatchFunctionsRef = iree_instruments_DispatchFunctionDef_vec_create(
+ metadataBuilder, dispatchFunctionRefs.data(),
+ dispatchFunctionRefs.size());
+ auto dispatchSitesRef = iree_instruments_DispatchSiteDef_vec_create(
+ metadataBuilder, dispatchSiteRefs.data(), dispatchSiteRefs.size());
+ iree_instruments_DispatchInstrumentDef_start_as_root(metadataBuilder);
+ iree_instruments_DispatchInstrumentDef_version_add(metadataBuilder, 0);
+ iree_instruments_DispatchInstrumentDef_flags_add(metadataBuilder, 0);
+ iree_instruments_DispatchInstrumentDef_functions_add(metadataBuilder,
+ dispatchFunctionsRef);
+ iree_instruments_DispatchInstrumentDef_sites_add(metadataBuilder,
+ dispatchSitesRef);
+ iree_instruments_DispatchInstrumentDef_end_as_root(metadataBuilder);
+ auto metadataAttr = metadataBuilder.getBufferAttr(&getContext());
+
+ // Create query function for getting the instrumentation data.
+ auto listType = moduleBuilder.getType<IREE::Util::ListType>(
+ moduleBuilder.getType<IREE::Util::VariantType>());
+ auto queryOp = moduleBuilder.create<func::FuncOp>(
+ loc, "__query_instruments",
+ moduleBuilder.getFunctionType({listType}, {}));
+ {
+ queryOp.setPublic();
+ auto *entryBlock = queryOp.addEntryBlock();
+ auto queryBuilder = OpBuilder::atBlockBegin(entryBlock);
+ auto listArg = entryBlock->getArgument(0);
+
+ SmallVector<Value> iovecs;
+
+ iovecs.push_back(
+ createChunkHeader(loc, IREE_IDBTS_CHUNK_TYPE_DISPATCH_METADATA,
+ metadataAttr.size(), queryBuilder));
+
+ // Grab the read-only dispatch metadata.
+ iovecs.push_back(queryBuilder.create<IREE::Util::BufferConstantOp>(
+ loc, queryBuilder.getStringAttr("dispatch_instrument.fb"),
+ metadataAttr, queryBuilder.getIndexAttr(16),
+ queryBuilder.getStringAttr("application/x-flatbuffers")));
+
+ if (Value metadataPadding =
+ createPadding(loc, metadataAttr.size(), queryBuilder)) {
+ iovecs.push_back(metadataPadding);
+ }
+
+ iovecs.push_back(
+ createChunkHeader(loc, IREE_IDBTS_CHUNK_TYPE_DISPATCH_RINGBUFFER,
+ totalBufferSize, queryBuilder));
+
+ // Export the device buffer containing the instrument data.
+ Value buffer =
+ queryBuilder.create<IREE::Util::GlobalLoadOp>(loc, globalOp);
+ Value bufferSize =
+ queryBuilder.create<arith::ConstantOp>(loc, bufferSizeAttr);
+ auto bufferViewType = moduleBuilder.getType<IREE::HAL::BufferViewType>();
+ auto exportOp = queryBuilder.create<IREE::Stream::TensorExportOp>(
+ loc, bufferViewType, buffer,
+ RankedTensorType::get({totalBufferSize}, queryBuilder.getI8Type()),
+ ValueRange{}, bufferSize,
+ /*affinity=*/nullptr);
+ iovecs.push_back(exportOp.getResult());
+
+ if (Value ringbufferPadding =
+ createPadding(loc, totalBufferSize, queryBuilder)) {
+ iovecs.push_back(ringbufferPadding);
+ }
+
+ appendListItems(loc, listArg, iovecs, queryBuilder);
+ queryBuilder.create<func::ReturnOp>(loc);
+ }
+ }
+
+ private:
+ Option<llvm::cl::PowerOf2ByteSize> bufferSize{
+ *this,
+ "bufferSize",
+ llvm::cl::desc("Power-of-two byte size of the instrumentation buffer."),
+ llvm::cl::init(llvm::cl::PowerOf2ByteSize(64 * 1024 * 1024)),
+ };
+};
+
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+createMaterializeDispatchInstrumentationPass(int64_t bufferSize) {
+ return std::make_unique<MaterializeDispatchInstrumentationPass>(bufferSize);
+}
+
+static PassRegistration<MaterializeDispatchInstrumentationPass> pass([] {
+ return std::make_unique<MaterializeDispatchInstrumentationPass>();
+});
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
index afda5d2..80c20d0 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -11,6 +11,7 @@
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
+#include "iree/compiler/Utils/OptionUtils.h"
#include "iree/compiler/Utils/PassUtils.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
@@ -30,14 +31,18 @@
// *this, "targets", llvm::cl::desc("One or more HAL devices to target."),
// llvm::cl::ZeroOrMore};
Option<bool> serializeExecutables{
- *this, "serialize-executables",
+ *this,
+ "serialize-executables",
llvm::cl::desc("Whether to serialize hal.executable.variant ops to "
"hal.executable.binary ops."),
- llvm::cl::init(true)};
+ llvm::cl::init(true),
+ };
Option<bool> linkExecutables{
- *this, "link-executables",
+ *this,
+ "link-executables",
llvm::cl::desc("Whether to link hal.executable ops together."),
- llvm::cl::init(true)};
+ llvm::cl::init(true),
+ };
};
static llvm::cl::opt<unsigned> clBenchmarkDispatchRepeatCount{
@@ -46,7 +51,15 @@
"The number of times to repeat each hal.command_buffer.dispatch op. "
"This simply duplicates the dispatch op and inserts barriers. It's "
"meant for command buffers having linear dispatch structures."),
- llvm::cl::init(1)};
+ llvm::cl::init(1),
+};
+
+static llvm::cl::opt<llvm::cl::PowerOf2ByteSize> clInstrumentDispatchBufferSize{
+ "iree-hal-instrument-dispatches",
+ llvm::cl::desc("Enables dispatch instrumentation with a power-of-two byte "
+ "size used for storage (16mib, 64mib, 2gib, etc)."),
+ llvm::cl::init(llvm::cl::PowerOf2ByteSize(0)),
+};
static llvm::cl::list<std::string> clSubstituteExecutableSource{
"iree-hal-substitute-executable-source",
@@ -148,6 +161,13 @@
}
passManager.addPass(createVerifyTargetEnvironmentPass());
+ // Add dispatch instrumentation prior to materializing interfaces so we can
+ // more easily mutate the stream dispatch ops and exports.
+ if (auto bufferSize = clInstrumentDispatchBufferSize.getValue()) {
+ passManager.addPass(
+ createMaterializeDispatchInstrumentationPass(bufferSize.value));
+ }
+
// Each executable needs a hal.interface to specify how the host and
// device communicate across the ABI boundary.
passManager.addPass(createMaterializeInterfacesPass());
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h
index f858c62..bcd949e 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h
@@ -155,6 +155,10 @@
// Resource initialization, caching, and optimization
//===----------------------------------------------------------------------===//
+// Materializes host and device dispatch instrumentation resources on stream IR.
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+createMaterializeDispatchInstrumentationPass(int64_t bufferSize);
+
// Finds all resource lookups (such as hal.executable.lookup), materializes
// their cache storage and initialization, and rewrites the lookups to
// references.
@@ -186,6 +190,7 @@
createFixupLegacySyncPass();
createLinkExecutablesPass();
createLinkTargetExecutablesPass("");
+ createMaterializeDispatchInstrumentationPass(0);
createMaterializeInterfacesPass();
createMaterializeResourceCachesPass(targetOptions);
createMemoizeDeviceQueriesPass();
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD
index aa15ea4..623f210 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD
@@ -24,6 +24,7 @@
"elide_redundant_commands.mlir",
"fixup_legacy_sync.mlir",
"inline_device_switches.mlir",
+ "materialize_dispatch_instrumentation.mlir",
"materialize_interfaces.mlir",
"materialize_resource_caches.mlir",
"memoize_device_queries.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt
index 64eaf9b..dbda10b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt
@@ -22,6 +22,7 @@
"elide_redundant_commands.mlir"
"fixup_legacy_sync.mlir"
"inline_device_switches.mlir"
+ "materialize_dispatch_instrumentation.mlir"
"materialize_interfaces.mlir"
"materialize_resource_caches.mlir"
"memoize_device_queries.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_dispatch_instrumentation.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_dispatch_instrumentation.mlir
new file mode 100644
index 0000000..3ea7bf2
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_dispatch_instrumentation.mlir
@@ -0,0 +1,80 @@
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-materialize-dispatch-instrumentation{bufferSize=64mib})' %s | FileCheck %s
+
+module attributes {hal.device.targets = [
+ #hal.device.target<"llvm-cpu", {
+ executable_targets = [
+ #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64">,
+ #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64">
+ ]
+ }>
+]} {
+
+ // Instrumentation storage buffer allocated at startup (defaults to 64MB + footer):
+ // CHECK: util.global public @__dispatch_instrumentation : !stream.resource<external>
+ // CHECK: util.initializer
+ // CHECK: %[[DEFAULT_SIZE:.+]] = arith.constant 67112960
+ // CHECK: %[[ALLOC_BUFFER:.+]] = stream.resource.alloc uninitialized : !stream.resource<external>{%[[DEFAULT_SIZE]]}
+ // CHECK: util.global.store %[[ALLOC_BUFFER]], @__dispatch_instrumentation
+
+ // Query function used by tools to get the buffers and metadata:
+ // CHECK: func.func @__query_instruments(%[[LIST:.+]]: !util.list<?>)
+ // CHECK: %[[INTERNAL_BUFFER:.+]] = util.global.load @__dispatch_instrumentation
+ // CHECK: %[[EXPORTED_BUFFER:.+]] = stream.tensor.export %[[INTERNAL_BUFFER]]
+ // CHECK: util.list.set %[[LIST]]{{.+}}
+ // CHECK: util.list.set %[[LIST]]{{.+}}
+ // CHECK: util.list.set %[[LIST]]{{.+}}
+ // CHECK: util.list.set %[[LIST]]{{.+}}
+ // CHECK: util.list.set %[[LIST]]{{.+}}, %[[EXPORTED_BUFFER]]
+
+ stream.executable private @executable {
+ stream.executable.export public @dispatch workgroups() -> (index, index, index) {
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root
+ stream.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ // Dispatches get the instrumentation buffer and a unique dispatch site ID:
+ // CHECK: func.func @dispatch
+ // CHECK-SAME: (%arg0: !stream.binding {stream.alignment = 64 : index}, %arg1: !stream.binding {stream.alignment = 64 : index}, %[[INSTR_BINDING:.+]]: !stream.binding {stream.alignment = 64 : index}, %[[SITE_ID:.+]]: i32)
+ func.func @dispatch(%arg0: !stream.binding {stream.alignment = 64 : index}, %arg1: !stream.binding {stream.alignment = 64 : index}) {
+ // Default instrumentation just adds the workgroup marker.
+ // Subsequent dispatch instruments will use the workgroup key.
+ // CHECK: %[[INSTR_BUFFER:.+]] = stream.binding.subspan %[[INSTR_BINDING]]
+ // CHECK: %[[WORKGROUP_KEY:.+]] = hal.instrument.workgroup[%[[INSTR_BUFFER]] : memref<67112960xi8>] dispatch(%[[SITE_ID]]) : index
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 2.000000e+00 : f32
+ %0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<f32>>
+ %1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<f32>>
+ %2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:tensor<f32>> -> tensor<f32>
+ %3 = tensor.empty() : tensor<f32>
+ %4 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%2 : tensor<f32>) outs(%3 : tensor<f32>) {
+ ^bb0(%in: f32, %out: f32):
+ %5 = math.powf %in, %cst : f32
+ linalg.yield %5 : f32
+ } -> tensor<f32>
+ flow.dispatch.tensor.store %4, %1, offsets = [], sizes = [], strides = [] : tensor<f32> -> !flow.dispatch.tensor<writeonly:tensor<f32>>
+ return
+ }
+ }
+ }
+ func.func @main(%arg0: !stream.resource<external>) -> !stream.resource<external> {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ %ret0 = stream.resource.alloc uninitialized : !stream.resource<external>{%c128}
+ // The instrumentation buffer is captured by submissions for dispatch.
+ // Note that there's no synchronization here (no timepoint waits/etc) as
+ // all accesses to the buffer are atomic.
+ // CHECK: %[[EXECUTE_BUFFER:.+]] = util.global.load @__dispatch_instrumentation
+ // CHECK: stream.cmd.execute
+ // CHECK-SAME: %[[EXECUTE_BUFFER]] as %[[CAPTURE_BUFFER:.+]]: !stream.resource<external>{%[[DEFAULT_SIZE]]})
+ %timepoint = stream.cmd.execute with(%arg0 as %arg0_capture: !stream.resource<external>{%c128}, %ret0 as %ret0_capture: !stream.resource<external>{%c128}) {
+ // CHECK: stream.cmd.dispatch @executable::@dispatch
+ stream.cmd.dispatch @executable::@dispatch {
+ ro %arg0_capture[%c0 for %c128] : !stream.resource<external>{%c128},
+ wo %ret0_capture[%c0 for %c128] : !stream.resource<external>{%c128}
+ // CHECK: rw %[[CAPTURE_BUFFER]]
+ }
+ } => !stream.timepoint
+ %ret0_ready = stream.timepoint.await %timepoint => %ret0 : !stream.resource<external>{%c128}
+ return %ret0_ready : !stream.resource<external>
+ }
+}
diff --git a/compiler/src/iree/compiler/Utils/OptionUtils.cpp b/compiler/src/iree/compiler/Utils/OptionUtils.cpp
index 2a2a7f2..6174e67 100644
--- a/compiler/src/iree/compiler/Utils/OptionUtils.cpp
+++ b/compiler/src/iree/compiler/Utils/OptionUtils.cpp
@@ -92,3 +92,106 @@
} // namespace iree_compiler
} // namespace mlir
+
+// Parses a byte size in |value| and returns the value in |out_size|.
+//
+// Examples:
+// 1073741824 => 1073741824
+// 1gb => 1000000000
+// 1gib => 1073741824
+static int64_t ParseByteSize(llvm::StringRef value) {
+ // TODO(benvanik): probably worth to-lowering here on the size. Having copies
+ // of all the string view utils for just this case is code size overkill. For
+ // now only accept lazy lowercase.
+ int64_t scale = 1;
+ if (value.consume_back_insensitive("kb")) {
+ scale = 1000;
+ } else if (value.consume_back_insensitive("kib")) {
+ scale = 1024;
+ } else if (value.consume_back_insensitive("mb")) {
+ scale = 1000 * 1000;
+ } else if (value.consume_back_insensitive("mib")) {
+ scale = 1024 * 1024;
+ } else if (value.consume_back_insensitive("gb")) {
+ scale = 1000 * 1000 * 1000;
+ } else if (value.consume_back_insensitive("gib")) {
+ scale = 1024 * 1024 * 1024;
+ } else if (value.consume_back_insensitive("b")) {
+ scale = 1;
+ }
+ auto terminatedStr = value.str();
+ int64_t size = std::atoll(terminatedStr.data());
+ return size * scale;
+}
+
+namespace llvm {
+namespace cl {
+template class basic_parser<ByteSize>;
+template class basic_parser<PowerOf2ByteSize>;
+} // namespace cl
+} // namespace llvm
+
+using ByteSize = llvm::cl::ByteSize;
+using PowerOf2ByteSize = llvm::cl::PowerOf2ByteSize;
+
+// Return true on error.
+bool llvm::cl::parser<ByteSize>::parse(Option &O, StringRef ArgName,
+ StringRef Arg, ByteSize &Val) {
+ Val.value = ParseByteSize(Arg);
+ return false;
+}
+
+void llvm::cl::parser<ByteSize>::printOptionDiff(const Option &O, ByteSize V,
+ const OptVal &Default,
+ size_t GlobalWidth) const {
+ printOptionName(O, GlobalWidth);
+ std::string Str;
+ {
+ llvm::raw_string_ostream SS(Str);
+ SS << V.value;
+ }
+ outs() << "= " << Str;
+ outs().indent(2) << " (default: ";
+ if (Default.hasValue()) {
+ outs() << Default.getValue().value;
+ } else {
+ outs() << "*no default*";
+ }
+ outs() << ")\n";
+}
+
+void llvm::cl::parser<ByteSize>::anchor() {}
+
+// Return true on error.
+bool llvm::cl::parser<PowerOf2ByteSize>::parse(Option &O, StringRef ArgName,
+ StringRef Arg,
+ PowerOf2ByteSize &Val) {
+ Val.value = ParseByteSize(Arg);
+ if (!llvm::isPowerOf2_64(Val.value)) {
+ return O.error("'" + Arg +
+ "' value not a power-of-two, use 16mib/64mib/2gb/etc");
+ return true;
+ }
+ return false;
+}
+
+void llvm::cl::parser<PowerOf2ByteSize>::printOptionDiff(
+ const Option &O, PowerOf2ByteSize V, const OptVal &Default,
+ size_t GlobalWidth) const {
+ printOptionName(O, GlobalWidth);
+ std::string Str;
+ {
+ llvm::raw_string_ostream SS(Str);
+ SS << V.value;
+ }
+ outs() << "= " << Str;
+ outs().indent(2) << " (default: ";
+ if (Default.hasValue()) {
+ outs() << Default.getValue().value;
+ } else {
+ outs() << "*no default*";
+ }
+ outs() << ")\n";
+}
+
+void llvm::cl::parser<PowerOf2ByteSize>::anchor() {}
diff --git a/compiler/src/iree/compiler/Utils/OptionUtils.h b/compiler/src/iree/compiler/Utils/OptionUtils.h
index 2e7fae5..4bc2f3e 100644
--- a/compiler/src/iree/compiler/Utils/OptionUtils.h
+++ b/compiler/src/iree/compiler/Utils/OptionUtils.h
@@ -227,4 +227,47 @@
} // namespace iree_compiler
} // namespace mlir
+namespace llvm {
+namespace cl {
+
+struct ByteSize {
+ int64_t value = 0;
+ ByteSize() = default;
+ ByteSize(int64_t value) : value(value) {}
+ operator bool() const noexcept { return value != 0; }
+};
+
+struct PowerOf2ByteSize : public ByteSize {
+ using ByteSize::ByteSize;
+};
+
+extern template class basic_parser<ByteSize>;
+extern template class basic_parser<PowerOf2ByteSize>;
+
+template <>
+class parser<ByteSize> : public basic_parser<ByteSize> {
+ public:
+ parser(Option &O) : basic_parser(O) {}
+ bool parse(Option &O, StringRef ArgName, StringRef Arg, ByteSize &Val);
+ StringRef getValueName() const override { return "byte size"; }
+ void printOptionDiff(const Option &O, ByteSize V, const OptVal &Default,
+ size_t GlobalWidth) const;
+ void anchor() override;
+};
+
+template <>
+class parser<PowerOf2ByteSize> : public basic_parser<PowerOf2ByteSize> {
+ public:
+ parser(Option &O) : basic_parser(O) {}
+ bool parse(Option &O, StringRef ArgName, StringRef Arg,
+ PowerOf2ByteSize &Val);
+ StringRef getValueName() const override { return "power of two byte size"; }
+ void printOptionDiff(const Option &O, PowerOf2ByteSize V,
+ const OptVal &Default, size_t GlobalWidth) const;
+ void anchor() override;
+};
+
+} // namespace cl
+} // namespace llvm
+
#endif // IREE_COMPILER_UTILS_FLAG_UTILS_H
diff --git a/runtime/src/iree/base/string_view.c b/runtime/src/iree/base/string_view.c
index 7172620..cd92211 100644
--- a/runtime/src/iree/base/string_view.c
+++ b/runtime/src/iree/base/string_view.c
@@ -478,6 +478,8 @@
scale = 1000 * 1000 * 1000;
} else if (iree_string_view_consume_suffix(&value, IREE_SV("gib"))) {
scale = 1024 * 1024 * 1024;
+ } else if (iree_string_view_consume_suffix(&value, IREE_SV("b"))) {
+ scale = 1;
}
uint64_t size = 0;
if (!iree_string_view_atoi_uint64(value, &size)) {
diff --git a/runtime/src/iree/base/string_view_test.cc b/runtime/src/iree/base/string_view_test.cc
index 4fdffc0..5818212 100644
--- a/runtime/src/iree/base/string_view_test.cc
+++ b/runtime/src/iree/base/string_view_test.cc
@@ -537,10 +537,13 @@
EXPECT_THAT(ParseDeviceSize("0"), IsOkAndHolds(Eq(0u)));
EXPECT_THAT(ParseDeviceSize("1"), IsOkAndHolds(Eq(1u)));
EXPECT_THAT(ParseDeviceSize("10000"), IsOkAndHolds(Eq(10000u)));
+ EXPECT_THAT(ParseDeviceSize("0b"), IsOkAndHolds(Eq(0u)));
EXPECT_THAT(ParseDeviceSize("0kb"), IsOkAndHolds(Eq(0u)));
EXPECT_THAT(ParseDeviceSize("0gib"), IsOkAndHolds(Eq(0u)));
+ EXPECT_THAT(ParseDeviceSize("1b"), IsOkAndHolds(Eq(1)));
EXPECT_THAT(ParseDeviceSize("1kb"), IsOkAndHolds(Eq(1 * 1000u)));
EXPECT_THAT(ParseDeviceSize("1kib"), IsOkAndHolds(Eq(1 * 1024u)));
+ EXPECT_THAT(ParseDeviceSize("1000b"), IsOkAndHolds(Eq(1000 * 1u)));
EXPECT_THAT(ParseDeviceSize("1000kb"), IsOkAndHolds(Eq(1000 * 1000u)));
EXPECT_THAT(ParseDeviceSize("1000kib"), IsOkAndHolds(Eq(1000 * 1024u)));
diff --git a/runtime/src/iree/schemas/instruments/BUILD b/runtime/src/iree/schemas/instruments/BUILD
new file mode 100644
index 0000000..c666fc1
--- /dev/null
+++ b/runtime/src/iree/schemas/instruments/BUILD
@@ -0,0 +1,41 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_build_test", "iree_runtime_cc_library")
+load("//build_tools/bazel:iree_flatcc.bzl", "iree_flatbuffer_c_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+FLATCC_ARGS = [
+ "--reader",
+ "--builder",
+ "--verifier",
+ "--json",
+]
+
+iree_flatbuffer_c_library(
+ name = "dispatch_def_c_fbs",
+ srcs = ["dispatch_def.fbs"],
+ flatcc_args = FLATCC_ARGS,
+)
+
+iree_build_test(
+ name = "schema_build_test",
+ targets = [
+ ":dispatch_def_c_fbs",
+ ],
+)
+
+iree_runtime_cc_library(
+ name = "instruments",
+ hdrs = [
+ "dispatch.h",
+ ],
+)
diff --git a/runtime/src/iree/schemas/instruments/CMakeLists.txt b/runtime/src/iree/schemas/instruments/CMakeLists.txt
new file mode 100644
index 0000000..50370e7
--- /dev/null
+++ b/runtime/src/iree/schemas/instruments/CMakeLists.txt
@@ -0,0 +1,36 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# runtime/src/iree/schemas/instruments/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+flatbuffer_c_library(
+ NAME
+ dispatch_def_c_fbs
+ SRCS
+ "dispatch_def.fbs"
+ FLATCC_ARGS
+ "--reader"
+ "--builder"
+ "--verifier"
+ "--json"
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ instruments
+ HDRS
+ "dispatch.h"
+ DEPS
+
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/runtime/src/iree/schemas/instruments/dispatch.h b/runtime/src/iree/schemas/instruments/dispatch.h
new file mode 100644
index 0000000..08def52
--- /dev/null
+++ b/runtime/src/iree/schemas/instruments/dispatch.h
@@ -0,0 +1,133 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_SCHEMAS_INSTRUMENTS_DISPATCH_H_
+#define IREE_SCHEMAS_INSTRUMENTS_DISPATCH_H_
+
+#include <stdint.h>
+
+//===----------------------------------------------------------------------===//
+// iree_idbts_chunk_t
+//===----------------------------------------------------------------------===//
+// Represents a single chunk of data within a larger chunked transport stream.
+// Each chunk has a type and length to allow quick scans for chunks of interest.
+// Chunks must always start at 16-byte alignment boundaries.
+//
+// TODO(benvanik): clean up and move to a dedicated idbts.h.
+// This is temporarily here while iterating on the transport stream.
+
+// Chunk magic identifier.
+// "IREE Instrumentation Database v0"
+// "IDB0" = 0x49 0x44 0x42 0x30
+typedef uint32_t iree_idbts_chunk_magic_t;
+#define IREE_IDBTS_CHUNK_MAGIC 0x42444930u
+
+// Chunk type.
+enum iree_idbts_chunk_type_e {
+ // NOTE: these will change in the real IDB spec.
+ IREE_IDBTS_CHUNK_TYPE_DISPATCH_METADATA = 0x0000u,
+ IREE_IDBTS_CHUNK_TYPE_DISPATCH_RINGBUFFER = 0x0001u,
+};
+typedef uint16_t iree_idbts_chunk_type_t;
+
+// IDB chunk format version.
+// Instruments and other embedded chunks may version themselves independently to
+// prevent entire files from being invalidated on compiler bumps.
+typedef uint16_t iree_idbts_chunk_version_t;
+
+// Header at the prefix of each chunk in the file.
+// Always aligned to 16-bytes in the file such that the trailing chunk contents
+// are 16-byte aligned.
+typedef struct {
+ // Magic header bytes; must be IREE_IDBTS_CHUNK_MAGIC.
+ iree_idbts_chunk_magic_t magic;
+ // Type of the chunk used to interpret the payload.
+ iree_idbts_chunk_type_t type;
+ // Type-specific version identifier. Usually 0.
+ iree_idbts_chunk_version_t version;
+ // Total byte length of the chunk content excluding this header.
+ uint64_t content_length;
+} iree_idbts_chunk_header_t;
+static_assert(sizeof(iree_idbts_chunk_header_t) % 16 == 0,
+ "chunk header must be 16-byte aligned");
+
+//===----------------------------------------------------------------------===//
+// Dispatch instrumentation ringbuffer transport stream
+//===----------------------------------------------------------------------===//
+
+// Total padding added to the base power-of-two ringbuffer size.
+// The last 8 bytes are the monotonically increasing ringbuffer write head.
+// Up to the (padding-8) bytes are available for the ringbuffer to spill past
+// its end.
+#define IREE_INSTRUMENT_DISPATCH_PADDING 4096
+
+typedef enum iree_instrument_dispatch_type_e {
+ IREE_INSTRUMENT_DISPATCH_TYPE_WORKGROUP = 0b00000000,
+ IREE_INSTRUMENT_DISPATCH_TYPE_PRINT = 0b00000001,
+ IREE_INSTRUMENT_DISPATCH_TYPE_VALUE = 0b00000010,
+ IREE_INSTRUMENT_DISPATCH_TYPE_RESERVED_0 = 0b00000011, // free for use
+ IREE_INSTRUMENT_DISPATCH_TYPE_MEMORY_LOAD = 0b00000100,
+ IREE_INSTRUMENT_DISPATCH_TYPE_MEMORY_STORE = 0b00000101,
+} iree_instrument_dispatch_type_t;
+
+typedef struct iree_instrument_dispatch_header_t {
+ uint32_t tag : 8;
+ uint32_t unknown : 24;
+} iree_instrument_dispatch_header_t;
+
+typedef struct iree_instrument_dispatch_workgroup_t {
+ uint32_t tag : 8; // IREE_INSTRUMENT_DISPATCH_TYPE_WORKGROUP
+ uint32_t dispatch_id : 24;
+ uint32_t workgroup_id_x;
+ uint32_t workgroup_id_y;
+ uint32_t workgroup_id_z;
+ uint32_t workgroup_count_x;
+ uint32_t workgroup_count_y;
+ uint32_t workgroup_count_z;
+ uint32_t processor_id;
+} iree_instrument_dispatch_workgroup_t;
+
+typedef struct iree_instrument_dispatch_print_t {
+ uint64_t tag : 8; // IREE_INSTRUMENT_DISPATCH_TYPE_PRINT
+ uint64_t length : 16;
+ uint64_t workgroup_offset : 40;
+ uint8_t data[]; // length, padded to ensure 16b struct, no NUL terminator
+} iree_instrument_dispatch_print_t;
+
+typedef struct iree_instrument_dispatch_memory_op_t {
+ uint64_t
+ tag : 8; // IREE_INSTRUMENT_DISPATCH_TYPE_MEMORY_LOAD / _MEMORY_STORE
+ uint64_t length : 16;
+ uint64_t workgroup_offset : 40;
+ uint64_t address;
+} iree_instrument_dispatch_memory_op_t;
+
+enum iree_instrument_dispatch_value_type_e {
+ IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_SINT_8 = 0,
+ IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_UINT_8,
+ IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_SINT_16,
+ IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_UINT_16,
+ IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_SINT_32,
+ IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_UINT_32,
+ IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_SINT_64,
+ IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_UINT_64,
+ IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_POINTER,
+ IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_FLOAT_16,
+ IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_FLOAT_32,
+ IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_FLOAT_64,
+ IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_BFLOAT_16,
+};
+typedef uint64_t iree_instrument_dispatch_value_type_t;
+
+typedef struct iree_instrument_dispatch_value_t {
+ uint64_t tag : 8; // IREE_INSTRUMENT_DISPATCH_TYPE_VALUE
+ uint64_t type : 8; // iree_instrument_dispatch_value_type_t
+ uint64_t ordinal : 8;
+ uint64_t workgroup_offset : 40;
+ uint64_t bits;
+} iree_instrument_dispatch_value_t;
+
+#endif // IREE_SCHEMAS_INSTRUMENTS_DISPATCH_H_
diff --git a/runtime/src/iree/schemas/instruments/dispatch_def.fbs b/runtime/src/iree/schemas/instruments/dispatch_def.fbs
new file mode 100644
index 0000000..8201219
--- /dev/null
+++ b/runtime/src/iree/schemas/instruments/dispatch_def.fbs
@@ -0,0 +1,47 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+namespace iree.instruments;
+
+// IREE Instrument DataBase for Dispatches.
+file_identifier "IDBD";
+file_extension "fb";
+
+table DispatchFunctionDef {
+ // MLIR symbol name of the exported function.
+ name:string;
+ // Target executable configuration attribute.
+ target:string;
+ // Pipeline layout string.
+ layout:string;
+ // Function source code.
+ source:string;
+ // TODO(benvanik): other structural information from the IR like bindings.
+}
+
+// A unique dispatch site within the program.
+// Many dispatch sites may dispatch the same function.
+// Note that dispatch sites may not be unique within the instrument stream as
+// loops and repeated function calls may cause the same dispatch site to be
+// reached many times.
+table DispatchSiteDef {
+ // Dispatched function ordinal in the function table.
+ function:uint32;
+}
+
+table DispatchInstrumentDef {
+ // 0 today - not stable!
+ version:uint32;
+ // Reserved.
+ flags:uint32;
+ // All functions referenced by ordinal in their original module order.
+ functions:[DispatchFunctionDef];
+ // All unique dispatch sites within the program.
+ sites:[DispatchSiteDef];
+}
+
+root_type DispatchInstrumentDef;
+
diff --git a/runtime/src/iree/tooling/BUILD b/runtime/src/iree/tooling/BUILD
index 423aa24..2c58739 100644
--- a/runtime/src/iree/tooling/BUILD
+++ b/runtime/src/iree/tooling/BUILD
@@ -108,6 +108,21 @@
)
cc_library(
+ name = "instrument_util",
+ srcs = ["instrument_util.c"],
+ hdrs = ["instrument_util.h"],
+ deps = [
+ "//runtime/src/iree/base",
+ "//runtime/src/iree/base:tracing",
+ "//runtime/src/iree/base/internal:flags",
+ "//runtime/src/iree/hal",
+ "//runtime/src/iree/modules/hal:types",
+ "//runtime/src/iree/schemas/instruments",
+ "//runtime/src/iree/vm",
+ ],
+)
+
+cc_library(
name = "numpy_io",
srcs = ["numpy_io.c"],
hdrs = ["numpy_io.h"],
diff --git a/runtime/src/iree/tooling/CMakeLists.txt b/runtime/src/iree/tooling/CMakeLists.txt
index 2e4a9e3..423bf4d 100644
--- a/runtime/src/iree/tooling/CMakeLists.txt
+++ b/runtime/src/iree/tooling/CMakeLists.txt
@@ -121,6 +121,24 @@
iree_cc_library(
NAME
+ instrument_util
+ HDRS
+ "instrument_util.h"
+ SRCS
+ "instrument_util.c"
+ DEPS
+ iree::base
+ iree::base::internal::flags
+ iree::base::tracing
+ iree::hal
+ iree::modules::hal::types
+ iree::schemas::instruments
+ iree::vm
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
numpy_io
HDRS
"numpy_io.h"
diff --git a/runtime/src/iree/tooling/instrument_util.c b/runtime/src/iree/tooling/instrument_util.c
new file mode 100644
index 0000000..5566931
--- /dev/null
+++ b/runtime/src/iree/tooling/instrument_util.c
@@ -0,0 +1,129 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/tooling/instrument_util.h"
+
+#include <memory.h>
+#include <stdio.h>
+#include <string.h>
+
+#include "iree/base/internal/flags.h"
+#include "iree/base/tracing.h"
+#include "iree/modules/hal/types.h"
+
+//===----------------------------------------------------------------------===//
+// Instrument data management
+//===----------------------------------------------------------------------===//
+
+IREE_FLAG(string, instrument_file, "",
+ "File to populate with instrument data from the program.");
+
+static iree_status_t iree_tooling_write_iovec(iree_vm_ref_t iovec, FILE* file) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+ bool write_ok = false;
+ if (iree_vm_buffer_isa(iovec)) {
+ iree_vm_buffer_t* buffer = iree_vm_buffer_deref(iovec);
+ IREE_TRACE_ZONE_APPEND_VALUE(z0, (int64_t)iree_vm_buffer_length(buffer));
+ write_ok =
+ fwrite(iree_vm_buffer_data(buffer), 1, iree_vm_buffer_length(buffer),
+ file) == iree_vm_buffer_length(buffer);
+ } else if (iree_hal_buffer_view_isa(iovec)) {
+ iree_hal_buffer_view_t* buffer_view = iree_hal_buffer_view_deref(iovec);
+ IREE_TRACE_ZONE_APPEND_VALUE(
+ z0, (int64_t)iree_hal_buffer_view_byte_length(buffer_view));
+ iree_hal_buffer_mapping_t mapping;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_buffer_map_range(iree_hal_buffer_view_buffer(buffer_view),
+ IREE_HAL_MAPPING_MODE_SCOPED,
+ IREE_HAL_MEMORY_ACCESS_READ, 0,
+ IREE_WHOLE_BUFFER, &mapping));
+ write_ok = fwrite(mapping.contents.data, 1, mapping.contents.data_length,
+ file) == mapping.contents.data_length;
+ IREE_IGNORE_ERROR(iree_hal_buffer_unmap_range(&mapping));
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return write_ok ? iree_ok_status()
+ : iree_make_status(iree_status_code_from_errno(errno),
+ "failed to write iovec to file");
+}
+
+iree_status_t iree_tooling_process_instrument_data(
+ iree_vm_context_t* context, iree_allocator_t host_allocator) {
+ // If no flag was specified we ignore instrument data.
+ if (strlen(FLAG_instrument_file) == 0) return iree_ok_status();
+
+ IREE_TRACE_ZONE_BEGIN(z0);
+ IREE_TRACE_ZONE_APPEND_TEXT(z0, FLAG_instrument_file);
+
+ // Open the file for overwriting. We do this even if there is no instrument
+ // data in the program as we'd rather have the user end up with a 0-byte file
+ // when they explicitly ask for it instead of stale data from previous runs.
+ FILE* file = fopen(FLAG_instrument_file, "wb");
+ if (!file) {
+ IREE_TRACE_ZONE_END(z0);
+ return iree_make_status(iree_status_code_from_errno(errno),
+ "failed to open instrument file '%s' for writing",
+ FLAG_instrument_file);
+ }
+
+ // Each query function pushes iovecs on to a list we provide; we create one
+ // list and use that across all of them.
+ iree_vm_list_t* iovec_list = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_vm_list_create(NULL, 8, host_allocator, &iovec_list));
+
+ iree_vm_list_t* input_list = NULL;
+ iree_status_t status =
+ iree_vm_list_create(NULL, 8, host_allocator, &input_list);
+ if (iree_status_is_ok(status)) {
+ iree_vm_ref_t iovec_list_ref = iree_vm_list_retain_ref(iovec_list);
+ status = iree_vm_list_push_ref_move(input_list, &iovec_list_ref);
+ }
+
+ // Process instrument data from all modules in the context.
+ if (iree_status_is_ok(status)) {
+ for (iree_host_size_t i = 0; i < iree_vm_context_module_count(context);
+ ++i) {
+ iree_vm_module_t* module = iree_vm_context_module_at(context, i);
+ if (!module) continue;
+
+ // Find the query function, if present.
+ iree_vm_function_t query_func;
+ iree_status_t lookup_status = iree_vm_module_lookup_function_by_name(
+ module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
+ IREE_SV("__query_instruments"), &query_func);
+ if (!iree_status_is_ok(lookup_status)) {
+ // Skip missing/invalid query function.
+ iree_status_ignore(lookup_status);
+ continue;
+ }
+
+ IREE_TRACE_ZONE_BEGIN(z1);
+ IREE_TRACE_ZONE_APPEND_TEXT(z1, iree_vm_module_name(module).data,
+ iree_vm_module_name(module).size);
+ status = iree_vm_invoke(context, query_func, IREE_VM_INVOCATION_FLAG_NONE,
+ NULL, input_list, NULL, host_allocator);
+ IREE_TRACE_ZONE_END(z1);
+ if (!iree_status_is_ok(status)) break;
+ }
+ }
+
+ if (iree_status_is_ok(status)) {
+ for (iree_host_size_t i = 0; i < iree_vm_list_size(iovec_list); ++i) {
+ iree_vm_ref_t iovec = iree_vm_ref_null();
+ status = iree_vm_list_get_ref_assign(iovec_list, i, &iovec);
+ if (!iree_status_is_ok(status)) break;
+ status = iree_tooling_write_iovec(iovec, file);
+ if (!iree_status_is_ok(status)) break;
+ }
+ }
+
+ iree_vm_list_release(input_list);
+ iree_vm_list_release(iovec_list);
+ fclose(file);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
diff --git a/runtime/src/iree/tooling/instrument_util.h b/runtime/src/iree/tooling/instrument_util.h
new file mode 100644
index 0000000..d592bef
--- /dev/null
+++ b/runtime/src/iree/tooling/instrument_util.h
@@ -0,0 +1,30 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_TOOLING_INSTRUMENT_UTIL_H_
+#define IREE_TOOLING_INSTRUMENT_UTIL_H_
+
+#include "iree/base/api.h"
+#include "iree/vm/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// Instrument data management
+//===----------------------------------------------------------------------===//
+
+// Processes instrument data in |context| based on command line flags.
+// No-op if there's no instrument data available.
+iree_status_t iree_tooling_process_instrument_data(
+ iree_vm_context_t* context, iree_allocator_t host_allocator);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_TOOLING_INSTRUMENT_UTIL_H_
diff --git a/runtime/src/iree/vm/buffer.c b/runtime/src/iree/vm/buffer.c
index 0d94681..8427a07 100644
--- a/runtime/src/iree/vm/buffer.c
+++ b/runtime/src/iree/vm/buffer.c
@@ -181,6 +181,11 @@
return buffer->data.data_length;
}
+IREE_API_EXPORT uint8_t* iree_vm_buffer_data(const iree_vm_buffer_t* buffer) {
+ IREE_ASSERT_ARGUMENT(buffer);
+ return buffer->data.data;
+}
+
IREE_API_EXPORT iree_status_t iree_vm_buffer_copy_bytes(
const iree_vm_buffer_t* source_buffer, iree_host_size_t source_offset,
const iree_vm_buffer_t* target_buffer, iree_host_size_t target_offset,
diff --git a/runtime/src/iree/vm/buffer.h b/runtime/src/iree/vm/buffer.h
index 9cead6b..1d20984 100644
--- a/runtime/src/iree/vm/buffer.h
+++ b/runtime/src/iree/vm/buffer.h
@@ -126,8 +126,7 @@
// WARNING: this performs no validation of the access allowance on the buffer
// and the caller is responsible for all range checking. Use with caution and
// prefer the utility methods instead.
-IREE_API_EXPORT iree_byte_span_t
-iree_vm_buffer_data(const iree_vm_buffer_t* buffer);
+IREE_API_EXPORT uint8_t* iree_vm_buffer_data(const iree_vm_buffer_t* buffer);
// Copies a byte range from |source_buffer| to |target_buffer|.
IREE_API_EXPORT iree_status_t iree_vm_buffer_copy_bytes(
diff --git a/runtime/src/iree/vm/context.c b/runtime/src/iree/vm/context.c
index e3c7fe8..c287b13 100644
--- a/runtime/src/iree/vm/context.c
+++ b/runtime/src/iree/vm/context.c
@@ -395,6 +395,19 @@
return context->flags;
}
+IREE_API_EXPORT iree_host_size_t
+iree_vm_context_module_count(const iree_vm_context_t* context) {
+ IREE_ASSERT_ARGUMENT(context);
+ return context->list.count;
+}
+
+IREE_API_EXPORT iree_vm_module_t* iree_vm_context_module_at(
+ const iree_vm_context_t* context, iree_host_size_t i) {
+ IREE_ASSERT_ARGUMENT(context);
+ if (i >= context->list.count) return NULL;
+ return context->list.modules[i];
+}
+
IREE_API_EXPORT iree_status_t iree_vm_context_register_modules(
iree_vm_context_t* context, iree_host_size_t module_count,
iree_vm_module_t** modules) {
diff --git a/runtime/src/iree/vm/context.h b/runtime/src/iree/vm/context.h
index 4d629d2..1370ac5 100644
--- a/runtime/src/iree/vm/context.h
+++ b/runtime/src/iree/vm/context.h
@@ -83,6 +83,14 @@
IREE_API_EXPORT iree_vm_context_flags_t
iree_vm_context_flags(const iree_vm_context_t* context);
+// Returns the total number of modules registered in |context|.
+IREE_API_EXPORT iree_host_size_t
+iree_vm_context_module_count(const iree_vm_context_t* context);
+
+// Returns the module registered at index |i| in |context|.
+IREE_API_EXPORT iree_vm_module_t* iree_vm_context_module_at(
+ const iree_vm_context_t* context, iree_host_size_t i);
+
// Registers a list of modules with the context and resolves imports in the
// order provided.
// The modules will be retained by the context until destruction.
diff --git a/tools/BUILD b/tools/BUILD
index 30303bf..03b4fe9 100644
--- a/tools/BUILD
+++ b/tools/BUILD
@@ -80,6 +80,18 @@
)
cc_binary(
+ name = "iree-dump-instruments",
+ srcs = ["iree-dump-instruments-main.c"],
+ deps = [
+ "//runtime/src/iree/base",
+ "//runtime/src/iree/base/internal:file_io",
+ "//runtime/src/iree/base/internal/flatcc:parsing",
+ "//runtime/src/iree/schemas/instruments",
+ "//runtime/src/iree/schemas/instruments:dispatch_def_c_fbs",
+ ],
+)
+
+cc_binary(
name = "iree-dump-module",
srcs = ["iree-dump-module-main.c"],
deps = [
@@ -154,6 +166,7 @@
"//runtime/src/iree/tooling:comparison",
"//runtime/src/iree/tooling:context_util",
"//runtime/src/iree/tooling:device_util",
+ "//runtime/src/iree/tooling:instrument_util",
"//runtime/src/iree/tooling:vm_util",
"//runtime/src/iree/vm",
],
diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt
index bd01660..01111d2 100644
--- a/tools/CMakeLists.txt
+++ b/tools/CMakeLists.txt
@@ -102,6 +102,20 @@
iree_cc_binary(
NAME
+ iree-dump-instruments
+ SRCS
+ "iree-dump-instruments-main.c"
+ DEPS
+ flatcc::runtime
+ iree::base
+ iree::base::internal::file_io
+ iree::base::internal::flatcc::parsing
+ iree::schemas::instruments
+ iree::schemas::instruments::dispatch_def_c_fbs
+)
+
+iree_cc_binary(
+ NAME
iree-dump-module
SRCS
"iree-dump-module-main.c"
@@ -128,6 +142,7 @@
iree::tooling::comparison
iree::tooling::context_util
iree::tooling::device_util
+ iree::tooling::instrument_util
iree::tooling::vm_util
iree::vm
)
diff --git a/tools/iree-dump-instruments-main.c b/tools/iree-dump-instruments-main.c
new file mode 100644
index 0000000..8c219e1
--- /dev/null
+++ b/tools/iree-dump-instruments-main.c
@@ -0,0 +1,295 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <stdio.h>
+
+#include "iree/base/api.h"
+#include "iree/base/internal/file_io.h"
+#include "iree/schemas/instruments/dispatch.h"
+
+// NOTE: include order matters:
+#include "iree/base/internal/flatcc/parsing.h"
+#include "iree/schemas/instruments/dispatch_def_reader.h"
+
+typedef struct {
+ iree_instruments_DispatchFunctionDef_vec_t functions_def;
+ iree_instruments_DispatchSiteDef_vec_t dispatch_sites_def;
+} iree_dispatch_metadata_t;
+
+static iree_status_t iree_tooling_dump_dispatch_metadata(
+ const uint8_t* flatbuffer_ptr, iree_host_size_t flatbuffer_size,
+ iree_dispatch_metadata_t* out_metadata, FILE* stream) {
+ memset(out_metadata, 0, sizeof(*out_metadata));
+
+ iree_instruments_DispatchInstrumentDef_table_t instr_def =
+ iree_instruments_DispatchInstrumentDef_as_root(flatbuffer_ptr);
+
+ iree_instruments_DispatchFunctionDef_vec_t functions_def =
+ iree_instruments_DispatchInstrumentDef_functions(instr_def);
+ out_metadata->functions_def = functions_def;
+ for (iree_host_size_t i = 0;
+ i < iree_instruments_DispatchFunctionDef_vec_len(functions_def); ++i) {
+ fprintf(stream, "\n");
+ iree_instruments_DispatchFunctionDef_table_t function_def =
+ iree_instruments_DispatchFunctionDef_vec_at(functions_def, i);
+ flatbuffers_string_t name =
+ iree_instruments_DispatchFunctionDef_name(function_def);
+ fprintf(stream,
+ "//"
+ "===---------------------------------------------------------------"
+ "-------===//\n");
+ fprintf(stream, "// export[%" PRIhsz "]: %s\n", i, name);
+ fprintf(stream,
+ "//"
+ "===---------------------------------------------------------------"
+ "-------===//\n");
+ flatbuffers_string_t target =
+ iree_instruments_DispatchFunctionDef_target(function_def);
+ if (target) fprintf(stream, "// target: %s\n", target);
+ flatbuffers_string_t layout =
+ iree_instruments_DispatchFunctionDef_layout(function_def);
+ if (layout) fprintf(stream, "// layout: %s\n", layout);
+ flatbuffers_string_t source =
+ iree_instruments_DispatchFunctionDef_source(function_def);
+ if (source) fprintf(stream, "%s\n", source);
+ fprintf(stream, "\n");
+ }
+
+ fprintf(stream,
+ "//"
+ "===---------------------------------------------------------------"
+ "-------===//\n");
+ iree_instruments_DispatchSiteDef_vec_t dispatch_sites_def =
+ iree_instruments_DispatchInstrumentDef_sites(instr_def);
+ out_metadata->dispatch_sites_def = dispatch_sites_def;
+ for (iree_host_size_t i = 0;
+ i < iree_instruments_DispatchSiteDef_vec_len(dispatch_sites_def); ++i) {
+ iree_instruments_DispatchSiteDef_table_t dispatch_site_def =
+ iree_instruments_DispatchSiteDef_vec_at(dispatch_sites_def, i);
+ iree_instruments_DispatchFunctionDef_table_t function_def =
+ iree_instruments_DispatchFunctionDef_vec_at(
+ functions_def,
+ iree_instruments_DispatchSiteDef_function(dispatch_site_def));
+ flatbuffers_string_t name =
+ iree_instruments_DispatchFunctionDef_name(function_def);
+ fprintf(stream, "// dispatch site %" PRIhsz ": %s\n", i, name);
+ }
+ fprintf(stream,
+ "//"
+ "===---------------------------------------------------------------"
+ "-------===//\n\n");
+
+ return iree_ok_status();
+}
+
+static void iree_tooling_dump_print_value(
+ iree_instrument_dispatch_value_type_t type, uint64_t raw_value,
+ FILE* stream) {
+ union {
+ int8_t i8;
+ int16_t i16;
+ int32_t i32;
+ int64_t i64;
+ float f32;
+ double f64;
+ uint8_t value_storage[sizeof(uint64_t)];
+ } value = {.i64 = raw_value};
+ switch (type) {
+ case IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_SINT_8:
+ fprintf(stream, "%" PRId8, value.i8);
+ break;
+ case IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_UINT_8:
+ fprintf(stream, "%" PRIu8, value.i8);
+ break;
+ case IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_SINT_16:
+ fprintf(stream, "%" PRId16, value.i16);
+ break;
+ case IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_UINT_16:
+ fprintf(stream, "%" PRIu16, value.i16);
+ break;
+ case IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_SINT_32:
+ fprintf(stream, "%" PRId32, value.i32);
+ break;
+ case IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_UINT_32:
+ fprintf(stream, "%" PRIu32, value.i32);
+ break;
+ case IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_SINT_64:
+ fprintf(stream, "%" PRId64, value.i64);
+ break;
+ case IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_UINT_64:
+ fprintf(stream, "%" PRIu64, value.i64);
+ break;
+ case IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_POINTER:
+ fprintf(stream, "%16" PRIX64, value.i64);
+ break;
+ case IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_FLOAT_16:
+ case IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_BFLOAT_16:
+ fprintf(stream, "%4" PRIX16, value.i16);
+ break;
+ case IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_FLOAT_32:
+ fprintf(stream, "%e %f", value.f32, value.f32);
+ break;
+ case IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_FLOAT_64:
+ fprintf(stream, "%e %f", value.f64, value.f64);
+ break;
+ default:
+ fprintf(stream, "<<unknown type: %02X>>", (uint32_t)type);
+ break;
+ }
+}
+
+static iree_status_t iree_tooling_dump_dispatch_ringbuffer(
+ const uint8_t* data_ptr, iree_host_size_t data_size,
+ const iree_dispatch_metadata_t* metadata, FILE* stream) {
+ const uint64_t ring_size = data_size - IREE_INSTRUMENT_DISPATCH_PADDING;
+ const uint8_t* ring_data = data_ptr;
+ const uint64_t ring_head = *(const uint64_t*)(ring_data + data_size - 8);
+ const uint64_t ring_range = iree_min(ring_head, ring_size);
+
+ for (iree_host_size_t i = 0; i < ring_range;) {
+ const iree_instrument_dispatch_header_t* header =
+ (const iree_instrument_dispatch_header_t*)(ring_data + i);
+ switch (header->tag) {
+ case IREE_INSTRUMENT_DISPATCH_TYPE_WORKGROUP: {
+ const iree_instrument_dispatch_workgroup_t* workgroup =
+ (const iree_instrument_dispatch_workgroup_t*)header;
+ iree_instruments_DispatchSiteDef_table_t dispatch_site_def =
+ iree_instruments_DispatchSiteDef_vec_at(
+ metadata->dispatch_sites_def, workgroup->dispatch_id);
+ iree_instruments_DispatchFunctionDef_table_t function_def =
+ iree_instruments_DispatchFunctionDef_vec_at(
+ metadata->functions_def,
+ iree_instruments_DispatchSiteDef_function(dispatch_site_def));
+ flatbuffers_string_t name_def =
+ iree_instruments_DispatchFunctionDef_name(function_def);
+ fprintf(stream,
+ "%016" PRIX64
+ " | WORKGROUP dispatch(%u %s %ux%ux%u) %u,%u,%u pid:%u\n",
+ (uint64_t)i, workgroup->dispatch_id, name_def,
+ workgroup->workgroup_count_x, workgroup->workgroup_count_y,
+ workgroup->workgroup_count_z, workgroup->workgroup_id_x,
+ workgroup->workgroup_id_y, workgroup->workgroup_id_z,
+ workgroup->processor_id);
+ i += sizeof(*workgroup);
+ break;
+ }
+ case IREE_INSTRUMENT_DISPATCH_TYPE_PRINT: {
+ const iree_instrument_dispatch_print_t* print =
+ (const iree_instrument_dispatch_print_t*)header;
+ fprintf(stream, "%016" PRIX64 " | PRINT %.*s\n",
+ print->workgroup_offset, (int)print->length, print->data);
+ i += iree_host_align(sizeof(*print) + print->length, 16);
+ break;
+ }
+ case IREE_INSTRUMENT_DISPATCH_TYPE_VALUE: {
+ const iree_instrument_dispatch_value_t* value =
+ (const iree_instrument_dispatch_value_t*)header;
+ fprintf(stream,
+ "%016" PRIX64 " | VALUE %04u = ", value->workgroup_offset,
+ (uint32_t)value->ordinal);
+ iree_tooling_dump_print_value(value->type, value->bits, stream);
+ fputc('\n', stream);
+ i += sizeof(*value);
+ break;
+ }
+ case IREE_INSTRUMENT_DISPATCH_TYPE_MEMORY_LOAD: {
+ const iree_instrument_dispatch_memory_op_t* op =
+ (const iree_instrument_dispatch_memory_op_t*)header;
+ fprintf(stream, "%016" PRIX64 " | LOAD %016" PRIX64 " %u\n",
+ op->workgroup_offset, op->address, (int)op->length);
+ i += sizeof(*op);
+ break;
+ }
+ case IREE_INSTRUMENT_DISPATCH_TYPE_MEMORY_STORE: {
+ const iree_instrument_dispatch_memory_op_t* op =
+ (const iree_instrument_dispatch_memory_op_t*)header;
+ fprintf(stream, "%016" PRIX64 " | STORE %016" PRIX64 " %u\n",
+ op->workgroup_offset, op->address, (int)op->length);
+ i += sizeof(*op);
+ break;
+ }
+ default:
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "unimplemented dispatch instr type: %u",
+ (uint32_t)header->tag);
+ }
+ }
+
+ return iree_ok_status();
+}
+
+static iree_status_t iree_tooling_dump_instrument_file(
+ iree_const_byte_span_t file_contents, FILE* stream) {
+ const uint8_t* file_ptr = file_contents.data;
+ iree_host_size_t file_size = file_contents.data_length;
+
+ iree_dispatch_metadata_t dispatch_metadata = {0};
+ for (iree_host_size_t file_offset = 0; file_offset < file_size;) {
+ const iree_idbts_chunk_header_t* header =
+ (const iree_idbts_chunk_header_t*)(file_ptr + file_offset);
+ const uint8_t* payload = file_ptr + file_offset + sizeof(*header);
+ switch (header->type) {
+ case IREE_IDBTS_CHUNK_TYPE_DISPATCH_METADATA: {
+ IREE_RETURN_IF_ERROR(iree_tooling_dump_dispatch_metadata(
+ payload, header->content_length, &dispatch_metadata, stream));
+ break;
+ }
+ case IREE_IDBTS_CHUNK_TYPE_DISPATCH_RINGBUFFER: {
+ IREE_RETURN_IF_ERROR(iree_tooling_dump_dispatch_ringbuffer(
+ payload, header->content_length, &dispatch_metadata, stream));
+ break;
+ }
+ default:
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "unimplemented chunk type: %u",
+ (uint32_t)header->type);
+ }
+ file_offset +=
+ sizeof(*header) + iree_host_align(header->content_length, 16);
+ }
+
+ return iree_ok_status();
+}
+
+int main(int argc, char** argv) {
+ if (argc < 2) {
+ fprintf(stderr,
+ "Syntax: iree-dump-instruments instruments.bin > instruments.txt\n"
+ "Example usage:\n"
+ " $ iree-compile \\n"
+ " --iree-hal-target-backends=llvm-cpu \\n"
+ " --iree-hal-instrument-dispatches=16mib \\n"
+ " --iree-llvmcpu-instrument-memory-accesses=false \\n"
+ " runtime/src/iree/runtime/testdata/simple_mul.mlir \\n"
+ " -o=simple_mul_instr.vmfb\n"
+ " $ iree-run-module \\n"
+ " --device=local-sync \\n"
+ " --module=simple_mul_instr.vmfb \\n"
+ " --function=simple_mul \\n"
+ " --input=4xf32=2 \\n"
+ " --input=4xf32=4 \\n"
+ " --instrument_file=instrument.bin\n"
+ " $ iree-dump-instruments instrument.bin\n"
+ "\n");
+ return 1;
+ }
+
+ iree_file_contents_t* file_contents = NULL;
+ iree_status_t status =
+ iree_file_read_contents(argv[1], iree_allocator_system(), &file_contents);
+ if (iree_status_is_ok(status)) {
+ status =
+ iree_tooling_dump_instrument_file(file_contents->const_buffer, stdout);
+ }
+ iree_file_contents_free(file_contents);
+
+ if (!iree_status_is_ok(status)) {
+ iree_status_fprint(stderr, status);
+ iree_status_free(status);
+ return EXIT_FAILURE;
+ }
+ return EXIT_SUCCESS;
+}
diff --git a/tools/iree-run-module-main.cc b/tools/iree-run-module-main.cc
index 4f263bf..6063e86 100644
--- a/tools/iree-run-module-main.cc
+++ b/tools/iree-run-module-main.cc
@@ -20,6 +20,7 @@
#include "iree/tooling/comparison.h"
#include "iree/tooling/context_util.h"
#include "iree/tooling/device_util.h"
+#include "iree/tooling/instrument_util.h"
#include "iree/tooling/vm_util.h"
#include "iree/vm/api.h"
@@ -188,6 +189,10 @@
*out_exit_code = did_match ? EXIT_SUCCESS : EXIT_FAILURE;
}
+ // Grab any instrumentation data present in the module and write it to disk.
+ IREE_RETURN_IF_ERROR(
+ iree_tooling_process_instrument_data(context.get(), host_allocator));
+
// Release resources before gathering statistics.
inputs.reset();
outputs.reset();