Make TorchIndexSelectOp being able to fuse with consumers. (#4170)
This patch relaxes the torch_index_select opt to live with consumers in
a dispatch function.
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
index 3581d9b..c3c3415 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
@@ -223,15 +223,15 @@
// TODO(b/144530470): replace with tablegen attributes/interfaces.
bool OpDispatchPolicy::isUnsupportedFusionOp(Operation *op) {
return isa<linalg::IndexedGenericOp, linalg::GenericOp, mhlo::ConcatenateOp,
- mhlo::ConvOp, mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp,
- mhlo::TorchIndexSelectOp>(op) ||
+ mhlo::ConvOp, mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp>(
+ op) ||
(!clEnableConsumerOnlyFusion &&
isa<mhlo::DotOp, mhlo::DotGeneralOp>(op)) ||
isRootOnlyOp(op);
}
bool OpDispatchPolicy::isRootOnlyOp(Operation *op) {
- return isa<mhlo::SliceOp>(op);
+ return isa<mhlo::SliceOp, mhlo::TorchIndexSelectOp>(op);
}
} // namespace Flow
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
index c324ef7..d551d13 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
@@ -162,3 +162,29 @@
// CHECK: flow.dispatch.region
// CHECK-NEXT: mhlo.slice
// CHECK-NEXT: mhlo.multiply
+
+// -----
+
+module {
+ func @torch_index_select_producer(%arg0: tensor<5x1x5xi32>,
+ %arg1: tensor<2xi32>) -> tensor<2x1x5xi32> {
+ %c10 = constant 0 : index
+ %0 = flow.dispatch.region[%c10 : index](%arg2 = %arg0 : tensor<5x1x5xi32>,
+ %arg3 = %arg1 : tensor<2xi32>) -> tensor<2x1x5xi32> {
+ %1 = "mhlo.torch_index_select"(%arg2, %arg3) {
+ dim = 0 : i64,
+ batch_dims = 0 : i64
+ } : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32>
+ flow.return %1 : tensor<2x1x5xi32>
+ }
+ %1 = flow.dispatch.region[%c10 : index](%arg2 = %0 : tensor<2x1x5xi32>) -> tensor<2x1x5xi32> {
+ %2 = mhlo.add %arg2, %arg2 : tensor<2x1x5xi32>
+ flow.return %2 : tensor<2x1x5xi32>
+ }
+ return %1 : tensor<2x1x5xi32>
+ }
+}
+// CHECK-LABEL: func @torch_index_select_producer
+// CHECK: flow.dispatch.region
+// CHECK-NEXT: mhlo.torch_index_select
+// CHECK-NEXT: mhlo.add
diff --git a/iree/test/e2e/structural/BUILD b/iree/test/e2e/structural/BUILD
index 1dec4be..fd9d7af 100644
--- a/iree/test/e2e/structural/BUILD
+++ b/iree/test/e2e/structural/BUILD
@@ -31,6 +31,7 @@
iree_check_single_backend_test_suite(
name = "check_vulkan-spirv_vulkan",
srcs = [
+ "gather_add.mlir",
"matmul_add.mlir",
"slice_add.mlir",
],
@@ -41,6 +42,7 @@
iree_check_single_backend_test_suite(
name = "check_dylib-llvm-aot_dylib",
srcs = [
+ "gather_add.mlir",
"matmul_add.mlir",
"slice_add.mlir",
],
diff --git a/iree/test/e2e/structural/CMakeLists.txt b/iree/test/e2e/structural/CMakeLists.txt
index a0bf677..ce86c16 100644
--- a/iree/test/e2e/structural/CMakeLists.txt
+++ b/iree/test/e2e/structural/CMakeLists.txt
@@ -30,6 +30,7 @@
NAME
check_vulkan-spirv_vulkan
SRCS
+ "gather_add.mlir"
"matmul_add.mlir"
"slice_add.mlir"
TARGET_BACKEND
@@ -42,6 +43,7 @@
NAME
check_dylib-llvm-aot_dylib
SRCS
+ "gather_add.mlir"
"matmul_add.mlir"
"slice_add.mlir"
TARGET_BACKEND
diff --git a/iree/test/e2e/structural/gather_add.mlir b/iree/test/e2e/structural/gather_add.mlir
new file mode 100644
index 0000000..676b3a7
--- /dev/null
+++ b/iree/test/e2e/structural/gather_add.mlir
@@ -0,0 +1,24 @@
+func @torch_select_index_0() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<[
+ [[01, 02, 03, 04, 05]],
+ [[06, 07, 08, 09, 10]],
+ [[11, 12, 13, 14, 15]],
+ [[16, 17, 18, 19, 20]],
+ [[21, 22, 23, 24, 25]]]> : tensor<5x1x5xi32>
+ %indices = iree.unfoldable_constant dense<[0, 2]> : tensor<2xi32>
+ %workload = constant 10 : index
+ %result = flow.dispatch.region[%workload: index](
+ %arg0 = %input : tensor<5x1x5xi32>,
+ %arg1 = %indices : tensor<2xi32>) -> tensor<2x1x5xi32> {
+ %0 = "mhlo.torch_index_select"(%arg0, %arg1) {
+ dim = 0 : i64,
+ batch_dims = 0 : i64
+ } : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32>
+ %1 = mhlo.add %0, %0 : tensor<2x1x5xi32>
+ flow.return %1 : tensor<2x1x5xi32>
+ }
+
+ check.expect_eq_const(%result, dense<[[[02, 04, 06, 08, 10]],
+ [[22, 24, 26, 28, 30]]]> : tensor<2x1x5xi32>) : tensor<2x1x5xi32>
+ return
+}