Adding `-iree-scheduling-dump-statistics-*` flag and pass. (#8167)
This dumps scheduling information derived from the program after the full
stream dialect transformation pipeline has run. This is before HAL target
backends have had a chance to translate executables and still contains
the post-flow ops (linalg, etc).
Use in iree-translate with
`-iree-scheduling-dump-statistics-format=csv` (or `pretty`, TBD)
By default things go to stderr but can be directed to a file with
`-iree-scheduling-dump-statistics-file=path` for easier automation.
Use in iree-opt with
`-pass-pipeline=iree-stream-transformation-pipeline{dump-statistics-format=csv}`.
diff --git a/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp b/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
index 9d05af4..6c884f0 100644
--- a/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
+++ b/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
@@ -57,10 +57,7 @@
static ExecutableLayout deriveExportLayout(
IREE::Stream::ExecutableExportOp exportOp,
SmallVector<IREE::Stream::CmdDispatchOp> &dispatchOps) {
- auto executableOp = exportOp->getParentOfType<IREE::Stream::ExecutableOp>();
- assert(executableOp && "unnested export");
- auto funcOp = executableOp.getInnerModule().lookupSymbol<mlir::FuncOp>(
- exportOp.function_ref());
+ auto funcOp = exportOp.getFunctionRef();
assert(funcOp && "export target not found");
// TODO(#3502): a real derivation based on dispatch sites.
@@ -130,6 +127,7 @@
executableLayout.setLayouts.push_back(setLayout);
LLVM_DEBUG({
+ auto executableOp = exportOp->getParentOfType<IREE::Stream::ExecutableOp>();
llvm::dbgs() << "deriveExportLayout(@" << executableOp.sym_name() << "::@"
<< exportOp.sym_name() << "):\n";
executableLayout.print(llvm::dbgs());
diff --git a/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
index 6ad56ee..c211eb6 100644
--- a/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
+++ b/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
@@ -1963,6 +1963,14 @@
builder.getStringAttr(sym_name), function_ref);
}
+::mlir::FuncOp ExecutableExportOp::getFunctionRef() {
+ auto executableOp =
+ this->getOperation()->getParentOfType<IREE::Stream::ExecutableOp>();
+ if (!executableOp) return {};
+ return executableOp.getInnerModule().lookupSymbol<::mlir::FuncOp>(
+ function_ref());
+}
+
//===----------------------------------------------------------------------===//
// stream.binding.subspan
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Stream/IR/StreamOps.td b/iree/compiler/Dialect/Stream/IR/StreamOps.td
index 755d86a..b647a6e 100644
--- a/iree/compiler/Dialect/Stream/IR/StreamOps.td
+++ b/iree/compiler/Dialect/Stream/IR/StreamOps.td
@@ -2759,6 +2759,10 @@
"FlatSymbolRefAttr":$function_ref
)>,
];
+
+ let extraClassDeclaration = [{
+ ::mlir::FuncOp getFunctionRef();
+ }];
}
def Stream_BindingSubspanOp : Stream_PureOp<"binding.subspan", [
diff --git a/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchArguments.cpp b/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchArguments.cpp
index 5df4c17..0165822 100644
--- a/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchArguments.cpp
+++ b/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchArguments.cpp
@@ -464,8 +464,7 @@
// Operands/resources on the func are in an arbitrary order; get maps that
// lets us go from dispatch site operand/resource to function argument.
- auto funcOp = executableOp.getInnerModule().lookupSymbol<mlir::FuncOp>(
- exportOp.function_refAttr());
+ auto funcOp = exportOp.getFunctionRef();
auto operandToArgMap =
IREE::Stream::CmdDispatchOp::makeOperandToArgMap(funcOp);
auto resourceToArgMap =
diff --git a/iree/compiler/Dialect/Stream/Transforms/BUILD b/iree/compiler/Dialect/Stream/Transforms/BUILD
index 770c6f5..31a1b30 100644
--- a/iree/compiler/Dialect/Stream/Transforms/BUILD
+++ b/iree/compiler/Dialect/Stream/Transforms/BUILD
@@ -17,6 +17,7 @@
srcs = [
"AnnotateDispatchArguments.cpp",
"ConvertToStream.cpp",
+ "DumpStatistics.cpp",
"ElideAsyncCopies.cpp",
"EncodeTensors.cpp",
"FoldUniformOperands.cpp",
diff --git a/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
index 23dc718..cc2f99e 100644
--- a/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
@@ -19,6 +19,7 @@
SRCS
"AnnotateDispatchArguments.cpp"
"ConvertToStream.cpp"
+ "DumpStatistics.cpp"
"ElideAsyncCopies.cpp"
"EncodeTensors.cpp"
"FoldUniformOperands.cpp"
diff --git a/iree/compiler/Dialect/Stream/Transforms/DumpStatistics.cpp b/iree/compiler/Dialect/Stream/Transforms/DumpStatistics.cpp
new file mode 100644
index 0000000..b5e9830
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Transforms/DumpStatistics.cpp
@@ -0,0 +1,555 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <utility>
+
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTraits.h"
+#include "iree/compiler/Dialect/Stream/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Stream {
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Usage analysis
+//===----------------------------------------------------------------------===//
+
+struct UsageInfo {
+ // util.globals holding resources mapped by name.
+ llvm::MapVector<StringRef, IREE::Util::GlobalOp> resourceGlobalOps;
+
+ // stream.executable ops mapped by name.
+ llvm::MapVector<StringRef, IREE::Stream::ExecutableOp> executableOps;
+ // stream.executable exported function -> dispatches to it.
+ llvm::MapVector<mlir::FuncOp, SmallVector<IREE::Stream::CmdDispatchOp>>
+ exportDispatchOps;
+
+ // TODO(benvanik): resource allocations.
+
+ // stream.cmd.execute ops containing all relevant device commands.
+ SmallVector<IREE::Stream::CmdExecuteOp> executeOps;
+ SmallVector<IREE::Stream::ResourceAllocaOp> allocaOps;
+
+ // stream.timepoint.await ops indicating host/device synchronization.
+ SmallVector<IREE::Stream::TimepointAwaitOp> awaitOps;
+
+ void analyze(mlir::ModuleOp moduleOp) {
+ SymbolTable symbolTable(moduleOp);
+ for (auto globalOp : moduleOp.getOps<IREE::Util::GlobalOp>()) {
+ if (globalOp.type().isa<IREE::Stream::ResourceType>()) {
+ resourceGlobalOps[globalOp.getName()] = globalOp;
+ }
+ }
+ for (auto executableOp : moduleOp.getOps<IREE::Stream::ExecutableOp>()) {
+ executableOps[executableOp.getName()] = executableOp;
+ }
+ for (auto &funcLikeOp : moduleOp.getOps()) {
+ if (!funcLikeOp.hasTrait<OpTrait::FunctionLike>()) continue;
+ funcLikeOp.walk([&](Operation *op) {
+ TypeSwitch<Operation *>(op)
+ .Case<IREE::Stream::ResourceAllocaOp>(
+ [&](auto op) { allocaOps.push_back(op); })
+ .Case<IREE::Stream::CmdExecuteOp>(
+ [&](auto op) { executeOps.push_back(op); })
+ .Case<IREE::Stream::TimepointAwaitOp>(
+ [&](auto op) { awaitOps.push_back(op); });
+ });
+ }
+ for (auto executeOp : executeOps) {
+ executeOp.walk([&](IREE::Stream::CmdDispatchOp dispatchOp) {
+ auto exportOp = cast<IREE::Stream::ExecutableExportOp>(
+ symbolTable.lookupSymbolIn(moduleOp, dispatchOp.entry_point()));
+ assert(exportOp && "missing executable/export");
+ auto funcOp = exportOp.getFunctionRef();
+ assert(funcOp && "missing exported function");
+ exportDispatchOps[funcOp].push_back(dispatchOp);
+ });
+ }
+ }
+};
+
+// TODO(benvanik): StaticSize helper or something for the dynamic bit.
+struct Statistics {
+ // Globals:
+ size_t constantCount = 0;
+ int64_t constantSize = 0;
+ bool constantSizeDynamic = false;
+ size_t variableCount = 0;
+ int64_t variableSize = 0;
+ bool variableSizeDynamic = false;
+
+ // Synchronization:
+ size_t awaitCount = 0;
+
+ // Execution:
+ size_t submissionCount = 0;
+ int64_t transientSize = 0;
+ bool transientSizeDynamic = false;
+ // TODO(benvanik): add fill/copy sizes (when possible).
+ size_t fillCount = 0;
+ size_t copyCount = 0;
+ size_t dispatchCount = 0;
+
+ // Executables:
+ size_t executableCount = 0;
+
+ void analyze(const UsageInfo &usageInfo) {
+ // Globals:
+ for (auto it : usageInfo.resourceGlobalOps) {
+ auto globalType = it.second.type().dyn_cast<IREE::Stream::ResourceType>();
+ if (!globalType) continue;
+ // TODO(benvanik): analyze size in UsageInfo.
+ switch (globalType.getLifetime()) {
+ case IREE::Stream::Lifetime::Constant:
+ ++constantCount;
+ break;
+ case IREE::Stream::Lifetime::Variable:
+ ++variableCount;
+ break;
+ default:
+ continue;
+ }
+ }
+
+ // Synchronization:
+ awaitCount = usageInfo.awaitOps.size();
+
+ // Execution:
+ submissionCount = usageInfo.executeOps.size();
+ for (auto allocaOp : usageInfo.allocaOps) {
+ APInt allocaSize;
+ if (matchPattern(allocaOp.storage_size(), m_ConstantInt(&allocaSize))) {
+ transientSize += allocaSize.getSExtValue();
+ } else {
+ transientSizeDynamic = true;
+ }
+ }
+ for (auto executeOp : usageInfo.executeOps) {
+ executeOp.walk([&](Operation *op) {
+ TypeSwitch<Operation *>(op)
+ .Case<IREE::Stream::CmdFillOp>([&](auto op) { ++fillCount; })
+ .Case<IREE::Stream::CmdCopyOp>([&](auto op) { ++copyCount; })
+ .Case<IREE::Stream::CmdDispatchOp>(
+ [&](auto op) { ++dispatchCount; });
+ });
+ }
+
+ // Executables:
+ executableCount = usageInfo.executableOps.size();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Pretty printing
+//===----------------------------------------------------------------------===//
+
+static void prettyPrintOpBreadcrumb(Operation *op, llvm::raw_fd_ostream &os) {
+ auto parentOp = op->getParentOp();
+ if (parentOp) {
+ prettyPrintOpBreadcrumb(parentOp, os);
+ os << " > ";
+ }
+ os << op->getName();
+ if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
+ os << " @" << symbolOp.getName();
+ }
+}
+
+static void prettyPrintSectionHeader(llvm::Twine header,
+ llvm::raw_fd_ostream &os) {
+ os << "//"
+ "======================================================================"
+ "======//\n";
+ os << "// " << header << "\n";
+ os << "//"
+ "======================================================================"
+ "======//\n";
+}
+
+static void prettyPrintItemHeader(llvm::Twine header,
+ llvm::raw_fd_ostream &os) {
+ os << "//"
+ "----------------------------------------------------------------------"
+ "------//\n";
+ os << "// " << header << "\n";
+ os << "//"
+ "----------------------------------------------------------------------"
+ "------//\n";
+}
+
+static void prettyPrintStatistics(const UsageInfo &usageInfo,
+ llvm::raw_fd_ostream &os) {
+ prettyPrintSectionHeader("Aggregate Statistics (static, whole-program)", os);
+ os << "//\n";
+
+ Statistics stats;
+ stats.analyze(usageInfo);
+
+ os << llvm::formatv("// Constants: {0}, ", stats.constantCount);
+ os << llvm::formatv(
+ "{0}{1} B ({2:F2} MiB)\n", stats.constantSizeDynamic ? "minimum " : "",
+ stats.constantSize, stats.constantSize / (1 * 1024 * 1024.0f));
+ os << llvm::formatv("// Variables: {0}, ", stats.variableCount);
+ os << llvm::formatv(
+ "{0}{1} B ({2:F2} MiB)\n", stats.variableSizeDynamic ? "minimum " : "",
+ stats.variableSize, stats.variableSize / (1 * 1024 * 1024.0f));
+
+ os << llvm::formatv("// D->H Syncs: {0}\n", stats.awaitCount);
+
+ os << llvm::formatv("// Submissions: {0}, using cumulative ",
+ stats.submissionCount);
+ os << llvm::formatv(
+ "{0}{1} B ({2:F2} MiB)\n", stats.transientSizeDynamic ? "minimum " : "",
+ stats.transientSize, stats.transientSize / (1 * 1024 * 1024.0f));
+
+ os << llvm::formatv("// DMA Fills: {0}\n", stats.fillCount);
+ os << llvm::formatv("// DMA Copies: {0}\n", stats.copyCount);
+ os << llvm::formatv("// Dispatches: {0}\n", stats.dispatchCount);
+
+ os << llvm::formatv(
+ "// Executables: {0}, {1}% reuse\n", stats.executableCount,
+ (int)std::roundf(
+ (1.0f - (stats.executableCount / (float)stats.dispatchCount)) *
+ 100.0f));
+
+ os << "//\n";
+}
+
+static void prettyPrintGlobalInfo(const UsageInfo &usageInfo, bool verbose,
+ llvm::raw_fd_ostream &os) {
+ prettyPrintSectionHeader("Constants / Variables", os);
+ os << "//\n";
+
+ // TODO(benvanik): print global information:
+ // - number of resource globals: constants/variables
+ // - util.byte_buffer.constant sizes (fed into stream.resource.try_map/map)
+ // - variable allocation sizes
+ os << "// TODO\n";
+
+ os << "//\n";
+}
+
+static void prettyPrintSyncInfo(const UsageInfo &usageInfo, bool verbose,
+ llvm::raw_fd_ostream &os) {
+ prettyPrintSectionHeader("Synchronization", os);
+ os << "//\n";
+
+ // TODO(benvanik): print host <-> device information:
+ // - number of stream.timepoint.awaits
+ // - staging buffer allocation sizes
+ // - number of buffer mapping operations
+ // - estimated number of submissions (execution with await in the middle)
+ os << "// TODO\n";
+
+ os << "//\n";
+}
+
+static void prettyPrintStreamInfo(const UsageInfo &usageInfo,
+ IREE::Stream::CmdExecuteOp executeOp,
+ llvm::raw_fd_ostream &os) {
+ auto parentOp = executeOp->getParentWithTrait<mlir::OpTrait::FunctionLike>();
+
+ prettyPrintItemHeader(
+ llvm::formatv("stream.cmd.execute", parentOp->getName().getStringRef()),
+ os);
+ os << "// ";
+ prettyPrintOpBreadcrumb(executeOp, os);
+ os << "\n";
+ os << "//\n";
+
+ // TODO(benvanik): print stream information (for each stream.cmd.execute):
+ // - number of unique resources captured
+ // - number of commands of each type
+ // - % concurrently executable
+ os << "// TODO\n";
+}
+
+static void prettyPrintAllStreamInfo(const UsageInfo &usageInfo, bool verbose,
+ llvm::raw_fd_ostream &os) {
+ prettyPrintSectionHeader("Streams", os);
+ os << "//\n";
+
+ // TODO(benvanik): aggregate stats:
+ // - number of streams
+ // - (eventually) number of streams per affinity
+ // - average commands per stream
+ // - streams with host dependencies/device dependencies (awaits/etc)
+ os << "// TODO\n";
+
+ os << "//\n";
+ for (auto executeOp : usageInfo.executeOps) {
+ prettyPrintStreamInfo(usageInfo, executeOp, os);
+ os << "//\n";
+ }
+}
+
+static void prettyPrintExecutableExportInfo(
+ const UsageInfo &usageInfo, IREE::Stream::ExecutableOp executableOp,
+ IREE::Stream::ExecutableExportOp exportOp, llvm::raw_fd_ostream &os) {
+ auto funcOp = exportOp.getFunctionRef();
+ prettyPrintItemHeader(
+ llvm::formatv("stream.executable.export @{0}::@{1}",
+ executableOp.getName(), exportOp.getName()),
+ os);
+ os << "// ";
+ prettyPrintOpBreadcrumb(funcOp, os);
+ os << "//\n";
+ os << "//\n";
+
+ // TODO(benvanik): interface and usage stats:
+ // - operand info
+ // - binding info
+ // - misaligned/unaligned/etc - big warning
+ // - incoming dispatches
+ // - workload params
+
+ // TODO(benvanik): ask codegen team if they want anything like a list of
+ // linalg named ops, etc.
+
+ os << "// TODO\n";
+}
+
+static void prettyPrintExecutableInfo(const UsageInfo &usageInfo,
+ IREE::Stream::ExecutableOp executableOp,
+ llvm::raw_fd_ostream &os) {
+ // Today we pretty much have one export per executable here as we are
+ // performing linking in the HAL. Once we link/deduplicate/etc in streams then
+ // we'll want to make this segmentation nicer.
+ for (auto exportOp :
+ executableOp.getOps<IREE::Stream::ExecutableExportOp>()) {
+ prettyPrintExecutableExportInfo(usageInfo, executableOp, exportOp, os);
+ }
+}
+
+static void prettyPrintAllExecutableInfo(const UsageInfo &usageInfo,
+ bool verbose,
+ llvm::raw_fd_ostream &os) {
+ prettyPrintSectionHeader("Executables", os);
+ os << "//\n";
+
+ // TODO(benvanik): aggregate stats:
+ // - number of executables
+ // - total number of exports
+ // - average bindings/operands per export
+ os << "// TODO\n";
+
+ os << "//\n";
+ for (auto it : usageInfo.executableOps) {
+ prettyPrintExecutableInfo(usageInfo, it.second, os);
+ os << "//\n";
+ }
+}
+
+static void prettyPrintUsageInfo(const UsageInfo &usageInfo, bool verbose,
+ llvm::raw_fd_ostream &os) {
+ prettyPrintStatistics(usageInfo, os);
+ prettyPrintGlobalInfo(usageInfo, verbose, os);
+ prettyPrintSyncInfo(usageInfo, verbose, os);
+ prettyPrintAllStreamInfo(usageInfo, verbose, os);
+ prettyPrintAllExecutableInfo(usageInfo, verbose, os);
+}
+
+//===----------------------------------------------------------------------===//
+// CSV tables
+//===----------------------------------------------------------------------===//
+
+static void dumpAggregateCSVTable(const UsageInfo &usageInfo,
+ llvm::raw_fd_ostream &os) {
+ Statistics stats;
+ stats.analyze(usageInfo);
+
+ os << R"("Constants","Constant Size","Variables","Variable Size","Awaits","Submissions","Transient Size","Fills","Copies","Dispatches","Executables")";
+ os << "\n";
+
+ // Globals:
+ os << llvm::formatv("{0},{1},{2},{3},", stats.constantCount,
+ stats.constantSize, stats.variableCount,
+ stats.variableSize);
+
+ // Synchronization:
+ os << llvm::formatv("{0},", stats.awaitCount);
+
+ // Execution:
+ os << llvm::formatv("{0},{1},{2},{3},{4},", stats.submissionCount,
+ stats.transientSize, stats.fillCount, stats.copyCount,
+ stats.dispatchCount);
+
+ // Executables:
+ os << llvm::formatv("{0}", stats.executableCount);
+
+ os << "\n";
+ os << "\n";
+}
+
+static void dumpExecutionCSVTable(const UsageInfo &usageInfo,
+ IREE::Stream::CmdExecuteOp executeOp,
+ llvm::raw_fd_ostream &os) {
+ os << "; ";
+ prettyPrintOpBreadcrumb(executeOp, os);
+ os << "\n";
+ os << R"("Depth","Command","Symbol","Length","Invocations","X","Y","Z","Operands","Resources")";
+ os << "\n";
+ std::function<void(Operation *)> dumpRow;
+ int depth = 0;
+ dumpRow = [&](Operation *op) {
+ TypeSwitch<Operation *>(op)
+ .Case<IREE::Stream::CmdSerialOp>([&](auto op) {
+ ++depth;
+ for (auto &nestedOp : op.body().front()) dumpRow(&nestedOp);
+ --depth;
+ })
+ .Case<IREE::Stream::CmdConcurrentOp>([&](auto op) {
+ ++depth;
+ for (auto &nestedOp : op.body().front()) dumpRow(&nestedOp);
+ --depth;
+ })
+ .Case<IREE::Stream::CmdFillOp>([&](auto op) {
+ APInt length;
+ matchPattern(op.target_length(), m_ConstantInt(&length));
+ os << llvm::formatv(R"({0},"fill",,{1},,,,,,)", depth, length);
+ os << "\n";
+ })
+ .Case<IREE::Stream::CmdCopyOp>([&](auto op) {
+ APInt length;
+ matchPattern(op.length(), m_ConstantInt(&length));
+ os << llvm::formatv(R"({0},"copy",,{1},,,,,,)", depth, length);
+ os << "\n";
+ })
+ .Case<IREE::Stream::CmdDispatchOp>([&](auto op) {
+ auto workload = op.workgroup_count();
+ APInt workloadX;
+ APInt workloadY;
+ APInt workloadZ;
+ matchPattern(workload[0], m_ConstantInt(&workloadX));
+ matchPattern(workload[1], m_ConstantInt(&workloadY));
+ matchPattern(workload[2], m_ConstantInt(&workloadZ));
+ os << llvm::formatv(
+ R"({0},"dispatch","{1}",,{2},{3},{4},{5},{6},{7})", depth,
+ op.entry_point(), workloadX * workloadY * workloadZ, workloadX,
+ workloadY, workloadZ, op.operands().size(),
+ op.resources().size());
+ os << "\n";
+ });
+ };
+ for (auto &op : executeOp.body().front()) {
+ dumpRow(&op);
+ }
+ os << "\n";
+}
+
+static void dumpCSVTables(const UsageInfo &usageInfo,
+ llvm::raw_fd_ostream &os) {
+ os << ";\n";
+ os << "; Aggregate Statistics (static, whole-program)\n";
+ os << ";\n\n";
+ dumpAggregateCSVTable(usageInfo, os);
+
+ // TODO(benvanik): globals/syncs/streams/etc.
+
+ os << ";\n";
+ os << "; Execution\n";
+ os << ";\n\n";
+ for (auto executeOp : usageInfo.executeOps) {
+ dumpExecutionCSVTable(usageInfo, executeOp, os);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// -iree-stream-dump-statistics
+//===----------------------------------------------------------------------===//
+
+// Opens a canonical |filePath| for text output.
+// An empty path can be used to target stderr and `-` will go to stdout.
+// If the file cannot be opened stderr will be used.
+static std::unique_ptr<llvm::raw_fd_ostream> openOutputFile(
+ StringRef filePath) {
+ if (filePath.empty()) {
+ return std::make_unique<llvm::raw_fd_ostream>(2, false); // stderr
+ } else if (filePath == "-") {
+ return std::make_unique<llvm::raw_fd_ostream>(1, false); // stdout
+ } else {
+ std::error_code ec;
+ auto result = std::make_unique<llvm::raw_fd_ostream>(
+ filePath, ec, llvm::sys::fs::OF_TextWithCRLF);
+ if (!ec) return result;
+ llvm::errs() << "Error opening iree-stream-dump-statistics output file '"
+ << filePath << "'\n";
+ return std::make_unique<llvm::raw_fd_ostream>(2, false); // stderr.
+ }
+}
+
+class DumpStatisticsPass : public DumpStatisticsBase<DumpStatisticsPass> {
+ public:
+ DumpStatisticsPass() = default;
+ DumpStatisticsPass(DumpOutputFormat outputFormat, std::string outputFile) {
+ this->outputFormat = outputFormat;
+ this->outputFile = outputFile;
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREE::Stream::StreamDialect>();
+ registry.insert<IREE::Util::UtilDialect>();
+ }
+
+ void runOnOperation() override {
+ if (outputFormat == DumpOutputFormat::None) return;
+
+ // Open the output file we'll be streaming to.
+ // Since we are processing the entire module at once we overwrite the file.
+ auto os = openOutputFile(outputFile);
+
+ // Walk the module once to accumulate everything we care about.
+ auto moduleOp = getOperation();
+ UsageInfo usageInfo;
+ usageInfo.analyze(moduleOp);
+
+ switch (outputFormat) {
+ case DumpOutputFormat::Pretty:
+ case DumpOutputFormat::Verbose:
+ prettyPrintUsageInfo(usageInfo,
+ outputFormat == DumpOutputFormat::Verbose, *os);
+ break;
+ case DumpOutputFormat::CSV:
+ dumpCSVTables(usageInfo, *os);
+ break;
+ default:
+ break;
+ }
+
+ os->flush();
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<mlir::ModuleOp>> createDumpStatisticsPass(
+ DumpOutputFormat outputFormat, std::string outputFile) {
+ return std::make_unique<DumpStatisticsPass>(outputFormat, outputFile);
+}
+
+} // namespace Stream
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Stream/Transforms/FoldUniformOperands.cpp b/iree/compiler/Dialect/Stream/Transforms/FoldUniformOperands.cpp
index 3ecc3e1..034d569 100644
--- a/iree/compiler/Dialect/Stream/Transforms/FoldUniformOperands.cpp
+++ b/iree/compiler/Dialect/Stream/Transforms/FoldUniformOperands.cpp
@@ -282,8 +282,7 @@
auto &dispatchOps = entryDispatchMap[exportOp];
if (dispatchOps.empty()) continue; // no-op if no dispatches
- auto funcOp = dyn_cast<mlir::FuncOp>(SymbolTable::lookupSymbolIn(
- executableOp.getInnerModule(), exportOp.function_refAttr()));
+ auto funcOp = exportOp.getFunctionRef();
// Deduplicate operands that are correlated at all dispatch sites.
// We do this first so that we know all constants passed in are unique
diff --git a/iree/compiler/Dialect/Stream/Transforms/FuseDispatchBindings.cpp b/iree/compiler/Dialect/Stream/Transforms/FuseDispatchBindings.cpp
index 5359ee7..59137cb 100644
--- a/iree/compiler/Dialect/Stream/Transforms/FuseDispatchBindings.cpp
+++ b/iree/compiler/Dialect/Stream/Transforms/FuseDispatchBindings.cpp
@@ -387,8 +387,7 @@
// can do it for everything.
// Update the executable function to use the new bindings.
- auto funcOp = executableOp.getInnerModule().lookupSymbol<mlir::FuncOp>(
- exportOp.function_refAttr());
+ auto funcOp = exportOp.getFunctionRef();
assert(funcOp && "entry func not found");
updateExecutableSignature(executableOp, exportOp, funcOp, bindings);
diff --git a/iree/compiler/Dialect/Stream/Transforms/PassDetail.h b/iree/compiler/Dialect/Stream/Transforms/PassDetail.h
index 6fab181..a0e3567 100644
--- a/iree/compiler/Dialect/Stream/Transforms/PassDetail.h
+++ b/iree/compiler/Dialect/Stream/Transforms/PassDetail.h
@@ -14,6 +14,19 @@
namespace IREE {
namespace Stream {
+// TODO(benvanik): find a way to share this with IREEVM.h w/o circular deps.
+// Defines the output format of a dump pass.
+enum class DumpOutputFormat {
+ // Dumping disabled.
+ None = 0,
+ // Human-readable pretty printing.
+ Pretty = 1,
+ // Pretty printing with additional information that can result in large dumps.
+ Verbose = 2,
+ // Comma separated values for throwing into Sheets.
+ CSV = 3,
+};
+
#define GEN_PASS_CLASSES
#include "iree/compiler/Dialect/Stream/Transforms/Passes.h.inc" // IWYU pragma: keep
diff --git a/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
index 6215aa1..bd69d45 100644
--- a/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
@@ -279,6 +279,15 @@
buildStreamAsyncPassPipeline(passManager, transformOptions);
buildStreamCmdPassPipeline(passManager, transformOptions);
+ // Dump statistics before the deeper optimizations happen.
+ // Optimizations such as dispatch operand fusion remove information we can use
+ // to determine memory usage by dispatches.
+ if (transformOptions.dumpStatisticsFormat != DumpOutputFormat::None) {
+ passManager.addPass(IREE::Stream::createDumpStatisticsPass(
+ transformOptions.dumpStatisticsFormat,
+ transformOptions.dumpStatisticsFile));
+ }
+
//----------------------------------------------------------------------------
// Optimizations (may be required by some targets)
//----------------------------------------------------------------------------
diff --git a/iree/compiler/Dialect/Stream/Transforms/Passes.h b/iree/compiler/Dialect/Stream/Transforms/Passes.h
index fd5c3ce..1d5100b 100644
--- a/iree/compiler/Dialect/Stream/Transforms/Passes.h
+++ b/iree/compiler/Dialect/Stream/Transforms/Passes.h
@@ -8,6 +8,7 @@
#define IREE_COMPILER_DIALECT_STREAM_TRANSFORMS_PASSES_H_
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Stream/Transforms/PassDetail.h"
#include "llvm/ADT/StringMap.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
@@ -27,10 +28,33 @@
// TODO(benvanik): options for async/sync overrides.
Option<bool> optimizeBindings{
- *this, "optimize-bindings",
+ *this,
+ "optimize-bindings",
llvm::cl::desc(
"Enables binding fusion and dispatch site specialization."),
- llvm::cl::init(true)};
+ llvm::cl::init(true),
+ };
+
+ Option<DumpOutputFormat> dumpStatisticsFormat{
+ *this,
+ "dump-statistics-format",
+ llvm::cl::desc("Dumps statistics in the specified output format."),
+ llvm::cl::init(DumpOutputFormat::None),
+ llvm::cl::values(
+ clEnumValN(IREE::Stream::DumpOutputFormat::Pretty, "pretty",
+ "Human-readable pretty printed output."),
+ clEnumValN(IREE::Stream::DumpOutputFormat::Verbose, "verbose",
+ "Pretty printed output with additional IR."),
+ clEnumValN(IREE::Stream::DumpOutputFormat::CSV, "csv",
+ "Comma separated values.")),
+ };
+ Option<std::string> dumpStatisticsFile{
+ *this,
+ "dump-statistics-file",
+ llvm::cl::desc(
+ "File path to write to; or `` for stderr or `-` for stdout."),
+ llvm::cl::init(""),
+ };
};
// Adds a set of passes to the given pass manager that run the required flow
@@ -133,6 +157,10 @@
// Diagnostics
//===----------------------------------------------------------------------===//
+std::unique_ptr<OperationPass<mlir::ModuleOp>> createDumpStatisticsPass(
+ DumpOutputFormat outputFormat = DumpOutputFormat::Pretty,
+ std::string outputFile = "");
+
std::unique_ptr<OperationPass<mlir::ModuleOp>> createVerifyInputPass();
std::unique_ptr<OperationPass<mlir::ModuleOp>>
createVerifyLoweringToTensorsPass();
diff --git a/iree/compiler/Dialect/Stream/Transforms/Passes.td b/iree/compiler/Dialect/Stream/Transforms/Passes.td
index 50645af..794041c 100644
--- a/iree/compiler/Dialect/Stream/Transforms/Passes.td
+++ b/iree/compiler/Dialect/Stream/Transforms/Passes.td
@@ -206,16 +206,30 @@
}
//===----------------------------------------------------------------------===//
-// Stream memoization
-//===----------------------------------------------------------------------===//
-
-// TODO(benvanik): outline streams (ala dispatch regions).
-// TODO(benvanik): deduplicate outlined streams.
-
-//===----------------------------------------------------------------------===//
// Diagnostics
//===----------------------------------------------------------------------===//
+def DumpStatistics :
+ Pass<"iree-stream-dump-statistics", "mlir::ModuleOp"> {
+ let summary = "Dumps stream dialect usage information to a file.";
+ let constructor = [{
+ mlir::iree_compiler::IREE::Stream::createDumpStatisticsPass()
+ }];
+ let options = [
+ Option<"outputFormat", "output-format", "IREE::Stream::DumpOutputFormat",
+ "IREE::Stream::DumpOutputFormat::Pretty",
+ "Specifies the output format to produce.",
+ [{::llvm::cl::values(
+ clEnumValN(IREE::Stream::DumpOutputFormat::Pretty, "pretty", "Human-readable pretty printed output."),
+ clEnumValN(IREE::Stream::DumpOutputFormat::Verbose, "verbose", "Pretty printed output with additional IR."),
+ clEnumValN(IREE::Stream::DumpOutputFormat::CSV, "csv", "Comma separated values.")
+ )}]>,
+ Option<"outputFile", "output-file",
+ "std::string", /*default=*/"std::string()",
+ "File path to write to; or `` for stderr or `-` for stdout.">
+ ];
+}
+
def VerifyInput :
Pass<"iree-stream-verify-input", "mlir::ModuleOp"> {
let summary = "Verifies that input dialects are supported by the streams dialect.";
diff --git a/iree/compiler/Dialect/Stream/Transforms/SpecializeDispatches.cpp b/iree/compiler/Dialect/Stream/Transforms/SpecializeDispatches.cpp
index 1079efd..aea5e5f 100644
--- a/iree/compiler/Dialect/Stream/Transforms/SpecializeDispatches.cpp
+++ b/iree/compiler/Dialect/Stream/Transforms/SpecializeDispatches.cpp
@@ -290,8 +290,7 @@
MemoizedCmdConstants &memoizedConstants) {
if (dispatchOps.empty()) return; // no-op if no dispatches
- auto funcOp = executableOp.getInnerModule().lookupSymbol<mlir::FuncOp>(
- exportOp.function_refAttr());
+ auto funcOp = exportOp.getFunctionRef();
// Build a constant table for unique per-dispatch constant values.
auto constantTable = buildConstantTable(funcOp, dispatchOps);
diff --git a/iree/compiler/Dialect/Stream/Transforms/test/BUILD b/iree/compiler/Dialect/Stream/Transforms/test/BUILD
index 3856619..1142f08 100644
--- a/iree/compiler/Dialect/Stream/Transforms/test/BUILD
+++ b/iree/compiler/Dialect/Stream/Transforms/test/BUILD
@@ -19,6 +19,7 @@
[
"annotate_dispatch_arguments.mlir",
"convert_to_stream.mlir",
+ "dump_statistics.mlir",
"elide_async_copies.mlir",
"encode_device_tensors.mlir",
"encode_host_tensors.mlir",
diff --git a/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
index f3ad456..04f7193 100644
--- a/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
+++ b/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
@@ -16,6 +16,7 @@
SRCS
"annotate_dispatch_arguments.mlir"
"convert_to_stream.mlir"
+ "dump_statistics.mlir"
"elide_async_copies.mlir"
"encode_device_tensors.mlir"
"encode_host_tensors.mlir"
diff --git a/iree/compiler/Dialect/Stream/Transforms/test/dump_statistics.mlir b/iree/compiler/Dialect/Stream/Transforms/test/dump_statistics.mlir
new file mode 100644
index 0000000..73047b7
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Transforms/test/dump_statistics.mlir
@@ -0,0 +1,153 @@
+// RUN: iree-opt -split-input-file -pass-pipeline=iree-stream-dump-statistics{output-format=pretty} %s 2>&1 | FileCheck %s -check-prefix=CHECK-PRETTY
+// RUN: iree-opt -split-input-file -pass-pipeline=iree-stream-dump-statistics{output-format=csv} %s 2>&1 | FileCheck %s -check-prefix=CHECK-CSV
+
+// CHECK-PRETTY: Aggregate Statistics
+// CHECK-PRETTY: Constants: 1, 0 B
+// CHECK-PRETTY: Variables: 0, 0 B
+// CHECK-PRETTY: D->H Syncs: 2
+// CHECK-PRETTY: Submissions: 3, using cumulative 0 B
+// CHECK-PRETTY: DMA Fills: 0
+// CHECK-PRETTY: DMA Copies: 2
+// CHECK-PRETTY: Dispatches: 3
+// CHECK-PRETTY: Executables: 2, 33% reuse
+
+// CHECK-CSV: ; Aggregate Statistics
+// CHECK-CSV: "Constants","Constant Size","Variables","Variable Size","Awaits","Submissions","Transient Size","Fills","Copies","Dispatches","Executables"
+// CHECK-CSV: 1,0,0,0,2,3,0,0,2,3,2
+
+util.global private mutable @_constant__timepoint = #stream.timepoint<immediate>
+util.global private @_constant : !stream.resource<constant>
+util.initializer {
+ %c0 = arith.constant 0 : index
+ %c192 = arith.constant 192 : index
+ %0 = stream.timepoint.immediate => !stream.timepoint
+ %1 = util.byte_buffer.constant {alignment = 32 : i64} : !util.byte_buffer = #util.composite<192xi8, [
+ dense<[5, 6, 7, 8]> : tensor<4xi32>,
+ dense<0> : vector<16xi8>,
+ dense<[5, 6, 3, 8]> : tensor<4xi32>,
+ dense<0> : vector<16xi8>,
+ dense<[1, 6, 7, 8]> : tensor<4xi32>,
+ dense<0> : vector<16xi8>,
+ dense<[5, 6, 7]> : tensor<3xi32>,
+ dense<0> : vector<20xi8>,
+ dense<[5, 6, 3]> : tensor<3xi32>,
+ dense<0> : vector<20xi8>,
+ dense<[1, 6, 7]> : tensor<3xi32>,
+ dense<0> : vector<20xi8>,
+ ]>
+ %did_map, %result = stream.resource.try_map %1[%c0] : !util.byte_buffer -> i1, !stream.resource<constant>{%c192}
+ %2:2 = scf.if %did_map -> (!stream.resource<constant>, !stream.timepoint) {
+ scf.yield %result, %0 : !stream.resource<constant>, !stream.timepoint
+ } else {
+ %3 = stream.resource.map %1[%c0] : !util.byte_buffer -> !stream.resource<staging>{%c192}
+ %4 = stream.resource.alloc uninitialized : !stream.resource<constant>{%c192}
+ %5 = stream.cmd.execute with(%3 as %arg0: !stream.resource<staging>{%c192}, %4 as %arg1: !stream.resource<constant>{%c192}) {
+ stream.cmd.copy %arg0[%c0], %arg1[%c0], %c192 : !stream.resource<staging>{%c192} -> !stream.resource<constant>{%c192}
+ } => !stream.timepoint
+ scf.yield %4, %5 : !stream.resource<constant>, !stream.timepoint
+ }
+ util.global.store %2#0, @_constant : !stream.resource<constant>
+ util.global.store %2#1, @_constant__timepoint : !stream.timepoint
+ util.initializer.return
+}
+
+stream.executable private @func_a_ex_0 {
+ stream.executable.export public @dispatch_0
+ builtin.module {
+ func @dispatch_0(%arg0: !stream.binding {stream.alignment = 32 : index}, %arg1: !stream.binding {stream.alignment = 32 : index}, %arg2: !stream.binding {stream.alignment = 32 : index}) {
+ %c4 = arith.constant 4 : index
+ %c0 = arith.constant 0 : index
+ %0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:4xi32>
+ %1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:4xi32>
+ %2 = stream.binding.subspan %arg2[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:4xi32>
+ %workgroup_size_0 = flow.dispatch.workgroup.size[0] : index
+ %workgroup_id_0 = flow.dispatch.workgroup.id[0] : index
+ %workgroup_count_0 = flow.dispatch.workgroup.count[0] : index
+ %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_0, %workgroup_size_0]
+ %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_0, %workgroup_size_0]
+ scf.for %arg3 = %3 to %c4 step %4 {
+ %5 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4)>(%arg3)[%workgroup_size_0]
+ %6 = flow.dispatch.tensor.load %0, offsets = [%arg3], sizes = [%5], strides = [1] : !flow.dispatch.tensor<readonly:4xi32> -> tensor<?xi32>
+ %7 = flow.dispatch.tensor.load %1, offsets = [%arg3], sizes = [%5], strides = [1] : !flow.dispatch.tensor<readonly:4xi32> -> tensor<?xi32>
+ %8 = linalg.init_tensor [%5] : tensor<?xi32>
+ %9 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%6, %7 : tensor<?xi32>, tensor<?xi32>) outs(%8 : tensor<?xi32>) {
+ ^bb0(%arg4: i32, %arg5: i32, %arg6: i32): // no predecessors
+ %10 = arith.maxsi %arg4, %arg5 : i32
+ linalg.yield %10 : i32
+ } -> tensor<?xi32>
+ flow.dispatch.tensor.store %9, %2, offsets = [%arg3], sizes = [%5], strides = [1] : tensor<?xi32> -> !flow.dispatch.tensor<writeonly:4xi32>
+ }
+ return
+ }
+ }
+}
+
+stream.executable private @func_a_ex_1 {
+ stream.executable.export public @dispatch_1
+ builtin.module {
+ func @dispatch_1(%arg0: !stream.binding {stream.alignment = 32 : index}, %arg1: !stream.binding {stream.alignment = 32 : index}, %arg2: !stream.binding {stream.alignment = 32 : index}) {
+ %c3 = arith.constant 3 : index
+ %c0 = arith.constant 0 : index
+ %0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:3xi32>
+ %1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:3xi32>
+ %2 = stream.binding.subspan %arg2[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:3xi32>
+ %workgroup_size_0 = flow.dispatch.workgroup.size[0] : index
+ %workgroup_id_0 = flow.dispatch.workgroup.id[0] : index
+ %workgroup_count_0 = flow.dispatch.workgroup.count[0] : index
+ %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_0, %workgroup_size_0]
+ %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_0, %workgroup_size_0]
+ scf.for %arg3 = %3 to %c3 step %4 {
+ %5 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 3)>(%arg3)[%workgroup_size_0]
+ %6 = flow.dispatch.tensor.load %0, offsets = [%arg3], sizes = [%5], strides = [1] : !flow.dispatch.tensor<readonly:3xi32> -> tensor<?xi32>
+ %7 = flow.dispatch.tensor.load %1, offsets = [%arg3], sizes = [%5], strides = [1] : !flow.dispatch.tensor<readonly:3xi32> -> tensor<?xi32>
+ %8 = linalg.init_tensor [%5] : tensor<?xi32>
+ %9 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%6, %7 : tensor<?xi32>, tensor<?xi32>) outs(%8 : tensor<?xi32>) {
+ ^bb0(%arg4: i32, %arg5: i32, %arg6: i32): // no predecessors
+ %10 = arith.maxsi %arg4, %arg5 : i32
+ linalg.yield %10 : i32
+ } -> tensor<?xi32>
+ flow.dispatch.tensor.store %9, %2, offsets = [%arg3], sizes = [%5], strides = [1] : tensor<?xi32> -> !flow.dispatch.tensor<writeonly:3xi32>
+ }
+ return
+ }
+ }
+}
+
+func public @func_a() -> (tensor<4xi32>, tensor<4xi32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c64 = arith.constant 64 : index
+ %c192 = arith.constant 192 : index
+ %_constant__timepoint = util.global.load @_constant__timepoint : !stream.timepoint
+ %_constant = util.global.load @_constant : !stream.resource<constant>
+ %0 = stream.resource.alloc uninitialized : !stream.resource<external>{%c16}
+ %1 = stream.cmd.execute await(%_constant__timepoint) => with(%_constant as %arg0: !stream.resource<constant>{%c192}, %0 as %arg1: !stream.resource<external>{%c16}) {
+ stream.cmd.copy %arg0[%c0], %arg1[%c0], %c16 : !stream.resource<constant>{%c192} -> !stream.resource<external>{%c16}
+ } => !stream.timepoint
+ %2 = stream.resource.alloc uninitialized : !stream.resource<external>{%c16}
+ %3 = stream.cmd.execute await(%_constant__timepoint) => with(%_constant as %arg0: !stream.resource<constant>{%c192}, %2 as %arg1: !stream.resource<external>{%c16}) {
+ stream.cmd.dispatch @func_a_ex_0::@dispatch_0[%c4, %c1, %c1] {
+ ro %arg0[%c64 for %c16] : !stream.resource<constant>{%c192},
+ ro %arg0[%c32 for %c16] : !stream.resource<constant>{%c192},
+ wo %arg1[%c0 for %c16] : !stream.resource<external>{%c16}
+ }
+ stream.cmd.dispatch @func_a_ex_0::@dispatch_0[%c4, %c1, %c1] {
+ ro %arg0[%c64 for %c16] : !stream.resource<constant>{%c192},
+ ro %arg0[%c32 for %c16] : !stream.resource<constant>{%c192},
+ wo %arg1[%c0 for %c16] : !stream.resource<external>{%c16}
+ }
+ stream.cmd.dispatch @func_a_ex_1::@dispatch_1[%c4, %c1, %c1] {
+ ro %arg0[%c64 for %c16] : !stream.resource<constant>{%c192},
+ ro %arg0[%c32 for %c16] : !stream.resource<constant>{%c192},
+ wo %arg1[%c0 for %c16] : !stream.resource<external>{%c16}
+ }
+ } => !stream.timepoint
+ %4 = stream.timepoint.await %3 => %2 : !stream.resource<external>{%c16}
+ %5 = stream.tensor.export %4 : tensor<4xi32> in !stream.resource<external>{%c16} -> tensor<4xi32>
+ %6 = stream.timepoint.await %1 => %0 : !stream.resource<external>{%c16}
+ %7 = stream.tensor.export %6 : tensor<4xi32> in !stream.resource<external>{%c16} -> tensor<4xi32>
+ return %5, %7 : tensor<4xi32>, tensor<4xi32>
+}
diff --git a/iree/compiler/Translation/IREEVM.cpp b/iree/compiler/Translation/IREEVM.cpp
index 41694e4..3a94ac1 100644
--- a/iree/compiler/Translation/IREEVM.cpp
+++ b/iree/compiler/Translation/IREEVM.cpp
@@ -33,57 +33,58 @@
void BindingOptions::bindOptions(OptionsBinder &binder) {
static llvm::cl::OptionCategory bindingOptionsCategory(
- "IREE translation binding support options");
+ "IREE translation binding support options.");
+
binder.opt<bool>(
"iree-native-bindings-support", native,
llvm::cl::desc(
- "Include runtime support for native IREE ABI-compatible bindings"),
+ "Include runtime support for native IREE ABI-compatible bindings."),
llvm::cl::cat(bindingOptionsCategory));
- binder.opt<bool>(
- "iree-tflite-bindings-support", tflite,
- llvm::cl::desc(
- "Include runtime support for the IREE TFLite compatibility bindings"),
- llvm::cl::cat(bindingOptionsCategory));
+ binder.opt<bool>("iree-tflite-bindings-support", tflite,
+ llvm::cl::desc("Include runtime support for the IREE TFLite "
+ "compatibility bindings."),
+ llvm::cl::cat(bindingOptionsCategory));
}
void InputDialectOptions::bindOptions(OptionsBinder &binder) {
static llvm::cl::OptionCategory inputDialectOptions(
- "IREE options for controlling the input transformations to apply");
+ "IREE options for controlling the input transformations to apply.");
binder.opt<InputDialectOptions::Type>(
- "iree-input-type", type, llvm::cl::desc("IREE input type"),
+ "iree-input-type", type,
+ llvm::cl::desc("Specifies the input program representation."),
llvm::cl::values(
clEnumValN(InputDialectOptions::Type::none, "none",
- "No input dialect transformation"),
+ "No input dialect transformation."),
clEnumValN(InputDialectOptions::Type::tosa, "tosa",
- "Legalize from TOSA ops"),
+ "Legalize from TOSA ops."),
clEnumValN(InputDialectOptions::Type::mhlo, "mhlo",
- "Legalize from MHLO ops"),
+ "Legalize from MHLO ops."),
clEnumValN(
InputDialectOptions::Type::xla, "xla",
- "Legalize from MHLO ops (with XLA cleanup preprocessing)")),
+ "Legalize from MHLO ops (with XLA cleanup preprocessing).")),
llvm::cl::cat(inputDialectOptions));
}
void HighLevelOptimizationOptions::bindOptions(OptionsBinder &binder) {
static llvm::cl::OptionCategory category(
- "IREE options for controlling high level optimizations");
+ "IREE options for controlling high level optimizations.");
binder.opt<bool>(
"iree-opt-const-eval", constEval,
llvm::cl::desc("Enables eager evaluation of constants using the full "
- "compiler and runtime"),
+ "compiler and runtime."),
llvm::cl::cat(category));
binder.opt<bool>(
"iree-opt-const-expr-hoisting", constExprHoisting,
llvm::cl::desc(
"Hoists the results of latent constant expressions into immutable "
- "global initializers for evaluation at program load"),
+ "global initializers for evaluation at program load."),
llvm::cl::cat(category));
binder.opt<bool>(
"iree-opt-numeric-precision-reduction", numericPrecisionReduction,
llvm::cl::desc(
- "Reduces numeric precision to lower bit depths where possible"),
+ "Reduces numeric precision to lower bit depths where possible."),
llvm::cl::cat(category));
binder.opt<bool>("iree-opt-strip-assertions", stripAssertions,
llvm::cl::desc("Strips debug assertions after any useful "
@@ -91,9 +92,31 @@
llvm::cl::cat(category));
}
+void SchedulingOptions::bindOptions(OptionsBinder &binder) {
+ static llvm::cl::OptionCategory category(
+ "IREE options for controlling host/device scheduling.");
+
+ binder.opt<DumpOutputFormat>(
+ "iree-scheduling-dump-statistics-format", dumpStatisticsFormat,
+ llvm::cl::desc("Dumps statistics in the specified output format."),
+ llvm::cl::cat(category),
+ llvm::cl::values(
+ clEnumValN(DumpOutputFormat::Pretty, "pretty",
+ "Human-readable pretty printed output."),
+ clEnumValN(DumpOutputFormat::Verbose, "verbose",
+ "Pretty printed output with additional IR."),
+ clEnumValN(DumpOutputFormat::CSV, "csv", "Comma separated values.")));
+ binder.opt<std::string>("iree-scheduling-dump-statistics-file",
+ dumpStatisticsFile,
+ llvm::cl::desc("File path to write statistics to; or "
+ "`` for stderr or `-` for stdout."),
+ llvm::cl::cat(category));
+}
+
void buildIREEVMTransformPassPipeline(
BindingOptions bindingOptions, InputDialectOptions inputOptions,
HighLevelOptimizationOptions highLevelOptimizationOptions,
+ SchedulingOptions schedulingOptions,
IREE::HAL::TargetOptions executableOptions,
IREE::VM::TargetOptions targetOptions, OpPassManager &passManager) {
// Input pipelines can result in changes to the exported functions and types
@@ -141,8 +164,13 @@
passManager.addPass(IREE::Util::createStripDebugOpsPass());
}
- IREE::Flow::buildFlowTransformPassPipeline(passManager, flowOptions);
IREE::Stream::TransformOptions streamOptions;
+ // TODO(benvanik): find a way to share the enums w/o circular deps.
+ streamOptions.dumpStatisticsFormat =
+ (IREE::Stream::DumpOutputFormat)schedulingOptions.dumpStatisticsFormat;
+ streamOptions.dumpStatisticsFile = schedulingOptions.dumpStatisticsFile;
+
+ IREE::Flow::buildFlowTransformPassPipeline(passManager, flowOptions);
IREE::Stream::buildStreamTransformPassPipeline(passManager, streamOptions);
IREE::HAL::buildHALTransformPassPipeline(passManager, executableOptions);
IREE::VM::buildVMTransformPassPipeline(passManager, targetOptions);
@@ -153,6 +181,7 @@
buildIREEVMTransformPassPipeline(
BindingOptions::FromFlags::get(), InputDialectOptions::FromFlags::get(),
HighLevelOptimizationOptions::FromFlags::get(),
+ SchedulingOptions::FromFlags::get(),
IREE::HAL::TargetOptions::FromFlags::get(),
IREE::VM::TargetOptions::FromFlags::get(), passManager);
}
@@ -173,6 +202,7 @@
ModuleOp moduleOp, BindingOptions bindingOptions,
InputDialectOptions inputOptions,
HighLevelOptimizationOptions highLevelOptimizationOptions,
+ SchedulingOptions schedulingOptions,
IREE::HAL::TargetOptions executableOptions,
IREE::VM::TargetOptions targetOptions) {
PassManager passManager(moduleOp.getContext());
@@ -181,7 +211,7 @@
passManager.addInstrumentation(std::make_unique<PassTracing>());
buildIREEVMTransformPassPipeline(
bindingOptions, inputOptions, highLevelOptimizationOptions,
- executableOptions, targetOptions, passManager);
+ schedulingOptions, executableOptions, targetOptions, passManager);
if (failed(passManager.run(moduleOp))) {
return moduleOp.emitError() << "conversion from source -> vm failed";
}
@@ -202,13 +232,14 @@
auto inputOptions = InputDialectOptions::FromFlags::get();
auto highLevelOptimizationOptions =
HighLevelOptimizationOptions::FromFlags::get();
+ auto schedulingOptions = SchedulingOptions::FromFlags::get();
auto halTargetOptions = IREE::HAL::TargetOptions::FromFlags::get();
auto vmTargetOptions = IREE::VM::TargetOptions::FromFlags::get();
auto bytecodeTargetOptions =
IREE::VM::BytecodeTargetOptions::FromFlags::get();
- auto result = translateFromMLIRToVM(moduleOp, bindingOptions, inputOptions,
- highLevelOptimizationOptions,
- halTargetOptions, vmTargetOptions);
+ auto result = translateFromMLIRToVM(
+ moduleOp, bindingOptions, inputOptions, highLevelOptimizationOptions,
+ schedulingOptions, halTargetOptions, vmTargetOptions);
if (failed(result)) {
return result;
}
@@ -227,12 +258,13 @@
auto inputOptions = InputDialectOptions::FromFlags::get();
auto highLevelOptimizationOptions =
HighLevelOptimizationOptions::FromFlags::get();
+ auto schedulingOptions = SchedulingOptions::FromFlags::get();
auto halTargetOptions = IREE::HAL::TargetOptions::FromFlags::get();
auto vmTargetOptions = IREE::VM::TargetOptions::FromFlags::get();
auto cTargetOptions = IREE::VM::getCTargetOptionsFromFlags();
- auto result = translateFromMLIRToVM(moduleOp, bindingOptions, inputOptions,
- highLevelOptimizationOptions,
- halTargetOptions, vmTargetOptions);
+ auto result = translateFromMLIRToVM(
+ moduleOp, bindingOptions, inputOptions, highLevelOptimizationOptions,
+ schedulingOptions, halTargetOptions, vmTargetOptions);
if (failed(result)) {
return result;
}
diff --git a/iree/compiler/Translation/IREEVM.h b/iree/compiler/Translation/IREEVM.h
index 766712b..e007a6d 100644
--- a/iree/compiler/Translation/IREEVM.h
+++ b/iree/compiler/Translation/IREEVM.h
@@ -46,13 +46,10 @@
// Applies no input transformation. Only supported core and extension ops
// are supported.
none,
-
// Legalizes input defined over TOSA ops.
tosa,
-
// Legalizes input defined over MHLO ops.
mhlo,
-
// Special case of 'mhlo' legalization which also performs some XLA
// cleanup activities.
xla,
@@ -82,6 +79,35 @@
using FromFlags = OptionsFromFlags<HighLevelOptimizationOptions>;
};
+// Options controlling scheduling across host/device.
+struct SchedulingOptions {
+ // TODO(benvanik): find a way to share this with
+ // Stream/Transforms/PassDetail.h w/o circular deps.
+ // Defines the output format of a dump pass.
+ enum class DumpOutputFormat {
+ // Dumping disabled.
+ None = 0,
+ // Human-readable pretty printing.
+ Pretty = 1,
+ // Pretty printing with additional information that can result in large
+ // dumps.
+ Verbose = 2,
+ // Comma separated values for throwing into Sheets.
+ CSV = 3,
+ };
+ // Enables and specifies the the format for a stream statistics dump.
+ DumpOutputFormat dumpStatisticsFormat = DumpOutputFormat::None;
+ // File path to write statistics to; or `` for stderr or `-` for stdout.
+ std::string dumpStatisticsFile = "";
+
+ // TODO(benvanik): favor size/speed/etc for partitioning.
+ // TODO(benvanik): execution model to optimize for (unified/discrete memory,
+ // single/multiple processors, etc).
+
+ void bindOptions(OptionsBinder &binder);
+ using FromFlags = OptionsFromFlags<SchedulingOptions>;
+};
+
// Builds the translation pipeline with defaults.
void buildDefaultIREEVMTransformPassPipeline(OpPassManager &passManager);
@@ -89,6 +115,7 @@
void buildIREEVMTransformPassPipeline(
BindingOptions bindingOptions, InputDialectOptions inputOptions,
HighLevelOptimizationOptions highLevelOptimizationOptions,
+ SchedulingOptions schedulingOptions,
IREE::HAL::TargetOptions executableOptions,
IREE::VM::TargetOptions targetOptions, OpPassManager &passManager);
diff --git a/llvm-external-projects/iree-compiler-api/lib/CAPI/Compiler.cpp b/llvm-external-projects/iree-compiler-api/lib/CAPI/Compiler.cpp
index 27b658d..756c04a 100644
--- a/llvm-external-projects/iree-compiler-api/lib/CAPI/Compiler.cpp
+++ b/llvm-external-projects/iree-compiler-api/lib/CAPI/Compiler.cpp
@@ -37,7 +37,8 @@
BindingOptions bindingOptions;
InputDialectOptions inputDialectOptions;
HighLevelOptimizationOptions highLevelOptimizationOptions;
- HALTargetOptions executableOptions;
+ SchedulingOptions schedulingOptions;
+ HALTargetOptions halTargetOptions;
VMTargetOptions vmTargetOptions;
VMBytecodeTargetOptions vmBytecodeTargetOptions;
@@ -47,7 +48,8 @@
bindingOptions.bindOptions(binder);
inputDialectOptions.bindOptions(binder);
highLevelOptimizationOptions.bindOptions(binder);
- executableOptions.bindOptions(binder);
+ schedulingOptions.bindOptions(binder);
+ halTargetOptions.bindOptions(binder);
vmTargetOptions.bindOptions(binder);
vmBytecodeTargetOptions.bindOptions(binder);
}
@@ -96,7 +98,7 @@
void ireeCompilerOptionsAddTargetBackend(IreeCompilerOptions options,
const char *targetBackend) {
- unwrap(options)->executableOptions.targets.push_back(
+ unwrap(options)->halTargetOptions.targets.push_back(
std::string(targetBackend));
}
@@ -133,8 +135,9 @@
auto *passManagerCpp = unwrap(passManager);
buildIREEVMTransformPassPipeline(
optionsCpp->bindingOptions, optionsCpp->inputDialectOptions,
- optionsCpp->highLevelOptimizationOptions, optionsCpp->executableOptions,
- optionsCpp->vmTargetOptions, *passManagerCpp);
+ optionsCpp->highLevelOptimizationOptions, optionsCpp->schedulingOptions,
+ optionsCpp->halTargetOptions, optionsCpp->vmTargetOptions,
+ *passManagerCpp);
}
// Translates a module op derived from the ireeCompilerBuildIREEVMPassPipeline