Add an example of generalized packing (#12076)
This revision demonstrates how the generalized packing transformation is
a one-size fits all implementation
for tiling linalg ops of rank N to linalg ops of higher rank.
The tensor.pack / tensor.unpack representation allows us to complete our
panoply of composable linalg transformations.
Now, one can either:
1. lower a linalg op to loops and use classical loop-based tiling
techniques (e.g. Allen&Kennedy, polyhdral/affine etc)
2. tile an N-D linalg op to N-D loops surrounding an N-D linalg op. This
often preserves the name of the linalg op and is at the basis of the
TilingInterface.
3. (this PR) tile an N-D linalg op to a 2*N-D linalg.generic. This step
requires that the tile dimensions divide the problem dimension.
tensor.pack / tensor.unpack provide this guarantee.
Step 3. can easily be adapted to produce a new named op (e.g. mmt4d)
when relevant, the point of this PR is to demonstrate generality.
This is related to discussion #12075.
An additional pattern is added to convert tensor.pack/unpack to
linalg_ext.pack/unpack until IREE adopts the upstream variants.
With this, it is possible to form dispatch regions without failing to
lower.
One thing to note in the IREE pass pipeline is that the
InterchangeGenericOps breaks the normalization property of the packing.
This can be recovered after the fact but it would be better to disable
in such cases if possible.
At this time, `iree-compile` fails on the `iree_linalg_ext` ops as it
wants statically allocated buffers to be hoisted to the top of the
function.
Both these issues can be left for a followup investigation.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index a694d3b..b2f41d1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -65,11 +65,16 @@
#define ADD_PATTERN(NAME, ATTR) \
if (patterns.NAME) \
result.addAttribute(ApplyPatternsOp::ATTR(result.name), unitAttr);
+ ///
+ /// When touching something here, do not forget to update CommonExtensions.h.
+ ///
ADD_PATTERN(additionalIreePatterns, getAdditionalIreePatternsAttrName)
ADD_PATTERN(bubbleCollapseExpand, getBubbleCollapseExpandAttrName)
ADD_PATTERN(canonicalization, getCanonicalizationAttrName)
ADD_PATTERN(eraseUnnecessaryTensorOperands,
getEraseUnnecessaryTensorOperandsAttrName)
+ ADD_PATTERN(expandMemrefStridedMetadata,
+ getExpandMemrefStridedMetadataAttrName)
ADD_PATTERN(foldMemrefAliases, getFoldMemrefAliasesAttrName)
ADD_PATTERN(foldReassociativeReshapes, getFoldReassociativeReshapesAttrName)
ADD_PATTERN(foldTensorEmptyExtract, getFoldTensorEmptyExtractAttrName)
@@ -77,8 +82,7 @@
getLowerTransferOpPermutationsAttrName)
ADD_PATTERN(rankReducingLinalg, getRankReducingLinalgAttrName)
ADD_PATTERN(rankReducingVector, getRankReducingVectorAttrName)
- ADD_PATTERN(expandMemrefStridedMetadata,
- getExpandMemrefStridedMetadataAttrName)
+ ADD_PATTERN(rewritePackOps, getRewritePackOpsAttrName)
ADD_PATTERN(swapPaddingElideConditional,
getSwapPaddingElideConditionalAttrName)
ADD_PATTERN(swappingPatterns, getSwappingPatternsAttrName)
@@ -125,6 +129,32 @@
return success();
}
};
+
+/// Trivial 1-1 pattern to retire once IREE adopts tensor.pack.
+struct TensorPackToLinalgExt : public OpRewritePattern<tensor::PackOp> {
+ using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(tensor::PackOp packOp,
+ PatternRewriter &rewriter) const final {
+ rewriter.replaceOpWithNewOp<LinalgExt::PackOp>(
+ packOp, packOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
+ packOp.getMixedTiles(), packOp.getPaddingValue(),
+ packOp.getOuterDimsPerm());
+ return success();
+ }
+};
+
+/// Trivial 1-1 pattern to retire once IREE adopts tensor.unpack.
+struct TensorUnPackToLinalgExt : public OpRewritePattern<tensor::UnPackOp> {
+ using OpRewritePattern<tensor::UnPackOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
+ PatternRewriter &rewriter) const final {
+ rewriter.replaceOpWithNewOp<LinalgExt::UnPackOp>(
+ unPackOp, unPackOp.getSource(), unPackOp.getDest(),
+ unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
+ unPackOp.getOuterDimsPerm());
+ return success();
+ }
+};
} // namespace
static void addLowerTransferOpPermutationsPatterns(
@@ -158,6 +188,11 @@
vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
}
+static void addRewritePackOpsPatterns(RewritePatternSet &patterns) {
+ patterns.add<TensorPackToLinalgExt, TensorUnPackToLinalgExt>(
+ patterns.getContext());
+}
+
static void addSwappingPatterns(RewritePatternSet &patterns,
bool swapPaddingElideCornerCase) {
patterns.add<linalg::ExtractSliceOfPadTensorSwapPattern>(
@@ -196,13 +231,14 @@
addLowerTransferOpPermutationsPatterns(patterns);
if (getEraseUnnecessaryTensorOperands())
addEraseUnnecessaryTensorOperandsPatterns(patterns);
+ if (getExpandMemrefStridedMetadata())
+ memref::populateExpandStridedMetadataPatterns(patterns);
if (getFoldMemrefAliases()) addFoldMemrefAliasPatterns(patterns);
if (getFoldReassociativeReshapes()) addReassociativeReshapePatterns(patterns);
if (getFoldTensorEmptyExtract()) addFoldTensorEmptyExtract(patterns);
if (getRankReducingLinalg()) addRankReducingLinalgPatterns(patterns);
if (getRankReducingVector()) addRankReducingVectorPatterns(patterns);
- if (getExpandMemrefStridedMetadata())
- memref::populateExpandStridedMetadataPatterns(patterns);
+ if (getRewritePackOps()) addRewritePackOpsPatterns(patterns);
if (getSwappingPatterns())
addSwappingPatterns(patterns, getSwapPaddingElideConditional());
if (getAdditionalIreePatterns()) addAdditionalIreePatterns(patterns);
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
index e59d825..87cea4f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
@@ -39,6 +39,7 @@
bool bubbleCollapseExpand = false;
bool canonicalization = false;
bool eraseUnnecessaryTensorOperands = false;
+ bool expandMemrefStridedMetadata = false;
bool foldMemrefAliases = false;
bool foldReassociativeReshapes = false;
bool foldTensorEmptyExtract = false;
@@ -46,7 +47,7 @@
bool promoteForeachThreadCaptureToShared = false;
bool rankReducingLinalg = false;
bool rankReducingVector = false;
- bool expandMemrefStridedMetadata = false;
+ bool rewritePackOps = false;
bool swapPaddingElideConditional = false;
bool swappingPatterns = false;
};
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
index fbee078..1dcd352 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
@@ -44,6 +44,9 @@
registered dialects and ops.
- erase_unnecessary_tensor_operands: add patterns that erase unnecessary
tensor operands.
+ - expand_memref_strided_metadata: adds patterns that expand memref
+ operations into extract_strided_metadata operations and a materialization
+ of their effect on the metadata (sizes, offset, strides).
- fold_memref_aliases: adds patterns for folding ops such as
memref.subview.
- fold_reassociative_reshapes: adds patterns that fold insert_slice/
@@ -56,9 +59,9 @@
behavior on subset-based linalg operations.
- rank_reducing_vector: adds patterns that results in rank-reducing
behavior on subset-based vector operations.
- - expand_memref_strided_metadata: adds patterns that expand memref
- operations into extract_strided_metadata operations and a materialization
- of their effect on the metadata (sizes, offset, strides).
+ - rewrite_pack_ops: rewrite tensor.pack/unpack to linalg_ext.pack/unpack.
+ This is a temporary pattern that is needed to connect to IREE until it
+ adopts the upstream version.
- swapping_patterns: adds patterns that swap operations for a better outcome.
This is a catch all that can be refined further if/when needed.
- swap_padding_elide_conditional: refines the tensor.pad +
@@ -86,13 +89,14 @@
UnitAttr:$bubble_collapse_expand,
UnitAttr:$canonicalization,
UnitAttr:$erase_unnecessary_tensor_operands,
+ UnitAttr:$expand_memref_strided_metadata,
UnitAttr:$fold_memref_aliases,
UnitAttr:$fold_reassociative_reshapes,
UnitAttr:$fold_tensor_empty_extract,
UnitAttr:$lower_transfer_op_permutations,
UnitAttr:$rank_reducing_linalg,
UnitAttr:$rank_reducing_vector,
- UnitAttr:$expand_memref_strided_metadata,
+ UnitAttr:$rewrite_pack_ops,
UnitAttr:$swap_padding_elide_conditional,
UnitAttr:$swapping_patterns);
let results = (outs PDL_Operation:$result);
diff --git a/tests/transform_dialect/cpu/BUILD b/tests/transform_dialect/cpu/BUILD
index 3527606..5f996ac 100644
--- a/tests/transform_dialect/cpu/BUILD
+++ b/tests/transform_dialect/cpu/BUILD
@@ -16,6 +16,8 @@
iree_lit_test_suite(
name = "lit",
srcs = [
+ "contraction-packing.mlir",
+ "contraction-packing-and-dispatch.mlir",
"eltwise_reduction_eltwise.mlir",
"matmul.mlir",
],
diff --git a/tests/transform_dialect/cpu/CMakeLists.txt b/tests/transform_dialect/cpu/CMakeLists.txt
index 4f21549..1de40c1 100644
--- a/tests/transform_dialect/cpu/CMakeLists.txt
+++ b/tests/transform_dialect/cpu/CMakeLists.txt
@@ -14,6 +14,8 @@
NAME
lit
SRCS
+ "contraction-packing-and-dispatch.mlir"
+ "contraction-packing.mlir"
"eltwise_reduction_eltwise.mlir"
"matmul.mlir"
TOOLS
diff --git a/tests/transform_dialect/cpu/contraction-packing-and-dispatch.mlir b/tests/transform_dialect/cpu/contraction-packing-and-dispatch.mlir
new file mode 100644
index 0000000..6c37451
--- /dev/null
+++ b/tests/transform_dialect/cpu/contraction-packing-and-dispatch.mlir
@@ -0,0 +1,74 @@
+
+// Preprocessing with generalized packing.
+//
+// RUN: iree-opt %s --iree-transform-dialect-interpreter --transform-dialect-drop-schedule | \
+// RUN: iree-opt --iree-hal-target-backends=llvm-cpu \
+// RUN: --iree-abi-transformation-pipeline \
+// RUN: --iree-flow-transformation-pipeline \
+// RUN: --iree-stream-transformation-pipeline \
+// RUN: --iree-hal-configuration-pipeline | \
+// RUN: FileCheck %s
+
+// Check that compilation runs all the way to the end.
+// TODO: this currently fails with:
+// 'memref.alloca' op all stack allocations need to be hoisted to the entry block of the function
+//
+// R-UN: iree-opt %s --iree-transform-dialect-interpreter --transform-dialect-drop-schedule | \
+// R-UN: iree-compile --iree-hal-target-backends=llvm-cpu
+
+!a_tensor_t = tensor<1234x567xf32>
+!b_tensor_t = tensor<567x890xf32>
+!c_tensor_t = tensor<1234x890xf32>
+
+// Note: the normalization in these maps is gone due to InterchangeGenericOps.
+// When using generalized packing, it would be better to drop that pass.
+
+// CHECK-DAG: #[[$map_lhs:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d2, d5)>
+// CHECK-DAG: #[[$map_rhs:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d1, d3, d5)>
+// CHECK-DAG: #[[$map_res:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+
+// CHECK-LABEL: func.func @_matmul_dispatch_0
+// CHECK: tensor.empty() : tensor<155x18x8x32xf32>
+// CHECK: iree_linalg_ext.pack
+
+// CHECK-LABEL: func.func @_matmul_dispatch_1
+// CHECK: arith.constant dense<1.000000e-01> : tensor<567x890xf32>
+// CHECK: tensor.empty() : tensor<18x56x16x32xf32>
+// CHECK: iree_linalg_ext.pack
+
+// CHECK-LABEL: func.func @_matmul_dispatch_2
+// CHECK: tensor.empty() : tensor<155x56x8x16xf32>
+// CHECK: iree_linalg_ext.pack
+
+// CHECK-LABEL: func.func @_matmul_dispatch_3
+func.func @matmul(%arg0: !a_tensor_t, %arg2: !c_tensor_t) -> !c_tensor_t {
+ %c0 = arith.constant dense<0.1> : !b_tensor_t
+ // CHECK-NOT: pack
+ // CHECK: linalg.generic {indexing_maps = [#[[$map_lhs]], #[[$map_rhs]], #[[$map_res]]],
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
+ // CHECK-SAME: ins(%{{.*}} : tensor<155x18x8x32xf32>, tensor<18x56x16x32xf32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<155x56x8x16xf32>)
+
+ %0 = linalg.matmul
+ ins(%arg0, %c0: !a_tensor_t, !b_tensor_t)
+ outs(%arg2: !c_tensor_t) -> !c_tensor_t
+ return %0 : !c_tensor_t
+}
+
+// CHECK-LABEL: func.func @_matmul_dispatch_4
+// CHECK: iree_linalg_ext.unpack
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ %matmul = transform.structured.match interface{LinalgOp} in %module_op
+ : (!pdl.operation) -> (!pdl.operation)
+
+ transform.structured.pack_greedily %matmul
+ gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [0, 1, 2]
+ : (!pdl.operation) -> !transform.op<"linalg.generic">
+
+ // TODO: Remove once IREE adopts tensor.pack/unpack.
+ %func = transform.structured.match ops{["func.func"]} in %module_op
+ : (!pdl.operation) -> (!pdl.operation)
+ transform.iree.apply_patterns %func { rewrite_pack_ops }
+}
diff --git a/tests/transform_dialect/cpu/contraction-packing.mlir b/tests/transform_dialect/cpu/contraction-packing.mlir
new file mode 100644
index 0000000..828f2c6
--- /dev/null
+++ b/tests/transform_dialect/cpu/contraction-packing.mlir
@@ -0,0 +1,151 @@
+
+// Preprocessing with generalized packing.
+//
+// RUN: iree-opt %s --iree-transform-dialect-interpreter --transform-dialect-drop-schedule | \
+// RUN: FileCheck %s
+
+!a_tensor_t = tensor<1234x567xf32>
+!at_tensor_t = tensor<567x1234xf32>
+!b_tensor_t = tensor<567x890xf32>
+!bt_tensor_t = tensor<890x567xf32>
+!c_tensor_t = tensor<1234x890xf32>
+!ct_tensor_t = tensor<890x1234xf32>
+
+// CHECK-DAG: #[[$map_lhs:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[$map_rhs:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
+// CHECK-DAG: #[[$map_res:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+// CHECK-DAG: #[[$map_tlhs:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>
+// CHECK-DAG: #[[$map_trhs:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
+// CHECK-DAG: #[[$map_tres:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>
+
+// CHECK-LABEL: func.func @matmul_nnn
+func.func @matmul_nnn(%arg0: !a_tensor_t, %arg2: !c_tensor_t) -> !c_tensor_t {
+ %c0 = arith.constant dense<0.1> : !b_tensor_t
+
+ // CHECK: tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+ // CHECK: tensor.pack %{{.*}} inner_dims_pos = [1, 0] inner_tiles = [16, 32]
+ // CHECK: tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 16]
+ // CHECK: linalg.generic
+ // CHECK-SAME: indexing_maps = [#[[$map_lhs]], #[[$map_rhs]], #[[$map_res]]]
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+ // CHECK-SAME: ins(%{{.*}} : tensor<155x18x8x32xf32>, tensor<18x56x16x32xf32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<155x56x8x16xf32>)
+ // CHECK: tensor.unpack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 16]
+ %0 = linalg.matmul
+ ins(%arg0, %c0: !a_tensor_t, !b_tensor_t)
+ outs(%arg2: !c_tensor_t) -> !c_tensor_t
+ return %0 : !c_tensor_t
+}
+
+#matmul_tnn_trait = {
+ indexing_maps = [
+ affine_map<(m, n, k) -> (k, m)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-LABEL: func.func @matmul_tnn
+func.func @matmul_tnn(%arg0: !at_tensor_t, %arg2: !c_tensor_t) -> !c_tensor_t {
+ %c0 = arith.constant dense<0.1> : !b_tensor_t
+
+ // CHECK: tensor.pack %{{.*}} inner_dims_pos = [1, 0] inner_tiles = [8, 32]
+ // CHECK: tensor.pack %{{.*}} inner_dims_pos = [1, 0] inner_tiles = [16, 32]
+ // CHECK: tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 16]
+ // CHECK: linalg.generic
+ // CHECK-SAME: indexing_maps = [#[[$map_tlhs]], #[[$map_rhs]], #[[$map_res]]]
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+ // CHECK-SAME: ins(%{{.*}} : tensor<18x155x8x32xf32>, tensor<18x56x16x32xf32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<155x56x8x16xf32>)
+ // CHECK: tensor.unpack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 16]
+ %0 = linalg.generic #matmul_tnn_trait
+ ins(%arg0, %c0: !at_tensor_t, !b_tensor_t)
+ outs(%arg2: !c_tensor_t) {
+ ^bb(%a: f32, %b: f32, %c: f32) :
+ %d = arith.mulf %a, %b: f32
+ %e = arith.addf %c, %d: f32
+ linalg.yield %e : f32
+ } -> !c_tensor_t
+ return %0 : !c_tensor_t
+}
+
+#matmul_ntn_trait = {
+ indexing_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (n, k)>,
+ affine_map<(m, n, k) -> (m, n)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-LABEL: func.func @matmul_ntn
+func.func @matmul_ntn(%arg0: !a_tensor_t, %arg2: !c_tensor_t) -> !c_tensor_t {
+ %c0 = arith.constant dense<0.1> : !bt_tensor_t
+
+ // CHECK: tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+ // CHECK: tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [16, 32]
+ // CHECK: tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 16]
+ // CHECK: linalg.generic
+ // CHECK-SAME: indexing_maps = [#[[$map_lhs]], #[[$map_trhs]], #[[$map_res]]]
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+ // CHECK-SAME: ins(%{{.*}} : tensor<155x18x8x32xf32>, tensor<56x18x16x32xf32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<155x56x8x16xf32>)
+ // CHECK: tensor.unpack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 16]
+ %0 = linalg.generic #matmul_ntn_trait
+ ins(%arg0, %c0: !a_tensor_t, !bt_tensor_t)
+ outs(%arg2: !c_tensor_t) {
+ ^bb(%a: f32, %b: f32, %c: f32) :
+ %d = arith.mulf %a, %b: f32
+ %e = arith.addf %c, %d: f32
+ linalg.yield %e : f32
+ } -> !c_tensor_t
+ return %0 : !c_tensor_t
+}
+
+#matmul_nnt_trait = {
+ indexing_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (n, m)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-LABEL: func.func @matmul_nnt
+func.func @matmul_nnt(%arg0: !a_tensor_t, %arg2: !ct_tensor_t) -> !ct_tensor_t {
+ %c0 = arith.constant dense<0.1> : !b_tensor_t
+
+ // CHECK: tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+ // CHECK: tensor.pack %{{.*}} inner_dims_pos = [1, 0] inner_tiles = [16, 32]
+ // CHECK: tensor.pack %{{.*}} inner_dims_pos = [1, 0] inner_tiles = [8, 16]
+ // CHECK: linalg.generic
+ // CHECK-SAME: indexing_maps = [#[[$map_lhs]], #[[$map_rhs]], #[[$map_tres]]]
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+ // CHECK-SAME: ins(%{{.*}} : tensor<155x18x8x32xf32>, tensor<18x56x16x32xf32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<56x155x8x16xf32>)
+ // CHECK: tensor.unpack %{{.*}} inner_dims_pos = [1, 0] inner_tiles = [8, 16]
+ %0 = linalg.generic #matmul_nnt_trait
+ ins(%arg0, %c0: !a_tensor_t, !b_tensor_t)
+ outs(%arg2: !ct_tensor_t) {
+ ^bb(%a: f32, %b: f32, %c: f32) :
+ %d = arith.mulf %a, %b: f32
+ %e = arith.addf %c, %d: f32
+ linalg.yield %e : f32
+ } -> !ct_tensor_t
+ return %0 : !ct_tensor_t
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ %matmul = transform.structured.match interface{LinalgOp} in %module_op
+ : (!pdl.operation) -> (!pdl.operation)
+
+ // Generalized packing rewrite extracts a gemm from any linalg op that contains
+ // one. This acts as a powerful normalization step: after this point, we have a
+ // gemm (i.e. 3-D contraction with (m,n,k)=(8,16,32) ) on the 3 most minor
+ // dimensions.
+ transform.structured.pack_greedily %matmul
+ gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [0, 1, 2]
+ : (!pdl.operation) -> !transform.op<"linalg.generic">
+}