Move TransformMatchers to llvm-external-projects (#11472)
This is a followup commit to #11422 addressing code move requests in the
review.
Co-authored-by: Thomas Raoux <thomasraoux@google.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD b/compiler/src/iree/compiler/Codegen/Common/BUILD
index 814298b..66512d4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD
@@ -100,14 +100,19 @@
srcs = [
"TransformDialectJitterPass.cpp",
"TransformDialectStrategies.cpp",
+ "TransformDialectStrategiesCPU.cpp",
+ "TransformDialectStrategiesGPU.cpp",
],
hdrs = [
"TransformDialectStrategies.h",
+ "TransformDialectStrategiesCPU.h",
+ "TransformDialectStrategiesGPU.h",
],
deps = [
# Dialects
"//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
+ "//llvm-external-projects/iree-dialects:IREEDialectsTransforms",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
"//llvm-external-projects/iree-dialects:IREELinalgExtTransformOps",
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index 201d5fc..fc6b489 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -72,10 +72,15 @@
TransformDialectJitterPass
HDRS
"TransformDialectStrategies.h"
+ "TransformDialectStrategiesCPU.h"
+ "TransformDialectStrategiesGPU.h"
SRCS
"TransformDialectJitterPass.cpp"
"TransformDialectStrategies.cpp"
+ "TransformDialectStrategiesCPU.cpp"
+ "TransformDialectStrategiesGPU.cpp"
DEPS
+ IREEDialectsTransforms
IREELinalgExtDialect
IREELinalgExtTransformOps
IREELinalgTransformDialect
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.cpp
index c2c0a5d..99dcc27 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.cpp
@@ -10,8 +10,8 @@
#include <type_traits>
#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
+#include "iree-dialects/Transforms/TransformMatchers.h"
#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h"
-#include "iree/compiler/Codegen/Common/TransformExtensions/TransformMatchers.h"
#include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
@@ -26,6 +26,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -46,27 +47,18 @@
using namespace mlir;
-namespace mlir {
-namespace iree_compiler {
-
#define DEBUG_TYPE "iree-transform-builder"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
// TODO: significantly better namespacing.
-using iree_compiler::IREE::transform_dialect::AllDims;
using iree_compiler::IREE::transform_dialect::ApplyPatternsOp;
using iree_compiler::IREE::transform_dialect::ConfigExtractPart;
using iree_compiler::IREE::transform_dialect::ForeachThreadToWorkgroupOp;
using iree_compiler::IREE::transform_dialect::IREEBufferizeOp;
using iree_compiler::IREE::transform_dialect::
IREEEraseHALDescriptorTypeFromMemRefOp;
-using iree_compiler::IREE::transform_dialect::IsPermutation;
-using iree_compiler::IREE::transform_dialect::m_StructuredOp;
using iree_compiler::IREE::transform_dialect::
MapNestedForeachThreadToGpuThreadsOp;
-using iree_compiler::IREE::transform_dialect::NumEqualsTo;
-using iree_compiler::IREE::transform_dialect::ShapeKind;
-using iree_compiler::IREE::transform_dialect::StructuredOpMatcher;
using iree_compiler::IREE::transform_dialect::
TileToForeachThreadAndWorkgroupCountRegionOp;
using iree_compiler::IREE::transform_dialect::VectorToWarpExecuteOnLane0Op;
@@ -80,6 +72,12 @@
using transform::SplitReductionOp;
using transform::TileToForeachThreadOp;
using transform::VectorizeOp;
+using transform_ext::AllDims;
+using transform_ext::IsPermutation;
+using transform_ext::m_StructuredOp;
+using transform_ext::NumEqualsTo;
+using transform_ext::ShapeKind;
+using transform_ext::StructuredOpMatcher;
/// Matches `args` within `targetH` and unpacks a number of handles `N`.
/// Assumes there are exactly `N` matched ops (but could be relaxed).
@@ -101,7 +99,8 @@
//===----------------------------------------------------------------------===//
/// Prints `handles` in order. Prints the whole IR if `handles` is empty.
-static void buildPrint(ImplicitLocOpBuilder &b, ValueRange handles = {}) {
+void mlir::iree_compiler::buildPrint(ImplicitLocOpBuilder &b,
+ ValueRange handles) {
if (handles.empty()) b.create<PrintOp>();
for (auto h : handles) b.create<PrintOp>(h);
}
@@ -141,56 +140,94 @@
return foreachThreadH;
}
-/// Call buildTileAndFuseAndDistributeImpl with ArrayRef<int64_t> tilesSizes.
-template <typename TilingTransformOp = TileToForeachThreadOp>
+// TODO: if someone knows how to properly export templates go for it ..
+// sigh.
+template <typename TilingTransformOp>
static Value buildTileFuseDistWithTileSizes(
ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping,
- SmallVectorImpl<Value> *resultingFusedOpsHandles = nullptr) {
+ SmallVectorImpl<Value> *resultingFusedOpsHandles) {
return buildTileAndFuseAndDistributeImpl<TilingTransformOp,
transform::TileSizesSpec>(
b, rootH, opsHToFuse, tileSizes, threadDimMapping,
resultingFusedOpsHandles);
}
-
-/// Call buildTileAndFuseAndDistributeImpl with ArrayRef<int64_t> numThreads.
-template <typename TilingTransformOp = TileToForeachThreadOp>
-static Value buildTileFuseDistWithNumThreads(
+Value mlir::iree_compiler::buildTileFuseDistToForeachThreadWithTileSizes(
ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
- ArrayRef<int64_t> numThreads, ArrayAttr threadDimMapping,
- SmallVectorImpl<Value> *resultingFusedOpsHandles = nullptr) {
- return buildTileAndFuseAndDistributeImpl<TilingTransformOp,
- transform::NumThreadsSpec>(
- b, rootH, opsHToFuse, getAsOpFoldResult(b.getI64ArrayAttr(numThreads)),
- threadDimMapping, resultingFusedOpsHandles);
+ ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping,
+ SmallVectorImpl<Value> *resultingFusedOpsHandles) {
+ return buildTileFuseDistWithTileSizes<TileToForeachThreadOp>(
+ b, rootH, opsHToFuse, tileSizes, threadDimMapping,
+ resultingFusedOpsHandles);
+}
+Value mlir::iree_compiler::
+ buildTileFuseDistToForeachThreadAndWorgroupCountWithTileSizes(
+ ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
+ ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping,
+ SmallVectorImpl<Value> *resultingFusedOpsHandles) {
+ return buildTileFuseDistWithTileSizes<
+ TileToForeachThreadAndWorkgroupCountRegionOp>(b, rootH, opsHToFuse,
+ tileSizes, threadDimMapping,
+ resultingFusedOpsHandles);
}
-/// Call buildTileAndFuseAndDistributeImpl with a handle to multiple numThreads.
-template <typename TilingTransformOp = TileToForeachThreadOp>
+/// Call buildTileAndFuseAndDistributeImpl with ArrayRef<int64_t> numThreads.
+// TODO: if someone knows how to properly export templates go for it ..
+// sigh.
+template <typename TilingTransformOp>
static Value buildTileFuseDistWithNumThreads(
ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
- Value numThreads, ArrayAttr threadDimMapping,
- SmallVectorImpl<Value> *resultingFusedOpsHandles = nullptr) {
+ ArrayRef<OpFoldResult> numThreads, ArrayAttr threadDimMapping,
+ SmallVectorImpl<Value> *resultingFusedOpsHandles) {
return buildTileAndFuseAndDistributeImpl<TilingTransformOp,
transform::NumThreadsSpec>(
- b, rootH, opsHToFuse, ArrayRef<OpFoldResult>{numThreads},
- threadDimMapping, resultingFusedOpsHandles);
+ b, rootH, opsHToFuse, numThreads, threadDimMapping,
+ resultingFusedOpsHandles);
+}
+Value mlir::iree_compiler::buildTileFuseDistToForeachThreadWithNumThreads(
+ ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
+ ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping,
+ SmallVectorImpl<Value> *resultingFusedOpsHandles) {
+ return buildTileFuseDistWithTileSizes<TileToForeachThreadOp>(
+ b, rootH, opsHToFuse, tileSizes, threadDimMapping,
+ resultingFusedOpsHandles);
+}
+Value mlir::iree_compiler::
+ buildTileFuseDistToForeachThreadAndWorgroupCountWithNumThreads(
+ ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
+ ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping,
+ SmallVectorImpl<Value> *resultingFusedOpsHandles) {
+ return buildTileFuseDistWithTileSizes<
+ TileToForeachThreadAndWorkgroupCountRegionOp>(b, rootH, opsHToFuse,
+ tileSizes, threadDimMapping,
+ resultingFusedOpsHandles);
}
/// Apply patterns and vectorize (for now always applies rank-reduction).
/// Takes a handle to a func.func and returns an updated handle to a
/// func.func.
// TODO: configure patterns.
-static Value buildVectorizeStrategy(ImplicitLocOpBuilder &b, Value funcH) {
+Value mlir::iree_compiler::buildVectorize(ImplicitLocOpBuilder &b,
+ Value funcH) {
funcH = b.create<ApplyPatternsOp>(funcH, /*rankReducing=*/true);
return b.create<VectorizeOp>(funcH);
}
+/// Bufferize and drop HAL descriptor from memref ops.
+Value mlir::iree_compiler::buildBufferize(ImplicitLocOpBuilder &b,
+ Value variantH, bool targetGpu) {
+ variantH = b.create<IREEBufferizeOp>(variantH, /*targetGpu=*/true);
+ Value memrefFunc =
+ b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
+ b.create<IREEEraseHALDescriptorTypeFromMemRefOp>(memrefFunc);
+ return variantH;
+}
+
/// Post-bufferization mapping to blocks and threads.
/// Takes a handle to a func.func and returns an updated handle to a
/// func.func.
-static Value buildMapToBlockAndThreads(ImplicitLocOpBuilder &b, Value funcH,
- ArrayRef<int64_t> blockSize) {
+Value mlir::iree_compiler::buildMapToBlockAndThreads(
+ ImplicitLocOpBuilder &b, Value funcH, ArrayRef<int64_t> blockSize) {
funcH = b.create<ForeachThreadToWorkgroupOp>(funcH);
return b.create<MapNestedForeachThreadToGpuThreadsOp>(funcH, blockSize);
}
@@ -200,9 +237,9 @@
/// Post-bufferization vector distribution with rank-reduction.
/// Takes a handle to a func.func and returns an updated handle to a
/// func.func.
-static Value buildDistributeVectors(ImplicitLocOpBuilder &b, Value variantH,
- Value funcH,
- int64_t warpSize = kCudaWarpSize) {
+Value mlir::iree_compiler::buildDistributeVectors(ImplicitLocOpBuilder &b,
+ Value variantH, Value funcH,
+ int64_t warpSize) {
funcH = b.create<ApplyPatternsOp>(funcH, /*rankReducing=*/true);
Value ifH = b.create<MatchOp>(funcH, scf::IfOp::getOperationName());
// Locally suppress failures for this op only because it doesn't cover the
@@ -220,11 +257,7 @@
return funcH;
}
-//===----------------------------------------------------------------------===//
-// Higher-level problem-specific strategy creation APIs, these should favor
-// user-friendliness.
-//===----------------------------------------------------------------------===//
-
+namespace {
/// Various handles produced by reduction splitting.
struct ReductionSplitResult {
/// Handle to the leading elementwise operation, may be null if no such
@@ -237,13 +270,14 @@
Value splitLinalgH;
/// Handle to the final reduction.
Value combinerH;
- /// Handle to the original fill operation, may be null if the operation was
- /// not re-matched.
+ /// Handle to the original fill operation, may be null if the operation
+ /// was not re-matched.
Value originalFillH;
- /// Handle to the trailing fill operation, may be null if the operation was
- /// not re-matched.
+ /// Handle to the trailing fill operation, may be null if the operation
+ /// was not re-matched.
Value trailingEltwiseH;
};
+} // namespace
/// Builds transform IR requesting to bubble up the "expand_shape" operation
/// produced as parent of reduction splitting if necessary for fusion of the
@@ -280,11 +314,11 @@
/// Distribute to blocks using the current IREE lowering config.
// TODO: consider passing a problem-specific struct to control information.
-static Value createReductionStrategyBlockDistributionPart(
+Value mlir::iree_compiler::createReductionStrategyBlockDistributionPart(
ImplicitLocOpBuilder &b, Value variantH, Value originalFillH,
Value reductionH, Value optionalFusionRootH,
- ArrayRef<OpFoldResult> tileSizes0Generic, bool hasLeadingEltwise = false,
- bool hasTrailingEltwise = false) {
+ ArrayRef<OpFoldResult> tileSizes0Generic, bool hasLeadingEltwise,
+ bool hasTrailingEltwise) {
// Step 1. Split the reduction to get meatier parallelism.
// TODO: use a scf.foreach_thread for this.
auto splitReductionTransformOp =
@@ -302,8 +336,8 @@
auto x = mlir::gpu::GPUBlockMappingAttr::get(b.getContext(),
::mlir::gpu::Blocks::DimX);
// Step 2. First level of tiling + fusion parallelizes to blocks using
- // `tileSizes`. If the fusion root was the reduction op, update it to be the
- // combiner op. Otherwise, fuse the combiner op into root.
+ // `tileSizes`. If the fusion root was the reduction op, update it to be
+ // the combiner op. Otherwise, fuse the combiner op into root.
SmallVector<Value> opsHToFuse(
{rs.originalFillH ? rs.originalFillH : originalFillH, rs.splitFillH,
rs.splitLinalgH});
@@ -318,192 +352,25 @@
opsHToFuse.push_back(rs.leadingEltwiseH);
}
- // The presence of leading elementwise operation implies that dispatch region
- // formation happened using another transform dialect script and doesn't need
- // the workgroup count part.
+ // The presence of leading elementwise operation implies that dispatch
+ // region formation happened using another transform dialect script and
+ // doesn't need the workgroup count part.
if (hasLeadingEltwise) {
- buildTileFuseDistWithTileSizes<TileToForeachThreadOp>(
+ iree_compiler::buildTileFuseDistToForeachThreadWithTileSizes(
b, optionalFusionRootH, opsHToFuse, tileSizes0Generic,
b.getArrayAttr({x}));
} else {
- buildTileFuseDistWithTileSizes<
- TileToForeachThreadAndWorkgroupCountRegionOp>(
- b, optionalFusionRootH, opsHToFuse, tileSizes0Generic,
- b.getArrayAttr({x}));
+ iree_compiler::
+ buildTileFuseDistToForeachThreadAndWorgroupCountWithTileSizes(
+ b, optionalFusionRootH, opsHToFuse, tileSizes0Generic,
+ b.getArrayAttr({x}));
}
return variantH;
}
-// TODO: consider passing a problem-specific struct to control information.
-static Value createReductionStrategyThreadDistributionPart(
- ImplicitLocOpBuilder &b, Value variantH, ArrayRef<int64_t> tileSizes1Fill,
- ArrayRef<int64_t> tileSizes1Generic, bool hasLeadingEltwise,
- bool hasTrailingEltwise) {
- // TODO: Relying on ordering is brittle, harden this.
- Value matchedH = b.create<MatchOp>(
- variantH, ArrayRef<StringRef>{linalg::GenericOp::getOperationName(),
- linalg::FillOp::getOperationName()});
- auto split = b.create<SplitHandlesOp>(
- matchedH,
- /*numResultHandles=*/4 + hasLeadingEltwise + hasTrailingEltwise);
- Value firstFusionRootH = split.getResults()[1 + hasLeadingEltwise];
- SmallVector<Value> firstFusionGroupHs =
- split.getResults().take_front(1 + hasLeadingEltwise);
- Value secondFusionRootH = split.getResults().back();
- SmallVector<Value> secondFusionGroupHs =
- split.getResults().drop_front(2 + hasLeadingEltwise).drop_back();
-
- auto z = mlir::gpu::GPUThreadMappingAttr::get(b.getContext(),
- ::mlir::gpu::Threads::DimZ);
- auto y = mlir::gpu::GPUThreadMappingAttr::get(b.getContext(),
- ::mlir::gpu::Threads::DimY);
-
- // clang-format off
- buildTileFuseDistWithTileSizes<TileToForeachThreadOp>(b,
- /*rootH=*/secondFusionRootH,
- /*opsHToFuse=*/secondFusionGroupHs,
- /*tileSizes=*/getAsOpFoldResult(b.getI64ArrayAttr(tileSizes1Fill)),
- /*threadDimMapping=*/b.getArrayAttr({z}));
- buildTileFuseDistWithTileSizes<TileToForeachThreadOp>(b,
- /*rootH=*/firstFusionRootH,
- /*opsHToFuse=*/firstFusionGroupHs,
- /*tileSizes=*/getAsOpFoldResult(b.getI64ArrayAttr(tileSizes1Generic)),
- /*threadDimMapping=*/b.getArrayAttr({z,y}));
- // clang-format on
- return variantH;
-}
-
-/// Structure to hold the parameters related to GPU reduction strategy.
-struct GPUReductionStrategyInfos {
- std::array<int64_t, 3> workgroupSize;
- SmallVector<int64_t> workgroupTileSizes;
- SmallVector<int64_t> fillSecondTileSizes;
- SmallVector<int64_t> genericSecondTileSizes;
- bool hasLeadingEltwise;
- bool hasTrailingEltwise;
-};
-
-/// Returns a triple of handles: the leading elementwise operation, the
-/// reduction operation and the fusion root. The leading elementwise and the
-/// fusion root may be null. If the fusion root is null, the reduction operation
-/// should be used as fusion root instead.
-// TODO: consider passing a problem-specific struct to control information.
-static std::tuple<Value, Value, Value>
-createMatchReductionBlockDistributionHandles(ImplicitLocOpBuilder &b,
- Value variantH,
- bool hasLeadingEltwise,
- bool hasTrailingEltwise) {
- Value originalGenericH =
- b.create<MatchOp>(variantH, linalg::GenericOp::getOperationName());
- auto op = b.create<SplitHandlesOp>(
- originalGenericH,
- /*numResultHandles=*/1 + hasLeadingEltwise + hasTrailingEltwise);
- return std::make_tuple(hasLeadingEltwise ? op.getResults().front() : Value(),
- op.getResults().drop_front(hasLeadingEltwise).front(),
- hasTrailingEltwise ? op.getResults().back() : Value());
-}
-
-// TODO: generalize and automate over and over.
-// TODO: significantly shrink this down.
-// TODO: consider passing a problem-specific struct to control information.
-static void createReductionCudaStrategy(
- ImplicitLocOpBuilder &b, Value variantH,
- const GPUReductionStrategyInfos &infos) {
- // Step 0. Match the ops.
- Value originalFillH =
- b.create<MatchOp>(variantH, linalg::FillOp::getOperationName());
- auto [leadingH, reductionH, fusionRootH] =
- createMatchReductionBlockDistributionHandles(
- b, variantH, infos.hasLeadingEltwise, infos.hasTrailingEltwise);
-
- // Step 1: Distribute to blocks using the current IREE lowering config.
- variantH = createReductionStrategyBlockDistributionPart(
- b, variantH, originalFillH, reductionH, fusionRootH,
- getAsOpFoldResult(b.getI64ArrayAttr(infos.workgroupTileSizes)),
- infos.hasLeadingEltwise, infos.hasTrailingEltwise);
-
- // Step 2. Second level of tiling + fusion parallelizes to threads.
- variantH = createReductionStrategyThreadDistributionPart(
- b, variantH, infos.fillSecondTileSizes, infos.genericSecondTileSizes,
- infos.hasLeadingEltwise, infos.hasTrailingEltwise);
-
- // Step 3. Rank-reduce and vectorize.
- // TODO: assumes a single func::FuncOp to transform, may need hardening.
- Value funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
- funcH = buildVectorizeStrategy(b, funcH);
-
- // Step 4. Bufferize and drop HAL decriptor from memref ops.
- variantH = b.create<IREEBufferizeOp>(variantH, /*targetGpu=*/true);
- Value memrefFunc =
- b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
- b.create<IREEEraseHALDescriptorTypeFromMemRefOp>(memrefFunc);
-
- // Step 5. Post-bufferization mapping to blocks and threads.
- // Need to match again since bufferize invalidated all handles.
- // TODO: assumes a single func::FuncOp to transform, may need hardening.
- funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
- funcH = buildMapToBlockAndThreads(b, funcH, infos.workgroupSize);
-
- // Step 6. Post-bufferization vector distribution with rank-reduction.
- buildDistributeVectors(b, variantH, funcH);
-}
-
-// TODO: consider passing a problem-specific struct to control information.
-static bool matchGPUReduction(linalg::LinalgOp op,
- GPUReductionStrategyInfos &info) {
- // TODO: match the sequence the strategy supports.
- StructuredOpMatcher pattern, fill, leadingEltwise, trailingEltwise;
- makeGPUReductionMatcher(pattern, fill, leadingEltwise, trailingEltwise);
- if (!matchPattern(op, pattern)) return false;
-
- info.hasLeadingEltwise = leadingEltwise.getCaptured() != nullptr;
- info.hasTrailingEltwise = trailingEltwise.getCaptured() != nullptr;
-
- // Hardcoded workagroup size, this could be deduced from the reduction dim.
- info.workgroupSize = {32, 2, 1};
- SmallVector<unsigned> partitionedLoops =
- cast<PartitionableLoopsInterface>(op.getOperation())
- .getPartitionableLoops(kNumMaxParallelDims);
- size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1;
- // Tile all the parallel dimension to 1.
- info.workgroupTileSizes.append(numLoops, 1);
- info.fillSecondTileSizes = {1, 0, 0};
- info.genericSecondTileSizes = {1, 1, 0};
- return true;
-}
-
-/// Structure to hold the parameters related to GPU reduction strategy.
-struct CPUReductionStrategyInfos {
- int64_t workgroupSize;
- SmallVector<int64_t> tileSizes;
-};
-
-static bool matchCPUReduction(linalg::LinalgOp op,
- CPUReductionStrategyInfos &infos) {
- // TODO: match the sequence the strategy supports.
- auto fill = m_StructuredOp<linalg::FillOp>();
- auto pattern = m_StructuredOp()
- .dim(AllDims(), ShapeKind::Static)
- .dim(-1, utils::IteratorType::reduction)
- .output(NumEqualsTo(1))
- .output(0, fill);
-
- // TODO: set the right config as expected by the strategy.
- infos.workgroupSize = 1;
- SmallVector<unsigned> partitionedLoops =
- cast<PartitionableLoopsInterface>(op.getOperation())
- .getPartitionableLoops(kNumMaxParallelDims);
- size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1;
- // Tile all the parallel dimension to 1.
- infos.tileSizes.append(numLoops, 1);
- return true;
-}
-
-using StrategyBuilderFn = std::function<void(ImplicitLocOpBuilder &, Value)>;
-
-static void createTransformRegion(func::FuncOp entryPoint,
- StrategyBuilderFn buildStrategy) {
+void mlir::iree_compiler::createTransformRegion(
+ func::FuncOp entryPoint, StrategyBuilderFn buildStrategy) {
MLIRContext *ctx = entryPoint.getContext();
Location loc = entryPoint.getLoc();
OpBuilder b(ctx);
@@ -524,71 +391,3 @@
<< "\n");
LLVM_DEBUG(sequence.print(DBGS()));
}
-
-// TODO: generalize and automate over and over.
-// TODO: significantly shrink this down.
-static LogicalResult createReductionCpuStrategy(
- ImplicitLocOpBuilder &b, Value variantH,
- const CPUReductionStrategyInfos &info) {
- // Step 0. Fetch transform information from the config and materialize it in
- // the payload IR.
- // TODO: this still requires specific knowledge of ops present in the IR
- // and is very brittle.
- Value originalFillH =
- b.create<MatchOp>(variantH, linalg::FillOp::getOperationName());
- Value originalGenericH =
- b.create<MatchOp>(variantH, linalg::GenericOp::getOperationName());
-
- // Step 1: Distribute to blocks using the current IREE lowering config.
- variantH = createReductionStrategyBlockDistributionPart(
- b, variantH, originalFillH, originalGenericH, Value(),
- getAsOpFoldResult(b.getI64ArrayAttr(info.tileSizes)));
-
- // Step 2. Rank-reduce and buildVectorizeStrategy.
- // TODO: assumes a single func::FuncOp to transform, may need hardening.
- Value funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
- funcH = buildVectorizeStrategy(b, funcH);
-
- // Step 3. Bufferize and drop HAL decriptor from memref ops.
- variantH = b.create<IREEBufferizeOp>(variantH, /*targetGpu=*/true);
- Value memrefFunc =
- b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
- b.create<IREEEraseHALDescriptorTypeFromMemRefOp>(memrefFunc);
-
- // Step 4. Post-bufferization mapping to blocks only.
- // Need to match again since bufferize invalidated all handles.
- // TODO: assumes a single func::FuncOp to transform, may need hardening.
- funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
- funcH = b.create<ForeachThreadToWorkgroupOp>(funcH);
-
- return success();
-}
-
-LogicalResult matchAndSetGPUReductionTransformStrategy(func::FuncOp entryPoint,
- linalg::LinalgOp op) {
- // 1. Match
- GPUReductionStrategyInfos infos;
- if (!matchGPUReduction(op, infos)) return failure();
- auto strategyBuilder = [&](ImplicitLocOpBuilder &b, Value variant) {
- return createReductionCudaStrategy(b, variant, infos);
- };
- // 2. Add the strategy.
- createTransformRegion(entryPoint, strategyBuilder);
- return success();
-}
-
-LogicalResult matchAndSetCPUReductionTransformStrategy(func::FuncOp entryPoint,
- linalg::LinalgOp op) {
- // 1. Match
- CPUReductionStrategyInfos infos;
- if (!matchCPUReduction(op, infos)) return failure();
- auto startegyBuilder = [&](ImplicitLocOpBuilder &b, Value variant) {
- return createReductionCpuStrategy(b, variant, infos);
- };
- // 2. Add the strategy.
- createTransformRegion(entryPoint, startegyBuilder);
- return success();
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.h b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.h
index 83ddb85..c7ca9fb 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.h
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.h
@@ -3,7 +3,8 @@
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_COMPILER_CODEGEN_LLVMGPU_GPUTRANSFORMDIALECT_STRATEGIES_H_
+
+#ifndef IREE_COMPILER_CODEGEN_COMMON_TRANSFORMDIALECT_STRATEGIES_H_
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -12,15 +13,95 @@
namespace mlir {
namespace iree_compiler {
-/// Return success if the IR matches what the GPU reduction strategy can handle.
-/// If it is success it will append the transform dialect after the entry point
-/// module.
-LogicalResult matchAndSetGPUReductionTransformStrategy(func::FuncOp entryPoint,
- linalg::LinalgOp op);
+//===----------------------------------------------------------------------===//
+// Low-level reusable builder APIs, these should follow MLIR-style builders.
+//===----------------------------------------------------------------------===//
-LogicalResult matchAndSetCPUReductionTransformStrategy(func::FuncOp entryPoint,
- linalg::LinalgOp op);
+/// Prints `handles` in order. Prints the whole IR if `handles` is empty.
+static void buildPrint(ImplicitLocOpBuilder &b, ValueRange handles = {});
+
+/// Performs the following transformations:
+/// 1. Tiles `rootH` to scf.foreach_thread to with `tileSizesOrNumThreads`
+/// according to whether spec is a TileSizesSpec or a NumThreadsSpec.
+/// 2. Maps the resulting scf.foreach_thread to threads according to
+/// `threadDimMapping`.
+/// 3. Iterates over `opsHToFuse` in order and fuses into the containing op.
+/// Returns a handle to the resulting scf.foreach_thread.
+///
+/// Fusion operates in batch mode: a single fusion command is issued and a
+/// topological sort is automatically computed by the fusion.
+/// Since this applies a single fusion, no interleaved canonicalization / cse /
+/// enabling transformation occurs and the resulting fusion may not be as good.
+///
+/// In the future, an iterative mode in which the user is responsible for
+/// providing the fusion order and has interleaved canonicalization / cse /
+/// enabling transform will be introduced and may result in better fusions.
+///
+/// If `resultingFusedOpsHandles` is a non-null pointer, the fused operation are
+/// appended in order.
+///
+// TODO: if someone knows how to properly export templates go for it .. sigh.
+Value buildTileFuseDistToForeachThreadWithTileSizes(
+ ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
+ ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping,
+ SmallVectorImpl<Value> *resultingFusedOpsHandles = nullptr);
+Value buildTileFuseDistToForeachThreadAndWorgroupCountWithTileSizes(
+ ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
+ ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping,
+ SmallVectorImpl<Value> *resultingFusedOpsHandles = nullptr);
+
+/// See buildTileFuseDistWithTileSizes.
+Value buildTileFuseDistToForeachThreadWithNumThreads(
+ ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
+ ArrayRef<OpFoldResult> numThreads, ArrayAttr threadDimMapping,
+ SmallVectorImpl<Value> *resultingFusedOpsHandles = nullptr);
+Value buildTileFuseDistToForeachThreadAndWorgroupCountWithNumThreads(
+ ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
+ ArrayRef<OpFoldResult> numThreads, ArrayAttr threadDimMapping,
+ SmallVectorImpl<Value> *resultingFusedOpsHandles = nullptr);
+
+/// Apply patterns and vectorize (for now always applies rank-reduction).
+/// Takes a handle to a func.func and returns an updated handle to a
+/// func.func.
+Value buildVectorize(ImplicitLocOpBuilder &b, Value funcH);
+
+/// Bufferize and drop HAL decriptor from memref ops.
+/// Takes a handle variantOp and returns a handle to the same variant op.
+Value buildBufferize(ImplicitLocOpBuilder &b, Value variantH,
+ bool targetGpu = false);
+
+/// Post-bufferization mapping to blocks and threads.
+/// Takes a handle to a func.func and returns an updated handle to a
+/// func.func.
+Value buildMapToBlockAndThreads(ImplicitLocOpBuilder &b, Value funcH,
+ ArrayRef<int64_t> blockSize);
+
+static constexpr unsigned kCudaWarpSize = 32;
+
+/// Post-bufferization vector distribution with rank-reduction.
+/// Takes a handle to a func.func and returns an updated handle to a
+/// func.func.
+Value buildDistributeVectors(ImplicitLocOpBuilder &b, Value variantH,
+ Value funcH, int64_t warpSize = kCudaWarpSize);
+
+using StrategyBuilderFn = std::function<void(ImplicitLocOpBuilder &, Value)>;
+
+void createTransformRegion(func::FuncOp entryPoint,
+ StrategyBuilderFn buildStrategy);
+
+//===----------------------------------------------------------------------===//
+// Higher-level problem-specific strategy creation APIs, these should favor
+// user-friendliness.
+//===----------------------------------------------------------------------===//
+/// Distribute to blocks using the current IREE lowering config.
+// TODO: consider passing a problem-specific struct to control information.
+Value createReductionStrategyBlockDistributionPart(
+ ImplicitLocOpBuilder &b, Value variantH, Value originalFillH,
+ Value reductionH, Value optionalFusionRootH,
+ ArrayRef<OpFoldResult> tileSizes0Generic, bool hasLeadingEltwise = false,
+ bool hasTrailingEltwise = false);
+
} // namespace iree_compiler
} // namespace mlir
-#endif // IREE_COMPILER_CODEGEN_LLVMGPU_GPUTRANSFORMDIALECT_STRATEGIES_H_
+#endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMDIALECT_STRATEGIES_H_
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesCPU.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesCPU.cpp
new file mode 100644
index 0000000..64b3d09
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesCPU.cpp
@@ -0,0 +1,178 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/TransformDialectStrategiesCPU.h"
+
+#include <numeric>
+#include <type_traits>
+
+#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
+#include "iree-dialects/Transforms/TransformMatchers.h"
+#include "iree/compiler/Codegen/Common/TransformDialectStrategies.h"
+#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h"
+#include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h"
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "iree-transform-builder"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+// TODO: significantly better namespacing.
+using iree_compiler::IREE::transform_dialect::ApplyPatternsOp;
+using iree_compiler::IREE::transform_dialect::ConfigExtractPart;
+using iree_compiler::IREE::transform_dialect::ForeachThreadToWorkgroupOp;
+using iree_compiler::IREE::transform_dialect::IREEBufferizeOp;
+using iree_compiler::IREE::transform_dialect::
+ IREEEraseHALDescriptorTypeFromMemRefOp;
+using iree_compiler::IREE::transform_dialect::
+ MapNestedForeachThreadToGpuThreadsOp;
+using iree_compiler::IREE::transform_dialect::
+ TileToForeachThreadAndWorkgroupCountRegionOp;
+using iree_compiler::IREE::transform_dialect::VectorToWarpExecuteOnLane0Op;
+using iree_compiler::IREE::transform_dialect::VectorWarpDistributionOp;
+using transform::FuseIntoContainingOp;
+using transform::MatchOp;
+using transform::MergeHandlesOp;
+using transform::PrintOp;
+using transform::SequenceOp;
+using transform::SplitHandlesOp;
+using transform::SplitReductionOp;
+using transform::TileToForeachThreadOp;
+using transform::VectorizeOp;
+using transform_ext::AllDims;
+using transform_ext::IsPermutation;
+using transform_ext::m_StructuredOp;
+using transform_ext::NumEqualsTo;
+using transform_ext::ShapeKind;
+using transform_ext::StructuredOpMatcher;
+
+/// Matches `args` within `targetH` and unpacks a number of handles `N`.
+/// Assumes there are exactly `N` matched ops (but could be relaxed).
+/// Returns the tuple of handles.
+template <int N, typename... MatchingArgs>
+auto matchAndUnpack(ImplicitLocOpBuilder &b, Value targetH,
+ MatchingArgs... args) {
+ Value matchedH = b.create<MatchOp>(targetH, args...);
+ auto matchOp = b.create<SplitHandlesOp>(matchedH,
+ /*numHandles=*/N);
+ assert(matchOp->getNumResults() == N && "Unexpected number of results");
+ std::array<Value, N> a;
+ for (int64_t i = 0; i < N; ++i) a[i] = matchOp->getResult(i);
+ return std::tuple_cat(a);
+}
+
+//===----------------------------------------------------------------------===//
+// Higher-level problem-specific strategy creation APIs, these should favor
+// user-friendliness.
+//===----------------------------------------------------------------------===//
+
+// TODO: consider lifting and exposing.
+
+/// Structure to hold the parameters related to GPU reduction strategy.
+struct CPUReductionStrategyInfos {
+ int64_t workgroupSize;
+ SmallVector<int64_t> tileSizes;
+};
+
+static bool matchCPUReduction(linalg::LinalgOp op,
+ CPUReductionStrategyInfos &infos) {
+ // TODO: match the sequence the strategy supports.
+ auto fill = m_StructuredOp<linalg::FillOp>();
+ auto pattern = m_StructuredOp()
+ .dim(AllDims(), ShapeKind::Static)
+ .dim(-1, utils::IteratorType::reduction)
+ .output(NumEqualsTo(1))
+ .output(0, fill);
+
+ // TODO: set the right config as expected by the strategy.
+ infos.workgroupSize = 1;
+ SmallVector<unsigned> partitionedLoops =
+ cast<iree_compiler::PartitionableLoopsInterface>(op.getOperation())
+ .getPartitionableLoops(iree_compiler::kNumMaxParallelDims);
+ size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1;
+ // Tile all the parallel dimension to 1.
+ infos.tileSizes.append(numLoops, 1);
+ return true;
+}
+
+// TODO: generalize and automate over and over.
+// TODO: significantly shrink this down.
+static LogicalResult createReductionCpuStrategy(
+ ImplicitLocOpBuilder &b, Value variantH,
+ const CPUReductionStrategyInfos &info) {
+ // Step 0. Fetch transform information from the config and materialize it in
+ // the payload IR.
+ // TODO: this still requires specific knowledge of ops present in the IR
+ // and is very brittle.
+ Value originalFillH =
+ b.create<MatchOp>(variantH, linalg::FillOp::getOperationName());
+ Value originalGenericH =
+ b.create<MatchOp>(variantH, linalg::GenericOp::getOperationName());
+
+ // Step 1: Distribute to blocks using the current IREE lowering config.
+ variantH = iree_compiler::createReductionStrategyBlockDistributionPart(
+ b, variantH, originalFillH, originalGenericH, Value(),
+ getAsOpFoldResult(b.getI64ArrayAttr(info.tileSizes)));
+
+ // Step 2. Rank-reduce and buildVectorize.
+ // TODO: assumes a single func::FuncOp to transform, may need hardening.
+ Value funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
+ funcH = iree_compiler::buildVectorize(b, funcH);
+
+ // Step 3. Bufferize and drop HAL descriptor from memref ops.
+ variantH = iree_compiler::buildBufferize(b, variantH, /*targetGpu=*/true);
+
+ // Step 4. Post-bufferization mapping to blocks only.
+ // Need to match again since bufferize invalidated all handles.
+ // TODO: assumes a single func::FuncOp to transform, may need hardening.
+ funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
+ funcH = b.create<ForeachThreadToWorkgroupOp>(funcH);
+
+ return success();
+}
+
+LogicalResult iree_compiler::matchAndSetCPUReductionTransformStrategy(
+ func::FuncOp entryPoint, linalg::LinalgOp op) {
+ // 1. Match
+ CPUReductionStrategyInfos infos;
+ if (!matchCPUReduction(op, infos)) return failure();
+ auto startegyBuilder = [&](ImplicitLocOpBuilder &b, Value variant) {
+ return createReductionCpuStrategy(b, variant, infos);
+ };
+ // 2. Add the strategy.
+ createTransformRegion(entryPoint, startegyBuilder);
+ return success();
+}
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesCPU.h b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesCPU.h
new file mode 100644
index 0000000..965a366
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesCPU.h
@@ -0,0 +1,23 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_CODEGEN_COMMON_TRANSFORMDIALECT_STRATEGIES_CPU_H_
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/IR/BuiltinOps.h"
+
+namespace mlir {
+namespace iree_compiler {
+/// Return success if the IR matches what the GPU reduction strategy can handle.
+/// If it is success it will append the transform dialect after the entry point
+/// module.
+LogicalResult matchAndSetCPUReductionTransformStrategy(func::FuncOp entryPoint,
+ linalg::LinalgOp op);
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMDIALECT_STRATEGIES_CPU_H_
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.cpp
new file mode 100644
index 0000000..c4e119d
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.cpp
@@ -0,0 +1,249 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.h"
+
+#include <numeric>
+#include <type_traits>
+
+#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
+#include "iree-dialects/Transforms/TransformMatchers.h"
+#include "iree/compiler/Codegen/Common/TransformDialectStrategies.h"
+#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h"
+#include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h"
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "iree-transform-builder"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+// TODO: significantly better namespacing.
+using iree_compiler::IREE::transform_dialect::ApplyPatternsOp;
+using iree_compiler::IREE::transform_dialect::ConfigExtractPart;
+using iree_compiler::IREE::transform_dialect::ForeachThreadToWorkgroupOp;
+using iree_compiler::IREE::transform_dialect::IREEBufferizeOp;
+using iree_compiler::IREE::transform_dialect::
+ IREEEraseHALDescriptorTypeFromMemRefOp;
+using iree_compiler::IREE::transform_dialect::
+ MapNestedForeachThreadToGpuThreadsOp;
+using iree_compiler::IREE::transform_dialect::
+ TileToForeachThreadAndWorkgroupCountRegionOp;
+using iree_compiler::IREE::transform_dialect::VectorToWarpExecuteOnLane0Op;
+using iree_compiler::IREE::transform_dialect::VectorWarpDistributionOp;
+using transform::FuseIntoContainingOp;
+using transform::MatchOp;
+using transform::MergeHandlesOp;
+using transform::PrintOp;
+using transform::SequenceOp;
+using transform::SplitHandlesOp;
+using transform::SplitReductionOp;
+using transform::TileToForeachThreadOp;
+using transform::VectorizeOp;
+using transform_ext::AllDims;
+using transform_ext::IsPermutation;
+using transform_ext::m_StructuredOp;
+using transform_ext::NumEqualsTo;
+using transform_ext::ShapeKind;
+using transform_ext::StructuredOpMatcher;
+
+/// Matches `args` within `targetH` and unpacks a number of handles `N`.
+/// Assumes there are exactly `N` matched ops (but could be relaxed).
+/// Returns the tuple of handles.
+template <int N, typename... MatchingArgs>
+auto matchAndUnpack(ImplicitLocOpBuilder &b, Value targetH,
+ MatchingArgs... args) {
+ Value matchedH = b.create<MatchOp>(targetH, args...);
+ auto matchOp = b.create<SplitHandlesOp>(matchedH,
+ /*numHandles=*/N);
+ assert(matchOp->getNumResults() == N && "Unexpected number of results");
+ std::array<Value, N> a;
+ for (int64_t i = 0; i < N; ++i) a[i] = matchOp->getResult(i);
+ return std::tuple_cat(a);
+}
+
+//===----------------------------------------------------------------------===//
+// Higher-level problem-specific strategy creation APIs, these should favor
+// user-friendliness.
+//===----------------------------------------------------------------------===//
+
+// TODO: consider passing a problem-specific struct to control information.
+static Value createReductionStrategyThreadDistributionPart(
+ ImplicitLocOpBuilder &b, Value variantH, ArrayRef<int64_t> tileSizes1Fill,
+ ArrayRef<int64_t> tileSizes1Generic, bool hasLeadingEltwise,
+ bool hasTrailingEltwise) {
+ // TODO: Relying on ordering is brittle, harden this.
+ Value matchedH = b.create<MatchOp>(
+ variantH, ArrayRef<StringRef>{linalg::GenericOp::getOperationName(),
+ linalg::FillOp::getOperationName()});
+ auto split = b.create<SplitHandlesOp>(
+ matchedH,
+ /*numResultHandles=*/4 + hasLeadingEltwise + hasTrailingEltwise);
+ Value firstFusionRootH = split.getResults()[1 + hasLeadingEltwise];
+ SmallVector<Value> firstFusionGroupHs =
+ split.getResults().take_front(1 + hasLeadingEltwise);
+ Value secondFusionRootH = split.getResults().back();
+ SmallVector<Value> secondFusionGroupHs =
+ split.getResults().drop_front(2 + hasLeadingEltwise).drop_back();
+
+ auto z = mlir::gpu::GPUThreadMappingAttr::get(b.getContext(),
+ ::mlir::gpu::Threads::DimZ);
+ auto y = mlir::gpu::GPUThreadMappingAttr::get(b.getContext(),
+ ::mlir::gpu::Threads::DimY);
+
+ // clang-format off
+ iree_compiler::buildTileFuseDistToForeachThreadWithTileSizes(b,
+ /*rootH=*/secondFusionRootH,
+ /*opsHToFuse=*/secondFusionGroupHs,
+ /*tileSizes=*/getAsOpFoldResult(b.getI64ArrayAttr(tileSizes1Fill)),
+ /*threadDimMapping=*/b.getArrayAttr({z}));
+ iree_compiler::buildTileFuseDistToForeachThreadWithTileSizes(b,
+ /*rootH=*/firstFusionRootH,
+ /*opsHToFuse=*/firstFusionGroupHs,
+ /*tileSizes=*/getAsOpFoldResult(b.getI64ArrayAttr(tileSizes1Generic)),
+ /*threadDimMapping=*/b.getArrayAttr({z,y}));
+ // clang-format on
+ return variantH;
+}
+
+/// Structure to hold the parameters related to GPU reduction strategy.
+struct GPUReductionStrategyInfos {
+ std::array<int64_t, 3> workgroupSize;
+ SmallVector<int64_t> workgroupTileSizes;
+ SmallVector<int64_t> fillSecondTileSizes;
+ SmallVector<int64_t> genericSecondTileSizes;
+ bool hasLeadingEltwise;
+ bool hasTrailingEltwise;
+};
+
+/// Returns a triple of handles: the leading elementwise operation, the
+/// reduction operation and the fusion root. The leading elementwise and the
+/// fusion root may be null. If the fusion root is null, the reduction operation
+/// should be used as fusion root instead.
+// TODO: consider passing a problem-specific struct to control information.
+static std::tuple<Value, Value, Value>
+createMatchReductionBlockDistributionHandles(ImplicitLocOpBuilder &b,
+ Value variantH,
+ bool hasLeadingEltwise,
+ bool hasTrailingEltwise) {
+ Value originalGenericH =
+ b.create<MatchOp>(variantH, linalg::GenericOp::getOperationName());
+ auto op = b.create<SplitHandlesOp>(
+ originalGenericH,
+ /*numResultHandles=*/1 + hasLeadingEltwise + hasTrailingEltwise);
+ return std::make_tuple(hasLeadingEltwise ? op.getResults().front() : Value(),
+ op.getResults().drop_front(hasLeadingEltwise).front(),
+ hasTrailingEltwise ? op.getResults().back() : Value());
+}
+
+// TODO: generalize and automate over and over.
+// TODO: significantly shrink this down.
+// TODO: consider passing a problem-specific struct to control information.
+static void createReductionCudaStrategy(
+ ImplicitLocOpBuilder &b, Value variantH,
+ const GPUReductionStrategyInfos &infos) {
+ // Step 0. Match the ops.
+ Value originalFillH =
+ b.create<MatchOp>(variantH, linalg::FillOp::getOperationName());
+ auto [leadingH, reductionH, fusionRootH] =
+ createMatchReductionBlockDistributionHandles(
+ b, variantH, infos.hasLeadingEltwise, infos.hasTrailingEltwise);
+
+ // Step 1: Distribute to blocks using the current IREE lowering config.
+ variantH = iree_compiler::createReductionStrategyBlockDistributionPart(
+ b, variantH, originalFillH, reductionH, fusionRootH,
+ getAsOpFoldResult(b.getI64ArrayAttr(infos.workgroupTileSizes)),
+ infos.hasLeadingEltwise, infos.hasTrailingEltwise);
+
+ // Step 2. Second level of tiling + fusion parallelizes to threads.
+ variantH = createReductionStrategyThreadDistributionPart(
+ b, variantH, infos.fillSecondTileSizes, infos.genericSecondTileSizes,
+ infos.hasLeadingEltwise, infos.hasTrailingEltwise);
+
+ // Step 3. Rank-reduce and vectorize.
+ // TODO: assumes a single func::FuncOp to transform, may need hardening.
+ Value funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
+ funcH = iree_compiler::buildVectorize(b, funcH);
+
+ // Step 4. Bufferize and drop HAL descriptor from memref ops.
+ variantH = iree_compiler::buildBufferize(b, variantH, /*targetGpu=*/true);
+
+ // Step 5. Post-bufferization mapping to blocks and threads.
+ // Need to match again since bufferize invalidated all handles.
+ // TODO: assumes a single func::FuncOp to transform, may need hardening.
+ funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
+ funcH =
+ iree_compiler::buildMapToBlockAndThreads(b, funcH, infos.workgroupSize);
+
+ // Step 6. Post-bufferization vector distribution with rank-reduction.
+ iree_compiler::buildDistributeVectors(b, variantH, funcH);
+}
+
+// TODO: consider passing a problem-specific struct to control information.
+static bool matchGPUReduction(linalg::LinalgOp op,
+ GPUReductionStrategyInfos &info) {
+ // TODO: match the sequence the strategy supports.
+ StructuredOpMatcher pattern, fill, leadingEltwise, trailingEltwise;
+ makeReductionMatcher(pattern, fill, leadingEltwise, trailingEltwise);
+ if (!matchPattern(op, pattern)) return false;
+
+ info.hasLeadingEltwise = leadingEltwise.getCaptured() != nullptr;
+ info.hasTrailingEltwise = trailingEltwise.getCaptured() != nullptr;
+
+ // Hardcoded workagroup size, this could be deduced from the reduction dim.
+ info.workgroupSize = {32, 2, 1};
+ SmallVector<unsigned> partitionedLoops =
+ cast<iree_compiler::PartitionableLoopsInterface>(op.getOperation())
+ .getPartitionableLoops(iree_compiler::kNumMaxParallelDims);
+ size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1;
+ // Tile all the parallel dimension to 1.
+ info.workgroupTileSizes.append(numLoops, 1);
+ info.fillSecondTileSizes = {1, 0, 0};
+ info.genericSecondTileSizes = {1, 1, 0};
+ return true;
+}
+
+LogicalResult iree_compiler::matchAndSetGPUReductionTransformStrategy(
+ func::FuncOp entryPoint, linalg::LinalgOp op) {
+ // 1. Match
+ GPUReductionStrategyInfos infos;
+ if (!matchGPUReduction(op, infos)) return failure();
+ auto strategyBuilder = [&](ImplicitLocOpBuilder &b, Value variant) {
+ return createReductionCudaStrategy(b, variant, infos);
+ };
+ // 2. Add the strategy.
+ createTransformRegion(entryPoint, strategyBuilder);
+ return success();
+}
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.h b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.h
new file mode 100644
index 0000000..f83d0d6
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.h
@@ -0,0 +1,25 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_CODEGEN_COMMON_TRANSFORMDIALECT_STRATEGIES_GPU_H_
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/IR/BuiltinOps.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Return success if the IR matches what the GPU reduction strategy can handle.
+/// If it is success it will append the transform dialect after the entry point
+/// module.
+LogicalResult matchAndSetGPUReductionTransformStrategy(func::FuncOp entryPoint,
+ linalg::LinalgOp op);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMDIALECT_STRATEGIES_GPU_H_
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD
index 2b4d3e8..41376fd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD
@@ -57,12 +57,10 @@
srcs = [
"CommonExtensions.cpp",
"CommonExtensionsOps.cpp.inc",
- "TransformMatchers.cpp",
],
hdrs = [
"CommonExtensions.h",
"CommonExtensionsOps.h.inc",
- "TransformMatchers.h",
],
deps = [
":CommonExtensionsOpGen",
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
index 2fd84a7..04ee20b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt
@@ -26,11 +26,9 @@
HDRS
"CommonExtensions.h"
"CommonExtensionsOps.h.inc"
- "TransformMatchers.h"
SRCS
"CommonExtensions.cpp"
"CommonExtensionsOps.cpp.inc"
- "TransformMatchers.cpp"
DEPS
::CommonExtensionsOpGen
IREEDialectsTransforms
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 2ecd1a2..1e1e953 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -10,7 +10,7 @@
#include "iree-dialects/Dialect/LinalgTransform/SimplePatternRewriter.h"
#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
#include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h"
-#include "iree/compiler/Codegen/Common/TransformExtensions/TransformMatchers.h"
+#include "iree-dialects/Transforms/TransformMatchers.h"
#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
#include "iree/compiler/Codegen/Passes.h"
@@ -1038,229 +1038,5 @@
return DiagnosedSilenceableFailure::success();
}
-//===---------------------------------------------------------------------===//
-// RegisterMatchCallbacksOp
-//===---------------------------------------------------------------------===//
-
-/// Match callback for "_test_match_callback" hook. Matches any payload
-/// operations associated with operand handles unless they have the
-/// "test.iree_transform_do_not_match" attribute, in which case produces a
-/// silenceable failure.
-static DiagnosedSilenceableFailure testMatchCallbackCallback(
- transform_dialect::MatchCallbackResult &res, Location loc,
- const transform::TransformState &state, ValueRange handles) {
- bool hadFailures = false;
- for (Value handle : handles) {
- if (llvm::any_of(state.getPayloadOps(handle), [](Operation *op) {
- return op->hasAttr("test.iree_transform_do_not_match");
- })) {
- res.addPayloadGroup(ArrayRef<Operation *>());
- hadFailures = true;
- } else {
- res.addPayloadGroup(state.getPayloadOps(handle));
- }
- }
- if (hadFailures) return emitSilenceableFailure(loc) << "failed to match";
- return DiagnosedSilenceableFailure::success();
-}
-
-/// Match callback for a reduction with optional leading and trailing
-/// elementwise operations. Matches *the first* occurrence of such a reduction
-/// within an op associated with the given handle.
-///
-/// Input handles:
-///
-/// - container op, must be associated with one operation.
-///
-/// Output handles:
-///
-/// - leading elementwise op, if any;
-/// - the "fill" op preceding the reduction;
-/// - reduction op;
-/// - trailing elementwise op, if any.
-static DiagnosedSilenceableFailure reductionCallback(
- transform_dialect::MatchCallbackResult &res, Location loc,
- const transform::TransformState &state, ValueRange handles) {
- if (handles.size() != 1 || state.getPayloadOps(handles[0]).size() != 1) {
- return emitSilenceableFailure(loc)
- << "expected one handle to one operation";
- }
-
- transform_dialect::StructuredOpMatcher pattern, fill, leadingEltwise,
- trailingEltwise;
- makeGPUReductionMatcher(pattern, fill, leadingEltwise, trailingEltwise);
-
- // TODO: need a mechanism for this to go around the entire IR,
- // potentially with list matches for each group.
- Operation *root = state.getPayloadOps(handles[0])[0];
- WalkResult walkResult = root->walk([&](Operation *op) {
- pattern.resetCapture();
- if (!matchPattern(op, pattern)) return WalkResult::advance();
-
- res.addPotentiallyEmptyPayloadGroup(leadingEltwise.getCaptured());
- res.addPayloadGroup({fill.getCaptured()});
- res.addPayloadGroup({pattern.getCaptured()});
- res.addPotentiallyEmptyPayloadGroup(trailingEltwise.getCaptured());
- return WalkResult::interrupt();
- });
-
- if (walkResult.wasInterrupted())
- return DiagnosedSilenceableFailure::success();
- return emitSilenceableFailure(loc) << "failed to match";
-}
-
-/// Match callback for a reduction after splitting with optional leading and
-/// trailing elementwise operations. Matches *the first* occurrence of such a
-/// reduction within an op associated with the given handle.
-///
-/// Input handles:
-///
-/// - container op, must be associated with one operation.
-///
-/// Output handles:
-///
-/// - leading elementwise op, if any;
-/// - the "fill" op preceding the original reduction;
-/// - the "fill" op preceding the split, more parallel reduction;
-/// - the split, more parallel reduction op;
-/// - reduction op;
-/// - trailing elementwise op, if any.
-static DiagnosedSilenceableFailure splitReductionCallback(
- transform_dialect::MatchCallbackResult &res, Location loc,
- const transform::TransformState &state, ValueRange handles) {
- if (handles.size() != 1 || state.getPayloadOps(handles[0]).size() != 1) {
- return emitSilenceableFailure(loc)
- << "expected one handle to one operation";
- }
-
- transform_dialect::StructuredOpMatcher parallel_reduction, combiner_reduction,
- parallel_fill, original_fill, leading, trailing;
- makeGPUSplitReductionMatcher(parallel_reduction, combiner_reduction,
- parallel_fill, original_fill, leading, trailing);
-
- // TODO: need a mechanism for this to go around the entire IR,
- // potentially with list matches for each group.
- Operation *root = state.getPayloadOps(handles[0])[0];
- WalkResult walkResult = root->walk([&](Operation *op) {
- combiner_reduction.resetCapture();
- if (!matchPattern(op, combiner_reduction)) return WalkResult::advance();
-
- res.addPotentiallyEmptyPayloadGroup(leading.getCaptured());
- res.addPayloadGroup({original_fill.getCaptured()});
- res.addPayloadGroup({parallel_fill.getCaptured()});
- res.addPayloadGroup({parallel_reduction.getCaptured()});
- res.addPayloadGroup({combiner_reduction.getCaptured()});
- res.addPotentiallyEmptyPayloadGroup(trailing.getCaptured());
- return WalkResult::interrupt();
- });
-
- if (walkResult.wasInterrupted())
- return DiagnosedSilenceableFailure::success();
- return emitSilenceableFailure(loc) << "failed to match";
-}
-
-DiagnosedSilenceableFailure transform_dialect::RegisterMatchCallbacksOp::apply(
- transform::TransformResults &results, transform::TransformState &state) {
- auto ®istry = state.addExtension<MatchCallbacksRegistry>();
- registry.registerCallback("_test_match_callback", testMatchCallbackCallback);
- registry.registerCallback("reduction", reductionCallback);
- registry.registerCallback("split_reduction", splitReductionCallback);
- return DiagnosedSilenceableFailure::success();
-}
-
-void transform_dialect::RegisterMatchCallbacksOp::getEffects(
- SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- // TODO: it doesn't really modify the payload, we need a separate resource for
- // this mapping.
- transform::modifiesPayload(effects);
-}
-
-//===---------------------------------------------------------------------===//
-// MatchCallbackOp
-//===---------------------------------------------------------------------===//
-
-DiagnosedSilenceableFailure transform_dialect::MatchCallbackOp::apply(
- transform::TransformResults &results, transform::TransformState &state) {
- auto setEmptyResults = [&results, this] {
- for (OpResult value : getResults()) {
- results.set(value, {});
- }
- };
- auto errorOut = [this, &setEmptyResults] {
- setEmptyResults();
- return emitSilenceableError();
- };
-
- auto *registry = state.getExtension<MatchCallbacksRegistry>();
- if (!registry) return errorOut() << "match registry not available";
-
- const MatchCallbacksRegistry::MatchCallbackFn *callback =
- registry->get(getCallbackName());
- if (!callback) {
- return errorOut() << "callback '" << getCallbackName()
- << "' not found in the registry";
- }
-
- MatchCallbackResult result;
- DiagnosedSilenceableFailure status =
- (*callback)(result, getLoc(), state, getInputs());
- if (!status.succeeded()) {
- setEmptyResults();
- if (status.isDefiniteFailure()) return status;
- if (getFailurePropagationMode() ==
- transform::FailurePropagationMode::Propagate) {
- return emitSilenceableError() << "failed to match";
- } else {
- return DiagnosedSilenceableFailure::success();
- }
- }
- if (getNumResults() != result.getNumPayloadGroups()) {
- return errorOut()
- << "callback produced a different number of handles than expected ( "
- << result.getNumPayloadGroups() << " vs " << getNumResults() << " )";
- }
-
- for (OpResult value : getResults()) {
- results.set(value, result.getPayloadGroup(value.getResultNumber()));
- }
- return DiagnosedSilenceableFailure::success();
-}
-
-void transform_dialect::MatchCallbackOp::getEffects(
- SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- transform::onlyReadsHandle(getInputs(), effects);
- transform::producesHandle(getOutputs(), effects);
- // TODO: it doesn't really modify the payload, we need a separate resource for
- // this mapping.
- transform::modifiesPayload(effects);
-}
-
-DiagnosedSilenceableFailure transform_dialect::TakeFirstOp::apply(
- transform::TransformResults &results, transform::TransformState &state) {
- SmallVector<Operation *> concatenated;
- bool found = false;
- for (Value handle : getInputs()) {
- ArrayRef<Operation *> payloads = state.getPayloadOps(handle);
- if (payloads.empty()) continue;
- if (!found) {
- results.set(getFirst().cast<OpResult>(), payloads);
- found = true;
- } else {
- llvm::append_range(concatenated, payloads);
- }
- }
-
- if (!found) results.set(getFirst().cast<OpResult>(), {});
- results.set(getRest().cast<OpResult>(), concatenated);
- return DiagnosedSilenceableFailure::success();
-}
-
-void transform_dialect::TakeFirstOp::getEffects(
- SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- transform::onlyReadsHandle(getInputs(), effects);
- transform::producesHandle(getFirst(), effects);
- transform::producesHandle(getRest(), effects);
-}
-
#define GET_OP_CLASSES
#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.cpp.inc"
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
index f012f51..8f7cf12 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
@@ -323,91 +323,4 @@
}];
}
-def RegisterMatchCallbacksOp :
- Op<Transform_Dialect, "iree.register_match_callbacks",
- [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- DeclareOpInterfaceMethods<TransformOpInterface>]> {
- let description = [{
- Registers named structured op matcher callbacks specific for IREE to use
- with `transform.iree.match_callback`. This should be called before first
- `match_callback` may be executed following the transform dialect control
- flow.
-
- The callbacks must have a unique name and a signature compatible with
- `MatchCallbacksRegistry::MatchCallbackFn`, which currently means
- `DiagnosedSilenceableFailure(MatchCallbackResult &, Location,
- const TransformState &, ValueRange)`. The callback receives a "result",
- followed by a location at which errors should be reported, a transform
- state at the moment of the _match_ (not registration) and a list of
- handle values passed as operands to the `match_callback` operation.
- It is expected to populate the "result" object with lists of payload
- operations that will be bound to the handles produced by the
- `match_callback` operation. The callback may fail, at which point
- it should produce a silenceable error. The callback currently is not
- allowed to modify the payload IR (though this may be revised in the
- future for the purpose of communicating the properties of the IR
- captured by the match). Therefore, it should not have a reason to
- produce a definite error.
- }];
-
- let arguments = (ins);
- let results = (outs);
- let assemblyFormat = "attr-dict";
- let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
-}
-
-def MatchCallbackOp :
- Op<Transform_Dialect, "iree.match_callback",
- [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- DeclareOpInterfaceMethods<TransformOpInterface>]> {
- let description = [{
- Performs payload IR matching using a C++ callback registered beforehand.
- The callback is identified by name and is passed the current transform
- state and the list of handle operands, along with information necessary
- for error propagation. See `register_match_callbacks` for the description
- of the callback contract.
-
- If `failure_propagation_mode` is set to `suppress`, any silenceable errors
- in the callback (typically, "failure to match") will be ignored and the
- resulting handles will be associated with empty lists of payload
- operations. Otherwise, silenceable failures are propagated.
- }];
-
- let arguments = (ins StrAttr:$callback_name,
- FailurePropagationMode:$failure_propagation_mode,
- Variadic<TransformTypeInterface>:$inputs);
- let results = (outs Variadic<TransformTypeInterface>:$outputs);
- let assemblyFormat = "`failures` `(` $failure_propagation_mode `)` "
- "$callback_name `(` $inputs `)` attr-dict "
- "`:` functional-type($inputs, $outputs)";
- let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
-}
-
-def TakeFirstOp :
- Op<Transform_Dialect, "iree.take_first",
- [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- DeclareOpInterfaceMethods<TransformOpInterface>]> {
- let description = [{
- Given an arbitrary list of handles associated with potentially empty lists
- of payload operations, produces two new handles:
-
- - a handle pointing to the same payload operations as the first operand
- handle with a non-empty list of payload operations;
- - a handle pointing to the concatenated list of payload operations
- associated with any other handle.
-
- Note that this does not perform any deduplication.
-
- This operation is useful to select a single target after some potentially
- unsuccessful matches.
- }];
-
- let arguments = (ins Variadic<TransformTypeInterface>:$inputs);
- let results = (outs TransformTypeInterface:$first,
- TransformTypeInterface:$rest);
- let assemblyFormat =
- "$inputs attr-dict `:` functional-type($inputs, results)";
- let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
-}
-
#endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_COMMONEXTENSIONS
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
index 6dfeffd..2f4b4a3 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
@@ -38,6 +38,7 @@
deps = [
"//compiler/src/iree/compiler/Codegen:PassHeaders",
"//compiler/src/iree/compiler/Codegen/Common",
+ "//compiler/src/iree/compiler/Codegen/Common:TransformDialectJitterPass",
"//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Interfaces:PartitionableLoopsInterface",
"//compiler/src/iree/compiler/Codegen/Sandbox",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
index 17b2e78..fcd81f6 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
@@ -81,6 +81,7 @@
MLIRVectorToSCF
MLIRVectorTransforms
iree::compiler::Codegen::Common
+ iree::compiler::Codegen::Common::TransformDialectJitterPass
iree::compiler::Codegen::Dialect::IREECodegenDialect
iree::compiler::Codegen::Interfaces::PartitionableLoopsInterface
iree::compiler::Codegen::PassHeaders
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index 4ec9218..321020d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -10,7 +10,7 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Codegen/Common/LinalgOpInfo.h"
-#include "iree/compiler/Codegen/Common/TransformDialectStrategies.h"
+#include "iree/compiler/Codegen/Common/TransformDialectStrategiesCPU.h"
#include "iree/compiler/Codegen/Common/UserConfig.h"
#include "iree/compiler/Codegen/LLVMCPU/TargetMLTransformInfo.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD
index 2243acf..b2d35d9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD
@@ -40,6 +40,7 @@
deps = [
"//compiler/src/iree/compiler/Codegen:PassHeaders",
"//compiler/src/iree/compiler/Codegen/Common",
+ "//compiler/src/iree/compiler/Codegen/Common:TransformDialectJitterPass",
"//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions:LLVMGPUExtensions",
"//compiler/src/iree/compiler/Codegen/Transforms",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index 7d2cb6b..1aa5a15 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -88,6 +88,7 @@
MLIRVectorToSCF
MLIRVectorTransforms
iree::compiler::Codegen::Common
+ iree::compiler::Codegen::Common::TransformDialectJitterPass
iree::compiler::Codegen::Dialect::IREECodegenDialect
iree::compiler::Codegen::LLVMGPU::TransformExtensions::LLVMGPUExtensions
iree::compiler::Codegen::PassHeaders
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index cc3317c..6423570 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -10,7 +10,7 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Codegen/Common/LinalgOpInfo.h"
-#include "iree/compiler/Codegen/Common/TransformDialectStrategies.h"
+#include "iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.h"
#include "iree/compiler/Codegen/Common/UserConfig.h"
#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
#include "iree/compiler/Codegen/LLVMGPU/TransposeUtils.h"
diff --git a/llvm-external-projects/iree-dialects/BUILD b/llvm-external-projects/iree-dialects/BUILD
index 1821fd3..46b5217 100644
--- a/llvm-external-projects/iree-dialects/BUILD
+++ b/llvm-external-projects/iree-dialects/BUILD
@@ -145,8 +145,13 @@
includes = ["include"],
deps = [
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:Rewrite",
+ "@llvm-project//mlir:SCFDialect",
+ "@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:Transforms",
],
)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h
index 07a4cfd..954deaa 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h
@@ -73,6 +73,7 @@
#define GET_OP_CLASSES
#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h.inc"
+namespace mlir {
namespace transform_ext {
class StructuredTransformOpsExtension
: public mlir::transform::TransformDialectExtension<
@@ -81,5 +82,6 @@
StructuredTransformOpsExtension();
};
} // namespace transform_ext
+} // namespace mlir
#endif // IREE_DIALECTS_DIALECT_LINALG_TRANSFORM_STRUCTUREDTRANSFORMOPSEXT_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td
index d8bfe7c..e2b716f 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td
@@ -55,7 +55,7 @@
static ::llvm::StringRef getDefaultDialect() { return "transform"; }
}];
- let cppNamespace = "transform_ext";
+ let cppNamespace = "mlir::transform_ext";
let hasVerifier = 1;
}
@@ -67,7 +67,7 @@
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let description = [{Indicates that the entire module should be bufferized.}];
let assemblyFormat = "attr-dict";
- let cppNamespace = "transform_ext";
+ let cppNamespace = "mlir::transform_ext";
}
def LowerVectorsOp : Op<Transform_Dialect, "lower_vectors",
@@ -89,7 +89,7 @@
);
let assemblyFormat = "attr-dict";
- let cppNamespace = "transform_ext";
+ let cppNamespace = "mlir::transform_ext";
}
def LowerToLLVMOp : Op<Transform_Dialect, "lower_to_llvm",
@@ -110,7 +110,95 @@
DefaultValuedAttr<BoolAttr, "false">:$enable_async);
let assemblyFormat = "attr-dict";
- let cppNamespace = "transform_ext";
+ let cppNamespace = "mlir::transform_ext";
+}
+
+
+def RegisterMatchCallbacksOp :
+ Op<Transform_Dialect, "iree.register_match_callbacks",
+ [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let description = [{
+ Registers named structured op matcher callbacks specific for IREE to use
+ with `transform.iree.match_callback`. This should be called before first
+ `match_callback` may be executed following the transform dialect control
+ flow.
+
+ The callbacks must have a unique name and a signature compatible with
+ `MatchCallbacksRegistry::MatchCallbackFn`, which currently means
+ `DiagnosedSilenceableFailure(MatchCallbackResult &, Location,
+ const TransformState &, ValueRange)`. The callback receives a "result",
+ followed by a location at which errors should be reported, a transform
+ state at the moment of the _match_ (not registration) and a list of
+ handle values passed as operands to the `match_callback` operation.
+ It is expected to populate the "result" object with lists of payload
+ operations that will be bound to the handles produced by the
+ `match_callback` operation. The callback may fail, at which point
+ it should produce a silenceable error. The callback currently is not
+ allowed to modify the payload IR (though this may be revised in the
+ future for the purpose of communicating the properties of the IR
+ captured by the match). Therefore, it should not have a reason to
+ produce a definite error.
+ }];
+
+ let arguments = (ins);
+ let results = (outs);
+ let assemblyFormat = "attr-dict";
+ let cppNamespace = "mlir::transform_ext";
+}
+
+def MatchCallbackOp :
+ Op<Transform_Dialect, "iree.match_callback",
+ [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let description = [{
+ Performs payload IR matching using a C++ callback registered beforehand.
+ The callback is identified by name and is passed the current transform
+ state and the list of handle operands, along with information necessary
+ for error propagation. See `register_match_callbacks` for the description
+ of the callback contract.
+
+ If `failure_propagation_mode` is set to `suppress`, any silenceable errors
+ in the callback (typically, "failure to match") will be ignored and the
+ resulting handles will be associated with empty lists of payload
+ operations. Otherwise, silenceable failures are propagated.
+ }];
+
+ let arguments = (ins StrAttr:$callback_name,
+ FailurePropagationMode:$failure_propagation_mode,
+ Variadic<TransformTypeInterface>:$inputs);
+ let results = (outs Variadic<TransformTypeInterface>:$outputs);
+ let assemblyFormat = "`failures` `(` $failure_propagation_mode `)` "
+ "$callback_name `(` $inputs `)` attr-dict "
+ "`:` functional-type($inputs, $outputs)";
+ let cppNamespace = "mlir::transform_ext";
+}
+
+def TakeFirstOp :
+ Op<Transform_Dialect, "iree.take_first",
+ [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let description = [{
+ Given an arbitrary list of handles associated with potentially empty lists
+ of payload operations, produces two new handles:
+
+ - a handle pointing to the same payload operations as the first operand
+ handle with a non-empty list of payload operations;
+ - a handle pointing to the concatenated list of payload operations
+ associated with any other handle.
+
+ Note that this does not perform any deduplication.
+
+ This operation is useful to select a single target after some potentially
+ unsuccessful matches.
+ }];
+
+ let arguments = (ins Variadic<TransformTypeInterface>:$inputs);
+ let results = (outs TransformTypeInterface:$first,
+ TransformTypeInterface:$rest);
+ let assemblyFormat =
+ "$inputs attr-dict `:` functional-type($inputs, results)";
+ let cppNamespace = "mlir::transform_ext";
}
#endif // STRUCTURED_TRANSFORM_OPS_EXT
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/TransformMatchers.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
similarity index 92%
rename from compiler/src/iree/compiler/Codegen/Common/TransformExtensions/TransformMatchers.h
rename to llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
index 4e467d5..430d879 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/TransformMatchers.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
@@ -11,14 +11,15 @@
#include <cstdint>
#include <functional>
-#include "llvm/ADT/StringMap.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/Matchers.h"
+#include "llvm/ADT/StringMap.h"
-namespace mlir::iree_compiler::IREE::transform_dialect {
+namespace mlir {
+namespace transform_ext {
//===---------------------------------------------------------------------===//
// StructuredOpMatcher and predicates.
@@ -93,7 +94,7 @@
namespace detail {
template <typename T>
using has_reset_capture_t = decltype(std::declval<T>().resetCapture());
-} // namespace detail
+} // namespace detail
/// Structured op matcher with additional predicates attachable through the
/// fluent, a.k.a. chainable, API. Note that public API must *not* accept
@@ -112,7 +113,7 @@
predicates.push_back(std::move(firstPredicate));
}
- public:
+public:
/// Matches any structured operation, i.e., operation with LinalgOp interface.
StructuredOpMatcher() {}
@@ -178,15 +179,18 @@
&operandMatcher](linalg::LinalgOp linalgOp) -> bool {
int64_t transformedPosition =
position >= 0 ? position : linalgOp.getNumDpsInputs() + position;
- if (transformedPosition >= linalgOp.getNumDpsInputs()) return false;
+ if (transformedPosition >= linalgOp.getNumDpsInputs())
+ return false;
Operation *definingOp = linalgOp.getDpsInputOperand(transformedPosition)
->get()
.getDefiningOp();
- if (!definingOp) return optional.value;
+ if (!definingOp)
+ return optional.value;
// We MUST run the matcher at this point, even if the match is optional,
// to allow for capture.
- if (operandMatcher.match(definingOp)) return true;
+ if (operandMatcher.match(definingOp))
+ return true;
return optional.value;
});
recordNestedMatcher(operandMatcher);
@@ -246,15 +250,18 @@
&operandMatcher](linalg::LinalgOp linalgOp) -> bool {
int64_t transformedPosition =
position >= 0 ? position : linalgOp.getNumDpsInits() + position;
- if (transformedPosition >= linalgOp.getNumDpsInits()) return false;
+ if (transformedPosition >= linalgOp.getNumDpsInits())
+ return false;
Operation *definingOp = linalgOp.getDpsInitOperand(transformedPosition)
->get()
.getDefiningOp();
- if (!definingOp) return optional.value;
+ if (!definingOp)
+ return optional.value;
// We MUST run the matcher at this point, even if the match is optional,
// to allow for capture.
- if (operandMatcher.match(definingOp)) return true;
+ if (operandMatcher.match(definingOp))
+ return true;
return optional.value;
});
recordNestedMatcher(operandMatcher);
@@ -276,7 +283,8 @@
position](linalg::LinalgOp linalgOp) -> bool {
int64_t transformedPosition =
position >= 0 ? position : linalgOp->getNumResults() + position;
- if (transformedPosition >= linalgOp->getNumResults()) return false;
+ if (transformedPosition >= linalgOp->getNumResults())
+ return false;
// We MUST run the matcher at this point, even if the match is optional,
// to allow for capture.
@@ -305,10 +313,11 @@
/// for optional nested predicates from the previous application.
void resetCapture() {
captured = nullptr;
- for (const CaptureResetFn &fn : captureResetFns) fn();
+ for (const CaptureResetFn &fn : captureResetFns)
+ fn();
}
- private:
+private:
/// Informs the matcher that it has another, nested matcher. Practically,
/// records the captured value cleanup function so it runs when required.
template <typename T>
@@ -348,7 +357,7 @@
/// transform operation. Conceptually, a list of lists of payload operations to
/// be associated with each result handle.
class MatchCallbackResult {
- public:
+public:
/// Returns the number of lists of payload operations.
unsigned getNumPayloadGroups() const { return payloadGroupLengths.size(); }
@@ -380,7 +389,7 @@
addPayloadGroup(ArrayRef<Operation *>(op));
}
- private:
+private:
/// The flat list of all payload opreations. `payloadGroupLengths` can be used
/// to compute the sublist that corresponds to one nested list.
// TODO: if somebody implements such a flattened vector generically, use it.
@@ -391,7 +400,7 @@
/// A transform state extension that maintains the mapping between callback
/// names as strings usable in `match_callback` and their implementations.
class MatchCallbacksRegistry : public transform::TransformState::Extension {
- public:
+public:
using MatchCallbackFn = std::function<DiagnosedSilenceableFailure(
MatchCallbackResult &, Location, const transform::TransformState &,
ValueRange)>;
@@ -414,11 +423,12 @@
/// name, or null if it is not present in the registry.
const MatchCallbackFn *get(StringRef name) const {
auto iter = callbacks.find(name);
- if (iter == callbacks.end()) return nullptr;
+ if (iter == callbacks.end())
+ return nullptr;
return &iter->getValue();
}
- private:
+private:
llvm::StringMap<MatchCallbackFn> callbacks;
};
@@ -432,10 +442,10 @@
///
/// where trailing and leading are elementwise operations whose presence is
/// optional. Each matcher will capture the corresponding operation.
-void makeGPUReductionMatcher(StructuredOpMatcher &reduction,
- StructuredOpMatcher &fill,
- StructuredOpMatcher &leading,
- StructuredOpMatcher &trailing);
+void makeReductionMatcher(StructuredOpMatcher &reduction,
+ StructuredOpMatcher &fill,
+ StructuredOpMatcher &leading,
+ StructuredOpMatcher &trailing);
/// Creates a group of matchers for:
///
@@ -447,13 +457,14 @@
/// where trailing and leading are elementwise operations whose presence is
/// optional, and with subsetting ops potentially present on the operand use-def
/// chains.
-void makeGPUSplitReductionMatcher(StructuredOpMatcher ¶llel_reduction,
- StructuredOpMatcher &combiner_reduction,
- StructuredOpMatcher ¶llel_fill,
- StructuredOpMatcher &original_fill,
- StructuredOpMatcher &leading,
- StructuredOpMatcher &trailing);
+void makeSplitReductionMatcher(StructuredOpMatcher ¶llel_reduction,
+ StructuredOpMatcher &combiner_reduction,
+ StructuredOpMatcher ¶llel_fill,
+ StructuredOpMatcher &original_fill,
+ StructuredOpMatcher &leading,
+ StructuredOpMatcher &trailing);
-} // namespace mlir::iree_compiler::IREE::transform_dialect
+} // namespace transform_ext
+} // namespace mlir
-#endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_TRANSFORMMATCHERS_H_
+#endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_TRANSFORMMATCHERS_H_
diff --git a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp b/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
index 4a4594a..576e626 100644
--- a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
+++ b/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
@@ -62,7 +62,7 @@
DialectRegistry registry;
registry.addExtensions<
mlir::iree_compiler::IREE::LinalgExt::LinalgExtTransformOpsExtension,
- transform_ext::StructuredTransformOpsExtension>();
+ mlir::transform_ext::StructuredTransformOpsExtension>();
ctx->appendDialectRegistry(registry);
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
index dc013b2..2afd477 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
@@ -12,6 +12,7 @@
#include "iree-dialects/Transforms/Listener.h"
#include "iree-dialects/Transforms/ListenerCSE.h"
#include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h"
+#include "iree-dialects/Transforms/TransformMatchers.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
@@ -317,7 +318,7 @@
// StructuredTransformOpsExtension
//===----------------------------------------------------------------------===//
-transform_ext::StructuredTransformOpsExtension::
+mlir::transform_ext::StructuredTransformOpsExtension::
StructuredTransformOpsExtension() {
registerTransformOps<
#define GET_OP_LIST
@@ -1115,3 +1116,244 @@
// TODO: make composable...
return DiagnosedSilenceableFailure::success();
}
+
+//===---------------------------------------------------------------------===//
+// MatchCallbackOp
+//===---------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform_ext::MatchCallbackOp::apply(
+ mlir::transform::TransformResults &results,
+ mlir::transform::TransformState &state) {
+ auto setEmptyResults = [&results, this] {
+ for (OpResult value : getResults()) {
+ results.set(value, {});
+ }
+ };
+ auto errorOut = [this, &setEmptyResults] {
+ setEmptyResults();
+ return emitSilenceableError();
+ };
+
+ auto *registry = state.getExtension<transform_ext::MatchCallbacksRegistry>();
+ if (!registry)
+ return errorOut() << "match registry not available";
+
+ const transform_ext::MatchCallbacksRegistry::MatchCallbackFn *callback =
+ registry->get(getCallbackName());
+ if (!callback) {
+ return errorOut() << "callback '" << getCallbackName()
+ << "' not found in the registry";
+ }
+
+ MatchCallbackResult result;
+ DiagnosedSilenceableFailure status =
+ (*callback)(result, getLoc(), state, getInputs());
+ if (!status.succeeded()) {
+ setEmptyResults();
+ if (status.isDefiniteFailure())
+ return status;
+ if (getFailurePropagationMode() ==
+ mlir::transform::FailurePropagationMode::Propagate) {
+ return emitSilenceableError() << "failed to match";
+ } else {
+ return DiagnosedSilenceableFailure::success();
+ }
+ }
+ if (getNumResults() != result.getNumPayloadGroups()) {
+ return errorOut()
+ << "callback produced a different number of handles than expected ( "
+ << result.getNumPayloadGroups() << " vs " << getNumResults() << " )";
+ }
+
+ for (OpResult value : getResults()) {
+ results.set(value, result.getPayloadGroup(value.getResultNumber()));
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform_ext::MatchCallbackOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ mlir::transform::onlyReadsHandle(getInputs(), effects);
+ mlir::transform::producesHandle(getOutputs(), effects);
+ // TODO: it doesn't really modify the payload, we need a separate resource for
+ // this mapping.
+ mlir::transform::modifiesPayload(effects);
+}
+
+//===---------------------------------------------------------------------===//
+// RegisterMatchCallbacksOp
+//===---------------------------------------------------------------------===//
+
+/// Match callback for "_test_match_callback" hook. Matches any payload
+/// operations associated with operand handles unless they have the
+/// "test.iree_transform_do_not_match" attribute, in which case produces a
+/// silenceable failure.
+static DiagnosedSilenceableFailure
+testMatchCallbackCallback(transform_ext::MatchCallbackResult &res, Location loc,
+ const mlir::transform::TransformState &state,
+ ValueRange handles) {
+ bool hadFailures = false;
+ for (Value handle : handles) {
+ if (llvm::any_of(state.getPayloadOps(handle), [](Operation *op) {
+ return op->hasAttr("test.iree_transform_do_not_match");
+ })) {
+ res.addPayloadGroup(ArrayRef<Operation *>());
+ hadFailures = true;
+ } else {
+ res.addPayloadGroup(state.getPayloadOps(handle));
+ }
+ }
+ if (hadFailures)
+ return emitSilenceableFailure(loc) << "failed to match";
+ return DiagnosedSilenceableFailure::success();
+}
+
+/// Match callback for a reduction with optional leading and trailing
+/// elementwise operations. Matches *the first* occurrence of such a reduction
+/// within an op associated with the given handle.
+///
+/// Input handles:
+///
+/// - container op, must be associated with one operation.
+///
+/// Output handles:
+///
+/// - leading elementwise op, if any;
+/// - the "fill" op preceding the reduction;
+/// - reduction op;
+/// - trailing elementwise op, if any.
+static DiagnosedSilenceableFailure
+reductionCallback(transform_ext::MatchCallbackResult &res, Location loc,
+ const mlir::transform::TransformState &state,
+ ValueRange handles) {
+ if (handles.size() != 1 || state.getPayloadOps(handles[0]).size() != 1) {
+ return emitSilenceableFailure(loc)
+ << "expected one handle to one operation";
+ }
+
+ transform_ext::StructuredOpMatcher pattern, fill, leadingEltwise,
+ trailingEltwise;
+ makeReductionMatcher(pattern, fill, leadingEltwise, trailingEltwise);
+
+ // TODO: need a mechanism for this to go around the entire IR,
+ // potentially with list matches for each group.
+ Operation *root = state.getPayloadOps(handles[0])[0];
+ WalkResult walkResult = root->walk([&](Operation *op) {
+ pattern.resetCapture();
+ if (!matchPattern(op, pattern))
+ return WalkResult::advance();
+
+ res.addPotentiallyEmptyPayloadGroup(leadingEltwise.getCaptured());
+ res.addPayloadGroup({fill.getCaptured()});
+ res.addPayloadGroup({pattern.getCaptured()});
+ res.addPotentiallyEmptyPayloadGroup(trailingEltwise.getCaptured());
+ return WalkResult::interrupt();
+ });
+
+ if (walkResult.wasInterrupted())
+ return DiagnosedSilenceableFailure::success();
+ return emitSilenceableFailure(loc) << "failed to match";
+}
+
+/// Match callback for a reduction after splitting with optional leading and
+/// trailing elementwise operations. Matches *the first* occurrence of such a
+/// reduction within an op associated with the given handle.
+///
+/// Input handles:
+///
+/// - container op, must be associated with one operation.
+///
+/// Output handles:
+///
+/// - leading elementwise op, if any;
+/// - the "fill" op preceding the original reduction;
+/// - the "fill" op preceding the split, more parallel reduction;
+/// - the split, more parallel reduction op;
+/// - reduction op;
+/// - trailing elementwise op, if any.
+static DiagnosedSilenceableFailure
+splitReductionCallback(transform_ext::MatchCallbackResult &res, Location loc,
+ const mlir::transform::TransformState &state,
+ ValueRange handles) {
+ if (handles.size() != 1 || state.getPayloadOps(handles[0]).size() != 1) {
+ return emitSilenceableFailure(loc)
+ << "expected one handle to one operation";
+ }
+
+ transform_ext::StructuredOpMatcher parallel_reduction, combiner_reduction,
+ parallel_fill, original_fill, leading, trailing;
+ makeSplitReductionMatcher(parallel_reduction, combiner_reduction,
+ parallel_fill, original_fill, leading, trailing);
+
+ // TODO: need a mechanism for this to go around the entire IR,
+ // potentially with list matches for each group.
+ Operation *root = state.getPayloadOps(handles[0])[0];
+ WalkResult walkResult = root->walk([&](Operation *op) {
+ combiner_reduction.resetCapture();
+ if (!matchPattern(op, combiner_reduction))
+ return WalkResult::advance();
+
+ res.addPotentiallyEmptyPayloadGroup(leading.getCaptured());
+ res.addPayloadGroup({original_fill.getCaptured()});
+ res.addPayloadGroup({parallel_fill.getCaptured()});
+ res.addPayloadGroup({parallel_reduction.getCaptured()});
+ res.addPayloadGroup({combiner_reduction.getCaptured()});
+ res.addPotentiallyEmptyPayloadGroup(trailing.getCaptured());
+ return WalkResult::interrupt();
+ });
+
+ if (walkResult.wasInterrupted())
+ return DiagnosedSilenceableFailure::success();
+ return emitSilenceableFailure(loc) << "failed to match";
+}
+
+DiagnosedSilenceableFailure transform_ext::RegisterMatchCallbacksOp::apply(
+ mlir::transform::TransformResults &results,
+ mlir::transform::TransformState &state) {
+ auto ®istry = state.addExtension<transform_ext::MatchCallbacksRegistry>();
+ registry.registerCallback("_test_match_callback", testMatchCallbackCallback);
+ registry.registerCallback("reduction", reductionCallback);
+ registry.registerCallback("split_reduction", splitReductionCallback);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform_ext::RegisterMatchCallbacksOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ // TODO: it doesn't really modify the payload, we need a separate resource for
+ // this mapping.
+ mlir::transform::modifiesPayload(effects);
+}
+
+//===---------------------------------------------------------------------===//
+// TakeFirstOp
+//===---------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform_ext::TakeFirstOp::apply(mlir::transform::TransformResults &results,
+ mlir::transform::TransformState &state) {
+ SmallVector<Operation *> concatenated;
+ bool found = false;
+ for (Value handle : getInputs()) {
+ ArrayRef<Operation *> payloads = state.getPayloadOps(handle);
+ if (payloads.empty())
+ continue;
+ if (!found) {
+ results.set(getFirst().cast<OpResult>(), payloads);
+ found = true;
+ } else {
+ llvm::append_range(concatenated, payloads);
+ }
+ }
+
+ if (!found)
+ results.set(getFirst().cast<OpResult>(), {});
+ results.set(getRest().cast<OpResult>(), concatenated);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform_ext::TakeFirstOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ mlir::transform::onlyReadsHandle(getInputs(), effects);
+ mlir::transform::producesHandle(getFirst(), effects);
+ mlir::transform::producesHandle(getRest(), effects);
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt
index 8047e74..00eda6f 100644
--- a/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@
Listener.cpp
ListenerCSE.cpp
ListenerGreedyPatternRewriteDriver.cpp
+ TransformMatchers.cpp
LINK_LIBS PRIVATE
# TODO: break dialect dependency by implementing the transformation separately
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/TransformMatchers.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
similarity index 75%
rename from compiler/src/iree/compiler/Codegen/Common/TransformExtensions/TransformMatchers.cpp
rename to llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
index 4085a16..eda5d55 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/TransformMatchers.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
@@ -4,22 +4,21 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/Codegen/Common/TransformExtensions/TransformMatchers.h"
+#include "iree-dialects/Transforms/TransformMatchers.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
using namespace mlir;
-using namespace mlir::iree_compiler;
-using namespace mlir::iree_compiler::IREE;
//===---------------------------------------------------------------------===//
// StructuredOpMatcher and friends.
//===---------------------------------------------------------------------===//
-bool transform_dialect::StructuredOpMatcher::match(Operation *op) {
+bool transform_ext::StructuredOpMatcher::match(Operation *op) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
- if (!linalgOp) return false;
+ if (!linalgOp)
+ return false;
if (!llvm::all_of(predicates, [linalgOp](const PredicateFn &fn) {
return fn(linalgOp);
@@ -31,21 +30,22 @@
return true;
}
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::dim(int64_t dimension, ShapeKind kind) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::dim(int64_t dimension, ShapeKind kind) {
predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
SmallVector<int64_t> shape = linalgOp.getStaticLoopRanges();
int64_t transformedDimension =
dimension >= 0 ? dimension : shape.size() + dimension;
- if (transformedDimension >= shape.size()) return false;
+ if (transformedDimension >= shape.size())
+ return false;
return ShapedType::isDynamic(shape[transformedDimension]) ^
(kind == ShapeKind::Static);
});
return *this;
}
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::dim(AllDims tag, ShapeKind kind) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::dim(AllDims tag, ShapeKind kind) {
predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
SmallVector<int64_t> shape = linalgOp.getStaticLoopRanges();
return llvm::all_of(shape, [=](int64_t dimension) {
@@ -55,14 +55,15 @@
return *this;
}
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::dim(int64_t dimension,
- utils::IteratorType kind) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::dim(int64_t dimension,
+ utils::IteratorType kind) {
predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
unsigned rank = linalgOp.getNumLoops();
int64_t transformedDimension =
dimension >= 0 ? dimension : rank + dimension;
- if (transformedDimension >= rank) return false;
+ if (transformedDimension >= rank)
+ return false;
utils::IteratorType iteratorKind =
linalgOp.getIteratorTypesArray()[transformedDimension];
@@ -70,9 +71,8 @@
});
return *this;
}
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::dim(AllDims tag,
- utils::IteratorType kind) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::dim(AllDims tag, utils::IteratorType kind) {
predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
return llvm::all_of(
linalgOp.getIteratorTypesArray(),
@@ -81,14 +81,15 @@
return *this;
}
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::dim(int64_t dimension,
- DivisibleBy divisibleBy) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::dim(int64_t dimension,
+ DivisibleBy divisibleBy) {
predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
unsigned rank = linalgOp.getNumLoops();
int64_t transformedDimension =
dimension >= 0 ? dimension : rank + dimension;
- if (transformedDimension >= rank) return false;
+ if (transformedDimension >= rank)
+ return false;
int64_t size = linalgOp.getStaticLoopRanges()[transformedDimension];
return !ShapedType::isDynamic(size) && (size % divisibleBy.value == 0);
@@ -96,8 +97,8 @@
return *this;
}
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::input(AllOperands tag, IsPermutation) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::input(AllOperands tag, IsPermutation) {
predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
// all_of with a lambda requires const-casting dance, so using a loop.
for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
@@ -151,7 +152,8 @@
auto it = llvm::find_if(range, [&](BlockArgument bbarg) {
return loop.getTiedOpOperand(bbarg) != &use;
});
- if (it == range.end()) return user;
+ if (it == range.end())
+ return user;
val = *it;
continue;
}
@@ -164,9 +166,8 @@
} while (true);
}
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::input(int64_t position,
- SubsetOf subset) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::input(int64_t position, SubsetOf subset) {
// Implementation note: SubsetOf must *not* be passed by-reference because
// it is typically a temporary constructed within the argument of a function
// call, but it will be used in the lambda that outlives the temporary. The
@@ -174,7 +175,8 @@
predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
int64_t transformedPosition =
position >= 0 ? position : linalgOp.getNumDpsInputs() + position;
- if (transformedPosition >= linalgOp.getNumDpsInputs()) return false;
+ if (transformedPosition >= linalgOp.getNumDpsInputs())
+ return false;
Operation *producer = traverseSubsetsBackwards(
linalgOp.getDpsInputOperand(transformedPosition)->get());
@@ -184,8 +186,8 @@
return *this;
}
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::output(AllOperands tag, IsPermutation) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::output(AllOperands tag, IsPermutation) {
predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
for (OpOperand *operand : linalgOp.getDpsInitOperands()) {
if (!linalgOp.getMatchingIndexingMap(operand).isPermutation())
@@ -196,13 +198,14 @@
return *this;
}
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::output(int64_t position,
- ElementTypeBitWidth width) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::output(int64_t position,
+ ElementTypeBitWidth width) {
predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
int64_t updatedPosition =
position >= 0 ? position : linalgOp.getNumDpsInits() + position;
- if (updatedPosition >= linalgOp.getNumDpsInits()) return false;
+ if (updatedPosition >= linalgOp.getNumDpsInits())
+ return false;
auto shapedType = linalgOp.getDpsInitOperand(updatedPosition)
->get()
.getType()
@@ -213,13 +216,14 @@
return *this;
}
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::output(int64_t position,
- SingleCombinerReduction tag) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::output(int64_t position,
+ SingleCombinerReduction tag) {
predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
int64_t updatedPosition =
position >= 0 ? position : linalgOp.getNumDpsInits() + position;
- if (updatedPosition >= linalgOp.getNumDpsInits()) return false;
+ if (updatedPosition >= linalgOp.getNumDpsInits())
+ return false;
SmallVector<Operation *> combinerOps;
return matchReduction(linalgOp.getRegionOutputArgs(), updatedPosition,
combinerOps) &&
@@ -228,9 +232,8 @@
return *this;
}
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::output(int64_t position,
- SubsetOf subset) {
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::output(int64_t position, SubsetOf subset) {
// Implementation note: SubsetOf must *not* be passed by-reference because
// it is typically a temporary constructed within the argument of a function
// call, but it will be used in the lambda that outlives the temporary. The
@@ -238,7 +241,8 @@
predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
int64_t transformedPosition =
position >= 0 ? position : linalgOp.getNumDpsInputs() + position;
- if (transformedPosition >= linalgOp.getNumDpsInputs()) return false;
+ if (transformedPosition >= linalgOp.getNumDpsInputs())
+ return false;
Operation *producer = traverseSubsetsBackwards(
linalgOp.getDpsInitOperand(transformedPosition)->get());
@@ -248,14 +252,13 @@
return *this;
}
-transform_dialect::StructuredOpMatcher &
-transform_dialect::StructuredOpMatcher::result(int64_t position, HasAnyUse tag,
- SubsetOf subset,
- OptionalMatch optional) {
+transform_ext::StructuredOpMatcher &transform_ext::StructuredOpMatcher::result(
+ int64_t position, HasAnyUse tag, SubsetOf subset, OptionalMatch optional) {
predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
int64_t transformedPosition =
position >= 0 ? position : linalgOp->getNumResults() + position;
- if (transformedPosition >= linalgOp->getNumResults()) return false;
+ if (transformedPosition >= linalgOp->getNumResults())
+ return false;
Operation *user =
traverseSubsetsForwardAnyUse(linalgOp->getResult(transformedPosition));
@@ -268,8 +271,8 @@
// MatchCallbackResult.
//===---------------------------------------------------------------------===//
-ArrayRef<Operation *> transform_dialect::MatchCallbackResult::getPayloadGroup(
- unsigned position) const {
+ArrayRef<Operation *>
+transform_ext::MatchCallbackResult::getPayloadGroup(unsigned position) const {
assert(position < payloadGroupLengths.size());
int64_t start = 0;
for (unsigned i = 0; i < position; ++i) {
@@ -285,11 +288,11 @@
static constexpr unsigned kCudaWarpSize = 32;
-void transform_dialect::makeGPUReductionMatcher(
- transform_dialect::StructuredOpMatcher &reduction,
- transform_dialect::StructuredOpMatcher &fill,
- transform_dialect::StructuredOpMatcher &leading,
- transform_dialect::StructuredOpMatcher &trailing) {
+void transform_ext::makeReductionMatcher(
+ transform_ext::StructuredOpMatcher &reduction,
+ transform_ext::StructuredOpMatcher &fill,
+ transform_ext::StructuredOpMatcher &leading,
+ transform_ext::StructuredOpMatcher &trailing) {
fill = m_StructuredOp<linalg::FillOp>();
trailing = m_StructuredOp<linalg::GenericOp>()
.input(AllOperands(), IsPermutation())
@@ -314,13 +317,13 @@
.result(0, HasAnyUse(), trailing, OptionalMatch());
}
-void transform_dialect::makeGPUSplitReductionMatcher(
- transform_dialect::StructuredOpMatcher ¶llel_reduction,
- transform_dialect::StructuredOpMatcher &combiner_reduction,
- transform_dialect::StructuredOpMatcher ¶llel_fill,
- transform_dialect::StructuredOpMatcher &original_fill,
- transform_dialect::StructuredOpMatcher &leading,
- transform_dialect::StructuredOpMatcher &trailing) {
+void transform_ext::makeSplitReductionMatcher(
+ transform_ext::StructuredOpMatcher ¶llel_reduction,
+ transform_ext::StructuredOpMatcher &combiner_reduction,
+ transform_ext::StructuredOpMatcher ¶llel_fill,
+ transform_ext::StructuredOpMatcher &original_fill,
+ transform_ext::StructuredOpMatcher &leading,
+ transform_ext::StructuredOpMatcher &trailing) {
original_fill = m_StructuredOp<linalg::FillOp>();
parallel_fill = m_StructuredOp<linalg::FillOp>();
trailing = m_StructuredOp<linalg::GenericOp>()