Disabling fusion of mutable resource bindings. (#7716)

Inducing aliasing of mutable bindings breaks autovectorization in LLVM
and may have issues in other backends as well. This workaround disables
fusing any mutable resources such that they remain unique looking to
the compiler.

This is not ideal: it results in a lot more bindings being tracked that
otherwise would not need to be. Hopefully once we stop using the
autovectorizer we can flip this back on. At least with this change we
have a way to compare the performance on various backends while we try
to get them working.
diff --git a/iree/compiler/Dialect/Stream/Transforms/FuseDispatchBindings.cpp b/iree/compiler/Dialect/Stream/Transforms/FuseDispatchBindings.cpp
index 90e057a..5359ee7 100644
--- a/iree/compiler/Dialect/Stream/Transforms/FuseDispatchBindings.cpp
+++ b/iree/compiler/Dialect/Stream/Transforms/FuseDispatchBindings.cpp
@@ -75,7 +75,8 @@
 // points if we had minor divergence in order to gain more fusion in the common
 // cases.
 static SmallVector<Binding> findCorrelatedBindings(
-    unsigned bindingCount, ArrayRef<IREE::Stream::CmdDispatchOp> dispatchOps) {
+    unsigned bindingCount, ArrayRef<IREE::Stream::CmdDispatchOp> dispatchOps,
+    bool aliasMutableBindings) {
   // For each dispatch build equivalence classes indicating which bindings are
   // from the same base resource. Note that not all dispatches will have the
   // same duplicate bindings (though we hope they do!).
@@ -84,15 +85,32 @@
   for (auto dispatchOp : dispatchOps) {
     llvm::EquivalenceClasses<unsigned> ec;
     DenseMap<Value, unsigned> leaders;
-    for (auto resource : llvm::enumerate(dispatchOp.resources())) {
-      auto it = leaders.find(resource.value());
-      if (it == leaders.end()) {
+    for (auto it : llvm::enumerate(llvm::zip(dispatchOp.resources(),
+                                             dispatchOp.resource_accesses()))) {
+      auto resource = std::get<0>(it.value());
+
+      // If the resource is mutable and we were told not to alias mutable
+      // bindings we always put the resource into its own class.
+      auto resourceAccess =
+          std::get<1>(it.value())
+              .cast<IREE::Stream::ResourceAccessBitfieldAttr>();
+      if (!aliasMutableBindings &&
+          bitEnumContains(resourceAccess.getValue(),
+                          IREE::Stream::ResourceAccessBitfield::Write)) {
+        ec.insert(it.index());
+        leaders.insert(std::make_pair(resource, it.index()));
+        continue;
+      }
+
+      // Find or create a class for equivalent aliasable resource bindings.
+      auto ecIt = leaders.find(resource);
+      if (ecIt == leaders.end()) {
         // New unique value.
-        ec.insert(resource.index());
-        leaders.insert(std::make_pair(resource.value(), resource.index()));
+        ec.insert(it.index());
+        leaders.insert(std::make_pair(resource, it.index()));
       } else {
         // Found existing; union with leader.
-        ec.unionSets(it->second, resource.index());
+        ec.unionSets(ecIt->second, it.index());
       }
     }
     ecs.push_back(std::move(ec));
@@ -293,7 +311,7 @@
     IREE::Stream::ExecutableOp executableOp,
     IREE::Stream::ExecutableExportOp exportOp,
     ArrayRef<IREE::Stream::CmdDispatchOp> dispatchOps,
-    MemoizedCmdZeros &memoizedZeros) {
+    bool aliasMutableBindings, MemoizedCmdZeros &memoizedZeros) {
   if (dispatchOps.empty()) return;  // no-op if no dispatches
   auto anyDispatchOp = dispatchOps.front();
   unsigned bindingCount = anyDispatchOp.resources().size();
@@ -310,7 +328,8 @@
   });
 
   // Analysis to find which bindings we can fuse together based on dispatches.
-  auto bindings = findCorrelatedBindings(bindingCount, dispatchOps);
+  auto bindings =
+      findCorrelatedBindings(bindingCount, dispatchOps, aliasMutableBindings);
 
   // TODO(benvanik): canonicalize bindings and bail early here. Today this
   // rebasing will widen access modes and pass in the offset across the bindings
@@ -424,7 +443,7 @@
       for (auto exportOp :
            executableOp.getOps<IREE::Stream::ExecutableExportOp>()) {
         fuseDispatchBindings(executableOp, exportOp, entryDispatchMap[exportOp],
-                             memoizedZeros);
+                             aliasMutableBindings, memoizedZeros);
       }
     }
   }
