[Im2col] Allow multiple batch, M, and K dimensions on im2col result (#18593)

This PR adds support for multiple M and K dimensions in the result of
the im2col op. New metadata is added for correctly tracking the offsets
into the M and K dimensions along the multiple dimensions. New `m_strides`
and `k_strides` fields are added to the op, which represent a basis for
linearizing the `m_offset` and `k_offset` fields.

The motivation for doing this is that flattening the M dimension can
create an expand_shape op consumer of the resulting matmul. This can
cause issues with fusion and distribution, so it is useful to be able to
keep the multiple M dimensions intact. This PR does not change any
behavior of Conv2DToIm2col pass, which will be done in a later PR.

---------

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
index 49d6892..3cec8a9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
@@ -389,7 +389,8 @@
       } {
     %4 = iree_linalg_ext.im2col {lowering_config = #config}
       strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-      m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [2, 3] k_pos = [1]
+      m_offset = [0] * [1] k_offset = [0] * [1]
+      batch_pos = [0] m_pos = [2, 3] k_pos = [1]
       ins(%2 : tensor<2x34x34x128xf16>)
       outs(%3 : tensor<2x128x8xf16>) -> tensor<2x128x8xf16>
     return %4 : tensor<2x128x8xf16>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
index 23bddd8..fe7d739 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
@@ -162,7 +162,7 @@
         %6 = tensor.empty() : tensor<2x256x11520xf16>
         %7 = iree_linalg_ext.im2col
             strides = [2, 2] dilations = [1, 1] kernel_size = [3, 3]
-            m_offset = [0] k_offset = [0]
+            m_offset = [0] * [1] k_offset = [0] * [1]
             batch_pos = [0] m_pos = [1, 2] k_pos = [3]
           ins(%3 : tensor<2x34x34x1280xf16>)
           outs(%6 : tensor<2x256x11520xf16>) -> tensor<2x256x11520xf16>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index ad818b2..dafb17f 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -1513,12 +1513,24 @@
                                    getKOffset());
 }
 
-/// Return all static and dynamic k_offset as OpFoldResults.
+/// Return all static and dynamic m_offset as OpFoldResults.
 SmallVector<OpFoldResult> Im2colOp::getMixedMOffset() {
   return LinalgExt::getMixedValues(getContext(), getStaticMOffset(),
                                    getMOffset());
 }
 
+/// Return all static and dynamic k_strides as OpFoldResults.
+SmallVector<OpFoldResult> Im2colOp::getMixedKStrides() {
+  return LinalgExt::getMixedValues(getContext(), getStaticKStrides(),
+                                   getKStrides());
+}
+
+/// Return all static and dynamic m_strides as OpFoldResults.
+SmallVector<OpFoldResult> Im2colOp::getMixedMStrides() {
+  return LinalgExt::getMixedValues(getContext(), getStaticMStrides(),
+                                   getMStrides());
+}
+
 void Im2colOp::setMixedKOffset(SmallVector<OpFoldResult> kOffset) {
   SmallVector<int64_t> staticKOffset;
   SmallVector<Value> dynamicKOffset;
@@ -1535,23 +1547,59 @@
   getMOffsetMutable().assign(dynamicMOffset);
 }
 
+void Im2colOp::setMixedKStrides(SmallVector<OpFoldResult> kStrides) {
+  SmallVector<int64_t> staticKStrides;
+  SmallVector<Value> dynamicKStrides;
+  dispatchIndexOpFoldResults(kStrides, dynamicKStrides, staticKStrides);
+  setStaticKStrides(staticKStrides);
+  getKStridesMutable().assign(dynamicKStrides);
+}
+
+void Im2colOp::setMixedMStrides(SmallVector<OpFoldResult> mStrides) {
+  SmallVector<int64_t> staticMStrides;
+  SmallVector<Value> dynamicMStrides;
+  dispatchIndexOpFoldResults(mStrides, dynamicMStrides, staticMStrides);
+  setStaticMStrides(staticMStrides);
+  getMStridesMutable().assign(dynamicMStrides);
+}
+
+SmallVector<int64_t> Im2colOp::getBatchOutputDims() {
+  return llvm::to_vector(llvm::seq<int64_t>(0, getBatchPos().size()));
+}
+
+SmallVector<int64_t> Im2colOp::getMOutputDims() {
+  int64_t begin = getBatchPos().size();
+  int64_t end = begin + getMixedMOffset().size();
+  return llvm::to_vector(llvm::seq<int64_t>(begin, end));
+}
+
+SmallVector<int64_t> Im2colOp::getKOutputDims() {
+  int64_t begin = getBatchPos().size() + getMixedMOffset().size();
+  int64_t end = begin + getMixedKOffset().size();
+  return llvm::to_vector(llvm::seq<int64_t>(begin, end));
+}
+
 /// Custom builder methods for im2col op.
-void Im2colOp::build(OpBuilder &builder, OperationState &state, Value input,
-                     Value output, ArrayRef<int64_t> strides,
-                     ArrayRef<int64_t> dilations,
-                     ArrayRef<OpFoldResult> kernelSize,
-                     ArrayRef<OpFoldResult> kOffset,
-                     ArrayRef<OpFoldResult> mOffset, ArrayRef<int64_t> batchPos,
-                     ArrayRef<int64_t> mPos, ArrayRef<int64_t> kPos) {
+void Im2colOp::build(
+    OpBuilder &builder, OperationState &state, Value input, Value output,
+    ArrayRef<int64_t> strides, ArrayRef<int64_t> dilations,
+    ArrayRef<OpFoldResult> kernelSize, ArrayRef<OpFoldResult> mOffset,
+    ArrayRef<OpFoldResult> mStrides, ArrayRef<OpFoldResult> kOffset,
+    ArrayRef<OpFoldResult> kStrides, ArrayRef<int64_t> batchPos,
+    ArrayRef<int64_t> mPos, ArrayRef<int64_t> kPos) {
   assert(strides.size() == kernelSize.size() &&
          dilations.size() == kernelSize.size() &&
          mPos.size() == kernelSize.size() &&
          "strides, dilations, m_pos, and kernel expected to be the same rank");
-  SmallVector<int64_t> staticKernelSize, staticMOffset, staticKOffset;
-  SmallVector<Value> dynamicKernelSize, dynamicMOffset, dynamicKOffset;
+  SmallVector<int64_t> staticKernelSize, staticMOffset, staticKOffset,
+      staticMStrides, staticKStrides;
+  SmallVector<Value> dynamicKernelSize, dynamicMOffset, dynamicKOffset,
+      dynamicMStrides, dynamicKStrides;
   dispatchIndexOpFoldResults(kernelSize, dynamicKernelSize, staticKernelSize);
   dispatchIndexOpFoldResults(mOffset, dynamicMOffset, staticMOffset);
+  dispatchIndexOpFoldResults(mStrides, dynamicMStrides, staticMStrides);
   dispatchIndexOpFoldResults(kOffset, dynamicKOffset, staticKOffset);
+  dispatchIndexOpFoldResults(kStrides, dynamicKStrides, staticKStrides);
   SmallVector<Type> resultType;
   auto outputType = output.getType();
   if (isa<RankedTensorType>(outputType)) {
@@ -1560,9 +1608,11 @@
   build(builder, state, resultType, input, output,
         builder.getDenseI64ArrayAttr(strides),
         builder.getDenseI64ArrayAttr(dilations), dynamicKernelSize,
-        builder.getDenseI64ArrayAttr(staticKernelSize), dynamicKOffset,
-        builder.getDenseI64ArrayAttr(staticKOffset), dynamicMOffset,
-        builder.getDenseI64ArrayAttr(staticMOffset),
+        builder.getDenseI64ArrayAttr(staticKernelSize), dynamicMOffset,
+        builder.getDenseI64ArrayAttr(staticMOffset), dynamicMStrides,
+        builder.getDenseI64ArrayAttr(staticMStrides), dynamicKOffset,
+        builder.getDenseI64ArrayAttr(staticKOffset), dynamicKStrides,
+        builder.getDenseI64ArrayAttr(staticKStrides),
         builder.getDenseI64ArrayAttr(batchPos),
         builder.getDenseI64ArrayAttr(mPos), builder.getDenseI64ArrayAttr(kPos));
 }
@@ -1578,14 +1628,35 @@
     return op->emitOpError("expected one output operand");
   }
 
-  // TODO(Max191): Support cases with more than 1 m or k dimension, and remove
-  // the check for a single m_offset and k_offset.
-  if (getMixedMOffset().size() != 1) {
-    return op->emitOpError("expected one m_offset");
+  // Verify offsets and strides
+  SmallVector<OpFoldResult> kOffset = getMixedKOffset();
+  SmallVector<OpFoldResult> mOffset = getMixedMOffset();
+  SmallVector<OpFoldResult> kStrides = getMixedKStrides();
+  SmallVector<OpFoldResult> mStrides = getMixedMStrides();
+  if (kOffset.size() < 1) {
+    return op->emitOpError("expected at least one k_offset");
   }
-  if (getMixedKOffset().size() != 1) {
-    return op->emitOpError("expected one k_offset");
+  if (mOffset.size() < 1) {
+    return op->emitOpError("expected at least one m_offset");
   }
+  if (kOffset.size() != kStrides.size()) {
+    return op->emitOpError("expected the same size k_offset and k_strides");
+  }
+  if (mOffset.size() != mStrides.size()) {
+    return op->emitOpError("expected the same size m_offset and m_strides");
+  }
+  std::optional<int64_t> constInnerKStrides =
+      getConstantIntValue(kStrides.back());
+  if (!constInnerKStrides.has_value() || constInnerKStrides.value() != 1) {
+    return op->emitOpError("expected inner k_strides to be 1");
+  }
+  std::optional<int64_t> constInnerMStrides =
+      getConstantIntValue(mStrides.back());
+  if (!constInnerMStrides.has_value() || constInnerMStrides.value() != 1) {
+    return op->emitOpError("expected inner m_strides to be 1");
+  }
+
+  // Verify operand ranks and dim position sizes.
   auto inputType = getInputType();
   unsigned inputRank = inputType.getRank();
   ArrayRef<int64_t> batchPos = getBatchPos();
@@ -1595,6 +1666,14 @@
     return op->emitOpError(
         "expected input rank to be the sum of batch, m, and k ranks");
   }
+  auto outputType = getOutputType();
+  unsigned outputRank = outputType.getRank();
+  if (outputRank != batchPos.size() + kOffset.size() + mOffset.size()) {
+    return op->emitOpError("expected output rank to be the sum of "
+                           "batch_pos, k_offset, and m_offset ranks");
+  }
+
+  // Verify convolution metadata.
   ArrayRef<int64_t> strides = getStrides();
   ArrayRef<int64_t> dilations = getDilations();
   SmallVector<OpFoldResult> kernelSize = getMixedKernelSize();
@@ -1611,17 +1690,16 @@
         "expected dilations rank to be equal to the kernel rank");
   }
 
+  // Verify input and output shapes.
   ArrayRef<int64_t> inputShape = inputType.getShape();
-  SmallVector<int64_t> expectedOutputShape;
-  for (auto pos : batchPos) {
-    expectedOutputShape.push_back(inputShape[pos]);
-  }
-  ArrayRef<int64_t> outputShape = getOutputType().getShape();
+  ArrayRef<int64_t> outputShape = outputType.getShape();
   // When the op is tiled, the m and k dimensions of the output are tiled, but
   // they are not tiled in the input, so we cannot verify the output size of
-  // these dimensions.
-  expectedOutputShape.push_back(outputShape[outputShape.size() - 2]);
-  expectedOutputShape.push_back(outputShape.back());
+  // these dimensions. Only verify the shape of the batch dimensions.
+  SmallVector<int64_t> expectedOutputShape(outputShape);
+  for (auto [idx, pos] : llvm::enumerate(batchPos)) {
+    expectedOutputShape[idx] = inputShape[pos];
+  }
   if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
     return op->emitOpError("incompatible output shape");
   }
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 0d0c44b..e6aab96 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -827,7 +827,7 @@
     ```
       %im2col = iree_linalg_ext.im2col
           strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-          m_offset = [0] k_offset = [0]
+          m_offset = [0] * [1] k_offset = [0] * [1]
           batch_pos = [0] m_pos = [1, 2] k_pos = [3]
           ins(%in : tensor<2x34x34x640xf32>)
           outs(%out : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
@@ -839,7 +839,7 @@
           scf.for %arg2 = %c0 to %c5760 step %c1
             %im2col = iree_linalg_ext.im2col
                 strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-                m_offset = [%arg1] k_offset = [%arg2]
+                m_offset = [%arg1] * [1] k_offset = [%arg2] * [1]
                 batch_pos = [0] m_pos = [1, 2] k_pos = [3]
                 ins(%in_tile : tensor<1x34x34x640xf32>)
                 outs(%out_tile : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
@@ -851,6 +851,21 @@
     (b, m, k) -> (b, M / 32 + K / (640*3), M % 32 + K % (640*3) / 640, K % 640)
     Where `(b, m, k)` are the indices of the tiled op's iteration space, and
     `M = m + m_offset` and `K = k + K_offset`.
+
+    The `m_strides` and `k_strides` fields are used as a basis for linearizing
+    the `m_offset` and `k_offset`. This is used when there are multiple M or K
+    output dimensions, and therefore multiple `m_offset` or `k_offset` values.
+    The strides fields are assembled in the IR as if they are multiplied as an
+    inner product with `m_offset` and `k_offset, indicating that the total
+    linear offset along the dimension is equal to this inner product. These
+    strides fields also determine the strides of the output dimensions along
+    M and K. For example, an op with `m_strides = [32, 1]`, `k_strides = [4, 1]`,
+    and output type `tensor<BxM0xM1xK0xK1>` (expanded from `tensor<BxMxK>`),
+    would have strides along the M dim of 32 for `M0`, meaning as `M0` increases
+    by 1, the index into the flat `M` increases by 32. Along the K dim, strides
+    would be 4 for `K0`, and 1 for `K1`, meaning as `K0` increases by 1, the
+    index into the flat `K` increases by 4. The strides in M from `m_strides`
+    are orthogonal to the strides in `K` from `k_strides`.
   }];
 
   let arguments = (ins AnyShaped:$input, AnyShaped:$output,
@@ -860,8 +875,12 @@
                        DenseI64ArrayAttr:$static_kernel_size,
                        Variadic<Index>:$m_offset,
                        DenseI64ArrayAttr:$static_m_offset,
+                       Variadic<Index>:$m_strides,
+                       DenseI64ArrayAttr:$static_m_strides,
                        Variadic<Index>:$k_offset,
                        DenseI64ArrayAttr:$static_k_offset,
+                       Variadic<Index>:$k_strides,
+                       DenseI64ArrayAttr:$static_k_strides,
                        DenseI64ArrayAttr:$batch_pos,
                        DenseI64ArrayAttr:$m_pos,
                        DenseI64ArrayAttr:$k_pos);
@@ -876,8 +895,10 @@
     custom<DynamicIndexList>($kernel_size, $static_kernel_size)
     `m_offset` `=`
     custom<DynamicIndexList>($m_offset, $static_m_offset)
+    `*` custom<DynamicIndexList>($m_strides, $static_m_strides)
     `k_offset` `=`
     custom<DynamicIndexList>($k_offset, $static_k_offset)
+    `*` custom<DynamicIndexList>($k_strides, $static_k_strides)
     `batch_pos` `=` $batch_pos
     `m_pos` `=` $m_pos
     `k_pos` `=` $k_pos
@@ -892,7 +913,9 @@
       "ArrayRef<int64_t>":$dilations,
       "ArrayRef<OpFoldResult>":$kernel_size,
       "ArrayRef<OpFoldResult>":$m_offset,
+      "ArrayRef<OpFoldResult>":$m_strides,
       "ArrayRef<OpFoldResult>":$k_offset,
+      "ArrayRef<OpFoldResult>":$k_strides,
       "ArrayRef<int64_t>":$batch_dimensions,
       "ArrayRef<int64_t>":$m_dimensions,
       "ArrayRef<int64_t>":$k_dimensions)>
@@ -911,14 +934,24 @@
     int64_t getOutputRank() {
       return getOutputType().getRank();
     }
+
+    // Helpers to get output dimensions corresponding to batch, m, and k.
+    SmallVector<int64_t> getBatchOutputDims();
+    SmallVector<int64_t> getMOutputDims();
+    SmallVector<int64_t> getKOutputDims();
+
     // Return op metadata.
     SmallVector<OpFoldResult> getMixedKernelSize();
     SmallVector<OpFoldResult> getMixedMOffset();
     SmallVector<OpFoldResult> getMixedKOffset();
+    SmallVector<OpFoldResult> getMixedMStrides();
+    SmallVector<OpFoldResult> getMixedKStrides();
 
     // Set op metadata.
     void setMixedKOffset(SmallVector<OpFoldResult> kOffset);
     void setMixedMOffset(SmallVector<OpFoldResult> mOffset);
+    void setMixedKStrides(SmallVector<OpFoldResult> kStrides);
+    void setMixedMStrides(SmallVector<OpFoldResult> mStrides);
 
     // Method to implement for specifying output range for
     // DestinationStyleOpInterface
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp
index 6af4173..95ff938 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp
@@ -1265,9 +1265,10 @@
   SmallVector<OpFoldResult> inputSizes = getDims(builder, loc, getInput());
 
   // Set batch offsets and sizes for input
-  for (auto [idx, dim] : llvm::enumerate(getBatchPos())) {
-    inputOffsets[dim] = offsets[idx];
-    inputSizes[dim] = sizes[idx];
+  for (auto [outDim, inDim] :
+       llvm::zip_equal(getBatchOutputDims(), getBatchPos())) {
+    inputOffsets[inDim] = offsets[outDim];
+    inputSizes[inDim] = sizes[outDim];
   }
 
   SmallVector<OpFoldResult> inputStrides(getInputRank(), one);
@@ -1292,26 +1293,30 @@
                        outputSlice->result_type_end());
   }
 
