[CodeGen] Do not fuse parallel ops if they directly write to destination. (#21837)

The pass was mainly introduced for softmax dispatch, so it's okay to
limit the scope of the fusion. If we unconditionally fuse the ops, it
may result in independent compute ops. In this context, there are more
than one root ops; codegen does not expect the case. It is basically a
result of two dispatches that get formed into a single dispatch.

Fixes https://github.com/iree-org/iree/issues/21836

It is a step towards https://github.com/iree-org/iree/issues/21828.
There are other issues about domination in some dispatches.

---------

Signed-off-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/RematerializeParallelOps.cpp b/compiler/src/iree/compiler/Codegen/Common/RematerializeParallelOps.cpp
index 1f2a7b7..3565fd9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/RematerializeParallelOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/RematerializeParallelOps.cpp
@@ -6,6 +6,8 @@
 
 #include "iree/compiler/Codegen/Common/Passes.h"
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
+#include "iree/compiler/Dialect/TensorExt/IR/TensorExtOps.h"
+#include "llvm/Support/Casting.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
@@ -25,6 +27,11 @@
   return t.isIntOrIndexOrFloat();
 }
 
+static bool hasDirectWriteResult(Operation *op) {
+  return llvm::any_of(op->getUsers(),
+                      llvm::IsaPred<IREE::TensorExt::DispatchTensorStoreOp>);
+}
+
 /// Rematerialize all parallel elementwise operations into its users within a
 /// `flow.dispatch.region`.
 struct RematerializeParallelOpsPattern
@@ -51,6 +58,9 @@
       if (producer && hasExternalCapture(producer)) {
         continue;
       }
+      if (producer && hasDirectWriteResult(producer)) {
+        continue;
+      }
       FailureOr<linalg::ElementwiseOpFusionResult> fusionResult =
           linalg::fuseElementwiseOps(rewriter, &opOperand);
       if (succeeded(fusionResult)) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/rematerialize_parallel_ops.mlir b/compiler/src/iree/compiler/Codegen/Common/test/rematerialize_parallel_ops.mlir
index ae5603c..dc09413 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/rematerialize_parallel_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/rematerialize_parallel_ops.mlir
@@ -169,3 +169,46 @@
 // CHECK-LABEL: func @no_external_capture_fusion(
 //       CHECK:   linalg.generic
 //       CHECK:   linalg.generic
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#pipeline_layout = #hal.pipeline.layout<
+  bindings = [
+    #hal.pipeline.binding<storage_buffer, Indirect>,
+    #hal.pipeline.binding<storage_buffer, Indirect>
+  ]>
+func.func @producer_has_direct_write(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5xf32>) {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c64) flags(Indirect) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<3x5xf32>>
+  %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c128) flags(Indirect) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<3x4x5xf32>>
+  %2 = tensor.empty() : tensor<3x5xf32>
+  %3 = tensor.empty() : tensor<3x4x5xf32>
+  %4 = linalg.fill ins(%cst : f32) outs(%2 : tensor<3x5xf32>) -> tensor<3x5xf32>
+  %5 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<3x4x5xf32>, tensor<3x5xf32>) outs(%3 : tensor<3x4x5xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %7 = arith.subf %in, %in_0 : f32
+    linalg.yield %7 : f32
+  } -> tensor<3x4x5xf32>
+  %6 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%5 : tensor<3x4x5xf32>) outs(%4 : tensor<3x5xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %7 = math.exp %in : f32
+    %8 = arith.addf %7, %out : f32
+    linalg.yield %8 : f32
+  } -> tensor<3x5xf32>
+  iree_tensor_ext.dispatch.tensor.store %6, %0, offsets = [0, 0], sizes = [3, 5], strides = [1, 1] : tensor<3x5xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<3x5xf32>>
+  iree_tensor_ext.dispatch.tensor.store %5, %1, offsets = [0, 0, 0], sizes = [3, 4, 5], strides = [1, 1, 1] : tensor<3x4x5xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<3x4x5xf32>>
+  return
+}
+// CHECK-LABEL: func.func @producer_has_direct_write
+//       CHECK:   %[[ELEM:.+]] = linalg.generic
+//       CHECK:   %[[REDUCTION:.+]] = linalg.generic
+//  CHECK-SAME:     ins(%[[ELEM]]
+//   CHECK-DAG:   iree_tensor_ext.dispatch.tensor.store %[[REDUCTION]]
+//   CHECK-DAG:   iree_tensor_ext.dispatch.tensor.store %[[ELEM]]