blob: 616026a9dcf182e3f016934508c7f9052880df86 [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <memory>
#include <utility>
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Utils/IndexSet.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/Path.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Transforms/Passes.h"
// NOTE: redundant bindings will result in unique buffer locations during the
// benchmark and will impact caching behavior.
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace HAL {
// We could use the resource constraints in the module when we have them.
static const int64_t kBufferAlignment = 256;
using Vec3 = std::tuple<unsigned, unsigned, unsigned>;
struct Binding {
unsigned set = 0;
unsigned binding = 0;
int64_t size = 0;
};
// Combined data for all dispatches of a particular static workload size.
struct DispatchParams {
// All locations that dispatch with these parameters.
SmallVector<Location> locs;
// Workload used as input to the workgroup count calculation function.
Vec3 workload;
// Analyzed minimum binding sizes.
SmallVector<Binding> bindings;
};
using DispatchParamsMap =
llvm::DenseMap<SymbolRefAttr, llvm::MapVector<Vec3, DispatchParams>>;
// Walk |moduleOp| and gather all of the dispatches to each executable.
// Dispatch parameters are deduplicated by workload so that there's only ever
// one entry for all dispatches with a given workgroup count.
// Dispatches will be ignored if they have a dynamic workload or any dynamically
// sized resources.
static DispatchParamsMap gatherDispatchParams(mlir::ModuleOp moduleOp) {
DispatchParamsMap map;
for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
funcOp.walk([&](IREE::Stream::CmdDispatchOp dispatchOp) {
// NOTE: we could record operands here if we wanted real push constants.
// TODO(benvanik): typed accessors for bindings.
auto bindingAttrs = dispatchOp->getAttr("hal.interface.bindings")
.dyn_cast_or_null<ArrayAttr>();
assert(bindingAttrs &&
"interface materialization must annotate dispatch sites");
auto workloadValues = dispatchOp.workgroup_count();
APInt workloadX, workloadY, workloadZ;
if (!matchPattern(workloadValues[0], m_ConstantInt(&workloadX)) ||
!matchPattern(workloadValues[1], m_ConstantInt(&workloadY)) ||
!matchPattern(workloadValues[2], m_ConstantInt(&workloadZ))) {
// Non-constant workload; skip this dispatch.
return;
}
Vec3 workload =
std::make_tuple(workloadX.getSExtValue(), workloadY.getSExtValue(),
workloadZ.getSExtValue());
SmallVector<Binding> bindings;
for (auto it : llvm::zip(bindingAttrs, dispatchOp.resource_lengths())) {
auto bindingAttr =
std::get<0>(it).cast<IREE::HAL::InterfaceBindingAttr>();
APInt resourceLength;
if (!matchPattern(std::get<1>(it), m_ConstantInt(&resourceLength))) {
// Non-constant resource length; skip this dispatch.
return;
}
bindings.push_back({(unsigned)bindingAttr.getSet(),
(unsigned)bindingAttr.getBinding(),
resourceLength.getSExtValue()});
}
auto &dispatchParamsSet = map[dispatchOp.entry_point()];
auto &dispatchParams = dispatchParamsSet[workload];
dispatchParams.locs.push_back(dispatchOp.getLoc());
dispatchParams.workload = workload;
dispatchParams.bindings = bindings;
});
}
return map;
}
// Appends a global hal.buffer initialized to the size required for all
// of the bindings in |dispatchParams| (plus alignment).
static IREE::Util::GlobalOp appendGlobalBuffer(
Location loc, StringRef baseName, const DispatchParams &dispatchParams,
OpBuilder &moduleBuilder) {
// Create a global to hold the HAL buffer.
auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
loc, (baseName + "_buffer").str(),
/*isMutable=*/true,
IREE::HAL::BufferType::get(moduleBuilder.getContext()));
globalOp.setPrivate();
// Compute the total size of the buffer based on all binding sizes when
// aligned.
int64_t totalLength = 0;
for (auto binding : dispatchParams.bindings) {
totalLength =
IREE::Util::align(totalLength + binding.size, kBufferAlignment);
}
// Build an initializer to allocate the buffer.
auto initOp = moduleBuilder.create<IREE::Util::InitializerOp>(loc);
auto initBuilder = OpBuilder::atBlockBegin(initOp.addEntryBlock());
IndexSet indexSet(loc, initBuilder);
// TODO(benvanik): real device lookup.
auto device = initBuilder.create<IREE::HAL::ExSharedDeviceOp>(loc);
auto allocator =
initBuilder.create<IREE::HAL::DeviceAllocatorOp>(loc, device).result();
auto memoryTypes = IREE::HAL::MemoryTypeBitfield::DeviceLocal;
auto bufferUsage = IREE::HAL::BufferUsageBitfield::Dispatch;
auto allocateOp = initBuilder.create<IREE::HAL::AllocatorAllocateOp>(
loc, globalOp.type(), allocator, memoryTypes, bufferUsage,
indexSet.get(totalLength));
initBuilder.create<IREE::Util::GlobalStoreOp>(loc, allocateOp.result(),
globalOp.getNameAttr());
initBuilder.create<IREE::Util::InitializerReturnOp>(loc);
return globalOp;
}
// Appends a function calling the given |entryPointOp| with |dispatchParams|.
// This will add a global value for the resources required.
//
// Expects the runner to pass an i32 value indicating the number of dispatches
// to be made in one submission.
static void appendDispatchBenchmark(
IREE::HAL::ExecutableOp executableOp,
IREE::HAL::ExecutableVariantOp variantOp,
IREE::HAL::ExecutableEntryPointOp entryPointOp,
DispatchParams dispatchParams, OpBuilder &moduleBuilder) {
auto loc = FusedLoc::get(executableOp.getContext(), dispatchParams.locs);
std::string baseName =
(executableOp.getName() + "_" + variantOp.getName() + "_" +
entryPointOp.getName() + "_" +
std::to_string(std::get<0>(dispatchParams.workload)) + "x" +
std::to_string(std::get<1>(dispatchParams.workload)) + "x" +
std::to_string(std::get<2>(dispatchParams.workload)))
.str();
// Add a global variable holding an initialized buffer for the dispatch IO.
auto bufferGlobalOp =
appendGlobalBuffer(loc, baseName, dispatchParams, moduleBuilder);
// Create an exported benchmark function that runs the dispatches.
auto funcType =
moduleBuilder.getFunctionType({moduleBuilder.getI32Type()}, {});
auto funcOp = moduleBuilder.create<mlir::FuncOp>(loc, baseName, funcType);
funcOp.setVisibility(SymbolTable::Visibility::Public);
// Mark the function as being a dispatch benchmark.
// This tells iree-benchmark-module to pass in the arguments we need.
funcOp->setAttr("iree.abi.stub", moduleBuilder.getUnitAttr());
funcOp->setAttr(
"iree.reflection",
moduleBuilder.getDictionaryAttr({
moduleBuilder.getNamedAttr("iree.benchmark",
moduleBuilder.getStringAttr("dispatch")),
}));
// Build the function that runs the dispatches.
auto *entryBlock = funcOp.addEntryBlock();
OpBuilder funcBuilder = OpBuilder::atBlockBegin(entryBlock);
IndexSet indexSet(loc, funcBuilder);
auto batchSizeArg = funcBuilder.create<arith::IndexCastOp>(
loc, funcBuilder.getIndexType(), entryBlock->getArgument(0));
// TODO(benvanik): real device lookup.
auto device = funcBuilder.create<IREE::HAL::ExSharedDeviceOp>(loc);
// Create and begin command buffer.
// TODO(benvanik): reuse the command buffer (initialize once and store).
auto commandBufferModes =
IREE::HAL::CommandBufferModeBitfield::OneShot |
IREE::HAL::CommandBufferModeBitfield::AllowInlineExecution;
auto commandBuffer =
funcBuilder
.create<IREE::HAL::CommandBufferCreateOp>(
loc, funcBuilder.getType<IREE::HAL::CommandBufferType>(), device,
commandBufferModes, IREE::HAL::CommandCategoryBitfield::Dispatch)
.result();
funcBuilder.create<IREE::HAL::CommandBufferBeginOp>(loc, commandBuffer);
// Get the layout required to set up the dispatches.
auto layoutAttr = entryPointOp.layoutAttr();
auto executableLayout =
funcBuilder
.create<IREE::HAL::ExecutableLayoutLookupOp>(
loc, IREE::HAL::ExecutableLayoutType::get(loc.getContext()),
device, layoutAttr)
.result();
// Push constant values.
// TODO(benvanik): use push constants the program used? can help with
// specialization that may have been applied in the streams dialect.
if (int64_t pushConstantCount = layoutAttr.getPushConstants()) {
int pushConstantBase = 0; // always 0 today
SmallVector<Value> pushConstants(pushConstantCount,
funcBuilder.create<arith::ConstantIntOp>(
loc, 0, funcBuilder.getI32Type()));
funcBuilder.create<IREE::HAL::CommandBufferPushConstantsOp>(
loc, commandBuffer, executableLayout,
funcBuilder.getIndexAttr(pushConstantBase), pushConstants);
}
// Push descriptor sets.
auto buffer =
funcBuilder.create<IREE::Util::GlobalLoadOp>(loc, bufferGlobalOp)
.result();
int64_t currentSet = -1;
SmallVector<IREE::HAL::DescriptorSetBindingValue> bindingValues;
auto flushSet = [&]() {
funcBuilder.create<IREE::HAL::CommandBufferPushDescriptorSetOp>(
loc, commandBuffer, executableLayout, currentSet, bindingValues);
bindingValues.clear();
};
int64_t bufferOffset = 0;
for (auto binding : dispatchParams.bindings) {
if (currentSet != -1 && currentSet != binding.set) flushSet();
currentSet = binding.set;
IREE::HAL::DescriptorSetBindingValue bindingValue;
bindingValue.ordinal =
funcBuilder.create<arith::ConstantIndexOp>(loc, binding.binding);
bindingValue.buffer = buffer;
bindingValue.byteOffset = indexSet.get(bufferOffset);
bindingValue.byteLength = indexSet.get(binding.size);
bindingValues.push_back(bindingValue);
bufferOffset =
IREE::Util::align(bufferOffset + binding.size, kBufferAlignment);
}
if (currentSet != -1) flushSet();
// Compute the workgroup parameters.
auto workgroupCount = entryPointOp.calculateWorkgroupCount(
loc,
{
indexSet.get(std::get<0>(dispatchParams.workload)),
indexSet.get(std::get<1>(dispatchParams.workload)),
indexSet.get(std::get<2>(dispatchParams.workload)),
},
funcBuilder);
// Loop around dispatches based on batch size.
// Note that we insert a barrier between each dispatch - we could make this
// optional so that concurrent utilization is measured.
funcBuilder.create<scf::ForOp>(
loc, indexSet.get(0), batchSizeArg, indexSet.get(1), ValueRange{},
[&](OpBuilder &forBuilder, Location loc, Value iv, ValueRange iters) {
// Dispatch.
auto symbolRefAttr = SymbolRefAttr::get(
executableOp.getNameAttr(),
{
SymbolRefAttr::get(variantOp.getNameAttr()),
SymbolRefAttr::get(entryPointOp.getNameAttr()),
});
forBuilder.create<IREE::HAL::CommandBufferDispatchSymbolOp>(
loc, commandBuffer, symbolRefAttr, workgroupCount[0],
workgroupCount[1], workgroupCount[2]);
// Barrier following the dispatch to block the next dispatch.
auto sourceStage = IREE::HAL::ExecutionStageBitfield::CommandRetire |
IREE::HAL::ExecutionStageBitfield::Dispatch;
auto targetStage = IREE::HAL::ExecutionStageBitfield::CommandIssue |
IREE::HAL::ExecutionStageBitfield::Dispatch;
auto barrierFlags = IREE::HAL::ExecutionBarrierFlagBitfield::None;
forBuilder.create<IREE::HAL::CommandBufferExecutionBarrierOp>(
loc, commandBuffer, sourceStage, targetStage, barrierFlags);
forBuilder.create<scf::YieldOp>(loc);
});
// Submit command buffer.
funcBuilder.create<IREE::HAL::CommandBufferEndOp>(loc, commandBuffer);
funcBuilder.create<IREE::HAL::ExSubmitAndWaitOp>(loc, device, commandBuffer);
funcBuilder.create<mlir::func::ReturnOp>(loc);
}
// Builds a module exporting one function for each dispatch configuration
// targeting |sourceExecutableOp|.
static mlir::OwningOpRef<mlir::ModuleOp> buildBenchmarkModule(
IREE::HAL::ExecutableOp sourceExecutableOp,
IREE::HAL::ExecutableVariantOp sourceVariantOp,
const DispatchParamsMap &dispatchParamsMap) {
// Empty module with default name.
// We could use the original module name here to make tracking nicer.
mlir::OwningOpRef<mlir::ModuleOp> moduleOp =
mlir::ModuleOp::create(sourceExecutableOp.getLoc());
auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp->getBody());
// Copy over the device targets from the original module.
// TODO(benvanik): filter this by the target of the variant.
moduleOp->getOperation()->setAttr(
"hal.device.targets",
sourceExecutableOp->getParentOfType<mlir::ModuleOp>()->getAttr(
"hal.device.targets"));
// Clone the executable variant into the new module.
auto executableOp = moduleBuilder.create<IREE::HAL::ExecutableOp>(
sourceExecutableOp.getLoc(), sourceExecutableOp.getName());
executableOp.setVisibility(sourceExecutableOp.getVisibility());
auto variantOp = cast<IREE::HAL::ExecutableVariantOp>(
OpBuilder::atBlockBegin(executableOp.getBody())
.clone(*sourceVariantOp.getOperation()));
// Add functions to test each entry point with its various dispatch
// parameters.
bool hasAnyBenchmarks = false;
for (auto entryPointOp :
variantOp.getOps<IREE::HAL::ExecutableEntryPointOp>()) {
auto symbolRefAttr = SymbolRefAttr::get(
executableOp.getNameAttr(),
{FlatSymbolRefAttr::get(entryPointOp.getNameAttr())});
auto dispatchParamsSet = dispatchParamsMap.find(symbolRefAttr);
if (dispatchParamsSet != dispatchParamsMap.end()) {
for (auto dispatchParams : dispatchParamsSet->second) {
appendDispatchBenchmark(executableOp, variantOp, entryPointOp,
dispatchParams.second, moduleBuilder);
hasAnyBenchmarks = true;
}
}
}
// Skip the file when we could not generate any benchmarks.
if (!hasAnyBenchmarks) return {};
// Run CSE and the canonicalizer to pretty up the output.
PassManager passManager(moduleOp->getContext());
passManager.addPass(mlir::createCanonicalizerPass());
passManager.addPass(mlir::createCSEPass());
if (failed(passManager.run(*moduleOp))) {
moduleOp->emitError("failed to run canonicalizer; malformed output");
return {};
}
return moduleOp;
}
static void dumpModuleToStream(mlir::ModuleOp moduleOp, StringRef fileName,
llvm::raw_ostream &os) {
OpPrintingFlags flags;
flags.useLocalScope(); // could use global scope, but IR gets messy fast
moduleOp.print(os, flags);
os << "\n"; // newline at end of file
}
class DumpExecutableBenchmarksPass
: public PassWrapper<DumpExecutableBenchmarksPass,
OperationPass<ModuleOp>> {
public:
DumpExecutableBenchmarksPass() = default;
DumpExecutableBenchmarksPass(const DumpExecutableBenchmarksPass &pass) {}
DumpExecutableBenchmarksPass(StringRef path) { this->path = path.str(); }
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::HAL::HALDialect>();
registry.insert<arith::ArithmeticDialect>();
registry.insert<scf::SCFDialect>();
}
StringRef getArgument() const override {
return "iree-hal-dump-executable-benchmarks";
}
StringRef getDescription() const override {
return "Dumps standalone hal.executable benchmarks to a path.";
}
void runOnOperation() override {
auto moduleOp = getOperation();
auto moduleName = moduleOp.getName().getValueOr("module");
// Analyze the module to find dispatch parameters.
// This is a full walk of all stream.cmd.dispatch ops and will handle
// filtering out dispatches that have dynamic parameters we don't
// currently support.
auto dispatchParamsMap = gatherDispatchParams(moduleOp);
if (dispatchParamsMap.empty()) return;
// Help people out and mkdir if needed.
if (!path.empty() && path != "-") {
llvm::sys::fs::create_directories(path);
}
// Produce one file per executable containing all entry points.
for (auto executableOp : moduleOp.getOps<IREE::HAL::ExecutableOp>()) {
for (auto variantOp :
executableOp.getOps<IREE::HAL::ExecutableVariantOp>()) {
auto benchmarkModuleOp =
buildBenchmarkModule(executableOp, variantOp, dispatchParamsMap);
if (!benchmarkModuleOp) continue;
auto fileName = (moduleName + "_" + executableOp.getName() + "_" +
variantOp.getName() + ".mlir")
.str();
if (path.empty() || path == "-") {
dumpModuleToStream(*benchmarkModuleOp, fileName, llvm::outs());
} else {
auto filePath =
(path + llvm::sys::path::get_separator() + fileName).str();
std::string error;
auto file = mlir::openOutputFile(filePath, &error);
if (!file) {
executableOp.emitError()
<< "while dumping to " << path << ": " << error;
return signalPassFailure();
}
dumpModuleToStream(*benchmarkModuleOp, fileName, file->os());
file->keep();
}
}
}
}
private:
Option<std::string> path{
*this, "path",
llvm::cl::desc("Path to write hal.executable benchmarks into.")};
};
std::unique_ptr<OperationPass<ModuleOp>> createDumpExecutableBenchmarksPass(
StringRef path) {
return std::make_unique<DumpExecutableBenchmarksPass>(path);
}
static PassRegistration<DumpExecutableBenchmarksPass> pass([] {
return std::make_unique<DumpExecutableBenchmarksPass>();
});
} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir