[LLVMGPU][NFC] Create LLVMGPU pass for IGEMM (#18871)

This PR refactors the ConvolutionToIGEMM pass to a shared transform
function, and creates a new pass for LLVMGPU. This keeps the lowering
config details in LLVMGPU separate from the common pass, and removes the
need for passing a control function or config function in the pass
constructor. This is also a precursor to adding some more complex logic
in the control function for LLVMGPU, which will be added in a later PR.

---------

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp
index 58b678c..8998b11 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp
@@ -5,6 +5,7 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 #include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/Common/Transforms.h"
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
 #include "iree/compiler/Codegen/Transforms/Transforms.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
@@ -12,6 +13,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/MLIRContext.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassRegistry.h"
@@ -26,10 +28,14 @@
 
 using iree_compiler::IREE::LinalgExt::IREELinalgExtDialect;
 
+/// Pattern to set a lowering configuration on an IGEMM convolution. Searches
+/// for a contraction with a linalg_ext.im2col producer, and calls the configFn
+/// to set the configuration.
+/// TODO(Max191): Use a funcOp walk instead of a pattern for this.
 struct SetIGEMMConfiguration final : OpRewritePattern<linalg::GenericOp> {
   using OpRewritePattern::OpRewritePattern;
 
-  SetIGEMMConfiguration(MLIRContext *context, ConfigFn configFn)
+  SetIGEMMConfiguration(MLIRContext *context, IGEMMConfigFn configFn)
       : OpRewritePattern(context), configFn(configFn) {}
 
   LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
@@ -67,7 +73,7 @@
   }
 
 private:
-  ConfigFn configFn;
+  IGEMMConfigFn configFn;
 };
 
 class ConvolutionToIGEMMPass final
@@ -75,91 +81,87 @@
 public:
   using ConvolutionToIGEMMPassBase::ConvolutionToIGEMMPassBase;
 
-  explicit ConvolutionToIGEMMPass(ConfigFn configFn) : configFn(configFn) {}
+  ConvolutionToIGEMMPass(std::optional<IGEMMConfigFn> configFn,
+                         std::optional<IGEMMControlFn> controlFn)
+      : configFn(configFn), controlFn(controlFn) {}
 
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<tensor::TensorDialect, IREELinalgExtDialect>();
-  }
-  void runOnOperation() override {
-    MLIRContext *context = &getContext();
-
-    // Rewrite convolutions into a im2col and GEMM.
-    {
-      auto conv2dToIm2colControlFn = [](Operation *conv) {
-        // Don't transform convolutions that have a preset lowering config.
-        if (getLoweringConfig(conv)) {
-          return false;
-        }
-        return true;
-      };
-      MLIRContext *context = &getContext();
-      RewritePatternSet patterns(context);
-      iree_compiler::IREE::LinalgExt::populateConv2DToIm2colOpPatterns(
-          patterns, conv2dToIm2colControlFn);
-      patterns.add<SetIGEMMConfiguration>(context, configFn);
-      if (failed(applyPatternsAndFoldGreedily(getOperation(),
-                                              std::move(patterns)))) {
-        return signalPassFailure();
-      }
-    }
-
-    // The im2col transformation collapses some of the dimensions of the
-    // convolution operands. Try to push the reshape ops towards the boundaries
-    // of the function and fold with interface tensor ops.
-    //
-    // TODO(Max191): Allow for the im2col op to have multiple M dimensions, and
-    //   generate a multi-M dim contraction instead of collapsing and
-    //   propagating reshapes. It should ultimately become a pass option to
-    //   decide whether to collapse the contraction dimensions into a single
-    //   M/N/K dimension.
-    {
-      RewritePatternSet bubbleCollapseShapePatterns(context);
-      linalg::ControlFusionFn bubbleUpExpansionControlFn =
-          [](OpOperand *fusedOperand) {
-            Operation *producer = fusedOperand->get().getDefiningOp();
-            Operation *consumer = fusedOperand->getOwner();
-
-            // Block only if one of the operations has a lowering configuration
-            // which means it likely expects tiling specific to its original
-            // shape.
-            if (getLoweringConfig(producer) || getLoweringConfig(consumer)) {
-              return false;
-            }
-            return true;
-          };
-      linalg::populateFoldReshapeOpsByCollapsingPatterns(
-          bubbleCollapseShapePatterns, bubbleUpExpansionControlFn);
-      // Add patterns to do some additional cleanup (on top of canonicalizations
-      // that can be done later) of reshape ops.
-      tensor::populateFoldTensorEmptyPatterns(bubbleCollapseShapePatterns);
-      linalg::FillOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
-                                                  context);
-      tensor::CollapseShapeOp::getCanonicalizationPatterns(
-          bubbleCollapseShapePatterns, context);
-      tensor::EmptyOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
-                                                   context);
-      tensor::ExpandShapeOp::getCanonicalizationPatterns(
-          bubbleCollapseShapePatterns, context);
-      populateReshapeToInterfaceTensorPatterns(bubbleCollapseShapePatterns);
-      if (failed(applyPatternsAndFoldGreedily(
-              getOperation(), std::move(bubbleCollapseShapePatterns)))) {
-        return signalPassFailure();
-      }
-    }
-  }
+  void runOnOperation() override;
 
 private:
