Adding a flag to force indirect command buffers on in non-reusable cases. (#18945)
Includes various fixes found during testing.
diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp
index 05ab667..308c880 100644
--- a/compiler/plugins/target/ROCM/ROCMTarget.cpp
+++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp
@@ -160,6 +160,17 @@
}
};
+// Returns the ABI or an empty string if unspecified.
+static StringRef getABI(IREE::HAL::ExecutableTargetAttr targetAttr) {
+ if (targetAttr) {
+ if (auto config = targetAttr.getConfiguration()) {
+ auto abiAttr = targetAttr.getConfiguration().getAs<StringAttr>("abi");
+ return abiAttr ? abiAttr.getValue() : "";
+ }
+ }
+ return "";
+}
+
static void dumpModuleToPath(StringRef path, StringRef baseName,
StringRef suffix, StringRef extension,
llvm::Module &module) {
@@ -585,8 +596,7 @@
// Wrap the HSACO ELF binary in a Flatbuffers container.
FailureOr<DenseIntElementsAttr> binaryContainer;
- if (targetAttr.getConfiguration() &&
- targetAttr.getConfiguration().getAs<StringAttr>("abi") == "amdgpu") {
+ if (getABI(targetAttr) == "amdgpu") {
binaryContainer = serializeAMDGPUBinaryContainer(
serializationOptions, variantOp, exportOps, targetHSACO);
} else {
diff --git a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp
index 7f83a5e..d12dbc5 100644
--- a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp
+++ b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp
@@ -1100,8 +1100,12 @@
return new Error("not a valid HAL executable");
}
auto binaryOp = binaryOps.front();
- auto rawData = binaryOp.getData().getRawData();
- output.outputStream->write(rawData.data(), rawData.size());
+ if (failed(cast<IREE::Util::SerializableAttrInterface>(binaryOp.getData())
+ .serializeToStream(binaryOp.getLoc(), llvm::endianness::little,
+ *output.outputStream))) {
+ return new Error(
+ "data attribute failed to serialize: unsupported format or encoding");
+ }
output.outputStream->flush();
return output.getWriteError();
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
index 4ca2fdb..491bd87 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
@@ -32,6 +32,15 @@
llvm::cl::init(true),
};
+// TODO(benvanik): remove when we support capturing dynamic values for reuse.
+static llvm::cl::opt<bool> clForceIndirectCommandBuffers{
+ "iree-hal-force-indirect-command-buffers",
+ llvm::cl::desc("Forces indirect command buffers when they would otherwise "
+ "not be chosen due to the values they capture. They may not "
+ "be reusable but will still be outlined."),
+ llvm::cl::init(false),
+};
+
struct ContextResolveOpPattern
: public StreamConversionPattern<IREE::Stream::ContextResolveOp> {
using StreamConversionPattern::StreamConversionPattern;
@@ -1002,7 +1011,9 @@
// changes dispatches to use them for any dispatch we can - note that there
// may still be some that slip through due to custom executables.
const bool capturesDynamicUniformValues =
- regionCapturesDynamicUniformValues(executeOp);
+ clForceIndirectCommandBuffers
+ ? false
+ : regionCapturesDynamicUniformValues(executeOp);
// Calculate the indirect buffer references used within the command buffer
// by analyzing captured resources. This analysis will be used by subsequent
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
index 2b2f23c..3f1e811 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
@@ -150,8 +150,6 @@
def HAL_OrdinalAttr : Util_IndexAttrBase<"size_t">;
def HAL_OrdinalArrayAttr : TypedArrayAttrBase<HAL_OrdinalAttr, "Array of index ordinal attributes">;
-def HAL_ExecutableDataAttr : SignlessIntElementsAttr<8>;
-
def HAL_ElementType : TypeAlias<I32>;
def HAL_ElementTypeAttr : SignlessIntegerAttrBase<
I32, "element type attribute">;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index e3f05fc..c1d9a4e 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -2593,7 +2593,7 @@
OptionalAttr<StrAttr>:$sym_visibility,
SymbolNameAttr:$sym_name,
StrAttr:$format,
- HAL_ExecutableDataAttr:$data,
+ Util_AnySerializableAttr:$data,
OptionalAttr<StrAttr>:$mime_type
// TODO(benvanik): add compatibility and versioning attributes.
);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/OutlineMemoizeRegions.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/OutlineMemoizeRegions.cpp
index 3044786..19b8d49 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/OutlineMemoizeRegions.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/OutlineMemoizeRegions.cpp
@@ -13,6 +13,7 @@
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Utils/StringUtils.h"
+#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -24,6 +25,8 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/RegionUtils.h"
+#define DEBUG_TYPE "iree-hal-outline-memoize-regions"
+
namespace mlir::iree_compiler::IREE::HAL {
#define GEN_PASS_DEF_OUTLINEMEMOIZEREGIONSPASS
@@ -153,6 +156,8 @@
name, funcType);
moduleSymbolTable.insert(funcOp);
funcOp.setVisibility(SymbolTable::Visibility::Private);
+ funcOp.setInliningPolicyAttr(
+ moduleBuilder.getAttr<IREE::Util::InlineNeverAttr>());
auto funcBuilder = OpBuilder::atBlockBegin(funcOp.addEntryBlock());
// Remap any captured operands that have corresponding function arguments.
@@ -521,8 +526,11 @@
// If we can't memoize the resources at initialization time then we need
// to do it on-demand.
if (!memoizeAnalysis.canRunAtInitializationTime()) {
- memoizeOp.emitWarning(
- "memoization failed: dynamic values captured at the call site");
+ LLVM_DEBUG({
+ llvm::dbgs()
+ << "memoization failed: dynamic values captured at the call site\n";
+ memoizeOp.dump();
+ });
replaceMemoizeOpWithApply(memoizeOp, memoizeAnalysis, applyFuncOp);
return;
}
@@ -532,8 +540,11 @@
auto deviceGlobals =
deviceAnalysis.lookupDeviceGlobals(memoizeOp.getDevice());
if (!deviceGlobals) {
- memoizeOp.emitWarning("memoization failed: unable to analyze devices "
- "that may be used with memoized region");
+ LLVM_DEBUG({
+ llvm::dbgs() << "memoization failed: unable to analyze devices that may "
+ "be used with memoized region\n";
+ memoizeOp.dump();
+ });
replaceMemoizeOpWithApply(memoizeOp, memoizeAnalysis, applyFuncOp);
return;
}
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp
index 9690bca..a71011b 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp
@@ -84,7 +84,6 @@
// Note that we want to remove entire chains of unused calls and run this
// as a pattern application.
RewritePatternSet patterns(&getContext());
- // patterns
patterns.insert<EraseUnusedCallOp<IREE::VM::CallOp>,
EraseUnusedCallOp<IREE::VM::CallVariadicOp>>(
&getContext(), noSideEffectsSymbols);
diff --git a/tests/compiler_driver/streams.mlir b/tests/compiler_driver/streams.mlir
index 03ebbc3..9e30a2e 100644
--- a/tests/compiler_driver/streams.mlir
+++ b/tests/compiler_driver/streams.mlir
@@ -51,10 +51,10 @@
}
}
}
-// CHECK: vm.func private @simple_mul
+// CHECK: vm.func private @__simple_mul_memoize_apply
+// CHECK: vm.call.variadic @hal.command_buffer.dispatch
func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%c4 = arith.constant 4 : index
- // CHECK: vm.call.variadic @hal.command_buffer.dispatch
%ret0 = flow.dispatch @executable_0::@dispatch[%c4](%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %ret0 : tensor<4xf32>
}
@@ -98,10 +98,10 @@
}
}
}
-// CHECK: vm.func private @simple_mul_inplace
+// CHECK: vm.func private @__simple_mul_inplace_memoize_apply
+// CHECK: vm.call.variadic @hal.command_buffer.dispatch
func.func @simple_mul_inplace(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%c4 = arith.constant 4 : index
- // CHECK: vm.call.variadic @hal.command_buffer.dispatch
%ret0 = flow.dispatch @executable_1::@dispatch[%c4](%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> %arg0
return %ret0 : tensor<4xf32>
}