[CodeGen] Move Linalg patterns and filters from LinalgExt to Codegen/ (#16619)
- The linalg filters are moved to Codegen/Utils/MarkerUtils
- The filter-based LinalgPromotionPattern is moved to Codegen/Transforms
- Update includes and deps
- Remove related unnecessary includes and deps
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
index fb5a82a..f3ffe7d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
@@ -62,7 +62,6 @@
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
- "//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms",
"//compiler/src/iree/compiler/Dialect/LinalgExt/Utils",
"//runtime/src/iree/builtins/ukernel:exported_bits",
"@llvm-project//llvm:Support",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt
index c20c0d3..6f219ac 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt
@@ -82,7 +82,6 @@
iree::compiler::Codegen::Utils
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::LinalgExt::IR
- iree::compiler::Dialect::LinalgExt::Transforms
iree::compiler::Dialect::LinalgExt::Utils
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
index 8d6b4aa..1610a15 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
@@ -87,8 +87,6 @@
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
- "//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
- "//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms",
"//llvm-external-projects/iree-dialects:IREEVectorExtDialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AMDGPUDialect",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
index d582e0b..d1e9f7d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
@@ -113,8 +113,6 @@
iree::compiler::Codegen::Utils
iree::compiler::Codegen::Utils::VectorOpUtils
iree::compiler::Dialect::HAL::IR
- iree::compiler::Dialect::LinalgExt::IR
- iree::compiler::Dialect::LinalgExt::Transforms
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp
index 0a90282..729865a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp
@@ -13,7 +13,6 @@
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -98,7 +97,7 @@
.setTileSizeComputationFunction(wgCopyTileSizeFn)
.setDistributionOptions(copyInvocationDistributionOptions);
- auto filter = IREE::LinalgExt::LinalgTransformationFilter(
+ auto filter = LinalgTransformationFilter(
{StringAttr::get(funcOp.getContext(), getCopyToWorkgroupMemoryMarker())},
StringAttr::get(funcOp.getContext(), getVectorizeMarker()));
return distributeLinalgOpsWithFilter(funcOp, tilingOptions, filter);
@@ -178,7 +177,7 @@
.setTileSizeComputationFunction(wgCopyTileSizeFn);
MLIRContext *context = funcOp.getContext();
- auto filter = IREE::LinalgExt::LinalgTransformationFilter(
+ auto filter = LinalgTransformationFilter(
{StringAttr::get(context, getCopyToWorkgroupMemoryMarker())},
StringAttr::get(context, kCopyToDistribute));
return distributeLinalgOpsWithFilter(funcOp, tilingOptions, filter);
@@ -259,7 +258,7 @@
.setTileSizeComputationFunction(wgCopyTileSizeFn)
.setDistributionOptions(copyInvocationDistributionOptions);
- auto filter = IREE::LinalgExt::LinalgTransformationFilter(
+ auto filter = LinalgTransformationFilter(
{StringAttr::get(funcOp.getContext(), kCopyToDistribute)},
StringAttr::get(funcOp.getContext(), kCopyDistributed));
return distributeLinalgOpsWithFilter(funcOp, tilingOptions, filter);
@@ -271,7 +270,7 @@
vectorizeCopyToWorkgroupMemoryOps(mlir::FunctionOpInterface funcOp) {
MLIRContext *context = funcOp.getContext();
IRRewriter rewriter(context);
- auto filter = IREE::LinalgExt::LinalgTransformationFilter(
+ auto filter = LinalgTransformationFilter(
{StringAttr::get(context, getCopyToWorkgroupMemoryMarker()),
StringAttr::get(context, kCopyDistributed)},
std::nullopt);
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToUKernels.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToUKernels.cpp
index b71efb1..58e0dcb 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToUKernels.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToUKernels.cpp
@@ -10,7 +10,6 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
-#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.cpp
index 675e02d..97c3dbc 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.cpp
@@ -6,9 +6,9 @@
#include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h"
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -181,10 +181,6 @@
}
};
-template <typename T>
-using LinalgPromotionPattern =
- mlir::iree_compiler::IREE::LinalgExt::LinalgPromotionPattern<T>;
-
/// Returns true if op is appropriate contract for promotion.
static LogicalResult contractOpFilter(Operation *op) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
@@ -240,7 +236,7 @@
.setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory)
.setOperandsToPromote(operandsToPromote)
.setUseFullTileBuffers({false, false}),
- IREE::LinalgExt::LinalgTransformationFilter(
+ LinalgTransformationFilter(
{StringAttr::get(context, getWorkgroupKTiledMarker())},
StringAttr::get(context, getWorkgroupMemoryMarker()))
.setMatchByDefault()
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp
index 0a3ed27..9a38cbc 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp
@@ -12,7 +12,6 @@
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -33,9 +32,10 @@
class TileConsumerAndFuseInputProducer final
: public OpInterfaceRewritePattern<TilingInterface> {
public:
- TileConsumerAndFuseInputProducer(
- MLIRContext *context, IREE::LinalgExt::LinalgTransformationFilter filter,
- bool fuseInputProducer, PatternBenefit benefit = 1)
+ TileConsumerAndFuseInputProducer(MLIRContext *context,
+ LinalgTransformationFilter filter,
+ bool fuseInputProducer,
+ PatternBenefit benefit = 1)
: OpInterfaceRewritePattern<TilingInterface>(context, benefit),
filter(std::move(filter)), fuseInputProducer(fuseInputProducer) {}
@@ -146,7 +146,7 @@
// Mark the fused input producer for distribution when writing to shared
// memory. We cannot use the current matmul op's tiling scheme here
// given dimensions are different.
- IREE::LinalgExt::LinalgTransformationFilter f(
+ LinalgTransformationFilter f(
ArrayRef<StringAttr>(),
rewriter.getStringAttr(getCopyToWorkgroupMemoryMarker()));
f.replaceLinalgTransformationFilter(
@@ -156,7 +156,7 @@
return tilingResult;
}
- IREE::LinalgExt::LinalgTransformationFilter filter;
+ LinalgTransformationFilter filter;
bool fuseInputProducer;
};
@@ -167,7 +167,7 @@
bool fuseInputProducer) {
MLIRContext *context = patterns.getContext();
- IREE::LinalgExt::LinalgTransformationFilter filter(
+ LinalgTransformationFilter filter(
ArrayRef<StringAttr>{
StringAttr::get(context, getWorkgroupMemoryMarker())},
StringAttr::get(context, getWorkgroupKTiledMarker()));
@@ -221,8 +221,7 @@
StringAttr::get(funcOp.getContext(), getCopyToWorkgroupMemoryMarker());
for (TilingInterface tilingOp : computeOps) {
- auto attr = tilingOp->getAttr(
- IREE::LinalgExt::LinalgTransforms::kLinalgTransformMarker);
+ auto attr = tilingOp->getAttr(LinalgTransforms::kLinalgTransformMarker);
if (attr == marker)
continue;
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/WorkgroupSpecializationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/WorkgroupSpecializationPass.cpp
index 6f4e094..1621d36 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/WorkgroupSpecializationPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/WorkgroupSpecializationPass.cpp
@@ -31,7 +31,6 @@
#include "iree/compiler/Codegen/Common/GPU/PassDetail.h"
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
-#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.h b/compiler/src/iree/compiler/Codegen/Common/Transforms.h
index e0d84f3..1fb5479f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.h
@@ -10,7 +10,6 @@
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
namespace mlir::bufferization {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/VerifyLinalgTransformLegality.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/VerifyLinalgTransformLegality.cpp
index 0f3fac0..c09a958 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/VerifyLinalgTransformLegality.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/VerifyLinalgTransformLegality.cpp
@@ -7,7 +7,6 @@
#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
@@ -25,8 +24,7 @@
auto moduleOp = getOperation();
// For now only check that there are no Linalg transform markers.
auto walkResult = moduleOp.walk([](linalg::LinalgOp op) -> WalkResult {
- if (op->hasAttr(
- IREE::LinalgExt::LinalgTransforms::kLinalgTransformMarker)) {
+ if (op->hasAttr(LinalgTransforms::kLinalgTransformMarker)) {
return op.emitError("expected no Linalg transform markers");
}
return WalkResult::advance();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp
index c10ddfc..3ed9b70 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp
@@ -12,7 +12,6 @@
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "llvm/Support/Debug.h"
#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -34,8 +33,7 @@
static void vectorizeLinalgOps(mlir::FunctionOpInterface funcOp) {
MLIRContext *context = funcOp.getContext();
IRRewriter rewriter(context);
- IREE::LinalgExt::LinalgTransformationFilter f(
- StringAttr::get(context, getVectorizeMarker()));
+ LinalgTransformationFilter f(StringAttr::get(context, getVectorizeMarker()));
funcOp.walk([&](Operation *op) {
if (failed(f.checkAndNotify(rewriter, op)) ||
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
index 3cf3332..d9b0b99 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
@@ -13,6 +13,7 @@
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
@@ -58,7 +59,7 @@
scf::SCFTilingOptions().setTileSizeComputationFunction(tileSizesFn);
MLIRContext *context = funcOp.getContext();
- IREE::LinalgExt::LinalgTransformationFilter filter(
+ LinalgTransformationFilter filter(
ArrayRef<StringAttr>{
StringAttr::get(context, getWorkgroupMemoryMarker())},
StringAttr::get(context, getWorkgroupKTiledMarker()));
@@ -154,7 +155,7 @@
.setTileSizeComputationFunction(getInnerTileSizeFn)
.setDistributionOptions(warpDistributionOptions);
MLIRContext *context = funcOp.getContext();
- IREE::LinalgExt::LinalgTransformationFilter filter(
+ LinalgTransformationFilter filter(
{StringAttr::get(context, getWorkgroupKTiledMarker()),
StringAttr::get(context, getWorkgroupMemoryMarker())},
StringAttr::get(context, getVectorizeMarker()));
@@ -185,7 +186,7 @@
.setDistributionOptions(invocationDistributionOptions);
MLIRContext *context = funcOp.getContext();
- IREE::LinalgExt::LinalgTransformationFilter f(
+ LinalgTransformationFilter f(
{StringAttr::get(context, getWorkgroupKTiledMarker()),
StringAttr::get(context, getWorkgroupMemoryMarker())},
StringAttr::get(context, getVectorizeMarker()));
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp
index 7392c33..d55c488 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp
@@ -18,7 +18,6 @@
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -148,8 +147,7 @@
static LogicalResult
tileReduction(mlir::FunctionOpInterface funcOp,
const scf::SCFTileSizeComputationFunction &computeFn) {
- auto filter =
- IREE::LinalgExt::LinalgTransformationFilter().setMatchByDefault();
+ auto filter = LinalgTransformationFilter().setMatchByDefault();
auto options =
scf::SCFTilingOptions().setTileSizeComputationFunction(computeFn);
auto result = tileLinalgOpsWithFilter(funcOp, options, filter);
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
index 6a63691..5bdca3f 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
@@ -19,7 +19,6 @@
#include "iree/compiler/Codegen/SPIRV/Utils.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -69,8 +68,7 @@
MLIRContext *context = funcOp.getContext();
IRRewriter rewriter(context);
auto marker = StringAttr::get(context, getTileReductionMarker());
- auto filter = IREE::LinalgExt::LinalgTransformationFilter(
- ArrayRef<StringAttr>(), marker);
+ auto filter = LinalgTransformationFilter(ArrayRef<StringAttr>(), marker);
SmallVector<TilingInterface> candidates;
funcOp.walk([&](TilingInterface op) { candidates.push_back(op); });
@@ -98,7 +96,7 @@
const scf::SCFTileSizeComputationFunction &computeFn) {
MLIRContext *context = funcOp.getContext();
IRRewriter rewriter(context);
- auto filter = IREE::LinalgExt::LinalgTransformationFilter(
+ auto filter = LinalgTransformationFilter(
StringAttr::get(context, getTileReductionMarker()), std::nullopt);
auto options =
scf::SCFTilingOptions().setTileSizeComputationFunction(computeFn);
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
index 005676f..beba1e2 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
@@ -20,7 +20,6 @@
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -43,7 +42,7 @@
static LogicalResult
tileReductionLoops(mlir::FunctionOpInterface funcOp,
- IREE::LinalgExt::LinalgTransformationFilter filter,
+ LinalgTransformationFilter filter,
const scf::SCFTileSizeComputationFunction &computeFn) {
auto options =
scf::SCFTilingOptions().setTileSizeComputationFunction(computeFn);
@@ -56,7 +55,7 @@
static LogicalResult
tileToInvocation(mlir::FunctionOpInterface funcOp,
- IREE::LinalgExt::LinalgTransformationFilter filter,
+ LinalgTransformationFilter filter,
const linalg::TileSizeComputationFunction &computeFn) {
auto getThreadProcInfoFn = [](OpBuilder &builder, Location loc,
ArrayRef<Range> parallelLoopRanges) {
@@ -80,10 +79,6 @@
static const char promoteBothMarker[] = "promote_lhs_and_rhs";
-template <typename T>
-using LinalgPromotionPattern =
- mlir::iree_compiler::IREE::LinalgExt::LinalgPromotionPattern<T>;
-
static void populatePromotionPatterns(RewritePatternSet &patterns,
StringAttr replaceMarker) {
MLIRContext *context = patterns.getContext();
@@ -95,7 +90,7 @@
.setUseFullTileBuffers({false, false});
auto promoteBothOptions = baseOptions.setOperandsToPromote({0, 1});
- IREE::LinalgExt::LinalgTransformationFilter promoteBothFilter(
+ LinalgTransformationFilter promoteBothFilter(
{StringAttr::get(context, promoteBothMarker)}, replaceMarker);
patterns.insert<LinalgPromotionPattern<linalg::MatmulOp>,
@@ -163,14 +158,13 @@
if (failed(doPromoteCMatrix(funcOp)))
return signalPassFailure();
- StringLiteral markerAttrName =
- IREE::LinalgExt::LinalgTransforms::kLinalgTransformMarker;
+ StringLiteral markerAttrName = LinalgTransforms::kLinalgTransformMarker;
auto workgroupMarker = StringAttr::get(context, getWorkgroupMemoryMarker());
auto kTiledMarker = StringAttr::get(context, getWorkgroupKTiledMarker());
{ // Tile reduction dimensions.
RewritePatternSet patterns(context);
- IREE::LinalgExt::LinalgTransformationFilter filter(
+ LinalgTransformationFilter filter(
// Going through C matrix promotion we will have the marker..
{workgroupMarker}, kTiledMarker);
// Not going through C matrix promotion we will have no marker..
@@ -258,8 +252,7 @@
});
if (!skipThreadLevel) { // Tile and distribute to invocations.
- IREE::LinalgExt::LinalgTransformationFilter filter({workgroupMarker},
- std::nullopt);
+ LinalgTransformationFilter filter({workgroupMarker}, std::nullopt);
if (failed(tileToInvocation(funcOp, filter, *threadTileComputeFn))) {
funcOp.emitOpError() << "failed tiling and distributing to invocations";
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
index abfa011..8fb91b7 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
@@ -24,7 +24,6 @@
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
@@ -164,7 +163,7 @@
.setTileSizeComputationFunction(setTileSizesFn)
.setDistributionOptions(distributionOptions);
- IREE::LinalgExt::LinalgTransformationFilter filter(
+ LinalgTransformationFilter filter(
{StringAttr::get(context, getWorkgroupKTiledMarker()),
StringAttr::get(context, getWorkgroupMemoryMarker())},
StringAttr::get(context, getVectorizeMarker()));
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Transforms/BUILD.bazel
index 341edb2..3699c8e 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Transforms/BUILD.bazel
@@ -27,7 +27,6 @@
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
- "//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineUtils",
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Transforms/CMakeLists.txt
index 5a96e4d..48b9f48 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Transforms/CMakeLists.txt
@@ -41,7 +41,6 @@
iree::compiler::Codegen::Utils
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
- iree::compiler::Dialect::LinalgExt::Transforms
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
index 32d3f4e..b425e2b 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
@@ -764,10 +764,9 @@
}
}
-LogicalResult
-tileLinalgOpsWithFilter(mlir::FunctionOpInterface funcOp,
- scf::SCFTilingOptions options,
- IREE::LinalgExt::LinalgTransformationFilter filter) {
+LogicalResult tileLinalgOpsWithFilter(mlir::FunctionOpInterface funcOp,
+ scf::SCFTilingOptions options,
+ LinalgTransformationFilter filter) {
IRRewriter rewriter(funcOp.getContext());
SmallVector<Operation *> candidates;
funcOp.walk([&](linalg::LinalgOp op) {
@@ -800,9 +799,10 @@
return success();
}
-LogicalResult distributeLinalgOpsWithFilter(
- mlir::FunctionOpInterface funcOp, linalg::LinalgTilingOptions tilingOptions,
- IREE::LinalgExt::LinalgTransformationFilter filter) {
+LogicalResult
+distributeLinalgOpsWithFilter(mlir::FunctionOpInterface funcOp,
+ linalg::LinalgTilingOptions tilingOptions,
+ LinalgTransformationFilter filter) {
IRRewriter rewriter(funcOp.getContext());
SmallVector<linalg::LinalgOp> candidates;
funcOp.walk([&](linalg::LinalgOp op) {
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h
index 504df40..99a2069 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h
@@ -12,10 +12,10 @@
#ifndef IREE_COMPILER_CODEGEN_TRANSFORMS_TRANSFORMS_H_
#define IREE_COMPILER_CODEGEN_TRANSFORMS_TRANSFORMS_H_
+#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
@@ -150,15 +150,100 @@
ArrayRef<OpFoldResult> workgroupCount,
int maxWorkgroupParallelDims = kNumMaxParallelDims);
+//===----------------------------------------------------------------------===//
+// Transformations exposed as patterns, moved from upstream MLIR as IREE still
+// heavily relies on patterns that compose through filters.
+// TODO: Deprecate all the code below.
+//===----------------------------------------------------------------------===//
+///
+/// Linalg promotion patterns.
+///
+/// Apply the `promoteSubViews` transformation as a pattern.
+/// `filter` controls LinalgTransformMarker matching and update when specified.
+/// See `promoteSubViews` for more details.
+struct LinalgBasePromotionPattern : public RewritePattern {
+ /// Entry point to match any LinalgOp
+ /// OpInterface. MatchAnyOpTag-based constructor
+ /// with a mandatory `filter`.
+ LinalgBasePromotionPattern(
+ MLIRContext *context, LinalgTransformationFilter f,
+ linalg::LinalgPromotionOptions options = linalg::LinalgPromotionOptions(),
+ PatternBenefit benefit = 1)
+ : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
+ filter(std::move(f)), options(std::move(options)) {}
+ /// Entry point to match a specific Linalg op.
+ LinalgBasePromotionPattern(
+ StringRef opName, MLIRContext *context,
+ linalg::LinalgPromotionOptions options,
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : RewritePattern(opName, benefit, context, {}), filter(std::move(f)),
+ options(std::move(options)) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ if (failed(filter.checkAndNotify(rewriter, op)))
+ return failure();
+ if (failed(promoteSubviewsPrecondition(op, options)))
+ return failure();
+
+ // TODO: We cannot use root update here. This
+ // pattern is creating other ops, so if the
+ // promotion fails, those need to be cleaned
+ // up, which doesnt seem to be happening here.
+ // So to fail properly, we should be cloning
+ // the op and deleting the previous op. This
+ // needs more investigation.
+ rewriter.startOpModification(op);
+ std::optional<linalg::LinalgOp> promotedOp =
+ promoteSubViews(rewriter, cast<linalg::LinalgOp>(op), options);
+ if (!promotedOp) {
+ rewriter.cancelOpModification(op);
+ return op->emitError("subview promotion failed");
+ }
+ rewriter.finalizeOpModification(op);
+ filter.replaceLinalgTransformationFilter(rewriter, op);
+ return success();
+ }
+
+private:
+ /// LinalgTransformMarker handles special
+ /// attribute manipulations.
+ LinalgTransformationFilter filter;
+ /// Promotion options.
+ linalg::LinalgPromotionOptions options;
+};
+
+template <typename OpTy>
+struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
+ /// SFINAE: This constructor can only trigger for
+ /// concrete ops that have a static
+ /// `getOperationName` method.
+ template <typename ConcreateOpTy = OpTy>
+ LinalgPromotionPattern(
+ MLIRContext *context, linalg::LinalgPromotionOptions options,
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : LinalgBasePromotionPattern(OpTy::getOperationName(), context, options,
+ f, benefit) {}
+ /// This constructor is available to anyone.
+ LinalgPromotionPattern(
+ StringRef opName, MLIRContext *context,
+ linalg::LinalgPromotionOptions options,
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : LinalgBasePromotionPattern(opName, context, options, f, benefit) {}
+};
+
/// Tiles LinalgOp ops that match filter.
-LogicalResult
-tileLinalgOpsWithFilter(mlir::FunctionOpInterface funcOp,
- scf::SCFTilingOptions options,
- IREE::LinalgExt::LinalgTransformationFilter filter);
+LogicalResult tileLinalgOpsWithFilter(mlir::FunctionOpInterface funcOp,
+ scf::SCFTilingOptions options,
+ LinalgTransformationFilter filter);
/// Distributes LinalgOp ops that match filter.
-LogicalResult distributeLinalgOpsWithFilter(
- mlir::FunctionOpInterface funcOp, linalg::LinalgTilingOptions tilingOptions,
- IREE::LinalgExt::LinalgTransformationFilter filter);
+LogicalResult
+distributeLinalgOpsWithFilter(mlir::FunctionOpInterface funcOp,
+ linalg::LinalgTilingOptions tilingOptions,
+ LinalgTransformationFilter filter);
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.cpp
index 7a18e2d..077305e 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.cpp
@@ -6,13 +6,88 @@
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Operation.h"
namespace mlir::iree_compiler {
+// Marker used as attribute name in generated Linalg rewriting transformations.
+const StringLiteral LinalgTransforms::kLinalgTransformMarker =
+ "__internal_linalg_transform__";
+
+LinalgTransformationFilter::LinalgTransformationFilter(
+ ArrayRef<StringAttr> matchDisjunction,
+ std::optional<StringAttr> replacement)
+ : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
+ replacement(replacement), matchByDefault(false) {}
+
+LinalgTransformationFilter::LinalgTransformationFilter(
+ const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction,
+ std::optional<StringAttr> replacement)
+ : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
+ replacement(replacement), matchByDefault(false) {
+ if (f) {
+ filters.push_back(f);
+ }
+}
+
+LogicalResult LinalgTransformationFilter::checkAndNotify(RewriterBase &rewriter,
+ Operation *op) const {
+ if (llvm::any_of(filters,
+ [&](const FilterFunction &f) { return failed(f(op)); })) {
+ return failure();
+ }
+
+ auto attr = op->template getAttrOfType<StringAttr>(
+ LinalgTransforms::kLinalgTransformMarker);
+
+ if (!attr) {
+ // 1. Has no filter case and matchDisjunction is empty.
+ if (matchDisjunction.empty() || matchByDefault) {
+ return success();
+ }
+
+ // 2. Has no filter but was expecting a filter.
+ return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+ diag << " does not have any filter from list: ";
+ interleaveComma(matchDisjunction, diag);
+ });
+ }
+
+ // 4. Match explicit filter.
+ for (auto filter : matchDisjunction) {
+ if (attr.getValue() == filter) {
+ return success();
+ }
+ }
+
+ // 5. Fail to match.
+ return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+ diag << " does not have any filter from list: ";
+ interleaveComma(matchDisjunction, diag);
+ });
+}
+
+void LinalgTransformationFilter::replaceLinalgTransformationFilter(
+ RewriterBase &rewriter, Operation *op) const {
+ if (replacement.has_value()) {
+ op->setAttr(LinalgTransforms::kLinalgTransformMarker, replacement.value());
+ } else {
+ op->removeAttr(
+ rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker));
+ }
+}
+
+bool LinalgTransformationFilter::hasReplacementFilter(Operation *op) const {
+ if (!replacement) {
+ return false;
+ }
+ auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker)
+ .dyn_cast<StringAttr>();
+ return attr && attr == *replacement;
+}
+
struct VectorTransforms {
static const StringLiteral kVectorTransformMarker;
};
@@ -46,16 +121,16 @@
StringRef getDeleteMarker() { return "delete"; }
StringRef getMarkerOrNull(Operation *op) {
- StringAttr attr = op->getAttrOfType<StringAttr>(
- IREE::LinalgExt::LinalgTransforms::kLinalgTransformMarker);
+ StringAttr attr =
+ op->getAttrOfType<StringAttr>(LinalgTransforms::kLinalgTransformMarker);
if (!attr)
return "";
return attr.getValue();
}
bool hasMarker(Operation *op, ArrayRef<StringRef> marker) {
- StringAttr attr = op->getAttrOfType<StringAttr>(
- IREE::LinalgExt::LinalgTransforms::kLinalgTransformMarker);
+ StringAttr attr =
+ op->getAttrOfType<StringAttr>(LinalgTransforms::kLinalgTransformMarker);
return attr && (marker.empty() ||
llvm::any_of(marker, [&attr](StringRef markerValue) {
return attr.getValue() == markerValue;
@@ -63,7 +138,7 @@
}
void setMarker(Operation *op, StringRef marker) {
- op->setAttr(IREE::LinalgExt::LinalgTransforms::kLinalgTransformMarker,
+ op->setAttr(LinalgTransforms::kLinalgTransformMarker,
StringAttr::get(op->getContext(), marker));
}
diff --git a/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.h b/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.h
index 782bd87..40dc001 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.h
@@ -16,10 +16,73 @@
#include "llvm/ADT/ArrayRef.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
namespace mlir::iree_compiler {
+// Marker used as attribute name in generated Linalg rewriting transformations.
+struct LinalgTransforms {
+ static const StringLiteral kLinalgTransformMarker;
+};
+
+/// Helper class to control application of linalg transformation patterns.
+/// Control comes in 2 forms:
+/// 1. attribute matching and setting behavior using the attribute named
+/// `kLinalgTransformMarker`. This can be used to build a state machine
+/// using attributes and incrementally applying patterns to advance states.
+/// 2. filter function, which is a simple lambda on the Operation* that
+/// returns a LogicalResult.
+struct LinalgTransformationFilter {
+ using FilterFunction = std::function<LogicalResult(Operation *)>;
+
+ explicit LinalgTransformationFilter(
+ ArrayRef<StringAttr> matchDisjunction = {},
+ std::optional<StringAttr> replacement = std::nullopt);
+
+ explicit LinalgTransformationFilter(
+ const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction = {},
+ std::optional<StringAttr> replacement = std::nullopt);
+
+ LinalgTransformationFilter(LinalgTransformationFilter &&) = default;
+ LinalgTransformationFilter(const LinalgTransformationFilter &) = default;
+ LogicalResult checkAndNotify(RewriterBase &rewriter, Operation *op) const;
+ void replaceLinalgTransformationFilter(RewriterBase &rewriter,
+ Operation *op) const;
+ bool hasReplacementFilter(Operation *op) const;
+
+ LinalgTransformationFilter &addFilter(const FilterFunction &f) {
+ if (f)
+ filters.push_back(f);
+ return *this;
+ }
+
+ template <typename... OpTypes>
+ LinalgTransformationFilter &addOpFilter() {
+ return addFilter(
+ [](Operation *op) { return success(isa<OpTypes...>(op)); });
+ }
+
+ LinalgTransformationFilter &addOpNameFilter(StringRef opName) {
+ return addFilter([opName](Operation *op) {
+ return success(op->getName().getStringRef() == opName);
+ });
+ }
+
+ LinalgTransformationFilter &setMatchByDefault() {
+ matchByDefault = true;
+ return *this;
+ }
+
+private:
+ SmallVector<FilterFunction> filters;
+ SmallVector<StringAttr> matchDisjunction;
+ std::optional<StringAttr> replacement;
+ /// When set to true, if the attribute is not set, it will be treated as
+ /// a match. Default is false.
+ bool matchByDefault;
+};
+
/// Marker to denote that a linalg operation has been partitioned to
/// workgroups and tiled along reduction dimennsions.
StringRef getWorkgroupKTiledMarker();
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SplitReduction.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SplitReduction.cpp
index 3cd40a3..2f10d96 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SplitReduction.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SplitReduction.cpp
@@ -13,7 +13,6 @@
#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp
index b696e65..3533493 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp
@@ -6,7 +6,7 @@
#include "iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/OpImplementation.h"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel
index 7bf06dd..5875577 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel
@@ -42,7 +42,6 @@
hdrs = [
"Passes.h",
"Passes.h.inc",
- "Transforms.h",
],
deps = [
":PassesIncGen",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
index 1c669e4..84cf073e 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
@@ -25,7 +25,6 @@
HDRS
"Passes.h"
"Passes.h.inc"
- "Transforms.h"
SRCS
"ConvertConv2DToWinograd.cpp"
"ConvertToLoops.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.cpp
index 737450d..6eab1dc 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.cpp
@@ -13,82 +13,6 @@
namespace mlir::iree_compiler::IREE::LinalgExt {
-// Marker used as attribute name in generated Linalg rewriting transformations.
-const StringLiteral LinalgTransforms::kLinalgTransformMarker =
- "__internal_linalg_transform__";
-
-LinalgTransformationFilter::LinalgTransformationFilter(
- ArrayRef<StringAttr> matchDisjunction,
- std::optional<StringAttr> replacement)
- : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
- replacement(replacement), matchByDefault(false) {}
-
-LinalgTransformationFilter::LinalgTransformationFilter(
- const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction,
- std::optional<StringAttr> replacement)
- : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
- replacement(replacement), matchByDefault(false) {
- if (f) {
- filters.push_back(f);
- }
-}
-
-LogicalResult LinalgTransformationFilter::checkAndNotify(RewriterBase &rewriter,
- Operation *op) const {
- if (llvm::any_of(filters,
- [&](const FilterFunction &f) { return failed(f(op)); })) {
- return failure();
- }
-
- auto attr = op->template getAttrOfType<StringAttr>(
- LinalgTransforms::kLinalgTransformMarker);
-
- if (!attr) {
- // 1. Has no filter case and matchDisjunction is empty.
- if (matchDisjunction.empty() || matchByDefault) {
- return success();
- }
-
- // 2. Has no filter but was expecting a filter.
- return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
- diag << " does not have any filter from list: ";
- interleaveComma(matchDisjunction, diag);
- });
- }
-
- // 4. Match explicit filter.
- for (auto filter : matchDisjunction) {
- if (attr.getValue() == filter) {
- return success();
- }
- }
-
- // 5. Fail to match.
- return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
- diag << " does not have any filter from list: ";
- interleaveComma(matchDisjunction, diag);
- });
-}
-
-void LinalgTransformationFilter::replaceLinalgTransformationFilter(
- RewriterBase &rewriter, Operation *op) const {
- if (replacement.has_value()) {
- op->setAttr(LinalgTransforms::kLinalgTransformMarker, replacement.value());
- } else {
- op->removeAttr(
- rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker));
- }
-}
-
-bool LinalgTransformationFilter::hasReplacementFilter(Operation *op) const {
- if (!replacement) {
- return false;
- }
- auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker)
- .dyn_cast<StringAttr>();
- return attr && attr == *replacement;
-}
-
namespace detail {
#define GEN_PASS_REGISTRATION
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc" // IWYU pragma: export
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
index 0cffa82..c99cf07 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
@@ -16,71 +16,6 @@
namespace mlir::iree_compiler::IREE::LinalgExt {
-class ConversionTarget;
-class TypeConverter;
-
-// Marker used as attribute name in generated Linalg rewriting transformations.
-struct LinalgTransforms {
- static const StringLiteral kLinalgTransformMarker;
-};
-
-/// Helper class to control application of linalg transformation patterns.
-/// Control comes in 2 forms:
-/// 1. attribute matching and setting behavior using the attribute named
-/// `kLinalgTransformMarker`. This can be used to build a state machine
-/// using attributes and incrementally applying patterns to advance states.
-/// 2. filter function, which is a simple lambda on the Operation* that
-/// returns a LogicalResult.
-struct LinalgTransformationFilter {
- using FilterFunction = std::function<LogicalResult(Operation *)>;
-
- explicit LinalgTransformationFilter(
- ArrayRef<StringAttr> matchDisjunction = {},
- std::optional<StringAttr> replacement = std::nullopt);
-
- explicit LinalgTransformationFilter(
- const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction = {},
- std::optional<StringAttr> replacement = std::nullopt);
-
- LinalgTransformationFilter(LinalgTransformationFilter &&) = default;
- LinalgTransformationFilter(const LinalgTransformationFilter &) = default;
- LogicalResult checkAndNotify(RewriterBase &rewriter, Operation *op) const;
- void replaceLinalgTransformationFilter(RewriterBase &rewriter,
- Operation *op) const;
- bool hasReplacementFilter(Operation *op) const;
-
- LinalgTransformationFilter &addFilter(const FilterFunction &f) {
- if (f)
- filters.push_back(f);
- return *this;
- }
-
- template <typename... OpTypes>
- LinalgTransformationFilter &addOpFilter() {
- return addFilter(
- [](Operation *op) { return success(isa<OpTypes...>(op)); });
- }
-
- LinalgTransformationFilter &addOpNameFilter(StringRef opName) {
- return addFilter([opName](Operation *op) {
- return success(op->getName().getStringRef() == opName);
- });
- }
-
- LinalgTransformationFilter &setMatchByDefault() {
- matchByDefault = true;
- return *this;
- }
-
-private:
- SmallVector<FilterFunction> filters;
- SmallVector<StringAttr> matchDisjunction;
- std::optional<StringAttr> replacement;
- /// When set to true, if the attribute is not set, it will be treated as
- /// a match. Default is false.
- bool matchByDefault;
-};
-
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLinalgExtToLoopsPass();
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h
deleted file mode 100644
index 964f718..0000000
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h
+++ /dev/null
@@ -1,107 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_COMPILER_DIALECT_LINALGEXT_TRANSFORMS_TRANSFORMS_H_
-#define IREE_COMPILER_DIALECT_LINALGEXT_TRANSFORMS_TRANSFORMS_H_
-
-#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
-#include "mlir/IR/PatternMatch.h"
-
-namespace mlir::iree_compiler::IREE::LinalgExt {
-
-//===----------------------------------------------------------------------===//
-// Transformations exposed as patterns, moved from upstream MLIR as IREE still
-// heavily relies on patterns that compose through filters.
-// TODO: Deprecate all the code below.
-//===----------------------------------------------------------------------===//
-///
-/// Linalg promotion patterns.
-///
-/// Apply the `promoteSubViews` transformation as a pattern.
-/// `filter` controls LinalgTransformMarker matching and update when specified.
-/// See `promoteSubViews` for more details.
-struct LinalgBasePromotionPattern : public RewritePattern {
- /// Entry point to match any LinalgOp
- /// OpInterface. MatchAnyOpTag-based constructor
- /// with a mandatory `filter`.
- LinalgBasePromotionPattern(
- MLIRContext *context, LinalgTransformationFilter f,
- linalg::LinalgPromotionOptions options = linalg::LinalgPromotionOptions(),
- PatternBenefit benefit = 1)
- : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
- filter(std::move(f)), options(std::move(options)) {}
- /// Entry point to match a specific Linalg op.
- LinalgBasePromotionPattern(
- StringRef opName, MLIRContext *context,
- linalg::LinalgPromotionOptions options,
- LinalgTransformationFilter f = LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
- : RewritePattern(opName, benefit, context, {}), filter(std::move(f)),
- options(std::move(options)) {}
-
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
- if (failed(filter.checkAndNotify(rewriter, op)))
- return failure();
- if (failed(promoteSubviewsPrecondition(op, options)))
- return failure();
-
- // TODO: We cannot use root update here. This
- // pattern is creating other ops, so if the
- // promotion fails, those need to be cleaned
- // up, which doesnt seem to be happening here.
- // So to fail properly, we should be cloning
- // the op and deleting the previous op. This
- // needs more investigation.
- rewriter.startOpModification(op);
- std::optional<linalg::LinalgOp> promotedOp =
- promoteSubViews(rewriter, cast<linalg::LinalgOp>(op), options);
- if (!promotedOp) {
- rewriter.cancelOpModification(op);
- return op->emitError("subview promotion failed");
- }
- rewriter.finalizeOpModification(op);
- filter.replaceLinalgTransformationFilter(rewriter, op);
- return success();
- }
-
-private:
- /// LinalgTransformMarker handles special
- /// attribute manipulations.
- LinalgTransformationFilter filter;
- /// Promotion options.
- linalg::LinalgPromotionOptions options;
-};
-
-template <typename OpTy>
-struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
- /// SFINAE: This constructor can only trigger for
- /// concrete ops that have a static
- /// `getOperationName` method.
- template <typename ConcreateOpTy = OpTy>
- LinalgPromotionPattern(
- MLIRContext *context, linalg::LinalgPromotionOptions options,
- LinalgTransformationFilter f = LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
- : LinalgBasePromotionPattern(OpTy::getOperationName(), context, options,
- f, benefit) {}
- /// This constructor is available to anyone.
- LinalgPromotionPattern(
- StringRef opName, MLIRContext *context,
- linalg::LinalgPromotionOptions options,
- LinalgTransformationFilter f = LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
- : LinalgBasePromotionPattern(opName, context, options, f, benefit) {}
-};
-
-} // namespace mlir::iree_compiler::IREE::LinalgExt
-
-#endif // IREE_COMPILER_DIALECT_LINALGEXT_TRANSFORMS_TRANSFORMS_H_