-  AffineExpr d0, d1;
-  bindDims(getContext(), d0, d1);
-  auto map = AffineMap::get(2, 0, {d0 + d1}, getContext());
-  OpFoldResult kTileOffset = offsets.back();
-  OpFoldResult kOpOffset = getMixedKOffset()[0];
-  OpFoldResult kOffset = affine::makeComposedFoldedAffineApply(
-      builder, loc, map, {kTileOffset, kOpOffset});
-  OpFoldResult mTileOffset = offsets[offsets.size() - 2];
-  OpFoldResult mOpOffset = getMixedMOffset()[0];
-  OpFoldResult mOffset = affine::makeComposedFoldedAffineApply(
-      builder, loc, map, {mTileOffset, mOpOffset});
+  // Adjust m_offset and k_offset by adding the offsets from tiling.
+  SmallVector<OpFoldResult> newKOffsets, newMOffsets;
+  for (auto [outDim, kOffset] :
+       llvm::zip_equal(getKOutputDims(), getMixedKOffset())) {
+    OpFoldResult kTileOffset = offsets[outDim];
+    newKOffsets.push_back(addOfrs(builder, loc, kTileOffset, kOffset));
+  }
+  for (auto [outDim, mOffset] :
+       llvm::zip_equal(getMOutputDims(), getMixedMOffset())) {
+    OpFoldResult mTileOffset = offsets[outDim];
+    newMOffsets.push_back(addOfrs(builder, loc, mTileOffset, mOffset));
+  }
 
+  // Create the tiled op.
   SmallVector<Value> operands = {inputSlice->getResult(0),
                                  outputSlice->getResult(0)};
+  // Copy all metadata operands from the untiled operation.
   operands.append(getOperation()->getOperands().begin() + 2,
                   getOperation()->getOperands().end());
   Im2colOp tiledOp =
       mlir::clone(builder, *this, outputSlice->getResultTypes(), operands);
-  tiledOp.setMixedKOffset({kOffset});
-  tiledOp.setMixedMOffset({mOffset});
+  // Set the new k_offset and m_offset, since they have changed with tiling.
+  tiledOp.setMixedKOffset(newKOffsets);
+  tiledOp.setMixedMOffset(newMOffsets);
 
   return TilingResult{{tiledOp},
                       SmallVector<Value>(tiledOp->getResults()),
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
index 0643d8a..6884e9f 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
@@ -550,7 +550,8 @@
   %0 = tensor.empty() : tensor<2x1024x5760xf32>
   // expected-error @+1 {{expected strides rank to be equal to the kernel rank}}
   %1 = iree_linalg_ext.im2col strides = [1] dilations = [1, 1] kernel_size = [3, 3]
-           m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+           m_offset = [0] * [1] k_offset = [0] * [1]
+           batch_pos = [0] m_pos = [1, 2] k_pos = [3]
            ins(%arg0 : tensor<2x34x34x640xf32>)
            outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
   return %1 : tensor<2x1024x5760xf32>
@@ -562,7 +563,8 @@
   %0 = tensor.empty() : tensor<2x1024x5760xf32>
   // expected-error @+1 {{expected dilations rank to be equal to the kernel rank}}
   %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1, 1] kernel_size = [3, 3]
-           m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+           m_offset = [0] * [1] k_offset = [0] * [1]
+           batch_pos = [0] m_pos = [1, 2] k_pos = [3]
            ins(%arg0 : tensor<2x34x34x640xf32>)
            outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
   return %1 : tensor<2x1024x5760xf32>
@@ -574,7 +576,60 @@
   %0 = tensor.empty() : tensor<2x1024x5760xf32>
   // expected-error @+1 {{expected kernel rank to be equal to the m_pos rank}}
   %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3]
-           m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+           m_offset = [0] * [1] k_offset = [0] * [1]
+           batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+           ins(%arg0 : tensor<2x34x34x640xf32>)
+           outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
+  return %1 : tensor<2x1024x5760xf32>
+}
+
+// -----
+
+func.func @illegal_im2col_m_offset(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32> {
+  %0 = tensor.empty() : tensor<2x1024x5760xf32>
+  // expected-error @+1 {{expected the same size m_offset and m_strides}}
+  %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+           m_offset = [0, 0] * [1] k_offset = [0] * [1]
+           batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+           ins(%arg0 : tensor<2x34x34x640xf32>)
+           outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
+  return %1 : tensor<2x1024x5760xf32>
+}
+
+// -----
+
+func.func @illegal_im2col_k_offset(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32> {
+  %0 = tensor.empty() : tensor<2x1024x5760xf32>
+  // expected-error @+1 {{expected the same size k_offset and k_strides}}
+  %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+           m_offset = [0] * [1] k_offset = [0, 0] * [1]
+           batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+           ins(%arg0 : tensor<2x34x34x640xf32>)
+           outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
+  return %1 : tensor<2x1024x5760xf32>
+}
+
+// -----
+
+func.func @illegal_im2col_m_strides(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32> {
+  %0 = tensor.empty() : tensor<2x1024x5760xf32>
+  // expected-error @+1 {{expected inner m_strides to be 1}}
+  %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+           m_offset = [0] * [0] k_offset = [0] * [1]
+           batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+           ins(%arg0 : tensor<2x34x34x640xf32>)
+           outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
+  return %1 : tensor<2x1024x5760xf32>
+}
+
+// -----
+
+func.func @illegal_im2col_k_strides(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32> {
+  %0 = tensor.empty() : tensor<2x1024x5760xf32>
+  // expected-error @+1 {{expected inner k_strides to be 1}}
+  %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+           m_offset = [0] * [1] k_offset = [0] * [2]
+           batch_pos = [0] m_pos = [1, 2] k_pos = [3]
            ins(%arg0 : tensor<2x34x34x640xf32>)
            outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
   return %1 : tensor<2x1024x5760xf32>
@@ -586,7 +641,8 @@
   %0 = tensor.empty() : tensor<2x1024x5760xf32>
   // expected-error @+1 {{expected input rank to be the sum of batch, m, and k ranks}}
   %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-           m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+           m_offset = [0] * [1] k_offset = [0] * [1]
+           batch_pos = [0] m_pos = [1, 2] k_pos = [3]
            ins(%arg0 : tensor<1x2x34x34x640xf32>)
            outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
   return %1 : tensor<2x1024x5760xf32>
@@ -594,6 +650,19 @@
 
 // -----
 
+func.func @illegal_im2col_output_rank(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x9x640xf32> {
+  %0 = tensor.empty() : tensor<2x1024x9x640xf32>
+  // expected-error @+1 {{expected output rank to be the sum of batch_pos, k_offset, and m_offset ranks}}
+  %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+           m_offset = [0] * [1] k_offset = [0] * [1]
+           batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+           ins(%arg0 : tensor<2x34x34x640xf32>)
+           outs(%0 : tensor<2x1024x9x640xf32>) -> tensor<2x1024x9x640xf32>
+  return %1 : tensor<2x1024x9x640xf32>
+}
+
+// -----
+
 func.func @illegal_winograd_input_shape(%arg0: tensor<1x10x10x32xf32>) -> tensor<8x8x1x6x6x32xf32> {
   %0 = tensor.empty() : tensor<8x8x1x6x6x32xf32>
   // expected-error @+1 {{incompatible output shape}}
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
index 7fedef7..df94c9e 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
@@ -1,10 +1,5 @@
 // RUN: iree-opt --split-input-file %s | FileCheck %s
 
-// CHECK-LABEL: func.func @sort_tensor
-// CHECK:         iree_linalg_ext.sort
-// CHECK-SAME:      dimension(0)
-// CHECK-SAME:      outs({{.*}})
-// CHECK:           iree_linalg_ext.yield
 func.func @sort_tensor(%arg0: tensor<128xi32>) -> tensor<128xi32> {
   %0 = iree_linalg_ext.sort
     dimension(0)
@@ -15,14 +10,14 @@
   } -> tensor<128xi32>
   return %0 : tensor<128xi32>
 }
-
-// -----
-
-// CHECK-LABEL: func.func @sort_memref
+// CHECK-LABEL: func.func @sort_tensor(
 // CHECK:         iree_linalg_ext.sort
 // CHECK-SAME:      dimension(0)
 // CHECK-SAME:      outs({{.*}})
 // CHECK:           iree_linalg_ext.yield
+
+// -----
+
 func.func @sort_memref(%arg0: memref<128xi32>) {
   iree_linalg_ext.sort dimension(0)
     outs(%arg0 : memref<128xi32>) {
@@ -32,6 +27,11 @@
   }
   return
 }
+// CHECK-LABEL: func.func @sort_memref(
+// CHECK:         iree_linalg_ext.sort
+// CHECK-SAME:      dimension(0)
+// CHECK-SAME:      outs({{.*}})
+// CHECK:           iree_linalg_ext.yield
 
 // -----
 
@@ -46,7 +46,7 @@
       } -> tensor<?x?xi32>, tensor<?x?xf32>
   return %0#0, %0#1 : tensor<?x?xi32>, tensor<?x?xf32>
 }
-// CHECK-LABEL: func.func @sort_multi_result_tensor
+// CHECK-LABEL: func.func @sort_multi_result_tensor(
 //  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?xi32>
 //  CHECK-SAME:   %[[ARG1:.+]]: tensor<?x?xf32>
 //       CHECK:   %[[RESULT:.+]]:2 = iree_linalg_ext.sort dimension(0)
@@ -65,7 +65,7 @@
      }
   return
 }
-// CHECK-LABEL: func.func @sort_multi_result_memref
+// CHECK-LABEL: func.func @sort_multi_result_memref(
 //  CHECK-SAME:   %[[ARG0:.+]]: memref<?x?xi32>
 //  CHECK-SAME:   %[[ARG1:.+]]: memref<?x?xf32>
 //       CHECK:   iree_linalg_ext.sort dimension(0)
@@ -524,10 +524,9 @@
         } -> tensor<20x10x3x4xf32>, tensor<20x10x3x4xi32>
   return %0#0, %0#1 : tensor<20x10x3x4xf32>, tensor<20x10x3x4xi32>
 }
-
-// CHECK-LABEL: func.func @topk_tensor
-//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x10x8x4xf32>
-//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x10x8x4xi32>
+// CHECK-LABEL: func.func @topk_tensor(
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<20x10x8x4xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<20x10x8x4xi32>
 //       CHECK:   %[[OUT_VALUES:.+]] = tensor.empty()
 //       CHECK:   %[[OUT_INDICES:.+]] = tensor.empty()
 //       CHECK:   %[[RESULT:.+]]:2 = iree_linalg_ext.topk
@@ -550,7 +549,7 @@
         }
   return
 }
-// CHECK-LABEL: func.func @topk_memref
+// CHECK-LABEL: func.func @topk_memref(
 //  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: memref<4x10xf32>
 //  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: memref<4x10xi32>
 //  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9]+]]: memref<4x3xf32>
@@ -574,7 +573,7 @@
         } -> tensor<?x?xf32>, tensor<?x?xi32>
   return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xi32>
 }
-// CHECK-LABEL: func.func @topk_dynamic_tensor
+// CHECK-LABEL: func.func @topk_dynamic_tensor(
 //  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
 //  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xi32>
 //  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
@@ -602,7 +601,7 @@
   return %0#0, %0#1 : tensor<20x10x3x4xf32>, tensor<20x10x3x4xi32>
 }
 
-// CHECK-LABEL: func.func @topk_tensor
+// CHECK-LABEL: func.func @topk_tensor_optional(
 //  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x10x8x4xf32>
 //       CHECK:   %[[OUT_VALUES:.+]] = tensor.empty()
 //       CHECK:   %[[OUT_INDICES:.+]] = tensor.empty()
@@ -620,11 +619,11 @@
   return %1 : tensor<3x3x1x1xf32>
 }
 
-// CHECK: func.func @pack(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x3xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<3x3x1x1xf32>) -> tensor<3x3x1x1xf32>
-// CHECK: %[[RES:.*]] = iree_linalg_ext.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %[[ARG1]] : (tensor<3x3xf32> tensor<3x3x1x1xf32>) -> tensor<3x3x1x1xf32>
-// CHECK: return %[[RES]] : tensor<3x3x1x1xf32>
+// CHECK-LABEL: func.func @pack(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x3xf32>
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: tensor<3x3x1x1xf32>
+// CHECK:         %[[RES:.*]] = iree_linalg_ext.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %[[ARG1]] : (tensor<3x3xf32> tensor<3x3x1x1xf32>) -> tensor<3x3x1x1xf32>
+// CHECK:         return %[[RES]] : tensor<3x3x1x1xf32>
 
 // -----
 
@@ -633,10 +632,10 @@
   return
 }
 
-// CHECK: func.func @pack(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<3x3xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<3x3x1x1xf32>) {
-// CHECK: iree_linalg_ext.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %[[ARG1]] : (memref<3x3xf32> memref<3x3x1x1xf32>)
+// CHECK-LABEL: func.func @pack(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<3x3xf32>
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: memref<3x3x1x1xf32>
+// CHECK:         iree_linalg_ext.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %[[ARG1]] : (memref<3x3xf32> memref<3x3x1x1xf32>)
 
 // -----
 
@@ -645,16 +644,16 @@
   %0 = iree_linalg_ext.pack %input padding_value(%pad: f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : (tensor<13x15xf32> tensor<3x8x8x2xf32>) -> tensor<3x8x8x2xf32>
   return %0 : tensor<3x8x8x2xf32>
 }
-// CHECK:      func @extra_pad_and_pack(
-// CHECK-SAME:   %[[INPUT:.+]]: tensor<13x15xf32>
-// CHECK-SAME:   %[[OUTPUT:.+]]: tensor<3x8x8x2xf32>
-// CHECK-SAME:   %[[PAD:.+]]: f32
-// CHECK:        %[[RES:.+]] = iree_linalg_ext.pack %[[INPUT]]
-// CHECK-SAME:     padding_value(%[[PAD]] : f32)
-// CHECK-SAME:     inner_dims_pos = [0, 1]
-// CHECK-SAME:     inner_tiles = [8, 2]
-// CHECK-SAME:     into %[[OUTPUT]]
-// CHECK:       return %[[RES]]
+// CHECK-LABEL: func @extra_pad_and_pack(
+// CHECK-SAME:    %[[INPUT:.+]]: tensor<13x15xf32>
+// CHECK-SAME:    %[[OUTPUT:.+]]: tensor<3x8x8x2xf32>
+// CHECK-SAME:    %[[PAD:.+]]: f32
+// CHECK:         %[[RES:.+]] = iree_linalg_ext.pack %[[INPUT]]
+// CHECK-SAME:      padding_value(%[[PAD]] : f32)
+// CHECK-SAME:      inner_dims_pos = [0, 1]
+// CHECK-SAME:      inner_tiles = [8, 2]
+// CHECK-SAME:      into %[[OUTPUT]]
+// CHECK:         return %[[RES]]
 
 // -----
 
@@ -662,16 +661,16 @@
   %0 = iree_linalg_ext.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : (tensor<13x15xf32> tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
   return %0 : tensor<2x8x8x2xf32>
 }
-// CHECK:      func.func @pad_and_pack_static
-// CHECK-SAME:   %[[INPUT:[a-zA-Z0-9_]+]]: tensor<13x15xf32>
-// CHECK-SAME:   %[[OUTPUT:[a-zA-Z0-9_]+]]: tensor<2x8x8x2xf32>
-// CHECK-SAME:   %[[PAD:[a-zA-Z0-9_]+]]: f32
-// CHECK:        %[[RES:.+]] = iree_linalg_ext.pack %[[INPUT]]
-// CHECK-SAME:     padding_value(%[[PAD]] : f32)
-// CHECK-SAME:     inner_dims_pos = [0, 1]
-// CHECK-SAME:     inner_tiles = [8, 2]
-// CHECK-SAME:     into %[[OUTPUT]]
-// CHECK:        return %[[RES]]
+// CHECK-LABEL: func.func @pad_and_pack_static(
+// CHECK-SAME:    %[[INPUT:[a-zA-Z0-9_]+]]: tensor<13x15xf32>
+// CHECK-SAME:    %[[OUTPUT:[a-zA-Z0-9_]+]]: tensor<2x8x8x2xf32>
+// CHECK-SAME:    %[[PAD:[a-zA-Z0-9_]+]]: f32
+// CHECK:         %[[RES:.+]] = iree_linalg_ext.pack %[[INPUT]]
+// CHECK-SAME:      padding_value(%[[PAD]] : f32)
+// CHECK-SAME:      inner_dims_pos = [0, 1]
+// CHECK-SAME:      inner_tiles = [8, 2]
+// CHECK-SAME:      into %[[OUTPUT]]
+// CHECK:         return %[[RES]]
 
 // -----
 
@@ -679,16 +678,16 @@
   %0 = iree_linalg_ext.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : (tensor<?x?xf32> tensor<?x?x8x2xf32>) -> tensor<?x?x8x2xf32>
   return %0 : tensor<?x?x8x2xf32>
 }
-// CHECK:      func.func @pad_and_pack_partially_dynamic
-// CHECK-SAME:   %[[INPUT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME:   %[[OUTPUT:[a-zA-Z0-9_]+]]: tensor<?x?x8x2xf32>
-// CHECK-SAME:   %[[PAD:[a-zA-Z0-9_]+]]: f32
-// CHECK:        %[[RES:.+]] = iree_linalg_ext.pack %[[INPUT]]
-// CHECK-SAME:     padding_value(%[[PAD]] : f32)
-// CHECK-SAME:     inner_dims_pos = [0, 1]
-// CHECK-SAME:     inner_tiles = [8, 2]
-// CHECK-SAME:     into %[[OUTPUT]]
-// CHECK:        return %[[RES]]
+// CHECK-LABEL: func.func @pad_and_pack_partially_dynamic(
+// CHECK-SAME:    %[[INPUT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:    %[[OUTPUT:[a-zA-Z0-9_]+]]: tensor<?x?x8x2xf32>
+// CHECK-SAME:    %[[PAD:[a-zA-Z0-9_]+]]: f32
+// CHECK:         %[[RES:.+]] = iree_linalg_ext.pack %[[INPUT]]
+// CHECK-SAME:      padding_value(%[[PAD]] : f32)
+// CHECK-SAME:      inner_dims_pos = [0, 1]
+// CHECK-SAME:      inner_tiles = [8, 2]
+// CHECK-SAME:      into %[[OUTPUT]]
+// CHECK:         return %[[RES]]
 
 // -----
 
