[gpu] Enable fusing input producers after tiling reduction loops (#13806)
This commit enables fusing input producers into the
serial loop after tiling the matmul K dimension.
This enables using the GPU shared memory promoted
for A/B matrix slices for weight dequantization, without
introducing additional shared memory allocations
during bufferization. This is to support int4 weight
quantized matmuls in LLMs.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
index 3d5e870..72c6525 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
@@ -47,6 +47,7 @@
"@llvm-project//mlir:AffineUtils",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:BufferizationDialect",
+ "@llvm-project//mlir:DestinationStyleOpInterface",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:GPUTransformOps",
@@ -64,6 +65,7 @@
"@llvm-project//mlir:SCFUtils",
"@llvm-project//mlir:SideEffectInterfaces",
"@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorDialect",
"@llvm-project//mlir:VectorToSCF",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
index 518372f..a717af0 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
@@ -39,6 +39,7 @@
MLIRAffineUtils
MLIRArithDialect
MLIRBufferizationDialect
+ MLIRDestinationStyleOpInterface
MLIRFuncDialect
MLIRGPUDialect
MLIRGPUTransformOps
@@ -56,6 +57,7 @@
MLIRSCFUtils
MLIRSideEffectInterfaces
MLIRSupport
+ MLIRTensorDialect
MLIRTransforms
MLIRVectorDialect
MLIRVectorToSCF
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CommonGPUPasses.h b/compiler/src/iree/compiler/Codegen/Common/GPU/CommonGPUPasses.h
index 3eed3d5..fe66198 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/CommonGPUPasses.h
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CommonGPUPasses.h
@@ -45,8 +45,11 @@
RewriterBase &rewriter, scf::ForOp forOp,
PipeliningSchedulingStrategy startegy, bool peelEpilogue, int64_t depth);
-/// Tiles Linalg ops in the given `funcOp` to serial loops without distribution.
-LogicalResult tileToSerialLoops(func::FuncOp funcOp, bool onlyReduction = true);
+/// Tiles Linalg ops in the given `funcOp` along reduction dimensions to serial
+/// loops without distribution. If `fuseInputProducer` is true, input producers
+/// will be fused into the serial loop.
+LogicalResult tileReductionToSerialLoops(func::FuncOp funcOp,
+ bool fuseInputProducer = false);
//===----------------------------------------------------------------------===//
// Passes
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorAlloc.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorAlloc.cpp
index 37793fb..27fcf6b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorAlloc.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorAlloc.cpp
@@ -8,9 +8,12 @@
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/LinalgOpInfo.h"
+#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#define DEBUG_TYPE "iree-codegen-gpu-tensor-alloc"
@@ -89,6 +92,53 @@
}
namespace {
+/// Swaps bufferization.alloc_tensor with the copied linalg op result when the
+/// linalg op does not use the output initial value during calculation.
+///
+/// This converts the following IR:
+/// ```
+/// %linalg = linalg ... ins(...) outs(...)
+/// %val = bufferization.alloc_tensor() copy(%linalg)
+/// ```
+/// Into
+/// ```
+/// %alloc = bufferization.alloc_tensor()
+/// %val = linalg ... ins(...) outs(%alloc)
+/// ```
+struct SwapAllocTensorPattern final
+ : OpRewritePattern<bufferization::AllocTensorOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(bufferization::AllocTensorOp allocOp,
+ PatternRewriter &rewriter) const override {
+ if (!allocOp.getCopy()) return failure();
+ auto linalgOp = allocOp.getCopy().getDefiningOp<linalg::LinalgOp>();
+ if (!linalgOp) return failure();
+
+ // Make sure we don't use the initial values for the linalg output we are
+ // copying during the tensor allocation.
+ unsigned resultNumber = cast<OpResult>(allocOp.getCopy()).getResultNumber();
+ OpOperand *initOperand = linalgOp.getDpsInitOperand(resultNumber);
+ if (linalgOp.payloadUsesValueFromOperand(initOperand)) return failure();
+
+ rewriter.setInsertionPoint(linalgOp);
+ std::optional<Attribute> memorySpace = allocOp.getMemorySpace();
+ auto newAllocOp = rewriter.create<bufferization::AllocTensorOp>(
+ allocOp.getLoc(), allocOp.getType(), allocOp.getDynamicSizes(),
+ /*copy=*/Value(),
+ memorySpace ? cast<IntegerAttr>(*memorySpace) : IntegerAttr());
+ newAllocOp->setAttr(bufferization::BufferizationDialect::kEscapeAttrName,
+ rewriter.getBoolArrayAttr({false}));
+ rewriter.updateRootInPlace(linalgOp, [&]() {
+ linalgOp->setOperand(linalgOp.getNumDpsInputs() + resultNumber,
+ newAllocOp);
+ });
+ rewriter.replaceOp(allocOp, linalgOp->getResult(resultNumber));
+
+ return failure();
+ }
+};
+
struct GPUTensorAllocPass : public GPUTensorAllocBase<GPUTensorAllocPass> {
private:
GPUPromoteSharedMemPattern promoteSharedMemPattern =
@@ -104,10 +154,17 @@
auto funcOp = getOperation();
// Tile the reduction first to reduce the alloc size.
- if (failed(tileToSerialLoops(funcOp))) {
+ if (failed(
+ tileReductionToSerialLoops(funcOp, /*fuseInputProducer=*/true))) {
return signalPassFailure();
}
+ LLVM_DEBUG({
+ llvm::dbgs() << "// --- After tiling to serial loops ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+
SmallVector<Operation *> opsToPromote;
funcOp.walk([&](Operation *op) {
switch (promoteSharedMemPattern) {
@@ -152,6 +209,24 @@
break;
}
}
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "// --- After promotion ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+
+ // Move tensor allocations earlier and use them for linalg init operands
+ // when possible. This change cleans up the IR to avoid bufferization
+ // creating extra buffers in later stages.
+ {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ patterns.add<SwapAllocTensorPattern>(context);
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
}
};
} // namespace
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp
index 822c117..8094e91 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp
@@ -21,7 +21,9 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using mlir::iree_compiler::IREE::LinalgExt::TilingPatterns;
@@ -31,31 +33,118 @@
namespace mlir {
namespace iree_compiler {
+class TileConsumerAndFuseInputProducer final
+ : public OpInterfaceRewritePattern<TilingInterface> {
+ public:
+ TileConsumerAndFuseInputProducer(
+ MLIRContext *context, IREE::LinalgExt::LinalgTransformationFilter filter,
+ bool fuseInputProducer, PatternBenefit benefit = 1)
+ : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+ filter(std::move(filter)),
+ fuseInputProducer(fuseInputProducer) {}
+
+ LogicalResult matchAndRewrite(TilingInterface op,
+ PatternRewriter &rewriter) const override {
+ if (failed(filter.checkAndNotify(rewriter, op))) return failure();
+
+ // Make sure we have a PartitionableLoopInterface op here and query the tile
+ // sizes from the partitionable loops.
+ auto plOp = dyn_cast<PartitionableLoopsInterface>(*op);
+ if (!plOp) return failure();
+ auto partitionedLoops = plOp.getPartitionableLoops(kNumMaxParallelDims);
+ SmallVector<int64_t, 4> tileSizes = getTileSizes(op, 0);
+ if (tileSizes.empty()) return failure();
+ // Mask out non reduction dimensions.
+ for (unsigned depth : partitionedLoops) {
+ if (depth < tileSizes.size()) tileSizes[depth] = 0;
+ }
+
+ // Make sure we have a tile size for each dimension.
+ // TODO: This is currently needed for LLVMGPU, where we propagate the
+ // lowering configuration to all linalg ops. Some linalg ops may not have
+ // the same rank, e.g., the configuration for a matmul attached to a
+ // producer linalg.fill op. It implicitly assumes that the leading
+ // dimensions of different linalg ops match, which is the current status;
+ // but may not hold true in the long term.
+ tileSizes.resize(op.getLoopIteratorTypes().size());
+
+ if (llvm::all_of(tileSizes, [](int64_t s) { return s == 0; })) {
+ return failure();
+ }
+
+ // Tile the current op and fuse its immediate input operands.
+ FailureOr<scf::SCFTilingResult> tilingResult =
+ tileConsumerAndFuseInputProducer(rewriter, op, tileSizes);
+ if (failed(tilingResult)) {
+ return rewriter.notifyMatchFailure(op, "failed to tile consumer");
+ }
+
+ // Replace the tiled op with replacements.
+ rewriter.replaceOp(op, tilingResult->replacements);
+ filter.replaceLinalgTransformationFilter(rewriter,
+ tilingResult->tiledOps.front());
+ return success();
+ }
+
+ private:
+ FailureOr<scf::SCFTilingResult> tileConsumerAndFuseInputProducer(
+ RewriterBase &rewriter, TilingInterface consumer,
+ ArrayRef<int64_t> tileSizes) const {
+ // First tile the current op as the consumer op.
+ auto tilingOptions = scf::SCFTilingOptions().setTileSizes(tileSizes);
+ FailureOr<scf::SCFTilingResult> tilingResult =
+ tileUsingSCFForOp(rewriter, consumer, tilingOptions);
+ if (failed(tilingResult)) {
+ return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
+ }
+
+ if (!fuseInputProducer) return tilingResult;
+ // If there are no generated loops generated, fusion is immaterial.
+ if (tilingResult->loops.empty()) return tilingResult;
+
+ // Collect immediate input operands that are fusable into the tiled loop.
+ // We have tensor extract slice ops taking slices of the untiled op.
+ //
+ // Note that this excludes init operands for correctness. Input operands are
+ // fine to fuse, at the cost of recomputation though.
+ SmallVector<tensor::ExtractSliceOp> candidates;
+ assert(tilingResult->tiledOps.size() == 1);
+ Operation *tiledOp = tilingResult->tiledOps.front();
+ auto dsOp = dyn_cast<DestinationStyleOpInterface>(tiledOp);
+ if (!dsOp) return tilingResult;
+ for (OpOperand *operand : dsOp.getDpsInputOperands()) {
+ auto sliceOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
+ if (!sliceOp) continue;
+ auto linalgOp = sliceOp.getSource().getDefiningOp<linalg::LinalgOp>();
+ if (!linalgOp) continue;
+ // Restrict to fully parallel linalg ops for now for simplicity.
+ auto isParallel = [](utils::IteratorType it) {
+ return linalg::isParallelIterator(it);
+ };
+ if (llvm::all_of(linalgOp.getIteratorTypesArray(), isParallel)) {
+ candidates.push_back(sliceOp);
+ }
+ }
+
+ // Fuse the candidate immeidate operands into the tiled loop.
+ OpBuilder::InsertionGuard guard(rewriter);
+ while (!candidates.empty()) {
+ tensor::ExtractSliceOp sliceOp = candidates.back();
+ candidates.pop_back();
+ tileAndFuseProducerOfSlice(rewriter, sliceOp, tilingResult->loops);
+ }
+ return tilingResult;
+ }
+
+ IREE::LinalgExt::LinalgTransformationFilter filter;
+ bool fuseInputProducer;
+};
+
/// Patterns for workgroup level tiling. Workgroup tiling is done at the flow
/// level but we may have extra tiling for the reduction dimension. Therefore we
/// tile again without distributing.
static void populateTilingPatterns(RewritePatternSet &patterns,
- bool onlyReduction) {
- auto tileSizesFn = [onlyReduction](OpBuilder &builder,
- Operation *op) -> SmallVector<Value, 4> {
- auto interfaceOp = cast<PartitionableLoopsInterface>(*op);
- auto partitionedLoops =
- interfaceOp.getPartitionableLoops(kNumMaxParallelDims);
- SmallVector<Value, 4> tileSizes = getTileSizes(builder, op, 0);
- if (onlyReduction) {
- auto zero = builder.create<arith::ConstantIndexOp>(op->getLoc(), 0);
- for (unsigned depth : partitionedLoops) {
- if (depth < tileSizes.size()) {
- tileSizes[depth] = zero;
- }
- }
- }
- return tileSizes;
- };
-
- auto tilingOptions = linalg::LinalgTilingOptions()
- .setLoopType(linalg::LinalgTilingLoopType::Loops)
- .setTileSizeComputationFunction(tileSizesFn);
+ bool fuseInputProducer) {
MLIRContext *context = patterns.getContext();
IREE::LinalgExt::LinalgTransformationFilter filter(
@@ -63,19 +152,19 @@
StringAttr::get(context, getWorkgroupMemoryMarker())},
StringAttr::get(context, getWorkgroupKTiledMarker()));
filter.setMatchByDefault();
- TilingPatterns<linalg::MatmulOp, linalg::BatchMatmulOp, linalg::GenericOp,
- linalg::Conv2DNhwcHwcfOp,
- linalg::Conv2DNchwFchwOp>::insert(patterns, tilingOptions,
- filter);
+
+ patterns.add<TileConsumerAndFuseInputProducer>(context, filter,
+ fuseInputProducer);
}
-LogicalResult tileToSerialLoops(func::FuncOp funcOp, bool onlyReduction) {
+LogicalResult tileReductionToSerialLoops(func::FuncOp funcOp,
+ bool fuseInputProducer) {
{
// Tile again at the workgroup level since redution dimension were
// ignored. Dimensions already tiled will be ignore since we tile to the
// same size.
RewritePatternSet wgTilingPatterns(funcOp.getContext());
- populateTilingPatterns(wgTilingPatterns, onlyReduction);
+ populateTilingPatterns(wgTilingPatterns, fuseInputProducer);
if (failed(applyPatternsAndFoldGreedily(funcOp,
std::move(wgTilingPatterns)))) {
return failure();
@@ -233,9 +322,7 @@
// Tile to serial loops to the wg tile size to handle reductions and other
// dimension that have not been distributed.
- if (failed(tileToSerialLoops(funcOp, /*onlyReduction=*/true))) {
- return signalPassFailure();
- }
+ if (failed(tileReductionToSerialLoops(funcOp))) return signalPassFailure();
LLVM_DEBUG({
llvm::dbgs() << "--- After tile reductions:";
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir
index 4cf4ebb..c8000b3 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir
@@ -75,12 +75,12 @@
// test corner case where the promoted value has multiple uses.
// CHECK-LABEL: func.func @matmul_multi_uses
-// CHECK: %[[C:.*]] = flow.dispatch.tensor.load
-// CHECK: %[[A:.*]] = flow.dispatch.tensor.load
-// CHECK: %[[B:.*]] = flow.dispatch.tensor.load
+// CHECK: %[[C:.*]] = flow.dispatch.tensor.load {{.+}} -> tensor<32x128xf32>
+// CHECK: %[[A:.*]] = flow.dispatch.tensor.load {{.+}} -> tensor<32x1024xf32>
+// CHECK: %[[B:.*]] = flow.dispatch.tensor.load {{.+}} -> tensor<1024x128xf32>
// CHECK: %[[PA:.*]] = bufferization.alloc_tensor() copy(%[[A]]) {bufferization.escape = [false]} : tensor<32x1024xf32>
// CHECK: %[[PB:.*]] = bufferization.alloc_tensor() copy(%[[B]]) {bufferization.escape = [false]} : tensor<1024x128xf32>
-// CHECK: %[[M:.*]] = linalg.matmul {{.*}} ins(%[[PA]], %[[PB]] : tensor<32x1024xf32>, tensor<1024x128xf32>) outs(%{{.*}} : tensor<32x128xf32>) -> tensor<32x128xf32>
+// CHECK: %[[M:.*]] = linalg.matmul ins(%[[PA]], %[[PB]] : tensor<32x1024xf32>, tensor<1024x128xf32>) outs(%{{.*}} : tensor<32x128xf32>) -> tensor<32x128xf32>
// CHECK: "some_use"(%[[A]]) : (tensor<32x1024xf32>) -> ()
// -----
@@ -129,3 +129,68 @@
// CHECK-LABEL: func.func @matmul_33x33x903168_f32
// CHECK-NOT: bufferization.alloc_tensor()
+// -----
+
+func.func @weight_dequant_matmul() {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<86x128x2048xi4>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<86x2048xf32>>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<86x2048xi4>>
+ %3 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x86x128xf32>>
+ %4 = hal.interface.binding.subspan set(0) binding(5) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4096x2048xf32>>
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %5 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %6 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_x]
+ %7 = flow.dispatch.tensor.load %4, offsets = [%5, %6], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<writeonly:tensor<4096x2048xf32>> -> tensor<32x128xf32>
+ %8 = flow.dispatch.tensor.load %3, offsets = [%5, 0, 0], sizes = [32, 86, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x86x128xf32>> -> tensor<32x86x128xf32>
+ %9 = flow.dispatch.tensor.load %0, offsets = [0, 0, %6], sizes = [86, 128, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<86x128x2048xi4>> -> tensor<86x128x128xi4>
+ %10 = flow.dispatch.tensor.load %1, offsets = [0, %6], sizes = [86, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<86x2048xf32>> -> tensor<86x128xf32>
+ %11 = flow.dispatch.tensor.load %2, offsets = [0, %6], sizes = [86, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<86x2048xi4>> -> tensor<86x128xi4>
+ %12 = tensor.empty() : tensor<86x128x128xf32>
+ %13 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]
+ } ins(%9, %10, %11 : tensor<86x128x128xi4>, tensor<86x128xf32>, tensor<86x128xi4>) outs(%12 : tensor<86x128x128xf32>) {
+ ^bb0(%in: i4, %in_0: f32, %in_1: i4, %out: f32):
+ %16 = arith.extsi %in : i4 to i32
+ %17 = arith.extsi %in_1 : i4 to i32
+ %18 = arith.subi %16, %17 : i32
+ %19 = arith.sitofp %18 : i32 to f32
+ %20 = arith.mulf %19, %in_0 : f32
+ linalg.yield %20 : f32
+ } -> tensor<86x128x128xf32>
+ %14 = linalg.fill ins(%cst : f32) outs(%7 : tensor<32x128xf32>) -> tensor<32x128xf32>
+ %15 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]
+ } ins(%8, %13 : tensor<32x86x128xf32>, tensor<86x128x128xf32>) outs(%14 : tensor<32x128xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 128, 1, 32]]>} {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %16 = arith.mulf %in, %in_0 : f32
+ %17 = arith.addf %out, %16 : f32
+ linalg.yield %17 : f32
+ } -> tensor<32x128xf32>
+ flow.dispatch.tensor.store %15, %4, offsets = [%5, %6], sizes = [32, 128], strides = [1, 1] : tensor<32x128xf32> -> !flow.dispatch.tensor<writeonly:tensor<4096x2048xf32>>
+ return
+}
+
+// CHECK-LABEL: func.func @weight_dequant_matmul()
+// CHECK: %[[LHS_LD:.+]] = flow.dispatch.tensor.load {{.+}} : !flow.dispatch.tensor<readonly:tensor<4096x86x128xf32>> -> tensor<32x86x128xf32>
+// Check that the linalg.fill as the matmul initial result is not fused in the serial loops.
+// CHECK: %[[FILL:.+]] = linalg.fill
+// Check that two serial loops are materialized for reductions.
+// CHECK: scf.for %{{.+}} = %c0 to %c86 step %c1 iter_args(%[[ARG1:.+]] = %[[FILL]]) -> (tensor<32x128xf32>)
+// CHECK: scf.for %{{.+}} = %c0 to %c128 step %c32 iter_args(%[[ARG2:.+]] = %[[ARG1]]) -> (tensor<32x128xf32>)
+// CHECK: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS_LD]]
+// Check that we have a bufferization.alloc_tensor() for in-place bufferization later.
+// CHECK: %[[RHS_ALLOC:.+]] = bufferization.alloc_tensor() {bufferization.escape = [false]} : tensor<1x32x128xf32>
+// Check that the weight dequant linalg.generic is fused inside the serial loops.
+// CHECK: %[[RHS:.+]] = linalg.generic
+// CHECK-SAME: outs(%[[RHS_ALLOC]] : tensor<1x32x128xf32>)
+// CHECK: %[[LHS_ALLOC:.+]] = bufferization.alloc_tensor() copy(%[[LHS_SLICE]]) {bufferization.escape = [false]} : tensor<32x1x32xf32>
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[LHS_ALLOC]], %[[RHS]] : tensor<32x1x32xf32>, tensor<1x32x128xf32>)
+// CHECK-SAME: outs(%[[ARG2]] : tensor<32x128xf32>)
+// CHECK: scf.yield
+// CHECK: scf.yield
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
index ee00f42..6637fda 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
@@ -24,6 +24,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -40,8 +41,7 @@
/// level but we may have extra tiling for the reduction dimension. Therefore we
/// tile again without distributing.
static void populateTilingReductionPatterns(RewritePatternSet &patterns) {
- auto tileSizesFn = [&](OpBuilder &builder,
- Operation *op) -> SmallVector<Value, 4> {
+ auto tileSizesFn = [](OpBuilder &builder, Operation *op) {
auto interfaceOp = cast<PartitionableLoopsInterface>(*op);
auto partitionedLoops =
interfaceOp.getPartitionableLoops(kNumMaxParallelDims);
@@ -65,8 +65,38 @@
StringAttr::get(context, getWorkgroupMemoryMarker())},
StringAttr::get(context, getWorkgroupKTiledMarker()));
filter.setMatchByDefault();
- TilingPatterns<linalg::MatmulOp, linalg::BatchMatmulOp,
- linalg::GenericOp>::insert(patterns, tilingOptions, filter);
+ TilingPatterns<linalg::MatmulOp, linalg::BatchMatmulOp, linalg::GenericOp,
+ linalg::Conv2DNhwcHwcfOp,
+ linalg::Conv2DNchwFchwOp>::insert(patterns, tilingOptions,
+ filter);
+}
+
+static LogicalResult tileToSerialLoops(func::FuncOp funcOp) {
+ {
+ // Tile again at the workgroup level since redution dimension were
+ // ignored. Dimensions already tiled will be ignore since we tile to the
+ // same size.
+ RewritePatternSet wgTilingPatterns(funcOp.getContext());
+ populateTilingReductionPatterns(wgTilingPatterns);
+ if (failed(applyPatternsAndFoldGreedily(funcOp,
+ std::move(wgTilingPatterns)))) {
+ return failure();
+ }
+ }
+
+ {
+ RewritePatternSet wgTilingCanonicalizationPatterns =
+ linalg::getLinalgTilingCanonicalizationPatterns(funcOp.getContext());
+ populateAffineMinSCFCanonicalizationPattern(
+ wgTilingCanonicalizationPatterns);
+ scf::populateSCFForLoopCanonicalizationPatterns(
+ wgTilingCanonicalizationPatterns);
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(wgTilingCanonicalizationPatterns)))) {
+ return failure();
+ }
+ return success();
+ }
}
/// Return the tile size associated to one thread or warp based on the number of
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
index ebe52b6..bba0dc4 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
@@ -41,6 +41,7 @@
"emulate_i64.mlir",
"erase_storage_buffer_static_shape.mlir",
"illegal_configuration.mlir",
+ "lowering_matmul_fusion.mlir",
"lowering_matmul_promotion.mlir",
"lowering_reduction.mlir",
"map_memref_storage_class.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
index 20d6e8f..026cd5d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
@@ -37,6 +37,7 @@
"emulate_i64.mlir"
"erase_storage_buffer_static_shape.mlir"
"illegal_configuration.mlir"
+ "lowering_matmul_fusion.mlir"
"lowering_matmul_promotion.mlir"
"lowering_reduction.mlir"
"map_memref_storage_class.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_fusion.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_fusion.mlir
new file mode 100644
index 0000000..09d8f0f
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_fusion.mlir
@@ -0,0 +1,118 @@
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-spirv-lower-executable-target-pass)))' %s | FileCheck %s
+
+#compilation = #iree_codegen.compilation_info<
+ lowering_config = <tile_sizes = [[32, 128, 1, 32]]>,
+ translation_info = <SPIRVMatmulPromoteVectorize pipeline_depth = 1>,
+ workgroup_size = [32, 8, 1]>
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>,
+ #hal.descriptor_set.binding<2, storage_buffer>,
+ #hal.descriptor_set.binding<3, storage_buffer>,
+ #hal.descriptor_set.binding<3, storage_buffer>
+ ]>
+]>
+
+hal.executable @matmul_i4_quant_weight {
+ hal.executable.variant public @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Shader], []>, AMD:DiscreteGPU, #spirv.resource_limits<
+ max_compute_shared_memory_size = 49152,
+ max_compute_workgroup_invocations = 1024,
+ max_compute_workgroup_size = [65535, 65535, 65535],
+ subgroup_size = 32>>}> {
+ hal.executable.export public @matmul_i4_quant_weight ordinal(0) layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @matmul_i4_quant_weight() {
+ %c32 = arith.constant 32 : index
+ %c128 = arith.constant 128 : index
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<86x128x2048xi4>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<86x2048xf32>>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<86x2048xi4>>
+ %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x86x128xf32>>
+ %4 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4096x2048xf32>>
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %5 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
+ %6 = flow.dispatch.tensor.load %3, offsets = [%5, 0, 0], sizes = [%c32, 86, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x86x128xf32>> -> tensor<?x86x128xf32>
+ %7 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_x]
+ %8 = flow.dispatch.tensor.load %0, offsets = [0, 0, %7], sizes = [86, 128, %c128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<86x128x2048xi4>> -> tensor<86x128x?xi4>
+ %9 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_x]
+ %10 = flow.dispatch.tensor.load %1, offsets = [0, %9], sizes = [86, %c128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<86x2048xf32>> -> tensor<86x?xf32>
+ %11 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_x]
+ %12 = flow.dispatch.tensor.load %2, offsets = [0, %11], sizes = [86, %c128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<86x2048xi4>> -> tensor<86x?xi4>
+ %13 = tensor.empty() : tensor<86x128x128xf32>
+ %cast = tensor.cast %8 : tensor<86x128x?xi4> to tensor<86x128x128xi4>
+ %cast_0 = tensor.cast %10 : tensor<86x?xf32> to tensor<86x128xf32>
+ %cast_1 = tensor.cast %12 : tensor<86x?xi4> to tensor<86x128xi4>
+ %14 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]
+ } ins(%cast, %cast_0, %cast_1 : tensor<86x128x128xi4>, tensor<86x128xf32>, tensor<86x128xi4>) outs(%13 : tensor<86x128x128xf32>) {
+ ^bb0(%in: i4, %in_4: f32, %in_5: i4, %out: f32):
+ %20 = arith.extsi %in : i4 to i32
+ %21 = arith.extsi %in_5 : i4 to i32
+ %22 = arith.subi %20, %21 : i32
+ %23 = arith.sitofp %22 : i32 to f32
+ %24 = arith.mulf %23, %in_4 : f32
+ linalg.yield %24 : f32
+ } -> tensor<86x128x128xf32>
+ %15 = tensor.empty() : tensor<32x128xf32>
+ %16 = linalg.fill ins(%cst : f32) outs(%15 : tensor<32x128xf32>) -> tensor<32x128xf32>
+ %cast_2 = tensor.cast %6 : tensor<?x86x128xf32> to tensor<32x86x128xf32>
+ %17 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]
+ } ins(%cast_2, %14 : tensor<32x86x128xf32>, tensor<86x128x128xf32>) outs(%16 : tensor<32x128xf32>) attrs = {compilation_info = #compilation} {
+ ^bb0(%in: f32, %in_4: f32, %out: f32):
+ %20 = arith.mulf %in, %in_4 : f32
+ %21 = arith.addf %out, %20 : f32
+ linalg.yield %21 : f32
+ } -> tensor<32x128xf32>
+ %cast_3 = tensor.cast %17 : tensor<32x128xf32> to tensor<?x?xf32>
+ %18 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
+ %19 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_x]
+ flow.dispatch.tensor.store %cast_3, %4, offsets = [%18, %19], sizes = [%c32, %c128], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:tensor<4096x2048xf32>>
+ return
+ }
+ }
+ }
+}
+
+// CHECK-LABEL: func.func @matmul_i4_quant_weight()
+// CHECK: %[[A_ALLOC:.+]] = memref.alloc() : memref<32x1x36xf32, #gpu.address_space<workgroup>>
+// CHECK: %[[B_ALLOC:.+]] = memref.alloc() : memref<1x32x132xf32, #gpu.address_space<workgroup>>
+// CHECK: %[[WEIGHT_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK: %[[SCALE_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK: %[[ZP_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
+// CHECK: scf.for %arg0 = %c0 to %c86 step %c1 iter_args({{.+}}) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>)
+// CHECK: %[[SCALE:.+]] = vector.transfer_read %[[SCALE_BINDING]]
+// CHECK: %[[ZP:.+]] = vector.transfer_read %[[ZP_BINDING]]
+// CHECK: %[[ZP_EXT:.+]] = arith.extsi %[[ZP]] : vector<4xi4> to vector<4xi32>
+// CHECK: scf.for %arg5 = %c0 to %c96 step %c32 iter_args({{.+}}) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>)
+// CHECK-COUNT-4: vector.transfer_read %[[WEIGHT_BINDING]]
+// CHECK-COUNT-4: arith.extsi %{{.+}} : vector<4xi4> to vector<4xi32>
+// CHECK-COUNT-4: arith.subi %{{.+}}, %[[ZP_EXT]] : vector<4xi32>
+// CHECK-COUNT-4: arith.sitofp %{{.+}} : vector<4xi32> to vector<4xf32>
+// CHECK-COUNT-4: arith.mulf %{{.+}}, %[[SCALE]] : vector<4xf32>
+// CHECK-COUNT-4: vector.transfer_write %{{.+}}, %[[B_ALLOC]]
+// CHECK: gpu.barrier
+// CHECK: vector.transfer_write %{{.+}}, %[[A_ALLOC]]
+// CHECK: gpu.barrier
+// CHECK-COUNT-32: vector.transfer_read %[[A_ALLOC]]
+// CHECK-COUNT-32: vector.transfer_read %[[B_ALLOC]]
+// CHECK-COUNT-128: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<4xf32>
+// CHECK-COUNT-2: scf.yield