[Im2col] Add decomposition pass for iree_linalg_ext.im2col (#17728)

This PR adds a pass for decomposing the `iree_linalg_ext.im2col` op into
serial loops of `extract_slice->copy->insert_slice`. The pass tries to
keep the innermost dimension of the iteration space un-tiled when the
inner slice is contiguous, but otherwise falls back to serial loops of
scalar slices. The outer dimensions (`B` and `M`) of the iteration
space, however, are always tiled to 1.

---------

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 7002342..fd907fe 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -774,6 +774,7 @@
 
 def IREELinalgExt_Im2colOp : IREELinalgExt_Op<"im2col",
     [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+     DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
      DeclareOpInterfaceMethods<TilingInterface,
       ["getIterationDomain",
        "getLoopIteratorTypes",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
index b4370e7..1db43d8 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
@@ -8,6 +8,8 @@
 #include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -132,6 +134,40 @@
   return genericOp.getResult(0);
 }
 
+// Helper method to check if a slice will be contiguous given the offset,
+// slice size. This checks that `inputSize` and `offset` are both evenly
+// divisible by `tileSize`.
+static bool willBeContiguousSlice(OpFoldResult inputSize, OpFoldResult tileSize,
+                                  OpFoldResult offset) {
+  auto constInputSize = getConstantIntValue(inputSize);
+  auto constTileSize = getConstantIntValue(tileSize);
+  if (!constTileSize.has_value() || !constInputSize.has_value() ||
+      constInputSize.value() % constTileSize.value() != 0) {
+    return false;
+  }
+  auto constOffset = getConstantIntValue(offset);
+  if (constOffset.has_value() &&
+      constOffset.value() % constTileSize.value() == 0) {
+    return true;
+  }
+  auto affineOp = cast<Value>(offset).getDefiningOp<affine::AffineApplyOp>();
+  return affineOp &&
+         affineOp.getMap().getResult(0).isMultipleOf(constTileSize.value());
+}
+
+// Helper method to add 2 OpFoldResult inputs with affine.apply.
+static OpFoldResult addOfrs(OpBuilder &builder, Location loc, OpFoldResult a,
+                            OpFoldResult b) {
+  AffineExpr d0, d1;
+  bindDims(builder.getContext(), d0, d1);
+  auto addMap = AffineMap::get(2, 0, {d0 + d1});
+  return affine::makeComposedFoldedAffineApply(builder, loc, addMap, {a, b});
+}
+
+//===----------------------------------------------------------------------===//
+// OnlineAttentionOp
+//===----------------------------------------------------------------------===//
+
 FailureOr<SmallVector<Value>>
 OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
   Location loc = getLoc();
@@ -225,4 +261,257 @@
   return SmallVector<Value>{newAcc, newMax, newSum};
 }
 
+//===----------------------------------------------------------------------===//
+// Im2colOp
+//===----------------------------------------------------------------------===//
+
+/// Decomposition implementation for iree_linalg_ext.im2col op.
+/// The im2col op is decomposed into serial loops of `insert->extract->copy`.
+/// The `batch` and `M` dimensions of the operation iteration space are always
+/// tiled to 1, and the `K` dimension is left un-tiled if possible. When the
+/// full `K` dimension is a contiguous slice of the input tensor, the K dim
+/// can be left un-tiled so it can be vectorized. Otherwise, it will be tiled
+/// to 1 along with the `batch` and `M` dimensions.
+/// TODO(Max191): Fallback to larger tile sizes instead of immediately tiling K
+///               dimension to 1 when non-contiguous.
+///
+/// The simple decomposition (with K tiled to 1) will look like:
+/// ```
+///   %im2col = iree_linalg_ext.im2col
+///       strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+///       m_offset = [%m_off] k_offset = [%k_off]
+///       batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+///       ins(%in : tensor<2x34x34x640xf32>)
+///       outs(%out : tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
+/// ```
+/// Decomposes to:
+/// ```
+/// scf.for %B = %c0 to %c2 step %c1
+///   scf.for %M = %c0 to %c4 step %c1
+///     scf.for %K = %c0 to %c8 step %c1
+///       %slice = tensor.extract_slice %in[%B, %h, %w, %k] ... to tensor<1xf32>
+///       %copy = linalg.copy ins(%slice) outs(%out)
+///       %insert = tensor.insert_slice %copy into %loop_arg
+/// ```
+/// Where the offsets are computed as:
+///   `%h` = `(%m_off + %M) / 32 + ((%k_off + %K) / 640) / 3`
+///   `%w` = `(%m_off + %M) mod 32 + ((%k_off + %K) / 640) mod 3`
+///   `%k` = `(%k_off + %K) mod 640`
+///
+FailureOr<SmallVector<Value>> Im2colOp::decomposeOperation(OpBuilder &b) {
+  Location loc = getLoc();
+  Value inputSlice = getInput();
+  // Unroll all but the K loop
+  SmallVector<OpFoldResult> kOffset = getMixedKOffset();
+  SmallVector<OpFoldResult> mOffset = getMixedMOffset();
+  // Only support single K and M output dimension for now.
+  if (kOffset.size() != 1 || mOffset.size() != 1) {
+    return failure();
+  }
+
+  // Step 1: Tile the im2col op to loops with contiguous slices in the
+  // innermost loop.
+  //
+  // If the `kOffset` will index to a full contiguous slice of the K dim of
+  // the input tensor, then don't tile the K loop of the im2col op and
+  // maintain a larger contiguous slice.
+  SmallVector<Range> iterationDomain(getIterationDomain(b));
+  OpFoldResult kTileSize = iterationDomain.back().size;
+  auto constKTileSize = getConstantIntValue(kTileSize);
+  if (constKTileSize) {
+    kTileSize = b.getIndexAttr(constKTileSize.value());
+  }
+  SmallVector<OpFoldResult> inputSizes =
+      tensor::getMixedSizes(b, loc, getInput());
+  // Find the innermost non-batch dimension. This dimension is the fastest
+  // changing dimension with the K dimension of the im2col iteration domain.
+  // This means it is the innermost dimension of the extract_slice on the
+  // input tensor, and the slice wants to be contiguous along this dimension.
+  SetVector<int64_t> batchPosSet(getBatchPos().begin(), getBatchPos().end());
+  OpFoldResult innerSliceSize;
+  for (int idx = inputSizes.size() - 1; idx >= 0; --idx) {
+    if (!batchPosSet.contains(idx)) {
+      innerSliceSize = inputSizes[idx];
+      break;
+    }
+  }
+  bool tileK =
+      !willBeContiguousSlice(innerSliceSize, kTileSize, kOffset.front());
+  if (!tileK) {
+    iterationDomain.pop_back();
+  } else {
+    kTileSize = b.getIndexAttr(1);
+  }
+
+  // Build loop nest.
+  SmallVector<Value> lbs, ubs, steps;
+  for (auto range : iterationDomain) {
+    lbs.push_back(getValueOrCreateConstantIndexOp(b, loc, range.offset));
+    ubs.push_back(getValueOrCreateConstantIndexOp(b, loc, range.size));
+    steps.push_back(getValueOrCreateConstantIndexOp(b, loc, range.stride));
+  }
+  scf::LoopNest loopNest = scf::buildLoopNest(
+      b, loc, lbs, ubs, steps, getOutput(),
+      [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs,
+          ValueRange iterArgs) -> scf::ValueVector { return iterArgs; });
+  SmallVector<Value> ivs;
+  for (scf::ForOp loop : loopNest.loops) {
+    ivs.push_back(loop.getInductionVar());
+  }
+
+  // Step 2: Compute indices into the input tensor for extract_slice.
+  OpBuilder::InsertionGuard guard(b);
+  b.setInsertionPoint(loopNest.loops.front());
+  SetVector<int64_t> mPosSet(getMPos().begin(), getMPos().end());
+
+  // Compute the basis for the iteration space of the convolution window
+  // (i.e., the H and W dims of the convolution output).
+  SmallVector<Value> mBasis;
+  ArrayRef<int64_t> strides = getStrides();
+  ArrayRef<int64_t> dilations = getDilations();
+  SmallVector<OpFoldResult> kernelSize = getMixedKernelSize();
+  for (auto [idx, pos] : llvm::enumerate(getMPos())) {
+    AffineExpr x, k;
+    bindDims(getContext(), x, k);
+    AffineExpr mapExpr =
+        (x - 1 - (k - 1) * dilations[idx]).floorDiv(strides[idx]) + 1;
+    OpFoldResult size = affine::makeComposedFoldedAffineApply(
+        b, loc, AffineMap::get(2, 0, {mapExpr}, getContext()),
+        {inputSizes[pos], kernelSize[idx]});
+    mBasis.push_back(getValueOrCreateConstantIndexOp(b, loc, size));
+  }
+
+  // Delinearize the k_offset into an offset into the convolution window and
+  // any reduced channels. For an NHWC conv2d, the basis for delinearization
+  // would be [P, Q, C] for a PxQ kernel with C channels.
+  Location nestedLoc =
+      loopNest.loops.back().getBody()->getTerminator()->getLoc();
+  b.setInsertionPointToStart(loopNest.loops.back().getBody());
+
+  SmallVector<OpFoldResult> kBasis;
+  SmallVector<int64_t> mKernelIdx(getInputRank(), -1);
+  for (auto [idx, mPos] : enumerate(getMPos())) {
+    mKernelIdx[mPos] = idx;
+  }
+  for (auto [idx, size] : enumerate(inputSizes)) {
+    if (batchPosSet.contains(idx))
+      continue;
+    if (mPosSet.contains(idx)) {
+      kBasis.push_back(kernelSize[mKernelIdx[idx]]);
+      continue;
+    }
+    kBasis.push_back(size);
+  }
+  OpFoldResult kIndex = kOffset.front();
+  if (tileK) {
+    kIndex = addOfrs(b, nestedLoc, kOffset.front(), ivs.back());
+  }
+  FailureOr<SmallVector<Value>> maybeDelinKOffset = affine::delinearizeIndex(
+      b, nestedLoc, getValueOrCreateConstantIndexOp(b, loc, kIndex),
+      getValueOrCreateConstantIndexOp(b, loc, (kBasis)));
+  if (failed(maybeDelinKOffset)) {
+    return failure();
+  }
+  SmallVector<Value> delinKOffset = maybeDelinKOffset.value();
+  // Split the delinearized offsets into the window offsets (for M offsets)
+  // and the K offsets for the input tensor.
+  SmallVector<Value> windowOffset, inputKOffset;
+  int delinKIdx = 0;
+  for (int i = 0; i < getInputRank(); ++i) {
+    if (batchPosSet.contains(i))
+      continue;
+    if (mPosSet.contains(i)) {
+      windowOffset.push_back(delinKOffset[delinKIdx++]);
+      continue;
+    }
+    inputKOffset.push_back(delinKOffset[delinKIdx++]);
+  }
+
+  // Compute offsets for extract. Start by delinearizing the combined offset
+  // of m_offset and the offset from the tiled loop, using the mBasis. This
+  // will give an index into the delinearized output space of the convolution.
+  Value mArg = tileK ? ivs[ivs.size() - 2] : ivs.back();
+  OpFoldResult linearMOffset = addOfrs(b, nestedLoc, mArg, mOffset.front());
+  FailureOr<SmallVector<Value>> maybeDelinMOffset = affine::delinearizeIndex(
+      b, nestedLoc,
+      getValueOrCreateConstantIndexOp(b, nestedLoc, linearMOffset), mBasis);
+  if (failed(maybeDelinMOffset)) {
+    return failure();
+  }
+  SmallVector<Value> delinMOffset = maybeDelinMOffset.value();
+
+  // Compute the final offsets into the input tensor.
+  OpFoldResult zero = b.getIndexAttr(0);
+  OpFoldResult one = b.getIndexAttr(1);
+  SmallVector<OpFoldResult> sliceOffsets(getInputRank(), zero);
+  SmallVector<OpFoldResult> sliceStrides(getInputRank(), one);
+  SmallVector<OpFoldResult> sliceSizes(getInputRank(), one);
+  // Add the offset into the convolution window, and account for strides and
+  // dilations.
+  AffineExpr mOff, wOff;
+  bindDims(b.getContext(), mOff, wOff);
+  for (auto [idx, mPos] : llvm::enumerate(getMPos())) {
+    auto map =
+        AffineMap::get(2, 0, {mOff * strides[idx] + wOff * dilations[idx]});
+    OpFoldResult offset = affine::makeComposedFoldedAffineApply(
+        b, nestedLoc, map, {delinMOffset[idx], windowOffset[idx]});
+    sliceOffsets[mPos] = offset;
+    sliceSizes[mPos] = one;
+  }
+  // Set the K offset and size for the input tensor.
+  const int64_t kPos = getKPos().front();
+  sliceOffsets[kPos] = inputKOffset.front();
+  sliceSizes[kPos] = kTileSize;
+
+  // Set the batch offsets for the input tensor.
+  int ivIdx = 0;
+  for (auto bPos : getBatchPos()) {
+    sliceOffsets[bPos] = ivs[ivIdx++];
+  }
+
+  // Step 3. Decompose the im2col op into:
+  // ```
+  // %extract = tensor.extract_slice %input
+  // %copy = linalg.copy ins(%extract) outs(%out_slice)
+  // %insert = tensor.insert_slice %copy into %loop_arg
+  // ```
+  //
+  // Extract a slice from the input tensor.
+  ShapedType outputType = getOutputType();
+  SmallVector<int64_t> kTileSizeStatic;
+  SmallVector<Value> kTileSizeDynamic;
+  dispatchIndexOpFoldResult(kTileSize, kTileSizeDynamic, kTileSizeStatic);
+  auto extractType = cast<RankedTensorType>(outputType.clone(kTileSizeStatic));
+  auto extract =
+      b.create<tensor::ExtractSliceOp>(nestedLoc, extractType, inputSlice,
+                                       sliceOffsets, sliceSizes, sliceStrides);
+
+  // Insert the slice into the destination tensor.
+  sliceOffsets = SmallVector<OpFoldResult>(getOutputRank(), zero);
+  sliceSizes = SmallVector<OpFoldResult>(getOutputRank(), one);
+  sliceStrides = SmallVector<OpFoldResult>(getOutputRank(), one);
+  sliceSizes.back() = kTileSize;
+  for (auto [idx, iv] : llvm::enumerate(ivs)) {
+    sliceOffsets[idx] = iv;
+  }
+  // Insert a `linalg.copy` so there is something to vectorize in the
+  // decomposition. Without this copy, the extract and insert slice ops
+  // do not get vectorized, and the sequence becomes a scalar memref.copy.
+  // This memref.copy could be vectorized after bufferization, but it is
+  // probably better to vectorize during generic vectorization.
+  Value copyDest = b.create<tensor::ExtractSliceOp>(
+      nestedLoc, extractType, loopNest.loops.back().getRegionIterArg(0),
+      sliceOffsets, sliceSizes, sliceStrides);
+  auto copiedSlice =
+      b.create<linalg::CopyOp>(nestedLoc, extract.getResult(), copyDest);
+  auto insert =
+      b.create<tensor::InsertSliceOp>(nestedLoc, copiedSlice.getResult(0),
+                                      loopNest.loops.back().getRegionIterArg(0),
+                                      sliceOffsets, sliceSizes, sliceStrides);
+  auto yieldOp =
+      cast<scf::YieldOp>(loopNest.loops.back().getBody()->getTerminator());
+  yieldOp->getOpOperands().front().assign(insert.getResult());
+  return SmallVector<Value>({loopNest.results[0]});
+}
+
 } // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel
