[GPU] Add a pass to convert accumulating GEMMs to GEMMs (#19587)

Converts dispatches with accumulating GEMMs that are doing in place
read/write to GEMM + elementwise add.
This is needed for the TileAndFuse path until we find a more permanent
fix for https://github.com/iree-org/iree/issues/19546

---------

Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
index f95b0fa..8582cf9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
@@ -94,6 +94,7 @@
         "CleanupBufferAllocViewPass.cpp",
         "ConcretizePadResultShape.cpp",
         "ConfigTrackingCanonicalizer.cpp",
+        "ConvertAccGEMMToGEMMPass.cpp",
         "ConvertBf16ArithToF32.cpp",
         "ConvertBf16ToUInt16Buffers.cpp",
         "ConvertToDestinationPassingStylePass.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index af3c557..1dd9f91 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -86,6 +86,7 @@
     "CleanupBufferAllocViewPass.cpp"
     "ConcretizePadResultShape.cpp"
     "ConfigTrackingCanonicalizer.cpp"
+    "ConvertAccGEMMToGEMMPass.cpp"
     "ConvertBf16ArithToF32.cpp"
     "ConvertBf16ToUInt16Buffers.cpp"
     "ConvertToDestinationPassingStylePass.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertAccGEMMToGEMMPass.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertAccGEMMToGEMMPass.cpp
new file mode 100644
index 0000000..40e17d9
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvertAccGEMMToGEMMPass.cpp
@@ -0,0 +1,125 @@
+// Copyright 2025 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//===- ConvertAccGEMMtoGEMMpass.cpp ----------------------------------===//
+//
+// Converts Accumulating GEMM to GEMM + elementwise add.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+namespace mlir::iree_compiler {
+
+#define GEN_PASS_DEF_CONVERTACCGEMMTOGEMMPASS
+#include "iree/compiler/Codegen/Common/Passes.h.inc"
+
+namespace {
+
+struct ConvertAccGEMMtoGEMM final
+    : OpInterfaceRewritePattern<linalg::LinalgOp> {
+  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
+                                PatternRewriter &rewriter) const override {
+    if (!linalg::isaContractionOpInterface(linalgOp) &&
+        !isa<linalg::ConvolutionOpInterface>(*linalgOp)) {
+      return failure();
+    }
+    if (!linalgOp.hasPureTensorSemantics())
+      return failure();
+
+    // Nothing to do if the output tensor operand is already a fill op.
+    SmallVector<OpOperand *> outputOperands;
+    if (!linalgOp.hasPureBufferSemantics()) {
+      outputOperands = llvm::to_vector(
+          llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
+    }
+
+    Value outputOperand = outputOperands.front()->get();
+
+    auto outsDefiningOp =
+        outputOperand.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
+    if (!outsDefiningOp) {
+      // If not DispatchTensorLoadOp then do nothing.
+      return failure();
+    }
+    auto outputType = cast<RankedTensorType>(outputOperand.getType());
+    if (!outputType.getElementType().isIntOrFloat())
+      return failure();
+    auto elementType = outputType.getElementType();
+
+    Location loc = linalgOp.getLoc();
+
+    // Check if the output tensor access is a projected permutation
+    if (!linalgOp.getMatchingIndexingMap(outputOperands.front())
+             .isProjectedPermutation()) {
+      return rewriter.notifyMatchFailure(
+          linalgOp, "Output indexing map must be a projected permutation.");
+    }
+
+    int64_t outputRank = outputType.getRank();
+    SmallVector<utils::IteratorType> iterators(outputRank,
+                                               utils::IteratorType::parallel);
+    SmallVector<AffineMap> maps(3, rewriter.getMultiDimIdentityMap(outputRank));
+
+    // Create a zero tensor as the new output tensor operand to the Linalg
+    // contraction op.
+    SmallVector<OpFoldResult> mixedSizes =
+        tensor::getMixedSizes(rewriter, loc, outputOperand);
+    auto initOp =
+        rewriter.create<tensor::EmptyOp>(loc, mixedSizes, elementType);
+    Value zero = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(elementType));
+    Value fill =
+        rewriter.create<linalg::FillOp>(loc, zero, initOp.getResult()).result();
+
+    // Update the contraction op to use the new zero tensor as output operand.
+    rewriter.modifyOpInPlace(linalgOp,
+                             [&]() { linalgOp.setDpsInitOperand(0, fill); });
+
+    // Create a generic op to add back the original output tensor operand.
+    rewriter.setInsertionPointAfter(linalgOp);
+    auto genericOp = rewriter.create<linalg::GenericOp>(
+        loc, outputType, ValueRange{linalgOp->getResult(0), outputOperand},
+        fill, maps, iterators,
+        [&](OpBuilder &b, Location nestedLoc, ValueRange args) {
+          Value result;
+          if (llvm::isa<FloatType>(elementType)) {
+            result = b.create<arith::AddFOp>(nestedLoc, args[0], args[1]);
+          } else {
+            result = b.create<arith::AddIOp>(nestedLoc, args[0], args[1]);
+          }
+          b.create<linalg::YieldOp>(nestedLoc, result);
+        });
+    linalgOp->getResult(0).replaceAllUsesExcept(genericOp->getResult(0),
+                                                genericOp);
+    return success();
+  }
+};
+
+struct ConvertAccGEMMToGEMMPass final
+    : impl::ConvertAccGEMMToGEMMPassBase<ConvertAccGEMMToGEMMPass> {
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    patterns.add<ConvertAccGEMMtoGEMM>(&getContext());
+    walkAndApplyPatterns(getOperation(), std::move(patterns));
+  }
+};
+
+} // namespace
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td
index 5cc0d55..1854279 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td
@@ -66,6 +66,11 @@
       "implements OffsetSizeAndStrideOpInterface.";
 }
 
