[GPU] Add SwapExpandShapeWithSlice pattern to loop fusion pass (#19729)
This PR moves the `SwapExpandShapeWithSlicePattern` to
Codegen/Common/Transforms, and adds the pattern to the
FuseAndHoistParallelLoops pass.
This pattern is generally useful for tiling fusion, because it exposes
more producer fusion opportunities when there are reshapes in the IR,
but more specifically, it is useful in combination with the pattern
introduced in https://github.com/iree-org/iree/pull/19295. That pattern
creates an expanded parallel_insert_slice, and an expand_shape on the
corresponding init block arg in the forall loop body. This makes the
slice on the init argument lower dimensional than the
parallel_insert_slice at the end. It is better for bufferization if
these slices are the same, and this pattern makes that happen by
bubbling the slice of the init arg up through the expand_shape,
increasing the dimensionality to match the parallel_insert_slice.
---------
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
index 940477d..17510c5 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
@@ -155,6 +155,7 @@
"TileDispatchUsingInterface.cpp",
"TileLargeTensors.cpp",
"TileSizeSelection.cpp",
+ "Transforms.cpp",
"TypePropagationPass.cpp",
"UnrollAnnotatedLoops.cpp",
"UserConfig.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index d6e2528..5029cff 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -147,6 +147,7 @@
"TileDispatchUsingInterface.cpp"
"TileLargeTensors.cpp"
"TileSizeSelection.cpp"
+ "Transforms.cpp"
"TypePropagationPass.cpp"
"UnrollAnnotatedLoops.cpp"
"UserConfig.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp
index 62a8ac2..68ee271 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp
@@ -5,6 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
+#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
@@ -41,156 +42,6 @@
};
} // namespace
-/// Pattern to convert `tensor.extract_slice(tensor.expand_shape)` to
-/// `tensor.expand_shape(tensor.extract_slice)`.
-static LogicalResult
-swapExpandShapeWithSlice(RewriterBase &rewriter,
- tensor::ExpandShapeOp expandShapeOp,
- tensor::ExtractSliceOp sliceOp) {
- SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
-
- if (sliceOp.getResultType().getRank() != sizes.size()) {
- return rewriter.notifyMatchFailure(sliceOp,
- "unimplemented: rank reducing slice");
- }
-
- // Helper variables and function for accumulating the new offset and length
- // values.
- Location loc = expandShapeOp->getLoc();
- AffineExpr d0, d1, d2;
- bindDims(rewriter.getContext(), d0, d1, d2);
- // Multiply two integers.
- auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
- auto mulMap = AffineMap::get(2, 0, {d0 * d1});
- return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
- {v1, v2});
- };
- auto mulAdd = [&](OpFoldResult v1, OpFoldResult v2, OpFoldResult v3) {
- auto mulMap = AffineMap::get(3, 0, {d0 * d1 + d2});
- return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
- {v1, v2, v3});
- };
-
- SmallVector<OpFoldResult> outputShape =
- getMixedValues(expandShapeOp.getStaticOutputShape(),
- expandShapeOp.getOutputShape(), rewriter);
-
- auto isZeroOffsetAndFullSize = [](OpFoldResult offset, OpFoldResult sliceSize,
- OpFoldResult size) {
- if (!isConstantIntValue(offset, 0))
- return false;
- FailureOr<bool> maybeEqual =
- ValueBoundsConstraintSet::areEqual(sliceSize, size);
- return llvm::succeeded(maybeEqual) && maybeEqual.value();
- };
-
- // First verify that this is a full slice of the expanded tensor.
- for (const ReassociationIndices &indices :
- expandShapeOp.getReassociationIndices()) {
- int64_t i = 0;
- int64_t e = indices.size();
- // Find the first expanded dim after the first dim with non-unit extracted
- // size.
- for (; i < e; ++i) {
- if (!isConstantIntValue(sizes[indices[i]], 1)) {
- // +1 to skip the first non-unit size dim.
- i++;
- break;
- }
- }
-
- // Verify that all subsequent dimensions extract the full size of the
- // source tensor.
- for (; i < e; ++i) {
- int64_t expandedDim = indices[i];
- if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
- outputShape[expandedDim])) {
- return rewriter.notifyMatchFailure(
- sliceOp, "Not a contiguous slice of the expanded tensor.");
- }
- }
- }
-
- // Compute new offsets, lengths, and strides.
- SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
- for (const ReassociationIndices &indices :
- expandShapeOp.getReassociationIndices()) {
- OpFoldResult newOffset = rewriter.getIndexAttr(0);
- OpFoldResult newSize = rewriter.getIndexAttr(1);
-
- int64_t i = 0;
- int64_t e = indices.size();
- // Offset = cumulative product of leading unit extracted dims.
- for (; i < e; ++i) {
- int64_t expandedDim = indices[i];
- if (!isConstantIntValue(sizes[expandedDim], 1))
- break;
-
- newOffset =
- mulAdd(newOffset, outputShape[expandedDim], offsets[expandedDim]);
- }
-
- if (i != e) {
- int64_t expandedDim = indices[i];
- newOffset =
- mulAdd(newOffset, outputShape[expandedDim], offsets[expandedDim]);
- newSize = sizes[expandedDim];
- i++;
- }
-
- for (; i < e; ++i) {
- OpFoldResult fullSize = outputShape[indices[i]];
- newOffset = mul(newOffset, fullSize);
- newSize = mul(newSize, fullSize);
- }
-
- newOffsets.push_back(newOffset);
- newLengths.push_back(newSize);
-
- // Only unit stride supported.
- newStrides.push_back(rewriter.getIndexAttr(1));
- }
-
- // The shape of the result can be obtained from the sizes passed in.
- SmallVector<Value> dynDims;
- SmallVector<int64_t> shape;
- dispatchIndexOpFoldResults(sizes, dynDims, shape);
- RankedTensorType resultType = RankedTensorType::get(
- shape, expandShapeOp.getResultType().getElementType());
-
- // Create a new ExtractSliceOp and ExpandShapeOp.
- Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
- loc, expandShapeOp.getSrc(), newOffsets, newLengths, newStrides);
- auto newExpandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
- loc, resultType, newSliceOp, expandShapeOp.getReassociationIndices(),
- sizes);
- rewriter.replaceOp(sliceOp, newExpandShapeOp);
- return success();
-}
-
-/// tensor.empty does not define any tensor contents, so an unpadded pack
-/// can be folded away.
-struct SwapExpandShapeWithSlicePattern
- : public OpRewritePattern<tensor::ExtractSliceOp> {
- using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
- PatternRewriter &rewriter) const override {
- auto expandOp = sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
- if (!expandOp) {
- return failure();
- }
-
- if (!sliceOp.hasUnitStride()) {
- return rewriter.notifyMatchFailure(sliceOp,
- "unsupported: non-unit stride");
- }
-
- return swapExpandShapeWithSlice(rewriter, expandOp, sliceOp);
- }
-};
-
/// This collects the set of operations to tile + fuse starting from the given
/// root |op| and walking up to its producers. Stops at operations given by
/// |exclude| which are expected to receive their own independent tiling for the
@@ -348,7 +199,7 @@
tensor::DimOp::getCanonicalizationPatterns(cleanupPatterns, context);
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
cleanupPatterns);
- cleanupPatterns.add<SwapExpandShapeWithSlicePattern>(context);
+ populateSwapExtractWithExpandPattern(cleanupPatterns);
}
tileAndFuseOptions.cleanupPatterns =
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp
index f974e51..cb61cf4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp
@@ -4,10 +4,11 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "iree/compiler/Codegen/Common/GPU/Passes.h"
+#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
-#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
@@ -374,6 +375,7 @@
patterns.add<FuseTilableDestinationProducers>(context);
patterns.add<FuseUnitLoopDestination>(context);
patterns.add<FuseTilableForallConsumers>(context);
+ populateSwapExtractWithExpandPattern(patterns);
tensor::populateFoldTensorEmptyPatterns(patterns);
scf::ForallOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
new file mode 100644
index 0000000..79bf739
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
@@ -0,0 +1,191 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/Transforms.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+
+#define DEBUG_TYPE "iree-codegen-common-transforms"
+
+namespace mlir::iree_compiler {
+
+/// Converts `tensor.extract_slice(tensor.expand_shape)` to
+/// `tensor.expand_shape(tensor.extract_slice)`.
+/// For this transformation to be possible, the slice must be fully contiguous
+/// within each reassociation group of the expand_shape. If the transformation
+/// is not possible, or if the slice is rank reducting, the function returns
+/// failure.
+///
+/// Example:
+/// ```
+/// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
+/// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
+/// %slice = tensor.extract_slice %reshape ...
+/// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
+///
+/// // The transformation is possible because each reassociation group has a
+/// // contiguous slice. (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4])
+/// // After the transformation:
+///
+/// %slice = tensor.extract_slice %in ...
+/// tensor<8x16x32xf32> to tensor<8x5x4xf32>
+/// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
+/// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
+/// ```
+static LogicalResult
+swapExpandShapeWithSlice(RewriterBase &rewriter,
+ tensor::ExpandShapeOp expandShapeOp,
+ tensor::ExtractSliceOp sliceOp) {
+ SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
+
+ if (sliceOp.getResultType().getRank() != sizes.size()) {
+ return rewriter.notifyMatchFailure(sliceOp,
+ "unimplemented: rank reducing slice");
+ }
+
+ // Helper variables and function for accumulating the new offset and length
+ // values.
+ Location loc = expandShapeOp->getLoc();
+ AffineExpr d0, d1, d2;
+ bindDims(rewriter.getContext(), d0, d1, d2);
+ // Multiply two integers.
+ auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
+ auto mulMap = AffineMap::get(2, 0, {d0 * d1});
+ return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
+ {v1, v2});
+ };
+ auto mulAdd = [&](OpFoldResult v1, OpFoldResult v2, OpFoldResult v3) {
+ auto mulMap = AffineMap::get(3, 0, {d0 * d1 + d2});
+ return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
+ {v1, v2, v3});
+ };
+
+ SmallVector<OpFoldResult> outputShape =
+ getMixedValues(expandShapeOp.getStaticOutputShape(),
+ expandShapeOp.getOutputShape(), rewriter);
+
+ auto isZeroOffsetAndFullSize = [](OpFoldResult offset, OpFoldResult sliceSize,
+ OpFoldResult size) {
+ if (!isConstantIntValue(offset, 0))
+ return false;
+ FailureOr<bool> maybeEqual =
+ ValueBoundsConstraintSet::areEqual(sliceSize, size);
+ return llvm::succeeded(maybeEqual) && maybeEqual.value();
+ };
+
+ // First verify that this is a full slice of the expanded tensor.
+ for (const ReassociationIndices &indices :
+ expandShapeOp.getReassociationIndices()) {
+ int64_t i = 0;
+ int64_t e = indices.size();
+ // Find the first expanded dim after the first dim with non-unit extracted
+ // size.
+ for (; i < e; ++i) {
+ if (!isConstantIntValue(sizes[indices[i]], 1)) {
+ // +1 to skip the first non-unit size dim.
+ i++;
+ break;
+ }
+ }
+
+ // Verify that all subsequent dimensions extract the full size of the
+ // source tensor.
+ for (; i < e; ++i) {
+ int64_t expandedDim = indices[i];
+ if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
+ outputShape[expandedDim])) {
+ return rewriter.notifyMatchFailure(
+ sliceOp, "Not a contiguous slice of the expanded tensor.");
+ }
+ }
+ }
+
+ // Compute new offsets, lengths, and strides.
+ SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
+ for (const ReassociationIndices &indices :
+ expandShapeOp.getReassociationIndices()) {
+ OpFoldResult newOffset = rewriter.getIndexAttr(0);
+ OpFoldResult newSize = rewriter.getIndexAttr(1);
+
+ int64_t i = 0;
+ int64_t e = indices.size();
+ // Offset = cumulative product of leading unit extracted dims.
+ for (; i < e; ++i) {
+ int64_t expandedDim = indices[i];
+ if (!isConstantIntValue(sizes[expandedDim], 1))
+ break;
+
+ newOffset =
+ mulAdd(newOffset, outputShape[expandedDim], offsets[expandedDim]);
+ }
+
+ if (i != e) {
+ int64_t expandedDim = indices[i];
+ newOffset =
+ mulAdd(newOffset, outputShape[expandedDim], offsets[expandedDim]);
+ newSize = sizes[expandedDim];
+ i++;
+ }
+
+ for (; i < e; ++i) {
+ OpFoldResult fullSize = outputShape[indices[i]];
+ newOffset = mul(newOffset, fullSize);
+ newSize = mul(newSize, fullSize);
+ }
+
+ newOffsets.push_back(newOffset);
+ newLengths.push_back(newSize);
+
+ // Only unit stride supported.
+ newStrides.push_back(rewriter.getIndexAttr(1));
+ }
+
+ // The shape of the result can be obtained from the sizes passed in.
+ SmallVector<Value> dynDims;
+ SmallVector<int64_t> shape;
+ dispatchIndexOpFoldResults(sizes, dynDims, shape);
+ RankedTensorType resultType = RankedTensorType::get(
+ shape, expandShapeOp.getResultType().getElementType());
+
+ // Create a new ExtractSliceOp and ExpandShapeOp.
+ Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
+ loc, expandShapeOp.getSrc(), newOffsets, newLengths, newStrides);
+ auto newExpandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
+ loc, resultType, newSliceOp, expandShapeOp.getReassociationIndices(),
+ sizes);
+ rewriter.replaceOp(sliceOp, newExpandShapeOp);
+ return success();
+}
+
+namespace {
+
+struct SwapExpandShapeWithSlicePattern
+ : public OpRewritePattern<tensor::ExtractSliceOp> {
+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ auto expandOp = sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+ if (!expandOp) {
+ return failure();
+ }
+
+ if (!sliceOp.hasUnitStride()) {
+ return rewriter.notifyMatchFailure(sliceOp,
+ "unsupported: non-unit stride");
+ }
+
+ return swapExpandShapeWithSlice(rewriter, expandOp, sliceOp);
+ }
+};
+
+} // namespace
+
+void populateSwapExtractWithExpandPattern(RewritePatternSet &patterns) {
+ patterns.add<SwapExpandShapeWithSlicePattern>(patterns.getContext());
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.h b/compiler/src/iree/compiler/Codegen/Common/Transforms.h
index 98c7478..53c2907 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.h
@@ -91,6 +91,8 @@
/// for maximumf/minimumf ops, e.g. LLVM NVIDIA-PTX.
void populateReplaceSlowMinMaxOpsPatterns(RewritePatternSet &patterns);
+void populateSwapExtractWithExpandPattern(RewritePatternSet &patterns);
+
} // namespace mlir::iree_compiler
#endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMS_H_
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_igemm_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_igemm_tile_and_fuse.mlir
index 3466bf8..ee06c4f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_igemm_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_igemm_tile_and_fuse.mlir
@@ -143,7 +143,7 @@
// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C721]] step %[[C1]] {{.*}} -> (vector<1x1x1x1x4x1xf32>)
// CHECK: gpu.barrier
// CHECK-DAG: %[[LHS_MM0:.+]] = vector.transfer_read {{.*}} vector<4xf16>
-// CHECK-DAG: %[[RHS_MM:.+]] = vector.transfer_read {{.*}} vector<4x1x1xf16>
+// CHECK-DAG: %[[RHS_MM:.+]] = vector.transfer_read {{.*}} vector<4xf16>
// CHECK-COUNT-1: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32
// CHECK: %[[LOOP_T:.+]] = vector.shape_cast %[[LOOP]] : vector<1x1x1x1x4x1xf32> to vector<4x1x1xf32>
// CHECK: vector.transfer_write %[[LOOP_T]]