Drop the tensor.pack/unpack -> LinalgExt lowering from transform dialect (#12401)
All the implementation of pack/unpack ops are upstreamed and is plumbed
through IREE software stack. The LinalgExt version is going to be
deprecated. The commits switches the usage to upstream version for
transform dialect.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 3ffb549..65e2bcc 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -146,7 +146,6 @@
getLowerTransferOpPermutationsAttrName)
ADD_PATTERN(rankReducingLinalg, getRankReducingLinalgAttrName)
ADD_PATTERN(rankReducingVector, getRankReducingVectorAttrName)
- ADD_PATTERN(rewritePackOps, getRewritePackOpsAttrName)
ADD_PATTERN(swapPaddingElideConditional,
getSwapPaddingElideConditionalAttrName)
ADD_PATTERN(swappingPatterns, getSwappingPatternsAttrName)
@@ -195,32 +194,6 @@
return success();
}
};
-
-/// Trivial 1-1 pattern to retire once IREE adopts tensor.pack.
-struct TensorPackToLinalgExt : public OpRewritePattern<tensor::PackOp> {
- using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(tensor::PackOp packOp,
- PatternRewriter &rewriter) const final {
- rewriter.replaceOpWithNewOp<LinalgExt::PackOp>(
- packOp, packOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
- packOp.getMixedTiles(), packOp.getPaddingValue(),
- packOp.getOuterDimsPerm());
- return success();
- }
-};
-
-/// Trivial 1-1 pattern to retire once IREE adopts tensor.unpack.
-struct TensorUnPackToLinalgExt : public OpRewritePattern<tensor::UnPackOp> {
- using OpRewritePattern<tensor::UnPackOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
- PatternRewriter &rewriter) const final {
- rewriter.replaceOpWithNewOp<LinalgExt::UnPackOp>(
- unPackOp, unPackOp.getSource(), unPackOp.getDest(),
- unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
- unPackOp.getOuterDimsPerm());
- return success();
- }
-};
} // namespace
static void addLowerTransferOpPermutationsPatterns(
@@ -254,11 +227,6 @@
vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
}
-static void addRewritePackOpsPatterns(RewritePatternSet &patterns) {
- patterns.add<TensorPackToLinalgExt, TensorUnPackToLinalgExt>(
- patterns.getContext());
-}
-
static void addSwappingPatterns(RewritePatternSet &patterns,
bool swapPaddingElideCornerCase) {
patterns.add<linalg::ExtractSliceOfPadTensorSwapPattern>(
@@ -338,7 +306,6 @@
if (getFoldTensorEmptyExtract()) addFoldTensorEmptyExtract(patterns);
if (getRankReducingLinalg()) addRankReducingLinalgPatterns(patterns);
if (getRankReducingVector()) addRankReducingVectorPatterns(patterns);
- if (getRewritePackOps()) addRewritePackOpsPatterns(patterns);
if (getSwappingPatterns())
addSwappingPatterns(patterns, getSwapPaddingElideConditional());
if (getAdditionalIreePatterns()) addAdditionalIreePatterns(patterns);
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
index 6ef6aa7..c0cefc1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
@@ -47,7 +47,6 @@
bool promoteForallCaptureToShared = false;
bool rankReducingLinalg = false;
bool rankReducingVector = false;
- bool rewritePackOps = false;
bool swapPaddingElideConditional = false;
bool swappingPatterns = false;
bool unrollVectorsGpuMmaSync = false;
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
index 2e9de43..60da8db 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
@@ -95,8 +95,6 @@
behavior on subset-based linalg operations.
- rank_reducing_vector: adds patterns that results in rank-reducing
behavior on subset-based vector operations.
- - rewrite_pack_ops: rewrite tensor.pack/unpack to linalg_ext.pack/unpack.
- This is a temporary pattern that is needed to connect to IREE until it
adopts the upstream version.
- swapping_patterns: adds patterns that swap operations for a better outcome.
This is a catch all that can be refined further if/when needed.
@@ -137,7 +135,6 @@
UnitAttr:$lower_transfer_op_permutations,
UnitAttr:$rank_reducing_linalg,
UnitAttr:$rank_reducing_vector,
- UnitAttr:$rewrite_pack_ops,
UnitAttr:$swap_padding_elide_conditional,
UnitAttr:$swapping_patterns,
UnitAttr:$unroll_vectors_gpu_mma_sync,
diff --git a/tests/transform_dialect/cpu/contraction-packing-and-dispatch.mlir b/tests/transform_dialect/cpu/contraction-packing-and-dispatch.mlir
index 6c37451..6e08a3f 100644
--- a/tests/transform_dialect/cpu/contraction-packing-and-dispatch.mlir
+++ b/tests/transform_dialect/cpu/contraction-packing-and-dispatch.mlir
@@ -14,7 +14,7 @@
// 'memref.alloca' op all stack allocations need to be hoisted to the entry block of the function
//
// R-UN: iree-opt %s --iree-transform-dialect-interpreter --transform-dialect-drop-schedule | \
-// R-UN: iree-compile --iree-hal-target-backends=llvm-cpu
+// R-UN: iree-compile --iree-hal-target-backends=llvm-cpu
!a_tensor_t = tensor<1234x567xf32>
!b_tensor_t = tensor<567x890xf32>
@@ -27,25 +27,25 @@
// CHECK-DAG: #[[$map_rhs:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d1, d3, d5)>
// CHECK-DAG: #[[$map_res:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @_matmul_dispatch_0
+// CHECK-LABEL: func.func @matmul_dispatch_0
// CHECK: tensor.empty() : tensor<155x18x8x32xf32>
-// CHECK: iree_linalg_ext.pack
+// CHECK: tensor.pack
-// CHECK-LABEL: func.func @_matmul_dispatch_1
+// CHECK-LABEL: func.func @matmul_dispatch_1
// CHECK: arith.constant dense<1.000000e-01> : tensor<567x890xf32>
// CHECK: tensor.empty() : tensor<18x56x16x32xf32>
-// CHECK: iree_linalg_ext.pack
+// CHECK: tensor.pack
-// CHECK-LABEL: func.func @_matmul_dispatch_2
+// CHECK-LABEL: func.func @matmul_dispatch_2
// CHECK: tensor.empty() : tensor<155x56x8x16xf32>
-// CHECK: iree_linalg_ext.pack
+// CHECK: tensor.pack
-// CHECK-LABEL: func.func @_matmul_dispatch_3
+// CHECK-LABEL: func.func @matmul_dispatch_3
func.func @matmul(%arg0: !a_tensor_t, %arg2: !c_tensor_t) -> !c_tensor_t {
%c0 = arith.constant dense<0.1> : !b_tensor_t
// CHECK-NOT: pack
- // CHECK: linalg.generic {indexing_maps = [#[[$map_lhs]], #[[$map_rhs]], #[[$map_res]]],
- // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
+ // CHECK: linalg.generic {indexing_maps = [#[[$map_lhs]], #[[$map_rhs]], #[[$map_res]]],
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
// CHECK-SAME: ins(%{{.*}} : tensor<155x18x8x32xf32>, tensor<18x56x16x32xf32>)
// CHECK-SAME: outs(%{{.*}} : tensor<155x56x8x16xf32>)
@@ -55,8 +55,8 @@
return %0 : !c_tensor_t
}
-// CHECK-LABEL: func.func @_matmul_dispatch_4
-// CHECK: iree_linalg_ext.unpack
+// CHECK-LABEL: func.func @matmul_dispatch_4
+// CHECK: tensor.unpack
transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
@@ -66,9 +66,4 @@
transform.structured.pack_greedily %matmul
gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [0, 1, 2]
: (!pdl.operation) -> !transform.op<"linalg.generic">
-
- // TODO: Remove once IREE adopts tensor.pack/unpack.
- %func = transform.structured.match ops{["func.func"]} in %module_op
- : (!pdl.operation) -> (!pdl.operation)
- transform.iree.apply_patterns %func { rewrite_pack_ops }
}