[gpu] Enable fusing input producers after tiling reduction loops (#13806)

This commit enables fusing input producers into the
serial loop after tiling the matmul K dimension. 
This enables using the GPU shared memory promoted
for A/B matrix slices for weight dequantization, without
introducing additional shared memory allocations
during bufferization. This is to support int4 weight
quantized matmuls in LLMs.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
index 3d5e870..72c6525 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
@@ -47,6 +47,7 @@
         "@llvm-project//mlir:AffineUtils",
         "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:BufferizationDialect",
+        "@llvm-project//mlir:DestinationStyleOpInterface",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:GPUDialect",
         "@llvm-project//mlir:GPUTransformOps",
@@ -64,6 +65,7 @@
         "@llvm-project//mlir:SCFUtils",
         "@llvm-project//mlir:SideEffectInterfaces",
         "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TensorDialect",
         "@llvm-project//mlir:Transforms",
         "@llvm-project//mlir:VectorDialect",
         "@llvm-project//mlir:VectorToSCF",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
index 518372f..a717af0 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
@@ -39,6 +39,7 @@
     MLIRAffineUtils
     MLIRArithDialect
     MLIRBufferizationDialect
+    MLIRDestinationStyleOpInterface
     MLIRFuncDialect
     MLIRGPUDialect
     MLIRGPUTransformOps
@@ -56,6 +57,7 @@
     MLIRSCFUtils
     MLIRSideEffectInterfaces
     MLIRSupport
+    MLIRTensorDialect
     MLIRTransforms
     MLIRVectorDialect
     MLIRVectorToSCF
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CommonGPUPasses.h b/compiler/src/iree/compiler/Codegen/Common/GPU/CommonGPUPasses.h
index 3eed3d5..fe66198 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/CommonGPUPasses.h
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CommonGPUPasses.h
@@ -45,8 +45,11 @@
     RewriterBase &rewriter, scf::ForOp forOp,
     PipeliningSchedulingStrategy startegy, bool peelEpilogue, int64_t depth);
 
-/// Tiles Linalg ops in the given `funcOp` to serial loops without distribution.
-LogicalResult tileToSerialLoops(func::FuncOp funcOp, bool onlyReduction = true);
+/// Tiles Linalg ops in the given `funcOp` along reduction dimensions to serial
+/// loops without distribution. If `fuseInputProducer` is true, input producers
+/// will be fused into the serial loop.
+LogicalResult tileReductionToSerialLoops(func::FuncOp funcOp,
+                                         bool fuseInputProducer = false);
 
 //===----------------------------------------------------------------------===//
 // Passes
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorAlloc.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorAlloc.cpp
index 37793fb..27fcf6b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorAlloc.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorAlloc.cpp
@@ -8,9 +8,12 @@
 #include "iree/compiler/Codegen/PassDetail.h"
 #include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "iree/compiler/Codegen/Utils/LinalgOpInfo.h"
+#include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
 
 #define DEBUG_TYPE "iree-codegen-gpu-tensor-alloc"
