Adding redundant command buffer barrier elision. (#7600)

Super simple since we only generate barriers in one way today.
diff --git a/iree/compiler/Dialect/HAL/Transforms/ElideRedundantCommands.cpp b/iree/compiler/Dialect/HAL/Transforms/ElideRedundantCommands.cpp
index 92c24e9..4cbf737 100644
--- a/iree/compiler/Dialect/HAL/Transforms/ElideRedundantCommands.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/ElideRedundantCommands.cpp
@@ -51,6 +51,14 @@
   SmallVector<Value, 32> pushConstants;
   SmallVector<DescriptorSetState, 4> descriptorSets;
 
+  // Set after we know a full barrier has been issued; any subsequent barrier
+  // until a real operation is redundant. We could track more fine-grained state
+  // here such as which stages are being waited on.
+  // Note that we assume no barriers by default, as the command buffer may have
+  // been passed as a function/branch argument and we don't have visibility.
+  // We need to use IPO to track that.
+  IREE::HAL::CommandBufferExecutionBarrierOp previousFullBarrier;
+
   Value &getPushConstant(int64_t index) {
     if (index >= pushConstants.size()) {
       pushConstants.resize(index + 1);
@@ -76,6 +84,31 @@
 
 }  // namespace
 
+static void processOp(IREE::HAL::CommandBufferExecutionBarrierOp op,
+                      CommandBufferState &state) {
+  if (state.previousFullBarrier) {
+    // We are following a full barrier - this is a no-op (issuing two barriers
+    // doesn't make the device barrier any harder).
+    op.erase();
+    return;
+  }
+
+  // See if this is a full barrier. These are all we emit today so this simple
+  // analysis can remain simple by pattern matching.
+  if (bitEnumContains(op.source_stage_mask(),
+                      IREE::HAL::ExecutionStageBitfield::CommandRetire |
+                          IREE::HAL::ExecutionStageBitfield::Transfer |
+                          IREE::HAL::ExecutionStageBitfield::Dispatch) &&
+      bitEnumContains(op.target_stage_mask(),
+                      IREE::HAL::ExecutionStageBitfield::CommandRetire |
+                          IREE::HAL::ExecutionStageBitfield::Transfer |
+                          IREE::HAL::ExecutionStageBitfield::Dispatch)) {
+    state.previousFullBarrier = op;
+  } else {
+    state.previousFullBarrier = {};
+  }
+}
+
 static LogicalResult processOp(IREE::HAL::CommandBufferPushConstantsOp op,
                                CommandBufferState &state) {
   // Today we only eat constants from the beginning or end of the range
@@ -200,51 +233,64 @@
         auto invalidateState = [&](Value commandBuffer) {
           stateMap[commandBuffer] = {};
         };
+        auto resetCommandBufferBarrierBit = [&](Operation *op) {
+          assert(op->getNumOperands() > 0 && "must be a command buffer op");
+          auto commandBuffer = op->getOperand(0);
+          assert(commandBuffer.getType().isa<IREE::HAL::CommandBufferType>() &&
+                 "operand 0 must be a command buffer");
+          stateMap[commandBuffer].previousFullBarrier = {};
+        };
         for (auto &op : llvm::make_early_inc_range(block.getOperations())) {
-          if (op.getDialect())
-            TypeSwitch<Operation *>(&op)
-                .Case([&](IREE::HAL::CommandBufferBeginOp op) {
+          if (!op.getDialect()) continue;
+          TypeSwitch<Operation *>(&op)
+              .Case([&](IREE::HAL::CommandBufferBeginOp op) {
+                invalidateState(op.command_buffer());
+              })
+              .Case([&](IREE::HAL::CommandBufferEndOp op) {
+                invalidateState(op.command_buffer());
+              })
+              .Case([&](IREE::HAL::CommandBufferExecutionBarrierOp op) {
+                processOp(op, stateMap[op.command_buffer()]);
+              })
+              .Case([&](IREE::HAL::CommandBufferPushConstantsOp op) {
+                resetCommandBufferBarrierBit(op);
+                if (failed(processOp(op, stateMap[op.command_buffer()]))) {
                   invalidateState(op.command_buffer());
-                })
-                .Case([&](IREE::HAL::CommandBufferEndOp op) {
+                }
+              })
+              .Case([&](IREE::HAL::CommandBufferPushDescriptorSetOp op) {
+                resetCommandBufferBarrierBit(op);
+                if (failed(processOp(op, stateMap[op.command_buffer()]))) {
                   invalidateState(op.command_buffer());
-                })
-                .Case([&](IREE::HAL::CommandBufferPushConstantsOp op) {
-                  if (failed(processOp(op, stateMap[op.command_buffer()]))) {
-                    invalidateState(op.command_buffer());
-                  }
-                })
-                .Case([&](IREE::HAL::CommandBufferPushDescriptorSetOp op) {
-                  if (failed(processOp(op, stateMap[op.command_buffer()]))) {
-                    invalidateState(op.command_buffer());
-                  }
-                })
-                .Case([&](IREE::HAL::CommandBufferBindDescriptorSetOp op) {
-                  if (failed(processOp(op, stateMap[op.command_buffer()]))) {
-                    invalidateState(op.command_buffer());
-                  }
-                })
-                .Case<IREE::HAL::CommandBufferDeviceOp,
-                      IREE::HAL::CommandBufferBeginDebugGroupOp,
-                      IREE::HAL::CommandBufferEndDebugGroupOp,
-                      IREE::HAL::CommandBufferExecutionBarrierOp,
-                      IREE::HAL::CommandBufferFillBufferOp,
-                      IREE::HAL::CommandBufferCopyBufferOp,
-                      IREE::HAL::CommandBufferDispatchSymbolOp,
-                      IREE::HAL::CommandBufferDispatchOp,
-                      IREE::HAL::CommandBufferDispatchIndirectSymbolOp,
-                      IREE::HAL::CommandBufferDispatchIndirectOp>(
-                    [&](Operation *op) {
-                      // Ok - don't impact state.
-                    })
-                .Default([&](Operation *op) {
-                  // Unknown op - discard state cache.
-                  // This is to avoid correctness issues with region ops (like
-                  // scf.if) that we don't analyze properly here. We could
-                  // restrict this a bit by only discarding on use of the
-                  // command buffer.
-                  stateMap.clear();
-                });
+                }
+              })
+              .Case([&](IREE::HAL::CommandBufferBindDescriptorSetOp op) {
+                resetCommandBufferBarrierBit(op);
+                if (failed(processOp(op, stateMap[op.command_buffer()]))) {
+                  invalidateState(op.command_buffer());
+                }
+              })
+              .Case<IREE::HAL::CommandBufferDeviceOp,
+                    IREE::HAL::CommandBufferBeginDebugGroupOp,
+                    IREE::HAL::CommandBufferEndDebugGroupOp,
+                    IREE::HAL::CommandBufferFillBufferOp,
+                    IREE::HAL::CommandBufferCopyBufferOp,
+                    IREE::HAL::CommandBufferDispatchSymbolOp,
+                    IREE::HAL::CommandBufferDispatchOp,
+                    IREE::HAL::CommandBufferDispatchIndirectSymbolOp,
+                    IREE::HAL::CommandBufferDispatchIndirectOp>(
+                  [&](Operation *op) {
+                    // Ok - don't impact state.
+                    resetCommandBufferBarrierBit(op);
+                  })
+              .Default([&](Operation *op) {
+                // Unknown op - discard state cache.
+                // This is to avoid correctness issues with region ops (like
+                // scf.if) that we don't analyze properly here. We could
+                // restrict this a bit by only discarding on use of the
+                // command buffer.
+                stateMap.clear();
+              });
         }
       }
     }
diff --git a/iree/compiler/Dialect/HAL/Transforms/test/elide_redundant_commands.mlir b/iree/compiler/Dialect/HAL/Transforms/test/elide_redundant_commands.mlir
index 16b532a..a80ebfd 100644
--- a/iree/compiler/Dialect/HAL/Transforms/test/elide_redundant_commands.mlir
+++ b/iree/compiler/Dialect/HAL/Transforms/test/elide_redundant_commands.mlir
@@ -1,5 +1,27 @@
 // RUN: iree-opt -split-input-file -pass-pipeline='builtin.func(iree-hal-elide-redundant-commands)' %s | IreeFileCheck %s
 
+// Tests that redundant barriers are elided but barriers gaurding ops are not.
+
+// CHECK-LABEL: @elideRedundantBarriers
+// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer, %[[LAYOUT:.+]]: !hal.executable_layout)
+func @elideRedundantBarriers(%cmd: !hal.command_buffer, %executable_layout: !hal.executable_layout) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c42_i32 = arith.constant 42 : i32
+  // CHECK: hal.command_buffer.execution_barrier
+  hal.command_buffer.execution_barrier<%cmd : !hal.command_buffer> source("Dispatch|Transfer|CommandRetire") target("CommandIssue|Dispatch|Transfer") flags("None")
+  // CHECK-NOT: hal.command_buffer.execution_barrier
+  hal.command_buffer.execution_barrier<%cmd : !hal.command_buffer> source("Dispatch|Transfer|CommandRetire") target("CommandIssue|Dispatch|Transfer") flags("None")
+  // CHECK: hal.command_buffer.push_constants
+  hal.command_buffer.push_constants<%cmd : !hal.command_buffer> layout(%executable_layout : !hal.executable_layout) offset(0) values([%c42_i32]) : i32
+  // CHECK: hal.command_buffer.execution_barrier
+  hal.command_buffer.execution_barrier<%cmd : !hal.command_buffer> source("Dispatch|Transfer|CommandRetire") target("CommandIssue|Dispatch|Transfer") flags("None")
+  // CHECK: return
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @elidePushConstants
 func @elidePushConstants(%cmd: !hal.command_buffer, %executable_layout: !hal.executable_layout) {
   // CHECK-DAG: %[[C0:.+]] = arith.constant 0