Inline IndexCastOp and int/float ops into the dispatch region. (#5865)
Allowing folding index_cast operation and other integer/float
operations into the dispatch region avoids round-tripping to host to
get values from a tensor on device required for some index
computation.
Note: This only works for MHLO lowering path. A more nuanced solution
might be needed in general with explicit white-listing of operations,
but that requires more use cases.
Fixes #5620
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index f5bf515..e36c374 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -239,12 +239,18 @@
 }
 
 static bool isAlwaysClonedIntoDispatchOp(Operation *op) {
-  if (isa<linalg::InitTensorOp, tensor::ExtractOp>(op)) {
+  if (isa<IndexCastOp, linalg::InitTensorOp, tensor::ExtractOp>(op)) {
     return true;
   }
   if (auto constantOp = dyn_cast<ConstantOp>(op)) {
     return constantOp.getResult().getType().isIntOrIndexOrFloat();
   }
+  if (llvm::all_of(op->getOperands(),
+                   [&](Value v) { return v.getType().isIntOrFloat(); }) &&
+      llvm::all_of(op->getResults(),
+                   [&](Value v) { return v.getType().isIntOrFloat(); })) {
+    return true;
+  }
   return false;
 }
 
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index aa9119b..43a5c0b 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -726,3 +726,46 @@
 //   CHECK-DAG:     flow.dispatch.tensor.store %[[OP_RESULT]]#0, %[[ARG5]]
 //   CHECK-DAG:     flow.dispatch.tensor.store %[[OP_RESULT]]#1, %[[ARG6]]
 //       CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func @dynamic_slice(%arg0: tensor<?x?xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3 : index) -> tensor<1x?xi32> {
+  %c1_i32 = constant 1 : i32
+  %c0_i32 = constant 0 : i32
+  %0 = tensor.extract %arg1[] : tensor<i32>
+  %1 = cmpi slt, %0, %c1_i32 : i32
+  %2 = select %1, %0, %c1_i32 : i32
+  %3 = cmpi sgt, %2, %c0_i32 : i32
+  %4 = select %3, %2, %c0_i32 : i32
+  %5 = index_cast %4 : i32 to index
+  %6 = tensor.extract %arg2[] : tensor<i32>
+  %7 = cmpi slt, %6, %c0_i32 : i32
+  %8 = select %7, %6, %c0_i32 : i32
+  %9 = cmpi sgt, %8, %c0_i32 : i32
+  %10 = select %9, %8, %c0_i32 : i32
+  %11 = index_cast %10 : i32 to index
+  %12 = subtensor %arg0[%5, %11] [1, %arg3] [1, 1] : tensor<?x?xi32> to tensor<1x?xi32>
+  return %12 : tensor<1x?xi32>
+}
+// CHECK-LABEL: func @dynamic_slice(
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?xi32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<i32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<i32>
+//  CHECK-SAME:   %[[ARG3:.+]]: index
+//       CHECK:   %[[C1:.+]] = constant 1 : index
+//       CHECK:   %[[RESULT:.+]] = flow.dispatch.workgroups
+//  CHECK-SAME:     [%[[ARG3]], %[[C1]], %[[C1]]]
+//  CHECK-SAME:     (%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]])
+//   CHECK-DAG:     cmpi
+//   CHECK-DAG:     select
+//   CHECK-DAG:     cmpi
+//   CHECK-DAG:     select
+//   CHECK-DAG:     cmpi
+//   CHECK-DAG:     cmpi
+//   CHECK-DAG:     select
+//   CHECK-DAG:     select
+//   CHECK-DAG:     index_cast
+//   CHECK-DAG:     index_cast
+//       CHECK:     subtensor
+//       CHECK:     flow.return
+//       CHECK:   return %[[RESULT]]
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index 7e028a0..c4502e8 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -44,6 +44,7 @@
     "divide.mlir",
     "dot.mlir",
     "dot_general.mlir",
+    "dynamic_slice.mlir",
     "exponential.mlir",
     "exponential_minus_one.mlir",
     "floor.mlir",
@@ -93,6 +94,7 @@
             "divide.mlir",
             "dot.mlir",
             "dot_general.mlir",
+            "dynamic_slice.mlir",
             "exponential.mlir",
             "exponential_minus_one.mlir",
             "fft.mlir",
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index f66d8d7..ff3fd0b 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -30,6 +30,7 @@
     "divide.mlir"
     "dot.mlir"
     "dot_general.mlir"
+    "dynamic_slice.mlir"
     "exponential.mlir"
     "exponential_minus_one.mlir"
     "fft.mlir"
@@ -87,6 +88,7 @@
     "divide.mlir"
     "dot.mlir"
     "dot_general.mlir"
+    "dynamic_slice.mlir"
     "exponential.mlir"
     "exponential_minus_one.mlir"
     "floor.mlir"
@@ -140,6 +142,7 @@
     "divide.mlir"
     "dot.mlir"
     "dot_general.mlir"
+    "dynamic_slice.mlir"
     "exponential.mlir"
     "exponential_minus_one.mlir"
     "fft.mlir"
diff --git a/iree/test/e2e/xla_ops/dynamic_slice.mlir b/iree/test/e2e/xla_ops/dynamic_slice.mlir
new file mode 100644
index 0000000..5124951
--- /dev/null
+++ b/iree/test/e2e/xla_ops/dynamic_slice.mlir
@@ -0,0 +1,30 @@
+func @dynamic_slice() attributes { iree.module.export } {
+  %input = iree.unfoldable_constant dense<[
+    [01, 02, 03, 04],
+    [05, 06, 07, 08],
+    [09, 10, 11, 12]]> : tensor<3x4xi32>
+  %start1 = iree.unfoldable_constant dense<1> : tensor<i64>
+  %start2 = iree.unfoldable_constant dense<2> : tensor<i64>
+  %result = "mhlo.dynamic-slice"(%input, %start1, %start2) {
+    slice_sizes = dense<[2, 2]> : tensor<2xi64>
+  } : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
+  check.expect_eq_const(%result, dense<[
+      [7, 8],
+      [11, 12]]> : tensor<2x2xi32>) : tensor<2x2xi32>
+  return
+}
+
+func @dynamic_unit_slice() attributes { iree.module.export } {
+  %input = iree.unfoldable_constant dense<[
+    [01, 02, 03, 04],
+    [05, 06, 07, 08],
+    [09, 10, 11, 12]]> : tensor<3x4xi32>
+  %start1 = iree.unfoldable_constant dense<1> : tensor<i64>
+  %start2 = iree.unfoldable_constant dense<2> : tensor<i64>
+  %result = "mhlo.dynamic-slice"(%input, %start1, %start2) {
+    slice_sizes = dense<[1, 2]> : tensor<2xi64>
+  } : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x2xi32>
+  check.expect_eq_const(%result, dense<[
+      [7, 8]]> : tensor<1x2xi32>) : tensor<1x2xi32>
+  return
+}