diff --git a/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
index 8425eef..fa4f63c 100644
--- a/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
@@ -234,6 +234,9 @@
     passManager.addPass(IREE::Stream::createFoldUniformOperandsPass());
 
     // Only want to specialize after we've added all the operands we need above.
+    // TODO(benvanik): make codegen more efficient with the specialized
+    // constants. The lookup tables inserted are currently extremely slow on
+    // some backends.
     // passManager.addPass(IREE::Stream::createSpecializeDispatchesPass());
 
     // TODO(benvanik): canonicalize bindings: we should sort the bindings by
diff --git a/iree/compiler/Dialect/Stream/Transforms/Passes.td b/iree/compiler/Dialect/Stream/Transforms/Passes.td
index dfc4c89..7aeddaa 100644
--- a/iree/compiler/Dialect/Stream/Transforms/Passes.td
+++ b/iree/compiler/Dialect/Stream/Transforms/Passes.td
@@ -166,6 +166,11 @@
   let constructor = [{
     mlir::iree_compiler::IREE::Stream::createFuseDispatchBindingsPass()
   }];
+  let options = [
+    Option<"aliasMutableBindings", "alias-mutable-bindings",
+           "bool", /*default=*/"false",
+           "Fuses bindings that are mutable instead of leaving them split.">
+  ];
 }
 
 def SpecializeDispatches :
diff --git a/iree/compiler/Dialect/Stream/Transforms/test/BUILD b/iree/compiler/Dialect/Stream/Transforms/test/BUILD
index 787319c..59e9ff9 100644
--- a/iree/compiler/Dialect/Stream/Transforms/test/BUILD
+++ b/iree/compiler/Dialect/Stream/Transforms/test/BUILD
@@ -23,6 +23,7 @@
             "encode_tensors.mlir",
             "fold_uniform_operands.mlir",
             "fuse_dispatch_bindings.mlir",
+            "fuse_dispatch_bindings_noalias.mlir",
             "layout_slices.mlir",
             "materialize_copy_on_write.mlir",
             "outline_constants.mlir",
diff --git a/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
index 354d5e5..416b77f 100644
--- a/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
+++ b/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
@@ -20,6 +20,7 @@
     "encode_tensors.mlir"
     "fold_uniform_operands.mlir"
     "fuse_dispatch_bindings.mlir"
+    "fuse_dispatch_bindings_noalias.mlir"
     "layout_slices.mlir"
     "materialize_copy_on_write.mlir"
     "outline_constants.mlir"
diff --git a/iree/compiler/Dialect/Stream/Transforms/test/fuse_dispatch_bindings.mlir b/iree/compiler/Dialect/Stream/Transforms/test/fuse_dispatch_bindings.mlir
index 609fe85..1ec4588 100644
--- a/iree/compiler/Dialect/Stream/Transforms/test/fuse_dispatch_bindings.mlir
+++ b/iree/compiler/Dialect/Stream/Transforms/test/fuse_dispatch_bindings.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -iree-stream-fuse-dispatch-bindings %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='iree-stream-fuse-dispatch-bindings{alias-mutable-bindings=true}' %s | IreeFileCheck %s
 
 // Test that bindings that are unique are rebased to the widest possible access
 // and to have a 0 offset by passing in the actual offset as operands. The