@@ -697,18 +696,18 @@
     inner_dims_pos = [0, 1] inner_tiles = [%tile_n, %tile_m] into %output : (tensor<?x?xf32> tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
   return %0 : tensor<?x?x?x?xf32>
 }
-// CHECK:      func.func @pad_and_pack_fully_dynamic
-// CHECK-SAME:   %[[INPUT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME:   %[[OUTPUT:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>
-// CHECK-SAME:   %[[PAD:[a-zA-Z0-9_]+]]: f32
-// CHECK-SAME:   %[[TILE_N:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[TILE_M:[a-zA-Z0-9_]+]]: index
-// CHECK:        %[[RES:.+]] = iree_linalg_ext.pack %[[INPUT]]
-// CHECK-SAME:     padding_value(%[[PAD]] : f32)
-// CHECK-SAME:     inner_dims_pos = [0, 1]
-// CHECK-SAME:     inner_tiles = [%[[TILE_N]], %[[TILE_M]]]
-// CHECK-SAME:     into %[[OUTPUT]]
-// CHECK:        return %[[RES]]
+// CHECK-LABEL: func.func @pad_and_pack_fully_dynamic(
+// CHECK-SAME:    %[[INPUT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:    %[[OUTPUT:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME:    %[[PAD:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME:    %[[TILE_N:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:    %[[TILE_M:[a-zA-Z0-9_]+]]: index
+// CHECK:         %[[RES:.+]] = iree_linalg_ext.pack %[[INPUT]]
+// CHECK-SAME:      padding_value(%[[PAD]] : f32)
+// CHECK-SAME:      inner_dims_pos = [0, 1]
+// CHECK-SAME:      inner_tiles = [%[[TILE_N]], %[[TILE_M]]]
+// CHECK-SAME:      into %[[OUTPUT]]
+// CHECK:         return %[[RES]]
 
 // -----
 
@@ -717,10 +716,10 @@
   return
 }
 
-// CHECK: func.func @unpack(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<3x3xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<3x3x1x1xf32>) {
-// CHECK: iree_linalg_ext.unpack %[[ARG1]] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %[[ARG0]] : (memref<3x3x1x1xf32> memref<3x3xf32>)
+// CHECK-LABEL: func.func @unpack(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<3x3xf32>,
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: memref<3x3x1x1xf32>) {
+// CHECK:         iree_linalg_ext.unpack %[[ARG1]] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %[[ARG0]] : (memref<3x3x1x1xf32> memref<3x3xf32>)
 
 // -----
 
@@ -729,15 +728,15 @@
   return %0 : tensor<256x128xf32>
 }
 
-// CHECK:      func.func @unpack_static
-// CHECK-SAME:   %[[INPUT:[a-zA-Z0-9_]+]]
-// CHECK-SAME:   %[[OUTPUT:[a-zA-Z0-9_]+]]
-// CHECK:        %[[UNPACK:.+]] = iree_linalg_ext.unpack
-// CHECK-SAME:     %[[INPUT]]
-// CHECK-SAME      dim_pos = [0, 1]
-// CHECK-SAME      inner_pos = [32, 16]
-// CHECK-SAME:     into %[[OUTPUT]]
-// CHECK:        return %[[UNPACK]]
+// CHECK-LABEL: func.func @unpack_static(
+// CHECK-SAME:    %[[INPUT:[a-zA-Z0-9_]+]]
+// CHECK-SAME:    %[[OUTPUT:[a-zA-Z0-9_]+]]
+// CHECK:         %[[UNPACK:.+]] = iree_linalg_ext.unpack
+// CHECK-SAME:      %[[INPUT]]
+// CHECK-SAME       dim_pos = [0, 1]
+// CHECK-SAME       inner_pos = [32, 16]
+// CHECK-SAME:      into %[[OUTPUT]]
+// CHECK:         return %[[UNPACK]]
 
 // -----
 
@@ -745,15 +744,15 @@
   %0 = iree_linalg_ext.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : (tensor<2x8x8x2xf32> tensor<13x15xf32>) -> tensor<13x15xf32>
   return %0 : tensor<13x15xf32>
 }
-// CHECK:      func.func @unpack_undo_padding
-// CHECK-SAME:   %[[INPUT:[a-zA-Z0-9_]+]]
-// CHECK-SAME:   %[[OUTPUT:[a-zA-Z0-9_]+]]
-// CHECK:        %[[UNPACK:.+]] = iree_linalg_ext.unpack
-// CHECK-SAME:     %[[INPUT]]
-// CHECK-SAME      dim_pos = [0, 1]
-// CHECK-SAME      inner_pos = [32, 16]
-// CHECK-SAME:     into %[[OUTPUT]]
-// CHECK:        return %[[UNPACK]]
+// CHECK-LABEL: func.func @unpack_undo_padding(
+// CHECK-SAME:    %[[INPUT:[a-zA-Z0-9_]+]]
+// CHECK-SAME:    %[[OUTPUT:[a-zA-Z0-9_]+]]
+// CHECK:         %[[UNPACK:.+]] = iree_linalg_ext.unpack
+// CHECK-SAME:      %[[INPUT]]
+// CHECK-SAME       dim_pos = [0, 1]
+// CHECK-SAME       inner_pos = [32, 16]
+// CHECK-SAME:      into %[[OUTPUT]]
+// CHECK:         return %[[UNPACK]]
 
 // -----
 
@@ -762,10 +761,10 @@
   return
 }
 
-// CHECK: func.func @unpack(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<3x3xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<3x3x1x1xf32>) {
-// CHECK: iree_linalg_ext.unpack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %[[ARG0]] : (memref<3x3x1x1xf32> memref<3x3xf32>)
+// CHECK-LABEL: func.func @unpack(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<3x3xf32>
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: memref<3x3x1x1xf32>
+// CHECK:         iree_linalg_ext.unpack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %[[ARG0]] : (memref<3x3x1x1xf32> memref<3x3xf32>)
 
 // -----
 
@@ -774,10 +773,10 @@
   return
 }
 
-// CHECK: func.func @pack
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<128x256xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<32x4x32x8xf32>) {
-// CHECK: iree_linalg_ext.pack %[[ARG0]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %[[ARG1]] : (memref<128x256xf32> memref<32x4x32x8xf32>)
+// CHECK-LABEL: func.func @pack(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<128x256xf32>
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: memref<32x4x32x8xf32>
+// CHECK:         iree_linalg_ext.pack %[[ARG0]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %[[ARG1]] : (memref<128x256xf32> memref<32x4x32x8xf32>)
 
 // -----
 
@@ -786,10 +785,10 @@
   return
 }
 
-// CHECK: func.func @pack
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<128x256xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<4x32x32x8xf32>) {
-// CHECK: iree_linalg_ext.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %[[ARG1]] : (memref<128x256xf32> memref<4x32x32x8xf32>)
+// CHECK-LABEL: func.func @pack(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<128x256xf32>
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: memref<4x32x32x8xf32>
+// CHECK:         iree_linalg_ext.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %[[ARG1]] : (memref<128x256xf32> memref<4x32x32x8xf32>)
 
 // -----
 
@@ -798,10 +797,10 @@
   return
 }
 
-// CHECK: func.func @unpack
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<128x256xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<4x32x32x8xf32>) {
-// CHECK: iree_linalg_ext.unpack %[[ARG1]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %[[ARG0]] : (memref<4x32x32x8xf32> memref<128x256xf32>)
+// CHECK-LABEL: func.func @unpack(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<128x256xf32>
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: memref<4x32x32x8xf32>
+// CHECK:         iree_linalg_ext.unpack %[[ARG1]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %[[ARG0]] : (memref<4x32x32x8xf32> memref<128x256xf32>)
 
 // -----
 
@@ -810,28 +809,31 @@
   return
 }
 
-// CHECK: func.func @unpack
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<128x256xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<32x4x32x8xf32>) {
-// CHECK: iree_linalg_ext.unpack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %[[ARG0]] : (memref<32x4x32x8xf32> memref<128x256xf32>)
+// CHECK-LABEL: func.func @unpack(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<128x256xf32>
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: memref<32x4x32x8xf32>
+// CHECK:         iree_linalg_ext.unpack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %[[ARG0]] : (memref<32x4x32x8xf32> memref<128x256xf32>)
 
 // -----
 
 func.func @im2col(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32> {
   %0 = tensor.empty() : tensor<2x1024x5760xf32>
   %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-           m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+           m_offset = [0] * [1] k_offset = [0] * [1]
+           batch_pos = [0] m_pos = [1, 2] k_pos = [3]
            ins(%arg0 : tensor<2x34x34x640xf32>)
            outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
   return %1 : tensor<2x1024x5760xf32>
 }
-// CHECK:      func.func @im2col(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32>
-// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
-// CHECK:        %[[D1:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-// CHECK-SAME:     m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
-// CHECK-SAME:     ins(%[[ARG0]] : tensor<2x34x34x640xf32>)
-// CHECK-SAME:     outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
-// CHECK:        return %[[D1]] : tensor<2x1024x5760xf32>
+// CHECK-LABEL: func.func @im2col(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x34x34x640xf32>
+// CHECK:         %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
+// CHECK:         %[[D1:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+// CHECK-SAME:      m_offset = [0] * [1] k_offset = [0] * [1]
+// CHECK-SAME:      batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+// CHECK-SAME:      ins(%[[ARG0]] : tensor<2x34x34x640xf32>)
+// CHECK-SAME:      outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
+// CHECK:         return %[[D1]] : tensor<2x1024x5760xf32>
 
 // -----
 
@@ -839,91 +841,127 @@
                           %mOffset: index, %kOffset: index) -> tensor<?x?x?xf32> {
   %0 = tensor.empty(%s0, %s1, %s2) : tensor<?x?x?xf32>
   %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-           m_offset = [%mOffset] k_offset = [%kOffset] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+           m_offset = [%mOffset] * [1] k_offset = [%kOffset] * [1]
+           batch_pos = [0] m_pos = [1, 2] k_pos = [3]
            ins(%arg0 : tensor<?x?x?x?xf32>)
            outs(%0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
   return %1 : tensor<?x?x?xf32>
 }
-// CHECK:      func.func @im2col_dynamic(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>,
-// CHECK-SAME:   %{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[MOFFSET:.+]]: index, %[[KOFFSET:.+]]: index
-// CHECK:        %[[D0:.+]] = tensor.empty({{.+}}) : tensor<?x?x?xf32>
-// CHECK:        %[[D1:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-// CHECK-SAME:     m_offset = [%[[MOFFSET]]] k_offset = [%[[KOFFSET]]] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
-// CHECK-SAME:     ins(%[[ARG0]] : tensor<?x?x?x?xf32>)
-// CHECK-SAME:     outs(%[[D0]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-// CHECK:        return %[[D1]] : tensor<?x?x?xf32>
+// CHECK-LABEL: func.func @im2col_dynamic(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME:    %{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[MOFFSET:.+]]: index, %[[KOFFSET:.+]]: index
+// CHECK:         %[[D0:.+]] = tensor.empty({{.+}}) : tensor<?x?x?xf32>
+// CHECK:         %[[D1:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+// CHECK-SAME:      m_offset = [%[[MOFFSET]]] * [1] k_offset = [%[[KOFFSET]]] * [1]
+// CHECK-SAME:      batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+// CHECK-SAME:      ins(%[[ARG0]] : tensor<?x?x?x?xf32>)
+// CHECK-SAME:      outs(%[[D0]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK:         return %[[D1]] : tensor<?x?x?xf32>
 
 // -----
 
 func.func @im2col_strided(%arg0: tensor<2x65x96x640xf32>) -> tensor<2x1024x5760xf32> {
   %0 = tensor.empty() : tensor<2x1024x5760xf32>
   %1 = iree_linalg_ext.im2col strides = [2, 3] dilations = [1, 1] kernel_size = [3, 3]
-           m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+           m_offset = [0] * [1] k_offset = [0] * [1]
+           batch_pos = [0] m_pos = [1, 2] k_pos = [3]
            ins(%arg0 : tensor<2x65x96x640xf32>)
            outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
   return %1 : tensor<2x1024x5760xf32>
 }
-// CHECK:      func.func @im2col_strided(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x65x96x640xf32>) -> tensor<2x1024x5760xf32>
-// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
-// CHECK:        %[[D1:.+]] = iree_linalg_ext.im2col strides = [2, 3] dilations = [1, 1] kernel_size = [3, 3]
-// CHECK-SAME:     m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
-// CHECK-SAME:     ins(%[[ARG0]] : tensor<2x65x96x640xf32>)
-// CHECK-SAME:     outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
-// CHECK:        return %[[D1]] : tensor<2x1024x5760xf32>
+// CHECK-LABEL: func.func @im2col_strided(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x65x96x640xf32>
+// CHECK:         %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
+// CHECK:         %[[D1:.+]] = iree_linalg_ext.im2col strides = [2, 3] dilations = [1, 1] kernel_size = [3, 3]
+// CHECK-SAME:      m_offset = [0] * [1] k_offset = [0] * [1]
+// CHECK-SAME:      batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+// CHECK-SAME:      ins(%[[ARG0]] : tensor<2x65x96x640xf32>)
+// CHECK-SAME:      outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
+// CHECK:         return %[[D1]] : tensor<2x1024x5760xf32>
 
 // -----
 
 func.func @im2col_dilated(%arg0: tensor<2x44x46x640xf32>) -> tensor<2x1024x5760xf32> {
   %0 = tensor.empty() : tensor<2x1024x5760xf32>
   %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [6, 7] kernel_size = [3, 3]
-           m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+           m_offset = [0] * [1] k_offset = [0] * [1]
+           batch_pos = [0] m_pos = [1, 2] k_pos = [3]
            ins(%arg0 : tensor<2x44x46x640xf32>)
            outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
   return %1 : tensor<2x1024x5760xf32>
 }
-// CHECK:      func.func @im2col_dilated(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x44x46x640xf32>) -> tensor<2x1024x5760xf32>
-// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
-// CHECK:        %[[D1:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [6, 7] kernel_size = [3, 3]
-// CHECK-SAME:     m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
-// CHECK-SAME:     ins(%[[ARG0]] : tensor<2x44x46x640xf32>)
-// CHECK-SAME:     outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
-// CHECK:        return %[[D1]] : tensor<2x1024x5760xf32>
+// CHECK-LABEL: func.func @im2col_dilated(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x44x46x640xf32>
+// CHECK:         %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
+// CHECK:         %[[D1:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [6, 7] kernel_size = [3, 3]
+// CHECK-SAME:      m_offset = [0] * [1] k_offset = [0] * [1]
+// CHECK-SAME:      batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+// CHECK-SAME:      ins(%[[ARG0]] : tensor<2x44x46x640xf32>)
+// CHECK-SAME:      outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
+// CHECK:         return %[[D1]] : tensor<2x1024x5760xf32>
 
 // -----
 
 func.func @im2col_strided_dilated_mixed_kernel(%arg0: tensor<2x172x101x640xf32>) -> tensor<2x1024x5760xf32> {
   %0 = tensor.empty() : tensor<2x1024x5760xf32>
   %1 = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2]
-           m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+           m_offset = [0] * [1] k_offset = [0] * [1]
+           batch_pos = [0] m_pos = [1, 2] k_pos = [3]
            ins(%arg0 : tensor<2x172x101x640xf32>)
            outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
   return %1 : tensor<2x1024x5760xf32>
 }
-// CHECK:      func.func @im2col_strided_dilated_mixed_kernel(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x172x101x640xf32>) -> tensor<2x1024x5760xf32>
-// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
-// CHECK:        %[[D1:.+]] = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2]
-// CHECK-SAME:     m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
-// CHECK-SAME:     ins(%[[ARG0]] : tensor<2x172x101x640xf32>)
-// CHECK-SAME:     outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
-// CHECK:        return %[[D1]] : tensor<2x1024x5760xf32>
+// CHECK-LABEL: func.func @im2col_strided_dilated_mixed_kernel(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x172x101x640xf32>
+// CHECK:         %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
+// CHECK:         %[[D1:.+]] = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2]
+// CHECK-SAME:      m_offset = [0] * [1] k_offset = [0] * [1]
+// CHECK-SAME:      batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+// CHECK-SAME:      ins(%[[ARG0]] : tensor<2x172x101x640xf32>)
+// CHECK-SAME:      outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
+// CHECK:         return %[[D1]] : tensor<2x1024x5760xf32>
 
 // -----
 
 func.func @im2col_transposed_m_pos(%arg0: tensor<640x2x101x172xf32>) -> tensor<2x1024x5760xf32> {
   %0 = tensor.empty() : tensor<2x1024x5760xf32>
   %1 = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2]
-           m_offset = [0] k_offset = [0] batch_pos = [1] m_pos = [3, 2] k_pos = [0]
+           m_offset = [0] * [1] k_offset = [0] * [1]
+           batch_pos = [1] m_pos = [3, 2] k_pos = [0]
            ins(%arg0 : tensor<640x2x101x172xf32>)
            outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
   return %1 : tensor<2x1024x5760xf32>
 }
