Make TorchIndexSelectOp being able to fuse with consumers. (#4170)

This patch relaxes the torch_index_select opt to live with consumers in
a dispatch function.
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
index 3581d9b..c3c3415 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
@@ -223,15 +223,15 @@
 // TODO(b/144530470): replace with tablegen attributes/interfaces.
 bool OpDispatchPolicy::isUnsupportedFusionOp(Operation *op) {
   return isa<linalg::IndexedGenericOp, linalg::GenericOp, mhlo::ConcatenateOp,
-             mhlo::ConvOp, mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp,
-             mhlo::TorchIndexSelectOp>(op) ||
+             mhlo::ConvOp, mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp>(
+             op) ||
          (!clEnableConsumerOnlyFusion &&
           isa<mhlo::DotOp, mhlo::DotGeneralOp>(op)) ||
          isRootOnlyOp(op);
 }
 
 bool OpDispatchPolicy::isRootOnlyOp(Operation *op) {
-  return isa<mhlo::SliceOp>(op);
+  return isa<mhlo::SliceOp, mhlo::TorchIndexSelectOp>(op);
 }
 
 }  // namespace Flow
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
index c324ef7..d551d13 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
@@ -162,3 +162,29 @@
 //       CHECK: flow.dispatch.region
 //  CHECK-NEXT:   mhlo.slice
 //  CHECK-NEXT:   mhlo.multiply
+
+// -----
+
+module {
+  func @torch_index_select_producer(%arg0: tensor<5x1x5xi32>,
+                                    %arg1: tensor<2xi32>) -> tensor<2x1x5xi32> {
+    %c10 = constant 0 : index
+    %0 = flow.dispatch.region[%c10 : index](%arg2 = %arg0 : tensor<5x1x5xi32>,
+                                            %arg3 = %arg1 : tensor<2xi32>) -> tensor<2x1x5xi32> {
+      %1 = "mhlo.torch_index_select"(%arg2, %arg3) {
+        dim = 0 : i64,
+        batch_dims = 0 : i64
+      } : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32>
+      flow.return %1 : tensor<2x1x5xi32>
+    }
+    %1 = flow.dispatch.region[%c10 : index](%arg2 = %0 : tensor<2x1x5xi32>) -> tensor<2x1x5xi32> {
+      %2 = mhlo.add %arg2, %arg2 : tensor<2x1x5xi32>
+      flow.return %2 : tensor<2x1x5xi32>
+    }
+    return %1 : tensor<2x1x5xi32>
+  }
+}
+// CHECK-LABEL: func @torch_index_select_producer
+//       CHECK: flow.dispatch.region
+//  CHECK-NEXT:   mhlo.torch_index_select
+//  CHECK-NEXT:   mhlo.add
diff --git a/iree/test/e2e/structural/BUILD b/iree/test/e2e/structural/BUILD
index 1dec4be..fd9d7af 100644
--- a/iree/test/e2e/structural/BUILD
+++ b/iree/test/e2e/structural/BUILD
@@ -31,6 +31,7 @@
 iree_check_single_backend_test_suite(
     name = "check_vulkan-spirv_vulkan",
     srcs = [
+        "gather_add.mlir",
         "matmul_add.mlir",
         "slice_add.mlir",
     ],
@@ -41,6 +42,7 @@
 iree_check_single_backend_test_suite(
     name = "check_dylib-llvm-aot_dylib",
     srcs = [
+        "gather_add.mlir",
         "matmul_add.mlir",
         "slice_add.mlir",
     ],
diff --git a/iree/test/e2e/structural/CMakeLists.txt b/iree/test/e2e/structural/CMakeLists.txt
index a0bf677..ce86c16 100644
--- a/iree/test/e2e/structural/CMakeLists.txt
+++ b/iree/test/e2e/structural/CMakeLists.txt
@@ -30,6 +30,7 @@
   NAME
     check_vulkan-spirv_vulkan
   SRCS
+    "gather_add.mlir"
     "matmul_add.mlir"
     "slice_add.mlir"
   TARGET_BACKEND
@@ -42,6 +43,7 @@
   NAME
     check_dylib-llvm-aot_dylib
   SRCS
+    "gather_add.mlir"
     "matmul_add.mlir"
     "slice_add.mlir"
   TARGET_BACKEND
diff --git a/iree/test/e2e/structural/gather_add.mlir b/iree/test/e2e/structural/gather_add.mlir
new file mode 100644
index 0000000..676b3a7
--- /dev/null
+++ b/iree/test/e2e/structural/gather_add.mlir
@@ -0,0 +1,24 @@
+func @torch_select_index_0() attributes { iree.module.export } {
+  %input = iree.unfoldable_constant dense<[
+    [[01, 02, 03, 04, 05]],
+    [[06, 07, 08, 09, 10]],
+    [[11, 12, 13, 14, 15]],
+    [[16, 17, 18, 19, 20]],
+    [[21, 22, 23, 24, 25]]]> : tensor<5x1x5xi32>
+  %indices = iree.unfoldable_constant dense<[0, 2]> : tensor<2xi32>
+  %workload = constant 10 : index
+  %result = flow.dispatch.region[%workload: index](
+      %arg0 = %input : tensor<5x1x5xi32>,
+      %arg1 = %indices : tensor<2xi32>) -> tensor<2x1x5xi32> {
+    %0 = "mhlo.torch_index_select"(%arg0, %arg1) {
+      dim = 0 : i64,
+      batch_dims = 0 : i64
+    } : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32>
+    %1 = mhlo.add %0, %0 : tensor<2x1x5xi32>
+    flow.return %1 : tensor<2x1x5xi32>
+  }
+
+  check.expect_eq_const(%result, dense<[[[02, 04, 06, 08, 10]],
+                                        [[22, 24, 26, 28, 30]]]> : tensor<2x1x5xi32>) : tensor<2x1x5xi32>
+  return
+}