Inline IndexCastOp and int/float ops into the dispatch region. (#5865)
Allowing folding index_cast operation and other integer/float
operations into the dispatch region avoids round-tripping to host to
get values from a tensor on device required for some index
computation.
Note: This only works for MHLO lowering path. A more nuanced solution
might be needed in general with explicit white-listing of operations,
but that requires more use cases.
Fixes #5620
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index f5bf515..e36c374 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -239,12 +239,18 @@
}
static bool isAlwaysClonedIntoDispatchOp(Operation *op) {
- if (isa<linalg::InitTensorOp, tensor::ExtractOp>(op)) {
+ if (isa<IndexCastOp, linalg::InitTensorOp, tensor::ExtractOp>(op)) {
return true;
}
if (auto constantOp = dyn_cast<ConstantOp>(op)) {
return constantOp.getResult().getType().isIntOrIndexOrFloat();
}
+ if (llvm::all_of(op->getOperands(),
+ [&](Value v) { return v.getType().isIntOrFloat(); }) &&
+ llvm::all_of(op->getResults(),
+ [&](Value v) { return v.getType().isIntOrFloat(); })) {
+ return true;
+ }
return false;
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index aa9119b..43a5c0b 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -726,3 +726,46 @@
// CHECK-DAG: flow.dispatch.tensor.store %[[OP_RESULT]]#0, %[[ARG5]]
// CHECK-DAG: flow.dispatch.tensor.store %[[OP_RESULT]]#1, %[[ARG6]]
// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func @dynamic_slice(%arg0: tensor<?x?xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3 : index) -> tensor<1x?xi32> {
+ %c1_i32 = constant 1 : i32
+ %c0_i32 = constant 0 : i32
+ %0 = tensor.extract %arg1[] : tensor<i32>
+ %1 = cmpi slt, %0, %c1_i32 : i32
+ %2 = select %1, %0, %c1_i32 : i32
+ %3 = cmpi sgt, %2, %c0_i32 : i32
+ %4 = select %3, %2, %c0_i32 : i32
+ %5 = index_cast %4 : i32 to index
+ %6 = tensor.extract %arg2[] : tensor<i32>
+ %7 = cmpi slt, %6, %c0_i32 : i32
+ %8 = select %7, %6, %c0_i32 : i32
+ %9 = cmpi sgt, %8, %c0_i32 : i32
+ %10 = select %9, %8, %c0_i32 : i32
+ %11 = index_cast %10 : i32 to index
+ %12 = subtensor %arg0[%5, %11] [1, %arg3] [1, 1] : tensor<?x?xi32> to tensor<1x?xi32>
+ return %12 : tensor<1x?xi32>
+}
+// CHECK-LABEL: func @dynamic_slice(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<i32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<i32>
+// CHECK-SAME: %[[ARG3:.+]]: index
+// CHECK: %[[C1:.+]] = constant 1 : index
+// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups
+// CHECK-SAME: [%[[ARG3]], %[[C1]], %[[C1]]]
+// CHECK-SAME: (%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]])
+// CHECK-DAG: cmpi
+// CHECK-DAG: select
+// CHECK-DAG: cmpi
+// CHECK-DAG: select
+// CHECK-DAG: cmpi
+// CHECK-DAG: cmpi
+// CHECK-DAG: select
+// CHECK-DAG: select
+// CHECK-DAG: index_cast
+// CHECK-DAG: index_cast
+// CHECK: subtensor
+// CHECK: flow.return
+// CHECK: return %[[RESULT]]
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index 7e028a0..c4502e8 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -44,6 +44,7 @@
"divide.mlir",
"dot.mlir",
"dot_general.mlir",
+ "dynamic_slice.mlir",
"exponential.mlir",
"exponential_minus_one.mlir",
"floor.mlir",
@@ -93,6 +94,7 @@
"divide.mlir",
"dot.mlir",
"dot_general.mlir",
+ "dynamic_slice.mlir",
"exponential.mlir",
"exponential_minus_one.mlir",
"fft.mlir",
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index f66d8d7..ff3fd0b 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -30,6 +30,7 @@
"divide.mlir"
"dot.mlir"
"dot_general.mlir"
+ "dynamic_slice.mlir"
"exponential.mlir"
"exponential_minus_one.mlir"
"fft.mlir"
@@ -87,6 +88,7 @@
"divide.mlir"
"dot.mlir"
"dot_general.mlir"
+ "dynamic_slice.mlir"
"exponential.mlir"
"exponential_minus_one.mlir"
"floor.mlir"
@@ -140,6 +142,7 @@
"divide.mlir"
"dot.mlir"
"dot_general.mlir"
+ "dynamic_slice.mlir"
"exponential.mlir"
"exponential_minus_one.mlir"
"fft.mlir"
diff --git a/iree/test/e2e/xla_ops/dynamic_slice.mlir b/iree/test/e2e/xla_ops/dynamic_slice.mlir
new file mode 100644
index 0000000..5124951
--- /dev/null
+++ b/iree/test/e2e/xla_ops/dynamic_slice.mlir
@@ -0,0 +1,30 @@
+func @dynamic_slice() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<[
+ [01, 02, 03, 04],
+ [05, 06, 07, 08],
+ [09, 10, 11, 12]]> : tensor<3x4xi32>
+ %start1 = iree.unfoldable_constant dense<1> : tensor<i64>
+ %start2 = iree.unfoldable_constant dense<2> : tensor<i64>
+ %result = "mhlo.dynamic-slice"(%input, %start1, %start2) {
+ slice_sizes = dense<[2, 2]> : tensor<2xi64>
+ } : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
+ check.expect_eq_const(%result, dense<[
+ [7, 8],
+ [11, 12]]> : tensor<2x2xi32>) : tensor<2x2xi32>
+ return
+}
+
+func @dynamic_unit_slice() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<[
+ [01, 02, 03, 04],
+ [05, 06, 07, 08],
+ [09, 10, 11, 12]]> : tensor<3x4xi32>
+ %start1 = iree.unfoldable_constant dense<1> : tensor<i64>
+ %start2 = iree.unfoldable_constant dense<2> : tensor<i64>
+ %result = "mhlo.dynamic-slice"(%input, %start1, %start2) {
+ slice_sizes = dense<[1, 2]> : tensor<2xi64>
+ } : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x2xi32>
+ check.expect_eq_const(%result, dense<[
+ [7, 8]]> : tensor<1x2xi32>) : tensor<1x2xi32>
+ return
+}