-// CHECK:      func.func @im2col_transposed_m_pos(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<640x2x101x172xf32>) -> tensor<2x1024x5760xf32>
-// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
-// CHECK:        %[[D1:.+]] = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2]
-// CHECK-SAME:     m_offset = [0] k_offset = [0] batch_pos = [1] m_pos = [3, 2] k_pos = [0]
-// CHECK-SAME:     ins(%[[ARG0]] : tensor<640x2x101x172xf32>)
-// CHECK-SAME:     outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
-// CHECK:        return %[[D1]] : tensor<2x1024x5760xf32>
+// CHECK-LABEL: func.func @im2col_transposed_m_pos(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]: tensor<640x2x101x172xf32>
+// CHECK:         %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
+// CHECK:         %[[D1:.+]] = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2]
+// CHECK-SAME:      m_offset = [0] * [1] k_offset = [0] * [1]
+// CHECK-SAME:      batch_pos = [1] m_pos = [3, 2] k_pos = [0]
+// CHECK-SAME:      ins(%[[ARG0]] : tensor<640x2x101x172xf32>)
+// CHECK-SAME:      outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
+// CHECK:         return %[[D1]] : tensor<2x1024x5760xf32>
+
+// -----
+
+func.func @im2col_expanded(%arg0: tensor<2x3x34x34x640xf32>) -> tensor<2x3x128x8x90x64xf32> {
+  %0 = tensor.empty() : tensor<2x3x128x8x90x64xf32>
+  %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+           m_offset = [0, 0] * [8, 1] k_offset = [0, 0] * [64, 1]
+           batch_pos = [0, 1] m_pos = [2, 3] k_pos = [4]
+           ins(%arg0 : tensor<2x3x34x34x640xf32>)
+           outs(%0 : tensor<2x3x128x8x90x64xf32>) -> tensor<2x3x128x8x90x64xf32>
+  return %1 : tensor<2x3x128x8x90x64xf32>
+}
+// CHECK-LABEL: func.func @im2col_expanded(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x3x34x34x640xf32>
+// CHECK:         %[[D0:.+]] = tensor.empty() : tensor<2x3x128x8x90x64xf32>
+// CHECK:         %[[D1:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+// CHECK-SAME:      m_offset = [0, 0] * [8, 1] k_offset = [0, 0] * [64, 1]
+// CHECK-SAME:      batch_pos = [0, 1] m_pos = [2, 3] k_pos = [4]
+// CHECK-SAME:      ins(%[[ARG0]] : tensor<2x3x34x34x640xf32>)
+// CHECK-SAME:      outs(%[[D0]] : tensor<2x3x128x8x90x64xf32>) -> tensor<2x3x128x8x90x64xf32>
+// CHECK:         return %[[D1]] : tensor<2x3x128x8x90x64xf32>
 
 // -----
 
@@ -934,14 +972,13 @@
     ins(%arg0 : tensor<3x3x64x128xf32>) outs(%0 : tensor<8x8x64x128xf32>) -> tensor<8x8x64x128xf32>
   return %1 : tensor<8x8x64x128xf32>
 }
-// CHECK:      func.func @winograd_filter_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<3x3x64x128xf32>) ->
-// CHECK-SAME:   tensor<8x8x64x128xf32> {
-// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<8x8x64x128xf32>
-// CHECK:        %[[D1:.+]] = iree_linalg_ext.winograd.filter_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME:     kernel_dimensions([0, 1]) ins(%[[ARG0]] : tensor<3x3x64x128xf32>) outs(%[[D0]] :
-// CHECK-SAME:     tensor<8x8x64x128xf32>) -> tensor<8x8x64x128xf32>
-// CHECK:        return %[[D1]] : tensor<8x8x64x128xf32>
-// CHECK:      }
+// CHECK-LABEL: func.func @winograd_filter_transform(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]: tensor<3x3x64x128xf32>
+// CHECK:         %[[D0:.+]] = tensor.empty() : tensor<8x8x64x128xf32>
+// CHECK:         %[[D1:.+]] = iree_linalg_ext.winograd.filter_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:      kernel_dimensions([0, 1]) ins(%[[ARG0]] : tensor<3x3x64x128xf32>) outs(%[[D0]] :
+// CHECK-SAME:      tensor<8x8x64x128xf32>) -> tensor<8x8x64x128xf32>
+// CHECK:         return %[[D1]] : tensor<8x8x64x128xf32>
 
 // -----
 
@@ -951,13 +988,13 @@
     ins(%arg0 : tensor<3x3x?x?xf32>) outs(%arg1 : tensor<8x8x?x?xf32>) -> tensor<8x8x?x?xf32>
   return %1 : tensor<8x8x?x?xf32>
 }
-// CHECK:      func.func @winograd_filter_transform_dynamic(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<3x3x?x?xf32>,
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<8x8x?x?xf32>) -> tensor<8x8x?x?xf32> {
-// CHECK:        %[[D0:.+]] = iree_linalg_ext.winograd.filter_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME:     kernel_dimensions([0, 1]) ins(%[[ARG0]] : tensor<3x3x?x?xf32>) outs(%[[ARG1]] :
-// CHECK-SAME:     tensor<8x8x?x?xf32>) -> tensor<8x8x?x?xf32>
-// CHECK:        return %[[D0]] : tensor<8x8x?x?xf32>
-// CHECK:      }
+// CHECK-LABEL: func.func @winograd_filter_transform_dynamic(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]: tensor<3x3x?x?xf32>
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]+]]: tensor<8x8x?x?xf32>
+// CHECK:         %[[D0:.+]] = iree_linalg_ext.winograd.filter_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:      kernel_dimensions([0, 1]) ins(%[[ARG0]] : tensor<3x3x?x?xf32>) outs(%[[ARG1]] :
+// CHECK-SAME:      tensor<8x8x?x?xf32>) -> tensor<8x8x?x?xf32>
+// CHECK:         return %[[D0]] : tensor<8x8x?x?xf32>
 
 // -----
 
@@ -968,15 +1005,13 @@
     ins(%arg0 : tensor<128x64x3x3xf32>) outs(%0 : tensor<8x8x64x128xf32>) -> tensor<8x8x64x128xf32>
   return %1 : tensor<8x8x64x128xf32>
 }
-// CHECK:      func.func @winograd_filter_transform_fchw(%[[ARG0]]: tensor<128x64x3x3xf32>) ->
-// CHECK-SAME:   tensor<8x8x64x128xf32> {
-// CHECK:        %[[D0]] = tensor.empty() : tensor<8x8x64x128xf32>
-// CHECK:        %[[D1]] = iree_linalg_ext.winograd.filter_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME:     kernel_dimensions([2, 3]) ins(%[[ARG0]] : tensor<128x64x3x3xf32>) outs(%[[D0]] :
-// CHECK-SAME:     tensor<8x8x64x128xf32>) -> tensor<8x8x64x128xf32>
-// CHECK:        return %[[D1]] : tensor<8x8x64x128xf32>
-// CHECK:      }
-// CHECK:    }
+// CHECK-LABEL: func.func @winograd_filter_transform_fchw(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]: tensor<128x64x3x3xf32>
+// CHECK:         %[[D0:.+]] = tensor.empty() : tensor<8x8x64x128xf32>
+// CHECK:         %[[D1:.+]] = iree_linalg_ext.winograd.filter_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:      kernel_dimensions([2, 3]) ins(%[[ARG0]] : tensor<128x64x3x3xf32>) outs(%[[D0]] :
+// CHECK-SAME:      tensor<8x8x64x128xf32>) -> tensor<8x8x64x128xf32>
+// CHECK:         return %[[D1]] : tensor<8x8x64x128xf32>
 
 // -----
 
@@ -986,14 +1021,13 @@
     ins(%arg0 : tensor<1x10x10x1280xf32>) outs(%0 : tensor<8x8x1x2x2x1280xf32>) -> tensor<8x8x1x2x2x1280xf32>
   return %1 : tensor<8x8x1x2x2x1280xf32>
 }
-// CHECK:      func.func @winograd_input_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10x10x1280xf32>) ->
-// CHECK-SAME:   tensor<8x8x1x2x2x1280xf32> {
-// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<8x8x1x2x2x1280xf32>
-// CHECK:        %[[D1:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME:     image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<1x10x10x1280xf32>) outs(%[[D0]] :
-// CHECK-SAME:     tensor<8x8x1x2x2x1280xf32>) -> tensor<8x8x1x2x2x1280xf32>
-// CHECK:        return %[[D1]] : tensor<8x8x1x2x2x1280xf32>
-// CHECK:      }
+// CHECK-LABEL: func.func @winograd_input_transform(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10x10x1280xf32>
+// CHECK:         %[[D0:.+]] = tensor.empty() : tensor<8x8x1x2x2x1280xf32>
+// CHECK:         %[[D1:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:      image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<1x10x10x1280xf32>) outs(%[[D0]] :
+// CHECK-SAME:      tensor<8x8x1x2x2x1280xf32>) -> tensor<8x8x1x2x2x1280xf32>
+// CHECK:         return %[[D1]] : tensor<8x8x1x2x2x1280xf32>
 
 // -----
 
@@ -1003,13 +1037,13 @@
     ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<8x8x?x?x?x?xf32>) -> tensor<8x8x?x?x?x?xf32>
   return %1 : tensor<8x8x?x?x?x?xf32>
 }
-// CHECK:      func.func @winograd_input_transform_dynamic(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>,
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<8x8x?x?x?x?xf32>) -> tensor<8x8x?x?x?x?xf32> {
-// CHECK:        %[[D0:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME:     image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<?x?x?x?xf32>) outs(%[[ARG1]] :
-// CHECK-SAME:     tensor<8x8x?x?x?x?xf32>) -> tensor<8x8x?x?x?x?xf32>
-// CHECK:        return %[[D0]] : tensor<8x8x?x?x?x?xf32>
-// CHECK:      }
+// CHECK-LABEL: func.func @winograd_input_transform_dynamic(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]+]]: tensor<8x8x?x?x?x?xf32>
+// CHECK:         %[[D0:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:      image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<?x?x?x?xf32>) outs(%[[ARG1]] :
+// CHECK-SAME:      tensor<8x8x?x?x?x?xf32>) -> tensor<8x8x?x?x?x?xf32>
+// CHECK:         return %[[D0]] : tensor<8x8x?x?x?x?xf32>
 
 // -----
 
@@ -1019,15 +1053,13 @@
     ins(%arg0 : tensor<1x1280x10x10xf32>) outs(%0 : tensor<8x8x1x2x2x1280xf32>) -> tensor<8x8x1x2x2x1280xf32>
   return %1 : tensor<8x8x1x2x2x1280xf32>
 }
-// CHECK:      func.func @winograd_input_transform_nchw(%[[ARG0]]: tensor<1x1280x10x10xf32>) ->
-// CHECK-SAME:   tensor<8x8x1x2x2x1280xf32> {
-// CHECK:        %[[D0]] = tensor.empty() : tensor<8x8x1x2x2x1280xf32>
-// CHECK:        %[[D1]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME:     image_dimensions([2, 3]) ins(%[[ARG0]] : tensor<1x1280x10x10xf32>) outs(%[[D0]] :
-// CHECK-SAME:     tensor<8x8x1x2x2x1280xf32>) -> tensor<8x8x1x2x2x1280xf32>
-// CHECK:        return %[[D1]] : tensor<8x8x1x2x2x1280xf32>
-// CHECK:      }
-// CHECK:    }
+// CHECK-LABEL: func.func @winograd_input_transform_nchw(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x1280x10x10xf32>
+// CHECK:         %[[D0:.+]] = tensor.empty() : tensor<8x8x1x2x2x1280xf32>
+// CHECK:         %[[D1:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:      image_dimensions([2, 3]) ins(%[[ARG0]] : tensor<1x1280x10x10xf32>) outs(%[[D0]] :
+// CHECK-SAME:      tensor<8x8x1x2x2x1280xf32>) -> tensor<8x8x1x2x2x1280xf32>
+// CHECK:         return %[[D1]] : tensor<8x8x1x2x2x1280xf32>
 
 // -----
 
@@ -1037,29 +1069,28 @@
     ins(%arg0 : tensor<8x8x1x2x2x1280xf32>) outs(%0 : tensor<1x12x12x1280xf32>) -> tensor<1x12x12x1280xf32>
   return %1 : tensor<1x12x12x1280xf32>
 }
-// CHECK:      func.func @winograd_output_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x1x2x2x1280xf32>) ->
-// CHECK-SAME:   tensor<1x12x12x1280xf32> {
-// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<1x12x12x1280xf32>
-// CHECK:        %[[D1:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME:     image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<8x8x1x2x2x1280xf32>) outs(%[[D0]] :
-// CHECK-SAME:     tensor<1x12x12x1280xf32>) -> tensor<1x12x12x1280xf32>
-// CHECK:        return %[[D1]] : tensor<1x12x12x1280xf32>
-// CHECK:      }
+// CHECK-LABEL: func.func @winograd_output_transform(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x1x2x2x1280xf32>
+// CHECK:         %[[D0:.+]] = tensor.empty() : tensor<1x12x12x1280xf32>
+// CHECK:         %[[D1:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:      image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<8x8x1x2x2x1280xf32>) outs(%[[D0]] :
+// CHECK-SAME:      tensor<1x12x12x1280xf32>) -> tensor<1x12x12x1280xf32>
+// CHECK:         return %[[D1]] : tensor<1x12x12x1280xf32>
 
 // -----
 
-func.func @winograd_output_transform(%arg0: tensor<8x8x?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+func.func @winograd_output_transform_dynamic(%arg0: tensor<8x8x?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
   %1 = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
     ins(%arg0 : tensor<8x8x?x?x?x?xf32>) outs(%arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
   return %1 : tensor<?x?x?x?xf32>
 }
-// CHECK:      func.func @winograd_output_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x?x?x?x?xf32>,
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
-// CHECK:        %[[D0:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME:     image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<8x8x?x?x?x?xf32>) outs(%[[ARG1]] :
-// CHECK-SAME:     tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
-// CHECK:        return %[[D0]] : tensor<?x?x?x?xf32>
-// CHECK:      }
+// CHECK-LABEL: func.func @winograd_output_transform_dynamic(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x?x?x?x?xf32>
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>
+// CHECK:         %[[D0:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:      image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<8x8x?x?x?x?xf32>) outs(%[[ARG1]] :
+// CHECK-SAME:      tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+// CHECK:         return %[[D0]] : tensor<?x?x?x?xf32>
 
 // -----
 
@@ -1069,15 +1100,13 @@
     ins(%arg0 : tensor<8x8x1x2x2x1280xf32>) outs(%0 : tensor<1x1280x12x12xf32>) -> tensor<1x1280x12x12xf32>
   return %1 : tensor<1x1280x12x12xf32>
 }
-// CHECK:      func.func @winograd_output_transform_nchw(%[[ARG0]]: tensor<8x8x1x2x2x1280xf32>) ->
-// CHECK-SAME:   tensor<1x1280x12x12xf32> {
-// CHECK:        %[[D0]] = tensor.empty() : tensor<1x1280x12x12xf32>
-// CHECK:        %[[D1]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME:     image_dimensions([2, 3]) ins(%[[ARG0]] : tensor<8x8x1x2x2x1280xf32>) outs(%[[D0]] :
-// CHECK-SAME:     tensor<1x1280x12x12xf32>) -> tensor<1x1280x12x12xf32>
-// CHECK:        return %[[D1]] : tensor<1x1280x12x12xf32>
-// CHECK:      }
-// CHECK:    }
+// CHECK-LABEL: func.func @winograd_output_transform_nchw(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x1x2x2x1280xf32>
+// CHECK:         %[[D0:.+]] = tensor.empty() : tensor<1x1280x12x12xf32>
+// CHECK:         %[[D1:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:      image_dimensions([2, 3]) ins(%[[ARG0]] : tensor<8x8x1x2x2x1280xf32>) outs(%[[D0]] :
+// CHECK-SAME:      tensor<1x1280x12x12xf32>) -> tensor<1x1280x12x12xf32>
+// CHECK:         return %[[D1]] : tensor<1x1280x12x12xf32>
 
 // -----
 
@@ -1099,18 +1128,18 @@
 // CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
 // CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
 
-// CHECK:      func.func @attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
-// CHECK-SAME:   tensor<192x1024x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
-// CHECK-SAME:   {
-// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
-// CHECK:        %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK:        %[[D1:.+]] = iree_linalg_ext.attention
-// CHECK-SAME:                {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
-// CHECK-SAME:                ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
-// CHECK-SAME:     tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%[[D0]] :
-// CHECK-SAME:     tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
-// CHECK:        return %[[D1]] : tensor<192x1024x64xf32>
-// CHECK:      }
+// CHECK-LABEL: func.func @attention(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>
+// CHECK:         %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
+// CHECK:         %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK:         %[[D1:.+]] = iree_linalg_ext.attention
+// CHECK-SAME:                 {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
+// CHECK-SAME:                 ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
+// CHECK-SAME:      tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%[[D0]] :
+// CHECK-SAME:      tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+// CHECK:         return %[[D1]] : tensor<192x1024x64xf32>
 
 // -----
 
@@ -1131,18 +1160,18 @@
 // CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
 // CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
 
-// CHECK:      func.func @cross_attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
-// CHECK-SAME:   tensor<192x2048x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x2048x64xf32>) -> tensor<192x1024x64xf32>
-// CHECK-SAME:   {
-// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
-// CHECK:        %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK:        %[[D1:.+]] = iree_linalg_ext.attention
-// CHECK-SAME:                {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
-// CHECK-SAME:                ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
-// CHECK-SAME:     tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x2048x64xf32>, f32) outs(%[[D0]] :
-// CHECK-SAME:     tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
-// CHECK:        return %[[D1]] : tensor<192x1024x64xf32>
-// CHECK:      }
+// CHECK-LABEL: func.func @cross_attention(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]+]]: tensor<192x2048x64xf32>
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x2048x64xf32>
+// CHECK:         %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
+// CHECK:         %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK:         %[[D1:.+]] = iree_linalg_ext.attention
+// CHECK-SAME:                 {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
+// CHECK-SAME:                 ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
+// CHECK-SAME:      tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x2048x64xf32>, f32) outs(%[[D0]] :
+// CHECK-SAME:      tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+// CHECK:         return %[[D1]] : tensor<192x1024x64xf32>
 
 // -----
 
@@ -1165,18 +1194,18 @@
 // CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
 // CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
 
-// CHECK:      func.func @cross_attention_transposev(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
-// CHECK-SAME:   tensor<192x2048x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x64x2048xf32>) -> tensor<192x1024x64xf32>
-// CHECK-SAME:   {
-// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
-// CHECK:        %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK:        %[[D1:.+]] = iree_linalg_ext.attention
-// CHECK-SAME:                {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
-// CHECK-SAME:                ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
-// CHECK-SAME:     tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x64x2048xf32>, f32) outs(%[[D0]] :
-// CHECK-SAME:     tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
-// CHECK:        return %[[D1]] : tensor<192x1024x64xf32>
-// CHECK:      }
+// CHECK-LABEL: func.func @cross_attention_transposev(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]+]]: tensor<192x2048x64xf32>
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x64x2048xf32>
+// CHECK:         %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
+// CHECK:         %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK:         %[[D1:.+]] = iree_linalg_ext.attention
+// CHECK-SAME:                 {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
+// CHECK-SAME:                 ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
+// CHECK-SAME:      tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x64x2048xf32>, f32) outs(%[[D0]] :
+// CHECK-SAME:      tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+// CHECK:         return %[[D1]] : tensor<192x1024x64xf32>
 
 // -----
 
@@ -1196,17 +1225,18 @@
 // CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
 // CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
 
-// CHECK:      func.func @cross_attention_transposev_dyn(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
-// CHECK-SAME:   tensor<?x?x?xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-// CHECK-SAME:   {
-// CHECK:        %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK:        %[[D1:.+]] = iree_linalg_ext.attention
-// CHECK-SAME:                {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
-// CHECK-SAME:                ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
-// CHECK-SAME:     tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32) outs(%[[ARG3]] :
-// CHECK-SAME:     tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-// CHECK:        return %[[D1]] : tensor<?x?x?xf32>
-// CHECK:      }
+// CHECK-LABEL: func.func @cross_attention_transposev_dyn(
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK-SAME:    %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK:         %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK:         %[[D1:.+]] = iree_linalg_ext.attention
+// CHECK-SAME:                 {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
+// CHECK-SAME:                 ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
+// CHECK-SAME:      tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32) outs(%[[ARG3]] :
+// CHECK-SAME:      tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK:         return %[[D1]] : tensor<?x?x?xf32>
 
 // -----
 
@@ -1220,12 +1250,12 @@
   } -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