-  ConfigFn configFn = [](linalg::GenericOp genericOp,
-                         IREE::LinalgExt::Im2colOp im2colOp) {
-    return failure();
-  };
+  std::optional<IGEMMConfigFn> configFn;
+  std::optional<IGEMMControlFn> controlFn;
 };
 
 } // namespace
 
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createConvolutionToIGEMMPass(ConfigFn configFn) {
-  return std::make_unique<ConvolutionToIGEMMPass>(configFn);
+LogicalResult
+convertToIGEMMAndSetConfig(FunctionOpInterface funcOp,
+                           std::optional<IGEMMConfigFn> configFn,
+                           std::optional<IGEMMControlFn> controlFn) {
+  // Rewrite convolutions into a im2col and GEMM.
+  MLIRContext *context = funcOp->getContext();
+  {
+    RewritePatternSet patterns(context);
+    iree_compiler::IREE::LinalgExt::populateConv2DToIm2colOpPatterns(patterns,
+                                                                     controlFn);
+    if (configFn.has_value()) {
+      patterns.add<SetIGEMMConfiguration>(context, configFn.value());
+    }
+    if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+      return failure();
+    }
+  }
+
+  // The im2col transformation collapses some of the dimensions of the
+  // convolution operands. Try to push the reshape ops towards the boundaries
+  // of the function and fold with interface tensor ops.
+  //
+  // TODO(Max191): Allow for the im2col op to have multiple M dimensions, and
+  //   generate a multi-M dim contraction instead of collapsing and
+  //   propagating reshapes. It should ultimately become a pass option to
+  //   decide whether to collapse the contraction dimensions into a single
+  //   M/N/K dimension.
+  {
+    RewritePatternSet bubbleCollapseShapePatterns(context);
+    linalg::ControlFusionFn bubbleUpExpansionControlFn =
+        [](OpOperand *fusedOperand) {
+          Operation *producer = fusedOperand->get().getDefiningOp();
+          Operation *consumer = fusedOperand->getOwner();
+
+          // Block only if one of the operations has a lowering configuration
+          // which means it likely expects tiling specific to its original
+          // shape.
+          if (getLoweringConfig(producer) || getLoweringConfig(consumer)) {
+            return false;
+          }
+          return true;
+        };
+    linalg::populateFoldReshapeOpsByCollapsingPatterns(
+        bubbleCollapseShapePatterns, bubbleUpExpansionControlFn);
+    // Add patterns to do some additional cleanup (on top of canonicalizations
+    // that can be done later) of reshape ops.
+    tensor::populateFoldTensorEmptyPatterns(bubbleCollapseShapePatterns);
+    linalg::FillOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
+                                                context);
+    tensor::CollapseShapeOp::getCanonicalizationPatterns(
+        bubbleCollapseShapePatterns, context);
+    tensor::EmptyOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
+                                                 context);
+    tensor::ExpandShapeOp::getCanonicalizationPatterns(
+        bubbleCollapseShapePatterns, context);
+    populateReshapeToInterfaceTensorPatterns(bubbleCollapseShapePatterns);
+    if (failed(applyPatternsAndFoldGreedily(
+            funcOp, std::move(bubbleCollapseShapePatterns)))) {
+      return failure();
+    }
+  }
+  return success();
+}
+
+void ConvolutionToIGEMMPass::runOnOperation() {
+  if (failed(convertToIGEMMAndSetConfig(getOperation()))) {
+    return signalPassFailure();
+  }
 }
 
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h
index 94192d5..eac457d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h
@@ -60,13 +60,6 @@
 createConvertToDestinationPassingStylePass(
     bool useWARForCooperativeMatrixCodegen);
 
