[Codegen][IGEMM] Add new pass for IGEMM transformation with reshape propagation (#18161)
This PR adds a new pass to perform the IGEMM transformation in Codegen.
The new pass uses the `Conv2DToIm2colOp` patterns plus some reshape
propagation and cleanup patterns. The PR also adds a control function on
the `Conv2DToIm2colOp` patterns, in order to avoid transforming
configured operations.
This separates the `Conv2DToIm2colOp` transformation from the
codegen-specific IGEMM pipeline, and addresses an issue with fusions
that requires reshape propagation. When there are consumers of the
convolution op, the consumer needs to also be collapsed in order to tile
and fuse it with the GEMM.
Adding reshape propagation is just one solution to the fusion issue. The
other potential solution is to allow the im2col op to have multiple M
dimensions in its result, and create a multi-M contraction instead of
the collapsed version. This second solution is ideal as long as backends
are able to handle the multi-M contraction, but it requires more work to
change the im2col op semantics. For now this PR fixes the issue, and the
alternative solution is left as a TODO.
---------
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
index 15cd82b..ec3eaec 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
@@ -95,6 +95,7 @@
"ConvertBf16ArithToF32.cpp",
"ConvertBf16ToUInt16Buffers.cpp",
"ConvertToDestinationPassingStylePass.cpp",
+ "ConvolutionToIGEMM.cpp",
"DecomposeAffineOpsPass.cpp",
"DecomposeConvolutionToLowerDimOps.cpp",
"DecomposeLinalgGeneric.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index 76ed4cd..44c25aa 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -86,6 +86,7 @@
"ConvertBf16ArithToF32.cpp"
"ConvertBf16ToUInt16Buffers.cpp"
"ConvertToDestinationPassingStylePass.cpp"
+ "ConvolutionToIGEMM.cpp"
"DecomposeAffineOpsPass.cpp"
"DecomposeConvolutionToLowerDimOps.cpp"
"DecomposeLinalgGeneric.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp
new file mode 100644
index 0000000..7eb6003
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp
@@ -0,0 +1,104 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/PassDetail.h"
+#include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::iree_compiler {
+
+namespace {
+
+using iree_compiler::IREE::LinalgExt::IREELinalgExtDialect;
+
+class ConvolutionToIGEMMPass
+ : public ConvolutionToIGEMMBase<ConvolutionToIGEMMPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<tensor::TensorDialect, IREELinalgExtDialect>();
+ }
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+
+ // Rewrite convolutions into a im2col and GEMM.
+ {
+ auto conv2dToIm2colControlFn = [](Operation *conv) {
+ // Don't transform convolutions that have a preset lowering config.
+ if (getLoweringConfig(conv)) {
+ return false;
+ }
+ return true;
+ };
+ RewritePatternSet patterns(&getContext());
+ iree_compiler::IREE::LinalgExt::populateConv2DToIm2colOpPatterns(
+ patterns, conv2dToIm2colControlFn);
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+
+ // The im2col transformation collapses some of the dimensions of the
+ // convolution operands. Try to push the reshape ops towards the boundaries
+ // of the function and fold with interface tensor ops.
+ //
+ // TODO(Max191): Allow for the im2col op to have multiple M dimensions, and
+ // generate a multi-M dim contraction instead of collapsing and
+ // propagating reshapes. It should ultimately become a pass option to
+ // decide whether to collapse the contraction dimensions into a single
+ // M/N/K dimension.
+ {
+ RewritePatternSet bubbleCollapseShapePatterns(context);
+ linalg::ControlFusionFn bubbleUpExpansionControlFn =
+ [](OpOperand *fusedOperand) {
+ Operation *producer = fusedOperand->get().getDefiningOp();
+ Operation *consumer = fusedOperand->getOwner();
+
+ // Block only if one of the operations has a lowering configuration
+ // which means it likely expects tiling specific to its original
+ // shape.
+ if (getLoweringConfig(producer) || getLoweringConfig(consumer)) {
+ return false;
+ }
+ return true;
+ };
+ linalg::populateFoldReshapeOpsByCollapsingPatterns(
+ bubbleCollapseShapePatterns, bubbleUpExpansionControlFn);
+ // Add patterns to do some additional cleanup (on top of canonicalizations
+ // that can be done later) of reshape ops.
+ tensor::populateFoldTensorEmptyPatterns(bubbleCollapseShapePatterns);
+ linalg::FillOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
+ context);
+ tensor::CollapseShapeOp::getCanonicalizationPatterns(
+ bubbleCollapseShapePatterns, context);
+ tensor::EmptyOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
+ context);
+ tensor::ExpandShapeOp::getCanonicalizationPatterns(
+ bubbleCollapseShapePatterns, context);
+ populateReshapeToInterfaceTensorPatterns(bubbleCollapseShapePatterns);
+ if (failed(applyPatternsAndFoldGreedily(
+ getOperation(), std::move(bubbleCollapseShapePatterns)))) {
+ return signalPassFailure();
+ }
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createConvolutionToIGEMMPass() {
+ return std::make_unique<ConvolutionToIGEMMPass>();
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h
index 2880477..ade1550 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h
@@ -86,6 +86,10 @@
createConvertToDestinationPassingStylePass(
bool useWARForCooperativeMatrixCodegen = false);
+/// Converts convolution operations to a GEMM with an im2col op on the image.
+std::unique_ptr<InterfacePass<FunctionOpInterface>>
+createConvolutionToIGEMMPass();
+
// Decompose affine.apply operations into sub affine.apply that can be
// hoisted in different loops.
std::unique_ptr<Pass> createDecomposeAffineOpsPass();
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td
index ed18294..5bda8d1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td
@@ -70,6 +70,13 @@
];
}
+def ConvolutionToIGEMM :
+ InterfacePass<"iree-codegen-convolution-to-igemm", "mlir::FunctionOpInterface"> {
+ let summary =
+ "Transforms convolution operations into an implicit GEMM format.";
+ let constructor = "mlir::iree_compiler::createConvolutionToIGEMMPass()";
+}
+
def DecomposeAffineOps: Pass<"decompose-affine-ops"> {
let summary = "Decompose `affine.apply` operations into sub `affine.apply`";
let description = [{
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
index f0e3a8e..9651d49 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
@@ -27,6 +27,7 @@
"convert_bf16_to_uint16_buffers.mlir",
"convert_bf16_arith_to_f32.mlir",
"convert_to_destination_passing_style.mlir",
+ "convolution_to_igemm.mlir",
"convolutions.mlir",
"erase_dead_alloc_and_stores.mlir",
"decompose_affine_ops.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index 6f1dd78..d2b97e2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -23,6 +23,7 @@
"convert_bf16_arith_to_f32.mlir"
"convert_bf16_to_uint16_buffers.mlir"
"convert_to_destination_passing_style.mlir"
+ "convolution_to_igemm.mlir"
"convolutions.mlir"
"decompose_affine_ops.mlir"
"decompose_conv2d.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir
new file mode 100644
index 0000000..46f30fe
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir
@@ -0,0 +1,92 @@
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-convolution-to-igemm),canonicalize,cse)" %s | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3)->(d0, d1, d2, d3)>
+func.func public @conv_with_consumer(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf16> {
+ %cst = arith.constant 0.0 : f32
+ %empty = tensor.empty() : tensor<1x14x14x16xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ %0 = linalg.conv_2d_nhwc_hwcf
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
+ outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ %1 = tensor.empty() : tensor<1x14x14x16xf16>
+ %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%0 : tensor<1x14x14x16xf32>) outs(%1 : tensor<1x14x14x16xf16>) {
+ ^bb0(%in: f32, %out: f16):
+ %3 = arith.truncf %in : f32 to f16
+ linalg.yield %3 : f16
+ } -> tensor<1x14x14x16xf16>
+ return %2 : tensor<1x14x14x16xf16>
+}
+// CHECK: func.func public @conv_with_consumer
+// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col
+// CHECK-SAME: : tensor<1x196x36xf32>) -> tensor<1x196x36xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK-SAME: -> tensor<1x196x16xf32>
+// CHECK: %[[MATMUL:.+]] = linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK: %[[TRUNCF:.+]] = linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[TRUNCF]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf16> into tensor<1x14x14x16xf16>
+// CHECK: return %[[EXPANDED]] : tensor<1x14x14x16xf16>
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>,
+ #hal.descriptor_set.binding<2, storage_buffer>
+ ]>
+]>
+#config = #iree_gpu.lowering_config<{thread = [2, 16], subgroup = [2, 16]}>
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+ func.func @fold_with_interface_tensor() {
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<1x16x16x4xf32>>
+ %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<3x3x4x16xf32>>
+ %2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<1x14x14x16xf32>>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 16, 16, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x16x16x4xf32>> -> tensor<1x16x16x4xf32>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 4, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<3x3x4x16xf32>> -> tensor<3x3x4x16xf32>
+ %5 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0, 0], sizes = [1, 14, 14, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<1x14x14x16xf32>> -> tensor<1x14x14x16xf32>
+ %cst = arith.constant 0.0 : f32
+ %fill = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ %6 = linalg.conv_2d_nhwc_hwcf
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%3, %4: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
+ outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0, 0], sizes = [1, 14, 14, 16], strides = [1, 1, 1, 1] : tensor<1x14x14x16xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x14x14x16xf32>>
+ return
+ }
+}
+
+// CHECK: func.func @fold_with_interface_tensor
+// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load {{.*}} -> tensor<1x16x16x4xf32>
+// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load {{.*}} -> tensor<36x16xf32>
+// CHECK-DAG: %[[RES:.+]] = flow.dispatch.tensor.load {{.*}} -> tensor<1x196x16xf32>
+// CHECK-DAG: %[[IM2COL:.+]] = iree_linalg_ext.im2col {{.*}} ins(%[[LHS]] : tensor<1x16x16x4xf32>){{.*}}-> tensor<1x196x36xf32>
+// CHECK-DAG: %[[FILL:.+]] = linalg.fill {{.*}}outs(%[[RES]] : tensor<1x196x16xf32>)
+// CHECK: %[[MATMUL:.+]] = linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME: ins(%[[IM2COL]], %[[RHS]] : tensor<1x196x36xf32>, tensor<36x16xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<1x196x16xf32>) {
+// CHECK: flow.dispatch.tensor.store %[[MATMUL]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3)->(d0, d1, d2, d3)>
+#config = #iree_codegen.lowering_config<tile_sizes = [[0, 1, 4, 32], [0, 1, 2, 4], [0, 0, 0, 0, 1, 1, 4], [0, 1, 0, 0]]>
+func.func public @conv_with_lowering_config(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf32> {
+ %cst = arith.constant 0.0 : f32
+ %empty = tensor.empty() : tensor<1x14x14x16xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ %0 = linalg.conv_2d_nhwc_hwcf {lowering_config = #config,
+ dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
+ outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+// CHECK: func.func public @conv_with_lowering_config
+// CHECK-NOT: iree_linalg_ext.im2col
+// CHECK: linalg.conv_2d_nhwc_hwcf
+// CHECK-SAME: lowering_config
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 957813a..f116b97 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -1054,8 +1054,8 @@
OpPassManager &modulePassManager) {
{
FunctionLikeNest funcPassManager(modulePassManager);
- funcPassManager.addPredicatedPass(
- clLLVMGPUUseIgemm, IREE::LinalgExt::createConvertConv2DToIm2ColOpPass);
+ funcPassManager.addPredicatedPass(clLLVMGPUUseIgemm,
+ createConvolutionToIGEMMPass);
funcPassManager.addPass(createGPUGeneralizeNamedOpsPass);
addCommonTargetExecutablePreprocessingPasses(funcPassManager);
addEncodingToNopPasses(funcPassManager);
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp
index a5bb42a..0e7b3b7 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp
@@ -37,6 +37,8 @@
namespace {
+using ControlFnTy = std::optional<std::function<bool(Operation *)>>;
+
// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing)
// and linalg.matmul.
//
@@ -75,8 +77,16 @@
public:
using OpRewritePattern::OpRewritePattern;
+ ConvertConv2DNhwcHwcf(MLIRContext *context, ControlFnTy controlFn)
+ : OpRewritePattern<linalg::Conv2DNhwcHwcfOp>(context),
+ controlFn(controlFn) {}
+
LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
PatternRewriter &rewriter) const override {
+ if (controlFn.has_value() && !controlFn.value()(convOp)) {
+ return rewriter.notifyMatchFailure(convOp, "controlFn failed.");
+ }
+
auto inputType = llvm::cast<ShapedType>(convOp.getInputs()[0].getType());
auto filterType = llvm::cast<ShapedType>(convOp.getInputs()[1].getType());
auto outputType = llvm::cast<ShapedType>(convOp.getOutputs()[0].getType());
@@ -181,6 +191,9 @@
return success();
}
+
+private:
+ ControlFnTy controlFn;
};
// For nchw, because the channels are to the left of the image shape dimensions,
@@ -192,8 +205,16 @@
public:
using OpRewritePattern::OpRewritePattern;
+ ConvertConv2DNchwFchw(MLIRContext *context, ControlFnTy controlFn)
+ : OpRewritePattern<linalg::Conv2DNchwFchwOp>(context),
+ controlFn(controlFn) {}
+
LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
PatternRewriter &rewriter) const override {
+ if (controlFn.has_value() && !controlFn.value()(convOp)) {
+ return rewriter.notifyMatchFailure(convOp, "controlFn failed.");
+ }
+
auto inputType = llvm::cast<ShapedType>(convOp.getInputs()[0].getType());
auto filterType = llvm::cast<ShapedType>(convOp.getInputs()[1].getType());
auto outputType = llvm::cast<ShapedType>(convOp.getOutputs()[0].getType());
@@ -296,18 +317,19 @@
return success();
}
+
+private:
+ ControlFnTy controlFn;
};
struct ConvertConv2DToIm2ColOpPass
: ConvertConv2DToIm2ColOpBase<ConvertConv2DToIm2ColOpPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
- registry
- .insert<tensor::TensorDialect, IREE::LinalgExt::IREELinalgExtDialect>();
+ registry.insert<tensor::TensorDialect, IREELinalgExtDialect>();
}
void runOnOperation() override {
- MLIRContext *context = &getContext();
RewritePatternSet patterns(&getContext());
- patterns.insert<ConvertConv2DNhwcHwcf, ConvertConv2DNchwFchw>(context);
+ populateConv2DToIm2colOpPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
@@ -317,6 +339,12 @@
} // namespace
+void populateConv2DToIm2colOpPatterns(RewritePatternSet &patterns,
+ ControlFnTy controlFn) {
+ patterns.insert<ConvertConv2DNhwcHwcf, ConvertConv2DNchwFchw>(
+ patterns.getContext(), std::move(controlFn));
+}
+
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createConvertConv2DToIm2ColOpPass() {
return std::make_unique<ConvertConv2DToIm2ColOpPass>();
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
index c5dcdee..04c1e8f 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
@@ -49,6 +49,12 @@
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createConvertConv2DToIm2ColOpPass();
+// Patterns to convert linalg convolution ops into a gemm with an im2col
+// op and reshapes on the inputs.
+void populateConv2DToIm2colOpPatterns(
+ RewritePatternSet &patterns,
+ std::optional<std::function<bool(Operation *)>> controlFn = std::nullopt);
+
// Creates a pass to convert linalg convolution ops into a sequence of
// linalg_ext.winograd.* ops and linalg.batch_matmul ops using the winograd
// tranformation.