-//       CHECK: #[[MAP:.+]] = affine_map<(d0) -> (d0)>
-//       CHECK: func @custom_op_default(
+//       CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @custom_op_default(
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?xf32>
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?xf32>
 //       CHECK:   %[[RESULT:.+]] = iree_linalg_ext.custom_op
-//  CHECK-SAME:       indexing_maps = [#[[MAP]], #[[MAP]]]
+//  CHECK-SAME:       indexing_maps = [#[[$MAP]], #[[$MAP]]]
 //  CHECK-SAME:       iterator_types = [#iree_linalg_ext.iterator_type<parallel>]
 //  CHECK-SAME:       ins(%[[ARG0]] : tensor<?xf32>) outs(%[[ARG1]] : tensor<?xf32>)
 //  CHECK-NEXT:     ^bb0(%[[B0:[a-zA-Z0-9]+]]: tensor<?xf32>, %[[B1:[a-zA-Z0-9]+]]: tensor<?xf32>)
@@ -1245,11 +1275,11 @@
   } -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
-//       CHECK: #[[MAP:.+]] = affine_map<(d0) -> ()>
-//       CHECK: func @custom_op_scalar_arg(
+//       CHECK: #[[$MAP:.+]] = affine_map<(d0) -> ()>
+// CHECK-LABEL: func @custom_op_scalar_arg(
 //  CHECK-SAME:     %[[SCALAR_ARG:[a-zA-Z0-9]+]]: f32
 //       CHECK:   iree_linalg_ext.custom_op
-//  CHECK-SAME:       indexing_maps = [#{{.+}}, #[[MAP]], #{{.+}}]
+//  CHECK-SAME:       indexing_maps = [#{{.+}}, #[[$MAP]], #{{.+}}]
 //  CHECK-SAME:       ins(%{{.+}}, %[[SCALAR_ARG]] : tensor<?xf32>, f32)
 //  CHECK-NEXT:     %[[B1:.+]]: f32
 
@@ -1265,10 +1295,10 @@
   } -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
-//       CHECK: #[[MAP:.+]] = affine_map<() -> ()>
-//       CHECK: func @custom_op_empty_affine_map(
+//       CHECK: #[[$MAP:.+]] = affine_map<() -> ()>
+// CHECK-LABEL: func @custom_op_empty_affine_map(
 //       CHECK:   iree_linalg_ext.custom_op
-//  CHECK-SAME:       indexing_maps = [#{{.+}}, #[[MAP]], #{{.+}}]
+//  CHECK-SAME:       indexing_maps = [#{{.+}}, #[[$MAP]], #{{.+}}]
 
 // -----
 
@@ -1282,8 +1312,7 @@
   } -> tensor<10xf32>
   return %0 : tensor<10xf32>
 }
-//       CHECK: #[[MAP:.+]] = affine_map<(d0) -> (d0)>
-//       CHECK: func @custom_op_static_args(
+// CHECK-LABEL: func @custom_op_static_args(
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<10xf32>
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<10xf32>
 //       CHECK:   iree_linalg_ext.custom_op
@@ -1317,7 +1346,7 @@
   } -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
-//       CHECK: func @custom_op_reduction(
+// CHECK-LABEL: func @custom_op_reduction(
 //       CHECK:   iree_linalg_ext.custom_op
 //  CHECK-SAME:     iterator_types = [#iree_linalg_ext.iterator_type<parallel>, #iree_linalg_ext.iterator_type<reduction>]
 //  CHECK-NEXT:   ^bb0
@@ -1339,8 +1368,7 @@
   } -> tensor<?xf32>, tensor<?xf32>
   return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
 }
-//       CHECK: #[[MAP:.+]] = affine_map<(d0) -> (d0)>
-//       CHECK: func @custom_op_multiple_results(
+// CHECK-LABEL: func @custom_op_multiple_results(
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?xf32>
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?xf32>
 //       CHECK:   %[[RESULT:.+]]:2 = iree_linalg_ext.custom_op
@@ -1376,14 +1404,14 @@
     } -> tensor<1000000x?xf32>, tensor<1000000x?xf32>
   return %0#0, %0#1 : tensor<1000000x?xf32>, tensor<1000000x?xf32>
 }
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0, s0)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (s0, s1)>
-//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (s1, d1)>
-//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0, s1)>
-//  CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
-//      CHECK: func @custom_op_symbolic_dims
-//      CHECK:   iree_linalg_ext.custom_op
-// CHECK-SAME:       indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP4]]]
+//   CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0, s0)>
+//   CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (s0, s1)>
+//   CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (s1, d1)>
+//   CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0, s1)>
+//   CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
+// CHECK-LABEL: func @custom_op_symbolic_dims(
+//       CHECK:   iree_linalg_ext.custom_op
+//  CHECK-SAME:       indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]]]
 
 // -----
 
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
index c9895be..5f886c7 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
@@ -6,6 +6,7 @@
 
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/CommandLine.h"
@@ -17,6 +18,7 @@
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 
@@ -269,15 +271,6 @@
          affineOp.getMap().getResult(0).isMultipleOf(constTileSize.value());
 }
 
-// Helper method to add 2 OpFoldResult inputs with affine.apply.
-static OpFoldResult addOfrs(OpBuilder &builder, Location loc, OpFoldResult a,
-                            OpFoldResult b) {
-  AffineExpr d0, d1;
-  bindDims(builder.getContext(), d0, d1);
-  auto addMap = AffineMap::get(2, 0, {d0 + d1});
-  return affine::makeComposedFoldedAffineApply(builder, loc, addMap, {a, b});
-}
-
 //===----------------------------------------------------------------------===//
 // OnlineAttentionOp
 //===----------------------------------------------------------------------===//
@@ -448,7 +441,7 @@
 /// ```
 ///   %im2col = iree_linalg_ext.im2col
 ///       strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-///       m_offset = [%m_off] k_offset = [%k_off]
+///       m_offset = [%m_off] * [1] k_offset = [%k_off] * [1]
 ///       batch_pos = [0] m_pos = [1, 2] k_pos = [3]
 ///       ins(%in : tensor<2x34x34x640xf32>)
 ///       outs(%out : tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