index 1a20a8f..eda9324 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel
@@ -34,6 +34,7 @@
         "ConvertConv2DToWinograd.cpp",
         "ConvertToLoops.cpp",
         "DecomposeAttention.cpp",
+        "DecomposeIm2col.cpp",
         "DecomposeWinogradPass.cpp",
         "PadContractionToBlockSize.cpp",
         "PassDetail.h",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
index 19c7522..bab7997 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
@@ -30,6 +30,7 @@
     "ConvertConv2DToWinograd.cpp"
     "ConvertToLoops.cpp"
     "DecomposeAttention.cpp"
+    "DecomposeIm2col.cpp"
     "DecomposeWinogradPass.cpp"
     "PadContractionToBlockSize.cpp"
     "PassDetail.h"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeIm2col.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeIm2col.cpp
new file mode 100644
index 0000000..a8fc5f2
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeIm2col.cpp
@@ -0,0 +1,67 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::iree_compiler::IREE::LinalgExt {
+namespace {
+
+/// Pattern to decompose the tiled im2col op.
+struct DecomposeIm2col : public OpRewritePattern<Im2colOp> {
+  using OpRewritePattern<Im2colOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(Im2colOp im2colOp,
+                                PatternRewriter &rewriter) const override {
+    FailureOr<SmallVector<Value>> decomposedIm2col =
+        im2colOp.decomposeOperation(rewriter);
+    if (failed(decomposedIm2col)) {
+      return failure();
+    }
+    rewriter.replaceOp(im2colOp, decomposedIm2col.value().front());
+    return success();
+  }
+};
+
+} // namespace
+
+namespace {
+struct DecomposeIm2colPass : public DecomposeIm2colBase<DecomposeIm2colPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<
+        affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
+        linalg::LinalgDialect, scf::SCFDialect, tensor::TensorDialect>();
+  }
+
+  void runOnOperation() override;
+};
+} // namespace
+
+void DecomposeIm2colPass::runOnOperation() {
+  MLIRContext *context = &getContext();
+  RewritePatternSet patterns(context);
+  patterns.add<DecomposeIm2col>(context);
+  memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
+  if (failed(
+          applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
+    return signalPassFailure();
+  }
+}
+
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createDecomposeIm2colPass() {
+  return std::make_unique<DecomposeIm2colPass>();
+}
+
+} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
index 165430a..a549b86 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
@@ -36,6 +36,10 @@
 std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
 createTopkSplitReductionPass();
 
