[Flow] Always permute the accesses on inputs for elementwise consumer from namedop/reduction producer. (#17663)
For dispatch formation, the current logic (and a lot of code-generation)
works much better if the consumer uses an identity indexing map for the
producer. There is already a pass in dispatch region formation flow that
does this for just a convolution op. Make this apply for more general
cases.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InterchangeTransposeGenericOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InterchangeTransposeGenericOps.cpp
index ba33194..7a03bc8 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InterchangeTransposeGenericOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InterchangeTransposeGenericOps.cpp
@@ -39,8 +39,11 @@
std::optional<AffineMap> mapForInterchange;
for (auto operand : genericOp.getDpsInputOperands()) {
- auto producer = operand->get().getDefiningOp<linalg::Conv2DNhwcHwcfOp>();
- if (!producer || !llvm::hasSingleElement(producer->getUsers()))
+ // Check that the producer is a named op or a reduction op (i.e. not
+ // elementwise op) with a single use.
+ auto producer = operand->get().getDefiningOp<linalg::LinalgOp>();
+ if (!producer || !llvm::hasSingleElement(producer->getUsers()) ||
+ linalg::isElementwise(producer))
continue;
// check if the generic op has a non-identity map for the operand.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/interchange_transpose_generic_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/interchange_transpose_generic_ops.mlir
index cc9fe7a..9da0a96 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/interchange_transpose_generic_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/interchange_transpose_generic_ops.mlir
@@ -20,9 +20,38 @@
} -> tensor<2x320x128x128xf16>
util.return %truncf : tensor<2x320x128x128xf16>
}
-// CHECK-LABEL: func public @supported_conv
+// CHECK-LABEL: func public @supported_conv(
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)>]
// CHECK-SAME: ins(%[[CONV]] :
// CHECK: return %[[GENERIC]]
+
+// -----
+
+util.func @generalize_to_any_linalg_op(%arg0 : tensor<?x?x?x?xi8>, %arg1 : tensor<?x?x?x?xi8>,
+ %arg2 : tensor<?x?x?x?xi64>, %arg3 : tensor<?x?x?x?xi64>, %arg4 : tensor<?x?x?x?xi8>) -> tensor<?x?x?x?xi8> {
+ %c0_i64 = arith.constant 0 : i64
+ %0 = linalg.conv_2d_nhwc_hwcf_q {
+ dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
+ ins(%arg0, %arg1, %c0_i64, %c0_i64 : tensor<?x?x?x?xi8>, tensor<?x?x?x?xi8>, i64, i64)
+ outs(%arg2 : tensor<?x?x?x?xi64>) -> tensor<?x?x?x?xi64>
+ %2 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2, d3, d0)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%0 : tensor<?x?x?x?xi64>) outs(%arg4 : tensor<?x?x?x?xi8>) {
+ ^bb0(%in: i64, %out: i8):
+ %3 = arith.trunci %in : i64 to i32
+ %4 = arith.sitofp %3 : i32 to f32
+ %5 = arith.fptosi %4 : f32 to i8
+ linalg.yield %5 : i8
+ } -> tensor<?x?x?x?xi8>
+ util.return %2 : tensor<?x?x?x?xi8>
+}
+// CHECK-LABEL: func public @generalize_to_any_linalg_op(
+// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q
+// CHECK: %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>]
+// CHECK: return %[[RESULT]]