[LLVMGPU][NFC] Create LLVMGPU pass for IGEMM (#18871)
This PR refactors the ConvolutionToIGEMM pass to a shared transform
function, and creates a new pass for LLVMGPU. This keeps the lowering
config details in LLVMGPU separate from the common pass, and removes the
need for passing a control function or config function in the pass
constructor. This is also a precursor to adding some more complex logic
in the control function for LLVMGPU, which will be added in a later PR.
---------
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp
index 58b678c..8998b11 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp
@@ -5,6 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
@@ -12,6 +13,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/MLIRContext.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
@@ -26,10 +28,14 @@
using iree_compiler::IREE::LinalgExt::IREELinalgExtDialect;
+/// Pattern to set a lowering configuration on an IGEMM convolution. Searches
+/// for a contraction with a linalg_ext.im2col producer, and calls the configFn
+/// to set the configuration.
+/// TODO(Max191): Use a funcOp walk instead of a pattern for this.
struct SetIGEMMConfiguration final : OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern::OpRewritePattern;
- SetIGEMMConfiguration(MLIRContext *context, ConfigFn configFn)
+ SetIGEMMConfiguration(MLIRContext *context, IGEMMConfigFn configFn)
: OpRewritePattern(context), configFn(configFn) {}
LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
@@ -67,7 +73,7 @@
}
private:
- ConfigFn configFn;
+ IGEMMConfigFn configFn;
};
class ConvolutionToIGEMMPass final
@@ -75,91 +81,87 @@
public:
using ConvolutionToIGEMMPassBase::ConvolutionToIGEMMPassBase;
- explicit ConvolutionToIGEMMPass(ConfigFn configFn) : configFn(configFn) {}
+ ConvolutionToIGEMMPass(std::optional<IGEMMConfigFn> configFn,
+ std::optional<IGEMMControlFn> controlFn)
+ : configFn(configFn), controlFn(controlFn) {}
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<tensor::TensorDialect, IREELinalgExtDialect>();
- }
- void runOnOperation() override {
- MLIRContext *context = &getContext();
-
- // Rewrite convolutions into a im2col and GEMM.
- {
- auto conv2dToIm2colControlFn = [](Operation *conv) {
- // Don't transform convolutions that have a preset lowering config.
- if (getLoweringConfig(conv)) {
- return false;
- }
- return true;
- };
- MLIRContext *context = &getContext();
- RewritePatternSet patterns(context);
- iree_compiler::IREE::LinalgExt::populateConv2DToIm2colOpPatterns(
- patterns, conv2dToIm2colControlFn);
- patterns.add<SetIGEMMConfiguration>(context, configFn);
- if (failed(applyPatternsAndFoldGreedily(getOperation(),
- std::move(patterns)))) {
- return signalPassFailure();
- }
- }
-
- // The im2col transformation collapses some of the dimensions of the
- // convolution operands. Try to push the reshape ops towards the boundaries
- // of the function and fold with interface tensor ops.
- //
- // TODO(Max191): Allow for the im2col op to have multiple M dimensions, and
- // generate a multi-M dim contraction instead of collapsing and
- // propagating reshapes. It should ultimately become a pass option to
- // decide whether to collapse the contraction dimensions into a single
- // M/N/K dimension.
- {
- RewritePatternSet bubbleCollapseShapePatterns(context);
- linalg::ControlFusionFn bubbleUpExpansionControlFn =
- [](OpOperand *fusedOperand) {
- Operation *producer = fusedOperand->get().getDefiningOp();
- Operation *consumer = fusedOperand->getOwner();
-
- // Block only if one of the operations has a lowering configuration
- // which means it likely expects tiling specific to its original
- // shape.
- if (getLoweringConfig(producer) || getLoweringConfig(consumer)) {
- return false;
- }
- return true;
- };
- linalg::populateFoldReshapeOpsByCollapsingPatterns(
- bubbleCollapseShapePatterns, bubbleUpExpansionControlFn);
- // Add patterns to do some additional cleanup (on top of canonicalizations
- // that can be done later) of reshape ops.
- tensor::populateFoldTensorEmptyPatterns(bubbleCollapseShapePatterns);
- linalg::FillOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
- context);
- tensor::CollapseShapeOp::getCanonicalizationPatterns(
- bubbleCollapseShapePatterns, context);
- tensor::EmptyOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
- context);
- tensor::ExpandShapeOp::getCanonicalizationPatterns(
- bubbleCollapseShapePatterns, context);
- populateReshapeToInterfaceTensorPatterns(bubbleCollapseShapePatterns);
- if (failed(applyPatternsAndFoldGreedily(
- getOperation(), std::move(bubbleCollapseShapePatterns)))) {
- return signalPassFailure();
- }
- }
- }
+ void runOnOperation() override;
private:
- ConfigFn configFn = [](linalg::GenericOp genericOp,
- IREE::LinalgExt::Im2colOp im2colOp) {
- return failure();
- };
+ std::optional<IGEMMConfigFn> configFn;
+ std::optional<IGEMMControlFn> controlFn;
};
} // namespace
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createConvolutionToIGEMMPass(ConfigFn configFn) {
- return std::make_unique<ConvolutionToIGEMMPass>(configFn);
+LogicalResult
+convertToIGEMMAndSetConfig(FunctionOpInterface funcOp,
+ std::optional<IGEMMConfigFn> configFn,
+ std::optional<IGEMMControlFn> controlFn) {
+ // Rewrite convolutions into a im2col and GEMM.
+ MLIRContext *context = funcOp->getContext();
+ {
+ RewritePatternSet patterns(context);
+ iree_compiler::IREE::LinalgExt::populateConv2DToIm2colOpPatterns(patterns,
+ controlFn);
+ if (configFn.has_value()) {
+ patterns.add<SetIGEMMConfiguration>(context, configFn.value());
+ }
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return failure();
+ }
+ }
+
+ // The im2col transformation collapses some of the dimensions of the
+ // convolution operands. Try to push the reshape ops towards the boundaries
+ // of the function and fold with interface tensor ops.
+ //
+ // TODO(Max191): Allow for the im2col op to have multiple M dimensions, and
+ // generate a multi-M dim contraction instead of collapsing and
+ // propagating reshapes. It should ultimately become a pass option to
+ // decide whether to collapse the contraction dimensions into a single
+ // M/N/K dimension.
+ {
+ RewritePatternSet bubbleCollapseShapePatterns(context);
+ linalg::ControlFusionFn bubbleUpExpansionControlFn =
+ [](OpOperand *fusedOperand) {
+ Operation *producer = fusedOperand->get().getDefiningOp();
+ Operation *consumer = fusedOperand->getOwner();
+
+ // Block only if one of the operations has a lowering configuration
+ // which means it likely expects tiling specific to its original
+ // shape.
+ if (getLoweringConfig(producer) || getLoweringConfig(consumer)) {
+ return false;
+ }
+ return true;
+ };
+ linalg::populateFoldReshapeOpsByCollapsingPatterns(
+ bubbleCollapseShapePatterns, bubbleUpExpansionControlFn);
+ // Add patterns to do some additional cleanup (on top of canonicalizations
+ // that can be done later) of reshape ops.
+ tensor::populateFoldTensorEmptyPatterns(bubbleCollapseShapePatterns);
+ linalg::FillOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
+ context);
+ tensor::CollapseShapeOp::getCanonicalizationPatterns(
+ bubbleCollapseShapePatterns, context);
+ tensor::EmptyOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
+ context);
+ tensor::ExpandShapeOp::getCanonicalizationPatterns(
+ bubbleCollapseShapePatterns, context);
+ populateReshapeToInterfaceTensorPatterns(bubbleCollapseShapePatterns);
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(bubbleCollapseShapePatterns)))) {
+ return failure();
+ }
+ }
+ return success();
+}
+
+void ConvolutionToIGEMMPass::runOnOperation() {
+ if (failed(convertToIGEMMAndSetConfig(getOperation()))) {
+ return signalPassFailure();
+ }
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h
index 94192d5..eac457d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h
@@ -60,13 +60,6 @@
createConvertToDestinationPassingStylePass(
bool useWARForCooperativeMatrixCodegen);
-using ConfigFn =
- std::function<LogicalResult(linalg::GenericOp, IREE::LinalgExt::Im2colOp)>;
-/// Pass to convert Conv2D ops into IGEMM (Im2colOp + matmul). `configFn` is
-/// used to set lowering configurations on the resulting ops, if necessary.
-std::unique_ptr<InterfacePass<FunctionOpInterface>>
-createConvolutionToIGEMMPass(ConfigFn configFn);
-
std::unique_ptr<Pass> createDecomposeSoftmaxPass(bool useFusion);
/// Pass to perform linalg on tensor bufferization. The function passed into
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td
index ff281d6..6a5a9b5 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td
@@ -83,6 +83,10 @@
InterfacePass<"iree-codegen-convolution-to-igemm", "mlir::FunctionOpInterface"> {
let summary =
"Transforms convolution operations into an implicit GEMM format.";
+ let dependentDialects = [
+ "tensor::TensorDialect",
+ "iree_compiler::IREE::LinalgExt::IREELinalgExtDialect"
+ ];
}
def DecomposeAffineOpsPass: Pass<"iree-codegen-decompose-affine-ops"> {
diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.h b/compiler/src/iree/compiler/Codegen/Common/Transforms.h
index 13cdbf5..0a00034 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.h
@@ -18,6 +18,17 @@
namespace mlir::iree_compiler {
+using IGEMMConfigFn =
+ std::function<LogicalResult(linalg::GenericOp, IREE::LinalgExt::Im2colOp)>;
+using IGEMMControlFn = std::function<bool(Operation *)>;
+
+/// Converts conv_2d ops into linalg_ext.im2col + matmul, and sets a lowering
+/// configuration on the matmul.
+LogicalResult convertToIGEMMAndSetConfig(
+ FunctionOpInterface funcOp,
+ std::optional<IGEMMConfigFn> configFn = std::nullopt,
+ std::optional<IGEMMControlFn> controlFn = std::nullopt);
+
/// Eliminates tensor.empty ops to avoid buffer allocations.
LogicalResult eliminateEmptyTensors(
RewriterBase &rewriter, Operation *op,
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir
index 3d5494e..3373fda 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir
@@ -69,25 +69,6 @@
// -----
-#map = affine_map<(d0, d1, d2, d3)->(d0, d1, d2, d3)>
-#config = #iree_codegen.lowering_config<tile_sizes = [[0, 1, 4, 32], [0, 1, 2, 4], [0, 0, 0, 0, 1, 1, 4], [0, 1, 0, 0]]>
-func.func public @conv_with_lowering_config(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf32> {
- %cst = arith.constant 0.0 : f32
- %empty = tensor.empty() : tensor<1x14x14x16xf32>
- %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
- %0 = linalg.conv_2d_nhwc_hwcf {lowering_config = #config,
- dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
- ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
- outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
- return %0 : tensor<1x14x14x16xf32>
-}
-// CHECK: func.func public @conv_with_lowering_config
-// CHECK-NOT: iree_linalg_ext.im2col
-// CHECK: linalg.conv_2d_nhwc_hwcf
-// CHECK-SAME: lowering_config
-
-// -----
-
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
index b074612..3d8c7a2 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
@@ -95,6 +95,7 @@
"LLVMGPUCastTypeToFitMMA.cpp",
"LLVMGPUConfigureTensorLayouts.cpp",
"LLVMGPUConfigureVectorLayouts.cpp",
+ "LLVMGPUConvolutionToIGEMM.cpp",
"LLVMGPULowerExecutableTarget.cpp",
"LLVMGPUPackSharedMemoryAlloc.cpp",
"LLVMGPUPrefetching.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index 6a92f60..9016d63 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -80,6 +80,7 @@
"LLVMGPUCastTypeToFitMMA.cpp"
"LLVMGPUConfigureTensorLayouts.cpp"
"LLVMGPUConfigureVectorLayouts.cpp"
+ "LLVMGPUConvolutionToIGEMM.cpp"
"LLVMGPULowerExecutableTarget.cpp"
"LLVMGPUPackSharedMemoryAlloc.cpp"
"LLVMGPUPrefetching.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConvolutionToIGEMM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConvolutionToIGEMM.cpp
new file mode 100644
index 0000000..b88696a
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConvolutionToIGEMM.cpp
@@ -0,0 +1,66 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/Common/Transforms.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+
+#define DEBUG_TYPE "iree-llvmgpu-convolution-to-igemm"
+
+namespace mlir::iree_compiler {
+
+#define GEN_PASS_DEF_LLVMGPUCONVOLUTIONTOIGEMMPASS
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
+
+namespace {
+
+/// Function for setting lowering configurations on contractions resulting from
+/// the IGEMM transformation. This currently uses the TileAndFuse pipeline, and
+/// tries to target MMA intrinsics.
+static LogicalResult llvmgpuConfigFn(linalg::GenericOp genericOp,
+ IREE::LinalgExt::Im2colOp im2colOp) {
+ auto funcOp = genericOp->getParentOfType<FunctionOpInterface>();
+ if (!funcOp) {
+ return genericOp.emitError("cannot find parent funcOp");
+ }
+ IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp);
+ if (!target) {
+ return funcOp.emitError("missing GPU target in parent funcOp");
+ }
+ if (failed(IREE::GPU::setMatmulLoweringConfig(target, funcOp, genericOp))) {
+ return IREE::GPU::setTileAndFuseLoweringConfig(target, funcOp, genericOp);
+ }
+ return success();
+}
+
+static bool llvmgpuControlFn(Operation *op) {
+ // Do not convert anything that already has a lowering configuration.
+ if (getLoweringConfig(op)) {
+ return false;
+ }
+ return true;
+}
+
+struct LLVMGPUConvolutionToIGEMMPass final
+ : impl::LLVMGPUConvolutionToIGEMMPassBase<LLVMGPUConvolutionToIGEMMPass> {
+ using impl::LLVMGPUConvolutionToIGEMMPassBase<
+ LLVMGPUConvolutionToIGEMMPass>::LLVMGPUConvolutionToIGEMMPassBase;
+
+ void runOnOperation() override;
+};
+
+void LLVMGPUConvolutionToIGEMMPass::runOnOperation() {
+ if (failed(convertToIGEMMAndSetConfig(getOperation(), llvmgpuConfigFn,
+ llvmgpuControlFn))) {
+ return signalPassFailure();
+ }
+}
+
+} // namespace
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 51fcc6b..aab73c9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -1170,29 +1170,12 @@
// Common Pass Pipelines
//===----------------------------------------------------------------------===//
-static LogicalResult igemmConfigFn(linalg::GenericOp genericOp,
- IREE::LinalgExt::Im2colOp im2colOp) {
- auto funcOp = genericOp->getParentOfType<FunctionOpInterface>();
- if (!funcOp) {
- return genericOp.emitError("cannot find parent funcOp");
- }
- IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp);
- if (!target) {
- return funcOp.emitError("missing GPU target in parent funcOp");
- }
- if (failed(IREE::GPU::setMatmulLoweringConfig(target, funcOp, genericOp))) {
- return IREE::GPU::setTileAndFuseLoweringConfig(target, funcOp, genericOp);
- }
- return success();
-}
-
static void buildLLVMGPUCodegenConfigurationPassPipelineImpl(
OpPassManager &modulePassManager) {
{
FunctionLikeNest funcPassManager(modulePassManager);
- funcPassManager.addPredicatedPass(clLLVMGPUUseIgemm, []() {
- return createConvolutionToIGEMMPass(igemmConfigFn);
- });
+ funcPassManager.addPredicatedPass(clLLVMGPUUseIgemm,
+ createLLVMGPUConvolutionToIGEMMPass);
funcPassManager.addPass(createGPUGeneralizeNamedOpsPass);
addCommonTargetExecutablePreprocessingPasses(funcPassManager);
addEncodingToNopPasses(funcPassManager);
@@ -1242,9 +1225,8 @@
OpPassManager &modulePassManager) {
{
FunctionLikeNest funcPassManager(modulePassManager);
- funcPassManager.addPredicatedPass(clLLVMGPUUseIgemm, []() {
- return createConvolutionToIGEMMPass(igemmConfigFn);
- });
+ funcPassManager.addPredicatedPass(clLLVMGPUUseIgemm,
+ createLLVMGPUConvolutionToIGEMMPass);
funcPassManager.addPass(createGPUGeneralizeNamedOpsPass);
addCommonTargetExecutablePreprocessingPasses(funcPassManager);
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
index 815a82f..aa6b552 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
@@ -87,6 +87,17 @@
let summary = "Pass to set layouts for vector distribution";
}
+def LLVMGPUConvolutionToIGEMMPass :
+ InterfacePass<"iree-llvmgpu-convolution-to-igemm", "mlir::FunctionOpInterface"> {
+ let summary = "Pass to convert conv_2d ops to igemm and set a lowering configuration.";
+ let dependentDialects = [
+ "tensor::TensorDialect",
+ "iree_compiler::IREE::Codegen::IREECodegenDialect",
+ "iree_compiler::IREE::GPU::IREEGPUDialect",
+ "iree_compiler::IREE::LinalgExt::IREELinalgExtDialect"
+ ];
+}
+
def LLVMGPULowerExecutableTargetPass :
InterfacePass<"iree-llvmgpu-lower-executable-target", "mlir::FunctionOpInterface"> {
let summary = "Perform lowering of executable target using one of the IREE::HAL::DispatchLoweringPassPipeline";
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index 00bc6f9..4097320 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -49,6 +49,7 @@
"legalize.mlir",
"linalg_transform.mlir",
"llvmgpu_bufferize.mlir",
+ "llvmgpu_convolution_to_igemm.mlir",
"pack_pipeline_test.mlir",
"pack_shared_memory_alloc.mlir",
"prefetch_shared_memory.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 6be97c0..2a86fd3 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -40,6 +40,7 @@
"legalize.mlir"
"linalg_transform.mlir"
"llvmgpu_bufferize.mlir"
+ "llvmgpu_convolution_to_igemm.mlir"
"nvvm_extract_address_computation.mlir"
"nvvm_mma_sync_pipeline_test.mlir"
"nvvm_pipeline_test.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir
new file mode 100644
index 0000000..1fa2bae
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir
@@ -0,0 +1,36 @@
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 --pass-pipeline="builtin.module(func.func(iree-llvmgpu-convolution-to-igemm),canonicalize,cse)" %s | FileCheck %s
+
+#config = #iree_codegen.lowering_config<tile_sizes = [[0, 1, 4, 32], [0, 1, 2, 4], [0, 0, 0, 0, 1, 1, 4], [0, 1, 0, 0]]>
+func.func public @conv_with_lowering_config(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf32> {
+ %cst = arith.constant 0.0 : f32
+ %empty = tensor.empty() : tensor<1x14x14x16xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ %0 = linalg.conv_2d_nhwc_hwcf {lowering_config = #config,
+ dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
+ outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+// CHECK: func.func public @conv_with_lowering_config
+// CHECK-NOT: iree_linalg_ext.im2col
+// CHECK: linalg.conv_2d_nhwc_hwcf
+// CHECK-SAME: lowering_config
+
+// -----
+
+func.func public @set_lowering_config(%arg0: tensor<1x34x34x128xf32>, %arg1: tensor<3x3x128x128xf32>) -> tensor<1x32x32x128xf32> {
+ %cst = arith.constant 0.0 : f32
+ %empty = tensor.empty() : tensor<1x32x32x128xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x32x32x128xf32>) -> tensor<1x32x32x128xf32>
+ %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%arg0, %arg1: tensor<1x34x34x128xf32>, tensor<3x3x128x128xf32>)
+ outs(%fill: tensor<1x32x32x128xf32>) -> tensor<1x32x32x128xf32>
+ return %0 : tensor<1x32x32x128xf32>
+}
+// CHECK: func.func public @set_lowering_config
+// CHECK: iree_linalg_ext.im2col
+// CHECK: linalg.generic
+// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<
+// CHECK-SAME: {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>,
+// CHECK-SAME: promote_operands = [0, 1], reduction = [0, 0, 0, 0, 8],
+// CHECK-SAME: subgroup = [0, 0, 2, 2, 0], workgroup = [1, 1, 2, 8, 0]}>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp
index 6e699fd..131ff3e 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp
@@ -38,7 +38,7 @@
namespace {
-using ControlFnTy = std::optional<std::function<bool(Operation *)>>;
+using ControlFnTy = std::function<bool(Operation *)>;
// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing)
// and linalg.matmul.
@@ -78,7 +78,8 @@
public:
using OpRewritePattern::OpRewritePattern;
- ConvertConv2DNhwcHwcf(MLIRContext *context, ControlFnTy controlFn)
+ ConvertConv2DNhwcHwcf(MLIRContext *context,
+ std::optional<ControlFnTy> controlFn)
: OpRewritePattern<linalg::Conv2DNhwcHwcfOp>(context),
controlFn(controlFn) {}
@@ -192,7 +193,7 @@
}
private:
- ControlFnTy controlFn;
+ std::optional<ControlFnTy> controlFn;
};
// For nchw, because the channels are to the left of the image shape dimensions,
@@ -204,7 +205,8 @@
public:
using OpRewritePattern::OpRewritePattern;
- ConvertConv2DNchwFchw(MLIRContext *context, ControlFnTy controlFn)
+ ConvertConv2DNchwFchw(MLIRContext *context,
+ std::optional<ControlFnTy> controlFn)
: OpRewritePattern<linalg::Conv2DNchwFchwOp>(context),
controlFn(controlFn) {}
@@ -314,7 +316,7 @@
}
private:
- ControlFnTy controlFn;
+ std::optional<ControlFnTy> controlFn;
};
struct ConvertConv2DToIm2ColOpPass final
@@ -335,7 +337,7 @@
} // namespace
void populateConv2DToIm2colOpPatterns(RewritePatternSet &patterns,
- ControlFnTy controlFn) {
+ std::optional<ControlFnTy> controlFn) {
patterns.insert<ConvertConv2DNhwcHwcf, ConvertConv2DNchwFchw>(
patterns.getContext(), std::move(controlFn));
}
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
index 1e858df..cc894b3 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
@@ -28,8 +28,10 @@
splitReduction(RewriterBase &rewriter, LinalgExt::TopkOp topkOp,
const TopkSplitReductionControlFn &splitReductionFn);
-// Patterns to convert linalg convolution ops into a gemm with an im2col
-// op and reshapes on the inputs.
+/// Patterns to convert linalg convolution ops into a gemm with an im2col
+/// op and reshapes on the inputs.
+/// TODO(Max191): Maybe move to transforms and use a funcOp walk instead of a
+/// rewrite pattern for this.
void populateConv2DToIm2colOpPatterns(
RewritePatternSet &patterns,
std::optional<std::function<bool(Operation *)>> controlFn = std::nullopt);