diff --git a/iree/compiler/Dialect/Stream/Transforms/test/fuse_dispatch_bindings_noalias.mlir b/iree/compiler/Dialect/Stream/Transforms/test/fuse_dispatch_bindings_noalias.mlir
new file mode 100644
index 0000000..a07bee7
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Transforms/test/fuse_dispatch_bindings_noalias.mlir
@@ -0,0 +1,79 @@
+// RUN: iree-opt -split-input-file -pass-pipeline='iree-stream-fuse-dispatch-bindings{alias-mutable-bindings=false}' %s | IreeFileCheck %s
+
+// TODO(benvanik): remove this file when aliasing mutable bindings is fixed.
+
+// Tests that bindings that are duplicated at all dispatch sites are folded
+// so long as they are not mutable.
+
+// CHECK-LABEL: @deduplicateBindingsEx
+stream.executable private @deduplicateBindingsEx {
+  stream.executable.export public @dispatch
+  builtin.module  {
+    // CHECK: func @dispatch(%[[BINDING_A:.+]]: !stream.binding, %[[BINDING_C:.+]]: !stream.binding,
+    // CHECK-SAME:           %[[OFFSET_A:.+]]: index, %[[OFFSET_B:.+]]: index, %[[OFFSET_C:.+]]: index, %[[OPERAND:.+]]: index)
+    func @dispatch(%binding_a: !stream.binding, %binding_b: !stream.binding, %binding_c: !stream.binding, %operand: index) {
+      %c0 = arith.constant 0 : index
+      %c20 = arith.constant 20 : index
+      %c40 = arith.constant 40 : index
+
+      // CHECK: %[[SUM_OFFSET_A:.+]] = arith.addi %c0, %[[OFFSET_A]]
+      // CHECK: %[[SUBSPAN_A:.+]] = stream.binding.subspan %[[BINDING_A]][%[[SUM_OFFSET_A]]]
+      %subspan_a = stream.binding.subspan %binding_a[%c0] : !stream.binding -> !flow.dispatch.tensor<readwrite:20xi8>{%c20}
+      // CHECK-NEXT: util.do_not_optimize(%[[SUBSPAN_A]])
+      util.do_not_optimize(%subspan_a) : !flow.dispatch.tensor<readwrite:20xi8>
+
+      // CHECK: %[[SUM_OFFSET_B:.+]] = arith.addi %c20, %[[OFFSET_B]]
+      // CHECK-NEXT: %[[SUBSPAN_B:.+]] = stream.binding.subspan %[[BINDING_A]][%[[SUM_OFFSET_B]]]
+      %subspan_b = stream.binding.subspan %binding_b[%c20] : !stream.binding -> !flow.dispatch.tensor<readwrite:20xi8>{%c20}
+      // CHECK-NEXT: util.do_not_optimize(%[[SUBSPAN_B]])
+      util.do_not_optimize(%subspan_b) : !flow.dispatch.tensor<readwrite:20xi8>
+
+      // CHECK: %[[SUM_OFFSET_C:.+]] = arith.addi %c40, %[[OFFSET_C]]
+      // CHECK-NEXT: %[[SUBSPAN_C:.+]] = stream.binding.subspan %[[BINDING_C]][%[[SUM_OFFSET_C]]]
+      %subspan_c = stream.binding.subspan %binding_c[%c40] : !stream.binding -> !flow.dispatch.tensor<readwrite:20xi8>{%c20}
+      // CHECK-NEXT: util.do_not_optimize(%[[SUBSPAN_C]])
+      util.do_not_optimize(%subspan_c) : !flow.dispatch.tensor<readwrite:20xi8>
+
+      // CHECK-NEXT: util.do_not_optimize(%[[OPERAND]]) : index
+      util.do_not_optimize(%operand) : index
+      return
+    }
+  }
+}
+// CHECK: func @deduplicateBindings(%[[OPERAND:.+]]: index)
+func @deduplicateBindings(%operand: index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c20 = arith.constant 20 : index
+  %c40 = arith.constant 40 : index
+  %c80 = arith.constant 80 : index
+  %c120 = arith.constant 120 : index
+  %c160 = arith.constant 160 : index
+  %c200 = arith.constant 200 : index
+  // CHECK: %[[ALLOC0:.+]] = stream.resource.alloc
+  %alloc0 = stream.resource.alloc uninitialized : !stream.resource<transient>{%c200}
+  // CHECK: stream.cmd.execute
+  %result_timepoint = stream.cmd.execute
+      // CHECK-SAME: with(%[[ALLOC0]] as %[[CAPTURE0:.+]]: !stream.resource<transient>{%c200})
+      with(%alloc0 as %capture0: !stream.resource<transient>{%c200}) {
+    // CHECK: stream.cmd.dispatch {{.+}}(%c40, %c80, %c0, %[[OPERAND]] : index, index, index, index)
+    stream.cmd.dispatch @deduplicateBindingsEx::@dispatch[%c1, %c1, %c1](%operand : index) {
+      // CHECK-NEXT: ro %[[CAPTURE0]][%c0
+      ro %capture0[%c40 for %c20] : !stream.resource<transient>{%c200},
+      // CHECK-NOT: ro %[[CAPTURE0]][%c0
+      ro %capture0[%c80 for %c20] : !stream.resource<transient>{%c200},
+      // CHECK-NEXT: rw %[[CAPTURE0]]
+      rw %capture0[%c0 for %c20] : !stream.resource<transient>{%c200}
+    }
+    // CHECK: stream.cmd.dispatch {{.+}}(%c120, %c160, %c20, %[[OPERAND]] : index, index, index, index)
+    stream.cmd.dispatch @deduplicateBindingsEx::@dispatch[%c1, %c1, %c1](%operand : index) {
+      // CHECK-NEXT: ro %[[CAPTURE0]][%c0
+      ro %capture0[%c120 for %c20] : !stream.resource<transient>{%c200},
+      // CHECK-NOT: ro %[[CAPTURE0]][%c0
+      ro %capture0[%c160 for %c20] : !stream.resource<transient>{%c200},
+      // CHECK-NEXT: rw %[[CAPTURE0]]
+      rw %capture0[%c20 for %c20] : !stream.resource<transient>{%c200}
+    }
+  } => !stream.timepoint
+  return
+}