+def ConvertAccGEMMToGEMMPass :
+    Pass<"iree-convert-accgemm-to-gemm", ""> {
+  let summary = "Convert accumulating GEMMs to GEMMs post dispatch creation.";
+}
+
 def ConvertBf16ArithToF32Pass : Pass<"iree-convert-bf16-arith-to-f32", ""> {
   let summary = "Convert bf16 arithmetic operations to f32";
 }
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
index 45d93be..edbb5d8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
@@ -25,6 +25,7 @@
             "bubble_up_ordinal_ops.mlir",
             "bufferize_copy_only_dispatches.mlir",
             "canonicalize_interface_load_store.mlir",
+            "convert_accgemm_to_gemm.mlir",
             "convert_bf16_to_uint16_buffers.mlir",
             "convert_bf16_arith_to_f32.mlir",
             "convert_to_destination_passing_style.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index 74566ef..e240940 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -21,6 +21,7 @@
     "bubble_up_ordinal_ops.mlir"
     "bufferize_copy_only_dispatches.mlir"
     "canonicalize_interface_load_store.mlir"
+    "convert_accgemm_to_gemm.mlir"
     "convert_bf16_arith_to_f32.mlir"
     "convert_bf16_to_uint16_buffers.mlir"
     "convert_to_destination_passing_style.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convert_accgemm_to_gemm.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convert_accgemm_to_gemm.mlir
