Integrate llvm-project at 88f07a31 (#14165)
diff --git a/compiler/src/iree/compiler/API/Internal/LLDToolEntryPoint.cpp b/compiler/src/iree/compiler/API/Internal/LLDToolEntryPoint.cpp
index 40962d4..fed1423 100644
--- a/compiler/src/iree/compiler/API/Internal/LLDToolEntryPoint.cpp
+++ b/compiler/src/iree/compiler/API/Internal/LLDToolEntryPoint.cpp
@@ -43,13 +43,10 @@
using namespace llvm;
using namespace llvm::sys;
-enum Flavor {
- Invalid,
- Gnu, // -flavor gnu
- WinLink, // -flavor link
- Darwin, // -flavor darwin
- Wasm, // -flavor wasm
-};
+LLD_HAS_DRIVER(coff)
+LLD_HAS_DRIVER(elf)
+LLD_HAS_DRIVER(macho)
+LLD_HAS_DRIVER(wasm)
[[noreturn]] static void die(const Twine &s) {
llvm::errs() << s << "\n";
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 29f1dd1..54ae591 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -107,15 +107,13 @@
DiagnosedSilenceableFailure
transform_dialect::ApplyBufferOptimizationsOp::applyToOne(
- Operation *target, transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, Operation *target,
+ transform::ApplyToEachResultList &results,
transform::TransformState &state) {
// Apply store to load forwarding and dead store elimination.
- IRRewriter rewriter(target->getContext());
- ErrorCheckingTrackingListener listener(state, *this);
- rewriter.setListener(&listener);
vector::transferOpflowOpt(rewriter, target);
eraseDeadAllocAndStores(rewriter, target);
- return listener.checkAndResetError();
+ return DiagnosedSilenceableFailure::success();
}
void transform_dialect::ApplyBufferOptimizationsOp::getEffects(
@@ -301,7 +299,8 @@
DiagnosedSilenceableFailure
transform_dialect::ApplyCommonSubexpressionEliminationOp::applyToOne(
- Operation *target, transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, Operation *target,
+ transform::ApplyToEachResultList &results,
transform::TransformState &state) {
ErrorCheckingTrackingListener listener(state, *this);
Operation *lastOpVisited = nullptr;
@@ -337,7 +336,8 @@
DiagnosedSilenceableFailure
transform_dialect::ApplyLoopIndependentCodeMotionOp::applyToOne(
- Operation *target, transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, Operation *target,
+ transform::ApplyToEachResultList &results,
transform::TransformState &state) {
ErrorCheckingTrackingListener listener(state, *this);
target->walk([&](func::FuncOp funcOp) {
@@ -373,14 +373,12 @@
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform_dialect::HoistStaticAllocOp::applyToOne(
- func::FuncOp target, transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, func::FuncOp target,
+ transform::ApplyToEachResultList &results,
transform::TransformState &state) {
- IRRewriter rewriter(target->getContext());
- ErrorCheckingTrackingListener listener(state, *this);
- rewriter.setListener(&listener);
mlir::iree_compiler::hoistStaticallyBoundAllocationsInFunc<memref::AllocOp>(
rewriter, target);
- return listener.checkAndResetError();
+ return DiagnosedSilenceableFailure::success();
}
void transform_dialect::HoistStaticAllocOp::getEffects(
@@ -395,9 +393,9 @@
DiagnosedSilenceableFailure
transform_dialect::ShareForallOperandsOp::applyToOne(
- scf::ForallOp forallOp, transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, scf::ForallOp forallOp,
+ transform::ApplyToEachResultList &results,
transform::TransformState &state) {
- IRRewriter rewriter(getContext());
SmallVector<int64_t> shareOperands(getShareOperands());
// Empty case: consider all operands need to be shared.
if (shareOperands.empty()) {
@@ -563,7 +561,8 @@
//===---------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform_dialect::ForallToWorkgroupOp::applyToOne(
- func::FuncOp target, transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, func::FuncOp target,
+ transform::ApplyToEachResultList &results,
transform::TransformState &state) {
if (!isa<HAL::ExecutableOp, HAL::ExecutableVariantOp>(state.getTopLevel())) {
return mlir::emitDefiniteFailure(
@@ -596,14 +595,11 @@
target, "could not find a unique topLevel scf.forall");
}
- IRRewriter rewriter(topLevelForallOp->getContext());
rewriter.setInsertionPoint(topLevelForallOp);
- ErrorCheckingTrackingListener listener(state, *this);
- rewriter.setListener(&listener);
if (failed(rewriteForallToWorkgroup(rewriter, topLevelForallOp, exportOp)))
return mlir::emitDefiniteFailure(target, "rewriteForallToWorkgroup failed");
- return listener.checkAndResetError();
+ return DiagnosedSilenceableFailure::success();
}
void transform_dialect::ForallToWorkgroupOp::getEffects(
@@ -624,7 +620,8 @@
DiagnosedSilenceableFailure transform_dialect::
IREEPopulateWorkgroupCountRegionUsingNumThreadsSliceOp::applyToOne(
- Operation *target, transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, Operation *target,
+ transform::ApplyToEachResultList &results,
transform::TransformState &state) {
auto forAllOp = dyn_cast<scf::ForallOp>(target);
if (!forAllOp) {
@@ -635,7 +632,6 @@
return mlir::emitDefiniteFailure(state.getTopLevel(),
"Expect the for op to be normalized");
}
- IRRewriter rewriter(target->getContext());
auto workgroupCount =
getMixedValues(forAllOp.getStaticUpperBound(),
forAllOp.getDynamicUpperBound(), rewriter);
@@ -837,6 +833,7 @@
} // namespace
DiagnosedSilenceableFailure transform_dialect::IREEBufferizeOp::apply(
+ transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
auto payload = state.getPayloadOps(getTarget());
if (!llvm::hasSingleElement(payload) ||
@@ -933,17 +930,14 @@
DiagnosedSilenceableFailure
transform_dialect::IREEEliminateEmptyTensorsOp::applyToOne(
- ::mlir::Operation *target,
+ transform::TransformRewriter &rewriter, ::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state) {
- IRRewriter rewriter(target->getContext());
- ErrorCheckingTrackingListener listener(state, *this);
- rewriter.setListener(&listener);
if (failed(
eliminateEmptyTensors(rewriter, target, getBufferizationOptions())))
return emitDefaultDefiniteFailure(target)
<< "failed to eliminate tensor.empty ops";
- return listener.checkAndResetError();
+ return DiagnosedSilenceableFailure::success();
}
void transform_dialect::IREEEliminateEmptyTensorsOp::getEffects(
@@ -958,7 +952,7 @@
DiagnosedSilenceableFailure
transform_dialect::IREEEraseHALDescriptorTypeFromMemRefOp::applyToOne(
- ::mlir::Operation *target,
+ transform::TransformRewriter &rewriter, ::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state) {
if (!isa<func::FuncOp>(target)) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
index 9becc08..eae6909 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
@@ -20,7 +20,8 @@
Op<Transform_Dialect, "iree.apply_buffer_optimizations",
[TransformEachOpTrait,
TransformOpInterface,
- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
This applies memory optimization on memref. In particular it does store to
load forwarding, dead store elimination and dead alloc elimination.
@@ -45,6 +46,7 @@
];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -53,7 +55,8 @@
def ApplyBubbleCollapsePatternsOp : Op<Transform_Dialect,
"apply_patterns.iree.bubble_collapse",
- [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Populate patterns to fold an expanding tensor.expand_shape operation with
its producer generic operation by collapsing the dimensions of the generic
@@ -66,7 +69,8 @@
def ApplyBubbleExpandPatternsOp : Op<Transform_Dialect,
"apply_patterns.iree.bubble_expand",
- [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Populate patterns to fold an expanding (collapsing) tensor_reshape
operation with its producer (consumer) generic operation by expanding
@@ -79,7 +83,8 @@
def ApplyBubblePackUnpackPatternsOp : Op<Transform_Dialect,
"apply_patterns.iree.bubble_pack_unpack",
- [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Populate patterns to bubble up or down data layout ops across other
operations.
@@ -91,7 +96,8 @@
def ApplyFoldFillIntoPadPatternsOp : Op<Transform_Dialect,
"apply_patterns.iree.fold_fill_into_pad",
- [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Populates a pattern that folds
"tensor.pad(cst, tensor.extract*(linalg.fill(cst)))" into
@@ -105,7 +111,8 @@
def ApplyFoldReshapeIntoTensorHalInterfacePatternsOp : Op<Transform_Dialect,
"apply_patterns.iree.fold_reshape_into_tensor_hal_interface",
- [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Populate patterns that fold tensor.expand_shape/tensor.collapse_shape into
the source hal.interface.binding.subspan op.
@@ -117,7 +124,8 @@
def ApplyIreeLinalgElementwiseGreedyFusionPatternsOp : Op<Transform_Dialect,
"apply_patterns.iree.linalg_elementwise_greedy_fusion",
- [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Populate patterns to fuse `linalg.generic` -> `linalg.generic` operations
when both operations are fusable elementwise operations.
@@ -132,7 +140,8 @@
def ApplyPrepareVectorToMMAPatternsOp : Op<Transform_Dialect,
"apply_patterns.iree.prepare_vector_to_mma",
- [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Populate patterns that transform vector ops into a canonical form to
convert to MMA matrix operations. If `useNvGpu` is true, then the patterns
@@ -147,7 +156,8 @@
def ApplyUnrollVectorsGpuMmaSyncPatternsOp : Op<Transform_Dialect,
"apply_patterns.iree.unroll_vectors_gpu_mma_sync",
- [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Populate patterns that unroll vectors. TODO: better documentation.
}];
@@ -158,7 +168,8 @@
def ApplyUnrollVectorsGpuWmmaSyncPatternsOp : Op<Transform_Dialect,
"apply_patterns.iree.unroll_vectors_gpu_wmma_sync",
- [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Populate patterns that unroll vectors. TODO: better documentation.
}];
@@ -170,7 +181,8 @@
def ApplyCommonSubexpressionEliminationOp : Op<Transform_Dialect, "iree.apply_cse",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Apply common subexpression elimination. This transform is applied to all
ops within the target that are isolated from above.
@@ -187,6 +199,7 @@
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -196,7 +209,8 @@
def ApplyLoopIndependentCodeMotionOp : Op<Transform_Dialect, "iree.apply_licm",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Apply loop-independent code motion and single iteration loop promotion.
This transform is applied to all FuncOps within the target.
@@ -213,6 +227,7 @@
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -222,7 +237,8 @@
def HoistStaticAllocOp : Op<Transform_Dialect, "iree.hoist_static_alloc",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let summary = "Hoist static allocations";
let description = [{
Find static allocations and hoist them to the top level.
@@ -241,6 +257,7 @@
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::func::FuncOp funcOp,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -250,7 +267,8 @@
def IREEBufferizeOp : Op<Transform_Dialect, "iree.bufferize",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
- DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Target the whole hal.executable_variant op and call upstream comprehensive
bufferize with extra IREE hooks.
@@ -296,7 +314,8 @@
Transform_Dialect, "iree.eliminate_empty_tensors",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
This is a pre-processing pass for iree.bufferize. It tries to remove
tensor.empty ops by replacing them with suitable destination tensors,
@@ -317,6 +336,7 @@
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -327,7 +347,8 @@
"iree.erase_hal_descriptor_type_from_memref",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Erase #hal.descriptor_type from MemRef memory space to ignore all IREE
memory space planning. This is meant to ease transitioning given that
@@ -352,6 +373,7 @@
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -362,7 +384,8 @@
"iree.forall_to_workgroup",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Target the whole hal.executable_variant op and rewrite the unique topLevel
scf.forall to distributed workgroup_id and workgroup_count.
@@ -407,6 +430,7 @@
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::func::FuncOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -418,7 +442,8 @@
FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Target a single scf.forall op and shares all uses of the specified
`share_operands` operand indices.
@@ -478,6 +503,7 @@
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::scf::ForallOp forallOp,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -488,7 +514,8 @@
Op<Transform_Dialect, "iree.populate_workgroup_count_region_using_num_threads_slice",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Populate the workgroup_count region on the `hal.executable.export` op.
@@ -509,6 +536,7 @@
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorization.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorization.cpp
index dafc91b..de46976 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorization.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorization.cpp
@@ -254,7 +254,9 @@
vectorSizes.append(ty.getShape().begin(), ty.getShape().end());
}
}
- (void)linalg::vectorize(rewriter, op, vectorSizes, vectorizeGatherAccesses);
+ SmallVector<bool> scalableVecDims(vectorSizes.size(), false);
+ (void)linalg::vectorize(rewriter, op, vectorSizes, scalableVecDims,
+ vectorizeGatherAccesses);
};
// TODO: Move this down the pipeline once we have the ODM-based masking
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
index c8bc183..6593fda 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
@@ -75,7 +75,8 @@
// TODO: synchronizations for imperfectly nested stuff.
DiagnosedSilenceableFailure
transform_dialect::MapNestedForallToGpuThreadsOp::applyToOne(
- func::FuncOp target, transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, func::FuncOp target,
+ transform::ApplyToEachResultList &results,
transform::TransformState &state) {
if (!isa<HAL::ExecutableOp, HAL::ExecutableVariantOp>(state.getTopLevel())) {
state.getTopLevel()->emitOpError(
@@ -97,9 +98,6 @@
auto transformOp = cast<transform::TransformOpInterface>(getOperation());
- IRRewriter rewriter(target->getContext());
- ErrorCheckingTrackingListener listener(state, *this);
- rewriter.setListener(&listener);
rewriter.setInsertionPointToStart(&target.getBody().front());
DiagnosedSilenceableFailure diag =
mlir::transform::gpu::mapNestedForallToThreadsImpl(
@@ -110,7 +108,7 @@
rewriter.startRootUpdate(exportOp);
exportOp->setAttr(exportOp.getWorkgroupSizeAttrName(), newAttr);
rewriter.finalizeRootUpdate(exportOp);
- return listener.checkAndResetError();
+ return DiagnosedSilenceableFailure::success();
}
void transform_dialect::MapNestedForallToGpuThreadsOp::getEffects(
@@ -285,7 +283,8 @@
DiagnosedSilenceableFailure
transform_dialect::VectorToWarpExecuteOnLane0Op::applyToOne(
- scf::IfOp target, transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, scf::IfOp target,
+ transform::ApplyToEachResultList &results,
transform::TransformState &state) {
if (!isa<HAL::ExecutableOp, HAL::ExecutableVariantOp>(state.getTopLevel())) {
results.assign(1, nullptr);
@@ -339,10 +338,7 @@
}
Location loc = target->getLoc();
- IRRewriter rewriter(target->getContext());
rewriter.setInsertionPoint(target);
- ErrorCheckingTrackingListener listener(state, *this);
- rewriter.setListener(&listener);
FailureOr<VectorDistributionResult> vectorDistributionResult =
rewriteScfIfAsWarpExecuteOnLane0(rewriter, loc, target, workgroupSizeX,
warpSize);
@@ -357,7 +353,7 @@
}
results.push_back(vectorDistributionResult->warpOp);
- return listener.checkAndResetError();
+ return DiagnosedSilenceableFailure::success();
}
//===---------------------------------------------------------------------===//
@@ -588,7 +584,8 @@
DiagnosedSilenceableFailure
transform_dialect::VectorWarpDistributionOp::applyToOne(
- Operation *target, transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, Operation *target,
+ transform::ApplyToEachResultList &results,
transform::TransformState &state) {
if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
target->emitOpError(
@@ -658,7 +655,8 @@
DiagnosedSilenceableFailure
transform_dialect::VectorToMMAConversionOp::applyToOne(
- Operation *target, transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, Operation *target,
+ transform::ApplyToEachResultList &results,
transform::TransformState &state) {
if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
target->emitOpError(
@@ -702,8 +700,6 @@
<< *target << "\n";
});
- IRRewriter rewriter(target->getContext());
- rewriter.setListener(&listener);
auto diag = DiagnosedSilenceableFailure::success();
if (getUseWmma()) {
if (failed(convertVectorToMMAOps(rewriter, target)))
@@ -740,10 +736,10 @@
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform_dialect::PromoteOperandsOp::applyToOne(
- Operation *target, transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, Operation *target,
+ transform::ApplyToEachResultList &results,
transform::TransformState &state) {
Location loc = target->getLoc();
- IRRewriter rewriter(getContext());
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(target);
SmallVector<int64_t> indices = llvm::to_vector(getIndices());
@@ -774,9 +770,9 @@
DiagnosedSilenceableFailure
transform_dialect::PipelineSharedMemoryCopiesOp::applyToOne(
- scf::ForOp forOp, transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, scf::ForOp forOp,
+ transform::ApplyToEachResultList &results,
transform::TransformState &state) {
- IRRewriter rewriter(getContext());
int64_t depth(getDepth());
auto schedule = getUseMmaSync()
? PipeliningSchedulingStrategy::nvidiaTensorCore
@@ -799,9 +795,9 @@
}
DiagnosedSilenceableFailure transform_dialect::SynchronizeLoopOp::applyToOne(
- scf::ForOp forOp, transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, scf::ForOp forOp,
+ transform::ApplyToEachResultList &results,
transform::TransformState &state) {
- IRRewriter rewriter(getContext());
rewriter.setInsertionPointAfter(forOp);
rewriter.create<gpu::BarrierOp>(forOp.getLoc());
return DiagnosedSilenceableFailure::success();
@@ -818,14 +814,12 @@
}
DiagnosedSilenceableFailure transform_dialect::CreateAsyncGroupsOp::applyToOne(
- func::FuncOp target, transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, func::FuncOp target,
+ transform::ApplyToEachResultList &results,
transform::TransformState &state) {
- IRRewriter rewriter(target->getContext());
- ErrorCheckingTrackingListener listener(state, *this);
- rewriter.setListener(&listener);
iree_compiler::createAsyncGroups(rewriter, cast<func::FuncOp>(target),
getUseMmaSync());
- return listener.checkAndResetError();
+ return DiagnosedSilenceableFailure::success();
}
//===---------------------------------------------------------------------===//
@@ -833,9 +827,9 @@
//===---------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform_dialect::LayoutAnalysisAndDistributionOp::applyToOne(
- func::FuncOp target, transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, func::FuncOp target,
+ transform::ApplyToEachResultList &results,
transform::TransformState &state) {
- IRRewriter rewriter(getContext());
iree_compiler::doLayoutAnalysisAndDistribution(rewriter,
cast<func::FuncOp>(target));
results.push_back(target);
@@ -846,9 +840,9 @@
// ReorderTransposeOp
//===---------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform_dialect::ReorderTransposeOp::applyToOne(
- func::FuncOp target, transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, func::FuncOp target,
+ transform::ApplyToEachResultList &results,
transform::TransformState &state) {
- IRRewriter rewriter(getContext());
iree_compiler::reorderTranspose(rewriter, cast<func::FuncOp>(target));
results.push_back(target);
return DiagnosedSilenceableFailure::success();
@@ -1430,7 +1424,8 @@
DiagnosedSilenceableFailure
transform_dialect::EliminateGpuBarriersOp::applyToOne(
- func::FuncOp target, transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, func::FuncOp target,
+ transform::ApplyToEachResultList &results,
transform::TransformState &state) {
RewritePatternSet patterns(target.getContext());
patterns.insert<BarrierElimination>(getContext());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
index 767860b..6399b5a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td
@@ -17,7 +17,8 @@
Op<Transform_Dialect, "iree.map_nested_forall_to_gpu_threads",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Target the whole hal.executable_variant op and rewrite all scf.forall
to distributed gpu.thread_id and translation_info attribute.
@@ -102,6 +103,7 @@
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::func::FuncOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -112,7 +114,8 @@
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Given an scf.if target predicated by `if (threadIdx.x == 0)`, rewrite its
body to vector.execute_on_lane_0 running ***on a single warp***.
@@ -207,6 +210,7 @@
];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::scf::IfOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -216,7 +220,8 @@
def VectorWarpDistributionOp : Op<Transform_Dialect, "iree.vector.warp_distribute",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Given a vector.warp_execute_on_lane_0, apply the patterns to rewrite into
distributed form with warp synchronization. This produces IR that runs
@@ -319,6 +324,7 @@
];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -328,7 +334,8 @@
def VectorToMMAConversionOp : Op<Transform_Dialect, "iree.vector.vector_to_mma_conversion",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
This converts slices of operations containing vector.contract op into
mma operations, targetting warp level tensorcore operations. If the vector
@@ -356,6 +363,7 @@
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -367,7 +375,8 @@
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
This op promotes the specified operands of the provided target handle.
@@ -388,6 +397,7 @@
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -399,7 +409,8 @@
FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
This applies software pipelining to a given scf.for loop. The pipelining
strategy will look for a copy to shared memory and pipeline it to overlap
@@ -433,6 +444,7 @@
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::scf::ForOp forOp,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -443,7 +455,8 @@
Transform_Dialect, "iree.synchronize_loop", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
This inserts a gpu.barrier after a given scf.for loop.
@@ -466,6 +479,7 @@
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::scf::ForOp forOp,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -476,7 +490,8 @@
Op<Transform_Dialect, "iree.create_async_groups",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Convert copies to shared memory to async copies. This creates groups
of consecutive copies and emit wait operation right after.
@@ -502,6 +517,7 @@
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::func::FuncOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -513,7 +529,8 @@
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Targets the whole func op and does the following:
@@ -540,6 +557,7 @@
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::func::FuncOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -552,7 +570,8 @@
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Targets the whole func op and finds transpose ops whose source
comes from an elementwise op. For each of those transpose ops,
@@ -584,6 +603,7 @@
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::func::FuncOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -596,7 +616,8 @@
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Removes unnecessary GPU barriers from the function. If a barrier does not
enforce any conflicting pair of memory effects, including a pair that is
@@ -620,6 +641,7 @@
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::func::FuncOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPULayoutAnalysisAndDistribution.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPULayoutAnalysisAndDistribution.cpp
index 180ff4c..0671b2b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPULayoutAnalysisAndDistribution.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPULayoutAnalysisAndDistribution.cpp
@@ -240,7 +240,7 @@
// before the op.
static void createLayoutConflictOp(Value value, Layout targetLayout,
DenseMap<Value, Layout> &layoutMap,
- Operation *op, IRRewriter &rewriter) {
+ Operation *op, RewriterBase &rewriter) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
vector::ShapeCastOp conflictOp = rewriter.create<vector::ShapeCastOp>(
@@ -253,7 +253,7 @@
static void setMMALayout(Value aMatrix, Value bMatrix, Value cMatrix,
Value dMatrix, DenseMap<Value, Layout> &layoutMap,
- Operation *op, IRRewriter &rewriter) {
+ Operation *op, RewriterBase &rewriter) {
// First determine which variant of MMA this op is most suitable for
auto aType = llvm::cast<ShapedType>(aMatrix.getType());
auto bType = llvm::cast<ShapedType>(aMatrix.getType());
@@ -977,7 +977,7 @@
static void distributeFor(scf::ForOp forOp, DenseMap<Value, Layout> &layoutMap,
DenseMap<Value, Value> &simdToSimtMap,
- IRRewriter &rewriter,
+ RewriterBase &rewriter,
llvm::SetVector<Operation *> &ops) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(forOp);
@@ -996,7 +996,7 @@
static void distributeYield(scf::YieldOp yieldOp,
DenseMap<Value, Layout> &layoutMap,
DenseMap<Value, Value> &simdToSimtMap,
- IRRewriter &rewriter,
+ RewriterBase &rewriter,
llvm::SetVector<Operation *> &ops) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(yieldOp);
@@ -1018,7 +1018,7 @@
static void distributeConstants(arith::ConstantOp constantOp,
DenseMap<Value, Layout> &layoutMap,
DenseMap<Value, Value> &simdToSimtMap,
- IRRewriter &rewriter,
+ RewriterBase &rewriter,
llvm::SetVector<Operation *> &ops) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(constantOp);
@@ -1044,7 +1044,7 @@
static void distributeElementwise(Operation *op,
DenseMap<Value, Layout> &layoutMap,
DenseMap<Value, Value> &simdToSimtMap,
- IRRewriter &rewriter,
+ RewriterBase &rewriter,
llvm::SetVector<Operation *> &ops) {
if (!OpTrait::hasElementwiseMappableTraits(op)) return;
if (op->getNumResults() != 1) return;
@@ -1066,7 +1066,7 @@
static Value resolveBatchConflict(SmallVectorImpl<int> &mismatchedDims,
Value vector, const Layout &targetLayout,
const Layout ¤tLayout,
- IRRewriter &rewriter, Location loc) {
+ RewriterBase &rewriter, Location loc) {
assert(mismatchedDims.size() == 1);
int batchDim = mismatchedDims[0];
VectorType vectorType = llvm::cast<VectorType>(vector.getType());
@@ -1106,7 +1106,7 @@
Value vector,
const Layout &targetLayout,
const Layout ¤tLayout,
- IRRewriter &rewriter, Location loc) {
+ RewriterBase &rewriter, Location loc) {
int numMismatchedVecDims{0};
int vecDim, batchDim;
for (auto dimType : mismatchedDims) {
@@ -1155,7 +1155,7 @@
static void distributeLayoutConflicts(vector::ShapeCastOp op,
DenseMap<Value, Layout> &layoutMap,
DenseMap<Value, Value> &simdToSimtMap,
- IRRewriter &rewriter,
+ RewriterBase &rewriter,
llvm::SetVector<Operation *> &ops) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(op);
@@ -1204,7 +1204,7 @@
}
static void eraseOps(llvm::SetVector<Operation *> &opsToErase,
- IRRewriter &rewriter) {
+ RewriterBase &rewriter) {
for (int i = opsToErase.size() - 1; i >= 0; i--) {
assert(opsToErase[i]->getUses().empty());
rewriter.eraseOp(opsToErase[i]);
@@ -1240,7 +1240,7 @@
return maps == infer({{m, k}, {n, k}, {m, n}});
}
-void doLayoutAnalysisAndDistribution(IRRewriter &rewriter,
+void doLayoutAnalysisAndDistribution(RewriterBase &rewriter,
func::FuncOp funcOp) {
// First walk through all the MMA ops and set their layouts
DenseMap<Value, Layout> layoutMap;
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp
index 6a30b5d..6301db9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp
@@ -267,7 +267,7 @@
}
}
-void reorderTranspose(IRRewriter& rewriter, func::FuncOp funcOp) {
+void reorderTranspose(RewriterBase& rewriter, func::FuncOp funcOp) {
SmallVector<vector::TransposeOp> transposeOps;
funcOp.walk([&](Operation* op) {
if (auto transposeOp = dyn_cast<vector::TransposeOp>(op)) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h
index aa9cdb7..31c4b1c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h
@@ -19,10 +19,11 @@
bool useMMASync);
/// Function to do layout analysis and distribution.
-void doLayoutAnalysisAndDistribution(IRRewriter &rewriter, func::FuncOp funcOp);
+void doLayoutAnalysisAndDistribution(RewriterBase &rewriter,
+ func::FuncOp funcOp);
/// Function to reorder transposes and elementwise ops.
-void reorderTranspose(IRRewriter &rewriter, func::FuncOp funcOp);
+void reorderTranspose(RewriterBase &rewriter, func::FuncOp funcOp);
} // namespace iree_compiler
} // namespace mlir
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
index f758e27..0ed735b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
@@ -151,18 +151,17 @@
// CHECK: %[[D19:.+]] = vector.broadcast %[[D18]] : vector<128xf32> to vector<128x128xf32>
// CHECK: %[[D20:.+]] = vector.transpose %[[D19]], [1, 0] : vector<128x128xf32> to vector<128x128xf32>
// CHECK: %[[D21:.+]] = arith.divf %[[D14]], %[[D20]] : vector<128x128xf32>
-// CHECK: %[[D22:.+]] = vector.broadcast %[[D17]] : vector<128xf32> to vector<64x128xf32>
-// CHECK: %[[D23:.+]] = vector.broadcast %[[D18]] : vector<128xf32> to vector<64x128xf32>
-// CHECK: %[[D24:.+]] = arith.divf %[[D22]], %[[D23]] : vector<64x128xf32>
-// CHECK: %[[D25:.+]] = vector.transpose %[[D24]], [1, 0] : vector<64x128xf32> to vector<128x64xf32>
-// CHECK: %[[D26:.+]] = arith.mulf %[[D25]], %[[ARG3]] : vector<128x64xf32>
-// CHECK: %[[D27:.+]] = vector.transfer_read %[[D2]][%[[WORKGROUP_ID_X]], %[[ARG0]], %[[C0]]], %[[CST_2]]
+// CHECK: %[[D22:.+]] = arith.divf %[[D17]], %[[D18]] : vector<128xf32>
+// CHECK: %[[D23:.+]] = vector.broadcast %[[D22]] : vector<128xf32> to vector<64x128xf32>
+// CHECK: %[[D24:.+]] = vector.transpose %[[D23]], [1, 0] : vector<64x128xf32> to vector<128x64xf32>
+// CHECK: %[[D25:.+]] = arith.mulf %[[D24]], %[[ARG3]] : vector<128x64xf32>
+// CHECK: %[[D26:.+]] = vector.transfer_read %[[D2]][%[[WORKGROUP_ID_X]], %[[ARG0]], %[[C0]]], %[[CST_2]]
// CHECK-SAME: {in_bounds = [true, true]} : memref<192x1024x64xf32>, vector<128x64xf32>
-// CHECK: %[[D28:.+]] = vector.contract {indexing_maps = [#[[MAP1]], #[[MAP4]], #[[MAP3]]],
+// CHECK: %[[D27:.+]] = vector.contract {indexing_maps = [#[[MAP1]], #[[MAP4]], #[[MAP3]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #[[VECTOR]].kind<add>}
-// CHECK-SAME: %[[D21]], %[[D27]], %[[D26]] : vector<128x128xf32>, vector<128x64xf32> into
+// CHECK-SAME: %[[D21]], %[[D26]], %[[D25]] : vector<128x128xf32>, vector<128x64xf32> into
// CHECK-SAME: vector<128x64xf32>
-// CHECK: scf.yield %[[D10]], %[[D18]], %[[D28]] : vector<128xf32>, vector<128xf32>, vector<128x64xf32>
+// CHECK: scf.yield %[[D10]], %[[D18]], %[[D27]] : vector<128xf32>, vector<128xf32>, vector<128x64xf32>
// CHECK: }
// CHECK: vector.transfer_write %[[D6]]#[[D2:.+]], %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds =
// CHECK-SAME: [true, true]} : vector<128x64xf32>, memref<1x128x64xf32, strided<[65536, 64, 1], offset: ?>>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_distribution_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_distribution_spec.mlir
index 443d7d2..40d1abf 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_distribution_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_distribution_spec.mlir
@@ -10,12 +10,14 @@
: (!transform.any_op) -> ()
// Late canonicalizations to cleanup and pass the checks.
- transform.apply_patterns to %variant_op {
+ %func_op = transform.structured.match ops{["func.func"]} in %variant_op
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_op {
transform.apply_patterns.iree.fold_fill_into_pad
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
} : !transform.any_op
- transform.iree.apply_licm %variant_op : !transform.any_op
- transform.iree.apply_cse %variant_op : !transform.any_op
+ transform.iree.apply_licm %func_op : !transform.any_op
+ transform.iree.apply_cse %func_op : !transform.any_op
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_warp_execute_on_lane_0_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_warp_execute_on_lane_0_spec.mlir
index 43f82f7..93a8ae6 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_warp_execute_on_lane_0_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_warp_execute_on_lane_0_spec.mlir
@@ -6,9 +6,11 @@
: (!transform.any_op) -> !transform.any_op
// Late canonicalizations to cleanup and pass the checks.
- transform.apply_patterns to %variant_op {
+ %func_op = transform.structured.match ops{["func.func"]} in %variant_op
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_op {
transform.apply_patterns.canonicalization
} : !transform.any_op
- transform.iree.apply_licm %variant_op : !transform.any_op
- transform.iree.apply_cse %variant_op : !transform.any_op
+ transform.iree.apply_licm %func_op : !transform.any_op
+ transform.iree.apply_cse %func_op : !transform.any_op
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_promote_operands.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_promote_operands.mlir
index 19cc283..cbf3e44 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_promote_operands.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_promote_operands.mlir
@@ -37,14 +37,15 @@
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
// Late canonicalizations to cleanup and pass the checks.
- transform.apply_patterns to %variant_op {
- transform.apply_patterns.iree.fold_fill_into_pad
- transform.apply_patterns.linalg.tiling_canonicalization
- transform.apply_patterns.scf.for_loop_canonicalization
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.iree.apply_licm %variant_op : !transform.any_op
- transform.iree.apply_cse %variant_op : !transform.any_op
+ %func_op = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.iree.fold_fill_into_pad
+ transform.apply_patterns.linalg.tiling_canonicalization
+ transform.apply_patterns.scf.for_loop_canonicalization
+ transform.apply_patterns.canonicalization
+ } : !transform.any_op
+ transform.iree.apply_licm %func_op : !transform.any_op
+ transform.iree.apply_cse %func_op : !transform.any_op
}
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir
index e40cca7..1dd83ee 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir
@@ -53,11 +53,12 @@
// Late canonicalizations to cleanup and pass the checks.
// Needs to occur on the whole variant to perform cse on the workgroup_count region
- transform.apply_patterns to %variant_op {
+ %func_op = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_op {
transform.apply_patterns.canonicalization
} : !transform.any_op
- transform.iree.apply_licm %variant_op : !transform.any_op
- transform.iree.apply_cse %variant_op : !transform.any_op
+ transform.iree.apply_licm %func_op : !transform.any_op
+ transform.iree.apply_cse %func_op : !transform.any_op
}
}
}
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.cpp
index ca6ac61..ec7d0a9 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.cpp
@@ -128,19 +128,18 @@
/// tiling-related canonicalization patterns, canonicalization, licm and cse
/// (in this order).
void mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(
- ImplicitLocOpBuilder &b, Value variantH,
+ ImplicitLocOpBuilder &b, Value funcH,
ApplyPatternsOpBodyBuilderFn populatePatternsFn) {
- b.create<transform::ApplyPatternsOp>(
- variantH, [&](OpBuilder &b, Location loc) {
- b.create<transform::ApplyTilingCanonicalizationPatternsOp>(loc);
- b.create<IREE::transform_dialect::ApplyFoldFillIntoPadPatternsOp>(loc);
- b.create<transform::ApplyForLoopCanonicalizationPatternsOp>(loc);
- b.create<transform::ApplyCanonicalizationPatternsOp>(loc);
- if (populatePatternsFn) populatePatternsFn(b, loc);
- });
- b.create<IREE::transform_dialect::ApplyLoopIndependentCodeMotionOp>(variantH);
+ b.create<transform::ApplyPatternsOp>(funcH, [&](OpBuilder &b, Location loc) {
+ b.create<transform::ApplyTilingCanonicalizationPatternsOp>(loc);
+ b.create<IREE::transform_dialect::ApplyFoldFillIntoPadPatternsOp>(loc);
+ b.create<transform::ApplyForLoopCanonicalizationPatternsOp>(loc);
+ b.create<transform::ApplyCanonicalizationPatternsOp>(loc);
+ if (populatePatternsFn) populatePatternsFn(b, loc);
+ });
+ b.create<IREE::transform_dialect::ApplyLoopIndependentCodeMotionOp>(funcH);
b.create<IREE::transform_dialect::ApplyCommonSubexpressionEliminationOp>(
- variantH);
+ funcH);
}
/// Dynamically selects the first non-empty handle; i.e. if (h1, h2) is:
@@ -161,7 +160,7 @@
mlir::iree_compiler::TileToScfForAndFuseResult
mlir::iree_compiler::buildTileFuseToScfFor(ImplicitLocOpBuilder &b,
- Value isolatedParentOpH, Value rootH,
+ Value variantH, Value rootH,
ValueRange opsHToFuse,
ArrayRef<OpFoldResult> tileSizes,
bool canonicalize) {
@@ -177,8 +176,9 @@
// matmuls.
// TODO: Make padding less brittle so that this toggle is unnecessary.
if (canonicalize) {
- mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(
- b, isolatedParentOpH);
+ Value funcH = b.create<transform::MatchOp>(
+ variantH, func::FuncOp::getOperationName());
+ mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(b, funcH);
}
return result;
}
@@ -205,9 +205,8 @@
// TODO: apply forwarding pattern.
template <typename TileOrNumThreadSpec>
static iree_compiler::TileToForallAndFuseAndDistributeResult
-buildTileAndFuseAndDistributeImpl(ImplicitLocOpBuilder &b,
- Value isolatedParentOpH, Value rootH,
- ValueRange opsHToFuse,
+buildTileAndFuseAndDistributeImpl(ImplicitLocOpBuilder &b, Value variantH,
+ Value rootH, ValueRange opsHToFuse,
ArrayRef<OpFoldResult> tileSizesOrNumThreads,
ArrayAttr threadDimMapping) {
iree_compiler::TileToForallAndFuseAndDistributeResult result;
@@ -218,8 +217,9 @@
result.tiledOpH = tileToForeachOp.getTiledOp();
// Perform a pass of canonicalization + enabling after tiling.
- mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(
- b, isolatedParentOpH);
+ Value funcH =
+ b.create<transform::MatchOp>(variantH, func::FuncOp::getOperationName());
+ mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(b, funcH);
// Batch fusion if requested.
if (opsHToFuse.size() > 1) {
@@ -239,11 +239,10 @@
// sigh.
iree_compiler::TileToForallAndFuseAndDistributeResult
mlir::iree_compiler::buildTileFuseDistToForallWithTileSizes(
- ImplicitLocOpBuilder &b, Value isolatedParentOpH, Value rootH,
- ValueRange opsHToFuse, ArrayRef<OpFoldResult> tileSizes,
- ArrayAttr threadDimMapping) {
+ ImplicitLocOpBuilder &b, Value variantH, Value rootH, ValueRange opsHToFuse,
+ ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping) {
return buildTileAndFuseAndDistributeImpl<transform::TileSizesSpec>(
- b, isolatedParentOpH, rootH, opsHToFuse, tileSizes, threadDimMapping);
+ b, variantH, rootH, opsHToFuse, tileSizes, threadDimMapping);
}
/// Call buildTileAndFuseAndDistributeImpl with ArrayRef<int64_t> numThreads.
@@ -251,11 +250,10 @@
// sigh.
iree_compiler::TileToForallAndFuseAndDistributeResult
mlir::iree_compiler::buildTileFuseDistToForallWithNumThreads(
- ImplicitLocOpBuilder &b, Value isolatedParentOpH, Value rootH,
- ValueRange opsHToFuse, ArrayRef<OpFoldResult> numThreads,
- ArrayAttr threadDimMapping) {
+ ImplicitLocOpBuilder &b, Value variantH, Value rootH, ValueRange opsHToFuse,
+ ArrayRef<OpFoldResult> numThreads, ArrayAttr threadDimMapping) {
return buildTileAndFuseAndDistributeImpl<transform::NumThreadsSpec>(
- b, isolatedParentOpH, rootH, opsHToFuse, numThreads, threadDimMapping);
+ b, variantH, rootH, opsHToFuse, numThreads, threadDimMapping);
}
/// Build the transform IR to pad an op `opH`.
@@ -287,42 +285,37 @@
return funcH;
}
-Value mlir::iree_compiler::buildLowerMaskedTransfersAndCleanup(
- ImplicitLocOpBuilder &b, Value containingOpH, bool cleanup) {
+void mlir::iree_compiler::buildLowerMaskedTransfersAndCleanup(
+ ImplicitLocOpBuilder &b, Value funcH, bool cleanup) {
// TODO: avoid functional style transform so we can apply to the variant.
- b.create<transform::ApplyPatternsOp>(
- containingOpH, [](OpBuilder &b, Location loc) {
- b.create<transform::ApplyLowerMaskedTransfersPatternsOp>(loc);
- });
+ b.create<transform::ApplyPatternsOp>(funcH, [](OpBuilder &b, Location loc) {
+ b.create<transform::ApplyLowerMaskedTransfersPatternsOp>(loc);
+ });
if (cleanup) {
- b.create<transform::ApplyPatternsOp>(
- containingOpH, [](OpBuilder &b, Location loc) {
- b.create<transform::ApplyCastAwayVectorLeadingOneDimPatternsOp>(loc);
- b.create<transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp>(loc);
- b.create<IREE::transform_dialect::
- ApplyFoldReshapeIntoTensorHalInterfacePatternsOp>(loc);
- });
+ b.create<transform::ApplyPatternsOp>(funcH, [](OpBuilder &b, Location loc) {
+ b.create<transform::ApplyCastAwayVectorLeadingOneDimPatternsOp>(loc);
+ b.create<transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp>(loc);
+ b.create<IREE::transform_dialect::
+ ApplyFoldReshapeIntoTensorHalInterfacePatternsOp>(loc);
+ });
}
- return containingOpH;
}
Value mlir::iree_compiler::buildLowerVectorMasksAndCleanup(
- ImplicitLocOpBuilder &b, Value containingOpH, bool cleanup) {
- b.create<transform::ApplyPatternsOp>(
- containingOpH, [](OpBuilder &b, Location loc) {
- b.create<transform::ApplyLowerMasksPatternsOp>(loc);
- });
- b.create<transform::ApplyPatternsOp>(
- containingOpH, [](OpBuilder &b, Location loc) {
- b.create<transform::ApplyMaterializeMasksPatternsOp>(loc);
- });
+ ImplicitLocOpBuilder &b, Value funcH, bool cleanup) {
+ b.create<transform::ApplyPatternsOp>(funcH, [](OpBuilder &b, Location loc) {
+ b.create<transform::ApplyLowerMasksPatternsOp>(loc);
+ });
+ b.create<transform::ApplyPatternsOp>(funcH, [](OpBuilder &b, Location loc) {
+ b.create<transform::ApplyMaterializeMasksPatternsOp>(loc);
+ });
if (cleanup) {
iree_compiler::buildCanonicalizationAndEnablingTransforms(
- b, containingOpH, [](OpBuilder &b, Location loc) {
+ b, funcH, [](OpBuilder &b, Location loc) {
b.create<transform::ApplyFoldMemrefAliasOpsPatternsOp>(loc);
});
}
- return containingOpH;
+ return funcH;
}
/// Hoist redundant subet ops.
@@ -335,8 +328,10 @@
Value variantH, bool targetGpu) {
// Perform a pass of canonicalization + enabling before bufferization to avoid
// spurious allocations.
+ Value funcH =
+ b.create<transform::MatchOp>(variantH, func::FuncOp::getOperationName());
buildCanonicalizationAndEnablingTransforms(
- b, variantH, [](OpBuilder &b, Location loc) {
+ b, funcH, [](OpBuilder &b, Location loc) {
b.create<transform::ApplyReassociativeReshapeFoldingPatternsOp>(loc);
b.create<transform::ApplyFoldTensorSliceIntoTransferPatternsOp>(loc);
});
@@ -457,7 +452,7 @@
TileToForallAndFuseAndDistributeResult tileResult =
buildTileFuseDistToForallWithTileSizes(
/*builder=*/b,
- /*isolatedParentOpH=*/variantH,
+ /*variantH=*/variantH,
/*rootH=*/fusionTargetH,
/*opsToFuseH=*/fusionGroupH,
/*tileSizes=*/
@@ -475,7 +470,9 @@
.getFusedOp();
// Perform a pass of canonicalization + enabling after fusion.
- mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(b, variantH);
+ Value funcH =
+ b.create<transform::MatchOp>(variantH, func::FuncOp::getOperationName());
+ mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(b, funcH);
// Step 3. Normalize to reorder results irrespective of emptiness.
auto [blockReductionH, maybeBlockTrailingH] = buildSelectFirstNonEmpty(
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.h b/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.h
index 3018566..66a7ef1 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.h
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.h
@@ -79,7 +79,7 @@
/// In addition to the specified transform, perform the following ones:
/// canonicalization, tiling_canonicalization, licm and cse (in this order).
void buildCanonicalizationAndEnablingTransforms(
- ImplicitLocOpBuilder &b, Value variantH,
+ ImplicitLocOpBuilder &b, Value funcH,
ApplyPatternsOpBodyBuilderFn populatePatternsFn = nullptr);
/// Build transform IR to dynamically selects the first non-empty handle; i.e.
@@ -112,9 +112,8 @@
/// Build transform IR to perform multi-level tile and fuse into an scf.for op.
/// Note: fusion is currently unsupported.
TileToScfForAndFuseResult buildTileFuseToScfFor(
- ImplicitLocOpBuilder &b, Value isolatedParentOpH, Value rootH,
- ValueRange opsHToFuse, ArrayRef<OpFoldResult> tileSizes,
- bool canonicalize = true);
+ ImplicitLocOpBuilder &b, Value variantH, Value rootH, ValueRange opsHToFuse,
+ ArrayRef<OpFoldResult> tileSizes, bool canonicalize = true);
/// Result of the combined transform performing tiling, fusion and
/// distribution to parallel constructs.
@@ -154,16 +153,14 @@
///
// TODO: if someone knows how to properly export templates go for it .. sigh.
TileToForallAndFuseAndDistributeResult buildTileFuseDistToForallWithTileSizes(
- ImplicitLocOpBuilder &b, Value isolatedParentOpH, Value rootH,
- ValueRange opsHToFuse, ArrayRef<OpFoldResult> tileSizes,
- ArrayAttr threadDimMapping);
+ ImplicitLocOpBuilder &b, Value variantH, Value rootH, ValueRange opsHToFuse,
+ ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping);
/// Similar to `buildTileFuseDistWithTileSizes` but using `numThreads` instead
/// of `tileSizes`.
TileToForallAndFuseAndDistributeResult buildTileFuseDistToForallWithNumThreads(
- ImplicitLocOpBuilder &b, Value isolatedParentOpH, Value rootH,
- ValueRange opsHToFuse, ArrayRef<OpFoldResult> numThreads,
- ArrayAttr threadDimMapping);
+ ImplicitLocOpBuilder &b, Value variantH, Value rootH, ValueRange opsHToFuse,
+ ArrayRef<OpFoldResult> numThreads, ArrayAttr threadDimMapping);
/// Build transform IR to split the reduction into a parallel and combiner part.
/// Then tile the parallel part and map it to `tileSize` threads, each reducing
@@ -195,16 +192,15 @@
/// operations and subsequent cleanup patterns (fold-memref-aliases).
/// Takes a handle to a containing op and returns an updated handle to the
/// containing op.
-Value buildLowerMaskedTransfersAndCleanup(ImplicitLocOpBuilder &b,
- Value containingOpH,
- bool cleanup = true);
+void buildLowerMaskedTransfersAndCleanup(ImplicitLocOpBuilder &b, Value funcH,
+ bool cleanup = true);
/// Build transform IR that applies vector mask lowering and subsequent cleanup
/// patterns (fold-memref-aliases).
/// Takes a handle to a containing op and returns an updated handle to the
/// containing op.
-Value buildLowerVectorMasksAndCleanup(ImplicitLocOpBuilder &b,
- Value containingOpH, bool cleanup = true);
+Value buildLowerVectorMasksAndCleanup(ImplicitLocOpBuilder &b, Value funcH,
+ bool cleanup = true);
/// Build transform IR to hoist redundant subset operations.
void buildHoisting(ImplicitLocOpBuilder &b, Value funcH);
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.cpp
index 39dede0..efae0ad 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.cpp
@@ -181,9 +181,9 @@
//===----------------------------------------------------------------------===//
void mlir::iree_compiler::gpu::
build1DSplittingStrategyWithOptionalThreadMapping(
- ImplicitLocOpBuilder &b, Value isolatedParentOpH, Value opH,
- int64_t rank, int64_t mostMinorDim, SmallVector<int64_t> opSizes,
- int64_t numThreads, Attribute mappingAttr, int64_t maxVectorSize) {
+ ImplicitLocOpBuilder &b, Value variantH, Value opH, int64_t rank,
+ int64_t mostMinorDim, SmallVector<int64_t> opSizes, int64_t numThreads,
+ Attribute mappingAttr, int64_t maxVectorSize) {
// Poor man's handling of optionality in C++. Will need to be converted to
// proper transform dialect filters or handling of emptiness.
if (rank == 0) return;
@@ -216,7 +216,7 @@
if (vectorSize > 1) {
auto res = iree_compiler::buildTileFuseToScfFor(
/*b=*/b,
- /*isolatedParentOpH=*/isolatedParentOpH,
+ /*variantH=*/variantH,
/*rootH=*/opH,
/*opsHToFuse=*/{},
/*tileSizes=*/
@@ -230,7 +230,7 @@
assert(mappingAttr && "must specify a mapping attribute");
iree_compiler::buildTileFuseDistToForallWithNumThreads(
/*b=*/b,
- /*isolatedParentOpH=*/isolatedParentOpH,
+ /*variantH=*/variantH,
/*rootH=*/opH,
/*opsHToFuse=*/{},
/*numThreads=*/getAsOpFoldResult(b.getI64ArrayAttr(foreachTileSizes)),
@@ -243,7 +243,7 @@
if (vectorSize > 1) {
auto res = iree_compiler::buildTileFuseToScfFor(
/*b=*/b,
- /*isolatedParentOpH=*/isolatedParentOpH,
+ /*variantH=*/variantH,
/*rootH=*/opH,
/*opsHToFuse=*/{},
/*tileSizes=*/getAsOpFoldResult(b.getI64ArrayAttr({scfForTileSizes})));
@@ -253,7 +253,7 @@
assert(mappingAttr && "must specify a mapping attribute");
iree_compiler::buildTileFuseDistToForallWithNumThreads(
/*b=*/b,
- /*isolatedParentOpH=*/isolatedParentOpH,
+ /*variantH=*/variantH,
/*rootH=*/opH,
/*opsHToFuse=*/{},
/*numThreads=*/getAsOpFoldResult(b.getI64ArrayAttr(foreachTileSizes)),
@@ -298,7 +298,7 @@
// Step N. Perform a final pass of canonicalization + enabling before
// returning.
mlir::iree_compiler::buildCanonicalizationAndEnablingTransforms(
- b, variantH, [](OpBuilder &b, Location loc) {
+ b, funcH, [](OpBuilder &b, Location loc) {
b.create<transform::ApplyFoldTensorEmptyPatternsOp>(loc);
});
return std::make_pair(variantH, funcH);
@@ -333,8 +333,10 @@
// Perform a pass of canonicalization cleanups + folding fill + pad into pad
// by applying `foldTensorSubsets` and `tilingCanonicalization`.
{
+ Value funcH = b.create<transform::MatchOp>(
+ variantH, func::FuncOp::getOperationName());
iree_compiler::buildCanonicalizationAndEnablingTransforms(
- b, variantH, [](OpBuilder &b, Location loc) {
+ b, funcH, [](OpBuilder &b, Location loc) {
b.create<transform::ApplyFoldTensorSubsetOpsPatternsOp>(loc);
b.create<
transform::ApplyMergeConsecutiveInsertExtractSlicePatternsOp>(
@@ -365,7 +367,7 @@
TileToForallAndFuseAndDistributeResult res =
buildTileFuseDistToForallWithTileSizes(
/*builder=*/b,
- /*isolatedParentOpH=*/variantH,
+ /*variantH=*/variantH,
/*rootH=*/copyOpH,
/*opsToFuseH=*/{},
/*tileSizes=*/
@@ -393,7 +395,7 @@
TileToForallAndFuseAndDistributeResult res =
buildTileFuseDistToForallWithNumThreads(
/*builder=*/b,
- /*isolatedParentOpH=*/variantH,
+ /*variantH=*/variantH,
/*rootH=*/copyOpH,
/*opsToFuseH=*/{},
/*numThreads=*/
@@ -479,7 +481,9 @@
// Also, no canonicalization is allowed after vector masking and before we
// lower the masks: masks are currently quite brittle and do not like
// canonicalization or anything else that may insert an op in their region.
- iree_compiler::buildCanonicalizationAndEnablingTransforms(b, variantH);
+ Value funcH =
+ b.create<transform::MatchOp>(variantH, func::FuncOp::getOperationName());
+ iree_compiler::buildCanonicalizationAndEnablingTransforms(b, funcH);
// Apply vector masking.
if (!strategy.alignedLhs()) {
@@ -501,9 +505,9 @@
// Lower all masked vector transfers at this point, as they make
// canonicalization generate incorrect IR.
// TODO: don't rematch, apply on the variant op directly.
- Value funcH =
+ funcH =
b.create<transform::MatchOp>(variantH, func::FuncOp::getOperationName());
- funcH = buildLowerMaskedTransfersAndCleanup(b, funcH, /*cleanup=*/false);
+ buildLowerMaskedTransfersAndCleanup(b, funcH, /*cleanup=*/false);
// Apply vectorization + cleanups to what remains.
funcH = iree_compiler::buildVectorize(b, funcH, /*applyCleanups=*/true);
@@ -646,13 +650,14 @@
Value mlir::iree_compiler::gpu::buildBufferize(ImplicitLocOpBuilder &b,
Value variantH) {
- b.create<transform::ApplyPatternsOp>(
- variantH, [](OpBuilder &b, Location loc) {
- b.create<transform::ApplyCanonicalizationPatternsOp>(loc);
- });
- b.create<IREE::transform_dialect::ApplyLoopIndependentCodeMotionOp>(variantH);
+ Value funcH =
+ b.create<transform::MatchOp>(variantH, func::FuncOp::getOperationName());
+ b.create<transform::ApplyPatternsOp>(funcH, [](OpBuilder &b, Location loc) {
+ b.create<transform::ApplyCanonicalizationPatternsOp>(loc);
+ });
+ b.create<IREE::transform_dialect::ApplyLoopIndependentCodeMotionOp>(funcH);
b.create<IREE::transform_dialect::ApplyCommonSubexpressionEliminationOp>(
- variantH);
+ funcH);
b.create<IREEEliminateEmptyTensorsOp>(variantH);
auto bufferizeOp = b.create<IREEBufferizeOp>(variantH, /*targetGpu=*/true);
bufferizeOp.setTargetGpu(true);
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.h b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.h
index f1b64ba..ce2a0fd 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.h
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.h
@@ -112,7 +112,7 @@
// in schedule complexity and can be handled with simple padding of the
// underlying allocation.
void build1DSplittingStrategyWithOptionalThreadMapping(
- ImplicitLocOpBuilder &b, Value isolatedParentOpH, Value opH, int64_t rank,
+ ImplicitLocOpBuilder &b, Value variantH, Value opH, int64_t rank,
int64_t mostMinorDim, SmallVector<int64_t> opSizes, int64_t numThreads,
Attribute mappingAttr = Attribute(), int64_t maxVectorSize = 4);
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp
index fd9fd84..568e2f1 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp
@@ -180,7 +180,7 @@
TileToForallAndFuseAndDistributeResult tileResult =
buildTileFuseDistToForallWithTileSizes(
/*builder=*/b,
- /*isolatedParentOpH=*/variantH,
+ /*variantH=*/variantH,
/*rootH=*/matmulH,
/*opsToFuseH=*/fillH,
/*tileSizes=*/
@@ -233,7 +233,9 @@
// Running canonicalization is required here to enable aligned pads to become
// linalg.copy ops when rewriting in DPS.
- iree_compiler::buildCanonicalizationAndEnablingTransforms(b, variantH);
+ Value funcH =
+ b.create<transform::MatchOp>(variantH, func::FuncOp::getOperationName());
+ iree_compiler::buildCanonicalizationAndEnablingTransforms(b, funcH);
// Step 4. Distribute pad and copies: SIMT programming model.
auto [lhsCopyOpH, rhsCopyOpH, copyBackOpH] =
@@ -261,7 +263,7 @@
// Step 8. Post-bufferization mapping to blocks and threads.
// Need to match again since bufferize invalidated all handles.
// TODO: assumes a single func::FuncOp to transform, needs hardening.
- Value funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
+ funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
funcH = buildMapToBlockAndThreads(b, funcH, strategy.numThreads,
strategy.numWarps);
funcH = b.create<EliminateGpuBarriersOp>(funcH);
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/PadStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/PadStrategy.cpp
index be3cead..dc0b691 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/PadStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/PadStrategy.cpp
@@ -34,8 +34,6 @@
using iree_compiler::blockY;
using iree_compiler::blockZ;
using iree_compiler::buildPad;
-using iree_compiler::buildTileFuseDistToForallWithNumThreads;
-using iree_compiler::buildTileFuseDistToForallWithTileSizes;
using iree_compiler::TileToForallAndFuseAndDistributeResult;
using iree_compiler::gpu::buildBufferize;
using iree_compiler::gpu::buildConvertToAsyncCopies;
@@ -123,7 +121,7 @@
// TODO: don't rematch, apply on the variant op directly.
Value funcH =
b.create<transform::MatchOp>(variantH, func::FuncOp::getOperationName());
- funcH = buildLowerMaskedTransfersAndCleanup(b, funcH);
+ buildLowerMaskedTransfersAndCleanup(b, funcH);
// Step 5. Vectorize the rest of func normally.
funcH = buildVectorize(b, funcH, /*applyCleanups=*/true);
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/SmallReductionStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/SmallReductionStrategy.cpp
index c4cd687..c664c0e 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/SmallReductionStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/SmallReductionStrategy.cpp
@@ -110,7 +110,7 @@
iree_compiler::TileToForallAndFuseAndDistributeResult tileResult =
iree_compiler::buildTileFuseDistToForallWithNumThreads(
/*builder=*/b,
- /*isolatedParentOpH=*/variantH,
+ /*variantH=*/variantH,
/*rootH=*/fusionTargetH,
/*opsToFuseH=*/fusionGroupH,
/*numThreads=*/
@@ -137,7 +137,7 @@
// part.
build1DSplittingStrategyWithOptionalThreadMapping(
/*b=*/b,
- /*isolatedParentOpH=*/variantH,
+ /*variantH=*/variantH,
/*opH=*/blockReductionH,
/*rank=*/strategy.captures.reductionRank,
// TODO: capture and generalize mostMinorDim.
@@ -150,7 +150,7 @@
// mapping part.
build1DSplittingStrategyWithOptionalThreadMapping(
/*b=*/b,
- /*isolatedParentOpH=*/variantH,
+ /*variantH=*/variantH,
/*opH=*/maybeBlockTrailingH,
/*rank=*/strategy.captures.maybeTrailingRank,
// TODO: capture and generalize mostMinorDim.
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/StagedReductionStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/StagedReductionStrategy.cpp
index 030842c..6dbc487 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/StagedReductionStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/StagedReductionStrategy.cpp
@@ -123,7 +123,7 @@
}
static void buildStagedReductionStrategyThreadLevel(
- ImplicitLocOpBuilder &b, Value isolatedParentOpH, Value gridReductionH,
+ ImplicitLocOpBuilder &b, Value variantH, Value gridReductionH,
Value gridFillH, Value maybeTiledLeadingH, Value maybeTiledTrailingH,
const StagedReductionStrategy &strategy) {
MLIRContext *ctx = b.getContext();
@@ -136,7 +136,7 @@
assert((vectorSize & (vectorSize - 1)) == 0 && "size must be power of 2");
build1DSplittingStrategyWithOptionalThreadMapping(
/*b=*/b,
- /*isolatedParentOpH=*/isolatedParentOpH,
+ /*variantH=*/variantH,
/*opH=*/maybeTiledLeadingH,
/*rank=*/strategy.captures.maybeLeadingRank,
// TODO: capture and generalize mostMinorDim.
@@ -151,7 +151,7 @@
auto [blockParallelForallOp, blockParallelFillH, blockCombinerOpH] =
buildTileReductionUsingScfForeach(
/*b=*/b,
- /*isolatedParentOpH=*/isolatedParentOpH,
+ /*isolatedParentOpH=*/variantH,
/*reductionH=*/gridReductionH,
/*reductionRank=*/strategy.captures.reductionRank,
/*tileSize=*/strategy.getNumThreadsInBlock().front(),
@@ -185,7 +185,7 @@
}
iree_compiler::buildTileFuseDistToForallWithTileSizes(
/*b=*/b,
- /*isolatedParentOpH=*/isolatedParentOpH,
+ /*variantH=*/variantH,
/*rootH=*/root,
/*opsToFuse=*/opsToFuse,
/*tileSizes=*/getAsOpFoldResult(b.getI64ArrayAttr({1})),
@@ -199,7 +199,7 @@
strategy.captures.maybeTrailingOutputElementalTypeBitWidth;
build1DSplittingStrategyWithOptionalThreadMapping(
/*b=*/b,
- /*isolatedParentOpH=*/isolatedParentOpH,
+ /*variantH=*/variantH,
/*opH=*/maybeTiledTrailingH,
/*rank=*/strategy.captures.maybeTrailingRank,
// TODO: capture and generalize mostMinorDim.
@@ -228,10 +228,9 @@
// Step 2. Split the reduction and tile the pieces to ensure vector
// load/stores and mapping to a single warp with shuffles.
- buildStagedReductionStrategyThreadLevel(
- b,
- /*isolatedParentOpH=*/variantH, gridReductionH, gridFillH,
- maybeLeadingHBlock, maybeTiledTrailingHBlock, strategy);
+ buildStagedReductionStrategyThreadLevel(b, variantH, gridReductionH,
+ gridFillH, maybeLeadingHBlock,
+ maybeTiledTrailingHBlock, strategy);
// Step 3. Make sure we don't create allocation by sharing forall
// output. This amounts to injecting user-defined static information that each
diff --git a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp
index 1c860d7..0764687 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp
@@ -457,11 +457,11 @@
DiagnosedSilenceableFailure
transform_dialect::ForeachThreadToFlowDispatchWorkgroupsOp::applyToOne(
- scf::ForallOp target, transform::ApplyToEachResultList &results,
- transform::TransformState &) {
- SimplePatternRewriter rewriter(target->getContext());
+ transform::TransformRewriter &rewriter, scf::ForallOp target,
+ transform::ApplyToEachResultList &results, transform::TransformState &) {
+ SimplePatternRewriter patternRewriter(target->getContext());
FailureOr<Flow::DispatchWorkgroupsOp> result =
- rewriteForeachThreadToFlowDispatchWorkgroups(target, rewriter);
+ rewriteForeachThreadToFlowDispatchWorkgroups(target, patternRewriter);
if (failed(result)) return emitDefaultDefiniteFailure(target);
results.push_back(*result);
return DiagnosedSilenceableFailure::success();
@@ -475,9 +475,8 @@
}
DiagnosedSilenceableFailure transform_dialect::RegionToWorkgroupsOp::applyToOne(
- Flow::DispatchRegionOp target, transform::ApplyToEachResultList &results,
- transform::TransformState &) {
- IRRewriter rewriter(target->getContext());
+ transform::TransformRewriter &rewriter, Flow::DispatchRegionOp target,
+ transform::ApplyToEachResultList &results, transform::TransformState &) {
FailureOr<Flow::DispatchWorkgroupsOp> result =
rewriteFlowDispatchRegionToFlowDispatchWorkgroups(target, rewriter);
if (failed(result)) return emitDefaultDefiniteFailure(target);
@@ -494,6 +493,7 @@
DiagnosedSilenceableFailure
transform_dialect::ClonePrecedingOpIntoDispatchRegionOp::apply(
+ transform::TransformRewriter &rewriter,
transform::TransformResults &transformResults,
transform::TransformState &state) {
auto targetOps = state.getPayloadOps(getTarget());
@@ -522,7 +522,6 @@
assert(sortResult && "unable to sort topologically");
SmallVector<Operation *> orderedTargets =
llvm::to_vector(llvm::reverse(targetOps));
- IRRewriter rewriter(regionOp->getContext());
SmallVector<Operation *> clonedTargets;
for (Operation *target : orderedTargets) {
FailureOr<Operation *> clonedTarget =
@@ -545,6 +544,7 @@
DiagnosedSilenceableFailure
transform_dialect::MovePrecedingOpIntoDispatchRegionOp::apply(
+ transform::TransformRewriter &rewriter,
transform::TransformResults &transformResults,
transform::TransformState &state) {
auto targetOps = state.getPayloadOps(getTarget());
@@ -573,7 +573,6 @@
assert(sortResult && "unable to sort topologically");
SmallVector<Operation *> orderedTargets =
llvm::to_vector(llvm::reverse(targetOps));
- IRRewriter rewriter(regionOp->getContext());
for (Operation *target : orderedTargets) {
auto newRegionOp =
movePrecedingOpsIntoDispatchRegion(rewriter, target, regionOp);
@@ -744,6 +743,7 @@
DiagnosedSilenceableFailure
transform_dialect::CloneSucceedingOpIntoDispatchRegionOp::apply(
+ transform::TransformRewriter &rewriter,
transform::TransformResults &transformResults,
transform::TransformState &state) {
auto targetOps = state.getPayloadOps(getTarget());
@@ -760,7 +760,6 @@
bool sortResult = computeTopologicalSorting(orderedTargets);
(void)sortResult;
assert(sortResult && "unable to sort topologically");
- IRRewriter rewriter(regionOp->getContext());
SmallVector<Operation *> newTargets;
for (Operation *target : orderedTargets) {
auto newTarget =
@@ -784,6 +783,7 @@
DiagnosedSilenceableFailure
transform_dialect::MoveSucceedingOpIntoDispatchRegionOp::apply(
+ transform::TransformRewriter &rewriter,
transform::TransformResults &transformResults,
transform::TransformState &state) {
auto targetOps = state.getPayloadOps(getTarget());
@@ -802,7 +802,6 @@
bool sortResult = computeTopologicalSorting(orderedTargets);
(void)sortResult;
assert(sortResult && "unable to sort topologically");
- IRRewriter rewriter(regionOp->getContext());
for (Operation *target : orderedTargets) {
auto newRegionOp =
moveSucceedingOpIntoDispatchRegion(rewriter, target, regionOp);
@@ -826,9 +825,9 @@
DiagnosedSilenceableFailure
transform_dialect::WrapInDispatchRegionOp::applyToOne(
- Operation *target, transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, Operation *target,
+ transform::ApplyToEachResultList &results,
transform::TransformState &state) {
- IRRewriter rewriter(target->getContext());
auto regionOp = Flow::wrapOpInDispatchRegion(rewriter, target);
if (failed(regionOp)) return emitDefaultDefiniteFailure(target);
diff --git a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensionsOps.td b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensionsOps.td
index bed7a0d..c915fcf 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensionsOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensionsOps.td
@@ -16,7 +16,8 @@
[FunctionalStyleTransformOpTrait,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Rewrite an scf.forall to Flow::DispatchWorkgroups.
@@ -44,6 +45,7 @@
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::scf::ForallOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -54,7 +56,8 @@
[FunctionalStyleTransformOpTrait,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformOpInterface,
- TransformEachOpTrait]> {
+ TransformEachOpTrait,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Convert a flow.dispatch.region op into a flow.dispatch.workgroups op.
@@ -75,6 +78,7 @@
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::iree_compiler::IREE::Flow::DispatchRegionOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -86,7 +90,8 @@
[FunctionalStyleTransformOpTrait,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformOpInterface,
- TransformEachOpTrait]> {
+ TransformEachOpTrait,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Wrap the `target` op in a new `dispatch.region` op. All uses of target op
are replaces with the results of the newly generated `dispach.region` op.
@@ -108,6 +113,7 @@
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -117,7 +123,8 @@
def ClonePrecedingOpIntoDispatchRegionOp : Op<
Transform_Dialect, "iree.clone_preceding_op_into_dispatch_region", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- TransformOpInterface]> {
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Clone the `target` op into the given dispatch region op. The dispatch region
handle must be mapped to exactly one payload op.
@@ -140,17 +147,13 @@
`:` functional-type(operands, results)
}];
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure apply(
- ::mlir::transform::TransformResults &transformResults,
- ::mlir::transform::TransformState &state);
- }];
}
def MovePrecedingOpIntoDispatchRegionOp : Op<
Transform_Dialect, "iree.move_preceding_op_into_dispatch_region", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- TransformOpInterface]> {
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Move the `target` op into the given dispatch region op. The dispatch region
handle must be mapped to exactly one payload op.
@@ -176,17 +179,13 @@
`:` functional-type(operands, results)
}];
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure apply(
- ::mlir::transform::TransformResults &transformResults,
- ::mlir::transform::TransformState &state);
- }];
}
def CloneSucceedingOpIntoDispatchRegionOp : Op<
Transform_Dialect, "iree.clone_succeeding_op_into_dispatch_region", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- TransformOpInterface]> {
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Clone the `target` op into the given dispatch region op. The dispatch region
handle must be mapped to exactly one payload op.
@@ -212,17 +211,13 @@
`:` functional-type(operands, results)
}];
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure apply(
- ::mlir::transform::TransformResults &transformResults,
- ::mlir::transform::TransformState &state);
- }];
}
def MoveSucceedingOpIntoDispatchRegionOp : Op<
Transform_Dialect, "iree.move_succeeding_op_into_dispatch_region", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- TransformOpInterface]> {
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Move the `target` op into the given dispatch region op. The dispatch region
handle must be mapped to exactly one payload op.
@@ -248,11 +243,6 @@
`:` functional-type(operands, results)
}];
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure apply(
- ::mlir::transform::TransformResults &transformResults,
- ::mlir::transform::TransformState &state);
- }];
}
#endif // IREE_COMPILER_DIALECT_FLOW_TRANSFORMEXTENSIONS_FLOWEXTENSIONS
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
index 883134f..1f4f70d 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
@@ -824,7 +824,7 @@
funcOp.setType(rewriter.getFunctionType(newTypes, {}));
}
- rewriter.replaceOp(flowOp, {});
+ rewriter.eraseOp(flowOp);
return success();
}
};
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
index b7c3905..c7b203e 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
@@ -54,7 +54,7 @@
for (Block &block : llvm::make_early_inc_range(blockRange)) {
rewriter.eraseBlock(&block);
}
- rewriter.replaceOp(srcOp, {});
+ rewriter.eraseOp(srcOp);
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointToEnd(&newModuleOp.getBodyRegion().front());
rewriter.create<IREE::VM::ModuleTerminatorOp>(srcOp.getLoc());
@@ -153,7 +153,7 @@
// vm.export ops.
newFuncOp.setPrivate();
- rewriter.replaceOp(srcOp, std::nullopt);
+ rewriter.eraseOp(srcOp);
return success();
}
};
@@ -228,7 +228,7 @@
// Retain function attributes in the allowlist.
copyImportAttrs(srcOp, importOp);
- rewriter.replaceOp(srcOp, std::nullopt);
+ rewriter.eraseOp(srcOp);
return success();
}
};
diff --git a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
index f947606..5c8c5e5 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
+++ b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
@@ -194,7 +194,7 @@
return failure();
}
- rewriter.replaceOp(srcOp, std::nullopt);
+ rewriter.eraseOp(srcOp);
return success();
}
};
diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel
index 1bd4185..b1dcbce 100644
--- a/compiler/src/iree/compiler/Tools/BUILD.bazel
+++ b/compiler/src/iree/compiler/Tools/BUILD.bazel
@@ -110,6 +110,7 @@
"@llvm-project//mlir:ConversionPasses",
"@llvm-project//mlir:EmitCDialect",
"@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:FuncExtensions",
"@llvm-project//mlir:FuncToSPIRV",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:GPUToSPIRV",
diff --git a/compiler/src/iree/compiler/Tools/CMakeLists.txt b/compiler/src/iree/compiler/Tools/CMakeLists.txt
index fd1a433..f1e31ca 100644
--- a/compiler/src/iree/compiler/Tools/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Tools/CMakeLists.txt
@@ -130,6 +130,7 @@
MLIRBufferizationDialect
MLIRComplexDialect
MLIRControlFlowDialect
+ MLIRFuncInlinerExtension
MLIRGPUDialect
MLIRGPUToSPIRV
MLIRIR
diff --git a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h
index 2b4f034..114b756 100644
--- a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h
+++ b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -71,6 +72,7 @@
transform::TransformDialect,
shape::ShapeDialect>();
// clang-format on
+ func::registerInlinerExtension(registry);
tensor::registerInferTypeOpInterfaceExternalModels(registry);
tensor::registerTilingInterfaceExternalModels(registry);
diff --git a/llvm-external-projects/iree-dialects/BUILD.bazel b/llvm-external-projects/iree-dialects/BUILD.bazel
index 2ed114a..9242932 100644
--- a/llvm-external-projects/iree-dialects/BUILD.bazel
+++ b/llvm-external-projects/iree-dialects/BUILD.bazel
@@ -690,6 +690,7 @@
"@llvm-project//mlir:BufferizationTransformOps",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:FuncExtensions",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:LinalgTransformOps",
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
index f39f5f6..2005bbe 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
@@ -170,7 +170,7 @@
// Transform dialect version of tile and decompose attention
SmallVector<Operation *>
tileAndDecomposeAttention(IREE::LinalgExt::AttentionOp attnOp,
- IRRewriter &rewriter);
+ RewriterBase &rewriter);
// Creates a pass to convert the attention op into a sequence of
// linalg ops.
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
index 70e96dc..2c62d6e 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
@@ -15,7 +15,8 @@
def FuseProducersOp : Op<Transform_Dialect, "fuse_producers",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
- DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{Fuses the producers for the operands to fuse.}];
let arguments = (ins TransformHandleTypeInterface:$target,
@@ -33,7 +34,8 @@
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Rewrite a bufferized scf.forall op to the async dialect.
@@ -58,6 +60,7 @@
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::scf::ForallOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -69,7 +72,8 @@
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformEachOpTrait,
- TransformOpInterface]> {
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Rewrite a bufferized scf.forall to a sequential scf.for.
@@ -94,6 +98,7 @@
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::scf::ForallOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
@@ -104,7 +109,8 @@
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformOpInterface,
- TransformEachOpTrait]> {
+ TransformEachOpTrait,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Target iree_linalg_ext.attention ops and tile and decompose them.
This transform consumes the target handle and produces a result handle.
@@ -128,6 +134,7 @@
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::iree_compiler::IREE::LinalgExt::AttentionOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td
index 399e467..9d54633 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td
@@ -18,7 +18,8 @@
def LowerToLLVMOp : Op<Transform_Dialect, "lower_to_llvm",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
- DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Indicates that the entire targeted module should be converted
to the LLVM dialect. This is expected to be the last transformation in
@@ -47,7 +48,8 @@
def RegisterMatchCallbacksOp :
Op<Transform_Dialect, "iree.register_match_callbacks",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Registers named structured op matcher callbacks specific for IREE to use
with `transform.iree.match_callback`. This should be called before first
@@ -80,7 +82,8 @@
def MatchCallbackOp :
Op<Transform_Dialect, "iree.match_callback",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Performs payload IR matching using a C++ callback registered beforehand.
The callback is identified by name and is passed the current transform
@@ -107,7 +110,8 @@
def TakeFirstOp :
Op<Transform_Dialect, "iree.take_first",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Given an arbitrary list of handles associated with potentially empty lists
of payload operations, produces two new handles:
@@ -134,7 +138,8 @@
def EmitRemarkOp :
Op<Transform_Dialect, "iree.emit_remark",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- TransformOpInterface, TransformEachOpTrait]> {
+ TransformOpInterface, TransformEachOpTrait,
+ ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Emits a diagnostic remark with the given message located at payload ops
associated with the given handle. This can be used, e.g., for debugging.
@@ -147,6 +152,7 @@
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
index 051cdbd..6c8862b 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeAttention.cpp
@@ -267,7 +267,7 @@
SmallVector<Operation *>
tileAndDecomposeAttention(IREE::LinalgExt::AttentionOp attnOp,
- IRRewriter &rewriter) {
+ RewriterBase &rewriter) {
SmallVector<Operation *> ops;
Location loc = attnOp.getLoc();
OpBuilder::InsertionGuard guard(rewriter);
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
index 6647610..95ef281 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
@@ -46,7 +46,8 @@
//===---------------------------------------------------------------------===//
DiagnosedSilenceableFailure
-LinalgExt::FuseProducersOp::apply(transform::TransformResults &transformResults,
+LinalgExt::FuseProducersOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &transformResults,
transform::TransformState &state) {
SmallVector<int64_t> operandsToFuse = extractI64Array(getOperandsToFuse());
LinalgExt::LinalgExtFusionPattern pattern(getContext(), operandsToFuse);
@@ -56,10 +57,10 @@
SmallVector<SmallVector<Operation *>> fusedOps(numProducers);
for (Operation *target : state.getPayloadOps(getTarget())) {
// Apply the pattern.
- SimplePatternRewriter rewriter(target);
+ SimplePatternRewriter patternRewriter(target);
FailureOr<LinalgExt::FusionResult> result =
pattern.returningMatchAndRewrite(cast<TilingInterface>(target),
- rewriter);
+ patternRewriter);
if (failed(result))
return emitDefaultDefiniteFailure(target);
@@ -129,12 +130,13 @@
}
DiagnosedSilenceableFailure LinalgExt::RewriteForallToAsyncOp::applyToOne(
- scf::ForallOp target, transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, scf::ForallOp target,
+ transform::ApplyToEachResultList &results,
transform::TransformState &state) {
LinalgExt::ForallOpToAsyncRewriter pattern(this->getContext());
- SimplePatternRewriter rewriter(target);
+ SimplePatternRewriter patternRewriter(target);
FailureOr<Operation *> result =
- pattern.returningMatchAndRewrite(target, rewriter);
+ pattern.returningMatchAndRewrite(target, patternRewriter);
if (failed(result))
return emitDefaultDefiniteFailure(target);
results.push_back(*result);
@@ -142,12 +144,13 @@
}
DiagnosedSilenceableFailure LinalgExt::RewriteForallToScfForOp::applyToOne(
- scf::ForallOp target, transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, scf::ForallOp target,
+ transform::ApplyToEachResultList &results,
transform::TransformState &state) {
LinalgExt::ForallOpToScfForRewriter pattern(this->getContext());
- SimplePatternRewriter rewriter(target);
+ SimplePatternRewriter patternRewriter(target);
FailureOr<Operation *> result =
- pattern.returningMatchAndRewrite(target, rewriter);
+ pattern.returningMatchAndRewrite(target, patternRewriter);
if (failed(result))
return emitDefaultDefiniteFailure(target);
results.push_back(*result);
@@ -159,10 +162,9 @@
//===---------------------------------------------------------------------===//
DiagnosedSilenceableFailure LinalgExt::TileAndDecomposeAttentionOp::applyToOne(
- LinalgExt::AttentionOp attentionOp,
+ transform::TransformRewriter &rewriter, LinalgExt::AttentionOp attentionOp,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
- IRRewriter rewriter(getContext());
SmallVector<Operation *> ops =
LinalgExt::tileAndDecomposeAttention(attentionOp, rewriter);
for (auto op : ops)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
index 247292c..a344d54 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
@@ -121,7 +121,8 @@
if (options.enableVectorMasking)
vectorSizes.append(options.vectorSizeComputationFunction(
linalgOp, options.canonicalVectorSizes));
- return vectorize(rewriter, linalgOp, vectorSizes,
+ SmallVector<bool> scalableVecDims(vectorSizes.size(), false);
+ return vectorize(rewriter, linalgOp, vectorSizes, scalableVecDims,
options.vectorizeGatherAccesses);
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
index c6c8cf6..7e2963b 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
@@ -373,9 +373,10 @@
// LowerToLLVMOp
//===---------------------------------------------------------------------===//
-DiagnosedSilenceableFailure
-transform_ext::LowerToLLVMOp::apply(mlir::transform::TransformResults &result,
- mlir::transform::TransformState &state) {
+DiagnosedSilenceableFailure transform_ext::LowerToLLVMOp::apply(
+ mlir::transform::TransformRewriter &rewriter,
+ mlir::transform::TransformResults &result,
+ mlir::transform::TransformState &state) {
auto payloadOps = state.getPayloadOps(getTarget());
if (!llvm::hasSingleElement(payloadOps) ||
!isa<ModuleOp>(*payloadOps.begin()))
@@ -475,6 +476,7 @@
//===---------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform_ext::MatchCallbackOp::apply(
+ mlir::transform::TransformRewriter &rewriter,
mlir::transform::TransformResults &results,
mlir::transform::TransformState &state) {
auto setEmptyResults = [&results, this] {
@@ -907,6 +909,7 @@
//===---------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform_ext::RegisterMatchCallbacksOp::apply(
+ mlir::transform::TransformRewriter &rewriter,
mlir::transform::TransformResults &results,
mlir::transform::TransformState &state) {
auto ®istry = state.addExtension<transform_ext::MatchCallbacksRegistry>();
@@ -939,7 +942,8 @@
//===---------------------------------------------------------------------===//
DiagnosedSilenceableFailure
-transform_ext::TakeFirstOp::apply(mlir::transform::TransformResults &results,
+transform_ext::TakeFirstOp::apply(mlir::transform::TransformRewriter &rewriter,
+ mlir::transform::TransformResults &results,
mlir::transform::TransformState &state) {
SmallVector<Operation *> concatenated;
bool found = false;
@@ -973,7 +977,8 @@
//===---------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform_ext::EmitRemarkOp::applyToOne(
- Operation *target, mlir::transform::ApplyToEachResultList &results,
+ transform::TransformRewriter &rewriter, Operation *target,
+ mlir::transform::ApplyToEachResultList &results,
mlir::transform::TransformState &state) {
for (Operation *payload : state.getPayloadOps(getHandle())) {
payload->emitRemark(getMessage());
diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
index 6500e5e..fe4625d 100644
--- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
@@ -17,6 +17,7 @@
MLIRControlFlowDialect
MLIRDialect
MLIRFuncDialect
+ MLIRFuncInlinerExtension
MLIRIndexToLLVM
MLIRLinalgDialect
MLIRLinalgTransformOps
diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
index b5cc299..692f0c5 100644
--- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
+++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
@@ -92,6 +93,7 @@
mlir::test_ext::registerTestListenerPasses();
// External models.
+ mlir::func::registerInlinerExtension(registry);
mlir::linalg::registerTilingInterfaceExternalModels(registry);
registry.addExtensions<IREE::LinalgExt::LinalgExtTransformOpsExtension,
diff --git a/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir b/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir
index 5c445b9..b400114 100644
--- a/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir
+++ b/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir
@@ -14,13 +14,15 @@
// Canonicalization/CSE is needed before bufferization otherwise unnecessary
// allocs will be created.
- transform.apply_patterns to %variant_op {
+ %func_op = transform.structured.match ops{["func.func"]} in %variant_op
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_op {
transform.apply_patterns.iree.fold_fill_into_pad
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
} : !transform.any_op
- transform.iree.apply_cse %variant_op : !transform.any_op
+ transform.iree.apply_cse %func_op : !transform.any_op
%variant_op_3 = transform.iree.bufferize %variant_op : (!transform.any_op) -> (!transform.any_op)
%memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
: (!transform.any_op) -> !transform.any_op
diff --git a/tests/transform_dialect/cuda/BUILD.bazel b/tests/transform_dialect/cuda/BUILD.bazel
index ca1c2dc..8615fab 100644
--- a/tests/transform_dialect/cuda/BUILD.bazel
+++ b/tests/transform_dialect/cuda/BUILD.bazel
@@ -31,7 +31,6 @@
"reduction_eltwise.mlir",
"reduction_v2.mlir",
"reduction_v2_uneven.mlir",
- "reduction_v3.mlir",
"softmax.mlir",
"softmax_v2.mlir",
# First few ops of softmax only, acts as a proxy example.
@@ -48,7 +47,6 @@
"reduction_codegen_spec.mlir",
"reduction_eltwise_codegen_spec.mlir",
"reduction_v2_codegen_spec.mlir",
- "reduction_v3_codegen_spec.mlir",
"softmax_codegen_spec.mlir",
"softmax_v2_codegen_spec.mlir",
#
diff --git a/tests/transform_dialect/cuda/CMakeLists.txt b/tests/transform_dialect/cuda/CMakeLists.txt
index 2ccf09e..b2d4d1d 100644
--- a/tests/transform_dialect/cuda/CMakeLists.txt
+++ b/tests/transform_dialect/cuda/CMakeLists.txt
@@ -23,7 +23,6 @@
"reduction_eltwise.mlir"
"reduction_v2.mlir"
"reduction_v2_uneven.mlir"
- "reduction_v3.mlir"
"softmax.mlir"
"softmax_partial.mlir"
"softmax_v2.mlir"
@@ -40,7 +39,6 @@
reduction_codegen_spec.mlir
reduction_eltwise_codegen_spec.mlir
reduction_v2_codegen_spec.mlir
- reduction_v3_codegen_spec.mlir
softmax_codegen_spec.mlir
softmax_dispatch_spec.mlir
softmax_partial_codegen_spec.mlir
diff --git a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
index 3cb09ed..6d46e42 100644
--- a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
@@ -33,13 +33,15 @@
transform.structured.fuse_into_containing_op %fill_1d into %forall_block_combiner_op : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
// Canonicalizations.
- transform.apply_patterns to %variant_op {
+ %func_op = transform.structured.match ops{["func.func"]} in %variant_op
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_op {
transform.apply_patterns.iree.fold_fill_into_pad
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
} : !transform.any_op
- transform.iree.apply_licm %variant_op : !transform.any_op
- transform.iree.apply_cse %variant_op : !transform.any_op
+ transform.iree.apply_licm %func_op : !transform.any_op
+ transform.iree.apply_cse %func_op : !transform.any_op
%fill_2d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1x2xf32> in %variant_op
: (!transform.any_op) -> !transform.any_op
@@ -102,11 +104,13 @@
// Late Canonicalizations.
- transform.apply_patterns to %variant_op_3 {
+ %func_op_3 = transform.structured.match ops{["func.func"]} in %variant_op_3
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_op_3 {
transform.apply_patterns.iree.fold_fill_into_pad
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
} : !transform.any_op
- transform.iree.apply_licm %variant_op_3 : !transform.any_op
- transform.iree.apply_cse %variant_op_3 : !transform.any_op
+ transform.iree.apply_licm %func_op_3 : !transform.any_op
+ transform.iree.apply_cse %func_op_3 : !transform.any_op
}
diff --git a/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
index 186fd08..409b16e 100644
--- a/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
@@ -17,14 +17,16 @@
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
// Canonicalizations.
- transform.apply_patterns to %variant_op {
+ %func_op = transform.structured.match ops{["func.func"]} in %variant_op
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_op {
transform.apply_patterns.iree.fold_fill_into_pad
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
} : !transform.any_op
- transform.iree.apply_licm %variant_op : !transform.any_op
- transform.iree.apply_cse %variant_op : !transform.any_op
+ transform.iree.apply_licm %func_op : !transform.any_op
+ transform.iree.apply_cse %func_op : !transform.any_op
// Step 2. First level of tiling + fusion parallelizes to blocks. Tile the
// trailing elementwise the same way we want to tile the reduction.
@@ -38,14 +40,14 @@
transform.structured.fuse_into_containing_op %not_eltwise into %grid_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
// Canonicalizations.
- transform.apply_patterns to %variant_op {
+ transform.apply_patterns to %func_op {
transform.apply_patterns.iree.fold_fill_into_pad
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
} : !transform.any_op
- transform.iree.apply_licm %variant_op : !transform.any_op
- transform.iree.apply_cse %variant_op : !transform.any_op
+ transform.iree.apply_licm %func_op : !transform.any_op
+ transform.iree.apply_cse %func_op : !transform.any_op
// Step 3. Second level of tiling + fusion parallelizes to threads.
// ===========================================================================
@@ -62,14 +64,14 @@
transform.structured.fuse_into_containing_op %combined_and_fill into %eltwise_block_loop : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
// Canonicalizations.
- transform.apply_patterns to %variant_op {
+ transform.apply_patterns to %func_op {
transform.apply_patterns.iree.fold_fill_into_pad
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
} : !transform.any_op
- transform.iree.apply_licm %variant_op : !transform.any_op
- transform.iree.apply_cse %variant_op : !transform.any_op
+ transform.iree.apply_licm %func_op : !transform.any_op
+ transform.iree.apply_cse %func_op : !transform.any_op
%fill_2d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1x2xf32> in %variant_op
: (!transform.any_op) -> !transform.any_op
@@ -83,14 +85,14 @@
transform.structured.fuse_into_containing_op %fill_2d into %forall_block_more_parallel_op : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
// Canonicalizations.
- transform.apply_patterns to %variant_op {
+ transform.apply_patterns to %func_op {
transform.apply_patterns.iree.fold_fill_into_pad
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
} : !transform.any_op
- transform.iree.apply_licm %variant_op : !transform.any_op
- transform.iree.apply_cse %variant_op : !transform.any_op
+ transform.iree.apply_licm %func_op : !transform.any_op
+ transform.iree.apply_cse %func_op : !transform.any_op
// Step 4. Rank-reduce and vectorize.
// ===========================================================================
@@ -141,12 +143,14 @@
// Late canonicalizations.
- transform.apply_patterns to %variant_op_3 {
+ %func_op_3 = transform.structured.match ops{["func.func"]} in %variant_op_3
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_op_3 {
transform.apply_patterns.iree.fold_fill_into_pad
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
} : !transform.any_op
- transform.iree.apply_licm %variant_op_3 : !transform.any_op
- transform.iree.apply_cse %variant_op_3 : !transform.any_op
+ transform.iree.apply_licm %func_op_3 : !transform.any_op
+ transform.iree.apply_cse %func_op_3 : !transform.any_op
}
diff --git a/tests/transform_dialect/cuda/reduction_v3.mlir b/tests/transform_dialect/cuda/reduction_v3.mlir
deleted file mode 100644
index c49691a..0000000
--- a/tests/transform_dialect/cuda/reduction_v3.mlir
+++ /dev/null
@@ -1,69 +0,0 @@
-!in_tensor_t = tensor<?x?xf32>
-!out_tensor_t = tensor<?xf32>
-
-func.func @reduce(%arg : !in_tensor_t) -> (!out_tensor_t) {
- %c0 = arith.constant 0 : index
- %cst = arith.constant -0.000000e+00 : f32
-
- %d0 = tensor.dim %arg, %c0 : !in_tensor_t
- %0 = tensor.empty(%d0) : !out_tensor_t
- %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) -> !out_tensor_t
- %2 = linalg.generic {
- indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0)>],
- iterator_types = ["parallel", "reduction"]}
- ins(%arg : !in_tensor_t) outs(%1 : !out_tensor_t) {
- ^bb0(%arg3: f32, %arg4: f32):
- %3 = arith.addf %arg3, %arg4 : f32
- linalg.yield %3 : f32
- } -> !out_tensor_t
- return %2 : !out_tensor_t
-}
-
-// RUN: iree-opt %s --iree-hal-target-backends=cuda \
-// RUN: --iree-abi-transformation-pipeline \
-// RUN: --iree-flow-transformation-pipeline \
-// RUN: --iree-stream-transformation-pipeline \
-// RUN: --iree-hal-configuration-pipeline | \
-// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))' \
-// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_v3_codegen_spec.mlir | \
-// RUN: FileCheck %s --check-prefix=CHECK
-
-// RUN: iree-compile %s --iree-hal-target-backends=cuda \
-// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit=false \
-// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_v3_codegen_spec.mlir | \
-// RUN: iree-run-module --module=- --function=reduce --device=cuda --input="123x4567xf32=1" |\
-// RUN: FileCheck %s --check-prefix=EXEC
-
-// RUN: iree-compile %s --iree-hal-target-backends=cuda | \
-// RUN: iree-run-module --module=- --function=reduce --device=cuda --input="123x4567xf32=1" |\
-// RUN: FileCheck %s --check-prefix=EXEC
-
- // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
- // CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x1024xf32, #gpu.address_space<workgroup>>
-
- // CHECK: %[[TIDX:.]] = gpu.thread_id x
- // Local per-thread scf.for-based reduction.
- // CHECK: %[[v:.*]] = scf.for {{.*}} -> (vector<f32>) {
- // CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<f32>
- // CHECK: arith.addf {{.*}} : f32
- // CHECK: }
- // CHECK: vector.transfer_write %[[v]], %[[SHMEM_ALLOC]][%[[C0]], %[[TIDX]]] : vector<f32>, memref<1x1024xf32, #gpu.address_space<workgroup>>
-
- // Distributed reduction: everyone loads then 5 xor + addf expected
- // CHECK: vector.transfer_read %{{.*}}[%[[C0]], %{{.*}}]
- // CHECK-COUNT-5: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
-
- // CHECK: %[[RES:.*]] = arith.addf %{{.*}}
-
- // CHECK: %[[RES_VEC:.*]] = vector.broadcast %[[RES]] : f32 to vector<f32>
- // CHECK: %[[CONDXIS0:.*]] = arith.cmpi eq, %[[TIDX]], %[[C0]] : index
- // CHECK: scf.if %[[CONDXIS0]]
- // CHECK: vector.transfer_write %[[RES_VEC]]
- // CHECK: gpu.barrier
-
-// only checking the first 6 of 123
-// EXEC: result[0]: hal.buffer_view
-// EXEC-NEXT: 123xf32=4567 4567 4567 4567 4567 4567
diff --git a/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
deleted file mode 100644
index 055a551..0000000
--- a/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
+++ /dev/null
@@ -1,143 +0,0 @@
-// RUN: iree-opt %s
-
-transform.sequence failures(propagate) {
-^bb1(%variant_op: !transform.any_op):
- %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op
- : (!transform.any_op) -> !transform.any_op
- %reduction = transform.structured.match ops{["linalg.generic"]} in %variant_op
- : (!transform.any_op) -> !transform.any_op
-
- // Step 1. First level of tiling + fusion parallelizes to blocks.
- // ===========================================================================
- %forall_grid, %grid_reduction =
- transform.structured.tile_to_forall_op %reduction tile_sizes [1]
- ( mapping = [#gpu.block<x>] )
- : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.iree.populate_workgroup_count_region_using_num_threads_slice %forall_grid : (!transform.any_op) -> ()
-
- transform.structured.fuse_into_containing_op %fill into %forall_grid : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-
- // Canonicalizations.
- transform.apply_patterns to %variant_op {
- transform.apply_patterns.iree.fold_fill_into_pad
- transform.apply_patterns.linalg.tiling_canonicalization
- transform.apply_patterns.scf.for_loop_canonicalization
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.iree.apply_licm %variant_op : !transform.any_op
- transform.iree.apply_cse %variant_op : !transform.any_op
-
- // Step 2. Split the reduction to get meatier parallelism.
- // This also parallelizes to threads.
- // ===========================================================================
- %forall, %block_more_parallel_fill_op_2, %block_more_parallel_op_2, %block_combiner_op_2 =
- transform.structured.tile_reduction_using_forall %grid_reduction
- by num_threads = [0, 1024], tile_sizes = [0, 1], mapping = [#gpu.thread<x>]
- : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-
- // Fuse the fill and pointwise to privatize them.
- transform.structured.fuse_into_containing_op %block_more_parallel_fill_op_2
- into %forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-
- // block_combiner_op_2 op is [parallel, reduction] of 1x384 that cannot fuse.
- // map the 1-dim to threadIdx.y to trigger mapping of the reduction to
- // threadIdx.x via predication via `if (x==0)`.
- transform.structured.tile_to_forall_op %block_combiner_op_2 num_threads [1]
- ( mapping = [#gpu.thread<y>] )
- : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-
- // Canonicalizations.
- transform.apply_patterns to %variant_op {
- transform.apply_patterns.iree.fold_fill_into_pad
- transform.apply_patterns.linalg.tiling_canonicalization
- transform.apply_patterns.scf.for_loop_canonicalization
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.iree.apply_licm %variant_op : !transform.any_op
- transform.iree.apply_cse %variant_op : !transform.any_op
-
- // Step 3. Rank-reduce and vectorize.
- // ===========================================================================
- %func = transform.structured.match ops{["func.func"]} in %variant_op
- : (!transform.any_op) -> !transform.any_op
- // TODO: masked vectorization on block_more_parallel_op_2 if we want
- // vector<4> to work as intended.
- transform.apply_patterns to %func {
- transform.apply_patterns.iree.fold_reshape_into_tensor_hal_interface
- transform.apply_patterns.linalg.fold_unit_extent_dims_via_slices
- transform.apply_patterns.vector.cast_away_vector_leading_one_dim
- } : !transform.any_op
- %func_3 = transform.structured.vectorize %func : (!transform.any_op) -> !transform.any_op
-
- // Canonicalizations is necessary to get rid of some tensor.cast that block
- // hoisting.
- transform.apply_patterns to %variant_op {
- transform.apply_patterns.iree.fold_fill_into_pad
- transform.apply_patterns.linalg.tiling_canonicalization
- transform.apply_patterns.scf.for_loop_canonicalization
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.iree.apply_licm %variant_op : !transform.any_op
- transform.iree.apply_cse %variant_op : !transform.any_op
- transform.structured.hoist_redundant_tensor_subsets %func_3
- : (!transform.any_op) -> ()
-
-
- // Step 4. Bufferize and drop HAL descriptor from memref ops.
- // ===========================================================================
- // Canonicalizations required before bufferization to avoid unnecessary allocs.
- transform.apply_patterns to %variant_op {
- transform.apply_patterns.iree.fold_fill_into_pad
- transform.apply_patterns.linalg.tiling_canonicalization
- transform.apply_patterns.scf.for_loop_canonicalization
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.iree.apply_licm %variant_op : !transform.any_op
- transform.iree.apply_cse %variant_op : !transform.any_op
- transform.apply_patterns to %func_3 {
- transform.apply_patterns.tensor.reassociative_reshape_folding
- } : !transform.any_op
- transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
- %func_6 = transform.structured.match ops{["func.func"]} in %variant_op
- : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %func_6 {
- transform.apply_patterns.linalg.erase_unnecessary_inputs
- } : !transform.any_op
- %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op
- : (!transform.any_op) -> !transform.any_op
- %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
- : (!transform.any_op) -> !transform.any_op
- transform.iree.erase_hal_descriptor_type_from_memref %memref_func : (!transform.any_op) -> ()
-
- // Step 5. Post-bufferization mapping to blocks and threads.
- // ===========================================================================
- %func_m = transform.structured.match ops{["func.func"]} in %variant_op_3
- : (!transform.any_op) -> !transform.any_op
- transform.iree.forall_to_workgroup %func_m : (!transform.any_op) -> ()
- transform.iree.map_nested_forall_to_gpu_threads %func_m
- workgroup_dims = [1024, 1, 1] : (!transform.any_op) -> ()
-
- // Step 6. Post-bufferization vector distribution with rank-reduction.
- // ===========================================================================
- transform.apply_patterns to %func_m {
- transform.apply_patterns.iree.fold_reshape_into_tensor_hal_interface
- transform.apply_patterns.linalg.fold_unit_extent_dims_via_slices
- transform.apply_patterns.memref.fold_memref_alias_ops
- transform.apply_patterns.vector.cast_away_vector_leading_one_dim
- } : !transform.any_op
- %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
- : (!transform.any_op) -> !transform.any_op
- %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 } : (!transform.any_op) -> !transform.any_op
- transform.iree.vector.warp_distribute %func_m
- : (!transform.any_op) -> ()
-
- // Late canonicalizations.
- transform.apply_patterns to %variant_op_3 {
- transform.apply_patterns.iree.fold_fill_into_pad
- transform.apply_patterns.linalg.tiling_canonicalization
- transform.apply_patterns.scf.for_loop_canonicalization
- transform.apply_patterns.canonicalization
- } : !transform.any_op
- transform.iree.apply_licm %variant_op_3 : !transform.any_op
- transform.iree.apply_cse %variant_op_3 : !transform.any_op
-}
diff --git a/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
index 6254ea2..052dc25 100644
--- a/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
@@ -35,14 +35,16 @@
: (!transform.op<"scf.forall">) -> !transform.op<"scf.forall">
// Canonicalizations.
- transform.apply_patterns to %variant_op {
+ %func_op = transform.structured.match ops{["func.func"]} in %variant_op
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_op {
transform.apply_patterns.iree.fold_fill_into_pad
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
} : !transform.any_op
- transform.iree.apply_licm %variant_op : !transform.any_op
- transform.iree.apply_cse %variant_op : !transform.any_op
+ transform.iree.apply_licm %func_op : !transform.any_op
+ transform.iree.apply_cse %func_op : !transform.any_op
// Step 2. Second level of tiling + fusion parallelizes to threads.
@@ -76,24 +78,23 @@
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
// Canonicalizations.
- transform.apply_patterns to %variant_op {
+ transform.apply_patterns to %func_op {
transform.apply_patterns.iree.fold_fill_into_pad
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
} : !transform.any_op
- transform.iree.apply_licm %variant_op : !transform.any_op
- transform.iree.apply_cse %variant_op : !transform.any_op
+ transform.iree.apply_licm %func_op : !transform.any_op
+ transform.iree.apply_cse %func_op : !transform.any_op
// Step 3. Rank-reduce and vectorize.
// ==================================
- %funcx_2 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %funcx_2 {
+ transform.apply_patterns to %func_op {
transform.apply_patterns.iree.fold_reshape_into_tensor_hal_interface
transform.apply_patterns.linalg.fold_unit_extent_dims_via_slices
transform.apply_patterns.vector.cast_away_vector_leading_one_dim
} : !transform.any_op
- transform.structured.vectorize %funcx_2 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %func_op : (!transform.any_op) -> !transform.any_op
// Step 4. Bufferize and drop HAL decriptor from memref ops.
// =========================================================
@@ -125,12 +126,14 @@
// Late canonicalizations.
- transform.apply_patterns to %variant_op_3 {
+ %func_op_3 = transform.structured.match ops{["func.func"]} in %variant_op_3
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_op_3 {
transform.apply_patterns.iree.fold_fill_into_pad
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
} : !transform.any_op
- transform.iree.apply_licm %variant_op_3 : !transform.any_op
- transform.iree.apply_cse %variant_op_3 : !transform.any_op
+ transform.iree.apply_licm %func_op_3 : !transform.any_op
+ transform.iree.apply_cse %func_op_3 : !transform.any_op
}
diff --git a/tests/transform_dialect/cuda/vecadd2d.mlir b/tests/transform_dialect/cuda/vecadd2d.mlir
index da3abe7..9279355 100644
--- a/tests/transform_dialect/cuda/vecadd2d.mlir
+++ b/tests/transform_dialect/cuda/vecadd2d.mlir
@@ -75,8 +75,9 @@
// CHECK-PARTIAL-TILE: hal.executable.export
// CHECK-PARTIAL-TILE: bb0(%[[DEV:.*]]: !hal.device):
// CHECK-PARTIAL-TILE: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-PARTIAL-TILE: %[[C1_2:.*]] = arith.constant 1 : index
// CHECK-PARTIAL-TILE: %[[C171:.*]] = arith.constant 171 : index
-// CHECK-PARTIAL-TILE: hal.return %[[C1]], %[[C1]], %[[C171]] : index, index, index
+// CHECK-PARTIAL-TILE: hal.return %[[C1]], %[[C1_2]], %[[C171]] : index, index, index
// EXEC: EXEC @vecadd2d
// EXEC: result[0]: hal.buffer_view
diff --git a/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir b/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir
index 12b6ef1..2184cf6 100644
--- a/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir
+++ b/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir
@@ -11,12 +11,14 @@
// Late canonicalizations to cleanup and pass the checks.
// Needs to occur on the whole variant to perform cse on the workgroup_count region
- transform.apply_patterns to %variant_op {
+ %func_op = transform.structured.match ops{["func.func"]} in %variant_op
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_op {
transform.apply_patterns.iree.fold_fill_into_pad
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
} : !transform.any_op
- transform.iree.apply_licm %variant_op : !transform.any_op
- transform.iree.apply_cse %variant_op : !transform.any_op
+ transform.iree.apply_licm %func_op : !transform.any_op
+ transform.iree.apply_cse %func_op : !transform.any_op
}
diff --git a/third_party/llvm-project b/third_party/llvm-project
index f6f4276..88f07a3 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit f6f42760a47365120dd088f595074cf8e84617a2
+Subproject commit 88f07a311947f88de82ad2de9b2d6a26eba21343