Evolve transform dialect usage towards non-blanket-canonicalized sequences (#12465) This revision evolves our usage of the transform dialect to be more in line with recent evolutions upstream related to rewriters and listeners. This is also related to #12444 for which it aims at surfacing potential errors more proactively. This rewrite surfaces 2 places upstream that do not take a RewriterBase and that we should fix to avoid footguns: 1. `moveLoopInvariantCode` 2. `promoteIfSingleIteration`
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD index fa01046..c0d82a6 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD
@@ -76,6 +76,7 @@ "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineUtils", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithUtils", @@ -91,6 +92,8 @@ "@llvm-project//mlir:PDLDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:SCFUtils", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TensorTransforms", "@llvm-project//mlir:TransformDialect",
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt index 6d6d6de..c217c18 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
@@ -36,6 +36,7 @@ IREELinalgTransformDialect LLVMSupport MLIRAffineDialect + MLIRAffineUtils MLIRAnalysis MLIRArithDialect MLIRArithUtils @@ -51,6 +52,8 @@ MLIRPDLDialect MLIRPass MLIRSCFDialect + MLIRSCFTransforms + MLIRSCFUtils MLIRTensorDialect MLIRTensorTransforms MLIRTransformDialect
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp index 00e8fa2..eefb051 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -9,6 +9,7 @@ #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree-dialects/Dialect/LinalgTransform/SimplePatternRewriter.h" #include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" +#include "iree-dialects/Transforms/ListenerCSE.h" #include "iree-dialects/Transforms/TransformMatchers.h" #include "iree/compiler/Codegen/Common/Transforms.h" #include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h" @@ -19,7 +20,9 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/TypeSwitch.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -30,6 +33,8 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" @@ -37,6 +42,7 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" using namespace mlir; using namespace mlir::iree_compiler; @@ -135,6 +141,7 @@ ADD_PATTERN(additionalIreePatterns, getAdditionalIreePatternsAttrName) ADD_PATTERN(bubbleCollapseExpand, getBubbleCollapseExpandAttrName) ADD_PATTERN(canonicalization, getCanonicalizationAttrName) + ADD_PATTERN(cse, getCseAttrName) ADD_PATTERN(eraseUnnecessaryTensorOperands, getEraseUnnecessaryTensorOperandsAttrName) ADD_PATTERN(expandMemrefStridedMetadata, @@ -142,6 +149,7 @@ ADD_PATTERN(foldMemrefAliases, getFoldMemrefAliasesAttrName) ADD_PATTERN(foldReassociativeReshapes, getFoldReassociativeReshapesAttrName) ADD_PATTERN(foldTensorEmptyExtract, getFoldTensorEmptyExtractAttrName) + ADD_PATTERN(licm, getLicmAttrName) ADD_PATTERN(lowerTransferOpPermutations, getLowerTransferOpPermutationsAttrName) ADD_PATTERN(rankReducingLinalg, getRankReducingLinalgAttrName) @@ -149,6 +157,7 @@ ADD_PATTERN(swapPaddingElideConditional, getSwapPaddingElideConditionalAttrName) ADD_PATTERN(swappingPatterns, getSwappingPatternsAttrName) + ADD_PATTERN(tilingCanonicalization, getTilingCanonicalizationAttrName) ADD_PATTERN(unrollVectorsGpuMmaSync, getUnrollVectorsGpuMmaSyncAttrName) ADD_PATTERN(unrollVectorsGpuWmma, getUnrollVectorsGpuWmmaAttrName) #undef ADD_PATTERN @@ -236,6 +245,11 @@ }); } +static void addTilingCanonicalizationPatterns(RewritePatternSet &patterns) { + linalg::populateLinalgTilingCanonicalizationPatterns(patterns); + scf::populateSCFForLoopCanonicalizationPatterns(patterns); +} + static Optional<SmallVector<int64_t>> getGPUTensorCoreNativeMmaSyncVectorSize( Operation *op) { return getMmaNativeVectorSize(op); @@ -294,9 +308,12 @@ } MLIRContext *ctx = target->getContext(); RewritePatternSet patterns(ctx); + if (getAdditionalIreePatterns()) addAdditionalIreePatterns(patterns); + if (getBubbleCollapseExpand()) { + linalg::populateFoldReshapeOpsByExpansionPatterns( + patterns, [](OpOperand *) { return true; }); + } if (getCanonicalization()) addAllRegisteredCanonicalizationPatterns(patterns); - if (getLowerTransferOpPermutations()) - addLowerTransferOpPermutationsPatterns(patterns); if (getEraseUnnecessaryTensorOperands()) addEraseUnnecessaryTensorOperandsPatterns(patterns); if (getExpandMemrefStridedMetadata()) @@ -304,15 +321,13 @@ if (getFoldMemrefAliases()) addFoldMemrefAliasPatterns(patterns); if (getFoldReassociativeReshapes()) addReassociativeReshapePatterns(patterns); if (getFoldTensorEmptyExtract()) addFoldTensorEmptyExtract(patterns); + if (getLowerTransferOpPermutations()) + addLowerTransferOpPermutationsPatterns(patterns); if (getRankReducingLinalg()) addRankReducingLinalgPatterns(patterns); if (getRankReducingVector()) addRankReducingVectorPatterns(patterns); if (getSwappingPatterns()) addSwappingPatterns(patterns, getSwapPaddingElideConditional()); - if (getAdditionalIreePatterns()) addAdditionalIreePatterns(patterns); - if (getBubbleCollapseExpand()) { - linalg::populateFoldReshapeOpsByExpansionPatterns( - patterns, [](OpOperand *) { return true; }); - } + if (getTilingCanonicalization()) addTilingCanonicalizationPatterns(patterns); if (getUnrollVectorsGpuMmaSync()) addUnrollVectorsGpuMmaSyncPatterns(patterns); if (getUnrollVectorsGpuWmma()) addUnrollVectorsGpuWmmaPatterns(patterns); @@ -328,13 +343,57 @@ }); LogicalResult result = applyOpPatternsAndFold(ops, std::move(patterns), config); - LogicalResult listenerResult = listener.checkErrorState(); if (failed(result)) { - return mlir::emitDefiniteFailure(target, - "greedy pattern application failed"); + return mlir::emitDefiniteFailure(target, "greedy patterns failed"); } + LogicalResult listenerResult = listener.checkErrorState(); if (failed(listenerResult)) - return mlir::emitDefiniteFailure(target, "listener tracking failed"); + return mlir::emitDefiniteFailure(target, "pattern listener tracker fail"); + + if (getLicm()) { + target->walk([&](func::FuncOp funcOp) { + // This assumes LICM never removes operations so we don't need tracking. + // TODO: confirm / revisit this assumption and plumb a rewriter through + // upstream moveLoopInvariantCode if necessary. + funcOp->walk([](LoopLikeOpInterface loopLike) { + moveLoopInvariantCode(loopLike); + }); + // For now, put single loop promotion as part of licm. Underlying + // implementations perform splice operations which shouldn't need + // tracking. + // TODO: confirm / revisit this assumption and plumb a rewriter through + // upstream moveLoopInvariantCode if necessary. + funcOp->walk([](Operation *op) { + (void)llvm::TypeSwitch<Operation *, LogicalResult>(op) + .Case<AffineForOp, scf::ForOp>( + [](auto loop) { return promoteIfSingleIteration(loop); }) + .Default([](Operation *) { return success(); }); + }); + }); + } + + if (getCse()) { + func::FuncOp lastFuncVisited; + auto walkResult = target->walk([&](func::FuncOp funcOp) -> WalkResult { + lastFuncVisited = funcOp; + result = + eliminateCommonSubexpressions(funcOp, /*domInfo=*/nullptr, &listener); + if (failed(result)) return WalkResult::interrupt(); + listenerResult = listener.checkErrorState(); + if (failed(listenerResult)) return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) { + if (failed(result)) { + return mlir::emitDefiniteFailure(lastFuncVisited, + "greedy patterns failed"); + } + LogicalResult listenerResult = listener.checkErrorState(); + if (failed(listenerResult)) + return mlir::emitDefiniteFailure(lastFuncVisited, + "pattern listener tracker fail"); + } + } results.push_back(target); return DiagnosedSilenceableFailure::success();
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h index c0cefc1..c1064f6 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
@@ -38,17 +38,20 @@ bool additionalIreePatterns = false; bool bubbleCollapseExpand = false; bool canonicalization = false; + bool cse = false; bool eraseUnnecessaryTensorOperands = false; bool expandMemrefStridedMetadata = false; bool foldMemrefAliases = false; bool foldReassociativeReshapes = false; bool foldTensorEmptyExtract = false; + bool licm = false; bool lowerTransferOpPermutations = false; bool promoteForallCaptureToShared = false; bool rankReducingLinalg = false; bool rankReducingVector = false; bool swapPaddingElideConditional = false; bool swappingPatterns = false; + bool tilingCanonicalization = false; bool unrollVectorsGpuMmaSync = false; bool unrollVectorsGpuWmma = false; };
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td index 4317058..a55c29d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
@@ -78,6 +78,10 @@ down across Linalg ops. - canonicalization: adds all the canonicalization patterns of all registered dialects and ops. + - cse: additionally apply common subexpression elimination. This must + apply on a funcOp. This is not a set of patterns per se but is still very + convenient to apply it close to canonicalization and other greedy pattern + applications. - erase_unnecessary_tensor_operands: add patterns that erase unnecessary tensor operands. - expand_memref_strided_metadata: adds patterns that expand memref @@ -89,6 +93,10 @@ extract_slice ops with reassociative reshape ops. - fold_tensor_empty_extract: Fold tensor.empty used by extract_slice in case it is the only use of extract. + - licm: additionally apply loop-independent code motion and single + iteration loop promotion. This is not a set of patterns per se but is still + very convenient to apply it close to canonicalization and other greedy + pattern applications. - lower_transfer_op_permutations: Lower transfer ops to transfer ops with minor identity permutations. - rank_reducing_linalg: adds patterns that results in rank-reducing @@ -102,6 +110,8 @@ tensor.extract_slice swapping pattern. This injects static information that guarantees padding is smaller than the window size which guarantees we never see a tile comprised of padding-only. + - tiling_canonicalization: adds specific tiling-related canonicalization + patterns. - unroll_vectors_gpu_mma_sync: adds patterns that unroll vectors to a native tile size for GPUs with mma operations. The size is currently hardcoded but should be refactored upstream and made pluggable. @@ -127,16 +137,19 @@ UnitAttr:$additional_iree_patterns, UnitAttr:$bubble_collapse_expand, UnitAttr:$canonicalization, + UnitAttr:$cse, UnitAttr:$erase_unnecessary_tensor_operands, UnitAttr:$expand_memref_strided_metadata, UnitAttr:$fold_memref_aliases, UnitAttr:$fold_reassociative_reshapes, UnitAttr:$fold_tensor_empty_extract, + UnitAttr:$licm, UnitAttr:$lower_transfer_op_permutations, UnitAttr:$rank_reducing_linalg, UnitAttr:$rank_reducing_vector, UnitAttr:$swap_padding_elide_conditional, UnitAttr:$swapping_patterns, + UnitAttr:$tiling_canonicalization, UnitAttr:$unroll_vectors_gpu_mma_sync, UnitAttr:$unroll_vectors_gpu_wmma); let results = (outs PDL_Operation:$result);
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/reductions_codegen_spec.mlir b/compiler/src/iree/compiler/Codegen/Common/test/reductions_codegen_spec.mlir index 360ba82..8811b76 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/reductions_codegen_spec.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/reductions_codegen_spec.mlir
@@ -1,6 +1,6 @@ // RUN: iree-opt %s -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): transform.iree.register_match_callbacks
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/reductions_match_spec.mlir b/compiler/src/iree/compiler/Codegen/Common/test/reductions_match_spec.mlir index 7ba482c..22311d6 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/reductions_match_spec.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/reductions_match_spec.mlir
@@ -1,6 +1,6 @@ // RUN: iree-opt %s -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): transform.iree.register_match_callbacks
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/repeated_matcher_use.mlir b/compiler/src/iree/compiler/Codegen/Common/test/repeated_matcher_use.mlir index d4dd18d..4c63f9a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/repeated_matcher_use.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/repeated_matcher_use.mlir
@@ -1,7 +1,7 @@ // RUN: iree-opt %s --iree-transform-dialect-interpreter --verify-diagnostics --split-input-file module { - transform.structured.canonicalized_sequence failures(propagate) { + transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): transform.iree.register_match_callbacks @@ -54,7 +54,7 @@ // expected-error @below {{transform dialect interpreter failed}} module { - transform.structured.canonicalized_sequence failures(propagate) { + transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): transform.iree.register_match_callbacks @@ -109,7 +109,7 @@ // expected-error @below {{transform dialect interpreter failed}} module { - transform.structured.canonicalized_sequence failures(propagate) { + transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): transform.iree.register_match_callbacks @@ -161,7 +161,7 @@ // ----- module { - transform.structured.canonicalized_sequence failures(propagate) { + transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): transform.iree.register_match_callbacks
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_bufferize.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_bufferize.mlir index bf02a3a..9268b17 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_bufferize.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_bufferize.mlir
@@ -33,7 +33,7 @@ } } -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op %variant_op_3 = transform.iree.bufferize %variant_op_2
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_iree_tile_to_forall.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_iree_tile_to_forall.mlir index 94e984e..d9cce5d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_iree_tile_to_forall.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_iree_tile_to_forall.mlir
@@ -47,7 +47,7 @@ } } -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): %original_matmul = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!pdl.operation) -> !pdl.operation @@ -56,4 +56,9 @@ transform.iree.tile_to_forall_and_workgroup_count_region %original_matmul num_threads [32] ( mapping = [#gpu.block<x>] ) + + // Late canonicalizations to cleanup and pass the checks. + // Needs to occur on the whole variant to perform cse on the workgroup_count region + transform.iree.apply_patterns %variant_op + { canonicalization, tiling_canonicalization, licm, cse } }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp index 1156f48..c2d1c45 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
@@ -404,6 +404,7 @@ void transform_dialect::VectorWarpDistributionOp::build(OpBuilder &builder, OperationState &result, Value target) { + result.addTypes(pdl::OperationType::get(builder.getContext())); result.addOperands(target); } @@ -671,15 +672,10 @@ target, "warp execute on lane 0 to scf patterns failed to apply"); } + results.push_back(target); return DiagnosedSilenceableFailure::success(); } -void transform_dialect::VectorWarpDistributionOp::getEffects( - SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { - transform::onlyReadsHandle(getTarget(), effects); - transform::modifiesPayload(effects); -} - void transform_dialect::VectorToMMAConversionOp::build(OpBuilder &builder, OperationState &result, Value target) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td index 33584a9..81f41c2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
@@ -213,9 +213,10 @@ } def VectorWarpDistributionOp : Op<Transform_Dialect, "iree.vector.warp_distribute", - [TransformEachOpTrait, - TransformOpInterface, - DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> { + [FunctionalStyleTransformOpTrait, + MemoryEffectsOpInterface, + TransformEachOpTrait, + TransformOpInterface]> { let description = [{ Given a vector.warp_execute_on_lane_0, apply the patterns to rewrite into distributed form with warp synchronization. This produces IR that runs @@ -302,10 +303,10 @@ ``` }]; - let arguments = (ins PDL_Operation:$target); - let results = (outs); + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$result); - let assemblyFormat = "$target attr-dict"; + let assemblyFormat = "$target attr-dict `:` functional-type($target, results)"; let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect"; let skipDefaultBuilders = 1;
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/create_async_groups.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/create_async_groups.mlir index a4f4cde..e96b768 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/create_async_groups.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/create_async_groups.mlir
@@ -22,7 +22,7 @@ return } - transform.structured.canonicalized_sequence failures(propagate) { + transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!pdl.operation) -> !pdl.operation %transformed_func = transform.iree.create_async_groups %top_level_func {use_mma_sync = true} : (!pdl.operation) -> (!pdl.operation) @@ -54,7 +54,7 @@ return } - transform.structured.canonicalized_sequence failures(propagate) { + transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!pdl.operation) -> !pdl.operation %transformed_func = transform.iree.create_async_groups %top_level_func {use_mma_sync = false} : (!pdl.operation) -> (!pdl.operation) @@ -82,7 +82,7 @@ return } - transform.structured.canonicalized_sequence failures(propagate) { + transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!pdl.operation) -> !pdl.operation %vector_transfer = transform.structured.match ops{["memref.alloc"]} in %top_level_func : (!pdl.operation) -> !pdl.operation
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_bufferize.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_bufferize.mlir index dea32e2..02ec5bb 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_bufferize.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_bufferize.mlir
@@ -27,7 +27,7 @@ } } - transform.structured.canonicalized_sequence failures(propagate) { + transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op_2
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_bufferize_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_bufferize_spec.mlir index e4d0557..87d7bb3 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_bufferize_spec.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_bufferize_spec.mlir
@@ -1,4 +1,4 @@ -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op %variant_op_3 = transform.iree.bufferize %variant_op_2
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir index 8ee8556..4314ef1 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
@@ -1,4 +1,4 @@ -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): %0 = transform.structured.match ops{["linalg.fill"]} in %variant_op : (!pdl.operation) -> !pdl.operation %forall, %tiled_fill = transform.structured.tile_to_forall_op %0 num_threads [5, 1] @@ -8,13 +8,23 @@ %forall_2, %tiled_matmul = transform.structured.tile_to_forall_op %1 num_threads [7, 9] ( mapping = [#gpu.thread<x>, #gpu.thread<y>] ) + // Canonicalization/CSE is needed before bufferization otherwise unnecessary + // allocs will be created. + %func = transform.structured.match ops{["func.func"]} in %variant_op + : (!pdl.operation) -> !pdl.operation + transform.iree.apply_patterns %func + { fold_reassociative_reshapes, canonicalization, tiling_canonicalization, cse } %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op %variant_op_3 = transform.iree.bufferize %variant_op_2 %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation transform.iree.erase_hal_descriptor_type_from_memref %memref_func // Get the function to which to apply to. - %2 = transform.structured.match ops{["linalg.matmul"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation - %func = transform.get_closest_isolated_parent %2 : (!pdl.operation) -> !pdl.operation - transform.iree.map_nested_forall_to_gpu_threads %func { workgroup_size = [10, 11]} + %func_2 = transform.structured.match ops{["linalg.matmul"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation + %func_3 = transform.get_closest_isolated_parent %func_2 : (!pdl.operation) -> !pdl.operation + %func_4 = transform.iree.map_nested_forall_to_gpu_threads %func_3 { workgroup_size = [10, 11]} + + // Late canonicalizations to cleanup and pass the checks + %func_5 = transform.iree.apply_patterns %func_4 + { canonicalization, tiling_canonicalization, licm, cse } }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_distribution_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_distribution_spec.mlir index d52461a..7f438da 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_distribution_spec.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_distribution_spec.mlir
@@ -1,7 +1,14 @@ -transform.structured.canonicalized_sequence failures(propagate) { -^bb1(%arg1: !pdl.operation): - %if_op = transform.structured.match ops{["scf.if"]} in %arg1 : (!pdl.operation) -> !pdl.operation +transform.sequence failures(propagate) { +^bb1(%variant_op: !pdl.operation): + %if_op = transform.structured.match ops{["scf.if"]} in %variant_op + : (!pdl.operation) -> !pdl.operation %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 } - %isolated = transform.get_closest_isolated_parent %warp : (!pdl.operation) -> !pdl.operation + %isolated = transform.get_closest_isolated_parent %warp + : (!pdl.operation) -> !pdl.operation transform.iree.vector.warp_distribute %isolated + : (!pdl.operation) -> !pdl.operation + + // Late canonicalizations to cleanup and pass the checks. + transform.iree.apply_patterns %variant_op + { canonicalization, tiling_canonicalization, licm, cse } }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_warp_execute_on_lane_0_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_warp_execute_on_lane_0_spec.mlir index 26cb9bf..ac01b43 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_warp_execute_on_lane_0_spec.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_warp_execute_on_lane_0_spec.mlir
@@ -1,5 +1,10 @@ -transform.structured.canonicalized_sequence failures(propagate) { -^bb1(%arg1: !pdl.operation): - %if_op = transform.structured.match ops{["scf.if"]} in %arg1 : (!pdl.operation) -> !pdl.operation +transform.sequence failures(propagate) { +^bb1(%variant_op: !pdl.operation): + %if_op = transform.structured.match ops{["scf.if"]} in %variant_op + : (!pdl.operation) -> !pdl.operation transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 } + + // Late canonicalizations to cleanup and pass the checks. + transform.iree.apply_patterns %variant_op + { canonicalization, tiling_canonicalization, licm, cse } }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_promote_operands.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_promote_operands.mlir index 0e18f1d..2657759 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_promote_operands.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_promote_operands.mlir
@@ -30,13 +30,17 @@ } } - transform.structured.canonicalized_sequence failures(propagate) { + transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): %matmul = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!pdl.operation) -> !pdl.operation %promoted_matmul, %alloc_0, %alloc_1 = transform.iree.promote_operands %matmul [0, 1] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) + + // Late canonicalizations to cleanup and pass the checks. + transform.iree.apply_patterns %variant_op + { canonicalization, tiling_canonicalization, licm, cse } } }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir index a527c53..c9539f3 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir
@@ -41,10 +41,16 @@ return } module { - transform.structured.canonicalized_sequence failures(propagate) { - ^bb0(%arg0: !pdl.operation): - %17 = transform.structured.match ops{["func.func"]} in %arg0 : (!pdl.operation) -> !pdl.operation + transform.sequence failures(propagate) { + ^bb0(%variant_op: !pdl.operation): + %17 = transform.structured.match ops{["func.func"]} in %variant_op + : (!pdl.operation) -> !pdl.operation %18 = transform.iree.map_nested_forall_to_gpu_threads %17 {workgroup_size = [256, 1, 1]} + + // Late canonicalizations to cleanup and pass the checks. + // Needs to occur on the whole variant to perform cse on the workgroup_count region + transform.iree.apply_patterns %variant_op + { canonicalization, tiling_canonicalization, licm, cse } } } }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_gpu_pipelining.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_gpu_pipelining.mlir index 23d14c7..9f99b28 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_gpu_pipelining.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_gpu_pipelining.mlir
@@ -52,7 +52,7 @@ return } } -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): %for = transform.structured.match ops{["scf.for"]} in %variant_op : (!pdl.operation) -> !pdl.operation %1 = transform.cast %for : !pdl.operation to !transform.op<"scf.for">
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_vector_to_mma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_vector_to_mma.mlir index c89d080..cf54d3d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_vector_to_mma.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_vector_to_mma.mlir
@@ -47,10 +47,15 @@ return } } -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): %func = transform.structured.match ops{["func.func"]} in %variant_op : (!pdl.operation) -> !pdl.operation %func_2 = transform.iree.apply_patterns %func { unroll_vectors_gpu_wmma } - transform.iree.vector.vector_to_mma_conversion %func_2 { use_wmma } + %func_3 = transform.iree.vector.vector_to_mma_conversion %func_2 { use_wmma } + + // Apply canonicalization post-hoc to trigger DCE and pass the test + // (i.e. all vector.contract are dead). + // TODO: consider having the vector_to_mma_conversion do the DCE automatically. + %func_4 = transform.iree.apply_patterns %func_3 { canonicalization } } }
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dialect_dispatch_spec.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dialect_dispatch_spec.mlir index 6539fd4..c98fdf1 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dialect_dispatch_spec.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dialect_dispatch_spec.mlir
@@ -1,4 +1,4 @@ -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation %foreach_op, %tiled_op = transform.structured.tile_to_forall_op %0 num_threads [42, 67]
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp index cf3cd22..264f47e 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
@@ -327,7 +327,9 @@ /// Find the linalg op that defines all values in range, potentially /// transitively through tensor casts. -static linalg::LinalgOp findSingleLinalgOpDefiningAll(ValueRange range) { +static FailureOr<linalg::LinalgOp> +findSingleLinalgOpDefiningAll(ValueRange range) { + LLVM_DEBUG(DBGS() << "Start findSingleLinalgOpDefiningAll\n"); linalg::LinalgOp sourceOp = nullptr; for (Value value : range) { // See through tensor casts and reshape ops. @@ -357,35 +359,41 @@ } LLVM_DEBUG( - DBGS() << "different source linalg ops for replacing one op: \n" + DBGS() << "--different source linalg ops for replacing one op: \n" << sourceOp << "\n" << currentSourceOp << "\n"); - return nullptr; + return failure(); } - LLVM_DEBUG(DBGS() << "replacing linalg op with unknown non-linalg op:\n" + LLVM_DEBUG(DBGS() << "--replacing linalg op with unknown non-linalg op:\n" << *value.getDefiningOp() << "\n"); - return nullptr; + return failure(); } return sourceOp; } /// Find the scf "for" op that defines all values in the range. -static scf::ForOp findSingleForOpDefiningAll(ValueRange range) { - scf::ForOp forOp = nullptr; +/// Take into account the the op may just disappear when it is replaced by its +/// body, in the case od a single iteration loop. +// It is unclear atm how to account for this properly. +static FailureOr<Operation *> findSingleForOpDefiningAll(ValueRange range) { + LLVM_DEBUG(DBGS() << "Start findSingleForOpDefiningAll\n"); + Operation *forOp = nullptr; for (Value value : range) { - if (auto currentSourceOp = value.getDefiningOp<scf::ForOp>()) { - if (!forOp || forOp == currentSourceOp) { - forOp = currentSourceOp; - continue; - } - LLVM_DEBUG( - DBGS() << "different source scf.for ops when replacing one op\n"); - return nullptr; + LLVM_DEBUG(DBGS() << "--find for: " << value << "\n"); + // Block arguments are just dropped. + auto currentSourceOp = value.getDefiningOp(); + if (!currentSourceOp) { + LLVM_DEBUG(DBGS() << "--replacing tracked op with bbarg -> SKIP\n"); + continue; } - + auto currentForOp = dyn_cast<scf::ForOp>(currentSourceOp); + if (!forOp || (currentForOp && forOp == currentForOp)) { + forOp = currentSourceOp; + continue; + } + LLVM_DEBUG(DBGS() << "---no single scf.for replacement found -> SKIP\n"); LLVM_DEBUG( - DBGS() - << "could not find a source scf.for when replacing another scf.for\n"); + DBGS() << "---WARNING: this will drop tracking of the scf.for\n"); return nullptr; } return forOp; @@ -420,15 +428,12 @@ return llvm::TypeSwitch<Operation *, FailureOr<Operation *>>(replacedOp) .Case<linalg::LinalgOp>([&](linalg::LinalgOp) -> FailureOr<Operation *> { auto op = findSingleLinalgOpDefiningAll(range); - if (!op) + if (failed(op)) return failure(); - return op.getOperation(); + return op->getOperation(); }) .Case<scf::ForOp>([&](scf::ForOp) -> FailureOr<Operation *> { - auto op = findSingleForOpDefiningAll(range); - if (!op) - return failure(); - return op.getOperation(); + return findSingleForOpDefiningAll(range); }) .Default([&](Operation *) -> FailureOr<Operation *> { return findSingleOpDefiningAll(range); @@ -443,15 +448,26 @@ // Exit early if the op is not tracked. SmallVector<Value> handles; - if (failed(getTransformState().getHandlesForPayloadOp(op, handles))) + if (failed(getTransformState().getHandlesForPayloadOp(op, handles))) { + LLVM_DEBUG(DBGS() << "no tracking handle to remove\n"); return; + } FailureOr<Operation *> replacement = findSingleDefiningOp(op, newValues); if (failed(replacement)) { + LLVM_DEBUG(DBGS() << "could not find replacement for tracked op\n"); emitError(op) << "could not find replacement for tracked op"; return; } + // If this would cause an error with replacement, drop instead. + if (*replacement && (*replacement)->getNumResults() != op->getNumResults()) { + LLVM_DEBUG(DBGS() << "failsafe error tracking activated due to mismatched " + "number of results for op: " + << op << " and replacement " << *replacement << "\n"); + replacement = nullptr; + } + if (*replacement == nullptr) { // TODO: Check if the handle is dead. Otherwise, the op should not be // dropped. This needs a change in the transform dialect interpreter. @@ -472,16 +488,17 @@ op->walk([&](Operation *op) { // Exit early if the op is not tracked. SmallVector<Value> handles; - if (failed(getTransformState().getHandlesForPayloadOp(op, handles))) + if (failed(getTransformState().getHandlesForPayloadOp(op, handles))) { + LLVM_DEBUG(DBGS() << "no tracking handle to remove\n"); return; - + } LLVM_DEBUG(DBGS() << "removing tracked @" << op << " : " << *op << "\n"); mayFail(replacePayloadOp(op, nullptr)); }); } void mlir::TrackingListener::removeMappings(Operation *op) { - // Bail if in error state. + // Bail out if in error state. if (hadErrors) return;
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-async.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-async.mlir index eb70e14..9439fd5 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-async.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-async.mlir
@@ -1,8 +1,5 @@ // RUN: iree-dialects-opt %s --transform-dialect-interpreter --split-input-file | FileCheck %s -// CHECK-DAG: #[[$MUL_MAP:.*]] = affine_map<(d0)[s0] -> (d0 * s0)> -// CHECK-DAG: #[[$SUB_MAP:.*]] = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> -// CHECK-DAG: #[[$ID1_MAP:.*]] = affine_map<(d0) -> (d0)> #map0 = affine_map<(d0)[s0] -> (d0 ceildiv s0)> #map1 = affine_map<(d0)[s0] -> (d0 * s0)> #map2 = affine_map<(d0, d1) -> (d0 - d1)> @@ -50,7 +47,7 @@ return } -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%module_op: !pdl.operation): %0 = transform.structured.match ops{["scf.forall"]} in %module_op : (!pdl.operation) -> !pdl.operation %1 = forall_to_async %0
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-scf-for.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-scf-for.mlir index 0aef6ab..a9a565d 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-scf-for.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-scf-for.mlir
@@ -1,8 +1,5 @@ // RUN: iree-dialects-opt %s --transform-dialect-interpreter --split-input-file | FileCheck %s -// CHECK-DAG: #[[$MUL_MAP:.*]] = affine_map<(d0)[s0] -> (d0 * s0)> -// CHECK-DAG: #[[$SUB_MAP:.*]] = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> -// CHECK-DAG: #[[$ID1_MAP:.*]] = affine_map<(d0) -> (d0)> #map0 = affine_map<(d0)[s0] -> (d0 ceildiv s0)> #map1 = affine_map<(d0)[s0] -> (d0 * s0)> #map2 = affine_map<(d0, d1) -> (d0 - d1)> @@ -22,7 +19,7 @@ // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[M:.*]] = memref.dim %{{.*}}, %{{.*}} : memref<?xf32> - // CHECK: scf.for %[[IV:.*]] = {{.*}} step %[[C1]] { + // CHECK: scf.for %[[IV:.*]] = {{.*}} scf.forall (%arg3) in (%1) shared_outs() -> () { %3 = affine.apply #map1(%arg3)[%arg0] %4 = affine.apply #map2(%0, %3) @@ -44,7 +41,7 @@ return } -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%module_op: !pdl.operation): %0 = transform.structured.match ops{["scf.forall"]} in %module_op : (!pdl.operation) -> !pdl.operation %1 = forall_to_scf_for %0
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/fuse-operands.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/fuse-operands.mlir index b790dbc..6e7b6bc 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/fuse-operands.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/fuse-operands.mlir
@@ -65,7 +65,7 @@ %2 = operation "scf.forall"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>) rewrite %2 with "transform.dialect" } - transform.structured.canonicalized_sequence %arg0 failures(propagate) { + transform.sequence %arg0 failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = pdl_match @match_elemwise in %arg1 : (!pdl.operation) -> !pdl.operation %1, %fusedOps:2 = fuse_producers %0 {operands_to_fuse=[0, 1]} @@ -131,7 +131,7 @@ %2 = operation "scf.forall"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>) rewrite %2 with "transform.dialect" } - transform.structured.canonicalized_sequence %arg0 failures(propagate) { + transform.sequence %arg0 failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = pdl_match @match_elemwise in %arg1 : (!pdl.operation) -> !pdl.operation %1, %fusedOps = fuse_producers %0 {operands_to_fuse=[0]}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/drop-schedule.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/drop-schedule.mlir index dd6e3f8..59e2032 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/drop-schedule.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/drop-schedule.mlir
@@ -22,8 +22,8 @@ rewrite %0 with "transform.apply" } - // CHECK-NOT: canonicalized_sequence - transform.structured.canonicalized_sequence %arg0 failures(propagate) { + // CHECK-NOT: sequence + transform.sequence %arg0: !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = pdl_match @pdl_target in %arg1 : (!pdl.operation) -> !pdl.operation transform.structured.tile %0 [4, 4, 4]
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir index d16c75d..c678ec4 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir
@@ -15,7 +15,7 @@ rewrite %0 with "transform.dialect" } - transform.structured.canonicalized_sequence %arg0 failures(propagate) { + transform.sequence %arg0: !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = pdl_match @some_operation in %arg1 : (!pdl.operation) -> !pdl.operation // Make sure we don't crash on wrong operation type.
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/invalid.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/invalid.mlir index 8b8abc1..4904f2e 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/invalid.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/invalid.mlir
@@ -1,6 +1,6 @@ // RUN: iree-dialects-opt %s --split-input-file -verify-diagnostics -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): %0 = pdl_match @match in %arg0 : (!pdl.operation) -> !pdl.operation // expected-error@below {{expects iterator_interchange to be a permutation, found 1, 1}} @@ -9,7 +9,7 @@ // ----- -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): %0 = pdl_match @match in %arg0 : (!pdl.operation) -> !pdl.operation // expected-error@below {{expected 'tile_sizes' attribute}} @@ -18,7 +18,7 @@ // ----- -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): %0 = pdl_match @match in %arg0 : (!pdl.operation) -> !pdl.operation // expected-error@below {{expects interchange to be a permutation, found [1, 1]}} @@ -27,7 +27,7 @@ // ----- -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): %0 = pdl_match @match in %arg0 : (!pdl.operation) -> !pdl.operation // expected-error@below {{expects pack_paddings to contain booleans (0/1), found [1, 7]}} @@ -36,7 +36,7 @@ // ----- -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): %0 = pdl_match @match in %arg0 : (!pdl.operation) -> !pdl.operation // expected-error@below {{expects transpose_paddings to be a permutation, found [1, 1]}}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir index 01a22ea..3926f9c 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir
@@ -1,7 +1,7 @@ // RUN: iree-dialects-opt %s | FileCheck %s -// CHECK: transform.structured.canonicalized_sequence -transform.structured.canonicalized_sequence failures(propagate) { +// CHECK: transform.sequence +transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): // CHECK: %[[OPS:.*]] = pdl_match @match1 in %{{.*}} %0 = pdl_match @match1 in %arg0 : (!pdl.operation) -> !pdl.operation
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir index 09dc29d..f5b7a52 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir
@@ -14,7 +14,7 @@ } -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%module_op: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %module_op : (!pdl.operation) -> !pdl.operation %1, %loops:3 = transform.structured.tile %0 [4, 4, 4]
diff --git a/tests/e2e/linalg_transform/transform_dialect_codegen_spec.mlir b/tests/e2e/linalg_transform/transform_dialect_codegen_spec.mlir index f891e49..88d1ec4 100644 --- a/tests/e2e/linalg_transform/transform_dialect_codegen_spec.mlir +++ b/tests/e2e/linalg_transform/transform_dialect_codegen_spec.mlir
@@ -1,4 +1,4 @@ -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): %variant_op_2 = transform.iree.bufferize %variant_op %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_2 : (!pdl.operation) -> !pdl.operation
diff --git a/tests/e2e/linalg_transform/transform_dialect_dispatch_spec.mlir b/tests/e2e/linalg_transform/transform_dialect_dispatch_spec.mlir index 065c2a2..6e13b45 100644 --- a/tests/e2e/linalg_transform/transform_dialect_dispatch_spec.mlir +++ b/tests/e2e/linalg_transform/transform_dialect_dispatch_spec.mlir
@@ -1,4 +1,4 @@ -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation %foreach_op, %tiled_op = transform.structured.tile_to_forall_op %0 num_threads [13, 33]
diff --git a/tests/transform_dialect/cpu/matmul.mlir b/tests/transform_dialect/cpu/matmul.mlir index 41e1a2c..601aa53 100644 --- a/tests/transform_dialect/cpu/matmul.mlir +++ b/tests/transform_dialect/cpu/matmul.mlir
@@ -38,7 +38,7 @@ // CODEGEN-CUSTOM-DISPATCH-FORMATION: hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset({{.*}}) : memref<3x3xf32> // CODEGEN-CUSTOM-DISPATCH-FORMATION: memref.assume_alignment %{{.*}}, 64 : memref<3x3xf32> // CODEGEN-CUSTOM-DISPATCH-FORMATION: %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index -// CODEGEN-CUSTOM-DISPATCH-FORMATION: affine.apply {{.*}}()[%workgroup_id_x] +// CODEGEN-CUSTOM-DISPATCH-FORMATION: affine.apply {{.*}}(%workgroup_id_x) // CODEGEN-CUSTOM-DISPATCH-FORMATION: memref.subview %{{.*}}[%{{.*}}, 0] [%{{.*}}, 5] [1, 1] : memref<3x5xf32> to memref<?x5xf32, strided<[5, 1], offset: ?>> // CODEGEN-CUSTOM-DISPATCH-FORMATION: memref.subview %{{.*}}[%{{.*}}, 0] [%{{.*}}, 3] [1, 1] : memref<3x3xf32> to memref<?x3xf32, strided<[3, 1], offset: ?>> // CODEGEN-CUSTOM-DISPATCH-FORMATION: linalg.matmul ins(%{{.*}}, %{{.*}} : memref<?x5xf32, strided<[5, 1], offset: ?>>, memref<5x3xf32>) outs(%{{.*}} : memref<?x3xf32, strided<[3, 1], offset: ?>>)
diff --git a/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir b/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir index 2ffb880..2029702 100644 --- a/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir +++ b/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir
@@ -1,6 +1,6 @@ // RUN: iree-opt %s -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!pdl.operation) -> !pdl.operation @@ -9,10 +9,16 @@ // TODO: IREE needs own workgroup mapping attribute. ( mapping = [#gpu.block<x>] ) - %1 = transform.iree.bufferize %variant_op - %memref_func = transform.structured.match ops{["func.func"]} in %1 : (!pdl.operation) -> !pdl.operation - transform.iree.erase_hal_descriptor_type_from_memref %memref_func + // Canonicalization/CSE is needed before bufferization otherwise unnecessary + // allocs will be created. + %variant_op_2 = transform.iree.apply_patterns %variant_op + { canonicalization, tiling_canonicalization, cse } + %variant_op_3 = transform.iree.bufferize %variant_op_2 + %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 + : (!pdl.operation) -> !pdl.operation + %memref_func_2 = transform.iree.erase_hal_descriptor_type_from_memref %memref_func + %memref_func_3 = transform.iree.forall_to_workgroup %memref_func_2 - %func = transform.structured.match ops{["func.func"]} in %1 : (!pdl.operation) -> !pdl.operation - transform.iree.forall_to_workgroup %func + // CSE is needed on the workgroup_count region to pass this particular test. + transform.iree.apply_patterns %variant_op_3 { cse } }
diff --git a/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir b/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir index 2d13cb7..c7f1491 100644 --- a/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir +++ b/tests/transform_dialect/cpu/matmul_codegen_default_spec.mlir
@@ -1,6 +1,6 @@ // RUN: iree-opt %s -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): %matmul = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!pdl.operation) -> !pdl.operation
diff --git a/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir b/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir index 25a08b2..5d13de7 100644 --- a/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir +++ b/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir
@@ -1,6 +1,6 @@ // RUN: iree-opt %s -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op : (!pdl.operation) -> !pdl.operation
diff --git a/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir b/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir index 29f3860..69af78b 100644 --- a/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir +++ b/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir
@@ -1,6 +1,6 @@ // RUN: iree-opt %s -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op : (!pdl.operation) -> !pdl.operation
diff --git a/tests/transform_dialect/cuda/mma.mlir b/tests/transform_dialect/cuda/mma.mlir index d7cdeec..6dc2e7d 100644 --- a/tests/transform_dialect/cuda/mma.mlir +++ b/tests/transform_dialect/cuda/mma.mlir
@@ -27,12 +27,17 @@ return } -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%module: !pdl.operation): %func = transform.structured.match ops{["func.func"]} in %module : (!pdl.operation) -> !pdl.operation %func_2 = transform.iree.apply_patterns %func { unroll_vectors_gpu_wmma } - transform.iree.vector.vector_to_mma_conversion %func_2 { use_wmma } + %func_3 = transform.iree.vector.vector_to_mma_conversion %func_2 { use_wmma } + + // Apply canonicalization post-hoc to trigger DCE and pass the test + // (i.e. all vector.contract are dead). + // TODO: consider having the vector_to_mma_conversion do the DCE automatically. + %func_4 = transform.iree.apply_patterns %func_3 { canonicalization } } // ----- @@ -62,10 +67,15 @@ return } -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%module: !pdl.operation): %func = transform.structured.match ops{["func.func"]} in %module : (!pdl.operation) -> !pdl.operation %func_2 = transform.iree.apply_patterns %func { unroll_vectors_gpu_mma_sync } - transform.iree.vector.vector_to_mma_conversion %func_2 { use_mma_sync } + %func_3 = transform.iree.vector.vector_to_mma_conversion %func_2 { use_mma_sync } + + // Apply canonicalization post-hoc to trigger DCE and pass the test + // (i.e. all vector.contract are dead). + // TODO: consider having the vector_to_mma_conversion do the DCE automatically. + %func_4 = transform.iree.apply_patterns %func_3 { canonicalization } }
diff --git a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir index ddf3fc7..e4c2f0a 100644 --- a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir +++ b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
@@ -1,6 +1,6 @@ // RUN: iree-opt %s -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op : (!pdl.operation) -> !pdl.operation @@ -21,15 +21,22 @@ // Step 3. Second level of tiling + fusion parallelizes to threads. // =========================================================================== - %fill_1d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1xf32> in %variant_op : (!pdl.operation) -> !pdl.operation + %fill_1d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1xf32> in %variant_op + : (!pdl.operation) -> !pdl.operation %forall_block_combiner_op, %block_combiner_op = transform.structured.tile_to_forall_op %grid_combiner_op tile_sizes [1] ( mapping = [#gpu.thread<z>] ) transform.structured.fuse_into_containing_op %fill_1d into %forall_block_combiner_op - %fill_2d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1x2xf32> in %variant_op : (!pdl.operation) -> !pdl.operation + // Canonicalizations. + transform.iree.apply_patterns %variant_op + { canonicalization, tiling_canonicalization, licm, cse } + + %fill_2d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1x2xf32> in %variant_op + : (!pdl.operation) -> !pdl.operation %grid_more_parallel_op = transform.structured.match ops{["linalg.generic"]} - attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op : (!pdl.operation) -> !pdl.operation + attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op + : (!pdl.operation) -> !pdl.operation %forall_block_more_parallel_op, %block_more_parallel_op = transform.structured.tile_to_forall_op %grid_more_parallel_op tile_sizes [1, 1] ( mapping = [#gpu.thread<z>, #gpu.thread<y>] ) @@ -37,7 +44,8 @@ // Step 4. Rank-reduce and vectorize. // =========================================================================== - %func = transform.structured.match ops{["func.func"]} in %variant_op : (!pdl.operation) -> !pdl.operation + %func = transform.structured.match ops{["func.func"]} in %variant_op + : (!pdl.operation) -> !pdl.operation %func_2 = transform.iree.apply_patterns %func { rank_reducing_linalg, rank_reducing_vector } %func_3 = transform.structured.vectorize %func_2 @@ -46,12 +54,14 @@ %func_4 = transform.iree.apply_patterns %func_3 { fold_reassociative_reshapes } %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op_2 - %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation + %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 + : (!pdl.operation) -> !pdl.operation transform.iree.erase_hal_descriptor_type_from_memref %memref_func // Step 6. Post-bufferization mapping to blocks and threads. // =========================================================================== - %func_5 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation + %func_5 = transform.structured.match ops{["func.func"]} in %variant_op_3 + : (!pdl.operation) -> !pdl.operation %func_6 = transform.iree.forall_to_workgroup %func_5 %func_7 = transform.iree.map_nested_forall_to_gpu_threads %func_6 { workgroup_size = [32, 2, 1] } @@ -59,7 +69,8 @@ // Step 7. Post-bufferization vector distribution with rank-reduction. // =========================================================================== %func_8 = transform.iree.apply_patterns %func_7 { rank_reducing_linalg, rank_reducing_vector, fold_memref_aliases } - %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation + %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3 + : (!pdl.operation) -> !pdl.operation // Don't complain about unsupported if (threadIdx.x == 0 && threadIdx.y == 0) // at this point. transform.sequence %variant_op_3 : !pdl.operation failures(suppress) { @@ -67,4 +78,10 @@ transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 } } transform.iree.vector.warp_distribute %func_8 + : (!pdl.operation) -> !pdl.operation + + + // Late Canonicalizations. + transform.iree.apply_patterns %variant_op_3 + { canonicalization, tiling_canonicalization, licm, cse } }
diff --git a/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir index c7213ee..164cd31 100644 --- a/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir +++ b/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
@@ -1,44 +1,68 @@ // RUN: iree-opt %s -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): - %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op : (!pdl.operation) -> !pdl.operation + %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op + : (!pdl.operation) -> !pdl.operation // Step 1. Split the reduction to get meatier (size(red) / 2)-way parallelism. // =========================================================================== - %0 = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation - %reduction, %eltwise = transform.split_handles %0 in [2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation) + %0 = transform.structured.match ops{["linalg.generic"]} in %variant_op + : (!pdl.operation) -> !pdl.operation + %reduction, %eltwise = transform.split_handles %0 in [2] + : (!pdl.operation) -> (!pdl.operation, !pdl.operation) %init_or_alloc_op, %more_parallel_fill_op, %more_parallel_op, %combiner_op = transform.structured.split_reduction %reduction { split_factor = 2, insert_split_dimension = 1 } + // Canonicalizations. + transform.iree.apply_patterns %variant_op + { canonicalization, tiling_canonicalization, licm, cse } + // Step 2. First level of tiling + fusion parallelizes to blocks. Tile the // trailing elementwise the same way we want to tile the reduction. // =========================================================================== %grid_loop, %eltwise_grid_op = transform.iree.tile_to_forall_and_workgroup_count_region %eltwise tile_sizes [1] (mapping = [#gpu.block<x>]) - %not_eltwise = transform.merge_handles %fill, %more_parallel_fill_op, %more_parallel_op, %combiner_op : !pdl.operation + %not_eltwise = transform.merge_handles %fill, %more_parallel_fill_op, %more_parallel_op, %combiner_op + : !pdl.operation transform.structured.fuse_into_containing_op %not_eltwise into %grid_loop + // Canonicalizations. + transform.iree.apply_patterns %variant_op + { canonicalization, tiling_canonicalization, licm, cse } + // Step 3. Second level of tiling + fusion parallelizes to threads. // =========================================================================== - %fill_1d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1xf32> in %variant_op : (!pdl.operation) -> !pdl.operation + %fill_1d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1xf32> in %variant_op + : (!pdl.operation) -> !pdl.operation %eltwise_block_loop, %eltwise_block_op = transform.structured.tile_to_forall_op %eltwise_grid_op tile_sizes [1] ( mapping = [#gpu.thread<z>] ) %block_combiner_op = transform.structured.match ops{["linalg.generic"]} - attributes {iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op : (!pdl.operation) -> !pdl.operation + attributes {iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op + : (!pdl.operation) -> !pdl.operation %combined_and_fill = transform.merge_handles %fill_1d, %block_combiner_op : !pdl.operation transform.structured.fuse_into_containing_op %combined_and_fill into %eltwise_block_loop - %fill_2d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1x2xf32> in %variant_op : (!pdl.operation) -> !pdl.operation + // Canonicalizations. + transform.iree.apply_patterns %variant_op + { canonicalization, tiling_canonicalization, licm, cse } + + %fill_2d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1x2xf32> in %variant_op + : (!pdl.operation) -> !pdl.operation %grid_more_parallel_op = transform.structured.match ops{["linalg.generic"]} - attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op : (!pdl.operation) -> !pdl.operation + attributes{iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]} in %variant_op + : (!pdl.operation) -> !pdl.operation %forall_block_more_parallel_op, %block_more_parallel_op = transform.structured.tile_to_forall_op %grid_more_parallel_op tile_sizes [1, 1] ( mapping = [#gpu.thread<z>, #gpu.thread<y>] ) transform.structured.fuse_into_containing_op %fill_2d into %forall_block_more_parallel_op + // Canonicalizations. + transform.iree.apply_patterns %variant_op + { canonicalization, tiling_canonicalization, licm, cse } + // Step 4. Rank-reduce and vectorize. // =========================================================================== %func = transform.structured.match ops{["func.func"]} in %variant_op : (!pdl.operation) -> !pdl.operation @@ -50,7 +74,8 @@ %func_4 = transform.iree.apply_patterns %func_3 { fold_reassociative_reshapes } %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op_2 - %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation + %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 + : (!pdl.operation) -> !pdl.operation transform.iree.erase_hal_descriptor_type_from_memref %memref_func // Step 6. Post-bufferization mapping to blocks and threads. @@ -63,7 +88,8 @@ // Step 7. Post-bufferization vector distribution with rank-reduction. // =========================================================================== %func_8 = transform.iree.apply_patterns %func_7 { rank_reducing_linalg, rank_reducing_vector, fold_memref_aliases } - %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation + %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3 + : (!pdl.operation) -> !pdl.operation // Don't complain about unsupported if (threadIdx.x == 0 && threadIdx.y == 0) // at this point. transform.sequence %variant_op_3 : !pdl.operation failures(suppress) { @@ -71,4 +97,10 @@ transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 } } transform.iree.vector.warp_distribute %func_8 + : (!pdl.operation) -> !pdl.operation + + + // Late canonicalizations. + transform.iree.apply_patterns %variant_op_3 + { canonicalization, tiling_canonicalization, licm, cse } }
diff --git a/tests/transform_dialect/cuda/reduction_v2.mlir b/tests/transform_dialect/cuda/reduction_v2.mlir index 0578a84..6f4af93 100644 --- a/tests/transform_dialect/cuda/reduction_v2.mlir +++ b/tests/transform_dialect/cuda/reduction_v2.mlir
@@ -46,19 +46,21 @@ // CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x128xf32, #gpu.address_space<workgroup>> // CHECK: %[[TIDX:.]] = gpu.thread_id x - // CHECK: %[[IDX:.*]] = affine.apply{{.*}}%[[TIDX]] + // CHECK: %[[IDX_0:.*]] = affine.apply{{.*}}()[%[[TIDX]]] // CHECK: gpu.barrier + // TODO: Properly poduce/CSE IDX_1 vs IDX_0 + // CHECK: %[[IDX_1:.*]] = affine.apply{{.*}}(%[[TIDX]]) // Local per-thread scf.for-based reduction. // CHECK: scf.for // CHECK: vector.transfer_read - // CHECK: vector.transfer_read %[[SHMEM_ALLOC]][%[[C0]], %[[IDX]]] + // CHECK: vector.transfer_read %[[SHMEM_ALLOC]][%[[C0]], %[[IDX_1]]] // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> - // CHECK: vector.transfer_write %{{.*}}, %[[SHMEM_ALLOC]][%[[C0]], %[[IDX]]] + // CHECK: vector.transfer_write %{{.*}}, %[[SHMEM_ALLOC]][%[[C0]], %[[IDX_1]]] // TODO: remote unnecessary barrier within the loop // CHECK: gpu.barrier // Distributed reduction: everyone loads then 5 xor + addf expected - // CHECK: vector.transfer_read %{{.*}}[%[[C0]], %[[IDX]]] + // CHECK: vector.transfer_read %{{.*}}[%[[C0]], %[[IDX_0]]] // CHECK-COUNT-5: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf // CHECK: %[[RES:.*]] = arith.addf %{{.*}}
diff --git a/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir index 8951d6a..bda3fcb 100644 --- a/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir +++ b/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir
@@ -1,6 +1,6 @@ // RUN: iree-opt %s -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op : (!pdl.operation) -> !pdl.operation %reduction = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation @@ -40,7 +40,10 @@ // Step 5. Bufferize and drop HAL decriptor from memref ops. // =========================================================================== - %func_4 = transform.iree.apply_patterns %func_3 { fold_reassociative_reshapes } + // Canonicalization/CSE is needed before bufferization otherwise unnecessary + // allocs will be created. + %func_4 = transform.iree.apply_patterns %func_3 + { fold_reassociative_reshapes, canonicalization, tiling_canonicalization, cse } %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op %func_5 = transform.structured.match ops{["func.func"]} in %variant_op_2 : (!pdl.operation) -> !pdl.operation %func_6 = transform.iree.apply_patterns %func_5 { erase_unnecessary_tensor_operands } @@ -58,7 +61,13 @@ // Step 7. Post-bufferization vector distribution with rank-reduction. // =========================================================================== %func_10 = transform.iree.apply_patterns %func_9 { rank_reducing_linalg, rank_reducing_vector, fold_memref_aliases } - %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation + %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3 + : (!pdl.operation) -> !pdl.operation %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 } - transform.iree.vector.warp_distribute %func_10 + %func_11 = transform.iree.vector.warp_distribute %func_10 + : (!pdl.operation) -> !pdl.operation + + // Late canonicalizations to cleanup and pass the checks + %func_12 = transform.iree.apply_patterns %func_11 + { canonicalization, tiling_canonicalization, licm, cse } }
diff --git a/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir index 967d37a..53afe4d 100644 --- a/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir +++ b/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
@@ -1,9 +1,14 @@ // RUN: iree-opt %s +// TODO: port this test to transform.sequence, atm some canonicalization patterns +// ping-pong into oblivion. +// transform.sequence failures(propagate) { transform.structured.canonicalized_sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): - %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op : (!pdl.operation) -> !pdl.operation - %reduction = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation + %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op + : (!pdl.operation) -> !pdl.operation + %reduction = transform.structured.match ops{["linalg.generic"]} in %variant_op + : (!pdl.operation) -> !pdl.operation // Step 1. First level of tiling + fusion parallelizes to blocks. // =========================================================================== @@ -31,25 +36,30 @@ // Step 3. Rank-reduce and vectorize. // =========================================================================== - %func = transform.structured.match ops{["func.func"]} in %variant_op : (!pdl.operation) -> !pdl.operation + %func = transform.structured.match ops{["func.func"]} in %variant_op + : (!pdl.operation) -> !pdl.operation // TODO: masked vectorization on block_more_parallel_op_2 if we want // vector<4> to work as intended. - %func_2 = transform.iree.apply_patterns %func { rank_reducing_linalg, rank_reducing_vector } + %func_2 = transform.iree.apply_patterns %func + { rank_reducing_linalg, rank_reducing_vector } %func_3 = transform.structured.vectorize %func_2 // Step 4. Bufferize and drop HAL descriptor from memref ops. // =========================================================================== %func_4 = transform.iree.apply_patterns %func_3 { fold_reassociative_reshapes } %variant_op_2 = transform.iree.eliminate_empty_tensors %variant_op - %func_5 = transform.structured.match ops{["func.func"]} in %variant_op_2 : (!pdl.operation) -> !pdl.operation + %func_5 = transform.structured.match ops{["func.func"]} in %variant_op_2 + : (!pdl.operation) -> !pdl.operation %func_6 = transform.iree.apply_patterns %func_5 { erase_unnecessary_tensor_operands } %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op_2 - %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation + %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 + : (!pdl.operation) -> !pdl.operation transform.iree.erase_hal_descriptor_type_from_memref %memref_func // Step 5. Post-bufferization mapping to blocks and threads. // =========================================================================== - %func_7 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation + %func_7 = transform.structured.match ops{["func.func"]} in %variant_op_3 + : (!pdl.operation) -> !pdl.operation %func_8 = transform.iree.forall_to_workgroup %func_7 %func_9 = transform.iree.map_nested_forall_to_gpu_threads %func_8 { workgroup_size = [1024, 1, 1] } @@ -57,7 +67,9 @@ // Step 6. Post-bufferization vector distribution with rank-reduction. // =========================================================================== %func_10 = transform.iree.apply_patterns %func_9 { rank_reducing_linalg, rank_reducing_vector, fold_memref_aliases } - %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation + %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3 + : (!pdl.operation) -> !pdl.operation %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 } transform.iree.vector.warp_distribute %func_10 + : (!pdl.operation) -> !pdl.operation }
diff --git a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir index df7464d..64eda2a 100644 --- a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir +++ b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
@@ -1,7 +1,7 @@ // RUN: iree-opt %s // Codegen -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): %ops = transform.structured.match ops{["linalg.fill", "linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation
diff --git a/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir b/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir index efa5e5d..3084926 100644 --- a/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir +++ b/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir
@@ -1,7 +1,7 @@ // RUN: iree-opt %s // Dispatch softmax. -transform.structured.canonicalized_sequence failures(propagate){ +transform.sequence failures(propagate){ ^bb1(%variant_op: !pdl.operation): %ops = transform.structured.match ops{["linalg.fill", "linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation
diff --git a/tests/transform_dialect/cuda/softmax_partial_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_partial_codegen_spec.mlir index ba90328..3e7d96a 100644 --- a/tests/transform_dialect/cuda/softmax_partial_codegen_spec.mlir +++ b/tests/transform_dialect/cuda/softmax_partial_codegen_spec.mlir
@@ -1,7 +1,7 @@ // RUN: iree-opt %s // Codegen -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): // Step 1. First level of tiling + fusion parallelizes to blocks.
diff --git a/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir index 89f18df..9860764 100644 --- a/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir +++ b/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
@@ -1,7 +1,7 @@ // RUN: iree-opt %s // Codegen -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): %ops = transform.structured.match ops{["linalg.fill", "linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation @@ -31,6 +31,11 @@ transform.iree.share_forall_operands %forall_with_type : (!transform.op<"scf.forall">) -> !transform.op<"scf.forall"> + // Canonicalizations. + transform.iree.apply_patterns %variant_op + { canonicalization, tiling_canonicalization, licm, cse } + + // Step 2. Second level of tiling + fusion parallelizes to threads. // ================================================================ %tiled_ops = transform.structured.match ops{["linalg.fill", "linalg.generic"]} @@ -59,6 +64,10 @@ transform.structured.tile_to_forall_op %parallel_linalg_ops num_threads [1, 4, 32] ( mapping = [#gpu.thread<z>, #gpu.thread<y>, #gpu.thread<x>] ) + // Canonicalizations. + transform.iree.apply_patterns %variant_op + { canonicalization, tiling_canonicalization, licm, cse } + // Step 3. Rank-reduce and vectorize. // ================================== %funcx_2 = transform.structured.match ops{["func.func"]} in %variant_op : (!pdl.operation) -> !pdl.operation @@ -81,9 +90,17 @@ // Step 6. Post-bufferization vector distribution with rank-reduction. // =================================================================== - %end_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation + %end_func = transform.structured.match ops{["func.func"]} in %variant_op_3 + : (!pdl.operation) -> !pdl.operation %end_func_2 = transform.iree.apply_patterns %end_func { rank_reducing_linalg, rank_reducing_vector, fold_memref_aliases } - %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3 : (!pdl.operation) -> !pdl.operation + %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3 + : (!pdl.operation) -> !pdl.operation %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 } transform.iree.vector.warp_distribute %end_func_2 + : (!pdl.operation) -> !pdl.operation + + + // Late canonicalizations. + transform.iree.apply_patterns %variant_op_3 + { canonicalization, tiling_canonicalization, licm, cse } }
diff --git a/tests/transform_dialect/cuda/vecadd2d_codegen_spec.mlir b/tests/transform_dialect/cuda/vecadd2d_codegen_spec.mlir index 5748777..ee202ad 100644 --- a/tests/transform_dialect/cuda/vecadd2d_codegen_spec.mlir +++ b/tests/transform_dialect/cuda/vecadd2d_codegen_spec.mlir
@@ -1,4 +1,4 @@ -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): // Step 1. Find three linalg.generics and tile to GPU thread blocks. // ===========================================================================
diff --git a/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir b/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir index f651ed1..8686968 100644 --- a/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir +++ b/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir
@@ -1,7 +1,13 @@ -transform.structured.canonicalized_sequence failures(propagate) { +transform.sequence failures(propagate) { ^bb1(%variant_op: !pdl.operation): - %generics = transform.structured.match ops{["linalg.generic"]} in %variant_op : (!pdl.operation) -> !pdl.operation + %generics = transform.structured.match ops{["linalg.generic"]} in %variant_op + : (!pdl.operation) -> !pdl.operation // Tile only one dimension, skip the other one. transform.iree.tile_to_forall_and_workgroup_count_region %generics tile_sizes [0, 3] ( mapping = [#gpu.block<z>]) + + // Late canonicalizations to cleanup and pass the checks. + // Needs to occur on the whole variant to perform cse on the workgroup_count region + transform.iree.apply_patterns %variant_op + { canonicalization, tiling_canonicalization, licm, cse } }