@@ -89,6 +92,53 @@
 }
 
 namespace {
+/// Swaps bufferization.alloc_tensor with the copied linalg op result when the
+/// linalg op does not use the output initial value during calculation.
+///
+/// This converts the following IR:
+/// ```
+/// %linalg = linalg ... ins(...) outs(...)
+/// %val = bufferization.alloc_tensor() copy(%linalg)
+/// ```
+/// Into
+/// ```
+/// %alloc = bufferization.alloc_tensor()
+/// %val = linalg ... ins(...) outs(%alloc)
+/// ```
+struct SwapAllocTensorPattern final
+    : OpRewritePattern<bufferization::AllocTensorOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(bufferization::AllocTensorOp allocOp,
+                                PatternRewriter &rewriter) const override {
+    if (!allocOp.getCopy()) return failure();
+    auto linalgOp = allocOp.getCopy().getDefiningOp<linalg::LinalgOp>();
+    if (!linalgOp) return failure();
+
+    // Make sure we don't use the initial values for the linalg output we are
+    // copying during the tensor allocation.
+    unsigned resultNumber = cast<OpResult>(allocOp.getCopy()).getResultNumber();
+    OpOperand *initOperand = linalgOp.getDpsInitOperand(resultNumber);
+    if (linalgOp.payloadUsesValueFromOperand(initOperand)) return failure();
+
+    rewriter.setInsertionPoint(linalgOp);
+    std::optional<Attribute> memorySpace = allocOp.getMemorySpace();
+    auto newAllocOp = rewriter.create<bufferization::AllocTensorOp>(
+        allocOp.getLoc(), allocOp.getType(), allocOp.getDynamicSizes(),
+        /*copy=*/Value(),
+        memorySpace ? cast<IntegerAttr>(*memorySpace) : IntegerAttr());
+    newAllocOp->setAttr(bufferization::BufferizationDialect::kEscapeAttrName,
+                        rewriter.getBoolArrayAttr({false}));
+    rewriter.updateRootInPlace(linalgOp, [&]() {
+      linalgOp->setOperand(linalgOp.getNumDpsInputs() + resultNumber,
+                           newAllocOp);
+    });
+    rewriter.replaceOp(allocOp, linalgOp->getResult(resultNumber));
+
+    return failure();
+  }
+};
+
 struct GPUTensorAllocPass : public GPUTensorAllocBase<GPUTensorAllocPass> {
  private:
   GPUPromoteSharedMemPattern promoteSharedMemPattern =
@@ -104,10 +154,17 @@
     auto funcOp = getOperation();
 
     // Tile the reduction first to reduce the alloc size.
-    if (failed(tileToSerialLoops(funcOp))) {
+    if (failed(
+            tileReductionToSerialLoops(funcOp, /*fuseInputProducer=*/true))) {
       return signalPassFailure();
     }
 
+    LLVM_DEBUG({
+      llvm::dbgs() << "// --- After tiling to serial loops ---\n";
+      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+      llvm::dbgs() << "\n\n";
+    });
+
     SmallVector<Operation *> opsToPromote;
     funcOp.walk([&](Operation *op) {
       switch (promoteSharedMemPattern) {
@@ -152,6 +209,24 @@
           break;
       }
     }
+
+    LLVM_DEBUG({
+      llvm::dbgs() << "// --- After promotion ---\n";
+      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+      llvm::dbgs() << "\n\n";
+    });
+
+    // Move tensor allocations earlier and use them for linalg init operands
+    // when possible. This change cleans up the IR to avoid bufferization
+    // creating extra buffers in later stages.
+    {
+      MLIRContext *context = &getContext();
+      RewritePatternSet patterns(context);
+      patterns.add<SwapAllocTensorPattern>(context);
+      if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+        return signalPassFailure();
+      }
+    }
   }
 };
 }  // namespace
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp
index 822c117..8094e91 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp
@@ -21,7 +21,9 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using mlir::iree_compiler::IREE::LinalgExt::TilingPatterns;
@@ -31,31 +33,118 @@
 namespace mlir {
 namespace iree_compiler {
 
+class TileConsumerAndFuseInputProducer final
+    : public OpInterfaceRewritePattern<TilingInterface> {
+ public:
+  TileConsumerAndFuseInputProducer(
+      MLIRContext *context, IREE::LinalgExt::LinalgTransformationFilter filter,
+      bool fuseInputProducer, PatternBenefit benefit = 1)
+      : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+        filter(std::move(filter)),
+        fuseInputProducer(fuseInputProducer) {}
+
+  LogicalResult matchAndRewrite(TilingInterface op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(filter.checkAndNotify(rewriter, op))) return failure();
+
+    // Make sure we have a PartitionableLoopInterface op here and query the tile
+    // sizes from the partitionable loops.
+    auto plOp = dyn_cast<PartitionableLoopsInterface>(*op);
+    if (!plOp) return failure();
+    auto partitionedLoops = plOp.getPartitionableLoops(kNumMaxParallelDims);
+    SmallVector<int64_t, 4> tileSizes = getTileSizes(op, 0);
+    if (tileSizes.empty()) return failure();
+    // Mask out non reduction dimensions.
+    for (unsigned depth : partitionedLoops) {
+      if (depth < tileSizes.size()) tileSizes[depth] = 0;
+    }
+
+    // Make sure we have a tile size for each dimension.
+    // TODO: This is currently needed for LLVMGPU, where we propagate the
+    // lowering configuration to all linalg ops. Some linalg ops may not have
+    // the same rank, e.g., the configuration for a matmul attached to a
+    // producer linalg.fill op. It implicitly assumes that the leading
+    // dimensions of different linalg ops match, which is the current status;
+    // but may not hold true in the long term.
+    tileSizes.resize(op.getLoopIteratorTypes().size());
+
+    if (llvm::all_of(tileSizes, [](int64_t s) { return s == 0; })) {
+      return failure();
+    }
+
+    // Tile the current op and fuse its immediate input operands.
+    FailureOr<scf::SCFTilingResult> tilingResult =
+        tileConsumerAndFuseInputProducer(rewriter, op, tileSizes);
+    if (failed(tilingResult)) {
+      return rewriter.notifyMatchFailure(op, "failed to tile consumer");
+    }
+
+    // Replace the tiled op with replacements.
+    rewriter.replaceOp(op, tilingResult->replacements);
+    filter.replaceLinalgTransformationFilter(rewriter,
+                                             tilingResult->tiledOps.front());
+    return success();
+  }
+
+ private:
+  FailureOr<scf::SCFTilingResult> tileConsumerAndFuseInputProducer(
+      RewriterBase &rewriter, TilingInterface consumer,
+      ArrayRef<int64_t> tileSizes) const {
+    // First tile the current op as the consumer op.
+    auto tilingOptions = scf::SCFTilingOptions().setTileSizes(tileSizes);
+    FailureOr<scf::SCFTilingResult> tilingResult =
+        tileUsingSCFForOp(rewriter, consumer, tilingOptions);
+    if (failed(tilingResult)) {
+      return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
+    }
+
+    if (!fuseInputProducer) return tilingResult;
+    // If there are no generated loops generated, fusion is immaterial.
+    if (tilingResult->loops.empty()) return tilingResult;
+
+    // Collect immediate input operands that are fusable into the tiled loop.
+    // We have tensor extract slice ops taking slices of the untiled op.
+    //
+    // Note that this excludes init operands for correctness. Input operands are
+    // fine to fuse, at the cost of recomputation though.
+    SmallVector<tensor::ExtractSliceOp> candidates;
+    assert(tilingResult->tiledOps.size() == 1);
+    Operation *tiledOp = tilingResult->tiledOps.front();
+    auto dsOp = dyn_cast<DestinationStyleOpInterface>(tiledOp);
+    if (!dsOp) return tilingResult;
+    for (OpOperand *operand : dsOp.getDpsInputOperands()) {
+      auto sliceOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
+      if (!sliceOp) continue;
+      auto linalgOp = sliceOp.getSource().getDefiningOp<linalg::LinalgOp>();
+      if (!linalgOp) continue;
+      // Restrict to fully parallel linalg ops for now for simplicity.
+      auto isParallel = [](utils::IteratorType it) {
+        return linalg::isParallelIterator(it);
+      };
+      if (llvm::all_of(linalgOp.getIteratorTypesArray(), isParallel)) {
+        candidates.push_back(sliceOp);
+      }
+    }
+
+    // Fuse the candidate immeidate operands into the tiled loop.
+    OpBuilder::InsertionGuard guard(rewriter);
+    while (!candidates.empty()) {
+      tensor::ExtractSliceOp sliceOp = candidates.back();
+      candidates.pop_back();
+      tileAndFuseProducerOfSlice(rewriter, sliceOp, tilingResult->loops);
+    }
+    return tilingResult;
+  }
+
+  IREE::LinalgExt::LinalgTransformationFilter filter;
+  bool fuseInputProducer;
+};
+
 /// Patterns for workgroup level tiling. Workgroup tiling is done at the flow
 /// level but we may have extra tiling for the reduction dimension. Therefore we
 /// tile again without distributing.
 static void populateTilingPatterns(RewritePatternSet &patterns,
-                                   bool onlyReduction) {
-  auto tileSizesFn = [onlyReduction](OpBuilder &builder,
-                                     Operation *op) -> SmallVector<Value, 4> {
-    auto interfaceOp = cast<PartitionableLoopsInterface>(*op);
-    auto partitionedLoops =
-        interfaceOp.getPartitionableLoops(kNumMaxParallelDims);
-    SmallVector<Value, 4> tileSizes = getTileSizes(builder, op, 0);
-    if (onlyReduction) {
-      auto zero = builder.create<arith::ConstantIndexOp>(op->getLoc(), 0);
-      for (unsigned depth : partitionedLoops) {
-        if (depth < tileSizes.size()) {
-          tileSizes[depth] = zero;
-        }
-      }
-    }
-    return tileSizes;
-  };
-
-  auto tilingOptions = linalg::LinalgTilingOptions()
-                           .setLoopType(linalg::LinalgTilingLoopType::Loops)
-                           .setTileSizeComputationFunction(tileSizesFn);
+                                   bool fuseInputProducer) {
   MLIRContext *context = patterns.getContext();
 
   IREE::LinalgExt::LinalgTransformationFilter filter(
@@ -63,19 +152,19 @@
           StringAttr::get(context, getWorkgroupMemoryMarker())},
       StringAttr::get(context, getWorkgroupKTiledMarker()));
   filter.setMatchByDefault();
-  TilingPatterns<linalg::MatmulOp, linalg::BatchMatmulOp, linalg::GenericOp,
-                 linalg::Conv2DNhwcHwcfOp,
-                 linalg::Conv2DNchwFchwOp>::insert(patterns, tilingOptions,
-                                                   filter);
+
+  patterns.add<TileConsumerAndFuseInputProducer>(context, filter,
+                                                 fuseInputProducer);
 }
 
