Use upstreamed ApplyPatternsOp in buildCanonicalizationAndEnablingTransforms (#14026)
Canonicalization, CSE and licm is still applied via the old
ApplyPatternsOp.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp
index bc7f79a..39b11d2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp
@@ -107,6 +107,7 @@
linalg::registerTransformDialectExtension(registry);
memref::registerTransformDialectExtension(registry);
scf::registerTransformDialectExtension(registry);
+ tensor::registerTransformDialectExtension(registry);
vector::registerTransformDialectExtension(registry);
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 86d935b..c3ac3b2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -104,6 +104,7 @@
//===---------------------------------------------------------------------===//
// ApplyBufferOptimizationsOp
//===---------------------------------------------------------------------===//
+
DiagnosedSilenceableFailure
transform_dialect::ApplyBufferOptimizationsOp::applyToOne(
Operation *target, transform::ApplyToEachResultList &results,
@@ -129,124 +130,10 @@
}
//===---------------------------------------------------------------------===//
-// ApplyPatternsOp
+// Apply...PatternsOp
//===---------------------------------------------------------------------===//
-void transform_dialect::ApplyPatternsOp::build(
- OpBuilder &builder, OperationState &result, Value target,
- const ApplyPatternsOpPatterns &patterns) {
- result.addOperands(target);
-
- auto unitAttr = builder.getUnitAttr();
-
-#define ADD_PATTERN(NAME, ATTR) \
- if (patterns.NAME) \
- result.addAttribute(ApplyPatternsOp::ATTR(result.name), unitAttr);
- ///
- /// When touching something here, do not forget to update CommonExtensions.h.
- ///
- ADD_PATTERN(additionalIreePatterns, getAdditionalIreePatternsAttrName)
- ADD_PATTERN(bubbleCollapse, getBubbleCollapseAttrName)
- ADD_PATTERN(bubbleExpand, getBubbleExpandAttrName)
- ADD_PATTERN(bubblePackUnPack, getBubblePackUnPackAttrName)
- ADD_PATTERN(canonicalization, getCanonicalizationAttrName)
- ADD_PATTERN(cse, getCseAttrName)
- ADD_PATTERN(eraseUnnecessaryTensorOperands,
- getEraseUnnecessaryTensorOperandsAttrName)
- ADD_PATTERN(expandMemrefStridedMetadata,
- getExpandMemrefStridedMetadataAttrName)
- ADD_PATTERN(extractAddressComputations, getExtractAddressComputationsAttrName)
- ADD_PATTERN(foldMemrefAliases, getFoldMemrefAliasesAttrName)
- ADD_PATTERN(foldReassociativeReshapes, getFoldReassociativeReshapesAttrName)
- ADD_PATTERN(foldTensorEmptyExtract, getFoldTensorEmptyExtractAttrName)
- ADD_PATTERN(foldTensorSubsets, getFoldTensorSubsetsAttrName)
- ADD_PATTERN(foldVectorTransferTensorSlice,
- getFoldVectorTransferTensorSliceAttrName)
- ADD_PATTERN(licm, getLicmAttrName)
- ADD_PATTERN(linalgElementwiseGreedyFusion,
- getLinalgElementwiseGreedyFusionAttrName)
- ADD_PATTERN(lowerTransferOpPermutations,
- getLowerTransferOpPermutationsAttrName)
- ADD_PATTERN(lowerVectorMasks, getLowerVectorMasksAttrName)
- ADD_PATTERN(prepareVectorToMma, getPrepareVectorToMmaAttrName)
- ADD_PATTERN(rankReducingLinalg, getRankReducingLinalgAttrName)
- ADD_PATTERN(rankReducingLinalgViaReshapes,
- getRankReducingLinalgViaReshapesAttrName)
- ADD_PATTERN(rankReducingVector, getRankReducingVectorAttrName)
- ADD_PATTERN(swapPaddingElideConditional,
- getSwapPaddingElideConditionalAttrName)
- ADD_PATTERN(swappingPatterns, getSwappingPatternsAttrName)
- ADD_PATTERN(tilingCanonicalization, getTilingCanonicalizationAttrName)
- ADD_PATTERN(unrollVectorsGpuMmaSync, getUnrollVectorsGpuMmaSyncAttrName)
- ADD_PATTERN(unrollVectorsGpuWmma, getUnrollVectorsGpuWmmaAttrName)
-#undef ADD_PATTERN
-}
-
-static void addOperands(Operation *op, SetVector<Value> &operandSet) {
- if (!op) return;
- TypeSwitch<Operation *, void>(op)
- .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
- SmallVector<Value> inputOperands{linalgOp.getDpsInputOperands()};
- operandSet.insert(inputOperands.begin(), inputOperands.end());
- })
- .Default([&](Operation *operation) {
- operandSet.insert(operation->operand_begin(), operation->operand_end());
- });
-}
-
-template <int limit = 3>
-static bool setFusedOpOperandLimit(OpOperand *fusedOperand) {
- Operation *producer = fusedOperand->get().getDefiningOp();
- if (!producer) return false;
- Operation *consumer = fusedOperand->getOwner();
- SetVector<Value> fusedOpOperands;
- if (producer->getNumResults() != 1) return false;
- addOperands(consumer, fusedOpOperands);
- fusedOpOperands.remove(producer->getResult(0));
- addOperands(producer, fusedOpOperands);
- return fusedOpOperands.size() <= limit;
-}
namespace {
-/// Rewrite a tensor.generate as an arith.constant when possible.
-struct GenerateToConstant : public OpRewritePattern<tensor::GenerateOp> {
- using OpRewritePattern<tensor::GenerateOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tensor::GenerateOp generateOp,
- PatternRewriter &rewriter) const final {
- auto tensorType =
- llvm::cast<RankedTensorType>(generateOp.getResult().getType());
- if (!tensorType.hasStaticShape()) return failure();
- auto terminatorOp =
- cast<tensor::YieldOp>(generateOp.getBody().front().getTerminator());
- if (terminatorOp->getNumOperands() > 1) return failure();
- auto constantOp =
- terminatorOp->getOperand(0).getDefiningOp<arith::ConstantOp>();
- if (!constantOp) return failure();
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(
- generateOp, tensorType,
- DenseElementsAttr::get(tensorType, constantOp.getValueAttr()));
- return success();
- }
-};
-
-/// Fold tensor.empty used by extract_slice if this the only use of
-/// extract_slice and the result is static.
-struct FoldTensorEmptyExtract
- : public OpRewritePattern<tensor::ExtractSliceOp> {
- using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
- PatternRewriter &rewriter) const final {
- auto tensorEmpty = extractOp.getSource().getDefiningOp<tensor::EmptyOp>();
- if (!tensorEmpty || !extractOp.getType().hasStaticShape() ||
- !tensorEmpty->hasOneUse())
- return failure();
- rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
- extractOp, extractOp.getType().getShape(),
- extractOp.getType().getElementType());
- return success();
- }
-};
-
/// Fold `tensor.pad(cst, tensor.extract*(linalg.fill(cst)))` into
/// `linalg.fill(cst, empty)` when the padding constant and the fill constant
/// are the same.
@@ -292,6 +179,87 @@
};
} // namespace
+void transform_dialect::ApplyFoldFillIntoPadPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ patterns.insert<FoldFillIntoPad>(patterns.getContext());
+}
+
+//===---------------------------------------------------------------------===//
+// ApplyPatternsOp
+//===---------------------------------------------------------------------===//
+
+void transform_dialect::ApplyPatternsOp::build(
+ OpBuilder &builder, OperationState &result, Value target,
+ const ApplyPatternsOpPatterns &patterns) {
+ result.addOperands(target);
+
+ auto unitAttr = builder.getUnitAttr();
+
+#define ADD_PATTERN(NAME, ATTR) \
+ if (patterns.NAME) \
+ result.addAttribute(ApplyPatternsOp::ATTR(result.name), unitAttr);
+ ///
+ /// When touching something here, do not forget to update CommonExtensions.h.
+ ///
+ ADD_PATTERN(bubbleCollapse, getBubbleCollapseAttrName)
+ ADD_PATTERN(bubbleExpand, getBubbleExpandAttrName)
+ ADD_PATTERN(bubblePackUnPack, getBubblePackUnPackAttrName)
+ ADD_PATTERN(canonicalization, getCanonicalizationAttrName)
+ ADD_PATTERN(cse, getCseAttrName)
+ ADD_PATTERN(eraseUnnecessaryTensorOperands,
+ getEraseUnnecessaryTensorOperandsAttrName)
+ ADD_PATTERN(expandMemrefStridedMetadata,
+ getExpandMemrefStridedMetadataAttrName)
+ ADD_PATTERN(extractAddressComputations, getExtractAddressComputationsAttrName)
+ ADD_PATTERN(foldMemrefAliases, getFoldMemrefAliasesAttrName)
+ ADD_PATTERN(foldReassociativeReshapes, getFoldReassociativeReshapesAttrName)
+ ADD_PATTERN(foldTensorSubsets, getFoldTensorSubsetsAttrName)
+ ADD_PATTERN(foldVectorTransferTensorSlice,
+ getFoldVectorTransferTensorSliceAttrName)
+ ADD_PATTERN(licm, getLicmAttrName)
+ ADD_PATTERN(linalgElementwiseGreedyFusion,
+ getLinalgElementwiseGreedyFusionAttrName)
+ ADD_PATTERN(lowerTransferOpPermutations,
+ getLowerTransferOpPermutationsAttrName)
+ ADD_PATTERN(lowerVectorMasks, getLowerVectorMasksAttrName)
+ ADD_PATTERN(prepareVectorToMma, getPrepareVectorToMmaAttrName)
+ ADD_PATTERN(rankReducingLinalg, getRankReducingLinalgAttrName)
+ ADD_PATTERN(rankReducingLinalgViaReshapes,
+ getRankReducingLinalgViaReshapesAttrName)
+ ADD_PATTERN(rankReducingVector, getRankReducingVectorAttrName)
+ ADD_PATTERN(swapPaddingElideConditional,
+ getSwapPaddingElideConditionalAttrName)
+ ADD_PATTERN(swappingPatterns, getSwappingPatternsAttrName)
+ ADD_PATTERN(unrollVectorsGpuMmaSync, getUnrollVectorsGpuMmaSyncAttrName)
+ ADD_PATTERN(unrollVectorsGpuWmma, getUnrollVectorsGpuWmmaAttrName)
+#undef ADD_PATTERN
+}
+
+static void addOperands(Operation *op, SetVector<Value> &operandSet) {
+ if (!op) return;
+ TypeSwitch<Operation *, void>(op)
+ .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
+ SmallVector<Value> inputOperands{linalgOp.getDpsInputOperands()};
+ operandSet.insert(inputOperands.begin(), inputOperands.end());
+ })
+ .Default([&](Operation *operation) {
+ operandSet.insert(operation->operand_begin(), operation->operand_end());
+ });
+}
+
+template <int limit = 3>
+static bool setFusedOpOperandLimit(OpOperand *fusedOperand) {
+ Operation *producer = fusedOperand->get().getDefiningOp();
+ if (!producer) return false;
+ Operation *consumer = fusedOperand->getOwner();
+ SetVector<Value> fusedOpOperands;
+ if (producer->getNumResults() != 1) return false;
+ addOperands(consumer, fusedOpOperands);
+ fusedOpOperands.remove(producer->getResult(0));
+ addOperands(producer, fusedOpOperands);
+ return fusedOpOperands.size() <= limit;
+}
+
static void addLowerTransferOpPermutationsPatterns(
RewritePatternSet &patterns) {
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
@@ -309,10 +277,6 @@
memref::populateFoldMemRefAliasOpPatterns(patterns);
}
-static void addFoldTensorEmptyExtract(RewritePatternSet &patterns) {
- patterns.add<FoldTensorEmptyExtract>(patterns.getContext());
-}
-
static void addReassociativeReshapePatterns(RewritePatternSet &patterns) {
tensor::populateReassociativeReshapeFoldingPatterns(patterns);
}
@@ -361,14 +325,6 @@
});
}
-static void addTilingCanonicalizationPatterns(RewritePatternSet &patterns) {
- linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
- scf::populateSCFForLoopCanonicalizationPatterns(patterns);
- /// This seems generally desirable as a folding but may be too intrusive, so
- /// we only apply it selectively for now.
- patterns.add<FoldFillIntoPad>(patterns.getContext());
-}
-
static std::optional<SmallVector<int64_t>>
getGPUTensorCoreNativeMmaSyncVectorSize(Operation *op) {
return getMmaNativeVectorSize(op);
@@ -403,10 +359,6 @@
.setUnrollTraversalOrderFn(unrollOrder));
}
-static void addAdditionalIreePatterns(RewritePatternSet &patterns) {
- patterns.add<GenerateToConstant>(patterns.getContext());
-}
-
static void addAllRegisteredCanonicalizationPatterns(
RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();
@@ -427,7 +379,6 @@
}
MLIRContext *ctx = target->getContext();
RewritePatternSet patterns(ctx);
- if (getAdditionalIreePatterns()) addAdditionalIreePatterns(patterns);
if (getBubbleCollapse()) {
linalg::populateFoldReshapeOpsByCollapsingPatterns(
patterns, [](OpOperand *) { return true; });
@@ -448,7 +399,6 @@
addExtractAddressComputationsPatterns(patterns);
if (getFoldMemrefAliases()) addFoldMemrefAliasPatterns(patterns);
if (getFoldReassociativeReshapes()) addReassociativeReshapePatterns(patterns);
- if (getFoldTensorEmptyExtract()) addFoldTensorEmptyExtract(patterns);
if (getFoldTensorSubsets()) addFoldTensorSubsetsPatterns(patterns);
if (getFoldVectorTransferTensorSlice())
addFoldVectorTransferTensorExtractPatterns(patterns);
@@ -465,7 +415,6 @@
if (getRankReducingVector()) addRankReducingVectorPatterns(patterns);
if (getSwappingPatterns())
addSwappingPatterns(patterns, getSwapPaddingElideConditional());
- if (getTilingCanonicalization()) addTilingCanonicalizationPatterns(patterns);
if (getUnrollVectorsGpuMmaSync())
addUnrollVectorsGpuMmaSyncPatterns(patterns);
if (getUnrollVectorsGpuWmma()) addUnrollVectorsGpuWmmaPatterns(patterns);
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
index c0b3887..98e3a91 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
@@ -34,7 +34,6 @@
namespace transform_dialect {
/// Selected patterns for ApplyPatternOp.
struct ApplyPatternsOpPatterns {
- bool additionalIreePatterns = false;
bool bubbleCollapse = false;
bool bubbleExpand = false;
bool bubblePackUnPack = false;
@@ -45,7 +44,6 @@
bool extractAddressComputations = false;
bool foldMemrefAliases = false;
bool foldReassociativeReshapes = false;
- bool foldTensorEmptyExtract = false;
bool foldTensorSubsets = false;
bool foldVectorTransferTensorSlice = false;
bool licm = false;
@@ -58,7 +56,6 @@
bool rankReducingVector = false;
bool swapPaddingElideConditional = false;
bool swappingPatterns = false;
- bool tilingCanonicalization = false;
bool unrollVectorsGpuMmaSync = false;
bool unrollVectorsGpuWmma = false;
};
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
index c415eff..33bd021 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
@@ -51,6 +51,20 @@
}];
}
+def ApplyFoldFillIntoPadPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.iree.fold_fill_into_pad",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Populates a pattern that folds
+ "tensor.pad(cst, tensor.extract*(linalg.fill(cst)))" into
+ "linalg.fill(cst, empty)" when the padding constant and the fill constant
+ are the same.
+ }];
+
+ let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyPatternsOp : Op<Transform_Dialect, "iree.apply_patterns",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformEachOpTrait,
@@ -70,8 +84,6 @@
The following additive attributes can be set, they add patterns in an
unspecified order:
- - additional_iree_patterns: fancy patterns we shortcut into the system,
- will need to be sliced out better in the future.
- bubble_collapse: bubble `collapse_shape` down across Linalg ops. This
must be applied separately from `bubble_expand` patterns because of some
upstream pattern interference issue atm.
@@ -97,8 +109,6 @@
memref.subview.
- fold_reassociative_reshapes: adds patterns that fold insert_slice/
extract_slice ops with reassociative reshape ops.
- - fold_tensor_empty_extract: Fold tensor.empty used by extract_slice in
- case it is the only use of extract.
- fold_tensor_subsets: adds patterns for folding tensor subset ops into
their producer and consumers.
- licm: additionally apply loop-independent code motion and single
@@ -125,8 +135,6 @@
tensor.extract_slice swapping pattern. This injects static information
that guarantees padding is smaller than the window size which guarantees
we never see a tile comprised of padding-only.
- - tiling_canonicalization: adds specific tiling-related canonicalization
- patterns.
- unroll_vectors_gpu_mma_sync: adds patterns that unroll vectors to a native tile
size for GPUs with mma operations. The size is currently hardcoded but
should be refactored upstream and made pluggable.
@@ -151,7 +159,6 @@
}];
let arguments = (ins TransformHandleTypeInterface:$target,
- UnitAttr:$additional_iree_patterns,
UnitAttr:$bubble_collapse,
UnitAttr:$bubble_expand,
UnitAttr:$bubble_pack_un_pack,
@@ -162,7 +169,6 @@
UnitAttr:$extract_address_computations,
UnitAttr:$fold_memref_aliases,
UnitAttr:$fold_reassociative_reshapes,
- UnitAttr:$fold_tensor_empty_extract,
UnitAttr:$fold_tensor_subsets,
UnitAttr:$fold_vector_transfer_tensor_slice,
UnitAttr:$licm,
@@ -175,7 +181,6 @@
UnitAttr:$rank_reducing_vector,
UnitAttr:$swap_padding_elide_conditional,
UnitAttr:$swapping_patterns,
- UnitAttr:$tiling_canonicalization,
UnitAttr:$unroll_vectors_gpu_mma_sync,
UnitAttr:$unroll_vectors_gpu_wmma);
let results = (outs);
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/transform_dialect_apply_pattern_op.mlir b/compiler/src/iree/compiler/Codegen/Common/test/transform_dialect_apply_pattern_op.mlir
index 5abedc7..0e777df 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/transform_dialect_apply_pattern_op.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/transform_dialect_apply_pattern_op.mlir
@@ -112,7 +112,11 @@
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.iree.apply_patterns %0 { tiling_canonicalization } : (!transform.any_op) -> ()
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
}
// -----
@@ -141,7 +145,11 @@
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.iree.apply_patterns %0 { tiling_canonicalization } : (!transform.any_op) -> ()
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
}
// -----
@@ -171,7 +179,11 @@
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.iree.apply_patterns %0 { tiling_canonicalization } : (!transform.any_op) -> ()
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
}
// -----
@@ -203,7 +215,11 @@
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.iree.apply_patterns %0 { tiling_canonicalization } : (!transform.any_op) -> ()
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
}
// -----
@@ -233,5 +249,9 @@
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.iree.apply_patterns %0 { tiling_canonicalization } : (!transform.any_op) -> ()
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
}
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel
index 1311721..6784b01 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel
@@ -67,6 +67,7 @@
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SCFTransformOps",
"@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:TensorTransformOps",
"@llvm-project//mlir:VectorTransformOps",
],
)
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt
index 3a40f12..10bf352 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt
@@ -38,6 +38,7 @@
MLIRSCFDialect
MLIRSCFTransformOps
MLIRTensorDialect
+ MLIRTensorTransformOps
MLIRVectorTransformOps
iree::compiler::Codegen::Common::TransformExtensions::CommonExtensions
iree::compiler::Codegen::LLVMCPU::TransformExtensions::LLVMCPUExtensions
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp
index 1699604..d65f8ae 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp
@@ -31,6 +31,7 @@
#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
namespace mlir {
@@ -61,6 +62,7 @@
memref::registerValueBoundsOpInterfaceExternalModels(registry);
scf::registerTransformDialectExtension(registry);
scf::registerValueBoundsOpInterfaceExternalModels(registry);
+ tensor::registerTransformDialectExtension(registry);
tensor::registerValueBoundsOpInterfaceExternalModels(registry);
vector::registerTransformDialectExtension(registry);
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
index 1171461..bd6af02 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
@@ -58,8 +58,13 @@
// Bufferization
// ==========================================
+ transform.apply_patterns to %func_3 {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %func_3
- { fold_reassociative_reshapes, canonicalization, tiling_canonicalization, cse } : (!transform.any_op) -> ()
+ { fold_reassociative_reshapes, canonicalization, cse } : (!transform.any_op) -> ()
transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
transform.iree.apply_patterns %func_3 { erase_unnecessary_tensor_operands } : (!transform.any_op) -> ()
%variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!transform.any_op) -> (!transform.any_op)
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
index 739cefc..cf2ec79 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
@@ -85,7 +85,11 @@
// CHECK: transform.iree.pipeline_shared_memory_copies %{{.*}} {depth = 3 : i64}
// CHECK: transform.apply_patterns.vector.lower_masks
// CHECK: transform.apply_patterns.vector.materialize_masks
-// CHECK: transform.iree.apply_patterns %{{.*}} {canonicalization, cse, fold_memref_aliases, licm, tiling_canonicalization}
+// CHECK: apply_patterns to %{{.*}} {
+// CHECK: transform.apply_patterns.linalg.tiling_canonicalization
+// CHECK: transform.apply_patterns.memref.fold_memref_alias_ops
+// CHECK: } : !transform.any_op
+// CHECK: transform.iree.apply_patterns %{{.*}} {canonicalization, cse, licm}
// WITH_OPTIONS-LABEL: func @matmul
@@ -141,7 +145,11 @@
// WITH_OPTIONS: transform.iree.pipeline_shared_memory_copies %{{.*}} {depth = 5 : i64}
// WITH_OPTIONS: transform.apply_patterns.vector.lower_masks
// WITH_OPTIONS: transform.apply_patterns.vector.materialize_masks
-// WITH_OPTIONS: transform.iree.apply_patterns %{{.*}} {canonicalization, cse, fold_memref_aliases, licm, tiling_canonicalization}
+// WITH_OPTIONS: apply_patterns to %{{.*}} {
+// WITH_OPTIONS: transform.apply_patterns.linalg.tiling_canonicalization
+// WITH_OPTIONS: transform.apply_patterns.memref.fold_memref_alias_ops
+// WITH_OPTIONS: } : !transform.any_op
+// WITH_OPTIONS: transform.iree.apply_patterns %{{.*}} {canonicalization, cse, licm}
// -----
@@ -260,7 +268,7 @@
// CHECK-SAME: pack_paddings = [1, 1, 1]
// CHECK-SAME: padding_dimensions = [0, 1, 2]
// CHECK-SAME: padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
-// CHECK: transform.iree.apply_patterns %{{.*}} {canonicalization, cse, licm, tiling_canonicalization}
+// CHECK: transform.iree.apply_patterns %{{.*}} {canonicalization, cse, licm}
// CHECK: %[[RES_PAD:.+]] = get_producer_of_operand %{{.*}}[2]
// CHECK: %[[RES_COPY:.+]] = transform.structured.rewrite_in_destination_passing_style %[[RES_PAD]]
// CHECK: %[[LHS_PAD:.+]] = get_producer_of_operand %{{.*}}[0]
@@ -273,7 +281,7 @@
// CHECK: transform.scf.take_assumed_branch %{{.*}} take_else_branch
// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [2, 2] tile_sizes [](mapping = [#gpu.warp<y>, #gpu.warp<x>])
// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [2, 2] tile_sizes [](mapping = [#gpu.warp<y>, #gpu.warp<x>])
-// CHECK: transform.iree.apply_patterns %{{.*}} {canonicalization, cse, licm, tiling_canonicalization}
+// CHECK: transform.iree.apply_patterns %{{.*}} {canonicalization, cse, licm}
// alignLhs
// CHECK: transform.structured.masked_vectorize %[[TILED_LHS]] vector_sizes [4, 4]
@@ -322,7 +330,7 @@
// CHECK-SAME: padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
// Canonicalization is currently required here to enable pad to dps to produce linalg.copy ops.
-// CHECK: transform.iree.apply_patterns %{{.*}} {canonicalization, cse, licm, tiling_canonicalization}
+// CHECK: transform.iree.apply_patterns %{{.*}} {canonicalization, cse, licm}
// CHECK: %[[RES_PAD:.+]] = get_producer_of_operand %{{.*}}[2]
// CHECK: %[[RES_COPY:.+]] = transform.structured.rewrite_in_destination_passing_style %[[RES_PAD]]
// CHECK: %[[LHS_PAD:.+]] = get_producer_of_operand %{{.*}}[0]
@@ -333,7 +341,7 @@
// CHECK: transform.structured.tile_to_forall_op %[[RHS_COPY]] num_threads [4, 32] tile_sizes [](mapping = [#gpu.linear<y>, #gpu.linear<x>])
// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [2, 2] tile_sizes [](mapping = [#gpu.warp<y>, #gpu.warp<x>])
// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [2, 2] tile_sizes [](mapping = [#gpu.warp<y>, #gpu.warp<x>])
-// CHECK: transform.iree.apply_patterns %{{.*}} {canonicalization, cse, licm, tiling_canonicalization}
+// CHECK: transform.iree.apply_patterns %{{.*}} {canonicalization, cse, licm}
// Verify we don't go down the path without the flag.
// WITH_OPTIONS-LABEL: func @aligned_matmul
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_pad.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_pad.mlir
index 874be430..d739675 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_pad.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_pad.mlir
@@ -48,12 +48,12 @@
// CHECK: transform.iree.register_match_callbacks
// CHECK: {{.*}} = transform.iree.match_callback failures(propagate) "pad"({{.*}}) : (!transform.any_op) -> !transform.any_op
// CHECK: transform.structured.tile_to_forall_op {{.*}} num_threads [] tile_sizes [64, 64](mapping = [#gpu.block<y>, #gpu.block<x>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-// CHECK: transform.iree.apply_patterns {{.*}} {canonicalization, cse, licm, tiling_canonicalization} : (!transform.any_op) -> ()
+// CHECK: transform.iree.apply_patterns {{.*}} {canonicalization, cse, licm} : (!transform.any_op) -> ()
// CHECK: {{.*}} = transform.structured.match ops{["scf.if"]} in {{.*}} : (!transform.any_op) -> !transform.any_op
// CHECK: transform.scf.take_assumed_branch {{.*}} take_else_branch : (!transform.any_op) -> ()
// CHECK: transform.iree.populate_workgroup_count_region_using_num_threads_slice {{.*}} : (!transform.any_op) -> ()
// CHECK: {{.*}} = transform.structured.tile_to_forall_op {{.*}} num_threads [16, 16] tile_sizes [](mapping = [#gpu.thread<y>, #gpu.thread<x>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-// CHECK: transform.iree.apply_patterns {{.*}} {canonicalization, cse, licm, tiling_canonicalization} : (!transform.any_op) -> ()
+// CHECK: transform.iree.apply_patterns {{.*}} {canonicalization, cse, licm} : (!transform.any_op) -> ()
// CHECK: {{.*}} = transform.structured.match ops{["scf.if"]} in {{.*}} : (!transform.any_op) -> !transform.any_op
// CHECK: transform.scf.take_assumed_branch {{.*}} take_else_branch : (!transform.any_op) -> ()
// CHECK: transform.structured.masked_vectorize {{.*}} vector_sizes [4, 4] : !transform.any_op
@@ -61,7 +61,7 @@
// CHECK: transform.apply_patterns.vector.lower_masked_transfers
// CHECK: transform.iree.apply_patterns {{.*}} {rank_reducing_linalg, rank_reducing_vector} : (!transform.any_op) -> ()
// CHECK: {{.*}} = transform.structured.vectorize {{.*}} : (!transform.any_op) -> !transform.any_op
-// CHECK: transform.iree.apply_patterns {{.*}} {canonicalization, cse, licm, tiling_canonicalization} : (!transform.any_op) -> ()
+// CHECK: transform.iree.apply_patterns {{.*}} {canonicalization, cse, licm} : (!transform.any_op) -> ()
// CHECK: transform.iree.eliminate_empty_tensors {{.*}} : (!transform.any_op) -> ()
// CHECK: {{.*}} = transform.iree.bufferize {target_gpu} {{.*}} : (!transform.any_op) -> !transform.any_op
// CHECK: {{.*}} = transform.structured.match ops{["func.func"]} in {{.*}} : (!transform.any_op) -> !transform.any_op
@@ -72,7 +72,11 @@
// CHECK: transform.iree.map_nested_forall_to_gpu_threads {{.*}} workgroup_dims = [16, 16, 1] warp_dims = [] : (!transform.any_op) -> ()
// CHECK: transform.apply_patterns.vector.lower_masks
// CHECK: transform.apply_patterns.vector.materialize_masks
-// CHECK: transform.iree.apply_patterns {{.*}} {canonicalization, cse, fold_memref_aliases, licm, tiling_canonicalization} : (!transform.any_op) -> ()
+// CHECK: apply_patterns to %{{.*}} {
+// CHECK: transform.apply_patterns.linalg.tiling_canonicalization
+// CHECK: transform.apply_patterns.memref.fold_memref_alias_ops
+// CHECK: } : !transform.any_op
+// CHECK: transform.iree.apply_patterns {{.*}} {canonicalization, cse, licm} : (!transform.any_op) -> ()
// WITH_OPTIONS-LABEL: func @pad
// WITH_OPTIONS: transform.structured.tile_to_forall_op {{.*}} num_threads [] tile_sizes [32, 16](mapping = [#gpu.block<y>, #gpu.block<x>])
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
index 4c15d1d..a094335 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
@@ -14,8 +14,13 @@
// allocs will be created.
%func = transform.structured.match ops{["func.func"]} in %variant_op
: (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %func
- { fold_reassociative_reshapes, canonicalization, tiling_canonicalization, cse } : (!transform.any_op) -> ()
+ { fold_reassociative_reshapes, canonicalization, cse } : (!transform.any_op) -> ()
transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
%variant_op_3 = transform.iree.bufferize %variant_op : (!transform.any_op) -> (!transform.any_op)
%memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
@@ -24,6 +29,11 @@
workgroup_dims = [10, 11, 1] : (!transform.any_op) -> ()
// Late canonicalizations to cleanup and pass the checks
+ transform.apply_patterns to %memref_func {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %memref_func
- { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+ { canonicalization, licm, cse } : (!transform.any_op) -> ()
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_distribution_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_distribution_spec.mlir
index e3850dd..c5dc594 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_distribution_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_distribution_spec.mlir
@@ -10,6 +10,11 @@
: (!transform.any_op) -> ()
// Late canonicalizations to cleanup and pass the checks.
+ transform.apply_patterns to %variant_op {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %variant_op
- { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+ { canonicalization, licm, cse } : (!transform.any_op) -> ()
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_promote_operands.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_promote_operands.mlir
index 06d9572..268bbbd 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_promote_operands.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_promote_operands.mlir
@@ -37,8 +37,13 @@
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
// Late canonicalizations to cleanup and pass the checks.
+ transform.apply_patterns to %variant_op {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %variant_op
- { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+ { canonicalization, licm, cse } : (!transform.any_op) -> ()
}
}
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/Common.cpp
index 6b2dcb1..06640f9 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/Common.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/CPU/Common.cpp
@@ -105,18 +105,19 @@
Value funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
// Step N-5. Fold tensor.empty to avoid large allocations.
- ApplyPatternsOpPatterns configuration;
- configuration.foldTensorEmptyExtract = true;
-
// Step N-4. Perform a pass of canonicalization + enabling after tiling.
- funcH = mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(
- b, configuration, funcH);
+ mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(
+ b, funcH, [](OpBuilder &b, Location loc) {
+ b.create<transform::ApplyFoldTensorEmptyPatternsOp>(loc);
+ });
funcH = iree_compiler::buildVectorize(b, funcH);
// Step N-3. Perform a pass of canonicalization + enabling after vectorization
// as well as hoisting subset operations such as vector.transfer_read/write.
- funcH = mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(
- b, configuration, funcH);
+ mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(
+ b, funcH, [](OpBuilder &b, Location loc) {
+ b.create<transform::ApplyFoldTensorEmptyPatternsOp>(loc);
+ });
iree_compiler::buildHoisting(b, funcH);
// Step N-2. Bufferize and drop HAL descriptor from memref ops.
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/BUILD.bazel
index 6da45f5..72f32e8 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/BUILD.bazel
@@ -55,8 +55,11 @@
# Transforms (needed mostly for the BufferizableOpInterfaceImpl)
"@llvm-project//mlir:ArithTransforms",
"@llvm-project//mlir:LinalgTransforms",
+ "@llvm-project//mlir:MemRefTransformOps",
"@llvm-project//mlir:SCFTransforms",
+ "@llvm-project//mlir:SCFTransformOps",
"@llvm-project//mlir:TensorTransforms",
+ "@llvm-project//mlir:TensorTransformOps",
"@llvm-project//mlir:VectorTransforms",
"@llvm-project//mlir:VectorTransformOps",
# Other Stuff
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/CMakeLists.txt
index 2e5471c..560279c 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/CMakeLists.txt
@@ -39,16 +39,19 @@
MLIRLLVMDialect
MLIRLinalgDialect
MLIRLinalgTransforms
+ MLIRMemRefTransformOps
MLIRPDLDialect
MLIRPDLInterpDialect
MLIRParser
MLIRPass
MLIRRewrite
MLIRSCFDialect
+ MLIRSCFTransformOps
MLIRSCFTransforms
MLIRSCFUtils
MLIRSupport
MLIRTensorDialect
+ MLIRTensorTransformOps
MLIRTensorTransforms
MLIRTransformDialect
MLIRVectorDialect
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.cpp
index cc42018..204c560 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.cpp
@@ -12,6 +12,9 @@
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
+#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
@@ -123,18 +126,24 @@
/// Create an ApplyPatternsOp that performs a set of key canonicalizations and
/// so-called enabling transformations to normalize the IR.
-/// Take an existing configuration by copy (cheap object) that will be augmented
-/// locally to additionally perform:
-/// canonicalization, tiling_canonicalization, licm and cse (in this order).
-Value mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(
- ImplicitLocOpBuilder &b, ApplyPatternsOpPatterns configuration,
- Value variantH) {
+/// In addition to the specified transform, perform the following ones:
+/// tiling-related canonicalization patterns, canonicalization, licm and cse
+/// (in this order).
+void mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(
+ ImplicitLocOpBuilder &b, Value variantH,
+ ApplyPatternsOpBodyBuilderFn populatePatternsFn) {
+ b.create<transform::ApplyPatternsOp>(
+ variantH, [&](OpBuilder &b, Location loc) {
+ b.create<transform::ApplyTilingCanonicalizationPatternsOp>(loc);
+ b.create<IREE::transform_dialect::ApplyFoldFillIntoPadPatternsOp>(loc);
+ b.create<transform::ApplyForLoopCanonicalizationPatternsOp>(loc);
+ if (populatePatternsFn) populatePatternsFn(b, loc);
+ });
+ ApplyPatternsOpPatterns configuration;
configuration.canonicalization = true;
configuration.cse = true;
configuration.licm = true;
- configuration.tilingCanonicalization = true;
b.create<ApplyPatternsOp>(variantH, configuration);
- return variantH;
}
/// Dynamically selects the first non-empty handle; i.e. if (h1, h2) is:
@@ -171,10 +180,8 @@
// matmuls.
// TODO: Make padding less brittle so that this toggle is unnecessary.
if (canonicalize) {
- ApplyPatternsOpPatterns configuration;
- isolatedParentOpH =
- mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(
- b, configuration, isolatedParentOpH);
+ mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(
+ b, isolatedParentOpH);
}
return result;
}
@@ -214,10 +221,8 @@
result.tiledOpH = tileToForeachOp.getTiledOp();
// Perform a pass of canonicalization + enabling after tiling.
- ApplyPatternsOpPatterns configuration;
- isolatedParentOpH =
- mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(
- b, configuration, isolatedParentOpH);
+ mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(
+ b, isolatedParentOpH);
// Batch fusion if requested.
if (opsHToFuse.size() > 1) {
@@ -280,9 +285,7 @@
bool applyCleanups) {
funcH = b.create<VectorizeOp>(funcH);
if (applyCleanups) {
- ApplyPatternsOpPatterns configuration;
- funcH = iree_compiler::buildCanonicalizationAndEnablingTransforms(
- b, configuration, funcH);
+ iree_compiler::buildCanonicalizationAndEnablingTransforms(b, funcH);
}
return funcH;
}
@@ -313,10 +316,10 @@
b.create<transform::ApplyMaterializeMasksPatternsOp>(loc);
});
{
- ApplyPatternsOpPatterns config;
- config.foldMemrefAliases = true;
- iree_compiler::buildCanonicalizationAndEnablingTransforms(b, config,
- containingOpH);
+ iree_compiler::buildCanonicalizationAndEnablingTransforms(
+ b, containingOpH, [](OpBuilder &b, Location loc) {
+ b.create<transform::ApplyFoldMemrefAliasOpsPatternsOp>(loc);
+ });
}
return containingOpH;
}
@@ -331,11 +334,11 @@
Value variantH, bool targetGpu) {
// Perform a pass of canonicalization + enabling before bufferization to avoid
// spurious allocations.
- ApplyPatternsOpPatterns configuration;
- configuration.foldReassociativeReshapes = true;
- configuration.foldVectorTransferTensorSlice = true;
- variantH =
- buildCanonicalizationAndEnablingTransforms(b, configuration, variantH);
+ buildCanonicalizationAndEnablingTransforms(
+ b, variantH, [](OpBuilder &b, Location loc) {
+ b.create<transform::ApplyReassociativeReshapeFoldingPatternsOp>(loc);
+ b.create<transform::ApplyFoldTensorSliceIntoTransferPatternsOp>(loc);
+ });
b.create<IREEEliminateEmptyTensorsOp>(variantH);
variantH = b.create<IREEBufferizeOp>(variantH, targetGpu);
Value memrefFunc =
@@ -469,9 +472,7 @@
.getFusedOp();
// Perform a pass of canonicalization + enabling after fusion.
- ApplyPatternsOpPatterns configuration;
- variantH = mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(
- b, configuration, variantH);
+ mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(b, variantH);
// Step 3. Normalize to reorder results irrespective of emptiness.
auto [blockReductionH, maybeBlockTrailingH] = buildSelectFirstNonEmpty(
@@ -483,11 +484,15 @@
Value mlir::iree_compiler::buildMemoryOptimizations(ImplicitLocOpBuilder &b,
Value funcH) {
ApplyPatternsOpPatterns configuration;
- configuration.lowerTransferOpPermutations = true;
configuration.rankReducingVector = true;
// Apply canonicalizations and enablings twice as they enable each other.
- buildCanonicalizationAndEnablingTransforms(b, configuration, funcH);
- buildCanonicalizationAndEnablingTransforms(b, configuration, funcH);
+ for (int i = 0; i < 2; ++i) {
+ buildCanonicalizationAndEnablingTransforms(
+ b, funcH, [](OpBuilder &b, Location loc) {
+ b.create<transform::ApplyTransferPermutationPatternsOp>(loc);
+ b.create<transform::ApplyCastAwayVectorLeadingOneDimPatternsOp>(loc);
+ });
+ }
b.create<ApplyBufferOptimizationsOp>(funcH);
return funcH;
}
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.h b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.h
index 9e7ea6c..67d855c 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.h
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/Common/Common.h
@@ -78,15 +78,15 @@
/// `handles` is empty.
void buildPrint(ImplicitLocOpBuilder &b, ValueRange handles = {});
+using ApplyPatternsOpBodyBuilderFn = std::function<void(OpBuilder &, Location)>;
+
/// Create an ApplyPatternsOp that performs a set of key canonicalizations and
/// so-called enabling transformations to normalize the IR.
-/// Take an existing configuration by copy (cheap object) that will be augmented
-/// locally to additionally perform:
+/// In addition to the specified transform, perform the following ones:
/// canonicalization, tiling_canonicalization, licm and cse (in this order).
-Value buildCanonicalizationAndEnablingTransforms(
- ImplicitLocOpBuilder &b,
- IREE::transform_dialect::ApplyPatternsOpPatterns configuration,
- Value variantH);
+void buildCanonicalizationAndEnablingTransforms(
+ ImplicitLocOpBuilder &b, Value variantH,
+ ApplyPatternsOpBodyBuilderFn populatePatternsFn = nullptr);
/// Build transform IR to dynamically selects the first non-empty handle; i.e.
/// if (h1, h2) is:
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.cpp
index c59e7b3..47d7130 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/Common.cpp
@@ -269,18 +269,19 @@
Value funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
// Step N-5. Fold tensor.empty to avoid large allocations.
- ApplyPatternsOpPatterns configuration;
- configuration.foldTensorEmptyExtract = true;
-
// Step N-4. Perform a pass of canonicalization + enabling after tiling.
- funcH = mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(
- b, configuration, funcH);
+ mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(
+ b, funcH, [](OpBuilder &b, Location loc) {
+ b.create<transform::ApplyFoldTensorEmptyPatternsOp>(loc);
+ });
funcH = iree_compiler::buildVectorize(b, funcH);
// Step N-3. Perform a pass of canonicalization + enabling after vectorization
// as well as hoisting subset operations such as vector.transfer_read/write.
- funcH = mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(
- b, configuration, funcH);
+ mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(
+ b, funcH, [](OpBuilder &b, Location loc) {
+ b.create<transform::ApplyFoldTensorEmptyPatternsOp>(loc);
+ });
iree_compiler::buildHoisting(b, funcH);
// Step N-2. Bufferize and drop HAL descriptor from memref ops.
@@ -294,8 +295,10 @@
// Step N. Perform a final pass of canonicalization + enabling before
// returning.
- variantH = mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(
- b, configuration, variantH);
+ mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(
+ b, variantH, [](OpBuilder &b, Location loc) {
+ b.create<transform::ApplyFoldTensorEmptyPatternsOp>(loc);
+ });
return std::make_pair(variantH, funcH);
}
@@ -328,11 +331,13 @@
// Perform a pass of canonicalization cleanups + folding fill + pad into pad
// by applying `foldTensorSubsets` and `tilingCanonicalization`.
{
- ApplyPatternsOpPatterns config;
- config.foldTensorSubsets = true;
- config.tilingCanonicalization = true;
- iree_compiler::buildCanonicalizationAndEnablingTransforms(b, config,
- variantH);
+ iree_compiler::buildCanonicalizationAndEnablingTransforms(
+ b, variantH, [](OpBuilder &b, Location loc) {
+ b.create<transform::ApplyFoldTensorSubsetOpsPatternsOp>(loc);
+ b.create<
+ transform::ApplyMergeConsecutiveInsertExtractSlicePatternsOp>(
+ loc);
+ });
}
// The canonicalization above should have rewritten hoistPad into a FillOp.
@@ -475,9 +480,7 @@
// Also, no canonicalization is allowed after vector masking and before we
// lower the masks: masks are currently quite brittle and do not like
// canonicalization or anything else that may insert an op in their region.
- ApplyPatternsOpPatterns configuration;
- variantH = iree_compiler::buildCanonicalizationAndEnablingTransforms(
- b, configuration, variantH);
+ iree_compiler::buildCanonicalizationAndEnablingTransforms(b, variantH);
// Apply vector masking.
if (!strategy.alignedLhs()) {
@@ -525,8 +528,7 @@
ImplicitLocOpBuilder &b, Value funcH,
const AbstractGemmLikeStrategy &strategy) {
// TODO: Fewer canonicalization.
- iree_compiler::buildCanonicalizationAndEnablingTransforms(
- b, ApplyPatternsOpPatterns(), funcH);
+ iree_compiler::buildCanonicalizationAndEnablingTransforms(b, funcH);
b.create<iree_compiler::IREE::transform_dialect::HoistStaticAllocOp>(funcH);
{
ApplyPatternsOpPatterns config;
@@ -538,8 +540,7 @@
config.extractAddressComputations = true;
b.create<ApplyPatternsOp>(funcH, config);
}
- iree_compiler::buildCanonicalizationAndEnablingTransforms(
- b, ApplyPatternsOpPatterns(), funcH);
+ iree_compiler::buildCanonicalizationAndEnablingTransforms(b, funcH);
{
ApplyPatternsOpPatterns config;
if (strategy.useMmaSync)
@@ -551,8 +552,7 @@
// TODO: not a functional style transform and avoid returning funcH.
funcH = b.create<transform::HoistRedundantVectorTransfersOp>(
transform::AnyOpType::get(b.getContext()), funcH);
- iree_compiler::buildCanonicalizationAndEnablingTransforms(
- b, ApplyPatternsOpPatterns(), funcH);
+ iree_compiler::buildCanonicalizationAndEnablingTransforms(b, funcH);
b.create<ApplyBufferOptimizationsOp>(funcH);
auto vectorToMMaConversionOp =
b.create<iree_compiler::IREE::transform_dialect::VectorToMMAConversionOp>(
@@ -568,8 +568,7 @@
void mlir::iree_compiler::gpu::buildMultiBuffering(
ImplicitLocOpBuilder &b, Value funcH,
const AbstractGemmLikeStrategy &strategy) {
- iree_compiler::buildCanonicalizationAndEnablingTransforms(
- b, ApplyPatternsOpPatterns(), funcH);
+ iree_compiler::buildCanonicalizationAndEnablingTransforms(b, funcH);
ApplyPatternsOpPatterns config;
config.foldMemrefAliases = true;
b.create<ApplyPatternsOp>(funcH, config);
@@ -600,16 +599,16 @@
transferToScfOp.setMaxTransferRank(1);
transferToScfOp.setFullUnroll(true);
});
- iree_compiler::buildCanonicalizationAndEnablingTransforms(
- b, ApplyPatternsOpPatterns(), funcH);
+ iree_compiler::buildCanonicalizationAndEnablingTransforms(b, funcH);
auto createAsyncGroupOp =
b.create<iree_compiler::IREE::transform_dialect::CreateAsyncGroupsOp>(
TypeRange{}, funcH);
// TODO: proper builder instead of a setting post-hoc.
createAsyncGroupOp.setUseMmaSync(strategy.useMmaSync);
- ApplyPatternsOpPatterns config;
- config.foldMemrefAliases = true;
- iree_compiler::buildCanonicalizationAndEnablingTransforms(b, config, funcH);
+ iree_compiler::buildCanonicalizationAndEnablingTransforms(
+ b, funcH, [](OpBuilder &b, Location loc) {
+ b.create<transform::ApplyFoldMemrefAliasOpsPatternsOp>(loc);
+ });
return funcH;
}
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/MatmulTensorCoreStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/MatmulTensorCoreStrategy.cpp
index fade172..44acdfe 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/MatmulTensorCoreStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectStrategies/GPU/MatmulTensorCoreStrategy.cpp
@@ -245,9 +245,7 @@
// Running canonicalization is required here to enable aligned pads to become
// linalg.copy ops when rewriting in DPS.
- ApplyPatternsOpPatterns config;
- iree_compiler::buildCanonicalizationAndEnablingTransforms(b, config,
- variantH);
+ iree_compiler::buildCanonicalizationAndEnablingTransforms(b, variantH);
// Step 4. Distribute pad and copies: SIMT programming model.
auto [lhsCopyOpH, rhsCopyOpH, copyBackOpH] =
diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
index 5d5b194..b5cc299 100644
--- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
+++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
@@ -101,6 +101,7 @@
mlir::scf::registerTransformDialectExtension(registry);
mlir::tensor::registerFindPayloadReplacementOpInterfaceExternalModels(
registry);
+ mlir::tensor::registerTransformDialectExtension(registry);
mlir::vector::registerTransformDialectExtension(registry);
// Dialect extensions.
diff --git a/tests/transform_dialect/cpu/attention_codegen_spec.mlir b/tests/transform_dialect/cpu/attention_codegen_spec.mlir
index 37511c2..a91ef15 100644
--- a/tests/transform_dialect/cpu/attention_codegen_spec.mlir
+++ b/tests/transform_dialect/cpu/attention_codegen_spec.mlir
@@ -26,8 +26,13 @@
%func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.apply_patterns %func { rank_reducing_linalg, rank_reducing_vector } : (!transform.any_op) -> ()
%func_3 = transform.structured.vectorize %func : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_3 {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %variant_op
- { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+ { canonicalization, licm, cse } : (!transform.any_op) -> ()
// Bufferization
// ==========================================
diff --git a/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir b/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir
index 3641acb..61fe536 100644
--- a/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir
+++ b/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir
@@ -14,8 +14,13 @@
// Canonicalization/CSE is needed before bufferization otherwise unnecessary
// allocs will be created.
+ transform.apply_patterns to %variant_op {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %variant_op
- { canonicalization, tiling_canonicalization, cse } : (!transform.any_op) -> ()
+ { canonicalization, cse } : (!transform.any_op) -> ()
%variant_op_3 = transform.iree.bufferize %variant_op : (!transform.any_op) -> (!transform.any_op)
%memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
: (!transform.any_op) -> !transform.any_op
diff --git a/tests/transform_dialect/cuda/double_mma_layout_analysis_codegen_spec.mlir b/tests/transform_dialect/cuda/double_mma_layout_analysis_codegen_spec.mlir
index 9c0af91..a2d9a75 100644
--- a/tests/transform_dialect/cuda/double_mma_layout_analysis_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/double_mma_layout_analysis_codegen_spec.mlir
@@ -28,8 +28,13 @@
// Step 4. Bufferize
// ===========================================================================
+ transform.apply_patterns to %func_3 {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %func_3
- { fold_reassociative_reshapes, canonicalization, tiling_canonicalization, cse } : (!transform.any_op) -> ()
+ { fold_reassociative_reshapes, canonicalization, cse } : (!transform.any_op) -> ()
transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
transform.iree.apply_patterns %func_3 { erase_unnecessary_tensor_operands } : (!transform.any_op) -> ()
%variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!transform.any_op) -> (!transform.any_op)
diff --git a/tests/transform_dialect/cuda/mma_elemwise_layout_analysis_codegen_spec.mlir b/tests/transform_dialect/cuda/mma_elemwise_layout_analysis_codegen_spec.mlir
index 6eaa73f..04777a4 100644
--- a/tests/transform_dialect/cuda/mma_elemwise_layout_analysis_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/mma_elemwise_layout_analysis_codegen_spec.mlir
@@ -26,8 +26,13 @@
// Step 4. Bufferize
// ===========================================================================
+ transform.apply_patterns to %func_3 {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %func_3
- { fold_reassociative_reshapes, canonicalization, tiling_canonicalization, cse } : (!transform.any_op) -> ()
+ { fold_reassociative_reshapes, canonicalization, cse } : (!transform.any_op) -> ()
transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
transform.iree.apply_patterns %func_3 { erase_unnecessary_tensor_operands } : (!transform.any_op) -> ()
%variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!transform.any_op) -> (!transform.any_op)
diff --git a/tests/transform_dialect/cuda/mma_reduction_layout_analysis_codegen_spec.mlir b/tests/transform_dialect/cuda/mma_reduction_layout_analysis_codegen_spec.mlir
index 0a4e35c..957cc8d 100644
--- a/tests/transform_dialect/cuda/mma_reduction_layout_analysis_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/mma_reduction_layout_analysis_codegen_spec.mlir
@@ -27,8 +27,13 @@
// Step 4. Bufferize
// ===========================================================================
+ transform.apply_patterns to %func_3 {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %func_3
- { fold_reassociative_reshapes, canonicalization, tiling_canonicalization, cse } : (!transform.any_op) -> ()
+ { fold_reassociative_reshapes, canonicalization, cse } : (!transform.any_op) -> ()
transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
transform.iree.apply_patterns %func_3 { erase_unnecessary_tensor_operands } : (!transform.any_op) -> ()
%variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!transform.any_op) -> (!transform.any_op)
diff --git a/tests/transform_dialect/cuda/mma_using_layout_analysis_codegen_spec.mlir b/tests/transform_dialect/cuda/mma_using_layout_analysis_codegen_spec.mlir
index dc91a60..25914a8 100644
--- a/tests/transform_dialect/cuda/mma_using_layout_analysis_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/mma_using_layout_analysis_codegen_spec.mlir
@@ -31,8 +31,13 @@
// Step 4. Bufferize
// ===========================================================================
+ transform.apply_patterns to %func_3 {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %func_3
- { fold_reassociative_reshapes, canonicalization, tiling_canonicalization, cse } : (!transform.any_op) -> ()
+ { fold_reassociative_reshapes, canonicalization, cse } : (!transform.any_op) -> ()
transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
transform.iree.apply_patterns %func_3 { erase_unnecessary_tensor_operands } : (!transform.any_op) -> ()
%variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!transform.any_op) -> (!transform.any_op)
diff --git a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
index 69b054b..e54b8d8 100644
--- a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
@@ -33,8 +33,13 @@
transform.structured.fuse_into_containing_op %fill_1d into %forall_block_combiner_op : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
// Canonicalizations.
+ transform.apply_patterns to %variant_op {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %variant_op
- { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+ { canonicalization, licm, cse } : (!transform.any_op) -> ()
%fill_2d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1x2xf32> in %variant_op
: (!transform.any_op) -> !transform.any_op
@@ -86,6 +91,11 @@
// Late Canonicalizations.
+ transform.apply_patterns to %variant_op_3 {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %variant_op_3
- { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+ { canonicalization, licm, cse } : (!transform.any_op) -> ()
}
diff --git a/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
index 5ec779b..ff4a48b 100644
--- a/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
@@ -17,8 +17,13 @@
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
// Canonicalizations.
+ transform.apply_patterns to %variant_op {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %variant_op
- { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+ { canonicalization, licm, cse } : (!transform.any_op) -> ()
// Step 2. First level of tiling + fusion parallelizes to blocks. Tile the
// trailing elementwise the same way we want to tile the reduction.
@@ -32,8 +37,13 @@
transform.structured.fuse_into_containing_op %not_eltwise into %grid_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
// Canonicalizations.
+ transform.apply_patterns to %variant_op {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %variant_op
- { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+ { canonicalization, licm, cse } : (!transform.any_op) -> ()
// Step 3. Second level of tiling + fusion parallelizes to threads.
// ===========================================================================
@@ -50,8 +60,13 @@
transform.structured.fuse_into_containing_op %combined_and_fill into %eltwise_block_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
// Canonicalizations.
+ transform.apply_patterns to %variant_op {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %variant_op
- { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+ { canonicalization, licm, cse } : (!transform.any_op) -> ()
%fill_2d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1x2xf32> in %variant_op
: (!transform.any_op) -> !transform.any_op
@@ -65,8 +80,13 @@
transform.structured.fuse_into_containing_op %fill_2d into %forall_block_more_parallel_op : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
// Canonicalizations.
+ transform.apply_patterns to %variant_op {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %variant_op
- { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+ { canonicalization, licm, cse } : (!transform.any_op) -> ()
// Step 4. Rank-reduce and vectorize.
// ===========================================================================
@@ -106,6 +126,11 @@
// Late canonicalizations.
+ transform.apply_patterns to %variant_op_3 {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %variant_op_3
- { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+ { canonicalization, licm, cse } : (!transform.any_op) -> ()
}
diff --git a/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir
index f4cb8e3..e8fa225 100644
--- a/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir
@@ -48,8 +48,13 @@
// ===========================================================================
// Canonicalization/CSE is needed before bufferization otherwise unnecessary
// allocs will be created.
+ transform.apply_patterns to %func_3 {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %func_3
- { fold_reassociative_reshapes, canonicalization, tiling_canonicalization, cse } : (!transform.any_op) -> ()
+ { fold_reassociative_reshapes, canonicalization, cse } : (!transform.any_op) -> ()
transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
%func_5 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.apply_patterns %func_5 { erase_unnecessary_tensor_operands } : (!transform.any_op) -> ()
@@ -74,6 +79,11 @@
: (!transform.any_op) -> ()
// Late canonicalizations to cleanup and pass the checks
+ transform.apply_patterns to %func_7 {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %func_7
- { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+ { canonicalization, licm, cse } : (!transform.any_op) -> ()
}
diff --git a/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
index a4b0ebe..1740b10 100644
--- a/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
@@ -18,8 +18,13 @@
transform.structured.fuse_into_containing_op %fill into %forall_grid : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
// Canonicalizations.
+ transform.apply_patterns to %variant_op {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %variant_op
- { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+ { canonicalization, licm, cse } : (!transform.any_op) -> ()
// Step 2. Split the reduction to get meatier parallelism.
// This also parallelizes to threads.
@@ -41,8 +46,13 @@
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
// Canonicalizations.
+ transform.apply_patterns to %variant_op {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %variant_op
- { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+ { canonicalization, licm, cse } : (!transform.any_op) -> ()
// Step 3. Rank-reduce and vectorize.
// ===========================================================================
@@ -56,8 +66,13 @@
// Canonicalizations is necessary to get rid of some tensor.cast that block
// hoisting.
+ transform.apply_patterns to %variant_op {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %variant_op
- { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+ { canonicalization, licm, cse } : (!transform.any_op) -> ()
transform.structured.hoist_redundant_tensor_subsets %func_3
: (!transform.any_op) -> ()
@@ -65,8 +80,13 @@
// Step 4. Bufferize and drop HAL descriptor from memref ops.
// ===========================================================================
// Canonicalizations required before bufferization to avoid unnecessary allocs.
+ transform.apply_patterns to %variant_op {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %variant_op
- { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+ { canonicalization, licm, cse } : (!transform.any_op) -> ()
transform.iree.apply_patterns %func_3 { fold_reassociative_reshapes } : (!transform.any_op) -> ()
transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
%func_6 = transform.structured.match ops{["func.func"]} in %variant_op
@@ -96,6 +116,11 @@
: (!transform.any_op) -> ()
// Late canonicalizations.
+ transform.apply_patterns to %variant_op_3 {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %variant_op_3
- { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+ { canonicalization, licm, cse } : (!transform.any_op) -> ()
}
diff --git a/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
index 0edf60d..93115ab 100644
--- a/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
@@ -35,8 +35,13 @@
: (!transform.op<"scf.forall">) -> !transform.op<"scf.forall">
// Canonicalizations.
+ transform.apply_patterns to %variant_op {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %variant_op
- { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+ { canonicalization, licm, cse } : (!transform.any_op) -> ()
// Step 2. Second level of tiling + fusion parallelizes to threads.
@@ -70,8 +75,13 @@
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
// Canonicalizations.
+ transform.apply_patterns to %variant_op {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %variant_op
- { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+ { canonicalization, licm, cse } : (!transform.any_op) -> ()
// Step 3. Rank-reduce and vectorize.
// ==================================
@@ -104,6 +114,11 @@
// Late canonicalizations.
+ transform.apply_patterns to %variant_op_3 {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %variant_op_3
- { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+ { canonicalization, licm, cse } : (!transform.any_op) -> ()
}
diff --git a/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir b/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir
index 93b4370..2fff8d8 100644
--- a/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir
+++ b/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir
@@ -11,6 +11,11 @@
// Late canonicalizations to cleanup and pass the checks.
// Needs to occur on the whole variant to perform cse on the workgroup_count region
+ transform.apply_patterns to %variant_op {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ } : !transform.any_op
transform.iree.apply_patterns %variant_op
- { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+ { canonicalization, licm, cse } : (!transform.any_op) -> ()
}