@@ -470,13 +463,30 @@
 FailureOr<SmallVector<Value>> Im2colOp::decomposeOperation(OpBuilder &b) {
   Location loc = getLoc();
   Value inputSlice = getInput();
-  // Unroll all but the K loop
-  SmallVector<OpFoldResult> kOffset = getMixedKOffset();
-  SmallVector<OpFoldResult> mOffset = getMixedMOffset();
-  // Only support single K and M output dimension for now.
-  if (kOffset.size() != 1 || mOffset.size() != 1) {
-    return failure();
-  }
+
+  // This is part of the im2col verifier, but check here in case this changes.
+  assert(getConstantIntValue(getMixedMStrides().back()).value() == 1 &&
+         getConstantIntValue(getMixedKStrides().back()).value() == 1 &&
+         "Expected inner m_offset and k_offset to be 1");
+
+  // Get the linearized mOffset and kOffset.
+  auto linearizeIndex = [&](ArrayRef<OpFoldResult> inds,
+                            ArrayRef<OpFoldResult> basis) {
+    MLIRContext *ctx = b.getContext();
+    SmallVector<AffineExpr> dims(inds.size()), symbols(basis.size());
+    bindDimsList<AffineExpr>(ctx, dims);
+    bindSymbolsList<AffineExpr>(ctx, symbols);
+    AffineExpr linearExpr = mlir::linearize(ctx, dims, symbols);
+    SmallVector<OpFoldResult> mapOperands(inds);
+    mapOperands.append(basis.begin(), basis.end());
+    auto linearMap = AffineMap::get(
+        /*dimCount=*/inds.size(), /*symbolCount=*/basis.size(), linearExpr);
+    OpFoldResult linearIdx =
+        affine::makeComposedFoldedAffineApply(b, loc, linearMap, mapOperands);
+    return linearIdx;
+  };
+  OpFoldResult mOffset = linearizeIndex(getMixedMOffset(), getMixedMStrides());
+  OpFoldResult kOffset = linearizeIndex(getMixedKOffset(), getMixedKStrides());
 
   // Step 1: Tile the im2col op to loops with contiguous slices in the
   // innermost loop.
@@ -499,14 +509,26 @@
   SetVector<int64_t> batchPosSet(getBatchPos().begin(), getBatchPos().end());
   OpFoldResult innerSliceSize;
   for (int idx = inputSizes.size() - 1; idx >= 0; --idx) {
-    if (!batchPosSet.contains(idx)) {
-      innerSliceSize = inputSizes[idx];
-      break;
+    if (batchPosSet.contains(idx)) {
+      continue;
     }
+    innerSliceSize = inputSizes[idx];
+    // If the innermost non-batch dimension is an m_pos dimension, then use the
+    // corresponding kernel_size instead of the input tensor size. This is
+    // because the slice will be of size `kernel_size` at some offset
+    // `i * kernel_size` in this case.
+    for (auto [mPos, kernelSize] :
+         llvm::zip_equal(getMPos(), getMixedKernelSize())) {
+      if (mPos == idx) {
+        innerSliceSize = kernelSize;
+      }
+    }
+    break;
   }
-  bool tileK =
-      !willBeContiguousSlice(innerSliceSize, kTileSize, kOffset.front());
-  if (!tileK) {
+  bool vectorizeInnerKLoop =
+      getKPos().back() == getInputRank() - 1 &&
+      willBeContiguousSlice(innerSliceSize, kTileSize, kOffset);
+  if (vectorizeInnerKLoop) {
     iterationDomain.pop_back();
   } else {
     kTileSize = b.getIndexAttr(1);
@@ -571,9 +593,14 @@
     }
     kBasis.push_back(size);
   }
-  OpFoldResult kIndex = kOffset.front();
-  if (tileK) {
-    kIndex = addOfrs(b, nestedLoc, kOffset.front(), ivs.back());
+  OpFoldResult kIndex = kOffset;
+  for (auto [i, ivIdx, stride] :
+       llvm::enumerate(getKOutputDims(), getMixedKStrides())) {
+    if (vectorizeInnerKLoop && i == getMixedKOffset().size() - 1) {
+      break;
+    }
+    OpFoldResult ivOffset = mulOfrs(b, nestedLoc, stride, ivs[ivIdx]);
+    kIndex = addOfrs(b, nestedLoc, kIndex, ivOffset);
   }
   FailureOr<SmallVector<Value>> maybeDelinKOffset = affine::delinearizeIndex(
       b, nestedLoc, getValueOrCreateConstantIndexOp(b, loc, kIndex),
@@ -596,11 +623,22 @@
     inputKOffset.push_back(delinKOffset[delinKIdx++]);
   }
 
-  // Compute offsets for extract. Start by delinearizing the combined offset
-  // of m_offset and the offset from the tiled loop, using the mBasis. This
-  // will give an index into the delinearized output space of the convolution.
-  Value mArg = tileK ? ivs[ivs.size() - 2] : ivs.back();
-  OpFoldResult linearMOffset = addOfrs(b, nestedLoc, mArg, mOffset.front());
+  // Compute offsets for extract. The linearized im2col result M offset is
+  // computed as the m_offset * m_strides inner product plus the linearized
+  // offset from the tiled m loops. The M offsets into the im2col input are then
+  // computed as the delinearized im2col result M offset (in the convolution
+  // result iteration space), plus the convolutional window offsets computed
+  // above.
+  SmallVector<int64_t> mOutDims = getMOutputDims();
+  SmallVector<OpFoldResult> mIvs, mOutStrides(getMixedMStrides());
+  for (auto [idx, dim] : llvm::enumerate(getMOutputDims())) {
+    mIvs.push_back(ivs[dim]);
+  }
+  OpFoldResult linearMIv = linearizeIndex(mIvs, mOutStrides);
+  OpFoldResult linearMOffset = addOfrs(b, nestedLoc, linearMIv, mOffset);
+  // Delinearize the m_offset * m_strides into the convolution output space.
+  // `mBasis` contains the basis for the iteration space of result of the
+  // convolution op (i.e., basis for result H and W dims).
   FailureOr<SmallVector<Value>> maybeDelinMOffset = affine::delinearizeIndex(
       b, nestedLoc,
       getValueOrCreateConstantIndexOp(b, nestedLoc, linearMOffset), mBasis);
@@ -648,7 +686,8 @@
   // Extract a slice from the input tensor.
 
   ShapedType outputType = getOutputType();
-  SmallVector<OpFoldResult> kTileSizes(outputType.getRank(), b.getIndexAttr(1));
+  SmallVector<OpFoldResult> kTileSizes(
+      std::min<int64_t>(getOutputRank(), getInputRank()), b.getIndexAttr(1));
   kTileSizes.back() = kTileSize;
 
   SmallVector<int64_t> kTileSizeStatic;
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp
index 31daf1f..b4d4cca 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp
@@ -137,17 +137,19 @@
     SmallVector<int64_t> dilations(convOp.getDilations().getValues<int64_t>());
     SmallVector<OpFoldResult> kernelSize = {rewriter.getIndexAttr(fh),
                                             rewriter.getIndexAttr(fw)};
-    SmallVector<OpFoldResult> kOffset = {rewriter.getIndexAttr(0)};
     SmallVector<OpFoldResult> mOffset = {rewriter.getIndexAttr(0)};
+    SmallVector<OpFoldResult> mBasis = {rewriter.getIndexAttr(1)};
+    SmallVector<OpFoldResult> kOffset = {rewriter.getIndexAttr(0)};
+    SmallVector<OpFoldResult> kBasis = {rewriter.getIndexAttr(1)};
     SmallVector<int64_t> batchPos = {0};
     SmallVector<int64_t> mPos = {1, 2};
     SmallVector<int64_t> kPos = {3};
-    Value img2ColTensor =
-        rewriter
-            .create<IREE::LinalgExt::Im2colOp>(
-                loc, input, /*output=*/colTensor, strides, dilations,
-                kernelSize, mOffset, kOffset, batchPos, mPos, kPos)
-            .getResult(0);
+    Value img2ColTensor = rewriter
+                              .create<IREE::LinalgExt::Im2colOp>(
+                                  loc, input, /*output=*/colTensor, strides,
+                                  dilations, kernelSize, mOffset, mBasis,
+                                  kOffset, kBasis, batchPos, mPos, kPos)
+                              .getResult(0);
 
     SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
     auto reshapedFilterType =
@@ -260,17 +262,19 @@
     SmallVector<int64_t> dilations(convOp.getDilations().getValues<int64_t>());
     SmallVector<OpFoldResult> kernelSize = {rewriter.getIndexAttr(fh),
                                             rewriter.getIndexAttr(fw)};
-    SmallVector<OpFoldResult> kOffset = {rewriter.getIndexAttr(0)};
     SmallVector<OpFoldResult> mOffset = {rewriter.getIndexAttr(0)};
+    SmallVector<OpFoldResult> mBasis = {rewriter.getIndexAttr(1)};
+    SmallVector<OpFoldResult> kOffset = {rewriter.getIndexAttr(0)};
+    SmallVector<OpFoldResult> kBasis = {rewriter.getIndexAttr(1)};
     SmallVector<int64_t> batchPos = {0};
     SmallVector<int64_t> mPos = {2, 3};
     SmallVector<int64_t> kPos = {1};
-    Value img2ColTensor =
-        rewriter
-            .create<IREE::LinalgExt::Im2colOp>(
-                loc, input, /*output=*/colTensor, strides, dilations,
-                kernelSize, mOffset, kOffset, batchPos, mPos, kPos)
-            .getResult(0);
+    Value img2ColTensor = rewriter
+                              .create<IREE::LinalgExt::Im2colOp>(
+                                  loc, input, /*output=*/colTensor, strides,
+                                  dilations, kernelSize, mOffset, mBasis,
+                                  kOffset, kBasis, batchPos, mPos, kPos)
+                              .getResult(0);
 
     SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
     auto reshapedFilterType =
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_im2col.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_im2col.mlir
index e827a7f..fd59547 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_im2col.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_im2col.mlir
@@ -17,7 +17,7 @@
 // CHECK:      %[[EMPTY:.+]] = tensor.empty() : tensor<1x196x36xf32>
 // CHECK:      %[[IM2COL:.+]] = iree_linalg_ext.im2col
 // CHECK-SAME:   strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-// CHECK-SAME:   m_offset = [0] k_offset = [0]
+// CHECK-SAME:   m_offset = [0] * [1] k_offset = [0] * [1]
 // CHECK-SAME:   batch_pos = [0] m_pos = [1, 2] k_pos = [3]
 // CHECK-SAME:   ins(%[[ARG0]] : tensor<1x16x16x4xf32>)
 // CHECK-SAME:   outs(%[[EMPTY]] : tensor<1x196x36xf32>) -> tensor<1x196x36xf32>
@@ -53,7 +53,7 @@
 // CHECK:      %[[EMPTY:.+]] = tensor.empty() : tensor<1x196x36xf32>
 // CHECK:      %[[IM2COL:.+]] = iree_linalg_ext.im2col
 // CHECK-SAME:   strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-// CHECK-SAME:   m_offset = [0] k_offset = [0]
+// CHECK-SAME:   m_offset = [0] * [1] k_offset = [0] * [1]
 // CHECK-SAME:   batch_pos = [0] m_pos = [2, 3] k_pos = [1]
 // CHECK-SAME:   ins(%[[ARG0]] : tensor<1x4x16x16xf32>)
 // CHECK-SAME:   outs(%[[EMPTY]] : tensor<1x196x36xf32>) -> tensor<1x196x36xf32>
@@ -89,7 +89,7 @@
 // CHECK:      %[[EMPTY:.+]] = tensor.empty() : tensor<1x196x36xf16>
 // CHECK:      %[[IM2COL:.+]] = iree_linalg_ext.im2col
 // CHECK-SAME:   strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-// CHECK-SAME:   m_offset = [0] k_offset = [0]
+// CHECK-SAME:   m_offset = [0] * [1] k_offset = [0] * [1]
 // CHECK-SAME:   batch_pos = [0] m_pos = [1, 2] k_pos = [3]
 // CHECK-SAME:   ins(%[[ARG0]] : tensor<1x16x16x4xf16>)
 // CHECK-SAME:   outs(%[[EMPTY]] : tensor<1x196x36xf16>) -> tensor<1x196x36xf16>
@@ -127,7 +127,7 @@
 // CHECK:      %[[EMPTY:.+]] = tensor.empty() : tensor<1x49x36xf16>
 // CHECK:      %[[IM2COL:.+]] = iree_linalg_ext.im2col
 // CHECK-SAME:   strides = [2, 2] dilations = [1, 1] kernel_size = [3, 3]
-// CHECK-SAME:   m_offset = [0] k_offset = [0]
+// CHECK-SAME:   m_offset = [0] * [1] k_offset = [0] * [1]
 // CHECK-SAME:   batch_pos = [0] m_pos = [1, 2] k_pos = [3]
 // CHECK-SAME:   ins(%[[ARG0]] : tensor<1x16x16x4xf16>)
 // CHECK-SAME:   outs(%[[EMPTY]] : tensor<1x49x36xf16>) -> tensor<1x49x36xf16>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_im2col.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_im2col.mlir
index 455381a..0316e8f 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_im2col.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_im2col.mlir
@@ -6,33 +6,39 @@
   func.func @im2col_untile_k(%arg0: tensor<2x34x34x640xf32>, %m_size: index, %m_off: index, %k: index) -> tensor<2x?x4xf32> {
     %0 = tensor.empty(%m_size) : tensor<2x?x4xf32>
     %k_off = affine.apply #map(%k)
-    %7 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] m_offset = [%m_off] k_offset = [%k_off] batch_pos = [0] m_pos = [1, 2] k_pos = [3] ins(%arg0 : tensor<2x34x34x640xf32>) outs(%0 : tensor<2x?x4xf32>) -> tensor<2x?x4xf32>
+    %7 = iree_linalg_ext.im2col
+            strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+            m_offset = [%m_off] * [1] k_offset = [%k_off] * [1]
+            batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+            ins(%arg0 : tensor<2x34x34x640xf32>)
+            outs(%0 : tensor<2x?x4xf32>) -> tensor<2x?x4xf32>
     return %7 : tensor<2x?x4xf32>
   }
 }
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 160) * 640)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) floordiv 32 + s1 floordiv 480)>
-//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) mod 32 + s1 floordiv 160 - (s1 floordiv 480) * 3)>
-//      CHECK: func.func @im2col_untile_k(%[[ARG0:.+]]: tensor<2x34x34x640xf32>
-// CHECK-SAME:   %[[mSIZE:.+]]: index, %[[mOFF:.+]]: index, %[[K:.+]]: index)
-//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
-//      CHECK:   %[[OUT_TILE:.+]] = tensor.empty(%[[mSIZE]]) : tensor<2x?x4xf32>
-//      CHECK:   %[[bLOOP:.+]] = scf.for %[[b:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[OUT0:.+]] = %[[OUT_TILE]]) -> (tensor<2x?x4xf32>) {
-//      CHECK:     %[[mLOOP:.+]] = scf.for %[[m:.+]] = %[[C0]] to %[[mSIZE]] step %[[C1]] iter_args(%[[OUT1:.+]] = %[[OUT0]]) -> (tensor<2x?x4xf32>) {
-//  CHECK-DAG:       %[[kIDX:.+]] = affine.apply #[[MAP]]()[%[[K]]]
-//  CHECK-DAG:       %[[hIDX:.+]] = affine.apply #[[MAP1]](%[[m]])[%[[mOFF]], %[[K]]]
-//  CHECK-DAG:       %[[wIDX:.+]] = affine.apply #[[MAP2]](%[[m]])[%[[mOFF]], %[[K]]]
-//      CHECK:       %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[b]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x4xf32>
-//      CHECK:       %[[OUT_SLICE:.+]] = tensor.extract_slice %[[OUT1]][%[[b]], %[[m]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x?x4xf32> to tensor<1x1x4xf32>
-//      CHECK:       %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
-//      CHECK:       %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[OUT1]][%[[b]], %[[m]], 0] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<2x?x4xf32>
-//      CHECK:       scf.yield %[[INSERT]] : tensor<2x?x4xf32>
-//      CHECK:     }
-//      CHECK:     scf.yield %[[mLOOP]] : tensor<2x?x4xf32>
-//      CHECK:   }
-//      CHECK:   return %[[bLOOP]] : tensor<2x?x4xf32>
+//   CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 160) * 640)>
+//   CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) floordiv 32 + s1 floordiv 480)>
+//   CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) mod 32 + s1 floordiv 160 - (s1 floordiv 480) * 3)>
+// CHECK-LABEL: func.func @im2col_untile_k
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]
+//  CHECK-SAME:     %[[mSIZE:[a-zA-Z0-9_]+]]
+//  CHECK-SAME:     %[[mOFF:[a-zA-Z0-9_]+]]
+//  CHECK-SAME:     %[[K:[a-zA-Z0-9_]+]]
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//       CHECK:   %[[OUT_TILE:.+]] = tensor.empty(%[[mSIZE]]) : tensor<2x?x4xf32>
+//       CHECK:   %[[bLOOP:.+]] = scf.for %[[b:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[OUT0:.+]] = %[[OUT_TILE]]) -> (tensor<2x?x4xf32>)
+//       CHECK:     %[[mLOOP:.+]] = scf.for %[[m:.+]] = %[[C0]] to %[[mSIZE]] step %[[C1]] iter_args(%[[OUT1:.+]] = %[[OUT0]]) -> (tensor<2x?x4xf32>)
+//   CHECK-DAG:       %[[kIDX:.+]] = affine.apply #[[$MAP]]()[%[[K]]]
+//   CHECK-DAG:       %[[hIDX:.+]] = affine.apply #[[$MAP1]](%[[m]])[%[[mOFF]], %[[K]]]
+//   CHECK-DAG:       %[[wIDX:.+]] = affine.apply #[[$MAP2]](%[[m]])[%[[mOFF]], %[[K]]]
+//       CHECK:       %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[b]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x4xf32>
+//       CHECK:       %[[OUT_SLICE:.+]] = tensor.extract_slice %[[OUT1]][%[[b]], %[[m]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x?x4xf32> to tensor<1x1x4xf32>
+//       CHECK:       %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
+//       CHECK:       %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[OUT1]][%[[b]], %[[m]], 0] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<2x?x4xf32>
+//       CHECK:       scf.yield %[[INSERT]] : tensor<2x?x4xf32>
+//       CHECK:     scf.yield %[[mLOOP]] : tensor<2x?x4xf32>
+//       CHECK:   return %[[bLOOP]] : tensor<2x?x4xf32>
 
 // -----
 
@@ -42,36 +48,106 @@
     %c1 = arith.constant 1 : index
     %c0 = arith.constant 0 : index
     %0 = tensor.empty(%m_size, %k_size) : tensor<2x?x?xf32>
-    %8 = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2] m_offset = [%m_off] k_offset = [%k_off] batch_pos = [1] m_pos = [3, 2] k_pos = [0] ins(%arg0 : tensor<640x2x101x172xf32>) outs(%0 : tensor<2x?x?xf32>) -> tensor<2x?x?xf32>
+    %8 = iree_linalg_ext.im2col
+            strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2]
+            m_offset = [%m_off] * [1] k_offset = [%k_off] * [1]
+            batch_pos = [1] m_pos = [3, 2] k_pos = [0]
+            ins(%arg0 : tensor<640x2x101x172xf32>)
+            outs(%0 : tensor<2x?x?xf32>) -> tensor<2x?x?xf32>
     return %8 : tensor<2x?x?xf32>
   }
 }
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0] -> ((d0 + s0) floordiv 10)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (((d0 + s0) floordiv 32) * 5 + (((d1 + s1) mod 10) floordiv 5) * 4)>
-//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * 3 + d1 * 7 + s0 * 3 + s1 * 7 - ((d0 + s0) floordiv 32) * 96 - ((d1 + s1) floordiv 5) * 35)>
-//      CHECK: func.func @im2col_transposed_m_pos(%[[ARG0:.+]]: tensor<640x2x101x172xf32>
-// CHECK-SAME:   %[[mSIZE:.+]]: index, %[[kSIZE:.+]]: index, %[[mOFF:.+]]: index, %[[kOFF:.+]]: index)
-//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
-//      CHECK:   %[[OUT_TILE:.+]] = tensor.empty(%[[mSIZE]], %[[kSIZE]]) : tensor<2x?x?xf32>
-//      CHECK:   %[[bLOOP:.+]] = scf.for %[[b:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[OUT0:.+]] = %[[OUT_TILE]]) -> (tensor<2x?x?xf32>) {
-//      CHECK:     %[[mLOOP:.+]] = scf.for %[[m:.+]] = %[[C0]] to %[[mSIZE]] step %[[C1]] iter_args(%[[OUT1:.+]] = %[[OUT0]]) -> (tensor<2x?x?xf32>) {
-//      CHECK:       %[[kLOOP:.+]] = scf.for %[[k:.+]] = %[[C0]] to %[[kSIZE]] step %[[C1]] iter_args(%[[OUT2:.+]] = %[[OUT1]]) -> (tensor<2x?x?xf32>) {
-//  CHECK-DAG:         %[[kIDX:.+]] = affine.apply #[[MAP]](%[[k]])[%[[kOFF]]]
-//  CHECK-DAG:         %[[hIDX:.+]] = affine.apply #[[MAP1]](%[[m]], %[[k]])[%[[mOFF]], %[[kOFF]]]
-//  CHECK-DAG:         %[[wIDX:.+]] = affine.apply #[[MAP2]](%[[m]], %[[k]])[%[[mOFF]], %[[kOFF]]]
-//      CHECK:         %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[kIDX]], %[[b]], %[[wIDX]], %[[hIDX]]] [1, 1, 1, 1] [1, 1, 1, 1] : tensor<640x2x101x172xf32> to tensor<1x1x1xf32>
-//      CHECK:         %[[OUT_SLICE:.+]] = tensor.extract_slice %[[OUT2]][%[[b]], %[[m]], %[[k]]] [1, 1, 1] [1, 1, 1] : tensor<2x?x?xf32> to tensor<1x1x1xf32>
-//      CHECK:         %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x1xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
-//      CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[OUT2]][%[[b]], %[[m]], %[[k]]] [1, 1, 1] [1, 1, 1] : tensor<1x1x1xf32> into tensor<2x?x?xf32>
-//      CHECK:         scf.yield %[[INSERT]] : tensor<2x?x?xf32>
-//      CHECK:       }
-//      CHECK:       scf.yield %[[kLOOP]] : tensor<2x?x?xf32>
-//      CHECK:     }
-//      CHECK:     scf.yield %[[mLOOP]] : tensor<2x?x?xf32>
-//      CHECK:   }
-//      CHECK:   return %[[bLOOP]] : tensor<2x?x?xf32>
+//   CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0)[s0] -> ((d0 + s0) floordiv 10)>
+//   CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (((d0 + s0) floordiv 32) * 5 + (((d1 + s1) mod 10) floordiv 5) * 4)>
+//   CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * 3 + d1 * 7 + s0 * 3 + s1 * 7 - ((d0 + s0) floordiv 32) * 96 - ((d1 + s1) floordiv 5) * 35)>
+// CHECK-LABEL: func.func @im2col_transposed_m_pos
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]
+//  CHECK-SAME:     %[[mSIZE:[a-zA-Z0-9_]+]]
+//  CHECK-SAME:     %[[kSIZE:[a-zA-Z0-9_]+]]
+//  CHECK-SAME:     %[[mOFF:[a-zA-Z0-9_]+]]
+//  CHECK-SAME:     %[[kOFF:[a-zA-Z0-9_]+]]
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//       CHECK:   %[[OUT_TILE:.+]] = tensor.empty(%[[mSIZE]], %[[kSIZE]]) : tensor<2x?x?xf32>
+//       CHECK:   %[[bLOOP:.+]] = scf.for %[[b:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[OUT0:.+]] = %[[OUT_TILE]]) -> (tensor<2x?x?xf32>)
+//       CHECK:     %[[mLOOP:.+]] = scf.for %[[m:.+]] = %[[C0]] to %[[mSIZE]] step %[[C1]] iter_args(%[[OUT1:.+]] = %[[OUT0]]) -> (tensor<2x?x?xf32>)
+//       CHECK:       %[[kLOOP:.+]] = scf.for %[[k:.+]] = %[[C0]] to %[[kSIZE]] step %[[C1]] iter_args(%[[OUT2:.+]] = %[[OUT1]]) -> (tensor<2x?x?xf32>)
+//   CHECK-DAG:         %[[kIDX:.+]] = affine.apply #[[$MAP]](%[[k]])[%[[kOFF]]]
+//   CHECK-DAG:         %[[hIDX:.+]] = affine.apply #[[$MAP1]](%[[m]], %[[k]])[%[[mOFF]], %[[kOFF]]]
+//   CHECK-DAG:         %[[wIDX:.+]] = affine.apply #[[$MAP2]](%[[m]], %[[k]])[%[[mOFF]], %[[kOFF]]]
+//       CHECK:         %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[kIDX]], %[[b]], %[[wIDX]], %[[hIDX]]] [1, 1, 1, 1] [1, 1, 1, 1] : tensor<640x2x101x172xf32> to tensor<1x1x1xf32>
+//       CHECK:         %[[OUT_SLICE:.+]] = tensor.extract_slice %[[OUT2]][%[[b]], %[[m]], %[[k]]] [1, 1, 1] [1, 1, 1] : tensor<2x?x?xf32> to tensor<1x1x1xf32>
+//       CHECK:         %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x1xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
+//       CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[OUT2]][%[[b]], %[[m]], %[[k]]] [1, 1, 1] [1, 1, 1] : tensor<1x1x1xf32> into tensor<2x?x?xf32>
+//       CHECK:         scf.yield %[[INSERT]] : tensor<2x?x?xf32>
+//       CHECK:       scf.yield %[[kLOOP]] : tensor<2x?x?xf32>
+//       CHECK:     scf.yield %[[mLOOP]] : tensor<2x?x?xf32>
+//       CHECK:   return %[[bLOOP]] : tensor<2x?x?xf32>
+
+// -----
+
+module {
+  func.func @im2col_expanded(%arg0: tensor<2x34x34x640xf32>, %m_size0: index, %m_size1: index, %m0: index, %m1: index, %k: index, %m_stride: index) -> tensor<2x?x?x2x4xf32> {
+    %0 = tensor.empty(%m_size0, %m_size1) : tensor<2x?x?x2x4xf32>
+    %7 = iree_linalg_ext.im2col
+            strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+            m_offset = [%m0, %m1] * [%m_stride, 1] k_offset = [%k, 0] * [4, 1]
+            batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+            ins(%arg0 : tensor<2x34x34x640xf32>)
+            outs(%0 : tensor<2x?x?x2x4xf32>) -> tensor<2x?x?x2x4xf32>
+    return %7 : tensor<2x?x?x2x4xf32>
+  }
+}
+//   CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0)[s0] -> (d0 * 4 + s0 * 4 - ((d0 + s0) floordiv 160) * 640)>
+//   CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> ((d1 + s2 + d0 * s0 + s1 * s0) floordiv 32 + (d2 + s3) floordiv 480)>
+//   CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s0 + d1 + s1 * s0 + s2 - ((d1 + s2 + d0 * s0 + s1 * s0) floordiv 32) * 32 + (d2 + s3) floordiv 160 - ((d2 + s3) floordiv 480) * 3)>
+// CHECK-LABEL: func.func @im2col_expanded
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]
+//  CHECK-SAME:     %[[mSIZE0:[a-zA-Z0-9_]+]]
+//  CHECK-SAME:     %[[mSIZE1:[a-zA-Z0-9_]+]]
+//  CHECK-SAME:     %[[mOFF0:[a-zA-Z0-9_]+]]
+//  CHECK-SAME:     %[[mOFF1:[a-zA-Z0-9_]+]]
+//  CHECK-SAME:     %[[kOFF:[a-zA-Z0-9_]+]]
+//  CHECK-SAME:     %[[mSTRIDE:[a-zA-Z0-9_]+]]
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//       CHECK:   %[[OUT_TILE:.+]] = tensor.empty(%[[mSIZE0]], %[[mSIZE1]]) : tensor<2x?x?x2x4xf32>
+//       CHECK:   %[[bLOOP:.+]] = scf.for %[[b:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[OUT0:.+]] = %[[OUT_TILE]]) -> (tensor<2x?x?x2x4xf32>)
+//       CHECK:     %[[mLOOP0:.+]] = scf.for %[[m0:.+]] = %[[C0]] to %[[mSIZE0]] step %[[C1]] iter_args(%[[OUT1:.+]] = %[[OUT0]]) -> (tensor<2x?x?x2x4xf32>)
+//       CHECK:       %[[mLOOP1:.+]] = scf.for %[[m1:.+]] = %[[C0]] to %[[mSIZE1]] step %[[C1]] iter_args(%[[OUT2:.+]] = %[[OUT1]]) -> (tensor<2x?x?x2x4xf32>)
+//       CHECK:         %[[kLOOP:.+]] = scf.for %[[k:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[OUT3:.+]] = %[[OUT2]]) -> (tensor<2x?x?x2x4xf32>)
+//   CHECK-DAG:           %[[kIDX:.+]] = affine.apply #[[$MAP]](%[[k]])[%[[kOFF]]]
+//   CHECK-DAG:           %[[hIDX:.+]] = affine.apply #[[$MAP1]](%[[m0]], %[[m1]], %[[k]])[%[[mSTRIDE]], %[[mOFF0]], %[[mOFF1]], %[[kOFF]]]
+//   CHECK-DAG:           %[[wIDX:.+]] = affine.apply #[[$MAP2]](%[[m0]], %[[m1]], %[[k]])[%[[mSTRIDE]], %[[mOFF0]], %[[mOFF1]], %[[kOFF]]]
+//       CHECK:           %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[b]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x1x4xf32>
+//       CHECK:           %[[OUT_SLICE:.+]] = tensor.extract_slice %[[OUT3]][%[[b]], %[[m0]], %[[m1]], %[[k]], 0] [1, 1, 1, 1, 4] [1, 1, 1, 1, 1] : tensor<2x?x?x2x4xf32> to tensor<1x1x1x4xf32>
+//       CHECK:           %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x1x4xf32>) -> tensor<1x1x1x4xf32>
+//       CHECK:           %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[OUT3]][%[[b]], %[[m0]], %[[m1]], %[[k]], 0] [1, 1, 1, 1, 4] [1, 1, 1, 1, 1] : tensor<1x1x1x4xf32> into tensor<2x?x?x2x4xf32>
+//       CHECK:           scf.yield %[[INSERT]] : tensor<2x?x?x2x4xf32>
+//       CHECK:         scf.yield %[[kLOOP]] : tensor<2x?x?x2x4xf32>
+//       CHECK:       scf.yield %[[mLOOP1]] : tensor<2x?x?x2x4xf32>
+//       CHECK:     scf.yield %[[mLOOP0]] : tensor<2x?x?x2x4xf32>
+//       CHECK:   return %[[bLOOP]] : tensor<2x?x?x2x4xf32>
+
+// -----
+
+module {
+  func.func @im2col_expanded_nchw(%arg0: tensor<2x640x34x34xf32>, %m0: index, %m1: index, %k: index) -> tensor<2x1x1x2x4xf32> {
+    %0 = tensor.empty() : tensor<2x1x1x2x4xf32>
+    %7 = iree_linalg_ext.im2col
+            strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+            m_offset = [%m0, %m1] * [32, 1] k_offset = [%k, 0] * [4, 1]
+            batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+            ins(%arg0 : tensor<2x640x34x34xf32>)
+            outs(%0 : tensor<2x1x1x2x4xf32>) -> tensor<2x1x1x2x4xf32>
+    return %7 : tensor<2x1x1x2x4xf32>
+  }
+}
+// Verify that the NCHW layout does not vectorize.
+// CHECK-LABEL: func.func @im2col_expanded_nchw
+//       CHECK:   linalg.copy ins({{.*}} : tensor<1x1x1x1xf32>) outs({{.*}} : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
 
 // -----
 