-LogicalResult tileToSerialLoops(func::FuncOp funcOp, bool onlyReduction) {
+LogicalResult tileReductionToSerialLoops(func::FuncOp funcOp,
+                                         bool fuseInputProducer) {
   {
     // Tile again at the workgroup level since redution dimension were
     // ignored. Dimensions already tiled will be ignore since we tile to the
     // same size.
     RewritePatternSet wgTilingPatterns(funcOp.getContext());
-    populateTilingPatterns(wgTilingPatterns, onlyReduction);
+    populateTilingPatterns(wgTilingPatterns, fuseInputProducer);
     if (failed(applyPatternsAndFoldGreedily(funcOp,
                                             std::move(wgTilingPatterns)))) {
       return failure();
@@ -233,9 +322,7 @@
 
     // Tile to serial loops to the wg tile size to handle reductions and other
     // dimension that have not been distributed.
-    if (failed(tileToSerialLoops(funcOp, /*onlyReduction=*/true))) {
-      return signalPassFailure();
-    }
+    if (failed(tileReductionToSerialLoops(funcOp))) return signalPassFailure();
 
     LLVM_DEBUG({
       llvm::dbgs() << "--- After tile reductions:";
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir
index 4cf4ebb..c8000b3 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir
@@ -75,12 +75,12 @@
 
 // test corner case where the promoted value has multiple uses.
 //    CHECK-LABEL: func.func @matmul_multi_uses
-//         CHECK:    %[[C:.*]] = flow.dispatch.tensor.load
-//         CHECK:    %[[A:.*]] = flow.dispatch.tensor.load
-//         CHECK:    %[[B:.*]] = flow.dispatch.tensor.load
+//         CHECK:    %[[C:.*]] = flow.dispatch.tensor.load {{.+}} -> tensor<32x128xf32>
+//         CHECK:    %[[A:.*]] = flow.dispatch.tensor.load {{.+}} -> tensor<32x1024xf32>
+//         CHECK:    %[[B:.*]] = flow.dispatch.tensor.load {{.+}} -> tensor<1024x128xf32>
 //         CHECK:    %[[PA:.*]] = bufferization.alloc_tensor() copy(%[[A]]) {bufferization.escape = [false]} : tensor<32x1024xf32>
 //         CHECK:    %[[PB:.*]] = bufferization.alloc_tensor() copy(%[[B]]) {bufferization.escape = [false]} : tensor<1024x128xf32>
-//         CHECK:    %[[M:.*]] = linalg.matmul {{.*}} ins(%[[PA]], %[[PB]] : tensor<32x1024xf32>, tensor<1024x128xf32>) outs(%{{.*}} : tensor<32x128xf32>) -> tensor<32x128xf32>
+//         CHECK:    %[[M:.*]] = linalg.matmul ins(%[[PA]], %[[PB]] : tensor<32x1024xf32>, tensor<1024x128xf32>) outs(%{{.*}} : tensor<32x128xf32>) -> tensor<32x128xf32>
 //         CHECK:    "some_use"(%[[A]]) : (tensor<32x1024xf32>) -> ()
 
 // -----
@@ -129,3 +129,68 @@
 // CHECK-LABEL: func.func @matmul_33x33x903168_f32
 // CHECK-NOT: bufferization.alloc_tensor()
 
+// -----
+
+func.func @weight_dequant_matmul() {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<86x128x2048xi4>>
+  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<86x2048xf32>>
+  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<86x2048xi4>>
+  %3 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x86x128xf32>>
+  %4 = hal.interface.binding.subspan set(0) binding(5) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4096x2048xf32>>
+  %workgroup_id_y = hal.interface.workgroup.id[1] : index
+  %5 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
+  %workgroup_id_x = hal.interface.workgroup.id[0] : index
+  %6 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_x]
+  %7 = flow.dispatch.tensor.load %4, offsets = [%5, %6], sizes = [32, 128], strides = [1, 1] : !flow.dispatch.tensor<writeonly:tensor<4096x2048xf32>> -> tensor<32x128xf32>
+  %8 = flow.dispatch.tensor.load %3, offsets = [%5, 0, 0], sizes = [32, 86, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x86x128xf32>> -> tensor<32x86x128xf32>
+  %9 = flow.dispatch.tensor.load %0, offsets = [0, 0, %6], sizes = [86, 128, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<86x128x2048xi4>> -> tensor<86x128x128xi4>
+  %10 = flow.dispatch.tensor.load %1, offsets = [0, %6], sizes = [86, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<86x2048xf32>> -> tensor<86x128xf32>
+  %11 = flow.dispatch.tensor.load %2, offsets = [0, %6], sizes = [86, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<86x2048xi4>> -> tensor<86x128xi4>
+  %12 = tensor.empty() : tensor<86x128x128xf32>
+  %13 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  } ins(%9, %10, %11 : tensor<86x128x128xi4>, tensor<86x128xf32>, tensor<86x128xi4>) outs(%12 : tensor<86x128x128xf32>) {
+  ^bb0(%in: i4, %in_0: f32, %in_1: i4, %out: f32):
+    %16 = arith.extsi %in : i4 to i32
+    %17 = arith.extsi %in_1 : i4 to i32
+    %18 = arith.subi %16, %17 : i32
+    %19 = arith.sitofp %18 : i32 to f32
+    %20 = arith.mulf %19, %in_0 : f32
+    linalg.yield %20 : f32
+  } -> tensor<86x128x128xf32>
+  %14 = linalg.fill ins(%cst : f32) outs(%7 : tensor<32x128xf32>) -> tensor<32x128xf32>
+  %15 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]
+  } ins(%8, %13 : tensor<32x86x128xf32>, tensor<86x128x128xf32>) outs(%14 : tensor<32x128xf32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 128, 1, 32]]>} {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %16 = arith.mulf %in, %in_0 : f32
+    %17 = arith.addf %out, %16 : f32
+    linalg.yield %17 : f32
+  } -> tensor<32x128xf32>
+  flow.dispatch.tensor.store %15, %4, offsets = [%5, %6], sizes = [32, 128], strides = [1, 1] : tensor<32x128xf32> -> !flow.dispatch.tensor<writeonly:tensor<4096x2048xf32>>
+  return
+}
+
+// CHECK-LABEL: func.func @weight_dequant_matmul()
+//       CHECK:   %[[LHS_LD:.+]] = flow.dispatch.tensor.load {{.+}} : !flow.dispatch.tensor<readonly:tensor<4096x86x128xf32>> -> tensor<32x86x128xf32>
+// Check that the linalg.fill as the matmul initial result is not fused in the serial loops.
+//       CHECK:   %[[FILL:.+]] = linalg.fill
+// Check that two serial loops are materialized for reductions.
+//       CHECK:   scf.for %{{.+}} = %c0 to %c86 step %c1 iter_args(%[[ARG1:.+]] = %[[FILL]]) -> (tensor<32x128xf32>)
+//       CHECK:     scf.for %{{.+}} = %c0 to %c128 step %c32 iter_args(%[[ARG2:.+]] = %[[ARG1]]) -> (tensor<32x128xf32>)
+//       CHECK:       %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS_LD]]
+// Check that we have a bufferization.alloc_tensor() for in-place bufferization later.
+//       CHECK:       %[[RHS_ALLOC:.+]] = bufferization.alloc_tensor() {bufferization.escape = [false]} : tensor<1x32x128xf32>
+// Check that the weight dequant linalg.generic is fused inside the serial loops.
+//       CHECK:       %[[RHS:.+]] = linalg.generic
+//  CHECK-SAME:        outs(%[[RHS_ALLOC]] : tensor<1x32x128xf32>)
+//       CHECK:       %[[LHS_ALLOC:.+]] = bufferization.alloc_tensor() copy(%[[LHS_SLICE]]) {bufferization.escape = [false]} : tensor<32x1x32xf32>
+//       CHECK:       linalg.generic
+//  CHECK-SAME:         ins(%[[LHS_ALLOC]], %[[RHS]] : tensor<32x1x32xf32>, tensor<1x32x128xf32>)
+//  CHECK-SAME:         outs(%[[ARG2]] : tensor<32x128xf32>)
+//       CHECK:       scf.yield
+//       CHECK:     scf.yield
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
index ee00f42..6637fda 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
@@ -24,6 +24,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -40,8 +41,7 @@
 /// level but we may have extra tiling for the reduction dimension. Therefore we
 /// tile again without distributing.
 static void populateTilingReductionPatterns(RewritePatternSet &patterns) {
-  auto tileSizesFn = [&](OpBuilder &builder,
-                         Operation *op) -> SmallVector<Value, 4> {
+  auto tileSizesFn = [](OpBuilder &builder, Operation *op) {
     auto interfaceOp = cast<PartitionableLoopsInterface>(*op);
     auto partitionedLoops =
         interfaceOp.getPartitionableLoops(kNumMaxParallelDims);
@@ -65,8 +65,38 @@
           StringAttr::get(context, getWorkgroupMemoryMarker())},
       StringAttr::get(context, getWorkgroupKTiledMarker()));
   filter.setMatchByDefault();
-  TilingPatterns<linalg::MatmulOp, linalg::BatchMatmulOp,
-                 linalg::GenericOp>::insert(patterns, tilingOptions, filter);
+  TilingPatterns<linalg::MatmulOp, linalg::BatchMatmulOp, linalg::GenericOp,
+                 linalg::Conv2DNhwcHwcfOp,
+                 linalg::Conv2DNchwFchwOp>::insert(patterns, tilingOptions,
+                                                   filter);
+}
+
+static LogicalResult tileToSerialLoops(func::FuncOp funcOp) {
+  {
+    // Tile again at the workgroup level since redution dimension were
+    // ignored. Dimensions already tiled will be ignore since we tile to the
+    // same size.
+    RewritePatternSet wgTilingPatterns(funcOp.getContext());
+    populateTilingReductionPatterns(wgTilingPatterns);
+    if (failed(applyPatternsAndFoldGreedily(funcOp,
+                                            std::move(wgTilingPatterns)))) {
+      return failure();
+    }
+  }
+
+  {
+    RewritePatternSet wgTilingCanonicalizationPatterns =
+        linalg::getLinalgTilingCanonicalizationPatterns(funcOp.getContext());
+    populateAffineMinSCFCanonicalizationPattern(
+        wgTilingCanonicalizationPatterns);
+    scf::populateSCFForLoopCanonicalizationPatterns(
+        wgTilingCanonicalizationPatterns);
+    if (failed(applyPatternsAndFoldGreedily(
+            funcOp, std::move(wgTilingCanonicalizationPatterns)))) {
+      return failure();
+    }
+    return success();
+  }
 }
 
 /// Return the tile size associated to one thread or warp based on the number of
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
index ebe52b6..bba0dc4 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
@@ -41,6 +41,7 @@
             "emulate_i64.mlir",
             "erase_storage_buffer_static_shape.mlir",
             "illegal_configuration.mlir",
+            "lowering_matmul_fusion.mlir",
             "lowering_matmul_promotion.mlir",
             "lowering_reduction.mlir",
             "map_memref_storage_class.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
index 20d6e8f..026cd5d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
@@ -37,6 +37,7 @@
     "emulate_i64.mlir"
     "erase_storage_buffer_static_shape.mlir"
     "illegal_configuration.mlir"
+    "lowering_matmul_fusion.mlir"
     "lowering_matmul_promotion.mlir"
     "lowering_reduction.mlir"
     "map_memref_storage_class.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_fusion.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_fusion.mlir
new file mode 100644
index 0000000..09d8f0f
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matmul_fusion.mlir
@@ -0,0 +1,118 @@
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-spirv-lower-executable-target-pass)))' %s | FileCheck %s
+
+#compilation = #iree_codegen.compilation_info<
+    lowering_config  = <tile_sizes = [[32, 128, 1, 32]]>,
+    translation_info = <SPIRVMatmulPromoteVectorize pipeline_depth = 1>,
+    workgroup_size = [32, 8, 1]>
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>,
+    #hal.descriptor_set.binding<3, storage_buffer>,
+    #hal.descriptor_set.binding<3, storage_buffer>
+  ]>
+]>
+
+hal.executable @matmul_i4_quant_weight {
+  hal.executable.variant public @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", {
+    spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Shader], []>, AMD:DiscreteGPU, #spirv.resource_limits<
+      max_compute_shared_memory_size = 49152,
+      max_compute_workgroup_invocations = 1024,
+      max_compute_workgroup_size = [65535, 65535, 65535],
+      subgroup_size = 32>>}> {
+    hal.executable.export public @matmul_i4_quant_weight ordinal(0) layout(#pipeline_layout) {
+    ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
+      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+      hal.return %x, %y, %z : index, index, index
+    }
+    builtin.module {
+      func.func @matmul_i4_quant_weight() {
+        %c32 = arith.constant 32 : index
+        %c128 = arith.constant 128 : index
+        %c0 = arith.constant 0 : index
+        %cst = arith.constant 0.000000e+00 : f32
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<86x128x2048xi4>>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<86x2048xf32>>
+        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<86x2048xi4>>
+        %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x86x128xf32>>
+        %4 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4096x2048xf32>>
+        %workgroup_id_x = hal.interface.workgroup.id[0] : index
+        %workgroup_id_y = hal.interface.workgroup.id[1] : index
+        %5 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
+        %6 = flow.dispatch.tensor.load %3, offsets = [%5, 0, 0], sizes = [%c32, 86, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x86x128xf32>> -> tensor<?x86x128xf32>
+        %7 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_x]
+        %8 = flow.dispatch.tensor.load %0, offsets = [0, 0, %7], sizes = [86, 128, %c128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<86x128x2048xi4>> -> tensor<86x128x?xi4>
+        %9 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_x]
+        %10 = flow.dispatch.tensor.load %1, offsets = [0, %9], sizes = [86, %c128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<86x2048xf32>> -> tensor<86x?xf32>
+        %11 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_x]
+        %12 = flow.dispatch.tensor.load %2, offsets = [0, %11], sizes = [86, %c128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<86x2048xi4>> -> tensor<86x?xi4>
+        %13 = tensor.empty() : tensor<86x128x128xf32>
+        %cast = tensor.cast %8 : tensor<86x128x?xi4> to tensor<86x128x128xi4>
+        %cast_0 = tensor.cast %10 : tensor<86x?xf32> to tensor<86x128xf32>
+        %cast_1 = tensor.cast %12 : tensor<86x?xi4> to tensor<86x128xi4>
+        %14 = linalg.generic {
+          indexing_maps = [
+            affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+            affine_map<(d0, d1, d2) -> (d0, d2)>,
+            affine_map<(d0, d1, d2) -> (d0, d2)>,
+            affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+          iterator_types = ["parallel", "parallel", "parallel"]
+        } ins(%cast, %cast_0, %cast_1 : tensor<86x128x128xi4>, tensor<86x128xf32>, tensor<86x128xi4>) outs(%13 : tensor<86x128x128xf32>) {
+        ^bb0(%in: i4, %in_4: f32, %in_5: i4, %out: f32):
+          %20 = arith.extsi %in : i4 to i32
+          %21 = arith.extsi %in_5 : i4 to i32
+          %22 = arith.subi %20, %21 : i32
+          %23 = arith.sitofp %22 : i32 to f32
+          %24 = arith.mulf %23, %in_4 : f32
+          linalg.yield %24 : f32
+        } -> tensor<86x128x128xf32>
+        %15 = tensor.empty() : tensor<32x128xf32>
+        %16 = linalg.fill ins(%cst : f32) outs(%15 : tensor<32x128xf32>) -> tensor<32x128xf32>
+        %cast_2 = tensor.cast %6 : tensor<?x86x128xf32> to tensor<32x86x128xf32>
+        %17 = linalg.generic {
+          indexing_maps = [
+            affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+            affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)>,
+            affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
+          iterator_types = ["parallel", "parallel", "reduction", "reduction"]
+        } ins(%cast_2, %14 : tensor<32x86x128xf32>, tensor<86x128x128xf32>) outs(%16 : tensor<32x128xf32>) attrs = {compilation_info = #compilation} {
+        ^bb0(%in: f32, %in_4: f32, %out: f32):
+          %20 = arith.mulf %in, %in_4 : f32
+          %21 = arith.addf %out, %20 : f32
+          linalg.yield %21 : f32
+        } -> tensor<32x128xf32>
+        %cast_3 = tensor.cast %17 : tensor<32x128xf32> to tensor<?x?xf32>
+        %18 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
+        %19 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_x]
+        flow.dispatch.tensor.store %cast_3, %4, offsets = [%18, %19], sizes = [%c32, %c128], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:tensor<4096x2048xf32>>
+        return
+      }
+    }
+  }
+}
+
+//     CHECK-LABEL: func.func @matmul_i4_quant_weight()
+//           CHECK:   %[[A_ALLOC:.+]] = memref.alloc() : memref<32x1x36xf32, #gpu.address_space<workgroup>>
+//           CHECK:   %[[B_ALLOC:.+]] = memref.alloc() : memref<1x32x132xf32, #gpu.address_space<workgroup>>
+//           CHECK:   %[[WEIGHT_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
+//           CHECK:   %[[SCALE_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
+//           CHECK:   %[[ZP_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
+//           CHECK:   scf.for %arg0 = %c0 to %c86 step %c1 iter_args({{.+}}) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>)
+//           CHECK:     %[[SCALE:.+]] = vector.transfer_read %[[SCALE_BINDING]]
+//           CHECK:     %[[ZP:.+]] = vector.transfer_read %[[ZP_BINDING]]
+//           CHECK:     %[[ZP_EXT:.+]] = arith.extsi %[[ZP]] : vector<4xi4> to vector<4xi32>
+//           CHECK:     scf.for %arg5 = %c0 to %c96 step %c32 iter_args({{.+}}) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>)
+//   CHECK-COUNT-4:       vector.transfer_read %[[WEIGHT_BINDING]]
+//   CHECK-COUNT-4:       arith.extsi %{{.+}} : vector<4xi4> to vector<4xi32>
+//   CHECK-COUNT-4:       arith.subi %{{.+}}, %[[ZP_EXT]] : vector<4xi32>
+//   CHECK-COUNT-4:       arith.sitofp %{{.+}} : vector<4xi32> to vector<4xf32>
+//   CHECK-COUNT-4:       arith.mulf %{{.+}}, %[[SCALE]] : vector<4xf32>
+//   CHECK-COUNT-4:       vector.transfer_write %{{.+}}, %[[B_ALLOC]]
+//           CHECK:       gpu.barrier
+//           CHECK:       vector.transfer_write %{{.+}}, %[[A_ALLOC]]
+//           CHECK:       gpu.barrier
+//  CHECK-COUNT-32:       vector.transfer_read %[[A_ALLOC]]
+//  CHECK-COUNT-32:       vector.transfer_read %[[B_ALLOC]]
+// CHECK-COUNT-128:       vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<4xf32>
+//   CHECK-COUNT-2:     scf.yield