Emit better reduction schedule from JIT (#11548)
Emit the reduction schedule splitting the reduction directly to
"scf.for" rather than to a generic over inputs with expanded shape. This
allows us to ensure the generation of vector loads and stores, and to
target the vector size appropriate to the input. Also plug in the
matcher into the schedule instead of the fragile reliance on the order
of operations.
This is a rebase of #11498 on main.
Co-authored-by: Oleksandr "Alex" Zinenko <zinenko@google.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.cpp
index 99dcc27..0984be7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.cpp
@@ -126,81 +126,85 @@
/// appended in order.
// TODO: apply forwarding pattern.
template <typename TilingTransformOp, typename TileOrNumThreadSpec>
-static Value buildTileAndFuseAndDistributeImpl(
- ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
- ArrayRef<OpFoldResult> tileSizesOrNumThreads, ArrayAttr threadDimMapping,
- SmallVectorImpl<Value> *resultingFusedOpsHandles) {
+static iree_compiler::TileAndFuseAndDistributeResult
+buildTileAndFuseAndDistributeImpl(ImplicitLocOpBuilder &b, Value rootH,
+ ValueRange opsHToFuse,
+ ArrayRef<OpFoldResult> tileSizesOrNumThreads,
+ ArrayAttr threadDimMapping) {
+ iree_compiler::TileAndFuseAndDistributeResult result;
auto tileToForeachOp = b.create<TilingTransformOp>(
rootH, tileSizesOrNumThreads, TileOrNumThreadSpec(), threadDimMapping);
- Value foreachThreadH = tileToForeachOp.getForeachThreadOp();
- // Batch fusion.
- Value mergedOpsH = b.create<MergeHandlesOp>(opsHToFuse, /*deduplicate=*/true);
- b.create<FuseIntoContainingOp>(mergedOpsH, foreachThreadH);
- assert(!resultingFusedOpsHandles && "Handle needs unpacking");
- return foreachThreadH;
+ result.foreachThreadH = tileToForeachOp.getForeachThreadOp();
+ result.tiledOpH = tileToForeachOp.getTiledOp();
+
+ // Batch fusion if requested.
+ if (opsHToFuse.size() > 1) {
+ Value mergedOpsH =
+ b.create<MergeHandlesOp>(opsHToFuse, /*deduplicate=*/true);
+ b.create<FuseIntoContainingOp>(mergedOpsH, result.foreachThreadH);
+ } else if (opsHToFuse.size() == 1) {
+ Value fusedH = b.create<FuseIntoContainingOp>(opsHToFuse.front(),
+ result.foreachThreadH);
+ result.resultingFusedOpsHandles.push_back(fusedH);
+ }
+ return result;
}
// TODO: if someone knows how to properly export templates go for it ..
// sigh.
template <typename TilingTransformOp>
-static Value buildTileFuseDistWithTileSizes(
- ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
- ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping,
- SmallVectorImpl<Value> *resultingFusedOpsHandles) {
+static iree_compiler::TileAndFuseAndDistributeResult
+buildTileFuseDistWithTileSizes(ImplicitLocOpBuilder &b, Value rootH,
+ ValueRange opsHToFuse,
+ ArrayRef<OpFoldResult> tileSizes,
+ ArrayAttr threadDimMapping) {
return buildTileAndFuseAndDistributeImpl<TilingTransformOp,
transform::TileSizesSpec>(
- b, rootH, opsHToFuse, tileSizes, threadDimMapping,
- resultingFusedOpsHandles);
+ b, rootH, opsHToFuse, tileSizes, threadDimMapping);
}
-Value mlir::iree_compiler::buildTileFuseDistToForeachThreadWithTileSizes(
+iree_compiler::TileAndFuseAndDistributeResult
+mlir::iree_compiler::buildTileFuseDistToForeachThreadWithTileSizes(
ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
- ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping,
- SmallVectorImpl<Value> *resultingFusedOpsHandles) {
+ ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping) {
return buildTileFuseDistWithTileSizes<TileToForeachThreadOp>(
- b, rootH, opsHToFuse, tileSizes, threadDimMapping,
- resultingFusedOpsHandles);
+ b, rootH, opsHToFuse, tileSizes, threadDimMapping);
}
-Value mlir::iree_compiler::
+iree_compiler::TileAndFuseAndDistributeResult mlir::iree_compiler::
buildTileFuseDistToForeachThreadAndWorgroupCountWithTileSizes(
ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
- ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping,
- SmallVectorImpl<Value> *resultingFusedOpsHandles) {
+ ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping) {
return buildTileFuseDistWithTileSizes<
- TileToForeachThreadAndWorkgroupCountRegionOp>(b, rootH, opsHToFuse,
- tileSizes, threadDimMapping,
- resultingFusedOpsHandles);
+ TileToForeachThreadAndWorkgroupCountRegionOp>(
+ b, rootH, opsHToFuse, tileSizes, threadDimMapping);
}
/// Call buildTileAndFuseAndDistributeImpl with ArrayRef<int64_t> numThreads.
// TODO: if someone knows how to properly export templates go for it ..
// sigh.
template <typename TilingTransformOp>
-static Value buildTileFuseDistWithNumThreads(
- ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
- ArrayRef<OpFoldResult> numThreads, ArrayAttr threadDimMapping,
- SmallVectorImpl<Value> *resultingFusedOpsHandles) {
+static iree_compiler::TileAndFuseAndDistributeResult
+buildTileFuseDistWithNumThreads(ImplicitLocOpBuilder &b, Value rootH,
+ ValueRange opsHToFuse,
+ ArrayRef<OpFoldResult> numThreads,
+ ArrayAttr threadDimMapping) {
return buildTileAndFuseAndDistributeImpl<TilingTransformOp,
transform::NumThreadsSpec>(
- b, rootH, opsHToFuse, numThreads, threadDimMapping,
- resultingFusedOpsHandles);
+ b, rootH, opsHToFuse, numThreads, threadDimMapping);
}
-Value mlir::iree_compiler::buildTileFuseDistToForeachThreadWithNumThreads(
+iree_compiler::TileAndFuseAndDistributeResult
+mlir::iree_compiler::buildTileFuseDistToForeachThreadWithNumThreads(
ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
- ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping,
- SmallVectorImpl<Value> *resultingFusedOpsHandles) {
- return buildTileFuseDistWithTileSizes<TileToForeachThreadOp>(
- b, rootH, opsHToFuse, tileSizes, threadDimMapping,
- resultingFusedOpsHandles);
+ ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping) {
+ return buildTileFuseDistWithNumThreads<TileToForeachThreadOp>(
+ b, rootH, opsHToFuse, tileSizes, threadDimMapping);
}
-Value mlir::iree_compiler::
+iree_compiler::TileAndFuseAndDistributeResult mlir::iree_compiler::
buildTileFuseDistToForeachThreadAndWorgroupCountWithNumThreads(
ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
- ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping,
- SmallVectorImpl<Value> *resultingFusedOpsHandles) {
- return buildTileFuseDistWithTileSizes<
- TileToForeachThreadAndWorkgroupCountRegionOp>(b, rootH, opsHToFuse,
- tileSizes, threadDimMapping,
- resultingFusedOpsHandles);
+ ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping) {
+ return buildTileFuseDistWithNumThreads<
+ TileToForeachThreadAndWorkgroupCountRegionOp>(
+ b, rootH, opsHToFuse, tileSizes, threadDimMapping);
}
/// Apply patterns and vectorize (for now always applies rank-reduction).
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.h b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.h
index c7ca9fb..5c72d68 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.h
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.h
@@ -20,13 +20,26 @@
/// Prints `handles` in order. Prints the whole IR if `handles` is empty.
static void buildPrint(ImplicitLocOpBuilder &b, ValueRange handles = {});
+/// Result of the combined transform performing tiling, fusion and distribution
+/// to parallel constructs.
+struct TileAndFuseAndDistributeResult {
+ /// Outer `scf.foreach_thread` loop containing the tiled and fused operations.
+ Value foreachThreadH;
+ /// Handles to fused operations other than the final consumer operation. May
+ /// be empty if fusion was not performed iteratively.
+ // TODO: support returning handles from `fuse_into_containing_op` and remove
+ // the restriction above.
+ SmallVector<Value> resultingFusedOpsHandles;
+ /// Handle to the tiled final consumer operation.
+ Value tiledOpH;
+};
+
/// Performs the following transformations:
/// 1. Tiles `rootH` to scf.foreach_thread to with `tileSizesOrNumThreads`
/// according to whether spec is a TileSizesSpec or a NumThreadsSpec.
/// 2. Maps the resulting scf.foreach_thread to threads according to
/// `threadDimMapping`.
/// 3. Iterates over `opsHToFuse` in order and fuses into the containing op.
-/// Returns a handle to the resulting scf.foreach_thread.
///
/// Fusion operates in batch mode: a single fusion command is issued and a
/// topological sort is automatically computed by the fusion.
@@ -37,28 +50,23 @@
/// providing the fusion order and has interleaved canonicalization / cse /
/// enabling transform will be introduced and may result in better fusions.
///
-/// If `resultingFusedOpsHandles` is a non-null pointer, the fused operation are
-/// appended in order.
-///
// TODO: if someone knows how to properly export templates go for it .. sigh.
-Value buildTileFuseDistToForeachThreadWithTileSizes(
+TileAndFuseAndDistributeResult buildTileFuseDistToForeachThreadWithTileSizes(
ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
- ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping,
- SmallVectorImpl<Value> *resultingFusedOpsHandles = nullptr);
-Value buildTileFuseDistToForeachThreadAndWorgroupCountWithTileSizes(
+ ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping);
+TileAndFuseAndDistributeResult
+buildTileFuseDistToForeachThreadAndWorgroupCountWithTileSizes(
ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
- ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping,
- SmallVectorImpl<Value> *resultingFusedOpsHandles = nullptr);
+ ArrayRef<OpFoldResult> tileSizes, ArrayAttr threadDimMapping);
/// See buildTileFuseDistWithTileSizes.
-Value buildTileFuseDistToForeachThreadWithNumThreads(
+TileAndFuseAndDistributeResult buildTileFuseDistToForeachThreadWithNumThreads(
ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
- ArrayRef<OpFoldResult> numThreads, ArrayAttr threadDimMapping,
- SmallVectorImpl<Value> *resultingFusedOpsHandles = nullptr);
-Value buildTileFuseDistToForeachThreadAndWorgroupCountWithNumThreads(
+ ArrayRef<OpFoldResult> numThreads, ArrayAttr threadDimMapping);
+TileAndFuseAndDistributeResult
+buildTileFuseDistToForeachThreadAndWorgroupCountWithNumThreads(
ImplicitLocOpBuilder &b, Value rootH, ValueRange opsHToFuse,
- ArrayRef<OpFoldResult> numThreads, ArrayAttr threadDimMapping,
- SmallVectorImpl<Value> *resultingFusedOpsHandles = nullptr);
+ ArrayRef<OpFoldResult> numThreads, ArrayAttr threadDimMapping);
/// Apply patterns and vectorize (for now always applies rank-reduction).
/// Takes a handle to a func.func and returns an updated handle to a
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.cpp
index c4e119d..50c1e32 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.cpp
@@ -75,9 +75,12 @@
using transform_ext::AllDims;
using transform_ext::IsPermutation;
using transform_ext::m_StructuredOp;
+using transform_ext::MatchCallbackOp;
using transform_ext::NumEqualsTo;
+using transform_ext::RegisterMatchCallbacksOp;
using transform_ext::ShapeKind;
using transform_ext::StructuredOpMatcher;
+using transform_ext::TakeFirstOp;
/// Matches `args` within `targetH` and unpacks a number of handles `N`.
/// Assumes there are exactly `N` matched ops (but could be relaxed).
@@ -99,105 +102,131 @@
// user-friendliness.
//===----------------------------------------------------------------------===//
-// TODO: consider passing a problem-specific struct to control information.
-static Value createReductionStrategyThreadDistributionPart(
- ImplicitLocOpBuilder &b, Value variantH, ArrayRef<int64_t> tileSizes1Fill,
- ArrayRef<int64_t> tileSizes1Generic, bool hasLeadingEltwise,
- bool hasTrailingEltwise) {
- // TODO: Relying on ordering is brittle, harden this.
- Value matchedH = b.create<MatchOp>(
- variantH, ArrayRef<StringRef>{linalg::GenericOp::getOperationName(),
- linalg::FillOp::getOperationName()});
- auto split = b.create<SplitHandlesOp>(
- matchedH,
- /*numResultHandles=*/4 + hasLeadingEltwise + hasTrailingEltwise);
- Value firstFusionRootH = split.getResults()[1 + hasLeadingEltwise];
- SmallVector<Value> firstFusionGroupHs =
- split.getResults().take_front(1 + hasLeadingEltwise);
- Value secondFusionRootH = split.getResults().back();
- SmallVector<Value> secondFusionGroupHs =
- split.getResults().drop_front(2 + hasLeadingEltwise).drop_back();
-
- auto z = mlir::gpu::GPUThreadMappingAttr::get(b.getContext(),
- ::mlir::gpu::Threads::DimZ);
- auto y = mlir::gpu::GPUThreadMappingAttr::get(b.getContext(),
- ::mlir::gpu::Threads::DimY);
-
- // clang-format off
- iree_compiler::buildTileFuseDistToForeachThreadWithTileSizes(b,
- /*rootH=*/secondFusionRootH,
- /*opsHToFuse=*/secondFusionGroupHs,
- /*tileSizes=*/getAsOpFoldResult(b.getI64ArrayAttr(tileSizes1Fill)),
- /*threadDimMapping=*/b.getArrayAttr({z}));
- iree_compiler::buildTileFuseDistToForeachThreadWithTileSizes(b,
- /*rootH=*/firstFusionRootH,
- /*opsHToFuse=*/firstFusionGroupHs,
- /*tileSizes=*/getAsOpFoldResult(b.getI64ArrayAttr(tileSizes1Generic)),
- /*threadDimMapping=*/b.getArrayAttr({z,y}));
- // clang-format on
- return variantH;
-}
-
/// Structure to hold the parameters related to GPU reduction strategy.
struct GPUReductionStrategyInfos {
std::array<int64_t, 3> workgroupSize;
SmallVector<int64_t> workgroupTileSizes;
SmallVector<int64_t> fillSecondTileSizes;
SmallVector<int64_t> genericSecondTileSizes;
- bool hasLeadingEltwise;
- bool hasTrailingEltwise;
+ int64_t reductionDimensionSize;
};
-/// Returns a triple of handles: the leading elementwise operation, the
-/// reduction operation and the fusion root. The leading elementwise and the
-/// fusion root may be null. If the fusion root is null, the reduction operation
-/// should be used as fusion root instead.
-// TODO: consider passing a problem-specific struct to control information.
-static std::tuple<Value, Value, Value>
-createMatchReductionBlockDistributionHandles(ImplicitLocOpBuilder &b,
- Value variantH,
- bool hasLeadingEltwise,
- bool hasTrailingEltwise) {
- Value originalGenericH =
- b.create<MatchOp>(variantH, linalg::GenericOp::getOperationName());
- auto op = b.create<SplitHandlesOp>(
- originalGenericH,
- /*numResultHandles=*/1 + hasLeadingEltwise + hasTrailingEltwise);
- return std::make_tuple(hasLeadingEltwise ? op.getResults().front() : Value(),
- op.getResults().drop_front(hasLeadingEltwise).front(),
- hasTrailingEltwise ? op.getResults().back() : Value());
+static std::pair<Value, Value> createReductionStrategyBlockDistribution(
+ ImplicitLocOpBuilder &b, Value maybeLeadingH, Value fillH, Value reductionH,
+ Value maybeTrailingH) {
+ auto pdlOperation = pdl::OperationType::get(b.getContext());
+ auto fusionTargetSelector = b.create<TakeFirstOp>(
+ pdlOperation, pdlOperation, ArrayRef<Value>{maybeTrailingH, reductionH});
+ Value fusionTargetH = fusionTargetSelector.getFirst();
+ Value fusionGroupH = fusionTargetSelector.getRest();
+ auto blockX = mlir::gpu::GPUBlockMappingAttr::get(b.getContext(),
+ mlir::gpu::Blocks::DimX);
+ iree_compiler::TileAndFuseAndDistributeResult tileResult = iree_compiler::
+ buildTileFuseDistToForeachThreadAndWorgroupCountWithTileSizes(
+ b, fusionTargetH, fusionGroupH,
+ getAsOpFoldResult(b.getI64ArrayAttr({1})), b.getArrayAttr(blockX));
+ Value foreachThreadH =
+ b.create<FuseIntoContainingOp>(fillH, tileResult.foreachThreadH);
+ foreachThreadH =
+ b.create<FuseIntoContainingOp>(maybeLeadingH, foreachThreadH);
+ auto gridReductionSelector = b.create<TakeFirstOp>(
+ pdlOperation, pdlOperation,
+ ArrayRef<Value>(
+ {tileResult.resultingFusedOpsHandles.front(), tileResult.tiledOpH}));
+
+ return std::make_pair(gridReductionSelector.getFirst(),
+ gridReductionSelector.getRest());
}
-// TODO: generalize and automate over and over.
-// TODO: significantly shrink this down.
-// TODO: consider passing a problem-specific struct to control information.
+static void createReductionStrategyThreadDistribution(
+ ImplicitLocOpBuilder &b, Value gridReductionH, Value maybeTiledTrailingH,
+ int64_t reductionDimensionSize) {
+ // Select tile sizes. Perfectly tile by:
+ // - 128 to obtain 32 threads working on vector<4xf32> when possible;
+ // - 64 to obtain 32 threads working on vector<2xf32> when possible;
+ // - 32 otherwise.
+ // TODO: refine sizes based on the bitwidth of the elemental type.
+ int64_t firstReductionSize = iree_compiler::kCudaWarpSize;
+ int64_t vectorTileSize = 1;
+ if (reductionDimensionSize % (4 * iree_compiler::kCudaWarpSize) == 0) {
+ firstReductionSize = 4 * iree_compiler::kCudaWarpSize;
+ vectorTileSize = 4;
+ } else if (reductionDimensionSize % (2 * iree_compiler::kCudaWarpSize) == 0) {
+ firstReductionSize = 2 * iree_compiler::kCudaWarpSize;
+ vectorTileSize = 2;
+ }
+
+ auto pdlOperation = pdl::OperationType::get(b.getContext());
+ auto threadX = mlir::gpu::GPUThreadMappingAttr::get(b.getContext(),
+ mlir::gpu::Threads::DimX);
+ auto threadY = mlir::gpu::GPUThreadMappingAttr::get(b.getContext(),
+ mlir::gpu::Threads::DimY);
+
+ // Split the reduction into a parallel and combiner part, then tile the
+ // parallel part and map it to a full warp so it works on vectors.
+ auto tileReduction = b.create<transform::TileReductionUsingScfOp>(
+ pdlOperation, pdlOperation, pdlOperation, pdlOperation, gridReductionH,
+ b.getI64ArrayAttr({0, firstReductionSize}));
+ Value blockParallelFillH = tileReduction.getFillOp();
+ Value blockParallelOpH = tileReduction.getSplitLinalgOp();
+ Value blockCombinerOpH = tileReduction.getCombiningLinalgOp();
+ iree_compiler::buildTileFuseDistToForeachThreadWithNumThreads(
+ b, blockParallelOpH, {},
+ getAsOpFoldResult(b.getI64ArrayAttr({0, iree_compiler::kCudaWarpSize})),
+ b.getArrayAttr(threadX));
+
+ // Tile the fill so it maps to vectors.
+ // TODO: fuse once the support is available
+ // (https://reviews.llvm.org/D139844).
+ iree_compiler::buildTileFuseDistToForeachThreadWithTileSizes(
+ b, blockParallelFillH, {},
+ getAsOpFoldResult(b.getI64ArrayAttr({0, vectorTileSize})),
+ b.getArrayAttr(threadX));
+
+ // Map the combiner reduction to one thread along y so it can be mapped
+ // further via predication. Fuse it into the trailing elementwise if present.
+ auto selector = b.create<TakeFirstOp>(
+ pdlOperation, pdlOperation,
+ ArrayRef<Value>({maybeTiledTrailingH, blockCombinerOpH}));
+ Value fusionRootH = selector.getFirst();
+ Value fusionGroupH = selector.getRest();
+ iree_compiler::buildTileFuseDistToForeachThreadWithTileSizes(
+ b, fusionRootH, fusionGroupH, getAsOpFoldResult(b.getI64ArrayAttr({1})),
+ b.getArrayAttr(threadY));
+}
+
+/// Builds the transform IR tiling reductions for CUDA targets. Supports
+/// reductions in the last dimension with static shape divisible by 32 (CUDA
+/// warp size), with optional leading and trailing elementwise operations.
static void createReductionCudaStrategy(
ImplicitLocOpBuilder &b, Value variantH,
const GPUReductionStrategyInfos &infos) {
- // Step 0. Match the ops.
- Value originalFillH =
- b.create<MatchOp>(variantH, linalg::FillOp::getOperationName());
- auto [leadingH, reductionH, fusionRootH] =
- createMatchReductionBlockDistributionHandles(
- b, variantH, infos.hasLeadingEltwise, infos.hasTrailingEltwise);
+ // Step 1. Call the matcher. Note that this is the same matcher as used to
+ // trigger this compilation path, so it must always apply.
+ b.create<RegisterMatchCallbacksOp>();
+ SmallVector<Type> matchedTypes(4, pdl::OperationType::get(b.getContext()));
+ auto match = b.create<MatchCallbackOp>(
+ matchedTypes, "reduction", transform::FailurePropagationMode::Propagate,
+ variantH);
+ Value maybeLeadingH = match.getResult(0);
+ Value fillH = match.getResult(1);
+ Value reductionH = match.getResult(2);
+ Value maybeTrailingH = match.getResult(3);
- // Step 1: Distribute to blocks using the current IREE lowering config.
- variantH = iree_compiler::createReductionStrategyBlockDistributionPart(
- b, variantH, originalFillH, reductionH, fusionRootH,
- getAsOpFoldResult(b.getI64ArrayAttr(infos.workgroupTileSizes)),
- infos.hasLeadingEltwise, infos.hasTrailingEltwise);
+ // Step 2. Use tiling to introduce a single-iteration loop mapped to a single
+ // block/workgroup. Keep everything fused.
+ auto [gridReductionH, maybeTiledTrailingH] =
+ createReductionStrategyBlockDistribution(b, maybeLeadingH, fillH,
+ reductionH, maybeTrailingH);
- // Step 2. Second level of tiling + fusion parallelizes to threads.
- variantH = createReductionStrategyThreadDistributionPart(
- b, variantH, infos.fillSecondTileSizes, infos.genericSecondTileSizes,
- infos.hasLeadingEltwise, infos.hasTrailingEltwise);
+ // Step 3. Split the reduction and tile the pieces to ensure vector
+ // load/stores and mapping to a single warp with shuffles.
+ createReductionStrategyThreadDistribution(
+ b, gridReductionH, maybeTiledTrailingH, infos.reductionDimensionSize);
- // Step 3. Rank-reduce and vectorize.
- // TODO: assumes a single func::FuncOp to transform, may need hardening.
+ // Step 4. Bufferize and drop HAL decriptor from memref ops.
Value funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
funcH = iree_compiler::buildVectorize(b, funcH);
-
- // Step 4. Bufferize and drop HAL descriptor from memref ops.
variantH = iree_compiler::buildBufferize(b, variantH, /*targetGpu=*/true);
// Step 5. Post-bufferization mapping to blocks and threads.
@@ -216,14 +245,12 @@
GPUReductionStrategyInfos &info) {
// TODO: match the sequence the strategy supports.
StructuredOpMatcher pattern, fill, leadingEltwise, trailingEltwise;
- makeReductionMatcher(pattern, fill, leadingEltwise, trailingEltwise);
+ makeReductionMatcher(pattern, fill, leadingEltwise, trailingEltwise,
+ info.reductionDimensionSize);
if (!matchPattern(op, pattern)) return false;
- info.hasLeadingEltwise = leadingEltwise.getCaptured() != nullptr;
- info.hasTrailingEltwise = trailingEltwise.getCaptured() != nullptr;
-
- // Hardcoded workagroup size, this could be deduced from the reduction dim.
- info.workgroupSize = {32, 2, 1};
+ // Hardcoded workgroup size, this could be deduced from the reduction dim.
+ info.workgroupSize = {32, 1, 1};
SmallVector<unsigned> partitionedLoops =
cast<iree_compiler::PartitionableLoopsInterface>(op.getOperation())
.getPartitionableLoops(iree_compiler::kNumMaxParallelDims);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform.mlir
index 25393c6..dad48b6 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform.mlir
@@ -27,37 +27,35 @@
}
}
}
+
// CHECK-LABEL: func.func @group_reduction
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[F0:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
// CHECK-DAG: %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
-// CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x2xf32, 3>
+// CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x64xf32, 3>
// CHECK-DAG: %[[TIDX:.]] = gpu.thread_id x
-// CHECK-DAG: %[[TIDY:.]] = gpu.thread_id y
-// CHECK-DAG: %[[TIDZ:.]] = gpu.thread_id z
-// CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][%[[TIDZ]], %[[TIDY]]]{{.*}}to memref<f32, {{.*}}, 3>
+// CHECK: %[[IDX:.*]] = affine.apply{{.*}}%[[TIDX]]
+// CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][0, %[[IDX]]]{{.*}}to memref<2xf32, strided<[1], offset: ?>, 3>
+// CHECK: gpu.barrier
-// Distributed reduction: everyone loads then 5 xor + addf expected
-// CHECK: vector.transfer_read %{{.*}}[%[[TIDZ]], %[[TIDY]], %[[TIDX]]]
+// Local per-thread scf.for-based reduction, after the single-iteration scf.for was canonicalized.
+// CHECK: vector.transfer_read
+// CHECK: vector.transfer_read
+// CHECK: arith.addf %{{.*}} : vector<2xf32>
+// CHECK: vector.transfer_write
+// CHECK: gpu.barrier
+
+// Distributed reduction: everyone loads then 5 xor + addf expected.
+// CHECK: %[[TIDY:.]] = gpu.thread_id y
+// CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[IDX]]]
// CHECK-COUNT-5: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
+
// CHECK: %[[RES:.*]] = arith.addf %{{.*}}
// CHECK: %[[RES_VEC:.*]] = vector.broadcast %[[RES]] : f32 to vector<f32>
// CHECK: %[[CONDXIS0:.*]] = arith.cmpi eq, %[[TIDX]], %[[C0]] : index
// CHECK: scf.if %[[CONDXIS0]]
-// CHECK: vector.transfer_write %[[RES_VEC]], %[[SHMEM_VIEW_EXPANDED]][]
+// CHECK: vector.transfer_write %[[RES_VEC]]
// CHECK: gpu.barrier
-
-// Last part is not distributed atm and is only ran by threadIdx.x == 0 and threadIdx.y == 0.
-// CHECK: %[[CONDYIS0:.*]] = arith.cmpi ult, %[[TIDY]], %[[C1]] : index
-// TODO: cond eq 0 and cond ult 1 do not CSE atm.
-// CHECK: %[[CONXANDYARE0:.*]] = arith.andi %{{.*}}, %[[CONDYIS0]] : i1
-// CHECK: scf.if %[[CONXANDYARE0]] {
-// CHECK: vector.transfer_read
-// CHECK: vector.reduction <add>
-// CHECK: vector.transfer_write
-// CHECK: gpu.barrier
-// CHECK: memref.dealloc %[[SHMEM_ALLOC]] : memref<1x2xf32, 3>
+// CHECK: memref.dealloc %[[SHMEM_ALLOC]]
// -----
@@ -96,37 +94,35 @@
// CHECK-LABEL: func.func @group_reduction_elementwise
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[F0:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
// CHECK-DAG: %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
-// CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x2xf32, 3>
+// CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x64xf32, 3>
// CHECK-DAG: %[[TIDX:.]] = gpu.thread_id x
-// CHECK-DAG: %[[TIDY:.]] = gpu.thread_id y
-// CHECK-DAG: %[[TIDZ:.]] = gpu.thread_id z
-// CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][%[[TIDZ]], %[[TIDY]]]{{.*}}to memref<f32, {{.*}}, 3>
+// CHECK: %[[IDX:.*]] = affine.apply{{.*}}%[[TIDX]]
+// CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][0, %[[IDX]]]{{.*}}to memref<2xf32, strided<[1], offset: ?>, 3>
+// CHECK: gpu.barrier
-// Distributed reduction: everyone loads then 5 xor + addf expected
-// CHECK: vector.transfer_read %{{.*}}[%[[TIDZ]], %[[TIDY]], %[[TIDX]]]
+// Local per-thread scf.for-based reduction, after the single-iteration scf.for was canonicalized.
+// CHECK: vector.transfer_read
+// CHECK: vector.transfer_read
+// CHECK: arith.addf %{{.*}} : vector<2xf32>
+// CHECK: vector.transfer_write
+// CHECK: gpu.barrier
+
+// Distributed reduction: everyone loads then 5 xor + addf expected.
+// CHECK: %[[TIDY:.]] = gpu.thread_id y
+// CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[IDX]]]
// CHECK-COUNT-5: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
-// CHECK: %[[RES:.*]] = arith.addf %{{.*}}
+
+// CHECK: %[[PARTIAL:.*]] = arith.addf %{{.*}}
+// CHECK: %[[PARTIAL_VEC:.*]] = vector.broadcast %[[PARTIAL]] : f32 to vector<f32>
+// CHECK: %[[ELEM:.*]] = vector.extractelement %[[PARTIAL_VEC]][]
+// CHECK: %[[RES:.*]] = math.sqrt %[[ELEM]]
// CHECK: %[[RES_VEC:.*]] = vector.broadcast %[[RES]] : f32 to vector<f32>
// CHECK: %[[CONDXIS0:.*]] = arith.cmpi eq, %[[TIDX]], %[[C0]] : index
// CHECK: scf.if %[[CONDXIS0]]
-// CHECK: vector.transfer_write %[[RES_VEC]], %[[SHMEM_VIEW_EXPANDED]][]
+// CHECK: vector.transfer_write %[[RES_VEC]]
// CHECK: gpu.barrier
-
-// Last part is not distributed atm and is only ran by threadIdx.x == 0 and threadIdx.y == 0.
-// It should contain the fused elementwise operation.
-// CHECK: %[[CONDYIS0:.*]] = arith.cmpi ult, %[[TIDY]], %[[C1]] : index
-// TODO: cond eq 0 and cond ult 1 do not CSE atm.
-// CHECK: %[[CONXANDYARE0:.*]] = arith.andi %{{.*}}, %[[CONDYIS0]] : i1
-// CHECK: scf.if %[[CONXANDYARE0]] {
-// CHECK: vector.transfer_read
-// CHECK: vector.reduction <add>
-// CHECK: math.sqrt
-// CHECK: vector.transfer_write
-// CHECK: gpu.barrier
-// CHECK: memref.dealloc %[[SHMEM_ALLOC]] : memref<1x2xf32, 3>
+// CHECK: memref.dealloc %[[SHMEM_ALLOC]]
// -----
@@ -162,40 +158,35 @@
// CHECK-LABEL: func.func @group_elementwise_reduction
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[F0:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
// CHECK-DAG: %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
-// CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x2xf32, 3>
+// CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x64xf32, 3>
// CHECK-DAG: %[[TIDX:.]] = gpu.thread_id x
-// CHECK-DAG: %[[TIDY:.]] = gpu.thread_id y
-// CHECK-DAG: %[[TIDZ:.]] = gpu.thread_id z
+// CHECK: %[[IDX:.*]] = affine.apply{{.*}}%[[TIDX]]
+// CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][0, %[[IDX]]]{{.*}}to memref<2xf32, strided<[1], offset: ?>, 3>
+// CHECK: gpu.barrier
-// CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][%[[TIDZ]], %[[TIDY]]]{{.*}}to memref<f32, {{.*}}, 3>
+// Local per-thread scf.for-based reduction, after the single-iteration scf.for was canonicalized.
+// CHECK: vector.transfer_read
+// CHECK: vector.transfer_read
+// CHECK: %[[PARTIAL_1:.*]] = arith.addf %[[ARG:.*]], %[[ARG]]
+// CHECK: %[[PARTIAL_2:.*]] = arith.addf %[[PARTIAL_1]], %[[PARTIAL_1]]
+// CHECK: arith.addf %[[PARTIAL_2]], %{{.*}} : vector<2xf32>
+// CHECK: vector.transfer_write
+// CHECK: gpu.barrier
-// Distributed reduction: everyone loads, does the elementwise then 5 xor + addf expected
-// CHECK: vector.transfer_read %{{.*}}[%[[TIDZ]], %[[TIDY]], %[[TIDX]]]
-// CHECK: arith.addf
-// CHECK: arith.addf
+// Distributed reduction: everyone loads then 5 xor + addf expected.
+// CHECK: %[[TIDY:.]] = gpu.thread_id y
+// CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[IDX]]]
// CHECK-COUNT-5: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
-// CHECK: %[[RES:.*]] = arith.addf %{{.*}}
+// CHECK: %[[RES:.*]] = arith.addf %{{.*}}
// CHECK: %[[RES_VEC:.*]] = vector.broadcast %[[RES]] : f32 to vector<f32>
// CHECK: %[[CONDXIS0:.*]] = arith.cmpi eq, %[[TIDX]], %[[C0]] : index
// CHECK: scf.if %[[CONDXIS0]]
-// CHECK: vector.transfer_write %[[RES_VEC]], %[[SHMEM_VIEW_EXPANDED]][]
+// CHECK: vector.transfer_write %[[RES_VEC]]
// CHECK: gpu.barrier
-
-// Last part is not distributed atm and is only ran by threadIdx.x == 0 and threadIdx.y == 0.
-// CHECK: %[[CONDYIS0:.*]] = arith.cmpi ult, %[[TIDY]], %[[C1]] : index
-// TODO: cond eq 0 and cond ult 1 do not CSE atm.
-// CHECK: %[[CONXANDYARE0:.*]] = arith.andi %{{.*}}, %[[CONDYIS0]] : i1
-// CHECK: scf.if %[[CONXANDYARE0]] {
-// CHECK: vector.transfer_read
-// CHECK: vector.reduction <add>
-// CHECK: vector.transfer_write
-// CHECK: gpu.barrier
-// CHECK: memref.dealloc %[[SHMEM_ALLOC]] : memref<1x2xf32, 3>
+// CHECK: memref.dealloc %[[SHMEM_ALLOC]]
// -----
@@ -236,38 +227,93 @@
// CHECK-LABEL: func.func @group_elementwise_reduction_elementwise
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[F0:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
// CHECK-DAG: %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
-// CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x2xf32, 3>
+// CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x64xf32, 3>
// CHECK-DAG: %[[TIDX:.]] = gpu.thread_id x
-// CHECK-DAG: %[[TIDY:.]] = gpu.thread_id y
-// CHECK-DAG: %[[TIDZ:.]] = gpu.thread_id z
+// CHECK: %[[IDX:.*]] = affine.apply{{.*}}%[[TIDX]]
+// CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][0, %[[IDX]]]{{.*}}to memref<2xf32, strided<[1], offset: ?>, 3>
+// CHECK: gpu.barrier
-// CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][%[[TIDZ]], %[[TIDY]]]{{.*}}to memref<f32, {{.*}}, 3>
+// Local per-thread scf.for-based reduction, after the single-iteration scf.for was canonicalized.
+// CHECK: vector.transfer_read
+// CHECK: vector.transfer_read
+// CHECK: %[[PARTIAL_1:.*]] = arith.addf %[[ARG:.*]], %[[ARG]]
+// CHECK: %[[PARTIAL_2:.*]] = arith.addf %[[PARTIAL_1]], %[[PARTIAL_1]]
+// CHECK: arith.addf %[[PARTIAL_2]], %{{.*}} : vector<2xf32>
+// CHECK: vector.transfer_write
+// CHECK: gpu.barrier
-// Distributed reduction: everyone loads, does the elementwise then 5 xor + addf expected
-// CHECK: vector.transfer_read %{{.*}}[%[[TIDZ]], %[[TIDY]], %[[TIDX]]]
-// CHECK: arith.addf
-// CHECK: arith.addf
+// Distributed reduction: everyone loads then 5 xor + addf expected.
+// CHECK: %[[TIDY:.]] = gpu.thread_id y
+// CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[IDX]]]
// CHECK-COUNT-5: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
-// CHECK: %[[RES:.*]] = arith.addf %{{.*}}
-
+// CHECK: %[[PARTIAL:.*]] = arith.addf %{{.*}}
+// CHECK: %[[PARTIAL_VEC:.*]] = vector.broadcast %[[PARTIAL]] : f32 to vector<f32>
+// CHECK: %[[ELEM:.*]] = vector.extractelement %[[PARTIAL_VEC]][]
+// CHECK: %[[RES:.*]] = math.sqrt %[[ELEM]]
// CHECK: %[[RES_VEC:.*]] = vector.broadcast %[[RES]] : f32 to vector<f32>
// CHECK: %[[CONDXIS0:.*]] = arith.cmpi eq, %[[TIDX]], %[[C0]] : index
// CHECK: scf.if %[[CONDXIS0]]
-// CHECK: vector.transfer_write %[[RES_VEC]], %[[SHMEM_VIEW_EXPANDED]][]
+// CHECK: vector.transfer_write %[[RES_VEC]]
+// CHECK: gpu.barrier
+// CHECK: memref.dealloc %[[SHMEM_ALLOC]]
+
+// -----
+
+hal.executable @group_reduction_larger {
+hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", {target_arch = "sm_35"}> {
+ hal.executable.export public @group_reduction_larger ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @group_reduction_larger() {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant -0.000000e+00 : f32
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<33x256xf32>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:tensor<33xf32>>
+ %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [33, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<33x256xf32>> -> tensor<33x256xf32>
+ %3 = tensor.empty() : tensor<33xf32>
+ %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<33xf32>) -> tensor<33xf32>
+ %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor<33x256xf32>) outs(%4 : tensor<33xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %6 = arith.addf %in, %out : f32
+ linalg.yield %6 : f32
+ } -> tensor<33xf32>
+ flow.dispatch.tensor.store %5, %1, offsets = [0], sizes = [8], strides = [1] : tensor<33xf32> -> !flow.dispatch.tensor<writeonly:tensor<33xf32>>
+ return
+ }
+ }
+}
+}
+
+// CHECK-LABEL: func.func @group_reduction_larger
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
+// CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x128xf32, 3>
+// CHECK-DAG: %[[TIDX:.]] = gpu.thread_id x
+// CHECK: %[[IDX:.*]] = affine.apply{{.*}}%[[TIDX]]
+// CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][0, %[[IDX]]]{{.*}}to memref<4xf32, strided<[1], offset: ?>, 3>
// CHECK: gpu.barrier
-// Last part is not distributed atm and is only ran by threadIdx.x == 0 and threadIdx.y == 0.
-// CHECK: %[[CONDYIS0:.*]] = arith.cmpi ult, %[[TIDY]], %[[C1]] : index
-// TODO: cond eq 0 and cond ult 1 do not CSE atm.
-// CHECK: %[[CONXANDYARE0:.*]] = arith.andi %{{.*}}, %[[CONDYIS0]] : i1
-// CHECK: scf.if %[[CONXANDYARE0]] {
-// CHECK: vector.transfer_read
-// CHECK: vector.reduction <add>
-// CHECK: math.sqrt
-// CHECK: vector.transfer_write
+// Local per-thread scf.for-based reduction, after the single-iteration scf.for was canonicalized.
+// CHECK: vector.transfer_read
+// CHECK: vector.transfer_read
+// CHECK: arith.addf %{{.*}} : vector<4xf32>
+// CHECK: vector.transfer_write
// CHECK: gpu.barrier
-// CHECK: memref.dealloc %[[SHMEM_ALLOC]] : memref<1x2xf32, 3>
+
+// Distributed reduction: everyone loads then 5 xor + addf expected.
+// CHECK: %[[TIDY:.]] = gpu.thread_id y
+// CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[IDX]]]
+// CHECK-COUNT-5: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
+
+// CHECK: %[[RES:.*]] = arith.addf %{{.*}}
+// CHECK: %[[RES_VEC:.*]] = vector.broadcast %[[RES]] : f32 to vector<f32>
+// CHECK: %[[CONDXIS0:.*]] = arith.cmpi eq, %[[TIDX]], %[[C0]] : index
+// CHECK: scf.if %[[CONDXIS0]]
+// CHECK: vector.transfer_write %[[RES_VEC]]
+// CHECK: gpu.barrier
+// CHECK: memref.dealloc %[[SHMEM_ALLOC]]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
index 05c5f7a..e25a9d5 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
@@ -28,17 +28,115 @@
}
}
-// CHECK-LABEL: func.func @group_reduction
-// CHECK: transform.structured.canonicalized_sequence failures(propagate)
-// CHECK: transform.structured.match ops{["linalg.fill"]} in %{{.+}}
-// CHECK: transform.structured.match ops{["linalg.generic"]} in %{{.+}}
-// CHECK: transform.structured.split_reduction %{{.+}} {insert_split_dimension = 1 : i64, split_factor = 2 : i64}
-// CHECK: transform.iree.tile_to_foreach_thread_and_workgroup_count_region %{{.*}} num_threads [] tile_sizes [1](mapping = [#gpu.block<x>])
-// CHECK: transform.structured.tile_to_foreach_thread_op %{{.*}} num_threads [] tile_sizes [1, 0, 0](mapping = [#gpu.thread<z>])
-// CHECK: transform.structured.tile_to_foreach_thread_op %{{.*}} num_threads [] tile_sizes [1, 1, 0](mapping = [#gpu.thread<z>, #gpu.thread<y>])
-// CHECK: transform.iree.bufferize {target_gpu}
-// CHECK: transform.iree.foreach_thread_to_workgroup
-// CHECK: transform.iree.map_nested_foreach_thread_to_gpu_threads %{{.+}} {workgroup_size = [32, 2, 1]}
-// CHECK: transform.structured.match ops{["scf.if"]} in %{{.+}}
-// CHECK: transform.iree.vector.to_warp_execute_on_lane_0 %{{.+}} {warp_size = 32 : i64}
-// CHECK: transform.iree.vector.warp_distribute %{{.+}}
+// CHECK-LABEL: func.func @group_reduction
+// CHECK: transform.structured.canonicalized_sequence failures(propagate)
+// CHECK: transform.iree.match_callback failures(propagate) "reduction"(%{{.+}})
+// CHECK: transform.iree.take_first
+// CHECK: transform.iree.tile_to_foreach_thread_and_workgroup_count_region {{.*}} tile_sizes [1](mapping = [#gpu.block<x>])
+// CHECK-COUNT-3: transform.structured.fuse_into_containing_op
+// CHECK: transform.iree.take_first
+// CHECK: transform.structured.tile_reduction_using_scf %{{.*}} {tile_sizes = [0, 64]}
+// CHECK: transform.structured.tile_to_foreach_thread_op %{{.*}} num_threads [0, 32]
+// CHECK-SAME: (mapping = [#gpu.thread<x>])
+// CHECK: transform.structured.tile_to_foreach_thread_op %{{.*}} tile_sizes [0, 2](mapping = [#gpu.thread<x>])
+// CHECK: transform.iree.take_first
+// CHECK: transform.structured.tile_to_foreach_thread_op %{{.*}} tile_sizes [1](mapping = [#gpu.thread<y>])
+// CHECK: transform.structured.fuse_into_containing_op
+// CHECK: transform.structured.match ops{["func.func"]} in %arg0
+// CHECK: transform.iree.apply_patterns %{{.*}} {rank_reducing}
+// CHECK: transform.structured.vectorize
+// CHECK: transform.iree.bufferize {target_gpu}
+// CHECK: transform.structured.match ops{["func.func"]} in %{{.*}}
+// CHECK: transform.iree.erase_hal_descriptor_type_from_memref
+// CHECK: transform.structured.match ops{["func.func"]} in %{{.*}}
+// CHECK: transform.iree.foreach_thread_to_workgroup
+// CHECK: transform.iree.map_nested_foreach_thread_to_gpu_threads %{{.*}} {workgroup_size = [32, 1, 1]}
+// CHECK: transform.iree.apply_patterns %{{.*}} {rank_reducing}
+// CHECK: transform.structured.match ops{["scf.if"]} in %{{.*}}
+// CHECK: sequence {{.*}} failures(suppress) {
+// CHECK: transform.iree.vector.to_warp_execute_on_lane_0 %{{.*}} {warp_size = 32 : i64}
+// CHECK: }
+// CHECK: transform.iree.vector.warp_distribute
+
+
+// -----
+
+
+hal.executable @group_reduction_128 {
+hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", {target_arch = "sm_35"}> {
+ hal.executable.export public @group_reduction_128 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @group_reduction_128() {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant -0.000000e+00 : f32
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<8x128xf32>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:tensor<8xf32>>
+ %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [8, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<8x128xf32>> -> tensor<8x128xf32>
+ %3 = tensor.empty() : tensor<8xf32>
+ %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<8xf32>) -> tensor<8xf32>
+ %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor<8x128xf32>) outs(%4 : tensor<8xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %6 = arith.addf %in, %out : f32
+ linalg.yield %6 : f32
+ } -> tensor<8xf32>
+ flow.dispatch.tensor.store %5, %1, offsets = [0], sizes = [8], strides = [1] : tensor<8xf32> -> !flow.dispatch.tensor<writeonly:tensor<8xf32>>
+ return
+ }
+ }
+}
+}
+
+// Overall, the schedule is same as above, but with larger tile sizes.
+// Checking only the tile sizes.
+
+// CHECK-LABEL: func.func @group_reduction_128
+// CHECK: transform.structured.canonicalized_sequence failures(propagate)
+// CHECK: transform.structured.tile_reduction_using_scf %{{.*}} {tile_sizes = [0, 128]}
+// CHECK: transform.structured.tile_to_foreach_thread_op %{{.*}} num_threads [0, 32]
+// CHECK-SAME: (mapping = [#gpu.thread<x>])
+// CHECK: transform.structured.tile_to_foreach_thread_op %{{.*}} tile_sizes [0, 4](mapping = [#gpu.thread<x>])
+
+// -----
+
+
+hal.executable @group_reduction_32 {
+hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", {target_arch = "sm_35"}> {
+ hal.executable.export public @group_reduction_32 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @group_reduction_32() {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant -0.000000e+00 : f32
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<8x32xf32>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:tensor<8xf32>>
+ %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<8x32xf32>> -> tensor<8x32xf32>
+ %3 = tensor.empty() : tensor<8xf32>
+ %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<8xf32>) -> tensor<8xf32>
+ %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor<8x32xf32>) outs(%4 : tensor<8xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %6 = arith.addf %in, %out : f32
+ linalg.yield %6 : f32
+ } -> tensor<8xf32>
+ flow.dispatch.tensor.store %5, %1, offsets = [0], sizes = [8], strides = [1] : tensor<8xf32> -> !flow.dispatch.tensor<writeonly:tensor<8xf32>>
+ return
+ }
+ }
+}
+}
+
+// Overall, the schedule is same as above, but with larger tile sizes.
+// Checking only the tile sizes.
+
+// CHECK-LABEL: func.func @group_reduction_32
+// CHECK: transform.structured.canonicalized_sequence failures(propagate)
+// CHECK: transform.structured.tile_reduction_using_scf %{{.*}} {tile_sizes = [0, 32]}
+// CHECK: transform.structured.tile_to_foreach_thread_op %{{.*}} num_threads [0, 32]
+// CHECK-SAME: (mapping = [#gpu.thread<x>])
+// CHECK: transform.structured.tile_to_foreach_thread_op %{{.*}} tile_sizes [0, 1](mapping = [#gpu.thread<x>])
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
index 430d879..661aece 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
@@ -40,6 +40,11 @@
/// for all operands of the relevant kind.
struct AllOperands {};
+struct CaptureDim {
+ explicit CaptureDim(int64_t &value) : value(value) {}
+ int64_t &value;
+};
+
/// A tag indicating to look for any user of the operation's result that would
/// satisfy the predicate.
struct HasAnyUse {};
@@ -153,6 +158,8 @@
/// (i.e. Python-style).
StructuredOpMatcher &dim(int64_t dimension, DivisibleBy divisibleBy);
+ StructuredOpMatcher &dim(int64_t dimension, CaptureDim capture);
+
/// Adds a predicate checking that the structured op has the given number of
/// inputs.
StructuredOpMatcher &input(NumEqualsTo num) {
@@ -445,7 +452,8 @@
void makeReductionMatcher(StructuredOpMatcher &reduction,
StructuredOpMatcher &fill,
StructuredOpMatcher &leading,
- StructuredOpMatcher &trailing);
+ StructuredOpMatcher &trailing,
+ int64_t &reductionDimensionSize);
/// Creates a group of matchers for:
///
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
index 2afd477..3638464 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
@@ -1233,7 +1233,8 @@
transform_ext::StructuredOpMatcher pattern, fill, leadingEltwise,
trailingEltwise;
- makeReductionMatcher(pattern, fill, leadingEltwise, trailingEltwise);
+ int64_t ignore;
+ makeReductionMatcher(pattern, fill, leadingEltwise, trailingEltwise, ignore);
// TODO: need a mechanism for this to go around the entire IR,
// potentially with list matches for each group.
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
index eda5d55..25b81c0 100644
--- a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
@@ -98,6 +98,21 @@
}
transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::dim(int64_t dimension, CaptureDim capture) {
+ predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
+ unsigned rank = linalgOp.getNumLoops();
+ int64_t transformedDimension =
+ dimension >= 0 ? dimension : rank + dimension;
+ if (transformedDimension >= rank)
+ return false;
+
+ capture.value = linalgOp.getStaticLoopRanges()[transformedDimension];
+ return true;
+ });
+ return *this;
+}
+
+transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::input(AllOperands tag, IsPermutation) {
predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool {
// all_of with a lambda requires const-casting dance, so using a loop.
@@ -292,7 +307,8 @@
transform_ext::StructuredOpMatcher &reduction,
transform_ext::StructuredOpMatcher &fill,
transform_ext::StructuredOpMatcher &leading,
- transform_ext::StructuredOpMatcher &trailing) {
+ transform_ext::StructuredOpMatcher &trailing,
+ int64_t &reductionDimensionSize) {
fill = m_StructuredOp<linalg::FillOp>();
trailing = m_StructuredOp<linalg::GenericOp>()
.input(AllOperands(), IsPermutation())
@@ -304,6 +320,7 @@
.dim(AllDims(), ShapeKind::Static)
.dim(-1, utils::IteratorType::reduction)
.dim(-1, DivisibleBy(kCudaWarpSize))
+ .dim(-1, CaptureDim(reductionDimensionSize))
// Can be extended to projected permutation with broadcast.
.input(AllOperands(), IsPermutation())
// TODO: we want to accept any input position here.
diff --git a/tests/transform_dialect/cuda/reduction_v2.mlir b/tests/transform_dialect/cuda/reduction_v2.mlir
index 4a3b2a2..da9a90c 100644
--- a/tests/transform_dialect/cuda/reduction_v2.mlir
+++ b/tests/transform_dialect/cuda/reduction_v2.mlir
@@ -11,7 +11,7 @@
affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%arg : !in_tensor_t) outs(%1 : !out_tensor_t) {
- ^bb0(%arg3: f32, %arg4: f32):
+ ^bb0(%arg3: f32, %arg4: f32):
%3 = arith.addf %arg3, %arg4 : f32
linalg.yield %3 : f32
} -> !out_tensor_t
@@ -27,11 +27,26 @@
// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_v2_codegen_spec.mlir | \
// RUN: FileCheck %s --check-prefix=CHECK
+// RUN: iree-opt %s --iree-hal-target-backends=cuda \
+// RUN: --iree-abi-transformation-pipeline \
+// RUN: --iree-flow-transformation-pipeline \
+// RUN: --iree-stream-transformation-pipeline \
+// RUN: --iree-hal-configuration-pipeline | \
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))' \
+// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit | \
+// RUN: FileCheck %s --check-prefix=CHECK
+
// RUN: iree-compile %s --iree-hal-target-backends=cuda \
// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_v2_codegen_spec.mlir | \
// RUN: iree-run-module --entry_function=reduce --device=cuda --function_input="33x1024xf32=1" |\
// RUN: FileCheck %s --check-prefix=EXEC
+// RUN: iree-compile %s --iree-hal-target-backends=cuda \
+// RUN: --iree-codegen-llvmgpu-enable-transform-dialect-jit | \
+// RUN: iree-run-module --entry_function=reduce --device=cuda --function_input="33x1024xf32=1" |\
+// RUN: FileCheck %s --check-prefix=EXEC
+
+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[F0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>