[DT] Fuse encoding ops more aggressively for multi-use, gather, and slices ops. (#21830)

The fusion constraint of multi-use dispatch is only required by
SetEncoding pass, because it has to move consumer dispatches around. It
is not required by encoding fusion, because it is just moving a
SetEncoding op into its producer dispatch.

The revision also allows the fusion when the dispatch region contains
tensor.extract_slice op and iree_linalg_ext.gather ops. It reduces the
number of dispatches to 644 in llama fp8 model, the same as without data
tiling. The latency drops 25ms, from 378ms to 353ms.

| | No Data Tiling | Data Tiling w/o the revision | Data Tiling w/ the
revision |
| ------------- | ------------- | ------------- | ------------- |
| Benchmark latency | 243ms  |  378ms | 353ms |
| Memory usage (HIP unpooled)  | 15.9GB  |  31.14GB | 31.11GB |
| Number of dispatches | 644  | 741 | 644 |

| | No Data Tiling (ms) | Data Tiling w/o the revision | Data Tiling w/
the revision |
| ------------- | ------------- | ------------- | ------------- |
| dispatch_15_attention_4x8x4xDx128xf8  | 62.29 |  55.35 | 59.21 |
| dispatch_20_matmul_like_Dx14336x4096_f8xf8xf32 | 40.13 | 89.14 |
93.72|
| dispatch_19_matmul_like_Dx14336x4096_f8xf8xf32 | 28.01 | 44.78 | 44.59
|
| dispatch_21_matmul_like_Dx4096x14336_f8xf8xf32 | 27.25 | 40.18 | 39.99
|
| dispatch_643_matmul_like_Dx128256x4096_f16xf16xf32 | 17.1 | 29.76 |
29.21 |
| dispatch_16_matmul_like_Dx4096x4096_f8xf8xf32 | 8.83 | 17.92 | 17.91 |
| dispatch_23_matmul_like_Dx4096x4096_f8xf8xf32 | 9.27 | 16.69 | 16.59 |
| encoding_10_encode_Dx4096xf8_to_Dx4096xf8 | - |  32.15 | - |
| encoding_6_encode_Dx14336xf32_to_Dx14336xf32 | - |  0.318 | - |

---------

Signed-off-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/src/iree/compiler/DispatchCreation/FuseEncodingOpsIntoDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseEncodingOpsIntoDispatchRegions.cpp
index 9d5e876..cb1efe2 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FuseEncodingOpsIntoDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FuseEncodingOpsIntoDispatchRegions.cpp
@@ -31,12 +31,13 @@
 namespace {
 
 /// Return true if the op is fusable with a SetEncodingOp consumer. The op's
-/// containing dispatch must contain only reshape ops, encoding ops, linalg ops,
-/// and attention ops. Non ShapedType ops (like arith ops, dim ops, etc.) are
-/// also allowed.
-/// TODO(#20179): It should be done by interface methods.
-static bool isFusableWithSetEncoding(Operation *op) {
-  auto parentRegion = op->getParentOfType<IREE::Flow::DispatchRegionOp>();
+/// containing dispatch must contain only:
+///   - Reshape ops, encoding ops, linalg ops, gather ops, and attention ops.
+///   - Non ShapedType ops, e.g., like arith ops, dim ops, etc.
+///   - tensor::ExtractSliceOp is allowed as they can be folded into dispatch
+///     tensor load ops.
+static bool isFusableWithSetEncoding(Operation *target) {
+  auto parentRegion = target->getParentOfType<IREE::Flow::DispatchRegionOp>();
   // Make sure the dispatch region has only one block.
   if (!llvm::hasSingleElement(parentRegion.getBody())) {
     return false;
@@ -49,8 +50,9 @@
       continue;
     }
     if (!isa<tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::EmptyOp,
-             IREE::Encoding::SetEncodingOp, IREE::Encoding::UnsetEncodingOp,
-             linalg::LinalgOp, IREE::LinalgExt::AttentionOp>(op)) {
+             tensor::ExtractSliceOp, IREE::Encoding::SetEncodingOp,
+             IREE::Encoding::UnsetEncodingOp, linalg::LinalgOp,
+             IREE::LinalgExt::AttentionOp, IREE::LinalgExt::GatherOp>(op)) {
       return false;
     }
   }
diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
index 1a321e4..cf51973 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
@@ -145,13 +145,8 @@
 
   auto producerDispatch =
       dyn_cast<IREE::Flow::DispatchRegionOp>(producerValue.getOwner());
-  // TODO(MaheshRavishankar): Multi-result producer dispatches can be supported.
-  // Will require to move the consumer dispatch immediately after the producer
-  // instead of what is done below and move other operands of the consumer
-  // dispatch before the producer dispatch.
   if (!producerDispatch ||
-      !llvm::hasSingleElement(producerDispatch.getBody()) ||
-      producerDispatch->getNumResults() != 1) {
+      !llvm::hasSingleElement(producerDispatch.getBody())) {
     return std::nullopt;
   }
   if (!llvm::hasSingleElement(producerValue.getUses())) {
diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
index f422449..a0b483f 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
+++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
@@ -28,8 +28,8 @@
 
 /// Returns the closest producer dispatch region op result and the chain of
 /// operations being looked past during the traversal to find the producer
-/// dispatch. Returns std::nullopt if the dispatch or any ops in the chain have
-/// multiple uses.
+/// dispatch. Returns std::nullopt if the dispatch can not be found in the
+/// chain or any op in the chain is not a reshape-like op.
 std::optional<std::pair<OpResult, SmallVector<Operation *>>>
 getProducerDispatchValueAndOpChain(Value operand);
 
diff --git a/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
index b2f9a56..1c44973 100644
--- a/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
@@ -392,16 +392,24 @@
     OpOperand &operand = op->getOpOperand(operandNum);
     std::optional<std::pair<OpResult, SmallVector<Operation *>>>
         dispatchAndOpChain = getProducerDispatchValueAndOpChain(operand.get());
-    if (dispatchAndOpChain.has_value()) {
-      auto producerDispatch = cast<IREE::Flow::DispatchRegionOp>(
-          dispatchAndOpChain->first.getOwner());
-      WalkResult res =
-          producerDispatch->walk([&](IREE::LinalgExt::AttentionOp op) {
-            return WalkResult::interrupt();
-          });
-      if (res.wasInterrupted()) {
-        return {};
-      }
+    if (!dispatchAndOpChain.has_value()) {
+      continue;
+    }
+    auto producerDispatch = cast<IREE::Flow::DispatchRegionOp>(
+        dispatchAndOpChain->first.getOwner());
+    // TODO(MaheshRavishankar): Multi-result producer dispatches can be
+    // supported. Will require to move the consumer dispatch immediately after
+    // the producer instead of what is done below and move other operands of the
+    // consumer dispatch before the producer dispatch.
+    if (producerDispatch->getNumResults() != 1) {
+      continue;
+    }
+    WalkResult res =
+        producerDispatch->walk([&](IREE::LinalgExt::AttentionOp op) {
+          return WalkResult::interrupt();
+        });
+    if (res.wasInterrupted()) {
+      return {};
     }
   }
 
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fuse_encoding_ops_into_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fuse_encoding_ops_into_dispatch_regions.mlir
index 1d12291..d044636 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/fuse_encoding_ops_into_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/fuse_encoding_ops_into_dispatch_regions.mlir
@@ -281,6 +281,24 @@
 // -----
 
 #encoding = #iree_encoding.testing<>
+util.func public @extract_slice_fusion(%arg0: tensor<192x1024x64xf32>) -> tensor<96x512x32xf32, #encoding> {
+  %0 = flow.dispatch.region -> (tensor<96x512x32xf32>) {
+    %1 = tensor.extract_slice %arg0[0, 0, 0] [96, 512, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<96x512x32xf32>
+    flow.return %1 : tensor<96x512x32xf32>
+  }
+  %2 = iree_encoding.set_encoding %0 : tensor<96x512x32xf32> -> tensor<96x512x32xf32, #encoding>
+  util.return %2 : tensor<96x512x32xf32, #encoding>
+}
+// CHECK-LABEL: @extract_slice_fusion
+// CHECK:       %[[DISPATCH0:.+]] = flow.dispatch.region
+// CHECK:         tensor.extract_slice
+// CHECK:         %[[SET_ENCODING:.+]] = iree_encoding.set_encoding
+// CHECK:         flow.return %[[SET_ENCODING]] :
+// CHECK:       util.return %[[DISPATCH0]]
+
+// -----
+
+#encoding = #iree_encoding.testing<>
 util.func public @attention_fusion(
     %query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>,
     %value: tensor<192x1024x64xf32>, %scale: f32) -> tensor<192x1024x64xf32, #encoding> {
@@ -309,3 +327,55 @@
 // CHECK:         flow.return %[[SET_ENCODING]] :
 // CHECK:       }
 // CHECK:       util.return %[[DISPATCH0]]
+
+// -----
+
+#encoding = #iree_encoding.testing<>
+util.func public @gather_fusion(
+    %source: tensor<10x10xi32>, %indices: tensor<1xi32>) -> tensor<1x10xi32, #encoding> {
+  %empty = tensor.empty() : tensor<1x10xi32>
+  %1 = flow.dispatch.region -> (tensor<1x10xi32>) {
+    %3 = iree_linalg_ext.gather dimension_map = [0]
+           ins(%source, %indices : tensor<10x10xi32>, tensor<1xi32>)
+           outs(%empty : tensor<1x10xi32>) -> tensor<1x10xi32>
+    flow.return %3 : tensor<1x10xi32>
+  }
+  %2 = iree_encoding.set_encoding %1 : tensor<1x10xi32> -> tensor<1x10xi32, #encoding>
+  util.return %2 : tensor<1x10xi32, #encoding>
+}
+// CHECK-LABEL: @gather_fusion
+// CHECK:       %[[DISPATCH0:.+]] = flow.dispatch.region
+// CHECK:         iree_linalg_ext.gather
+// CHECK:         %[[SET_ENCODING:.+]] = iree_encoding.set_encoding
+// CHECK:         flow.return %[[SET_ENCODING]] :
+// CHECK:       util.return %[[DISPATCH0]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#encoding = #iree_encoding.testing<>
+util.func public @multi_result_fusion(%arg0: tensor<123x456xf32>) -> (tensor<123x456xf32>, tensor<123x456xf32, #encoding>) {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<123x456xf32>
+  %1:2 = flow.dispatch.region -> (tensor<123x456xf32>, tensor<123x456xf32>) {
+    %3:2 = linalg.generic {
+        indexing_maps = [#map, #map, #map, #map],
+        iterator_types = ["parallel", "parallel"]}
+        ins(%arg0, %arg0 : tensor<123x456xf32>, tensor<123x456xf32>)
+        outs(%0, %0 : tensor<123x456xf32>, tensor<123x456xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32, %out2: f32):
+      %4 = arith.addf %in, %in_0 : f32
+      %5 = arith.mulf %in, %in_0 : f32
+      linalg.yield %4, %5 : f32, f32
+    } -> (tensor<123x456xf32>, tensor<123x456xf32>)
+    flow.return %3#0, %3#1 : tensor<123x456xf32>, tensor<123x456xf32>
+  }
+  %2 = iree_encoding.set_encoding %1#1 : tensor<123x456xf32> -> tensor<123x456xf32, #encoding>
+  util.return %1#0, %2 : tensor<123x456xf32>, tensor<123x456xf32, #encoding>
+}
+// CHECK-LABEL: @multi_result_fusion
+// CHECK:       %[[DISPATCH0:.+]]:2 = flow.dispatch.region
+// CHECK:         %[[ELEM:.+]]:2 = linalg.generic
+// CHECK:         %[[SET_ENCODING:.+]] = iree_encoding.set_encoding %[[ELEM]]#1
+// CHECK:         flow.return %[[ELEM]]#0, %[[SET_ENCODING]]
+// CHECK:       util.return %[[DISPATCH0]]#0, %[[DISPATCH0]]#1