-using ConfigFn =
-    std::function<LogicalResult(linalg::GenericOp, IREE::LinalgExt::Im2colOp)>;
-/// Pass to convert Conv2D ops into IGEMM (Im2colOp + matmul). `configFn` is
-/// used to set lowering configurations on the resulting ops, if necessary.
-std::unique_ptr<InterfacePass<FunctionOpInterface>>
-createConvolutionToIGEMMPass(ConfigFn configFn);
-
 std::unique_ptr<Pass> createDecomposeSoftmaxPass(bool useFusion);
 
 /// Pass to perform linalg on tensor bufferization. The function passed into
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td
index ff281d6..6a5a9b5 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td
@@ -83,6 +83,10 @@
     InterfacePass<"iree-codegen-convolution-to-igemm", "mlir::FunctionOpInterface"> {
   let summary =
       "Transforms convolution operations into an implicit GEMM format.";
+  let dependentDialects = [
+    "tensor::TensorDialect",
+    "iree_compiler::IREE::LinalgExt::IREELinalgExtDialect"
+  ];
 }
 
 def DecomposeAffineOpsPass: Pass<"iree-codegen-decompose-affine-ops"> {
diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.h b/compiler/src/iree/compiler/Codegen/Common/Transforms.h
index 13cdbf5..0a00034 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.h
@@ -18,6 +18,17 @@
 
 namespace mlir::iree_compiler {
 
+using IGEMMConfigFn =
+    std::function<LogicalResult(linalg::GenericOp, IREE::LinalgExt::Im2colOp)>;
+using IGEMMControlFn = std::function<bool(Operation *)>;
+
+/// Converts conv_2d ops into linalg_ext.im2col + matmul, and sets a lowering
+/// configuration on the matmul.
+LogicalResult convertToIGEMMAndSetConfig(
+    FunctionOpInterface funcOp,
+    std::optional<IGEMMConfigFn> configFn = std::nullopt,
+    std::optional<IGEMMControlFn> controlFn = std::nullopt);
+
 /// Eliminates tensor.empty ops to avoid buffer allocations.
 LogicalResult eliminateEmptyTensors(
     RewriterBase &rewriter, Operation *op,
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir
index 3d5494e..3373fda 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir
@@ -69,25 +69,6 @@
 
 // -----
 
-#map = affine_map<(d0, d1, d2, d3)->(d0, d1, d2, d3)>
-#config = #iree_codegen.lowering_config<tile_sizes = [[0, 1, 4, 32], [0, 1, 2, 4], [0, 0, 0, 0, 1, 1, 4], [0, 1, 0, 0]]>
-func.func public @conv_with_lowering_config(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf32> {
-  %cst = arith.constant 0.0 : f32
-  %empty = tensor.empty() : tensor<1x14x14x16xf32>
-  %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
-  %0 = linalg.conv_2d_nhwc_hwcf {lowering_config = #config,
-      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
-    ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
-    outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
-  return %0 : tensor<1x14x14x16xf32>
-}
-// CHECK:      func.func public @conv_with_lowering_config
-// CHECK-NOT:    iree_linalg_ext.im2col
-// CHECK:        linalg.conv_2d_nhwc_hwcf
-// CHECK-SAME:     lowering_config
-
-// -----
-
 #map = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
 #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
index b074612..3d8c7a2 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
@@ -95,6 +95,7 @@
         "LLVMGPUCastTypeToFitMMA.cpp",
         "LLVMGPUConfigureTensorLayouts.cpp",
         "LLVMGPUConfigureVectorLayouts.cpp",
+        "LLVMGPUConvolutionToIGEMM.cpp",
         "LLVMGPULowerExecutableTarget.cpp",
         "LLVMGPUPackSharedMemoryAlloc.cpp",
         "LLVMGPUPrefetching.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index 6a92f60..9016d63 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -80,6 +80,7 @@
     "LLVMGPUCastTypeToFitMMA.cpp"
     "LLVMGPUConfigureTensorLayouts.cpp"
     "LLVMGPUConfigureVectorLayouts.cpp"
+    "LLVMGPUConvolutionToIGEMM.cpp"
     "LLVMGPULowerExecutableTarget.cpp"
     "LLVMGPUPackSharedMemoryAlloc.cpp"
     "LLVMGPUPrefetching.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConvolutionToIGEMM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConvolutionToIGEMM.cpp
new file mode 100644
index 0000000..b88696a
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConvolutionToIGEMM.cpp
@@ -0,0 +1,66 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/Common/Transforms.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+
+#define DEBUG_TYPE "iree-llvmgpu-convolution-to-igemm"
+
+namespace mlir::iree_compiler {
+
+#define GEN_PASS_DEF_LLVMGPUCONVOLUTIONTOIGEMMPASS
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
+
+namespace {
+
+/// Function for setting lowering configurations on contractions resulting from
+/// the IGEMM transformation. This currently uses the TileAndFuse pipeline, and
+/// tries to target MMA intrinsics.
+static LogicalResult llvmgpuConfigFn(linalg::GenericOp genericOp,
+                                     IREE::LinalgExt::Im2colOp im2colOp) {
+  auto funcOp = genericOp->getParentOfType<FunctionOpInterface>();
+  if (!funcOp) {
+    return genericOp.emitError("cannot find parent funcOp");
+  }
+  IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp);
+  if (!target) {
+    return funcOp.emitError("missing GPU target in parent funcOp");
+  }
+  if (failed(IREE::GPU::setMatmulLoweringConfig(target, funcOp, genericOp))) {
+    return IREE::GPU::setTileAndFuseLoweringConfig(target, funcOp, genericOp);
+  }
+  return success();
+}
+
+static bool llvmgpuControlFn(Operation *op) {
+  // Do not convert anything that already has a lowering configuration.
+  if (getLoweringConfig(op)) {
+    return false;
+  }
+  return true;
+}
+
+struct LLVMGPUConvolutionToIGEMMPass final
+    : impl::LLVMGPUConvolutionToIGEMMPassBase<LLVMGPUConvolutionToIGEMMPass> {
+  using impl::LLVMGPUConvolutionToIGEMMPassBase<
+      LLVMGPUConvolutionToIGEMMPass>::LLVMGPUConvolutionToIGEMMPassBase;
+
+  void runOnOperation() override;
+};
+
+void LLVMGPUConvolutionToIGEMMPass::runOnOperation() {
+  if (failed(convertToIGEMMAndSetConfig(getOperation(), llvmgpuConfigFn,
+                                        llvmgpuControlFn))) {
+    return signalPassFailure();
+  }
+}
+
+} // namespace
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 51fcc6b..aab73c9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -1170,29 +1170,12 @@
 // Common Pass Pipelines
 //===----------------------------------------------------------------------===//
 
-static LogicalResult igemmConfigFn(linalg::GenericOp genericOp,
-                                   IREE::LinalgExt::Im2colOp im2colOp) {
-  auto funcOp = genericOp->getParentOfType<FunctionOpInterface>();
-  if (!funcOp) {
-    return genericOp.emitError("cannot find parent funcOp");
-  }
-  IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp);
-  if (!target) {
-    return funcOp.emitError("missing GPU target in parent funcOp");
-  }
-  if (failed(IREE::GPU::setMatmulLoweringConfig(target, funcOp, genericOp))) {
-    return IREE::GPU::setTileAndFuseLoweringConfig(target, funcOp, genericOp);
-  }
-  return success();
-}
-
 static void buildLLVMGPUCodegenConfigurationPassPipelineImpl(
     OpPassManager &modulePassManager) {
   {
     FunctionLikeNest funcPassManager(modulePassManager);
-    funcPassManager.addPredicatedPass(clLLVMGPUUseIgemm, []() {
-      return createConvolutionToIGEMMPass(igemmConfigFn);
-    });
+    funcPassManager.addPredicatedPass(clLLVMGPUUseIgemm,
+                                      createLLVMGPUConvolutionToIGEMMPass);
     funcPassManager.addPass(createGPUGeneralizeNamedOpsPass);
     addCommonTargetExecutablePreprocessingPasses(funcPassManager);
     addEncodingToNopPasses(funcPassManager);
@@ -1242,9 +1225,8 @@
     OpPassManager &modulePassManager) {
   {
     FunctionLikeNest funcPassManager(modulePassManager);
-    funcPassManager.addPredicatedPass(clLLVMGPUUseIgemm, []() {
-      return createConvolutionToIGEMMPass(igemmConfigFn);
-    });
+    funcPassManager.addPredicatedPass(clLLVMGPUUseIgemm,
+                                      createLLVMGPUConvolutionToIGEMMPass);
     funcPassManager.addPass(createGPUGeneralizeNamedOpsPass);
     addCommonTargetExecutablePreprocessingPasses(funcPassManager);
   }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
index 815a82f..aa6b552 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
@@ -87,6 +87,17 @@
   let summary = "Pass to set layouts for vector distribution";
 }
 
+def LLVMGPUConvolutionToIGEMMPass :
+    InterfacePass<"iree-llvmgpu-convolution-to-igemm", "mlir::FunctionOpInterface"> {
+  let summary = "Pass to convert conv_2d ops to igemm and set a lowering configuration.";
+  let dependentDialects = [
+    "tensor::TensorDialect",
+    "iree_compiler::IREE::Codegen::IREECodegenDialect",
+    "iree_compiler::IREE::GPU::IREEGPUDialect",
+    "iree_compiler::IREE::LinalgExt::IREELinalgExtDialect"
+  ];
+}
+
 def LLVMGPULowerExecutableTargetPass :
     InterfacePass<"iree-llvmgpu-lower-executable-target", "mlir::FunctionOpInterface"> {
   let summary = "Perform lowering of executable target using one of the IREE::HAL::DispatchLoweringPassPipeline";
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index 00bc6f9..4097320 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -49,6 +49,7 @@
             "legalize.mlir",
             "linalg_transform.mlir",
             "llvmgpu_bufferize.mlir",
+            "llvmgpu_convolution_to_igemm.mlir",
             "pack_pipeline_test.mlir",
             "pack_shared_memory_alloc.mlir",
             "prefetch_shared_memory.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 6be97c0..2a86fd3 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -40,6 +40,7 @@
     "legalize.mlir"
     "linalg_transform.mlir"
     "llvmgpu_bufferize.mlir"
+    "llvmgpu_convolution_to_igemm.mlir"
     "nvvm_extract_address_computation.mlir"
     "nvvm_mma_sync_pipeline_test.mlir"
     "nvvm_pipeline_test.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir
new file mode 100644
index 0000000..1fa2bae
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir
@@ -0,0 +1,36 @@
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 --pass-pipeline="builtin.module(func.func(iree-llvmgpu-convolution-to-igemm),canonicalize,cse)" %s | FileCheck %s
+
+#config = #iree_codegen.lowering_config<tile_sizes = [[0, 1, 4, 32], [0, 1, 2, 4], [0, 0, 0, 0, 1, 1, 4], [0, 1, 0, 0]]>
+func.func public @conv_with_lowering_config(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf32> {
+  %cst = arith.constant 0.0 : f32
+  %empty = tensor.empty() : tensor<1x14x14x16xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+  %0 = linalg.conv_2d_nhwc_hwcf {lowering_config = #config,
+      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+    ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
+    outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+  return %0 : tensor<1x14x14x16xf32>
+}
+// CHECK:      func.func public @conv_with_lowering_config
+// CHECK-NOT:    iree_linalg_ext.im2col
+// CHECK:        linalg.conv_2d_nhwc_hwcf
+// CHECK-SAME:     lowering_config
+
+// -----
+
+func.func public @set_lowering_config(%arg0: tensor<1x34x34x128xf32>, %arg1: tensor<3x3x128x128xf32>) -> tensor<1x32x32x128xf32> {
+  %cst = arith.constant 0.0 : f32
+  %empty = tensor.empty() : tensor<1x32x32x128xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x32x32x128xf32>) -> tensor<1x32x32x128xf32>
+  %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+    ins(%arg0, %arg1: tensor<1x34x34x128xf32>, tensor<3x3x128x128xf32>)
+    outs(%fill: tensor<1x32x32x128xf32>) -> tensor<1x32x32x128xf32>
+  return %0 : tensor<1x32x32x128xf32>
+}
+// CHECK:      func.func public @set_lowering_config
+// CHECK:        iree_linalg_ext.im2col
+// CHECK:        linalg.generic
+// CHECK-SAME:     lowering_config = #iree_gpu.lowering_config<
+// CHECK-SAME:         {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>,
+// CHECK-SAME:          promote_operands = [0, 1], reduction = [0, 0, 0, 0, 8],
+// CHECK-SAME:          subgroup = [0, 0, 2, 2, 0], workgroup = [1, 1, 2, 8, 0]}>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp
index 6e699fd..131ff3e 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp
@@ -38,7 +38,7 @@
 
 namespace {
 
-using ControlFnTy = std::optional<std::function<bool(Operation *)>>;
+using ControlFnTy = std::function<bool(Operation *)>;
 
 // Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing)
 // and linalg.matmul.
@@ -78,7 +78,8 @@
 public:
   using OpRewritePattern::OpRewritePattern;
 
-  ConvertConv2DNhwcHwcf(MLIRContext *context, ControlFnTy controlFn)
+  ConvertConv2DNhwcHwcf(MLIRContext *context,
+                        std::optional<ControlFnTy> controlFn)
       : OpRewritePattern<linalg::Conv2DNhwcHwcfOp>(context),
         controlFn(controlFn) {}
 
@@ -192,7 +193,7 @@
   }
 
 private:
-  ControlFnTy controlFn;
+  std::optional<ControlFnTy> controlFn;
 };
 
 // For nchw, because the channels are to the left of the image shape dimensions,
@@ -204,7 +205,8 @@
 public:
   using OpRewritePattern::OpRewritePattern;
 
-  ConvertConv2DNchwFchw(MLIRContext *context, ControlFnTy controlFn)
+  ConvertConv2DNchwFchw(MLIRContext *context,
+                        std::optional<ControlFnTy> controlFn)
       : OpRewritePattern<linalg::Conv2DNchwFchwOp>(context),
         controlFn(controlFn) {}
 
@@ -314,7 +316,7 @@
   }
 
 private:
-  ControlFnTy controlFn;
+  std::optional<ControlFnTy> controlFn;
 };
 
 struct ConvertConv2DToIm2ColOpPass final
@@ -335,7 +337,7 @@
 } // namespace
 
 void populateConv2DToIm2colOpPatterns(RewritePatternSet &patterns,
-                                      ControlFnTy controlFn) {
+                                      std::optional<ControlFnTy> controlFn) {
   patterns.insert<ConvertConv2DNhwcHwcf, ConvertConv2DNchwFchw>(
       patterns.getContext(), std::move(controlFn));
 }
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
index 1e858df..cc894b3 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
@@ -28,8 +28,10 @@
 splitReduction(RewriterBase &rewriter, LinalgExt::TopkOp topkOp,
                const TopkSplitReductionControlFn &splitReductionFn);
 
-// Patterns to convert linalg convolution ops into a gemm with an im2col
-// op and reshapes on the inputs.
+/// Patterns to convert linalg convolution ops into a gemm with an im2col
+/// op and reshapes on the inputs.
+/// TODO(Max191): Maybe move to transforms and use a funcOp walk instead of a
+///               rewrite pattern for this.
 void populateConv2DToIm2colOpPatterns(
     RewritePatternSet &patterns,
     std::optional<std::function<bool(Operation *)>> controlFn = std::nullopt);