+/// Decompose im2col ops into a serial loop of insert and extract slice ops.
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createDecomposeIm2colPass();
+
 /// Decompose the winograd transform ops into a sequence of linalg ops.
 std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
 createDecomposeWinogradTransformPass();
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
index ee801b6..0c2f0ca 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
@@ -53,6 +53,14 @@
   ];
 }
 
+def DecomposeIm2col :
+    InterfacePass<"iree-linalg-ext-decompose-im2col", "mlir::FunctionOpInterface"> {
+  let summary =
+      "Decomposes im2col ops into insert and extract slice ops";
+  let constructor = "mlir::iree_compiler::IREE::LinalgExt::"
+                    "createDecomposeIm2colPass()";
+}
+
 def DecomposeWinogradTransform :
     InterfacePass<"iree-linalg-ext-decompose-winograd", "mlir::FunctionOpInterface"> {
   let summary =
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
index fe21d60..87ce568 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel
@@ -19,6 +19,7 @@
             "conv2d_to_winograd.mlir",
             "convert_to_loops.mlir",
             "decompose_attention.mlir",
+            "decompose_im2col.mlir",
             "decompose_online_attention.mlir",
             "decompose_winograd.mlir",
             "distribution.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
index 7ef5fd7..87030f0 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
@@ -17,6 +17,7 @@
     "conv2d_to_winograd.mlir"
     "convert_to_loops.mlir"
     "decompose_attention.mlir"
+    "decompose_im2col.mlir"
     "decompose_online_attention.mlir"
     "decompose_winograd.mlir"
     "distribution.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_im2col.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_im2col.mlir
new file mode 100644
index 0000000..ab627b6
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_im2col.mlir
@@ -0,0 +1,73 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-linalg-ext-decompose-im2col))" --split-input-file %s | FileCheck %s
+
+#map = affine_map<(d0) -> (d0 * 4)>
+module {
+  func.func @im2col_untile_k(%arg0: tensor<2x34x34x640xf32>, %m_size: index, %m_off: index, %k: index) -> tensor<2x?x4xf32> {
+    %0 = tensor.empty(%m_size) : tensor<2x?x4xf32>
+    %k_off = affine.apply #map(%k)
+    %7 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] m_offset = [%m_off] k_offset = [%k_off] batch_pos = [0] m_pos = [1, 2] k_pos = [3] ins(%arg0 : tensor<2x34x34x640xf32>) outs(%0 : tensor<2x?x4xf32>) -> tensor<2x?x4xf32>
+    return %7 : tensor<2x?x4xf32>
+  }
+}
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 160) * 640)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) floordiv 32 + s1 floordiv 480)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) mod 32 + s1 floordiv 160 - (s1 floordiv 480) * 3)>
+//      CHECK: func.func @im2col_untile_k(%[[ARG0:.+]]: tensor<2x34x34x640xf32>
+// CHECK-SAME:   %[[mSIZE:.+]]: index, %[[mOFF:.+]]: index, %[[K:.+]]: index)
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//      CHECK:   %[[OUT_TILE:.+]] = tensor.empty(%[[mSIZE]]) : tensor<2x?x4xf32>
+//      CHECK:   %[[bLOOP:.+]] = scf.for %[[b:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[OUT0:.+]] = %[[OUT_TILE]]) -> (tensor<2x?x4xf32>) {
+//      CHECK:     %[[mLOOP:.+]] = scf.for %[[m:.+]] = %[[C0]] to %[[mSIZE]] step %[[C1]] iter_args(%[[OUT1:.+]] = %[[OUT0]]) -> (tensor<2x?x4xf32>) {
+//  CHECK-DAG:       %[[kIDX:.+]] = affine.apply #[[MAP]]()[%[[K]]]
+//  CHECK-DAG:       %[[hIDX:.+]] = affine.apply #[[MAP1]](%[[m]])[%[[mOFF]], %[[K]]]
+//  CHECK-DAG:       %[[wIDX:.+]] = affine.apply #[[MAP2]](%[[m]])[%[[mOFF]], %[[K]]]
+//      CHECK:       %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[b]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<4xf32>
+//      CHECK:       %[[OUT_SLICE:.+]] = tensor.extract_slice %[[OUT1]][%[[b]], %[[m]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x?x4xf32> to tensor<4xf32>
+//      CHECK:       %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<4xf32>) outs(%[[OUT_SLICE]] : tensor<4xf32>) -> tensor<4xf32>
+//      CHECK:       %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[OUT1]][%[[b]], %[[m]], 0] [1, 1, 4] [1, 1, 1] : tensor<4xf32> into tensor<2x?x4xf32>
+//      CHECK:       scf.yield %[[INSERT]] : tensor<2x?x4xf32>
+//      CHECK:     }
+//      CHECK:     scf.yield %[[mLOOP]] : tensor<2x?x4xf32>
+//      CHECK:   }
+//      CHECK:   return %[[bLOOP]] : tensor<2x?x4xf32>
+
+// -----
+
+module {
+  func.func @im2col_transposed_m_pos(%arg0: tensor<640x2x101x172xf32>, %m_size: index, %k_size: index, %m_off: index, %k_off: index) -> tensor<2x?x?xf32> {
+    %c2 = arith.constant 2 : index
+    %c1 = arith.constant 1 : index
+    %c0 = arith.constant 0 : index
+    %0 = tensor.empty(%m_size, %k_size) : tensor<2x?x?xf32>
+    %8 = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2] m_offset = [%m_off] k_offset = [%k_off] batch_pos = [1] m_pos = [3, 2] k_pos = [0] ins(%arg0 : tensor<640x2x101x172xf32>) outs(%0 : tensor<2x?x?xf32>) -> tensor<2x?x?xf32>
+    return %8 : tensor<2x?x?xf32>
+  }
+}
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0] -> ((d0 + s0) floordiv 10)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (((d0 + s0) floordiv 32) * 5 + (((d1 + s1) mod 10) floordiv 5) * 4)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * 3 + d1 * 7 + s0 * 3 + s1 * 7 - ((d0 + s0) floordiv 32) * 96 - ((d1 + s1) floordiv 5) * 35)>
+//      CHECK: func.func @im2col_transposed_m_pos(%[[ARG0:.+]]: tensor<640x2x101x172xf32>
+// CHECK-SAME:   %[[mSIZE:.+]]: index, %[[kSIZE:.+]]: index, %[[mOFF:.+]]: index, %[[kOFF:.+]]: index)
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//      CHECK:   %[[OUT_TILE:.+]] = tensor.empty(%[[mSIZE]], %[[kSIZE]]) : tensor<2x?x?xf32>
+//      CHECK:   %[[bLOOP:.+]] = scf.for %[[b:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[OUT0:.+]] = %[[OUT_TILE]]) -> (tensor<2x?x?xf32>) {
+//      CHECK:     %[[mLOOP:.+]] = scf.for %[[m:.+]] = %[[C0]] to %[[mSIZE]] step %[[C1]] iter_args(%[[OUT1:.+]] = %[[OUT0]]) -> (tensor<2x?x?xf32>) {
+//      CHECK:       %[[kLOOP:.+]] = scf.for %[[k:.+]] = %[[C0]] to %[[kSIZE]] step %[[C1]] iter_args(%[[OUT2:.+]] = %[[OUT1]]) -> (tensor<2x?x?xf32>) {
+//  CHECK-DAG:         %[[kIDX:.+]] = affine.apply #[[MAP]](%[[k]])[%[[kOFF]]]
+//  CHECK-DAG:         %[[hIDX:.+]] = affine.apply #[[MAP1]](%[[m]], %[[k]])[%[[mOFF]], %[[kOFF]]]
+//  CHECK-DAG:         %[[wIDX:.+]] = affine.apply #[[MAP2]](%[[m]], %[[k]])[%[[mOFF]], %[[kOFF]]]
+//      CHECK:         %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[kIDX]], %[[b]], %[[wIDX]], %[[hIDX]]] [1, 1, 1, 1] [1, 1, 1, 1] : tensor<640x2x101x172xf32> to tensor<1xf32>
+//      CHECK:         %[[OUT_SLICE:.+]] = tensor.extract_slice %[[OUT2]][%[[b]], %[[m]], %[[k]]] [1, 1, 1] [1, 1, 1] : tensor<2x?x?xf32> to tensor<1xf32>
+//      CHECK:         %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1xf32>) outs(%[[OUT_SLICE]] : tensor<1xf32>) -> tensor<1xf32>
+//      CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[OUT2]][%[[b]], %[[m]], %[[k]]] [1, 1, 1] [1, 1, 1] : tensor<1xf32> into tensor<2x?x?xf32>
+//      CHECK:         scf.yield %[[INSERT]] : tensor<2x?x?xf32>
+//      CHECK:       }
+//      CHECK:       scf.yield %[[kLOOP]] : tensor<2x?x?xf32>
+//      CHECK:     }
+//      CHECK:     scf.yield %[[mLOOP]] : tensor<2x?x?xf32>
+//      CHECK:   }
+//      CHECK:   return %[[bLOOP]] : tensor<2x?x?xf32>