[Im2col] Add decomposition pass for iree_linalg_ext.im2col (#17728)
This PR adds a pass for decomposing the `iree_linalg_ext.im2col` op into
serial loops of `extract_slice->copy->insert_slice`. The pass tries to
keep the innermost dimension of the iteration space un-tiled when the
inner slice is contiguous, but otherwise falls back to serial loops of
scalar slices. The outer dimensions (`B` and `M`) of the iteration
space, however, are always tiled to 1.
---------
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 7002342..fd907fe 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -774,6 +774,7 @@
def IREELinalgExt_Im2colOp : IREELinalgExt_Op<"im2col",
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
index b4370e7..1db43d8 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
@@ -8,6 +8,8 @@
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -132,6 +134,40 @@
return genericOp.getResult(0);
}
+// Helper method to check if a slice will be contiguous given the offset,
+// slice size. This checks that `inputSize` and `offset` are both evenly
+// divisible by `tileSize`.
+static bool willBeContiguousSlice(OpFoldResult inputSize, OpFoldResult tileSize,
+ OpFoldResult offset) {
+ auto constInputSize = getConstantIntValue(inputSize);
+ auto constTileSize = getConstantIntValue(tileSize);
+ if (!constTileSize.has_value() || !constInputSize.has_value() ||
+ constInputSize.value() % constTileSize.value() != 0) {
+ return false;
+ }
+ auto constOffset = getConstantIntValue(offset);
+ if (constOffset.has_value() &&
+ constOffset.value() % constTileSize.value() == 0) {
+ return true;
+ }
+ auto affineOp = cast<Value>(offset).getDefiningOp<affine::AffineApplyOp>();
+ return affineOp &&
+ affineOp.getMap().getResult(0).isMultipleOf(constTileSize.value());
+}
+
+// Helper method to add 2 OpFoldResult inputs with affine.apply.
+static OpFoldResult addOfrs(OpBuilder &builder, Location loc, OpFoldResult a,
+ OpFoldResult b) {
+ AffineExpr d0, d1;
+ bindDims(builder.getContext(), d0, d1);
+ auto addMap = AffineMap::get(2, 0, {d0 + d1});
+ return affine::makeComposedFoldedAffineApply(builder, loc, addMap, {a, b});
+}
+
+//===----------------------------------------------------------------------===//
+// OnlineAttentionOp
+//===----------------------------------------------------------------------===//
+
FailureOr<SmallVector<Value>>
OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
Location loc = getLoc();
@@ -225,4 +261,257 @@
return SmallVector<Value>{newAcc, newMax, newSum};
}
+//===----------------------------------------------------------------------===//
+// Im2colOp
+//===----------------------------------------------------------------------===//
+
+/// Decomposition implementation for iree_linalg_ext.im2col op.
+/// The im2col op is decomposed into serial loops of `insert->extract->copy`.
+/// The `batch` and `M` dimensions of the operation iteration space are always
+/// tiled to 1, and the `K` dimension is left un-tiled if possible. When the
+/// full `K` dimension is a contiguous slice of the input tensor, the K dim
+/// can be left un-tiled so it can be vectorized. Otherwise, it will be tiled
+/// to 1 along with the `batch` and `M` dimensions.
+/// TODO(Max191): Fallback to larger tile sizes instead of immediately tiling K
+/// dimension to 1 when non-contiguous.
+///
+/// The simple decomposition (with K tiled to 1) will look like:
+/// ```
+/// %im2col = iree_linalg_ext.im2col
+/// strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+/// m_offset = [%m_off] k_offset = [%k_off]
+/// batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+/// ins(%in : tensor<2x34x34x640xf32>)
+/// outs(%out : tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
+/// ```
+/// Decomposes to:
+/// ```
+/// scf.for %B = %c0 to %c2 step %c1
+/// scf.for %M = %c0 to %c4 step %c1
+/// scf.for %K = %c0 to %c8 step %c1
+/// %slice = tensor.extract_slice %in[%B, %h, %w, %k] ... to tensor<1xf32>
+/// %copy = linalg.copy ins(%slice) outs(%out)
+/// %insert = tensor.insert_slice %copy into %loop_arg
+/// ```
+/// Where the offsets are computed as:
+/// `%h` = `(%m_off + %M) / 32 + ((%k_off + %K) / 640) / 3`
+/// `%w` = `(%m_off + %M) mod 32 + ((%k_off + %K) / 640) mod 3`
+/// `%k` = `(%k_off + %K) mod 640`
+///
+FailureOr<SmallVector<Value>> Im2colOp::decomposeOperation(OpBuilder &b) {
+ Location loc = getLoc();
+ Value inputSlice = getInput();
+ // Unroll all but the K loop
+ SmallVector<OpFoldResult> kOffset = getMixedKOffset();
+ SmallVector<OpFoldResult> mOffset = getMixedMOffset();
+ // Only support single K and M output dimension for now.
+ if (kOffset.size() != 1 || mOffset.size() != 1) {
+ return failure();
+ }
+
+ // Step 1: Tile the im2col op to loops with contiguous slices in the
+ // innermost loop.
+ //
+ // If the `kOffset` will index to a full contiguous slice of the K dim of
+ // the input tensor, then don't tile the K loop of the im2col op and
+ // maintain a larger contiguous slice.
+ SmallVector<Range> iterationDomain(getIterationDomain(b));
+ OpFoldResult kTileSize = iterationDomain.back().size;
+ auto constKTileSize = getConstantIntValue(kTileSize);
+ if (constKTileSize) {
+ kTileSize = b.getIndexAttr(constKTileSize.value());
+ }
+ SmallVector<OpFoldResult> inputSizes =
+ tensor::getMixedSizes(b, loc, getInput());
+ // Find the innermost non-batch dimension. This dimension is the fastest
+ // changing dimension with the K dimension of the im2col iteration domain.
+ // This means it is the innermost dimension of the extract_slice on the
+ // input tensor, and the slice wants to be contiguous along this dimension.
+ SetVector<int64_t> batchPosSet(getBatchPos().begin(), getBatchPos().end());
+ OpFoldResult innerSliceSize;
+ for (int idx = inputSizes.size() - 1; idx >= 0; --idx) {
+ if (!batchPosSet.contains(idx)) {
+ innerSliceSize = inputSizes[idx];
+ break;
+ }
+ }
+ bool tileK =
+ !willBeContiguousSlice(innerSliceSize, kTileSize, kOffset.front());
+ if (!tileK) {
+ iterationDomain.pop_back();
+ } else {
+ kTileSize = b.getIndexAttr(1);
+ }
+
+ // Build loop nest.
+ SmallVector<Value> lbs, ubs, steps;
+ for (auto range : iterationDomain) {
+ lbs.push_back(getValueOrCreateConstantIndexOp(b, loc, range.offset));
+ ubs.push_back(getValueOrCreateConstantIndexOp(b, loc, range.size));
+ steps.push_back(getValueOrCreateConstantIndexOp(b, loc, range.stride));
+ }
+ scf::LoopNest loopNest = scf::buildLoopNest(
+ b, loc, lbs, ubs, steps, getOutput(),
+ [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs,
+ ValueRange iterArgs) -> scf::ValueVector { return iterArgs; });
+ SmallVector<Value> ivs;
+ for (scf::ForOp loop : loopNest.loops) {
+ ivs.push_back(loop.getInductionVar());
+ }
+
+ // Step 2: Compute indices into the input tensor for extract_slice.
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPoint(loopNest.loops.front());
+ SetVector<int64_t> mPosSet(getMPos().begin(), getMPos().end());
+
+ // Compute the basis for the iteration space of the convolution window
+ // (i.e., the H and W dims of the convolution output).
+ SmallVector<Value> mBasis;
+ ArrayRef<int64_t> strides = getStrides();
+ ArrayRef<int64_t> dilations = getDilations();
+ SmallVector<OpFoldResult> kernelSize = getMixedKernelSize();
+ for (auto [idx, pos] : llvm::enumerate(getMPos())) {
+ AffineExpr x, k;
+ bindDims(getContext(), x, k);
+ AffineExpr mapExpr =
+ (x - 1 - (k - 1) * dilations[idx]).floorDiv(strides[idx]) + 1;
+ OpFoldResult size = affine::makeComposedFoldedAffineApply(
+ b, loc, AffineMap::get(2, 0, {mapExpr}, getContext()),
+ {inputSizes[pos], kernelSize[idx]});
+ mBasis.push_back(getValueOrCreateConstantIndexOp(b, loc, size));
+ }
+
+ // Delinearize the k_offset into an offset into the convolution window and
+ // any reduced channels. For an NHWC conv2d, the basis for delinearization
+ // would be [P, Q, C] for a PxQ kernel with C channels.
+ Location nestedLoc =
+ loopNest.loops.back().getBody()->getTerminator()->getLoc();
+ b.setInsertionPointToStart(loopNest.loops.back().getBody());
+
+ SmallVector<OpFoldResult> kBasis;
+ SmallVector<int64_t> mKernelIdx(getInputRank(), -1);
+ for (auto [idx, mPos] : enumerate(getMPos())) {
+ mKernelIdx[mPos] = idx;
+ }
+ for (auto [idx, size] : enumerate(inputSizes)) {
+ if (batchPosSet.contains(idx))
+ continue;
+ if (mPosSet.contains(idx)) {
+ kBasis.push_back(kernelSize[mKernelIdx[idx]]);
+ continue;
+ }
+ kBasis.push_back(size);
+ }
+ OpFoldResult kIndex = kOffset.front();
+ if (tileK) {
+ kIndex = addOfrs(b, nestedLoc, kOffset.front(), ivs.back());
+ }
+ FailureOr<SmallVector<Value>> maybeDelinKOffset = affine::delinearizeIndex(
+ b, nestedLoc, getValueOrCreateConstantIndexOp(b, loc, kIndex),
+ getValueOrCreateConstantIndexOp(b, loc, (kBasis)));
+ if (failed(maybeDelinKOffset)) {
+ return failure();
+ }
+ SmallVector<Value> delinKOffset = maybeDelinKOffset.value();
+ // Split the delinearized offsets into the window offsets (for M offsets)
+ // and the K offsets for the input tensor.
+ SmallVector<Value> windowOffset, inputKOffset;
+ int delinKIdx = 0;
+ for (int i = 0; i < getInputRank(); ++i) {
+ if (batchPosSet.contains(i))
+ continue;
+ if (mPosSet.contains(i)) {
+ windowOffset.push_back(delinKOffset[delinKIdx++]);
+ continue;
+ }
+ inputKOffset.push_back(delinKOffset[delinKIdx++]);
+ }
+
+ // Compute offsets for extract. Start by delinearizing the combined offset
+ // of m_offset and the offset from the tiled loop, using the mBasis. This
+ // will give an index into the delinearized output space of the convolution.
+ Value mArg = tileK ? ivs[ivs.size() - 2] : ivs.back();
+ OpFoldResult linearMOffset = addOfrs(b, nestedLoc, mArg, mOffset.front());
+ FailureOr<SmallVector<Value>> maybeDelinMOffset = affine::delinearizeIndex(
+ b, nestedLoc,
+ getValueOrCreateConstantIndexOp(b, nestedLoc, linearMOffset), mBasis);
+ if (failed(maybeDelinMOffset)) {
+ return failure();
+ }
+ SmallVector<Value> delinMOffset = maybeDelinMOffset.value();
+
+ // Compute the final offsets into the input tensor.
+ OpFoldResult zero = b.getIndexAttr(0);
+ OpFoldResult one = b.getIndexAttr(1);
+ SmallVector<OpFoldResult> sliceOffsets(getInputRank(), zero);
+ SmallVector<OpFoldResult> sliceStrides(getInputRank(), one);
+ SmallVector<OpFoldResult> sliceSizes(getInputRank(), one);
+ // Add the offset into the convolution window, and account for strides and
+ // dilations.
+ AffineExpr mOff, wOff;
+ bindDims(b.getContext(), mOff, wOff);
+ for (auto [idx, mPos] : llvm::enumerate(getMPos())) {
+ auto map =
+ AffineMap::get(2, 0, {mOff * strides[idx] + wOff * dilations[idx]});
+ OpFoldResult offset = affine::makeComposedFoldedAffineApply(
+ b, nestedLoc, map, {delinMOffset[idx], windowOffset[idx]});
+ sliceOffsets[mPos] = offset;
+ sliceSizes[mPos] = one;
+ }
+ // Set the K offset and size for the input tensor.
+ const int64_t kPos = getKPos().front();
+ sliceOffsets[kPos] = inputKOffset.front();
+ sliceSizes[kPos] = kTileSize;
+
+ // Set the batch offsets for the input tensor.
+ int ivIdx = 0;
+ for (auto bPos : getBatchPos()) {
+ sliceOffsets[bPos] = ivs[ivIdx++];
+ }
+
+ // Step 3. Decompose the im2col op into:
+ // ```
+ // %extract = tensor.extract_slice %input
+ // %copy = linalg.copy ins(%extract) outs(%out_slice)
+ // %insert = tensor.insert_slice %copy into %loop_arg
+ // ```
+ //
+ // Extract a slice from the input tensor.
+ ShapedType outputType = getOutputType();
+ SmallVector<int64_t> kTileSizeStatic;
+ SmallVector<Value> kTileSizeDynamic;
+ dispatchIndexOpFoldResult(kTileSize, kTileSizeDynamic, kTileSizeStatic);
+ auto extractType = cast<RankedTensorType>(outputType.clone(kTileSizeStatic));
+ auto extract =
+ b.create<tensor::ExtractSliceOp>(nestedLoc, extractType, inputSlice,
+ sliceOffsets, sliceSizes, sliceStrides);
+
+ // Insert the slice into the destination tensor.
+ sliceOffsets = SmallVector<OpFoldResult>(getOutputRank(), zero);
+ sliceSizes = SmallVector<OpFoldResult>(getOutputRank(), one);
+ sliceStrides = SmallVector<OpFoldResult>(getOutputRank(), one);
+ sliceSizes.back() = kTileSize;
+ for (auto [idx, iv] : llvm::enumerate(ivs)) {
+ sliceOffsets[idx] = iv;
+ }
+ // Insert a `linalg.copy` so there is something to vectorize in the
+ // decomposition. Without this copy, the extract and insert slice ops
+ // do not get vectorized, and the sequence becomes a scalar memref.copy.
+ // This memref.copy could be vectorized after bufferization, but it is
+ // probably better to vectorize during generic vectorization.
+ Value copyDest = b.create<tensor::ExtractSliceOp>(
+ nestedLoc, extractType, loopNest.loops.back().getRegionIterArg(0),
+ sliceOffsets, sliceSizes, sliceStrides);
+ auto copiedSlice =
+ b.create<linalg::CopyOp>(nestedLoc, extract.getResult(), copyDest);
+ auto insert =
+ b.create<tensor::InsertSliceOp>(nestedLoc, copiedSlice.getResult(0),
+ loopNest.loops.back().getRegionIterArg(0),
+ sliceOffsets, sliceSizes, sliceStrides);
+ auto yieldOp =
+ cast<scf::YieldOp>(loopNest.loops.back().getBody()->getTerminator());
+ yieldOp->getOpOperands().front().assign(insert.getResult());
+ return SmallVector<Value>({loopNest.results[0]});
+}
+
} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel
index 1a20a8f..eda9324 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel
@@ -34,6 +34,7 @@
"ConvertConv2DToWinograd.cpp",
"ConvertToLoops.cpp",
"DecomposeAttention.cpp",
+ "DecomposeIm2col.cpp",
"DecomposeWinogradPass.cpp",
"PadContractionToBlockSize.cpp",
"PassDetail.h",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
index 19c7522..bab7997 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
@@ -30,6 +30,7 @@
"ConvertConv2DToWinograd.cpp"
"ConvertToLoops.cpp"
"DecomposeAttention.cpp"
+ "DecomposeIm2col.cpp"
"DecomposeWinogradPass.cpp"
"PadContractionToBlockSize.cpp"
"PassDetail.h"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeIm2col.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeIm2col.cpp
new file mode 100644
index 0000000..a8fc5f2
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeIm2col.cpp
@@ -0,0 +1,67 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::iree_compiler::IREE::LinalgExt {
+namespace {
+
+/// Pattern to decompose the tiled im2col op.
+struct DecomposeIm2col : public OpRewritePattern<Im2colOp> {
+ using OpRewritePattern<Im2colOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(Im2colOp im2colOp,
+ PatternRewriter &rewriter) const override {
+ FailureOr<SmallVector<Value>> decomposedIm2col =
+ im2colOp.decomposeOperation(rewriter);
+ if (failed(decomposedIm2col)) {
+ return failure();
+ }
+ rewriter.replaceOp(im2colOp, decomposedIm2col.value().front());
+ return success();
+ }
+};
+
+} // namespace
+
+namespace {
+struct DecomposeIm2colPass : public DecomposeIm2colBase<DecomposeIm2colPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<
+ affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
+ linalg::LinalgDialect, scf::SCFDialect, tensor::TensorDialect>();
+ }
+
+ void runOnOperation() override;
+};
+} // namespace
+
+void DecomposeIm2colPass::runOnOperation() {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ patterns.add<DecomposeIm2col>(context);
+ memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
+ return signalPassFailure();
+ }
+}
+
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createDecomposeIm2colPass() {
+ return std::make_unique<DecomposeIm2colPass>();
+}
+
+} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
index 165430a..a549b86 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
@@ -36,6 +36,10 @@
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createTopkSplitReductionPass();
+/// Decompose im2col ops into a serial loop of insert and extract slice ops.
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createDecomposeIm2colPass();
+
/// Decompose the winograd transform ops into a sequence of linalg ops.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createDecomposeWinogradTransformPass();
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
index ee801b6..0c2f0ca 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
@@ -53,6 +53,14 @@
];
}
+def DecomposeIm2col :
+ InterfacePass<"iree-linalg-ext-decompose-im2col", "mlir::FunctionOpInterface"> {
+ let summary =
+ "Decomposes im2col ops into insert and extract slice ops";
+ let constructor = "mlir::iree_compiler::IREE::LinalgExt::"
+ "createDecomposeIm2colPass()";
+}
+
def DecomposeWinogradTransform :
InterfacePass<"iree-linalg-ext-decompose-winograd", "mlir::FunctionOpInterface"> {
let summary =
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
index fe21d60..87ce568 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
@@ -19,6 +19,7 @@
"conv2d_to_winograd.mlir",
"convert_to_loops.mlir",
"decompose_attention.mlir",
+ "decompose_im2col.mlir",
"decompose_online_attention.mlir",
"decompose_winograd.mlir",
"distribution.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
index 7ef5fd7..87030f0 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
@@ -17,6 +17,7 @@
"conv2d_to_winograd.mlir"
"convert_to_loops.mlir"
"decompose_attention.mlir"
+ "decompose_im2col.mlir"
"decompose_online_attention.mlir"
"decompose_winograd.mlir"
"distribution.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_im2col.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_im2col.mlir
new file mode 100644
index 0000000..ab627b6
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_im2col.mlir
@@ -0,0 +1,73 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-linalg-ext-decompose-im2col))" --split-input-file %s | FileCheck %s
+
+#map = affine_map<(d0) -> (d0 * 4)>
+module {
+ func.func @im2col_untile_k(%arg0: tensor<2x34x34x640xf32>, %m_size: index, %m_off: index, %k: index) -> tensor<2x?x4xf32> {
+ %0 = tensor.empty(%m_size) : tensor<2x?x4xf32>
+ %k_off = affine.apply #map(%k)
+ %7 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] m_offset = [%m_off] k_offset = [%k_off] batch_pos = [0] m_pos = [1, 2] k_pos = [3] ins(%arg0 : tensor<2x34x34x640xf32>) outs(%0 : tensor<2x?x4xf32>) -> tensor<2x?x4xf32>
+ return %7 : tensor<2x?x4xf32>
+ }
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 160) * 640)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) floordiv 32 + s1 floordiv 480)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) mod 32 + s1 floordiv 160 - (s1 floordiv 480) * 3)>
+// CHECK: func.func @im2col_untile_k(%[[ARG0:.+]]: tensor<2x34x34x640xf32>
+// CHECK-SAME: %[[mSIZE:.+]]: index, %[[mOFF:.+]]: index, %[[K:.+]]: index)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[OUT_TILE:.+]] = tensor.empty(%[[mSIZE]]) : tensor<2x?x4xf32>
+// CHECK: %[[bLOOP:.+]] = scf.for %[[b:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[OUT0:.+]] = %[[OUT_TILE]]) -> (tensor<2x?x4xf32>) {
+// CHECK: %[[mLOOP:.+]] = scf.for %[[m:.+]] = %[[C0]] to %[[mSIZE]] step %[[C1]] iter_args(%[[OUT1:.+]] = %[[OUT0]]) -> (tensor<2x?x4xf32>) {
+// CHECK-DAG: %[[kIDX:.+]] = affine.apply #[[MAP]]()[%[[K]]]
+// CHECK-DAG: %[[hIDX:.+]] = affine.apply #[[MAP1]](%[[m]])[%[[mOFF]], %[[K]]]
+// CHECK-DAG: %[[wIDX:.+]] = affine.apply #[[MAP2]](%[[m]])[%[[mOFF]], %[[K]]]
+// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[b]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<4xf32>
+// CHECK: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[OUT1]][%[[b]], %[[m]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x?x4xf32> to tensor<4xf32>
+// CHECK: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<4xf32>) outs(%[[OUT_SLICE]] : tensor<4xf32>) -> tensor<4xf32>
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[OUT1]][%[[b]], %[[m]], 0] [1, 1, 4] [1, 1, 1] : tensor<4xf32> into tensor<2x?x4xf32>
+// CHECK: scf.yield %[[INSERT]] : tensor<2x?x4xf32>
+// CHECK: }
+// CHECK: scf.yield %[[mLOOP]] : tensor<2x?x4xf32>
+// CHECK: }
+// CHECK: return %[[bLOOP]] : tensor<2x?x4xf32>
+
+// -----
+
+module {
+ func.func @im2col_transposed_m_pos(%arg0: tensor<640x2x101x172xf32>, %m_size: index, %k_size: index, %m_off: index, %k_off: index) -> tensor<2x?x?xf32> {
+ %c2 = arith.constant 2 : index
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %0 = tensor.empty(%m_size, %k_size) : tensor<2x?x?xf32>
+ %8 = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2] m_offset = [%m_off] k_offset = [%k_off] batch_pos = [1] m_pos = [3, 2] k_pos = [0] ins(%arg0 : tensor<640x2x101x172xf32>) outs(%0 : tensor<2x?x?xf32>) -> tensor<2x?x?xf32>
+ return %8 : tensor<2x?x?xf32>
+ }
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0] -> ((d0 + s0) floordiv 10)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (((d0 + s0) floordiv 32) * 5 + (((d1 + s1) mod 10) floordiv 5) * 4)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * 3 + d1 * 7 + s0 * 3 + s1 * 7 - ((d0 + s0) floordiv 32) * 96 - ((d1 + s1) floordiv 5) * 35)>
+// CHECK: func.func @im2col_transposed_m_pos(%[[ARG0:.+]]: tensor<640x2x101x172xf32>
+// CHECK-SAME: %[[mSIZE:.+]]: index, %[[kSIZE:.+]]: index, %[[mOFF:.+]]: index, %[[kOFF:.+]]: index)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[OUT_TILE:.+]] = tensor.empty(%[[mSIZE]], %[[kSIZE]]) : tensor<2x?x?xf32>
+// CHECK: %[[bLOOP:.+]] = scf.for %[[b:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[OUT0:.+]] = %[[OUT_TILE]]) -> (tensor<2x?x?xf32>) {
+// CHECK: %[[mLOOP:.+]] = scf.for %[[m:.+]] = %[[C0]] to %[[mSIZE]] step %[[C1]] iter_args(%[[OUT1:.+]] = %[[OUT0]]) -> (tensor<2x?x?xf32>) {
+// CHECK: %[[kLOOP:.+]] = scf.for %[[k:.+]] = %[[C0]] to %[[kSIZE]] step %[[C1]] iter_args(%[[OUT2:.+]] = %[[OUT1]]) -> (tensor<2x?x?xf32>) {
+// CHECK-DAG: %[[kIDX:.+]] = affine.apply #[[MAP]](%[[k]])[%[[kOFF]]]
+// CHECK-DAG: %[[hIDX:.+]] = affine.apply #[[MAP1]](%[[m]], %[[k]])[%[[mOFF]], %[[kOFF]]]
+// CHECK-DAG: %[[wIDX:.+]] = affine.apply #[[MAP2]](%[[m]], %[[k]])[%[[mOFF]], %[[kOFF]]]
+// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[kIDX]], %[[b]], %[[wIDX]], %[[hIDX]]] [1, 1, 1, 1] [1, 1, 1, 1] : tensor<640x2x101x172xf32> to tensor<1xf32>
+// CHECK: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[OUT2]][%[[b]], %[[m]], %[[k]]] [1, 1, 1] [1, 1, 1] : tensor<2x?x?xf32> to tensor<1xf32>
+// CHECK: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1xf32>) outs(%[[OUT_SLICE]] : tensor<1xf32>) -> tensor<1xf32>
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[OUT2]][%[[b]], %[[m]], %[[k]]] [1, 1, 1] [1, 1, 1] : tensor<1xf32> into tensor<2x?x?xf32>
+// CHECK: scf.yield %[[INSERT]] : tensor<2x?x?xf32>
+// CHECK: }
+// CHECK: scf.yield %[[kLOOP]] : tensor<2x?x?xf32>
+// CHECK: }
+// CHECK: scf.yield %[[mLOOP]] : tensor<2x?x?xf32>
+// CHECK: }
+// CHECK: return %[[bLOOP]] : tensor<2x?x?xf32>