@@ -80,57 +156,64 @@
   func.func @im2col_unrolled(%arg0: tensor<2x34x34x640xf32>, %m_off: index, %k: index) -> tensor<2x2x4xf32> {
     %0 = tensor.empty() : tensor<2x2x4xf32>
     %k_off = affine.apply #map(%k)
-    %7 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] m_offset = [%m_off] k_offset = [%k_off] batch_pos = [0] m_pos = [1, 2] k_pos = [3] ins(%arg0 : tensor<2x34x34x640xf32>) outs(%0 : tensor<2x2x4xf32>) -> tensor<2x2x4xf32>
+    %7 = iree_linalg_ext.im2col
+            strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+            m_offset = [%m_off] * [1] k_offset = [%k_off] * [1]
+            batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+            ins(%arg0 : tensor<2x34x34x640xf32>)
+            outs(%0 : tensor<2x2x4xf32>) -> tensor<2x2x4xf32>
     return %7 : tensor<2x2x4xf32>
   }
 }
-//  CHECK-UNROLL-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 160) * 640)>
-//  CHECK-UNROLL-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) floordiv 32 + s1 floordiv 480)>
-//  CHECK-UNROLL-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) mod 32 + s1 floordiv 160 - (s1 floordiv 480) * 3)>
-//      CHECK-UNROLL: func.func @im2col_unrolled(%[[ARG0:.+]]: tensor<2x34x34x640xf32>
-// CHECK-UNROLL-SAME:   %[[mOFF:.+]]: index, %[[K:.+]]: index)
-//  CHECK-UNROLL-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//  CHECK-UNROLL-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//      CHECK-UNROLL:   %[[OUT_TILE:.+]] = tensor.empty() : tensor<2x2x4xf32>
+//   CHECK-UNROLL-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 160) * 640)>
+//   CHECK-UNROLL-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) floordiv 32 + s1 floordiv 480)>
+//   CHECK-UNROLL-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) mod 32 + s1 floordiv 160 - (s1 floordiv 480) * 3)>
+// CHECK-UNROLL-LABEL: func.func @im2col_unrolled
+//  CHECK-UNROLL-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]
+//  CHECK-UNROLL-SAME:     %[[mOFF:[a-zA-Z0-9_]+]]
+//  CHECK-UNROLL-SAME:     %[[K:[a-zA-Z0-9_]+]]
+//   CHECK-UNROLL-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-UNROLL-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//       CHECK-UNROLL:   %[[OUT_TILE:.+]] = tensor.empty() : tensor<2x2x4xf32>
 
 //  First iteration
 //
-//  CHECK-UNROLL-DAG:   %[[kIDX:.+]] = affine.apply #[[MAP]]()[%[[K]]]
-//  CHECK-UNROLL-DAG:   %[[hIDX:.+]] = affine.apply #[[MAP1]](%[[C0]])[%[[mOFF]], %[[K]]]
-//  CHECK-UNROLL-DAG:   %[[wIDX:.+]] = affine.apply #[[MAP2]](%[[C0]])[%[[mOFF]], %[[K]]]
-//      CHECK-UNROLL:   %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[C0]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x4xf32>
-//      CHECK-UNROLL:   %[[OUT_SLICE:.+]] = tensor.extract_slice %[[OUT_TILE]][%[[C0]], %[[C0]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x2x4xf32> to tensor<1x1x4xf32>
-//      CHECK-UNROLL:   %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
-//      CHECK-UNROLL:   %[[INSERT0:.+]] = tensor.insert_slice %[[COPY]] into %[[OUT_TILE]][%[[C0]], %[[C0]], 0] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<2x2x4xf32>
+//   CHECK-UNROLL-DAG:   %[[kIDX:.+]] = affine.apply #[[$MAP]]()[%[[K]]]
+//   CHECK-UNROLL-DAG:   %[[hIDX:.+]] = affine.apply #[[$MAP1]](%[[C0]])[%[[mOFF]], %[[K]]]
+//   CHECK-UNROLL-DAG:   %[[wIDX:.+]] = affine.apply #[[$MAP2]](%[[C0]])[%[[mOFF]], %[[K]]]
+//       CHECK-UNROLL:   %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[C0]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x4xf32>
+//       CHECK-UNROLL:   %[[OUT_SLICE:.+]] = tensor.extract_slice %[[OUT_TILE]][%[[C0]], %[[C0]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x2x4xf32> to tensor<1x1x4xf32>
+//       CHECK-UNROLL:   %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
+//       CHECK-UNROLL:   %[[INSERT0:.+]] = tensor.insert_slice %[[COPY]] into %[[OUT_TILE]][%[[C0]], %[[C0]], 0] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<2x2x4xf32>
 
 //  Second iteration
 //
-//  CHECK-UNROLL-DAG:   %[[kIDX:.+]] = affine.apply #[[MAP]]()[%[[K]]]
-//  CHECK-UNROLL-DAG:   %[[hIDX:.+]] = affine.apply #[[MAP1]](%[[C1]])[%[[mOFF]], %[[K]]]
-//  CHECK-UNROLL-DAG:   %[[wIDX:.+]] = affine.apply #[[MAP2]](%[[C1]])[%[[mOFF]], %[[K]]]
-//      CHECK-UNROLL:   %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[C0]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x4xf32>
-//      CHECK-UNROLL:   %[[OUT_SLICE:.+]] = tensor.extract_slice %[[INSERT0]][%[[C0]], %[[C1]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x2x4xf32> to tensor<1x1x4xf32>
-//      CHECK-UNROLL:   %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
-//      CHECK-UNROLL:   %[[INSERT1:.+]] = tensor.insert_slice %[[COPY]] into %[[INSERT0]][%[[C0]], %[[C1]], 0] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<2x2x4xf32>
+//   CHECK-UNROLL-DAG:   %[[kIDX:.+]] = affine.apply #[[$MAP]]()[%[[K]]]
+//   CHECK-UNROLL-DAG:   %[[hIDX:.+]] = affine.apply #[[$MAP1]](%[[C1]])[%[[mOFF]], %[[K]]]
+//   CHECK-UNROLL-DAG:   %[[wIDX:.+]] = affine.apply #[[$MAP2]](%[[C1]])[%[[mOFF]], %[[K]]]
+//       CHECK-UNROLL:   %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[C0]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x4xf32>
+//       CHECK-UNROLL:   %[[OUT_SLICE:.+]] = tensor.extract_slice %[[INSERT0]][%[[C0]], %[[C1]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x2x4xf32> to tensor<1x1x4xf32>
+//       CHECK-UNROLL:   %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
+//       CHECK-UNROLL:   %[[INSERT1:.+]] = tensor.insert_slice %[[COPY]] into %[[INSERT0]][%[[C0]], %[[C1]], 0] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<2x2x4xf32>
 
 //  Third iteration
 //
-//  CHECK-UNROLL-DAG:   %[[kIDX:.+]] = affine.apply #[[MAP]]()[%[[K]]]
-//  CHECK-UNROLL-DAG:   %[[hIDX:.+]] = affine.apply #[[MAP1]](%[[C0]])[%[[mOFF]], %[[K]]]
-//  CHECK-UNROLL-DAG:   %[[wIDX:.+]] = affine.apply #[[MAP2]](%[[C0]])[%[[mOFF]], %[[K]]]
-//      CHECK-UNROLL:   %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[C1]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x4xf32>
-//      CHECK-UNROLL:   %[[OUT_SLICE:.+]] = tensor.extract_slice %[[INSERT1]][%[[C1]], %[[C0]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x2x4xf32> to tensor<1x1x4xf32>
-//      CHECK-UNROLL:   %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
-//      CHECK-UNROLL:   %[[INSERT2:.+]] = tensor.insert_slice %[[COPY]] into %[[INSERT1]][%[[C1]], %[[C0]], 0] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<2x2x4xf32>
+//   CHECK-UNROLL-DAG:   %[[kIDX:.+]] = affine.apply #[[$MAP]]()[%[[K]]]
+//   CHECK-UNROLL-DAG:   %[[hIDX:.+]] = affine.apply #[[$MAP1]](%[[C0]])[%[[mOFF]], %[[K]]]
+//   CHECK-UNROLL-DAG:   %[[wIDX:.+]] = affine.apply #[[$MAP2]](%[[C0]])[%[[mOFF]], %[[K]]]
+//       CHECK-UNROLL:   %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[C1]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x4xf32>
+//       CHECK-UNROLL:   %[[OUT_SLICE:.+]] = tensor.extract_slice %[[INSERT1]][%[[C1]], %[[C0]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x2x4xf32> to tensor<1x1x4xf32>
+//       CHECK-UNROLL:   %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
+//       CHECK-UNROLL:   %[[INSERT2:.+]] = tensor.insert_slice %[[COPY]] into %[[INSERT1]][%[[C1]], %[[C0]], 0] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<2x2x4xf32>
 
 //  Fourth iteration
 //
-//  CHECK-UNROLL-DAG:   %[[kIDX:.+]] = affine.apply #[[MAP]]()[%[[K]]]
-//  CHECK-UNROLL-DAG:   %[[hIDX:.+]] = affine.apply #[[MAP1]](%[[C1]])[%[[mOFF]], %[[K]]]
-//  CHECK-UNROLL-DAG:   %[[wIDX:.+]] = affine.apply #[[MAP2]](%[[C1]])[%[[mOFF]], %[[K]]]
-//      CHECK-UNROLL:   %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[C1]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x4xf32>
-//      CHECK-UNROLL:   %[[OUT_SLICE:.+]] = tensor.extract_slice %[[INSERT2]][%[[C1]], %[[C1]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x2x4xf32> to tensor<1x1x4xf32>
-//      CHECK-UNROLL:   %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
-//      CHECK-UNROLL:   %[[INSERT3:.+]] = tensor.insert_slice %[[COPY]] into %[[INSERT2]][%[[C1]], %[[C1]], 0] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<2x2x4xf32>
+//   CHECK-UNROLL-DAG:   %[[kIDX:.+]] = affine.apply #[[$MAP]]()[%[[K]]]
+//   CHECK-UNROLL-DAG:   %[[hIDX:.+]] = affine.apply #[[$MAP1]](%[[C1]])[%[[mOFF]], %[[K]]]
+//   CHECK-UNROLL-DAG:   %[[wIDX:.+]] = affine.apply #[[$MAP2]](%[[C1]])[%[[mOFF]], %[[K]]]
+//       CHECK-UNROLL:   %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[C1]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x4xf32>
+//       CHECK-UNROLL:   %[[OUT_SLICE:.+]] = tensor.extract_slice %[[INSERT2]][%[[C1]], %[[C1]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x2x4xf32> to tensor<1x1x4xf32>
+//       CHECK-UNROLL:   %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
+//       CHECK-UNROLL:   %[[INSERT3:.+]] = tensor.insert_slice %[[COPY]] into %[[INSERT2]][%[[C1]], %[[C1]], 0] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<2x2x4xf32>
 
