[Flow] Make the output indexing_map of elementwise ops identity. (#17583)
Continuing https://github.com/iree-org/iree/pull/17262. Just moved logic
into fusion preprocessing
---------
Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
Co-authored-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp
index bd5bf56..68f6d7f 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionPreprocessing.cpp
@@ -70,6 +70,31 @@
};
//===----------------------------------------------------------------------===//
+// ElementwiseOpInterchangePattern
+//===----------------------------------------------------------------------===//
+
+struct ElementwiseOpInterchangePattern
+ : public OpRewritePattern<linalg::GenericOp> {
+ using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ if (!linalg::isElementwise(genericOp) || genericOp.getNumResults() != 1)
+ return failure();
+
+ AffineMap indexingMap = genericOp.getIndexingMapsArray().back();
+ if (indexingMap.isIdentity())
+ return failure();
+
+ ArrayRef<AffineExpr> exprs = indexingMap.getResults();
+ auto perm = llvm::map_to_vector(exprs, [](AffineExpr e) -> unsigned {
+ return cast<AffineDimExpr>(e).getPosition();
+ });
+
+ return linalg::interchangeGenericOp(rewriter, genericOp, perm);
+ }
+};
+
+//===----------------------------------------------------------------------===//
// FoldSuccessiveTensorInsertSliceOps
//===----------------------------------------------------------------------===//
@@ -209,7 +234,8 @@
FusionPreprocessingPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
- patterns.add<FoldSuccessiveTensorInsertSliceOps,
+ patterns.add<ElementwiseOpInterchangePattern,
+ FoldSuccessiveTensorInsertSliceOps,
GenericOpInterchangePattern, GatherFusionPattern>(
&getContext());
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_preprocessing.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_preprocessing.mlir
index f0aa650..04713f0 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_preprocessing.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_preprocessing.mlir
@@ -138,3 +138,23 @@
// CHECK-NEXT: %[[RES3:[a-zA-Z0-9]+]] = arith.mulf %[[RES]], %[[RES]] : f32
// CHECK-NEXT: %[[RES4:[a-zA-Z0-9]+]] = arith.addf %[[RES2]], %[[RES3]] : f32
// CHECK-NEXT: linalg.yield %[[RES4]] : f32
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)>
+util.func @output_transpose_map(%arg0: tensor<2x128x128x320xf32>) -> tensor<2x320x128x128xf16> {
+ %0 = tensor.empty() : tensor<2x320x128x128xf16>
+ %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x128x128x320xf32>) outs(%0 : tensor<2x320x128x128xf16>) {
+ ^bb0(%in: f32, %out: f16):
+ %2 = arith.truncf %in : f32 to f16
+ linalg.yield %2 : f16
+ } -> tensor<2x320x128x128xf16>
+ util.return %1 : tensor<2x320x128x128xf16>
+}
+
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: util.func public @output_transpose_map
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]