new file mode 100644
index 0000000..07b8bfa
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/convert_accgemm_to_gemm.mlir
@@ -0,0 +1,82 @@
+// RUN: iree-opt --split-input-file --iree-convert-accgemm-to-gemm %s | FileCheck %s
+
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+  #hal.pipeline.binding<storage_buffer>
+]>
+
+func.func @accumulate_gemm(%1 : tensor<512x128xi8>, %2 : tensor<512x128xi8>) {
+  %c0 = arith.constant 0 : index
+  %3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readwrite:tensor<512x512xi32>>
+  %4 = flow.dispatch.tensor.load %3, offsets = [0, 0], sizes = [512, 512], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<512x512xi32>> -> tensor<512x512xi32>
+  %5 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                     affine_map<(d0, d1, d2) -> (d1, d2)>,
+                     affine_map<(d0, d1, d2) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel", "reduction"]}
+    ins(%1, %2 : tensor<512x128xi8>, tensor<512x128xi8>) outs(%4 : tensor<512x512xi32>) {
+      ^bb0(%in: i8, %in_0: i8, %out: i32):
+        %6 = arith.extsi %in : i8 to i32
+        %7 = arith.extsi %in_0 : i8 to i32
+        %8 = arith.muli %6, %7 : i32
+        %9 = arith.addi %out, %8 : i32
+        linalg.yield %9 : i32
+      } -> tensor<512x512xi32>
+  flow.dispatch.tensor.store %5, %3, offsets = [0, 0], sizes = [512, 512], strides = [1, 1] : tensor<512x512xi32> -> !flow.dispatch.tensor<readwrite:tensor<512x512xi32>>
+  return
+}
+
+// CHECK-LABEL: func.func @accumulate_gemm
+//   CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32
+//   CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<512x512xi32>
+//       CHECK: %[[FILL:.+]] = linalg.fill ins(%[[C0]] : i32) outs(%[[EMPTY]] : tensor<512x512xi32>) -> tensor<512x512xi32>
+//       CHECK: %[[GEMM:.+]] = linalg.generic {{.*}} outs(%[[FILL]] : tensor<512x512xi32>) {
+//       CHECK: %[[ADD:.+]] = linalg.generic {{.+}} ins(%[[GEMM]]
+//       CHECK: flow.dispatch.tensor.store %[[ADD]]
+
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+  #hal.pipeline.binding<storage_buffer>
+]>
+
+func.func @acc_conv_nchw(%1 : tensor<1x64x58x58xf32>, %2 : tensor<64x64x3x3xf32>) {
+  %c0 = arith.constant 0 : index
+  %3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<readwrite:tensor<1x64x56x56xf32>>
+  %4 = flow.dispatch.tensor.load %3, offsets = [0, 0, 0, 0], sizes = [1, 64, 56, 56], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readwrite:tensor<1x64x56x56xf32>> -> tensor<1x64x56x56xf32>
+  %5 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
+      ins(%1, %2 : tensor<1x64x58x58xf32>, tensor<64x64x3x3xf32>) outs(%4 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
+  flow.dispatch.tensor.store %5, %3, offsets = [0, 0, 0, 0], sizes = [1, 64, 56, 56], strides = [1, 1, 1, 1] : tensor<1x64x56x56xf32> -> !flow.dispatch.tensor<readwrite:tensor<1x64x56x56xf32>>
+  return
+}
+
+// CHECK-LABEL: func.func @acc_conv_nchw
+//   CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
+//   CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<1x64x56x56xf32>
+//       CHECK: %[[FILL:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[EMPTY]] : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
+//       CHECK: %[[CONV:.+]] = linalg.conv_2d_nchw_fchw {{.*}} outs(%[[FILL]] : tensor<1x64x56x56xf32>)
+//       CHECK: %[[ADD:.+]] = linalg.generic {{.+}} ins(%[[CONV]]
+//       CHECK: flow.dispatch.tensor.store %[[ADD]]
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+  #hal.pipeline.binding<storage_buffer>
+]>
+
+
+func.func @nonacc_gemm(%1 : tensor<512x128xi8>, %2 : tensor<512x128xi8>) {
+  %c0_i32 = arith.constant 0 : i32
+  %c0 = arith.constant 0 : index
+  %3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<512x512xi32>>
+  %empty = tensor.empty() : tensor<512x512xi32>
+  %fill = linalg.fill ins(%c0_i32 : i32) outs(%empty : tensor<512x512xi32>) -> tensor<512x512xi32>
+  %5 = linalg.matmul_transpose_b
+    ins(%1, %2 : tensor<512x128xi8>, tensor<512x128xi8>) outs(%fill : tensor<512x512xi32>) -> tensor<512x512xi32>
+  flow.dispatch.tensor.store %5, %3, offsets = [0, 0], sizes = [512, 512], strides = [1, 1] : tensor<512x512xi32> -> !flow.dispatch.tensor<writeonly:tensor<512x512xi32>>
+  return
+}
+
+// CHECK-LABEL: func.func @nonacc_gemm
+//       CHECK: linalg.matmul_transpose_b
+//   CHECK-NOT: linalg.generic
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index f8399d3..d0a269e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -347,7 +347,9 @@
   if (pipelineOptions.useIgemmConvolution) {
     funcPassManager.addPass(createConvolutionToIGEMMPass());
   }
-
+  // TODO (nirvedhmeshram) : Can remove this pass after
+  // https://github.com/iree-org/iree/issues/19546 is fixed.
+  funcPassManager.addPass(createConvertAccGEMMToGEMMPass());
   tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/true,
                                /*convertToDpsOptions=*/std::nullopt);