-//      CHECK-UNROLL:   return %[[INSERT3]] : tensor<2x2x4xf32>
+//       CHECK-UNROLL:   return %[[INSERT3]] : tensor<2x2x4xf32>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
index 983ec33..f42a2fa 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
@@ -745,7 +745,8 @@
 func.func @im2col(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32> {
   %0 = tensor.empty() : tensor<2x1024x5760xf32>
   %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-           m_offset = [34] k_offset = [1000] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+           m_offset = [34] * [1] k_offset = [1000] * [1]
+           batch_pos = [0] m_pos = [1, 2] k_pos = [3]
            ins(%arg0 : tensor<2x34x34x640xf32>)
            outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
   return %1 : tensor<2x1024x5760xf32>
@@ -770,11 +771,11 @@
 // CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
 // CHECK:        %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
 // CHECK:        %[[RES0:.+]] = scf.for %[[ARG1:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK-SAME:       iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<2x1024x5760xf32>) {
+// CHECK-SAME:       iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<2x1024x5760xf32>)
 // CHECK:          %[[RES1:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C5]]
-// CHECK-SAME:       iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<2x1024x5760xf32>) {
+// CHECK-SAME:       iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<2x1024x5760xf32>)
 // CHECK:            %[[RES2:.+]] = scf.for %[[ARG5:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C5760]] step %[[C4]]
-// CHECK-SAME:         iter_args(%[[ARG6:[a-zA-Z0-9_]+]] = %[[ARG4]]) -> (tensor<2x1024x5760xf32>) {
+// CHECK-SAME:         iter_args(%[[ARG6:[a-zA-Z0-9_]+]] = %[[ARG4]]) -> (tensor<2x1024x5760xf32>)
 // CHECK-DAG:          %[[MSIZE:.+]] = affine.min #[[MAP]](%[[ARG3]])
 // CHECK-DAG:          %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG1]], 0, 0, 0]
 // CHECK-SAME:           [1, 34, 34, 640] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x34x34x640xf32>
@@ -783,18 +784,16 @@
 // CHECK-DAG:          %[[KOFFSET:.+]] = affine.apply #[[MAP1]](%[[ARG5]])
 // CHECK-DAG:          %[[MOFFSET:.+]] = affine.apply #[[MAP2]](%[[ARG3]])
 // CHECK:              %[[IM2COL:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-// CHECK-SAME:           m_offset = [%[[MOFFSET]]] k_offset = [%[[KOFFSET]]] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+// CHECK-SAME:           m_offset = [%[[MOFFSET]]] * [1] k_offset = [%[[KOFFSET]]] * [1]
+// CHECK-SAME:           batch_pos = [0] m_pos = [1, 2] k_pos = [3]
 // CHECK-SAME:           ins(%[[EXTRACTED_SLICE]] : tensor<1x34x34x640xf32>)
 // CHECK-SAME:           outs(%[[EXTRACTED_SLICE_0]] : tensor<1x?x4xf32>) -> tensor<1x?x4xf32>
 // CHECK:              %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[IM2COL]] into %[[ARG6]]
 // CHECK-SAME:           [%[[ARG1]], %[[ARG3]], %[[ARG5]]] [1, %[[MSIZE]], 4] [1, 1, 1]
 // CHECK-SAME:           tensor<1x?x4xf32> into tensor<2x1024x5760xf32>
 // CHECK:              scf.yield %[[INSERTED_SLICE]] : tensor<2x1024x5760xf32>
-// CHECK:            }
 // CHECK:            scf.yield %[[RES2]] : tensor<2x1024x5760xf32>
-// CHECK:          }
 // CHECK:          scf.yield %[[RES1]] : tensor<2x1024x5760xf32>
-// CHECK:        }
 // CHECK:        return %[[RES0]] : tensor<2x1024x5760xf32>
 
 // -----
@@ -802,7 +801,8 @@
 func.func @im2col_transposed_m_pos(%arg0: tensor<640x2x101x172xf32>) -> tensor<2x1024x5760xf32> {
   %0 = tensor.empty() : tensor<2x1024x5760xf32>
   %1 = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2]
-           m_offset = [42] k_offset = [7] batch_pos = [1] m_pos = [3, 2] k_pos = [0]
+           m_offset = [42] * [1] k_offset = [7] * [1]
+           batch_pos = [1] m_pos = [3, 2] k_pos = [0]
            ins(%arg0 : tensor<640x2x101x172xf32>)
            outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
   return %1 : tensor<2x1024x5760xf32>
@@ -828,11 +828,11 @@
 // CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
 // CHECK:        %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
 // CHECK:        %[[RES0:.+]] = scf.for %[[ARG1:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK-SAME:       iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<2x1024x5760xf32>) {
+// CHECK-SAME:       iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<2x1024x5760xf32>)
 // CHECK:          %[[RES1:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C9]]
-// CHECK-SAME:       iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<2x1024x5760xf32>) {
+// CHECK-SAME:       iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<2x1024x5760xf32>)
 // CHECK:            %[[RES2:.+]] = scf.for %[[ARG5:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C5760]] step %[[C7]]
-// CHECK-SAME:         iter_args(%[[ARG6:[a-zA-Z0-9_]+]] = %[[ARG4]]) -> (tensor<2x1024x5760xf32>) {
+// CHECK-SAME:         iter_args(%[[ARG6:[a-zA-Z0-9_]+]] = %[[ARG4]]) -> (tensor<2x1024x5760xf32>)
 // CHECK-DAG:          %[[MSIZE:.+]] = affine.min #[[MAP]](%[[ARG3]])
 // CHECK-DAG:          %[[KSIZE:.+]] = affine.min #[[MAP1]](%[[ARG5]])
 // CHECK-DAG:          %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %[[ARG1]], 0, 0]
@@ -842,18 +842,16 @@
 // CHECK-DAG:          %[[KOFFSET:.+]] = affine.apply #[[MAP2]](%[[ARG5]])
 // CHECK-DAG:          %[[MOFFSET:.+]] = affine.apply #[[MAP3]](%[[ARG3]])
 // CHECK:              %[[IM2COL:.+]] = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2]
-// CHECK-SAME:           m_offset = [%[[MOFFSET]]] k_offset = [%[[KOFFSET]]] batch_pos = [1] m_pos = [3, 2] k_pos = [0]
+// CHECK-SAME:           m_offset = [%[[MOFFSET]]] * [1] k_offset = [%[[KOFFSET]]] * [1]
+// CHECK-SAME:           batch_pos = [1] m_pos = [3, 2] k_pos = [0]
 // CHECK-SAME:           ins(%[[EXTRACTED_SLICE]] : tensor<640x1x101x172xf32>)
 // CHECK-SAME:           outs(%[[EXTRACTED_SLICE_0]] : tensor<1x?x?xf32>) -> tensor<1x?x?xf32>
 // CHECK:              %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[IM2COL]] into %[[ARG6]]
 // CHECK-SAME:           [%[[ARG1]], %[[ARG3]], %[[ARG5]]] [1, %[[MSIZE]], %[[KSIZE]]] [1, 1, 1]
 // CHECK-SAME:           tensor<1x?x?xf32> into tensor<2x1024x5760xf32>
 // CHECK:              scf.yield %[[INSERTED_SLICE]] : tensor<2x1024x5760xf32>
-// CHECK:            }
 // CHECK:            scf.yield %[[RES2]] : tensor<2x1024x5760xf32>
-// CHECK:          }
 // CHECK:          scf.yield %[[RES1]] : tensor<2x1024x5760xf32>
-// CHECK:        }
 // CHECK:        return %[[RES0]] : tensor<2x1024x5760xf32>
 
 // -----
@@ -862,7 +860,8 @@
                           %mOffset: index, %kOffset: index) -> tensor<?x?x?xf32> {
   %0 = tensor.empty(%s0, %s1, %s2) : tensor<?x?x?xf32>
   %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-           m_offset = [%mOffset] k_offset = [%kOffset] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+           m_offset = [%mOffset] * [1] k_offset = [%kOffset] * [1]
+           batch_pos = [0] m_pos = [1, 2] k_pos = [3]
            ins(%arg0 : tensor<?x?x?x?xf32>)
            outs(%0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
   return %1 : tensor<?x?x?xf32>
@@ -888,11 +887,11 @@
 // CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
 // CHECK:        %[[D0:.+]] = tensor.empty(%[[S0]], %[[S1]], %[[S2]]) : tensor<?x?x?xf32>
 // CHECK:        %[[RES0:.+]] = scf.for %[[ARG1:[a-zA-Z0-9_]+]] = %[[C0]] to %[[S0]] step %[[C2]]
-// CHECK-SAME:       iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<?x?x?xf32>) {
+// CHECK-SAME:       iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<?x?x?xf32>)
 // CHECK:          %[[RES1:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[S1]] step %[[C7]]
-// CHECK-SAME:       iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<?x?x?xf32>) {
+// CHECK-SAME:       iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<?x?x?xf32>)
 // CHECK:            %[[RES2:.+]] = scf.for %[[ARG5:[a-zA-Z0-9_]+]] = %[[C0]] to %[[S2]] step %[[C5]]
-// CHECK-SAME:         iter_args(%[[ARG6:[a-zA-Z0-9_]+]] = %[[ARG4]]) -> (tensor<?x?x?xf32>) {
+// CHECK-SAME:         iter_args(%[[ARG6:[a-zA-Z0-9_]+]] = %[[ARG4]]) -> (tensor<?x?x?xf32>)
 // CHECK-DAG:          %[[BSIZE:.+]] = affine.min #[[MAP]](%[[ARG1]])
 // CHECK-DAG:          %[[MSIZE:.+]] = affine.min #[[MAP1]](%[[ARG3]])
 // CHECK-DAG:          %[[KSIZE:.+]] = affine.min #[[MAP2]](%[[ARG5]])
@@ -906,22 +905,89 @@
 // CHECK-DAG:          %[[KOFFSET:.+]] = affine.apply #[[MAP3]](%[[ARG5]])[%[[KOFF]]]
 // CHECK-DAG:          %[[MOFFSET:.+]] = affine.apply #[[MAP3]](%[[ARG3]])[%[[MOFF]]]
 // CHECK:              %[[IM2COL:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-// CHECK-SAME:           m_offset = [%[[MOFFSET]]] k_offset = [%[[KOFFSET]]] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+// CHECK-SAME:           m_offset = [%[[MOFFSET]]] * [1] k_offset = [%[[KOFFSET]]] * [1]
+// CHECK-SAME:           batch_pos = [0] m_pos = [1, 2] k_pos = [3]
 // CHECK-SAME:           ins(%[[EXTRACTED_SLICE]] : tensor<?x?x?x?xf32>)
 // CHECK-SAME:           outs(%[[EXTRACTED_SLICE_0]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
 // CHECK:              %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[IM2COL]] into %[[ARG6]]
 // CHECK-SAME:           [%[[ARG1]], %[[ARG3]], %[[ARG5]]] [%[[BSIZE]], %[[MSIZE]], %[[KSIZE]]] [1, 1, 1]
 // CHECK-SAME:           tensor<?x?x?xf32> into tensor<?x?x?xf32>
 // CHECK:              scf.yield %[[INSERTED_SLICE]] : tensor<?x?x?xf32>
-// CHECK:            }
 // CHECK:            scf.yield %[[RES2]] : tensor<?x?x?xf32>
-// CHECK:          }
 // CHECK:          scf.yield %[[RES1]] : tensor<?x?x?xf32>
-// CHECK:        }
 // CHECK:        return %[[RES0]] : tensor<?x?x?xf32>
 
 // -----
 
+module {
+  func.func @im2col_expanded(%arg0: tensor<2x34x34x640xf32>, %m_stride: index) -> tensor<2x32x32x1440x4xf32> {
+    %0 = tensor.empty() : tensor<2x32x32x1440x4xf32>
+    %7 = iree_linalg_ext.im2col
+            strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+            m_offset = [0, 0] * [%m_stride, 1] k_offset = [0, 0] * [4, 1]
+            batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+            ins(%arg0 : tensor<2x34x34x640xf32>)
+            outs(%0 : tensor<2x32x32x1440x4xf32>) -> tensor<2x32x32x1440x4xf32>
+    return %7 : tensor<2x32x32x1440x4xf32>
+  }
+}
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["iree_linalg_ext.im2col"]} in %module_op : (!transform.any_op) -> !transform.any_op
+    %1, %loops:5 = transform.structured.tile_using_for %0 tile_sizes [1, 7, 5, 11, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-DAG:  #[[MAP:.+]] = affine_map<(d0) -> (-d0 + 32, 7)>
+// CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0) -> (-d0 + 32, 5)>
+// CHECK-DAG:  #[[MAP2:.+]] = affine_map<(d0) -> (-d0 + 1440, 11)>
+// CHECK:      func.func @im2col_expanded
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x34x34x640xf32>
+// CHECK-SAME:   %[[M_STRIDE:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG:    %[[C11:.+]] = arith.constant 11 : index
+// CHECK-DAG:    %[[C5:.+]] = arith.constant 5 : index
+// CHECK-DAG:    %[[C7:.+]] = arith.constant 7 : index
+// CHECK-DAG:    %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:    %[[C4:.+]] = arith.constant 4 : index
+// CHECK-DAG:    %[[C1440:.+]] = arith.constant 1440 : index
+// CHECK-DAG:    %[[C32:.+]] = arith.constant 32 : index
+// CHECK-DAG:    %[[C2:.+]] = arith.constant 2 : index
+// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<2x32x32x1440x4xf32>
+// CHECK:        %[[RES0:.+]] = scf.for %[[ARG1:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK-SAME:       iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<2x32x32x1440x4xf32>)
+// CHECK:          %[[RES1:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C32]] step %[[C7]]
+// CHECK-SAME:       iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<2x32x32x1440x4xf32>)
+// CHECK:            %[[RES2:.+]] = scf.for %[[ARG5:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C32]] step %[[C5]]
+// CHECK-SAME:         iter_args(%[[ARG6:[a-zA-Z0-9_]+]] = %[[ARG4]]) -> (tensor<2x32x32x1440x4xf32>)
+// CHECK:              %[[RES3:.+]] = scf.for %[[ARG7:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1440]] step %[[C11]]
+// CHECK-SAME:           iter_args(%[[ARG8:[a-zA-Z0-9_]+]] = %[[ARG6]]) -> (tensor<2x32x32x1440x4xf32>)
+// CHECK:                %[[RES4:.+]] = scf.for %[[ARG9:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C4]] step %[[C2]]
+// CHECK-SAME:             iter_args(%[[ARG10:[a-zA-Z0-9_]+]] = %[[ARG8]]) -> (tensor<2x32x32x1440x4xf32>)
+// CHECK-DAG:              %[[M0SIZE:.+]] = affine.min #[[MAP]](%[[ARG3]])
+// CHECK-DAG:              %[[M1SIZE:.+]] = affine.min #[[MAP1]](%[[ARG5]])
+// CHECK-DAG:              %[[K0SIZE:.+]] = affine.min #[[MAP2]](%[[ARG7]])
+// CHECK-DAG:              %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG1]], 0, 0, 0]
+// CHECK-SAME:               [1, 34, 34, 640] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x34x34x640xf32>
+// CHECK-DAG:              %[[EXTRACTED_SLICE_0:.+]] = tensor.extract_slice %[[ARG10]][%[[ARG1]], %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]]
+// CHECK-SAME:               [1, %[[M0SIZE]], %[[M1SIZE]], %[[K0SIZE]], 2] [1, 1, 1, 1, 1] : tensor<2x32x32x1440x4xf32> to tensor<1x?x?x?x2xf32>
+// CHECK:                  %[[IM2COL:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+// CHECK-SAME:               m_offset = [%[[ARG3]], %[[ARG5]]] * [%[[M_STRIDE]], 1] k_offset = [%[[ARG7]], %[[ARG9]]] * [4, 1]
+// CHECK-SAME:               batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+// CHECK-SAME:               ins(%[[EXTRACTED_SLICE]] : tensor<1x34x34x640xf32>)
+// CHECK-SAME:               outs(%[[EXTRACTED_SLICE_0]] : tensor<1x?x?x?x2xf32>) -> tensor<1x?x?x?x2xf32>
+// CHECK:                  %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[IM2COL]] into %[[ARG10]]
+// CHECK-SAME:               [%[[ARG1]], %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]] [1, %[[M0SIZE]], %[[M1SIZE]], %[[K0SIZE]], 2] [1, 1, 1, 1, 1]
+// CHECK-SAME:               tensor<1x?x?x?x2xf32> into tensor<2x32x32x1440x4xf32>
+// CHECK:                  scf.yield %[[INSERTED_SLICE]] : tensor<2x32x32x1440x4xf32>
+// CHECK:                scf.yield %[[RES4]] : tensor<2x32x32x1440x4xf32>
+// CHECK:              scf.yield %[[RES3]] : tensor<2x32x32x1440x4xf32>
+// CHECK:            scf.yield %[[RES2]] : tensor<2x32x32x1440x4xf32>
+// CHECK:          scf.yield %[[RES1]] : tensor<2x32x32x1440x4xf32>
+// CHECK:        return %[[RES0]] : tensor<2x32x32x1440x4xf32>
+
+// -----
+
 func.func @winograd_filter_transform(%arg0: tensor<3x3x64x128xf32>) -> tensor<8x8x64x128xf32> {
   %0 = tensor.empty() : tensor<8x8x64x128xf32>
   %1 = iree_linalg_ext.winograd.filter_transform
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel
index 0da302b..0212d6e 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel
@@ -25,6 +25,7 @@
     ],
     deps = [
         "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:AffineDialect",
         "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:DialectUtils",
         "@llvm-project//mlir:IR",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt
index b4c519c..564ba5b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt
@@ -22,6 +22,7 @@
     "Utils.cpp"
   DEPS
     LLVMSupport
+    MLIRAffineDialect
     MLIRArithDialect
     MLIRIR
     MLIRLinalgDialect
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
index ed9524d..1aa2efa 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
@@ -8,6 +8,7 @@
 
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -16,6 +17,22 @@
 
 namespace mlir::iree_compiler::IREE::LinalgExt {
 
+OpFoldResult addOfrs(OpBuilder &builder, Location loc, OpFoldResult a,
+                     OpFoldResult b) {
+  AffineExpr d0, d1;
+  bindDims(builder.getContext(), d0, d1);
+  auto addMap = AffineMap::get(2, 0, {d0 + d1});
+  return affine::makeComposedFoldedAffineApply(builder, loc, addMap, {a, b});
+}
+
+OpFoldResult mulOfrs(OpBuilder &builder, Location loc, OpFoldResult a,
+                     OpFoldResult b) {
+  AffineExpr d0, d1;
+  bindDims(builder.getContext(), d0, d1);
+  auto addMap = AffineMap::get(2, 0, {d0 * d1});
+  return affine::makeComposedFoldedAffineApply(builder, loc, addMap, {a, b});
+}
+
 Value getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim) {
   ShapedType type = cast<ShapedType>(v.getType());
   if (!type.isDynamicDim(dim)) {
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
index 4d1b986..a89cb67 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
@@ -20,6 +20,14 @@
 
 namespace mlir::iree_compiler::IREE::LinalgExt {
 
+// Helper method to add 2 OpFoldResult inputs with affine.apply.
+OpFoldResult addOfrs(OpBuilder &builder, Location loc, OpFoldResult a,
+                     OpFoldResult b);
+
+// Helper method to multiply 2 OpFoldResult inputs with affine.apply.
+OpFoldResult mulOfrs(OpBuilder &builder, Location loc, OpFoldResult a,
+                     OpFoldResult b);
+
 /// Returns a `memref.dim` or `tensor.dim` operation to get the shape of `v` at
 /// `dim`.
 Value getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim);