Adding redundant command buffer barrier elision. (#7600)
Super simple since we only generate barriers in one way today.
diff --git a/iree/compiler/Dialect/HAL/Transforms/ElideRedundantCommands.cpp b/iree/compiler/Dialect/HAL/Transforms/ElideRedundantCommands.cpp
index 92c24e9..4cbf737 100644
--- a/iree/compiler/Dialect/HAL/Transforms/ElideRedundantCommands.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/ElideRedundantCommands.cpp
@@ -51,6 +51,14 @@
SmallVector<Value, 32> pushConstants;
SmallVector<DescriptorSetState, 4> descriptorSets;
+ // Set after we know a full barrier has been issued; any subsequent barrier
+ // until a real operation is redundant. We could track more fine-grained state
+ // here such as which stages are being waited on.
+ // Note that we assume no barriers by default, as the command buffer may have
+ // been passed as a function/branch argument and we don't have visibility.
+ // We need to use IPO to track that.
+ IREE::HAL::CommandBufferExecutionBarrierOp previousFullBarrier;
+
Value &getPushConstant(int64_t index) {
if (index >= pushConstants.size()) {
pushConstants.resize(index + 1);
@@ -76,6 +84,31 @@
} // namespace
+static void processOp(IREE::HAL::CommandBufferExecutionBarrierOp op,
+ CommandBufferState &state) {
+ if (state.previousFullBarrier) {
+ // We are following a full barrier - this is a no-op (issuing two barriers
+ // doesn't make the device barrier any harder).
+ op.erase();
+ return;
+ }
+
+ // See if this is a full barrier. These are all we emit today so this simple
+ // analysis can remain simple by pattern matching.
+ if (bitEnumContains(op.source_stage_mask(),
+ IREE::HAL::ExecutionStageBitfield::CommandRetire |
+ IREE::HAL::ExecutionStageBitfield::Transfer |
+ IREE::HAL::ExecutionStageBitfield::Dispatch) &&
+ bitEnumContains(op.target_stage_mask(),
+ IREE::HAL::ExecutionStageBitfield::CommandRetire |
+ IREE::HAL::ExecutionStageBitfield::Transfer |
+ IREE::HAL::ExecutionStageBitfield::Dispatch)) {
+ state.previousFullBarrier = op;
+ } else {
+ state.previousFullBarrier = {};
+ }
+}
+
static LogicalResult processOp(IREE::HAL::CommandBufferPushConstantsOp op,
CommandBufferState &state) {
// Today we only eat constants from the beginning or end of the range
@@ -200,51 +233,64 @@
auto invalidateState = [&](Value commandBuffer) {
stateMap[commandBuffer] = {};
};
+ auto resetCommandBufferBarrierBit = [&](Operation *op) {
+ assert(op->getNumOperands() > 0 && "must be a command buffer op");
+ auto commandBuffer = op->getOperand(0);
+ assert(commandBuffer.getType().isa<IREE::HAL::CommandBufferType>() &&
+ "operand 0 must be a command buffer");
+ stateMap[commandBuffer].previousFullBarrier = {};
+ };
for (auto &op : llvm::make_early_inc_range(block.getOperations())) {
- if (op.getDialect())
- TypeSwitch<Operation *>(&op)
- .Case([&](IREE::HAL::CommandBufferBeginOp op) {
+ if (!op.getDialect()) continue;
+ TypeSwitch<Operation *>(&op)
+ .Case([&](IREE::HAL::CommandBufferBeginOp op) {
+ invalidateState(op.command_buffer());
+ })
+ .Case([&](IREE::HAL::CommandBufferEndOp op) {
+ invalidateState(op.command_buffer());
+ })
+ .Case([&](IREE::HAL::CommandBufferExecutionBarrierOp op) {
+ processOp(op, stateMap[op.command_buffer()]);
+ })
+ .Case([&](IREE::HAL::CommandBufferPushConstantsOp op) {
+ resetCommandBufferBarrierBit(op);
+ if (failed(processOp(op, stateMap[op.command_buffer()]))) {
invalidateState(op.command_buffer());
- })
- .Case([&](IREE::HAL::CommandBufferEndOp op) {
+ }
+ })
+ .Case([&](IREE::HAL::CommandBufferPushDescriptorSetOp op) {
+ resetCommandBufferBarrierBit(op);
+ if (failed(processOp(op, stateMap[op.command_buffer()]))) {
invalidateState(op.command_buffer());
- })
- .Case([&](IREE::HAL::CommandBufferPushConstantsOp op) {
- if (failed(processOp(op, stateMap[op.command_buffer()]))) {
- invalidateState(op.command_buffer());
- }
- })
- .Case([&](IREE::HAL::CommandBufferPushDescriptorSetOp op) {
- if (failed(processOp(op, stateMap[op.command_buffer()]))) {
- invalidateState(op.command_buffer());
- }
- })
- .Case([&](IREE::HAL::CommandBufferBindDescriptorSetOp op) {
- if (failed(processOp(op, stateMap[op.command_buffer()]))) {
- invalidateState(op.command_buffer());
- }
- })
- .Case<IREE::HAL::CommandBufferDeviceOp,
- IREE::HAL::CommandBufferBeginDebugGroupOp,
- IREE::HAL::CommandBufferEndDebugGroupOp,
- IREE::HAL::CommandBufferExecutionBarrierOp,
- IREE::HAL::CommandBufferFillBufferOp,
- IREE::HAL::CommandBufferCopyBufferOp,
- IREE::HAL::CommandBufferDispatchSymbolOp,
- IREE::HAL::CommandBufferDispatchOp,
- IREE::HAL::CommandBufferDispatchIndirectSymbolOp,
- IREE::HAL::CommandBufferDispatchIndirectOp>(
- [&](Operation *op) {
- // Ok - don't impact state.
- })
- .Default([&](Operation *op) {
- // Unknown op - discard state cache.
- // This is to avoid correctness issues with region ops (like
- // scf.if) that we don't analyze properly here. We could
- // restrict this a bit by only discarding on use of the
- // command buffer.
- stateMap.clear();
- });
+ }
+ })
+ .Case([&](IREE::HAL::CommandBufferBindDescriptorSetOp op) {
+ resetCommandBufferBarrierBit(op);
+ if (failed(processOp(op, stateMap[op.command_buffer()]))) {
+ invalidateState(op.command_buffer());
+ }
+ })
+ .Case<IREE::HAL::CommandBufferDeviceOp,
+ IREE::HAL::CommandBufferBeginDebugGroupOp,
+ IREE::HAL::CommandBufferEndDebugGroupOp,
+ IREE::HAL::CommandBufferFillBufferOp,
+ IREE::HAL::CommandBufferCopyBufferOp,
+ IREE::HAL::CommandBufferDispatchSymbolOp,
+ IREE::HAL::CommandBufferDispatchOp,
+ IREE::HAL::CommandBufferDispatchIndirectSymbolOp,
+ IREE::HAL::CommandBufferDispatchIndirectOp>(
+ [&](Operation *op) {
+ // Ok - don't impact state.
+ resetCommandBufferBarrierBit(op);
+ })
+ .Default([&](Operation *op) {
+ // Unknown op - discard state cache.
+ // This is to avoid correctness issues with region ops (like
+ // scf.if) that we don't analyze properly here. We could
+ // restrict this a bit by only discarding on use of the
+ // command buffer.
+ stateMap.clear();
+ });
}
}
}
diff --git a/iree/compiler/Dialect/HAL/Transforms/test/elide_redundant_commands.mlir b/iree/compiler/Dialect/HAL/Transforms/test/elide_redundant_commands.mlir
index 16b532a..a80ebfd 100644
--- a/iree/compiler/Dialect/HAL/Transforms/test/elide_redundant_commands.mlir
+++ b/iree/compiler/Dialect/HAL/Transforms/test/elide_redundant_commands.mlir
@@ -1,5 +1,27 @@
// RUN: iree-opt -split-input-file -pass-pipeline='builtin.func(iree-hal-elide-redundant-commands)' %s | IreeFileCheck %s
+// Tests that redundant barriers are elided but barriers gaurding ops are not.
+
+// CHECK-LABEL: @elideRedundantBarriers
+// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer, %[[LAYOUT:.+]]: !hal.executable_layout)
+func @elideRedundantBarriers(%cmd: !hal.command_buffer, %executable_layout: !hal.executable_layout) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c42_i32 = arith.constant 42 : i32
+ // CHECK: hal.command_buffer.execution_barrier
+ hal.command_buffer.execution_barrier<%cmd : !hal.command_buffer> source("Dispatch|Transfer|CommandRetire") target("CommandIssue|Dispatch|Transfer") flags("None")
+ // CHECK-NOT: hal.command_buffer.execution_barrier
+ hal.command_buffer.execution_barrier<%cmd : !hal.command_buffer> source("Dispatch|Transfer|CommandRetire") target("CommandIssue|Dispatch|Transfer") flags("None")
+ // CHECK: hal.command_buffer.push_constants
+ hal.command_buffer.push_constants<%cmd : !hal.command_buffer> layout(%executable_layout : !hal.executable_layout) offset(0) values([%c42_i32]) : i32
+ // CHECK: hal.command_buffer.execution_barrier
+ hal.command_buffer.execution_barrier<%cmd : !hal.command_buffer> source("Dispatch|Transfer|CommandRetire") target("CommandIssue|Dispatch|Transfer") flags("None")
+ // CHECK: return
+ return
+}
+
+// -----
+
// CHECK-LABEL: @elidePushConstants
func @elidePushConstants(%cmd: !hal.command_buffer, %executable_layout: !hal.executable_layout) {
// CHECK-DAG: %[[C0:.+]] = arith.constant 0