[GPU] Add a pass to convert accumulating GEMMs to GEMMs (#19587)
Converts dispatches with accumulating GEMMs that are doing in place
read/write to GEMM + elementwise add.
This is needed for the TileAndFuse path until we find a more permanent
fix for https://github.com/iree-org/iree/issues/19546
---------
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
index f95b0fa..8582cf9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
@@ -94,6 +94,7 @@
"CleanupBufferAllocViewPass.cpp",
"ConcretizePadResultShape.cpp",
"ConfigTrackingCanonicalizer.cpp",
+ "ConvertAccGEMMToGEMMPass.cpp",
"ConvertBf16ArithToF32.cpp",
"ConvertBf16ToUInt16Buffers.cpp",
"ConvertToDestinationPassingStylePass.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index af3c557..1dd9f91 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -86,6 +86,7 @@
"CleanupBufferAllocViewPass.cpp"
"ConcretizePadResultShape.cpp"
"ConfigTrackingCanonicalizer.cpp"
+ "ConvertAccGEMMToGEMMPass.cpp"
"ConvertBf16ArithToF32.cpp"
"ConvertBf16ToUInt16Buffers.cpp"
"ConvertToDestinationPassingStylePass.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertAccGEMMToGEMMPass.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertAccGEMMToGEMMPass.cpp
new file mode 100644
index 0000000..40e17d9
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvertAccGEMMToGEMMPass.cpp
@@ -0,0 +1,125 @@
+// Copyright 2025 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//===- ConvertAccGEMMtoGEMMpass.cpp ----------------------------------===//
+//
+// Converts Accumulating GEMM to GEMM + elementwise add.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+namespace mlir::iree_compiler {
+
+#define GEN_PASS_DEF_CONVERTACCGEMMTOGEMMPASS
+#include "iree/compiler/Codegen/Common/Passes.h.inc"
+
+namespace {
+
+struct ConvertAccGEMMtoGEMM final
+ : OpInterfaceRewritePattern<linalg::LinalgOp> {
+ using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
+ PatternRewriter &rewriter) const override {
+ if (!linalg::isaContractionOpInterface(linalgOp) &&
+ !isa<linalg::ConvolutionOpInterface>(*linalgOp)) {
+ return failure();
+ }
+ if (!linalgOp.hasPureTensorSemantics())
+ return failure();
+
+ // Nothing to do if the output tensor operand is already a fill op.
+ SmallVector<OpOperand *> outputOperands;
+ if (!linalgOp.hasPureBufferSemantics()) {
+ outputOperands = llvm::to_vector(
+ llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
+ }
+
+ Value outputOperand = outputOperands.front()->get();
+
+ auto outsDefiningOp =
+ outputOperand.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
+ if (!outsDefiningOp) {
+ // If not DispatchTensorLoadOp then do nothing.
+ return failure();
+ }
+ auto outputType = cast<RankedTensorType>(outputOperand.getType());
+ if (!outputType.getElementType().isIntOrFloat())
+ return failure();
+ auto elementType = outputType.getElementType();
+
+ Location loc = linalgOp.getLoc();
+
+ // Check if the output tensor access is a projected permutation
+ if (!linalgOp.getMatchingIndexingMap(outputOperands.front())
+ .isProjectedPermutation()) {
+ return rewriter.notifyMatchFailure(
+ linalgOp, "Output indexing map must be a projected permutation.");
+ }
+
+ int64_t outputRank = outputType.getRank();
+ SmallVector<utils::IteratorType> iterators(outputRank,
+ utils::IteratorType::parallel);
+ SmallVector<AffineMap> maps(3, rewriter.getMultiDimIdentityMap(outputRank));
+
+ // Create a zero tensor as the new output tensor operand to the Linalg
+ // contraction op.
+ SmallVector<OpFoldResult> mixedSizes =
+ tensor::getMixedSizes(rewriter, loc, outputOperand);
+ auto initOp =
+ rewriter.create<tensor::EmptyOp>(loc, mixedSizes, elementType);
+ Value zero = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(elementType));
+ Value fill =
+ rewriter.create<linalg::FillOp>(loc, zero, initOp.getResult()).result();
+
+ // Update the contraction op to use the new zero tensor as output operand.
+ rewriter.modifyOpInPlace(linalgOp,
+ [&]() { linalgOp.setDpsInitOperand(0, fill); });
+
+ // Create a generic op to add back the original output tensor operand.
+ rewriter.setInsertionPointAfter(linalgOp);
+ auto genericOp = rewriter.create<linalg::GenericOp>(
+ loc, outputType, ValueRange{linalgOp->getResult(0), outputOperand},
+ fill, maps, iterators,
+ [&](OpBuilder &b, Location nestedLoc, ValueRange args) {
+ Value result;
+ if (llvm::isa<FloatType>(elementType)) {
+ result = b.create<arith::AddFOp>(nestedLoc, args[0], args[1]);
+ } else {
+ result = b.create<arith::AddIOp>(nestedLoc, args[0], args[1]);
+ }
+ b.create<linalg::YieldOp>(nestedLoc, result);
+ });
+ linalgOp->getResult(0).replaceAllUsesExcept(genericOp->getResult(0),
+ genericOp);
+ return success();
+ }
+};
+
+struct ConvertAccGEMMToGEMMPass final
+ : impl::ConvertAccGEMMToGEMMPassBase<ConvertAccGEMMToGEMMPass> {
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ patterns.add<ConvertAccGEMMtoGEMM>(&getContext());
+ walkAndApplyPatterns(getOperation(), std::move(patterns));
+ }
+};
+
+} // namespace
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td
index 5cc0d55..1854279 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td
@@ -66,6 +66,11 @@
"implements OffsetSizeAndStrideOpInterface.";
}
+def ConvertAccGEMMToGEMMPass :
+ Pass<"iree-convert-accgemm-to-gemm", ""> {
+ let summary = "Convert accumulating GEMMs to GEMMs post dispatch creation.";
+}
+
def ConvertBf16ArithToF32Pass : Pass<"iree-convert-bf16-arith-to-f32", ""> {
let summary = "Convert bf16 arithmetic operations to f32";
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
index 45d93be..edbb5d8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
@@ -25,6 +25,7 @@
"bubble_up_ordinal_ops.mlir",
"bufferize_copy_only_dispatches.mlir",
"canonicalize_interface_load_store.mlir",
+ "convert_accgemm_to_gemm.mlir",
"convert_bf16_to_uint16_buffers.mlir",
"convert_bf16_arith_to_f32.mlir",
"convert_to_destination_passing_style.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index 74566ef..e240940 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -21,6 +21,7 @@
"bubble_up_ordinal_ops.mlir"
"bufferize_copy_only_dispatches.mlir"
"canonicalize_interface_load_store.mlir"
+ "convert_accgemm_to_gemm.mlir"
"convert_bf16_arith_to_f32.mlir"
"convert_bf16_to_uint16_buffers.mlir"
"convert_to_destination_passing_style.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convert_accgemm_to_gemm.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convert_accgemm_to_gemm.mlir
new file mode 100644
index 0000000..07b8bfa
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/convert_accgemm_to_gemm.mlir
@@ -0,0 +1,82 @@
+// RUN: iree-opt --split-input-file --iree-convert-accgemm-to-gemm %s | FileCheck %s
+
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+ #hal.pipeline.binding<storage_buffer>
+]>
+
+func.func @accumulate_gemm(%1 : tensor<512x128xi8>, %2 : tensor<512x128xi8>) {
+ %c0 = arith.constant 0 : index
+ %3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readwrite:tensor<512x512xi32>>
+ %4 = flow.dispatch.tensor.load %3, offsets = [0, 0], sizes = [512, 512], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<512x512xi32>> -> tensor<512x512xi32>
+ %5 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%1, %2 : tensor<512x128xi8>, tensor<512x128xi8>) outs(%4 : tensor<512x512xi32>) {
+ ^bb0(%in: i8, %in_0: i8, %out: i32):
+ %6 = arith.extsi %in : i8 to i32
+ %7 = arith.extsi %in_0 : i8 to i32
+ %8 = arith.muli %6, %7 : i32
+ %9 = arith.addi %out, %8 : i32
+ linalg.yield %9 : i32
+ } -> tensor<512x512xi32>
+ flow.dispatch.tensor.store %5, %3, offsets = [0, 0], sizes = [512, 512], strides = [1, 1] : tensor<512x512xi32> -> !flow.dispatch.tensor<readwrite:tensor<512x512xi32>>
+ return
+}
+
+// CHECK-LABEL: func.func @accumulate_gemm
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32
+// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<512x512xi32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[C0]] : i32) outs(%[[EMPTY]] : tensor<512x512xi32>) -> tensor<512x512xi32>
+// CHECK: %[[GEMM:.+]] = linalg.generic {{.*}} outs(%[[FILL]] : tensor<512x512xi32>) {
+// CHECK: %[[ADD:.+]] = linalg.generic {{.+}} ins(%[[GEMM]]
+// CHECK: flow.dispatch.tensor.store %[[ADD]]
+
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+ #hal.pipeline.binding<storage_buffer>
+]>
+
+func.func @acc_conv_nchw(%1 : tensor<1x64x58x58xf32>, %2 : tensor<64x64x3x3xf32>) {
+ %c0 = arith.constant 0 : index
+ %3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<readwrite:tensor<1x64x56x56xf32>>
+ %4 = flow.dispatch.tensor.load %3, offsets = [0, 0, 0, 0], sizes = [1, 64, 56, 56], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readwrite:tensor<1x64x56x56xf32>> -> tensor<1x64x56x56xf32>
+ %5 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
+ ins(%1, %2 : tensor<1x64x58x58xf32>, tensor<64x64x3x3xf32>) outs(%4 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
+ flow.dispatch.tensor.store %5, %3, offsets = [0, 0, 0, 0], sizes = [1, 64, 56, 56], strides = [1, 1, 1, 1] : tensor<1x64x56x56xf32> -> !flow.dispatch.tensor<readwrite:tensor<1x64x56x56xf32>>
+ return
+}
+
+// CHECK-LABEL: func.func @acc_conv_nchw
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<1x64x56x56xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[EMPTY]] : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
+// CHECK: %[[CONV:.+]] = linalg.conv_2d_nchw_fchw {{.*}} outs(%[[FILL]] : tensor<1x64x56x56xf32>)
+// CHECK: %[[ADD:.+]] = linalg.generic {{.+}} ins(%[[CONV]]
+// CHECK: flow.dispatch.tensor.store %[[ADD]]
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+ #hal.pipeline.binding<storage_buffer>
+]>
+
+
+func.func @nonacc_gemm(%1 : tensor<512x128xi8>, %2 : tensor<512x128xi8>) {
+ %c0_i32 = arith.constant 0 : i32
+ %c0 = arith.constant 0 : index
+ %3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<512x512xi32>>
+ %empty = tensor.empty() : tensor<512x512xi32>
+ %fill = linalg.fill ins(%c0_i32 : i32) outs(%empty : tensor<512x512xi32>) -> tensor<512x512xi32>
+ %5 = linalg.matmul_transpose_b
+ ins(%1, %2 : tensor<512x128xi8>, tensor<512x128xi8>) outs(%fill : tensor<512x512xi32>) -> tensor<512x512xi32>
+ flow.dispatch.tensor.store %5, %3, offsets = [0, 0], sizes = [512, 512], strides = [1, 1] : tensor<512x512xi32> -> !flow.dispatch.tensor<writeonly:tensor<512x512xi32>>
+ return
+}
+
+// CHECK-LABEL: func.func @nonacc_gemm
+// CHECK: linalg.matmul_transpose_b
+// CHECK-NOT: linalg.generic
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index f8399d3..d0a269e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -347,7 +347,9 @@
if (pipelineOptions.useIgemmConvolution) {
funcPassManager.addPass(createConvolutionToIGEMMPass());
}
-
+ // TODO (nirvedhmeshram) : Can remove this pass after
+ // https://github.com/iree-org/iree/issues/19546 is fixed.
+ funcPassManager.addPass(createConvertAccGEMMToGEMMPass());
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/true,
/*convertToDpsOptions=*/std::nullopt);