[Im2col] Allow multiple batch, M, and K dimensions on im2col result (#18593)
This PR adds support for multiple M and K dimensions in the result of
the im2col op. New metadata is added for correctly tracking the offsets
into the M and K dimensions along the multiple dimensions. New `m_strides`
and `k_strides` fields are added to the op, which represent a basis for
linearizing the `m_offset` and `k_offset` fields.
The motivation for doing this is that flattening the M dimension can
create an expand_shape op consumer of the resulting matmul. This can
cause issues with fusion and distribution, so it is useful to be able to
keep the multiple M dimensions intact. This PR does not change any
behavior of Conv2DToIm2col pass, which will be done in a later PR.
---------
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
index 49d6892..3cec8a9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
@@ -389,7 +389,8 @@
} {
%4 = iree_linalg_ext.im2col {lowering_config = #config}
strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
- m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [2, 3] k_pos = [1]
+ m_offset = [0] * [1] k_offset = [0] * [1]
+ batch_pos = [0] m_pos = [2, 3] k_pos = [1]
ins(%2 : tensor<2x34x34x128xf16>)
outs(%3 : tensor<2x128x8xf16>) -> tensor<2x128x8xf16>
return %4 : tensor<2x128x8xf16>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
index 23bddd8..fe7d739 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
@@ -162,7 +162,7 @@
%6 = tensor.empty() : tensor<2x256x11520xf16>
%7 = iree_linalg_ext.im2col
strides = [2, 2] dilations = [1, 1] kernel_size = [3, 3]
- m_offset = [0] k_offset = [0]
+ m_offset = [0] * [1] k_offset = [0] * [1]
batch_pos = [0] m_pos = [1, 2] k_pos = [3]
ins(%3 : tensor<2x34x34x1280xf16>)
outs(%6 : tensor<2x256x11520xf16>) -> tensor<2x256x11520xf16>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index ad818b2..dafb17f 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -1513,12 +1513,24 @@
getKOffset());
}
-/// Return all static and dynamic k_offset as OpFoldResults.
+/// Return all static and dynamic m_offset as OpFoldResults.
SmallVector<OpFoldResult> Im2colOp::getMixedMOffset() {
return LinalgExt::getMixedValues(getContext(), getStaticMOffset(),
getMOffset());
}
+/// Return all static and dynamic k_strides as OpFoldResults.
+SmallVector<OpFoldResult> Im2colOp::getMixedKStrides() {
+ return LinalgExt::getMixedValues(getContext(), getStaticKStrides(),
+ getKStrides());
+}
+
+/// Return all static and dynamic m_strides as OpFoldResults.
+SmallVector<OpFoldResult> Im2colOp::getMixedMStrides() {
+ return LinalgExt::getMixedValues(getContext(), getStaticMStrides(),
+ getMStrides());
+}
+
void Im2colOp::setMixedKOffset(SmallVector<OpFoldResult> kOffset) {
SmallVector<int64_t> staticKOffset;
SmallVector<Value> dynamicKOffset;
@@ -1535,23 +1547,59 @@
getMOffsetMutable().assign(dynamicMOffset);
}
+void Im2colOp::setMixedKStrides(SmallVector<OpFoldResult> kStrides) {
+ SmallVector<int64_t> staticKStrides;
+ SmallVector<Value> dynamicKStrides;
+ dispatchIndexOpFoldResults(kStrides, dynamicKStrides, staticKStrides);
+ setStaticKStrides(staticKStrides);
+ getKStridesMutable().assign(dynamicKStrides);
+}
+
+void Im2colOp::setMixedMStrides(SmallVector<OpFoldResult> mStrides) {
+ SmallVector<int64_t> staticMStrides;
+ SmallVector<Value> dynamicMStrides;
+ dispatchIndexOpFoldResults(mStrides, dynamicMStrides, staticMStrides);
+ setStaticMStrides(staticMStrides);
+ getMStridesMutable().assign(dynamicMStrides);
+}
+
+SmallVector<int64_t> Im2colOp::getBatchOutputDims() {
+ return llvm::to_vector(llvm::seq<int64_t>(0, getBatchPos().size()));
+}
+
+SmallVector<int64_t> Im2colOp::getMOutputDims() {
+ int64_t begin = getBatchPos().size();
+ int64_t end = begin + getMixedMOffset().size();
+ return llvm::to_vector(llvm::seq<int64_t>(begin, end));
+}
+
+SmallVector<int64_t> Im2colOp::getKOutputDims() {
+ int64_t begin = getBatchPos().size() + getMixedMOffset().size();
+ int64_t end = begin + getMixedKOffset().size();
+ return llvm::to_vector(llvm::seq<int64_t>(begin, end));
+}
+
/// Custom builder methods for im2col op.
-void Im2colOp::build(OpBuilder &builder, OperationState &state, Value input,
- Value output, ArrayRef<int64_t> strides,
- ArrayRef<int64_t> dilations,
- ArrayRef<OpFoldResult> kernelSize,
- ArrayRef<OpFoldResult> kOffset,
- ArrayRef<OpFoldResult> mOffset, ArrayRef<int64_t> batchPos,
- ArrayRef<int64_t> mPos, ArrayRef<int64_t> kPos) {
+void Im2colOp::build(
+ OpBuilder &builder, OperationState &state, Value input, Value output,
+ ArrayRef<int64_t> strides, ArrayRef<int64_t> dilations,
+ ArrayRef<OpFoldResult> kernelSize, ArrayRef<OpFoldResult> mOffset,
+ ArrayRef<OpFoldResult> mStrides, ArrayRef<OpFoldResult> kOffset,
+ ArrayRef<OpFoldResult> kStrides, ArrayRef<int64_t> batchPos,
+ ArrayRef<int64_t> mPos, ArrayRef<int64_t> kPos) {
assert(strides.size() == kernelSize.size() &&
dilations.size() == kernelSize.size() &&
mPos.size() == kernelSize.size() &&
"strides, dilations, m_pos, and kernel expected to be the same rank");
- SmallVector<int64_t> staticKernelSize, staticMOffset, staticKOffset;
- SmallVector<Value> dynamicKernelSize, dynamicMOffset, dynamicKOffset;
+ SmallVector<int64_t> staticKernelSize, staticMOffset, staticKOffset,
+ staticMStrides, staticKStrides;
+ SmallVector<Value> dynamicKernelSize, dynamicMOffset, dynamicKOffset,
+ dynamicMStrides, dynamicKStrides;
dispatchIndexOpFoldResults(kernelSize, dynamicKernelSize, staticKernelSize);
dispatchIndexOpFoldResults(mOffset, dynamicMOffset, staticMOffset);
+ dispatchIndexOpFoldResults(mStrides, dynamicMStrides, staticMStrides);
dispatchIndexOpFoldResults(kOffset, dynamicKOffset, staticKOffset);
+ dispatchIndexOpFoldResults(kStrides, dynamicKStrides, staticKStrides);
SmallVector<Type> resultType;
auto outputType = output.getType();
if (isa<RankedTensorType>(outputType)) {
@@ -1560,9 +1608,11 @@
build(builder, state, resultType, input, output,
builder.getDenseI64ArrayAttr(strides),
builder.getDenseI64ArrayAttr(dilations), dynamicKernelSize,
- builder.getDenseI64ArrayAttr(staticKernelSize), dynamicKOffset,
- builder.getDenseI64ArrayAttr(staticKOffset), dynamicMOffset,
- builder.getDenseI64ArrayAttr(staticMOffset),
+ builder.getDenseI64ArrayAttr(staticKernelSize), dynamicMOffset,
+ builder.getDenseI64ArrayAttr(staticMOffset), dynamicMStrides,
+ builder.getDenseI64ArrayAttr(staticMStrides), dynamicKOffset,
+ builder.getDenseI64ArrayAttr(staticKOffset), dynamicKStrides,
+ builder.getDenseI64ArrayAttr(staticKStrides),
builder.getDenseI64ArrayAttr(batchPos),
builder.getDenseI64ArrayAttr(mPos), builder.getDenseI64ArrayAttr(kPos));
}
@@ -1578,14 +1628,35 @@
return op->emitOpError("expected one output operand");
}
- // TODO(Max191): Support cases with more than 1 m or k dimension, and remove
- // the check for a single m_offset and k_offset.
- if (getMixedMOffset().size() != 1) {
- return op->emitOpError("expected one m_offset");
+ // Verify offsets and strides
+ SmallVector<OpFoldResult> kOffset = getMixedKOffset();
+ SmallVector<OpFoldResult> mOffset = getMixedMOffset();
+ SmallVector<OpFoldResult> kStrides = getMixedKStrides();
+ SmallVector<OpFoldResult> mStrides = getMixedMStrides();
+ if (kOffset.size() < 1) {
+ return op->emitOpError("expected at least one k_offset");
}
- if (getMixedKOffset().size() != 1) {
- return op->emitOpError("expected one k_offset");
+ if (mOffset.size() < 1) {
+ return op->emitOpError("expected at least one m_offset");
}
+ if (kOffset.size() != kStrides.size()) {
+ return op->emitOpError("expected the same size k_offset and k_strides");
+ }
+ if (mOffset.size() != mStrides.size()) {
+ return op->emitOpError("expected the same size m_offset and m_strides");
+ }
+ std::optional<int64_t> constInnerKStrides =
+ getConstantIntValue(kStrides.back());
+ if (!constInnerKStrides.has_value() || constInnerKStrides.value() != 1) {
+ return op->emitOpError("expected inner k_strides to be 1");
+ }
+ std::optional<int64_t> constInnerMStrides =
+ getConstantIntValue(mStrides.back());
+ if (!constInnerMStrides.has_value() || constInnerMStrides.value() != 1) {
+ return op->emitOpError("expected inner m_strides to be 1");
+ }
+
+ // Verify operand ranks and dim position sizes.
auto inputType = getInputType();
unsigned inputRank = inputType.getRank();
ArrayRef<int64_t> batchPos = getBatchPos();
@@ -1595,6 +1666,14 @@
return op->emitOpError(
"expected input rank to be the sum of batch, m, and k ranks");
}
+ auto outputType = getOutputType();
+ unsigned outputRank = outputType.getRank();
+ if (outputRank != batchPos.size() + kOffset.size() + mOffset.size()) {
+ return op->emitOpError("expected output rank to be the sum of "
+ "batch_pos, k_offset, and m_offset ranks");
+ }
+
+ // Verify convolution metadata.
ArrayRef<int64_t> strides = getStrides();
ArrayRef<int64_t> dilations = getDilations();
SmallVector<OpFoldResult> kernelSize = getMixedKernelSize();
@@ -1611,17 +1690,16 @@
"expected dilations rank to be equal to the kernel rank");
}
+ // Verify input and output shapes.
ArrayRef<int64_t> inputShape = inputType.getShape();
- SmallVector<int64_t> expectedOutputShape;
- for (auto pos : batchPos) {
- expectedOutputShape.push_back(inputShape[pos]);
- }
- ArrayRef<int64_t> outputShape = getOutputType().getShape();
+ ArrayRef<int64_t> outputShape = outputType.getShape();
// When the op is tiled, the m and k dimensions of the output are tiled, but
// they are not tiled in the input, so we cannot verify the output size of
- // these dimensions.
- expectedOutputShape.push_back(outputShape[outputShape.size() - 2]);
- expectedOutputShape.push_back(outputShape.back());
+ // these dimensions. Only verify the shape of the batch dimensions.
+ SmallVector<int64_t> expectedOutputShape(outputShape);
+ for (auto [idx, pos] : llvm::enumerate(batchPos)) {
+ expectedOutputShape[idx] = inputShape[pos];
+ }
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return op->emitOpError("incompatible output shape");
}
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 0d0c44b..e6aab96 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -827,7 +827,7 @@
```
%im2col = iree_linalg_ext.im2col
strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
- m_offset = [0] k_offset = [0]
+ m_offset = [0] * [1] k_offset = [0] * [1]
batch_pos = [0] m_pos = [1, 2] k_pos = [3]
ins(%in : tensor<2x34x34x640xf32>)
outs(%out : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
@@ -839,7 +839,7 @@
scf.for %arg2 = %c0 to %c5760 step %c1
%im2col = iree_linalg_ext.im2col
strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
- m_offset = [%arg1] k_offset = [%arg2]
+ m_offset = [%arg1] * [1] k_offset = [%arg2] * [1]
batch_pos = [0] m_pos = [1, 2] k_pos = [3]
ins(%in_tile : tensor<1x34x34x640xf32>)
outs(%out_tile : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
@@ -851,6 +851,21 @@
(b, m, k) -> (b, M / 32 + K / (640*3), M % 32 + K % (640*3) / 640, K % 640)
Where `(b, m, k)` are the indices of the tiled op's iteration space, and
`M = m + m_offset` and `K = k + K_offset`.
+
+ The `m_strides` and `k_strides` fields are used as a basis for linearizing
+ the `m_offset` and `k_offset`. This is used when there are multiple M or K
+ output dimensions, and therefore multiple `m_offset` or `k_offset` values.
+ The strides fields are assembled in the IR as if they are multiplied as an
+ inner product with `m_offset` and `k_offset, indicating that the total
+ linear offset along the dimension is equal to this inner product. These
+ strides fields also determine the strides of the output dimensions along
+ M and K. For example, an op with `m_strides = [32, 1]`, `k_strides = [4, 1]`,
+ and output type `tensor<BxM0xM1xK0xK1>` (expanded from `tensor<BxMxK>`),
+ would have strides along the M dim of 32 for `M0`, meaning as `M0` increases
+ by 1, the index into the flat `M` increases by 32. Along the K dim, strides
+ would be 4 for `K0`, and 1 for `K1`, meaning as `K0` increases by 1, the
+ index into the flat `K` increases by 4. The strides in M from `m_strides`
+ are orthogonal to the strides in `K` from `k_strides`.
}];
let arguments = (ins AnyShaped:$input, AnyShaped:$output,
@@ -860,8 +875,12 @@
DenseI64ArrayAttr:$static_kernel_size,
Variadic<Index>:$m_offset,
DenseI64ArrayAttr:$static_m_offset,
+ Variadic<Index>:$m_strides,
+ DenseI64ArrayAttr:$static_m_strides,
Variadic<Index>:$k_offset,
DenseI64ArrayAttr:$static_k_offset,
+ Variadic<Index>:$k_strides,
+ DenseI64ArrayAttr:$static_k_strides,
DenseI64ArrayAttr:$batch_pos,
DenseI64ArrayAttr:$m_pos,
DenseI64ArrayAttr:$k_pos);
@@ -876,8 +895,10 @@
custom<DynamicIndexList>($kernel_size, $static_kernel_size)
`m_offset` `=`
custom<DynamicIndexList>($m_offset, $static_m_offset)
+ `*` custom<DynamicIndexList>($m_strides, $static_m_strides)
`k_offset` `=`
custom<DynamicIndexList>($k_offset, $static_k_offset)
+ `*` custom<DynamicIndexList>($k_strides, $static_k_strides)
`batch_pos` `=` $batch_pos
`m_pos` `=` $m_pos
`k_pos` `=` $k_pos
@@ -892,7 +913,9 @@
"ArrayRef<int64_t>":$dilations,
"ArrayRef<OpFoldResult>":$kernel_size,
"ArrayRef<OpFoldResult>":$m_offset,
+ "ArrayRef<OpFoldResult>":$m_strides,
"ArrayRef<OpFoldResult>":$k_offset,
+ "ArrayRef<OpFoldResult>":$k_strides,
"ArrayRef<int64_t>":$batch_dimensions,
"ArrayRef<int64_t>":$m_dimensions,
"ArrayRef<int64_t>":$k_dimensions)>
@@ -911,14 +934,24 @@
int64_t getOutputRank() {
return getOutputType().getRank();
}
+
+ // Helpers to get output dimensions corresponding to batch, m, and k.
+ SmallVector<int64_t> getBatchOutputDims();
+ SmallVector<int64_t> getMOutputDims();
+ SmallVector<int64_t> getKOutputDims();
+
// Return op metadata.
SmallVector<OpFoldResult> getMixedKernelSize();
SmallVector<OpFoldResult> getMixedMOffset();
SmallVector<OpFoldResult> getMixedKOffset();
+ SmallVector<OpFoldResult> getMixedMStrides();
+ SmallVector<OpFoldResult> getMixedKStrides();
// Set op metadata.
void setMixedKOffset(SmallVector<OpFoldResult> kOffset);
void setMixedMOffset(SmallVector<OpFoldResult> mOffset);
+ void setMixedKStrides(SmallVector<OpFoldResult> kStrides);
+ void setMixedMStrides(SmallVector<OpFoldResult> mStrides);
// Method to implement for specifying output range for
// DestinationStyleOpInterface
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp
index 6af4173..95ff938 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp
@@ -1265,9 +1265,10 @@
SmallVector<OpFoldResult> inputSizes = getDims(builder, loc, getInput());
// Set batch offsets and sizes for input
- for (auto [idx, dim] : llvm::enumerate(getBatchPos())) {
- inputOffsets[dim] = offsets[idx];
- inputSizes[dim] = sizes[idx];
+ for (auto [outDim, inDim] :
+ llvm::zip_equal(getBatchOutputDims(), getBatchPos())) {
+ inputOffsets[inDim] = offsets[outDim];
+ inputSizes[inDim] = sizes[outDim];
}
SmallVector<OpFoldResult> inputStrides(getInputRank(), one);
@@ -1292,26 +1293,30 @@
outputSlice->result_type_end());
}
- AffineExpr d0, d1;
- bindDims(getContext(), d0, d1);
- auto map = AffineMap::get(2, 0, {d0 + d1}, getContext());
- OpFoldResult kTileOffset = offsets.back();
- OpFoldResult kOpOffset = getMixedKOffset()[0];
- OpFoldResult kOffset = affine::makeComposedFoldedAffineApply(
- builder, loc, map, {kTileOffset, kOpOffset});
- OpFoldResult mTileOffset = offsets[offsets.size() - 2];
- OpFoldResult mOpOffset = getMixedMOffset()[0];
- OpFoldResult mOffset = affine::makeComposedFoldedAffineApply(
- builder, loc, map, {mTileOffset, mOpOffset});
+ // Adjust m_offset and k_offset by adding the offsets from tiling.
+ SmallVector<OpFoldResult> newKOffsets, newMOffsets;
+ for (auto [outDim, kOffset] :
+ llvm::zip_equal(getKOutputDims(), getMixedKOffset())) {
+ OpFoldResult kTileOffset = offsets[outDim];
+ newKOffsets.push_back(addOfrs(builder, loc, kTileOffset, kOffset));
+ }
+ for (auto [outDim, mOffset] :
+ llvm::zip_equal(getMOutputDims(), getMixedMOffset())) {
+ OpFoldResult mTileOffset = offsets[outDim];
+ newMOffsets.push_back(addOfrs(builder, loc, mTileOffset, mOffset));
+ }
+ // Create the tiled op.
SmallVector<Value> operands = {inputSlice->getResult(0),
outputSlice->getResult(0)};
+ // Copy all metadata operands from the untiled operation.
operands.append(getOperation()->getOperands().begin() + 2,
getOperation()->getOperands().end());
Im2colOp tiledOp =
mlir::clone(builder, *this, outputSlice->getResultTypes(), operands);
- tiledOp.setMixedKOffset({kOffset});
- tiledOp.setMixedMOffset({mOffset});
+ // Set the new k_offset and m_offset, since they have changed with tiling.
+ tiledOp.setMixedKOffset(newKOffsets);
+ tiledOp.setMixedMOffset(newMOffsets);
return TilingResult{{tiledOp},
SmallVector<Value>(tiledOp->getResults()),
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
index 0643d8a..6884e9f 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
@@ -550,7 +550,8 @@
%0 = tensor.empty() : tensor<2x1024x5760xf32>
// expected-error @+1 {{expected strides rank to be equal to the kernel rank}}
%1 = iree_linalg_ext.im2col strides = [1] dilations = [1, 1] kernel_size = [3, 3]
- m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+ m_offset = [0] * [1] k_offset = [0] * [1]
+ batch_pos = [0] m_pos = [1, 2] k_pos = [3]
ins(%arg0 : tensor<2x34x34x640xf32>)
outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
return %1 : tensor<2x1024x5760xf32>
@@ -562,7 +563,8 @@
%0 = tensor.empty() : tensor<2x1024x5760xf32>
// expected-error @+1 {{expected dilations rank to be equal to the kernel rank}}
%1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1, 1] kernel_size = [3, 3]
- m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+ m_offset = [0] * [1] k_offset = [0] * [1]
+ batch_pos = [0] m_pos = [1, 2] k_pos = [3]
ins(%arg0 : tensor<2x34x34x640xf32>)
outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
return %1 : tensor<2x1024x5760xf32>
@@ -574,7 +576,60 @@
%0 = tensor.empty() : tensor<2x1024x5760xf32>
// expected-error @+1 {{expected kernel rank to be equal to the m_pos rank}}
%1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3]
- m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+ m_offset = [0] * [1] k_offset = [0] * [1]
+ batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+ ins(%arg0 : tensor<2x34x34x640xf32>)
+ outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
+ return %1 : tensor<2x1024x5760xf32>
+}
+
+// -----
+
+func.func @illegal_im2col_m_offset(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32> {
+ %0 = tensor.empty() : tensor<2x1024x5760xf32>
+ // expected-error @+1 {{expected the same size m_offset and m_strides}}
+ %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+ m_offset = [0, 0] * [1] k_offset = [0] * [1]
+ batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+ ins(%arg0 : tensor<2x34x34x640xf32>)
+ outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
+ return %1 : tensor<2x1024x5760xf32>
+}
+
+// -----
+
+func.func @illegal_im2col_k_offset(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32> {
+ %0 = tensor.empty() : tensor<2x1024x5760xf32>
+ // expected-error @+1 {{expected the same size k_offset and k_strides}}
+ %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+ m_offset = [0] * [1] k_offset = [0, 0] * [1]
+ batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+ ins(%arg0 : tensor<2x34x34x640xf32>)
+ outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
+ return %1 : tensor<2x1024x5760xf32>
+}
+
+// -----
+
+func.func @illegal_im2col_m_strides(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32> {
+ %0 = tensor.empty() : tensor<2x1024x5760xf32>
+ // expected-error @+1 {{expected inner m_strides to be 1}}
+ %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+ m_offset = [0] * [0] k_offset = [0] * [1]
+ batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+ ins(%arg0 : tensor<2x34x34x640xf32>)
+ outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
+ return %1 : tensor<2x1024x5760xf32>
+}
+
+// -----
+
+func.func @illegal_im2col_k_strides(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32> {
+ %0 = tensor.empty() : tensor<2x1024x5760xf32>
+ // expected-error @+1 {{expected inner k_strides to be 1}}
+ %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+ m_offset = [0] * [1] k_offset = [0] * [2]
+ batch_pos = [0] m_pos = [1, 2] k_pos = [3]
ins(%arg0 : tensor<2x34x34x640xf32>)
outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
return %1 : tensor<2x1024x5760xf32>
@@ -586,7 +641,8 @@
%0 = tensor.empty() : tensor<2x1024x5760xf32>
// expected-error @+1 {{expected input rank to be the sum of batch, m, and k ranks}}
%1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
- m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+ m_offset = [0] * [1] k_offset = [0] * [1]
+ batch_pos = [0] m_pos = [1, 2] k_pos = [3]
ins(%arg0 : tensor<1x2x34x34x640xf32>)
outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
return %1 : tensor<2x1024x5760xf32>
@@ -594,6 +650,19 @@
// -----
+func.func @illegal_im2col_output_rank(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x9x640xf32> {
+ %0 = tensor.empty() : tensor<2x1024x9x640xf32>
+ // expected-error @+1 {{expected output rank to be the sum of batch_pos, k_offset, and m_offset ranks}}
+ %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+ m_offset = [0] * [1] k_offset = [0] * [1]
+ batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+ ins(%arg0 : tensor<2x34x34x640xf32>)
+ outs(%0 : tensor<2x1024x9x640xf32>) -> tensor<2x1024x9x640xf32>
+ return %1 : tensor<2x1024x9x640xf32>
+}
+
+// -----
+
func.func @illegal_winograd_input_shape(%arg0: tensor<1x10x10x32xf32>) -> tensor<8x8x1x6x6x32xf32> {
%0 = tensor.empty() : tensor<8x8x1x6x6x32xf32>
// expected-error @+1 {{incompatible output shape}}
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
index 7fedef7..df94c9e 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
@@ -1,10 +1,5 @@
// RUN: iree-opt --split-input-file %s | FileCheck %s
-// CHECK-LABEL: func.func @sort_tensor
-// CHECK: iree_linalg_ext.sort
-// CHECK-SAME: dimension(0)
-// CHECK-SAME: outs({{.*}})
-// CHECK: iree_linalg_ext.yield
func.func @sort_tensor(%arg0: tensor<128xi32>) -> tensor<128xi32> {
%0 = iree_linalg_ext.sort
dimension(0)
@@ -15,14 +10,14 @@
} -> tensor<128xi32>
return %0 : tensor<128xi32>
}
-
-// -----
-
-// CHECK-LABEL: func.func @sort_memref
+// CHECK-LABEL: func.func @sort_tensor(
// CHECK: iree_linalg_ext.sort
// CHECK-SAME: dimension(0)
// CHECK-SAME: outs({{.*}})
// CHECK: iree_linalg_ext.yield
+
+// -----
+
func.func @sort_memref(%arg0: memref<128xi32>) {
iree_linalg_ext.sort dimension(0)
outs(%arg0 : memref<128xi32>) {
@@ -32,6 +27,11 @@
}
return
}
+// CHECK-LABEL: func.func @sort_memref(
+// CHECK: iree_linalg_ext.sort
+// CHECK-SAME: dimension(0)
+// CHECK-SAME: outs({{.*}})
+// CHECK: iree_linalg_ext.yield
// -----
@@ -46,7 +46,7 @@
} -> tensor<?x?xi32>, tensor<?x?xf32>
return %0#0, %0#1 : tensor<?x?xi32>, tensor<?x?xf32>
}
-// CHECK-LABEL: func.func @sort_multi_result_tensor
+// CHECK-LABEL: func.func @sort_multi_result_tensor(
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
// CHECK: %[[RESULT:.+]]:2 = iree_linalg_ext.sort dimension(0)
@@ -65,7 +65,7 @@
}
return
}
-// CHECK-LABEL: func.func @sort_multi_result_memref
+// CHECK-LABEL: func.func @sort_multi_result_memref(
// CHECK-SAME: %[[ARG0:.+]]: memref<?x?xi32>
// CHECK-SAME: %[[ARG1:.+]]: memref<?x?xf32>
// CHECK: iree_linalg_ext.sort dimension(0)
@@ -524,10 +524,9 @@
} -> tensor<20x10x3x4xf32>, tensor<20x10x3x4xi32>
return %0#0, %0#1 : tensor<20x10x3x4xf32>, tensor<20x10x3x4xi32>
}
-
-// CHECK-LABEL: func.func @topk_tensor
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x10x8x4xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x10x8x4xi32>
+// CHECK-LABEL: func.func @topk_tensor(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<20x10x8x4xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<20x10x8x4xi32>
// CHECK: %[[OUT_VALUES:.+]] = tensor.empty()
// CHECK: %[[OUT_INDICES:.+]] = tensor.empty()
// CHECK: %[[RESULT:.+]]:2 = iree_linalg_ext.topk
@@ -550,7 +549,7 @@
}
return
}
-// CHECK-LABEL: func.func @topk_memref
+// CHECK-LABEL: func.func @topk_memref(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<4x10xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<4x10xi32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<4x3xf32>
@@ -574,7 +573,7 @@
} -> tensor<?x?xf32>, tensor<?x?xi32>
return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xi32>
}
-// CHECK-LABEL: func.func @topk_dynamic_tensor
+// CHECK-LABEL: func.func @topk_dynamic_tensor(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xi32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
@@ -602,7 +601,7 @@
return %0#0, %0#1 : tensor<20x10x3x4xf32>, tensor<20x10x3x4xi32>
}
-// CHECK-LABEL: func.func @topk_tensor
+// CHECK-LABEL: func.func @topk_tensor_optional(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x10x8x4xf32>
// CHECK: %[[OUT_VALUES:.+]] = tensor.empty()
// CHECK: %[[OUT_INDICES:.+]] = tensor.empty()
@@ -620,11 +619,11 @@
return %1 : tensor<3x3x1x1xf32>
}
-// CHECK: func.func @pack(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x3xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<3x3x1x1xf32>) -> tensor<3x3x1x1xf32>
-// CHECK: %[[RES:.*]] = iree_linalg_ext.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %[[ARG1]] : (tensor<3x3xf32> tensor<3x3x1x1xf32>) -> tensor<3x3x1x1xf32>
-// CHECK: return %[[RES]] : tensor<3x3x1x1xf32>
+// CHECK-LABEL: func.func @pack(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x3xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<3x3x1x1xf32>
+// CHECK: %[[RES:.*]] = iree_linalg_ext.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %[[ARG1]] : (tensor<3x3xf32> tensor<3x3x1x1xf32>) -> tensor<3x3x1x1xf32>
+// CHECK: return %[[RES]] : tensor<3x3x1x1xf32>
// -----
@@ -633,10 +632,10 @@
return
}
-// CHECK: func.func @pack(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<3x3xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<3x3x1x1xf32>) {
-// CHECK: iree_linalg_ext.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %[[ARG1]] : (memref<3x3xf32> memref<3x3x1x1xf32>)
+// CHECK-LABEL: func.func @pack(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<3x3xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<3x3x1x1xf32>
+// CHECK: iree_linalg_ext.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %[[ARG1]] : (memref<3x3xf32> memref<3x3x1x1xf32>)
// -----
@@ -645,16 +644,16 @@
%0 = iree_linalg_ext.pack %input padding_value(%pad: f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : (tensor<13x15xf32> tensor<3x8x8x2xf32>) -> tensor<3x8x8x2xf32>
return %0 : tensor<3x8x8x2xf32>
}
-// CHECK: func @extra_pad_and_pack(
-// CHECK-SAME: %[[INPUT:.+]]: tensor<13x15xf32>
-// CHECK-SAME: %[[OUTPUT:.+]]: tensor<3x8x8x2xf32>
-// CHECK-SAME: %[[PAD:.+]]: f32
-// CHECK: %[[RES:.+]] = iree_linalg_ext.pack %[[INPUT]]
-// CHECK-SAME: padding_value(%[[PAD]] : f32)
-// CHECK-SAME: inner_dims_pos = [0, 1]
-// CHECK-SAME: inner_tiles = [8, 2]
-// CHECK-SAME: into %[[OUTPUT]]
-// CHECK: return %[[RES]]
+// CHECK-LABEL: func @extra_pad_and_pack(
+// CHECK-SAME: %[[INPUT:.+]]: tensor<13x15xf32>
+// CHECK-SAME: %[[OUTPUT:.+]]: tensor<3x8x8x2xf32>
+// CHECK-SAME: %[[PAD:.+]]: f32
+// CHECK: %[[RES:.+]] = iree_linalg_ext.pack %[[INPUT]]
+// CHECK-SAME: padding_value(%[[PAD]] : f32)
+// CHECK-SAME: inner_dims_pos = [0, 1]
+// CHECK-SAME: inner_tiles = [8, 2]
+// CHECK-SAME: into %[[OUTPUT]]
+// CHECK: return %[[RES]]
// -----
@@ -662,16 +661,16 @@
%0 = iree_linalg_ext.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : (tensor<13x15xf32> tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
return %0 : tensor<2x8x8x2xf32>
}
-// CHECK: func.func @pad_and_pack_static
-// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<13x15xf32>
-// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]]: tensor<2x8x8x2xf32>
-// CHECK-SAME: %[[PAD:[a-zA-Z0-9_]+]]: f32
-// CHECK: %[[RES:.+]] = iree_linalg_ext.pack %[[INPUT]]
-// CHECK-SAME: padding_value(%[[PAD]] : f32)
-// CHECK-SAME: inner_dims_pos = [0, 1]
-// CHECK-SAME: inner_tiles = [8, 2]
-// CHECK-SAME: into %[[OUTPUT]]
-// CHECK: return %[[RES]]
+// CHECK-LABEL: func.func @pad_and_pack_static(
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<13x15xf32>
+// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]]: tensor<2x8x8x2xf32>
+// CHECK-SAME: %[[PAD:[a-zA-Z0-9_]+]]: f32
+// CHECK: %[[RES:.+]] = iree_linalg_ext.pack %[[INPUT]]
+// CHECK-SAME: padding_value(%[[PAD]] : f32)
+// CHECK-SAME: inner_dims_pos = [0, 1]
+// CHECK-SAME: inner_tiles = [8, 2]
+// CHECK-SAME: into %[[OUTPUT]]
+// CHECK: return %[[RES]]
// -----
@@ -679,16 +678,16 @@
%0 = iree_linalg_ext.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : (tensor<?x?xf32> tensor<?x?x8x2xf32>) -> tensor<?x?x8x2xf32>
return %0 : tensor<?x?x8x2xf32>
}
-// CHECK: func.func @pad_and_pack_partially_dynamic
-// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]]: tensor<?x?x8x2xf32>
-// CHECK-SAME: %[[PAD:[a-zA-Z0-9_]+]]: f32
-// CHECK: %[[RES:.+]] = iree_linalg_ext.pack %[[INPUT]]
-// CHECK-SAME: padding_value(%[[PAD]] : f32)
-// CHECK-SAME: inner_dims_pos = [0, 1]
-// CHECK-SAME: inner_tiles = [8, 2]
-// CHECK-SAME: into %[[OUTPUT]]
-// CHECK: return %[[RES]]
+// CHECK-LABEL: func.func @pad_and_pack_partially_dynamic(
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]]: tensor<?x?x8x2xf32>
+// CHECK-SAME: %[[PAD:[a-zA-Z0-9_]+]]: f32
+// CHECK: %[[RES:.+]] = iree_linalg_ext.pack %[[INPUT]]
+// CHECK-SAME: padding_value(%[[PAD]] : f32)
+// CHECK-SAME: inner_dims_pos = [0, 1]
+// CHECK-SAME: inner_tiles = [8, 2]
+// CHECK-SAME: into %[[OUTPUT]]
+// CHECK: return %[[RES]]
// -----
@@ -697,18 +696,18 @@
inner_dims_pos = [0, 1] inner_tiles = [%tile_n, %tile_m] into %output : (tensor<?x?xf32> tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
return %0 : tensor<?x?x?x?xf32>
}
-// CHECK: func.func @pad_and_pack_fully_dynamic
-// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>
-// CHECK-SAME: %[[PAD:[a-zA-Z0-9_]+]]: f32
-// CHECK-SAME: %[[TILE_N:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[TILE_M:[a-zA-Z0-9_]+]]: index
-// CHECK: %[[RES:.+]] = iree_linalg_ext.pack %[[INPUT]]
-// CHECK-SAME: padding_value(%[[PAD]] : f32)
-// CHECK-SAME: inner_dims_pos = [0, 1]
-// CHECK-SAME: inner_tiles = [%[[TILE_N]], %[[TILE_M]]]
-// CHECK-SAME: into %[[OUTPUT]]
-// CHECK: return %[[RES]]
+// CHECK-LABEL: func.func @pad_and_pack_fully_dynamic(
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[PAD:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[TILE_N:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[TILE_M:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[RES:.+]] = iree_linalg_ext.pack %[[INPUT]]
+// CHECK-SAME: padding_value(%[[PAD]] : f32)
+// CHECK-SAME: inner_dims_pos = [0, 1]
+// CHECK-SAME: inner_tiles = [%[[TILE_N]], %[[TILE_M]]]
+// CHECK-SAME: into %[[OUTPUT]]
+// CHECK: return %[[RES]]
// -----
@@ -717,10 +716,10 @@
return
}
-// CHECK: func.func @unpack(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<3x3xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<3x3x1x1xf32>) {
-// CHECK: iree_linalg_ext.unpack %[[ARG1]] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %[[ARG0]] : (memref<3x3x1x1xf32> memref<3x3xf32>)
+// CHECK-LABEL: func.func @unpack(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<3x3xf32>,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<3x3x1x1xf32>) {
+// CHECK: iree_linalg_ext.unpack %[[ARG1]] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %[[ARG0]] : (memref<3x3x1x1xf32> memref<3x3xf32>)
// -----
@@ -729,15 +728,15 @@
return %0 : tensor<256x128xf32>
}
-// CHECK: func.func @unpack_static
-// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]
-// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]]
-// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack
-// CHECK-SAME: %[[INPUT]]
-// CHECK-SAME dim_pos = [0, 1]
-// CHECK-SAME inner_pos = [32, 16]
-// CHECK-SAME: into %[[OUTPUT]]
-// CHECK: return %[[UNPACK]]
+// CHECK-LABEL: func.func @unpack_static(
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]]
+// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack
+// CHECK-SAME: %[[INPUT]]
+// CHECK-SAME dim_pos = [0, 1]
+// CHECK-SAME inner_pos = [32, 16]
+// CHECK-SAME: into %[[OUTPUT]]
+// CHECK: return %[[UNPACK]]
// -----
@@ -745,15 +744,15 @@
%0 = iree_linalg_ext.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : (tensor<2x8x8x2xf32> tensor<13x15xf32>) -> tensor<13x15xf32>
return %0 : tensor<13x15xf32>
}
-// CHECK: func.func @unpack_undo_padding
-// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]
-// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]]
-// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack
-// CHECK-SAME: %[[INPUT]]
-// CHECK-SAME dim_pos = [0, 1]
-// CHECK-SAME inner_pos = [32, 16]
-// CHECK-SAME: into %[[OUTPUT]]
-// CHECK: return %[[UNPACK]]
+// CHECK-LABEL: func.func @unpack_undo_padding(
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]]
+// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack
+// CHECK-SAME: %[[INPUT]]
+// CHECK-SAME dim_pos = [0, 1]
+// CHECK-SAME inner_pos = [32, 16]
+// CHECK-SAME: into %[[OUTPUT]]
+// CHECK: return %[[UNPACK]]
// -----
@@ -762,10 +761,10 @@
return
}
-// CHECK: func.func @unpack(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<3x3xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<3x3x1x1xf32>) {
-// CHECK: iree_linalg_ext.unpack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %[[ARG0]] : (memref<3x3x1x1xf32> memref<3x3xf32>)
+// CHECK-LABEL: func.func @unpack(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<3x3xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<3x3x1x1xf32>
+// CHECK: iree_linalg_ext.unpack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %[[ARG0]] : (memref<3x3x1x1xf32> memref<3x3xf32>)
// -----
@@ -774,10 +773,10 @@
return
}
-// CHECK: func.func @pack
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<128x256xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<32x4x32x8xf32>) {
-// CHECK: iree_linalg_ext.pack %[[ARG0]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %[[ARG1]] : (memref<128x256xf32> memref<32x4x32x8xf32>)
+// CHECK-LABEL: func.func @pack(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<128x256xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<32x4x32x8xf32>
+// CHECK: iree_linalg_ext.pack %[[ARG0]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %[[ARG1]] : (memref<128x256xf32> memref<32x4x32x8xf32>)
// -----
@@ -786,10 +785,10 @@
return
}
-// CHECK: func.func @pack
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<128x256xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<4x32x32x8xf32>) {
-// CHECK: iree_linalg_ext.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %[[ARG1]] : (memref<128x256xf32> memref<4x32x32x8xf32>)
+// CHECK-LABEL: func.func @pack(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<128x256xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<4x32x32x8xf32>
+// CHECK: iree_linalg_ext.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %[[ARG1]] : (memref<128x256xf32> memref<4x32x32x8xf32>)
// -----
@@ -798,10 +797,10 @@
return
}
-// CHECK: func.func @unpack
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<128x256xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<4x32x32x8xf32>) {
-// CHECK: iree_linalg_ext.unpack %[[ARG1]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %[[ARG0]] : (memref<4x32x32x8xf32> memref<128x256xf32>)
+// CHECK-LABEL: func.func @unpack(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<128x256xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<4x32x32x8xf32>
+// CHECK: iree_linalg_ext.unpack %[[ARG1]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %[[ARG0]] : (memref<4x32x32x8xf32> memref<128x256xf32>)
// -----
@@ -810,28 +809,31 @@
return
}
-// CHECK: func.func @unpack
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<128x256xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<32x4x32x8xf32>) {
-// CHECK: iree_linalg_ext.unpack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %[[ARG0]] : (memref<32x4x32x8xf32> memref<128x256xf32>)
+// CHECK-LABEL: func.func @unpack(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<128x256xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<32x4x32x8xf32>
+// CHECK: iree_linalg_ext.unpack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %[[ARG0]] : (memref<32x4x32x8xf32> memref<128x256xf32>)
// -----
func.func @im2col(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32> {
%0 = tensor.empty() : tensor<2x1024x5760xf32>
%1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
- m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+ m_offset = [0] * [1] k_offset = [0] * [1]
+ batch_pos = [0] m_pos = [1, 2] k_pos = [3]
ins(%arg0 : tensor<2x34x34x640xf32>)
outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
return %1 : tensor<2x1024x5760xf32>
}
-// CHECK: func.func @im2col(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32>
-// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
-// CHECK: %[[D1:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-// CHECK-SAME: m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
-// CHECK-SAME: ins(%[[ARG0]] : tensor<2x34x34x640xf32>)
-// CHECK-SAME: outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
-// CHECK: return %[[D1]] : tensor<2x1024x5760xf32>
+// CHECK-LABEL: func.func @im2col(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x34x34x640xf32>
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
+// CHECK: %[[D1:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+// CHECK-SAME: m_offset = [0] * [1] k_offset = [0] * [1]
+// CHECK-SAME: batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+// CHECK-SAME: ins(%[[ARG0]] : tensor<2x34x34x640xf32>)
+// CHECK-SAME: outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
+// CHECK: return %[[D1]] : tensor<2x1024x5760xf32>
// -----
@@ -839,91 +841,127 @@
%mOffset: index, %kOffset: index) -> tensor<?x?x?xf32> {
%0 = tensor.empty(%s0, %s1, %s2) : tensor<?x?x?xf32>
%1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
- m_offset = [%mOffset] k_offset = [%kOffset] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+ m_offset = [%mOffset] * [1] k_offset = [%kOffset] * [1]
+ batch_pos = [0] m_pos = [1, 2] k_pos = [3]
ins(%arg0 : tensor<?x?x?x?xf32>)
outs(%0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
-// CHECK: func.func @im2col_dynamic(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>,
-// CHECK-SAME: %{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[MOFFSET:.+]]: index, %[[KOFFSET:.+]]: index
-// CHECK: %[[D0:.+]] = tensor.empty({{.+}}) : tensor<?x?x?xf32>
-// CHECK: %[[D1:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-// CHECK-SAME: m_offset = [%[[MOFFSET]]] k_offset = [%[[KOFFSET]]] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
-// CHECK-SAME: ins(%[[ARG0]] : tensor<?x?x?x?xf32>)
-// CHECK-SAME: outs(%[[D0]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-// CHECK: return %[[D1]] : tensor<?x?x?xf32>
+// CHECK-LABEL: func.func @im2col_dynamic(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[MOFFSET:.+]]: index, %[[KOFFSET:.+]]: index
+// CHECK: %[[D0:.+]] = tensor.empty({{.+}}) : tensor<?x?x?xf32>
+// CHECK: %[[D1:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+// CHECK-SAME: m_offset = [%[[MOFFSET]]] * [1] k_offset = [%[[KOFFSET]]] * [1]
+// CHECK-SAME: batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+// CHECK-SAME: ins(%[[ARG0]] : tensor<?x?x?x?xf32>)
+// CHECK-SAME: outs(%[[D0]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: return %[[D1]] : tensor<?x?x?xf32>
// -----
func.func @im2col_strided(%arg0: tensor<2x65x96x640xf32>) -> tensor<2x1024x5760xf32> {
%0 = tensor.empty() : tensor<2x1024x5760xf32>
%1 = iree_linalg_ext.im2col strides = [2, 3] dilations = [1, 1] kernel_size = [3, 3]
- m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+ m_offset = [0] * [1] k_offset = [0] * [1]
+ batch_pos = [0] m_pos = [1, 2] k_pos = [3]
ins(%arg0 : tensor<2x65x96x640xf32>)
outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
return %1 : tensor<2x1024x5760xf32>
}
-// CHECK: func.func @im2col_strided(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x65x96x640xf32>) -> tensor<2x1024x5760xf32>
-// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
-// CHECK: %[[D1:.+]] = iree_linalg_ext.im2col strides = [2, 3] dilations = [1, 1] kernel_size = [3, 3]
-// CHECK-SAME: m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
-// CHECK-SAME: ins(%[[ARG0]] : tensor<2x65x96x640xf32>)
-// CHECK-SAME: outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
-// CHECK: return %[[D1]] : tensor<2x1024x5760xf32>
+// CHECK-LABEL: func.func @im2col_strided(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x65x96x640xf32>
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
+// CHECK: %[[D1:.+]] = iree_linalg_ext.im2col strides = [2, 3] dilations = [1, 1] kernel_size = [3, 3]
+// CHECK-SAME: m_offset = [0] * [1] k_offset = [0] * [1]
+// CHECK-SAME: batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+// CHECK-SAME: ins(%[[ARG0]] : tensor<2x65x96x640xf32>)
+// CHECK-SAME: outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
+// CHECK: return %[[D1]] : tensor<2x1024x5760xf32>
// -----
func.func @im2col_dilated(%arg0: tensor<2x44x46x640xf32>) -> tensor<2x1024x5760xf32> {
%0 = tensor.empty() : tensor<2x1024x5760xf32>
%1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [6, 7] kernel_size = [3, 3]
- m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+ m_offset = [0] * [1] k_offset = [0] * [1]
+ batch_pos = [0] m_pos = [1, 2] k_pos = [3]
ins(%arg0 : tensor<2x44x46x640xf32>)
outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
return %1 : tensor<2x1024x5760xf32>
}
-// CHECK: func.func @im2col_dilated(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x44x46x640xf32>) -> tensor<2x1024x5760xf32>
-// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
-// CHECK: %[[D1:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [6, 7] kernel_size = [3, 3]
-// CHECK-SAME: m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
-// CHECK-SAME: ins(%[[ARG0]] : tensor<2x44x46x640xf32>)
-// CHECK-SAME: outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
-// CHECK: return %[[D1]] : tensor<2x1024x5760xf32>
+// CHECK-LABEL: func.func @im2col_dilated(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x44x46x640xf32>
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
+// CHECK: %[[D1:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [6, 7] kernel_size = [3, 3]
+// CHECK-SAME: m_offset = [0] * [1] k_offset = [0] * [1]
+// CHECK-SAME: batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+// CHECK-SAME: ins(%[[ARG0]] : tensor<2x44x46x640xf32>)
+// CHECK-SAME: outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
+// CHECK: return %[[D1]] : tensor<2x1024x5760xf32>
// -----
func.func @im2col_strided_dilated_mixed_kernel(%arg0: tensor<2x172x101x640xf32>) -> tensor<2x1024x5760xf32> {
%0 = tensor.empty() : tensor<2x1024x5760xf32>
%1 = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2]
- m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+ m_offset = [0] * [1] k_offset = [0] * [1]
+ batch_pos = [0] m_pos = [1, 2] k_pos = [3]
ins(%arg0 : tensor<2x172x101x640xf32>)
outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
return %1 : tensor<2x1024x5760xf32>
}
-// CHECK: func.func @im2col_strided_dilated_mixed_kernel(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x172x101x640xf32>) -> tensor<2x1024x5760xf32>
-// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
-// CHECK: %[[D1:.+]] = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2]
-// CHECK-SAME: m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
-// CHECK-SAME: ins(%[[ARG0]] : tensor<2x172x101x640xf32>)
-// CHECK-SAME: outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
-// CHECK: return %[[D1]] : tensor<2x1024x5760xf32>
+// CHECK-LABEL: func.func @im2col_strided_dilated_mixed_kernel(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x172x101x640xf32>
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
+// CHECK: %[[D1:.+]] = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2]
+// CHECK-SAME: m_offset = [0] * [1] k_offset = [0] * [1]
+// CHECK-SAME: batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+// CHECK-SAME: ins(%[[ARG0]] : tensor<2x172x101x640xf32>)
+// CHECK-SAME: outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
+// CHECK: return %[[D1]] : tensor<2x1024x5760xf32>
// -----
func.func @im2col_transposed_m_pos(%arg0: tensor<640x2x101x172xf32>) -> tensor<2x1024x5760xf32> {
%0 = tensor.empty() : tensor<2x1024x5760xf32>
%1 = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2]
- m_offset = [0] k_offset = [0] batch_pos = [1] m_pos = [3, 2] k_pos = [0]
+ m_offset = [0] * [1] k_offset = [0] * [1]
+ batch_pos = [1] m_pos = [3, 2] k_pos = [0]
ins(%arg0 : tensor<640x2x101x172xf32>)
outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
return %1 : tensor<2x1024x5760xf32>
}
-// CHECK: func.func @im2col_transposed_m_pos(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<640x2x101x172xf32>) -> tensor<2x1024x5760xf32>
-// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
-// CHECK: %[[D1:.+]] = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2]
-// CHECK-SAME: m_offset = [0] k_offset = [0] batch_pos = [1] m_pos = [3, 2] k_pos = [0]
-// CHECK-SAME: ins(%[[ARG0]] : tensor<640x2x101x172xf32>)
-// CHECK-SAME: outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
-// CHECK: return %[[D1]] : tensor<2x1024x5760xf32>
+// CHECK-LABEL: func.func @im2col_transposed_m_pos(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<640x2x101x172xf32>
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
+// CHECK: %[[D1:.+]] = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2]
+// CHECK-SAME: m_offset = [0] * [1] k_offset = [0] * [1]
+// CHECK-SAME: batch_pos = [1] m_pos = [3, 2] k_pos = [0]
+// CHECK-SAME: ins(%[[ARG0]] : tensor<640x2x101x172xf32>)
+// CHECK-SAME: outs(%[[D0]] : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
+// CHECK: return %[[D1]] : tensor<2x1024x5760xf32>
+
+// -----
+
+func.func @im2col_expanded(%arg0: tensor<2x3x34x34x640xf32>) -> tensor<2x3x128x8x90x64xf32> {
+ %0 = tensor.empty() : tensor<2x3x128x8x90x64xf32>
+ %1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+ m_offset = [0, 0] * [8, 1] k_offset = [0, 0] * [64, 1]
+ batch_pos = [0, 1] m_pos = [2, 3] k_pos = [4]
+ ins(%arg0 : tensor<2x3x34x34x640xf32>)
+ outs(%0 : tensor<2x3x128x8x90x64xf32>) -> tensor<2x3x128x8x90x64xf32>
+ return %1 : tensor<2x3x128x8x90x64xf32>
+}
+// CHECK-LABEL: func.func @im2col_expanded(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x3x34x34x640xf32>
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x3x128x8x90x64xf32>
+// CHECK: %[[D1:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+// CHECK-SAME: m_offset = [0, 0] * [8, 1] k_offset = [0, 0] * [64, 1]
+// CHECK-SAME: batch_pos = [0, 1] m_pos = [2, 3] k_pos = [4]
+// CHECK-SAME: ins(%[[ARG0]] : tensor<2x3x34x34x640xf32>)
+// CHECK-SAME: outs(%[[D0]] : tensor<2x3x128x8x90x64xf32>) -> tensor<2x3x128x8x90x64xf32>
+// CHECK: return %[[D1]] : tensor<2x3x128x8x90x64xf32>
// -----
@@ -934,14 +972,13 @@
ins(%arg0 : tensor<3x3x64x128xf32>) outs(%0 : tensor<8x8x64x128xf32>) -> tensor<8x8x64x128xf32>
return %1 : tensor<8x8x64x128xf32>
}
-// CHECK: func.func @winograd_filter_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<3x3x64x128xf32>) ->
-// CHECK-SAME: tensor<8x8x64x128xf32> {
-// CHECK: %[[D0:.+]] = tensor.empty() : tensor<8x8x64x128xf32>
-// CHECK: %[[D1:.+]] = iree_linalg_ext.winograd.filter_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME: kernel_dimensions([0, 1]) ins(%[[ARG0]] : tensor<3x3x64x128xf32>) outs(%[[D0]] :
-// CHECK-SAME: tensor<8x8x64x128xf32>) -> tensor<8x8x64x128xf32>
-// CHECK: return %[[D1]] : tensor<8x8x64x128xf32>
-// CHECK: }
+// CHECK-LABEL: func.func @winograd_filter_transform(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<3x3x64x128xf32>
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<8x8x64x128xf32>
+// CHECK: %[[D1:.+]] = iree_linalg_ext.winograd.filter_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: kernel_dimensions([0, 1]) ins(%[[ARG0]] : tensor<3x3x64x128xf32>) outs(%[[D0]] :
+// CHECK-SAME: tensor<8x8x64x128xf32>) -> tensor<8x8x64x128xf32>
+// CHECK: return %[[D1]] : tensor<8x8x64x128xf32>
// -----
@@ -951,13 +988,13 @@
ins(%arg0 : tensor<3x3x?x?xf32>) outs(%arg1 : tensor<8x8x?x?xf32>) -> tensor<8x8x?x?xf32>
return %1 : tensor<8x8x?x?xf32>
}
-// CHECK: func.func @winograd_filter_transform_dynamic(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<3x3x?x?xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<8x8x?x?xf32>) -> tensor<8x8x?x?xf32> {
-// CHECK: %[[D0:.+]] = iree_linalg_ext.winograd.filter_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME: kernel_dimensions([0, 1]) ins(%[[ARG0]] : tensor<3x3x?x?xf32>) outs(%[[ARG1]] :
-// CHECK-SAME: tensor<8x8x?x?xf32>) -> tensor<8x8x?x?xf32>
-// CHECK: return %[[D0]] : tensor<8x8x?x?xf32>
-// CHECK: }
+// CHECK-LABEL: func.func @winograd_filter_transform_dynamic(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<3x3x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<8x8x?x?xf32>
+// CHECK: %[[D0:.+]] = iree_linalg_ext.winograd.filter_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: kernel_dimensions([0, 1]) ins(%[[ARG0]] : tensor<3x3x?x?xf32>) outs(%[[ARG1]] :
+// CHECK-SAME: tensor<8x8x?x?xf32>) -> tensor<8x8x?x?xf32>
+// CHECK: return %[[D0]] : tensor<8x8x?x?xf32>
// -----
@@ -968,15 +1005,13 @@
ins(%arg0 : tensor<128x64x3x3xf32>) outs(%0 : tensor<8x8x64x128xf32>) -> tensor<8x8x64x128xf32>
return %1 : tensor<8x8x64x128xf32>
}
-// CHECK: func.func @winograd_filter_transform_fchw(%[[ARG0]]: tensor<128x64x3x3xf32>) ->
-// CHECK-SAME: tensor<8x8x64x128xf32> {
-// CHECK: %[[D0]] = tensor.empty() : tensor<8x8x64x128xf32>
-// CHECK: %[[D1]] = iree_linalg_ext.winograd.filter_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME: kernel_dimensions([2, 3]) ins(%[[ARG0]] : tensor<128x64x3x3xf32>) outs(%[[D0]] :
-// CHECK-SAME: tensor<8x8x64x128xf32>) -> tensor<8x8x64x128xf32>
-// CHECK: return %[[D1]] : tensor<8x8x64x128xf32>
-// CHECK: }
-// CHECK: }
+// CHECK-LABEL: func.func @winograd_filter_transform_fchw(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<128x64x3x3xf32>
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<8x8x64x128xf32>
+// CHECK: %[[D1:.+]] = iree_linalg_ext.winograd.filter_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: kernel_dimensions([2, 3]) ins(%[[ARG0]] : tensor<128x64x3x3xf32>) outs(%[[D0]] :
+// CHECK-SAME: tensor<8x8x64x128xf32>) -> tensor<8x8x64x128xf32>
+// CHECK: return %[[D1]] : tensor<8x8x64x128xf32>
// -----
@@ -986,14 +1021,13 @@
ins(%arg0 : tensor<1x10x10x1280xf32>) outs(%0 : tensor<8x8x1x2x2x1280xf32>) -> tensor<8x8x1x2x2x1280xf32>
return %1 : tensor<8x8x1x2x2x1280xf32>
}
-// CHECK: func.func @winograd_input_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10x10x1280xf32>) ->
-// CHECK-SAME: tensor<8x8x1x2x2x1280xf32> {
-// CHECK: %[[D0:.+]] = tensor.empty() : tensor<8x8x1x2x2x1280xf32>
-// CHECK: %[[D1:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<1x10x10x1280xf32>) outs(%[[D0]] :
-// CHECK-SAME: tensor<8x8x1x2x2x1280xf32>) -> tensor<8x8x1x2x2x1280xf32>
-// CHECK: return %[[D1]] : tensor<8x8x1x2x2x1280xf32>
-// CHECK: }
+// CHECK-LABEL: func.func @winograd_input_transform(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10x10x1280xf32>
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<8x8x1x2x2x1280xf32>
+// CHECK: %[[D1:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<1x10x10x1280xf32>) outs(%[[D0]] :
+// CHECK-SAME: tensor<8x8x1x2x2x1280xf32>) -> tensor<8x8x1x2x2x1280xf32>
+// CHECK: return %[[D1]] : tensor<8x8x1x2x2x1280xf32>
// -----
@@ -1003,13 +1037,13 @@
ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<8x8x?x?x?x?xf32>) -> tensor<8x8x?x?x?x?xf32>
return %1 : tensor<8x8x?x?x?x?xf32>
}
-// CHECK: func.func @winograd_input_transform_dynamic(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<8x8x?x?x?x?xf32>) -> tensor<8x8x?x?x?x?xf32> {
-// CHECK: %[[D0:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<?x?x?x?xf32>) outs(%[[ARG1]] :
-// CHECK-SAME: tensor<8x8x?x?x?x?xf32>) -> tensor<8x8x?x?x?x?xf32>
-// CHECK: return %[[D0]] : tensor<8x8x?x?x?x?xf32>
-// CHECK: }
+// CHECK-LABEL: func.func @winograd_input_transform_dynamic(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<8x8x?x?x?x?xf32>
+// CHECK: %[[D0:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<?x?x?x?xf32>) outs(%[[ARG1]] :
+// CHECK-SAME: tensor<8x8x?x?x?x?xf32>) -> tensor<8x8x?x?x?x?xf32>
+// CHECK: return %[[D0]] : tensor<8x8x?x?x?x?xf32>
// -----
@@ -1019,15 +1053,13 @@
ins(%arg0 : tensor<1x1280x10x10xf32>) outs(%0 : tensor<8x8x1x2x2x1280xf32>) -> tensor<8x8x1x2x2x1280xf32>
return %1 : tensor<8x8x1x2x2x1280xf32>
}
-// CHECK: func.func @winograd_input_transform_nchw(%[[ARG0]]: tensor<1x1280x10x10xf32>) ->
-// CHECK-SAME: tensor<8x8x1x2x2x1280xf32> {
-// CHECK: %[[D0]] = tensor.empty() : tensor<8x8x1x2x2x1280xf32>
-// CHECK: %[[D1]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME: image_dimensions([2, 3]) ins(%[[ARG0]] : tensor<1x1280x10x10xf32>) outs(%[[D0]] :
-// CHECK-SAME: tensor<8x8x1x2x2x1280xf32>) -> tensor<8x8x1x2x2x1280xf32>
-// CHECK: return %[[D1]] : tensor<8x8x1x2x2x1280xf32>
-// CHECK: }
-// CHECK: }
+// CHECK-LABEL: func.func @winograd_input_transform_nchw(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x1280x10x10xf32>
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<8x8x1x2x2x1280xf32>
+// CHECK: %[[D1:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: image_dimensions([2, 3]) ins(%[[ARG0]] : tensor<1x1280x10x10xf32>) outs(%[[D0]] :
+// CHECK-SAME: tensor<8x8x1x2x2x1280xf32>) -> tensor<8x8x1x2x2x1280xf32>
+// CHECK: return %[[D1]] : tensor<8x8x1x2x2x1280xf32>
// -----
@@ -1037,29 +1069,28 @@
ins(%arg0 : tensor<8x8x1x2x2x1280xf32>) outs(%0 : tensor<1x12x12x1280xf32>) -> tensor<1x12x12x1280xf32>
return %1 : tensor<1x12x12x1280xf32>
}
-// CHECK: func.func @winograd_output_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x1x2x2x1280xf32>) ->
-// CHECK-SAME: tensor<1x12x12x1280xf32> {
-// CHECK: %[[D0:.+]] = tensor.empty() : tensor<1x12x12x1280xf32>
-// CHECK: %[[D1:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<8x8x1x2x2x1280xf32>) outs(%[[D0]] :
-// CHECK-SAME: tensor<1x12x12x1280xf32>) -> tensor<1x12x12x1280xf32>
-// CHECK: return %[[D1]] : tensor<1x12x12x1280xf32>
-// CHECK: }
+// CHECK-LABEL: func.func @winograd_output_transform(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x1x2x2x1280xf32>
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<1x12x12x1280xf32>
+// CHECK: %[[D1:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<8x8x1x2x2x1280xf32>) outs(%[[D0]] :
+// CHECK-SAME: tensor<1x12x12x1280xf32>) -> tensor<1x12x12x1280xf32>
+// CHECK: return %[[D1]] : tensor<1x12x12x1280xf32>
// -----
-func.func @winograd_output_transform(%arg0: tensor<8x8x?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+func.func @winograd_output_transform_dynamic(%arg0: tensor<8x8x?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
%1 = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3) image_dimensions([1, 2])
ins(%arg0 : tensor<8x8x?x?x?x?xf32>) outs(%arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
return %1 : tensor<?x?x?x?xf32>
}
-// CHECK: func.func @winograd_output_transform(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x?x?x?x?xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
-// CHECK: %[[D0:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<8x8x?x?x?x?xf32>) outs(%[[ARG1]] :
-// CHECK-SAME: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
-// CHECK: return %[[D0]] : tensor<?x?x?x?xf32>
-// CHECK: }
+// CHECK-LABEL: func.func @winograd_output_transform_dynamic(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>
+// CHECK: %[[D0:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<8x8x?x?x?x?xf32>) outs(%[[ARG1]] :
+// CHECK-SAME: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+// CHECK: return %[[D0]] : tensor<?x?x?x?xf32>
// -----
@@ -1069,15 +1100,13 @@
ins(%arg0 : tensor<8x8x1x2x2x1280xf32>) outs(%0 : tensor<1x1280x12x12xf32>) -> tensor<1x1280x12x12xf32>
return %1 : tensor<1x1280x12x12xf32>
}
-// CHECK: func.func @winograd_output_transform_nchw(%[[ARG0]]: tensor<8x8x1x2x2x1280xf32>) ->
-// CHECK-SAME: tensor<1x1280x12x12xf32> {
-// CHECK: %[[D0]] = tensor.empty() : tensor<1x1280x12x12xf32>
-// CHECK: %[[D1]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME: image_dimensions([2, 3]) ins(%[[ARG0]] : tensor<8x8x1x2x2x1280xf32>) outs(%[[D0]] :
-// CHECK-SAME: tensor<1x1280x12x12xf32>) -> tensor<1x1280x12x12xf32>
-// CHECK: return %[[D1]] : tensor<1x1280x12x12xf32>
-// CHECK: }
-// CHECK: }
+// CHECK-LABEL: func.func @winograd_output_transform_nchw(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<8x8x1x2x2x1280xf32>
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<1x1280x12x12xf32>
+// CHECK: %[[D1:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: image_dimensions([2, 3]) ins(%[[ARG0]] : tensor<8x8x1x2x2x1280xf32>) outs(%[[D0]] :
+// CHECK-SAME: tensor<1x1280x12x12xf32>) -> tensor<1x1280x12x12xf32>
+// CHECK: return %[[D1]] : tensor<1x1280x12x12xf32>
// -----
@@ -1099,18 +1128,18 @@
// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
-// CHECK: func.func @attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
-// CHECK-SAME: tensor<192x1024x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
-// CHECK-SAME: {
-// CHECK: %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
-// CHECK: %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[D1:.+]] = iree_linalg_ext.attention
-// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
-// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
-// CHECK-SAME: tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%[[D0]] :
-// CHECK-SAME: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
-// CHECK: return %[[D1]] : tensor<192x1024x64xf32>
-// CHECK: }
+// CHECK-LABEL: func.func @attention(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
+// CHECK: %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[D1:.+]] = iree_linalg_ext.attention
+// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
+// CHECK-SAME: tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%[[D0]] :
+// CHECK-SAME: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+// CHECK: return %[[D1]] : tensor<192x1024x64xf32>
// -----
@@ -1131,18 +1160,18 @@
// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
-// CHECK: func.func @cross_attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
-// CHECK-SAME: tensor<192x2048x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x2048x64xf32>) -> tensor<192x1024x64xf32>
-// CHECK-SAME: {
-// CHECK: %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
-// CHECK: %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[D1:.+]] = iree_linalg_ext.attention
-// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
-// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
-// CHECK-SAME: tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x2048x64xf32>, f32) outs(%[[D0]] :
-// CHECK-SAME: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
-// CHECK: return %[[D1]] : tensor<192x1024x64xf32>
-// CHECK: }
+// CHECK-LABEL: func.func @cross_attention(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<192x2048x64xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x2048x64xf32>
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
+// CHECK: %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[D1:.+]] = iree_linalg_ext.attention
+// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
+// CHECK-SAME: tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x2048x64xf32>, f32) outs(%[[D0]] :
+// CHECK-SAME: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+// CHECK: return %[[D1]] : tensor<192x1024x64xf32>
// -----
@@ -1165,18 +1194,18 @@
// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
-// CHECK: func.func @cross_attention_transposev(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
-// CHECK-SAME: tensor<192x2048x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x64x2048xf32>) -> tensor<192x1024x64xf32>
-// CHECK-SAME: {
-// CHECK: %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
-// CHECK: %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[D1:.+]] = iree_linalg_ext.attention
-// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
-// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
-// CHECK-SAME: tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x64x2048xf32>, f32) outs(%[[D0]] :
-// CHECK-SAME: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
-// CHECK: return %[[D1]] : tensor<192x1024x64xf32>
-// CHECK: }
+// CHECK-LABEL: func.func @cross_attention_transposev(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<192x2048x64xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x64x2048xf32>
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32>
+// CHECK: %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[D1:.+]] = iree_linalg_ext.attention
+// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
+// CHECK-SAME: tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x64x2048xf32>, f32) outs(%[[D0]] :
+// CHECK-SAME: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32>
+// CHECK: return %[[D1]] : tensor<192x1024x64xf32>
// -----
@@ -1196,17 +1225,18 @@
// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
-// CHECK: func.func @cross_attention_transposev_dyn(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
-// CHECK-SAME: tensor<?x?x?xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-// CHECK-SAME: {
-// CHECK: %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[D1:.+]] = iree_linalg_ext.attention
-// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
-// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
-// CHECK-SAME: tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32) outs(%[[ARG3]] :
-// CHECK-SAME: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-// CHECK: return %[[D1]] : tensor<?x?x?xf32>
-// CHECK: }
+// CHECK-LABEL: func.func @cross_attention_transposev_dyn(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK: %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[D1:.+]] = iree_linalg_ext.attention
+// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]}
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] :
+// CHECK-SAME: tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32) outs(%[[ARG3]] :
+// CHECK-SAME: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: return %[[D1]] : tensor<?x?x?xf32>
// -----
@@ -1220,12 +1250,12 @@
} -> tensor<?xf32>
return %0 : tensor<?xf32>
}
-// CHECK: #[[MAP:.+]] = affine_map<(d0) -> (d0)>
-// CHECK: func @custom_op_default(
+// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @custom_op_default(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?xf32>
// CHECK: %[[RESULT:.+]] = iree_linalg_ext.custom_op
-// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
// CHECK-SAME: iterator_types = [#iree_linalg_ext.iterator_type<parallel>]
// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>) outs(%[[ARG1]] : tensor<?xf32>)
// CHECK-NEXT: ^bb0(%[[B0:[a-zA-Z0-9]+]]: tensor<?xf32>, %[[B1:[a-zA-Z0-9]+]]: tensor<?xf32>)
@@ -1245,11 +1275,11 @@
} -> tensor<?xf32>
return %0 : tensor<?xf32>
}
-// CHECK: #[[MAP:.+]] = affine_map<(d0) -> ()>
-// CHECK: func @custom_op_scalar_arg(
+// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> ()>
+// CHECK-LABEL: func @custom_op_scalar_arg(
// CHECK-SAME: %[[SCALAR_ARG:[a-zA-Z0-9]+]]: f32
// CHECK: iree_linalg_ext.custom_op
-// CHECK-SAME: indexing_maps = [#{{.+}}, #[[MAP]], #{{.+}}]
+// CHECK-SAME: indexing_maps = [#{{.+}}, #[[$MAP]], #{{.+}}]
// CHECK-SAME: ins(%{{.+}}, %[[SCALAR_ARG]] : tensor<?xf32>, f32)
// CHECK-NEXT: %[[B1:.+]]: f32
@@ -1265,10 +1295,10 @@
} -> tensor<?xf32>
return %0 : tensor<?xf32>
}
-// CHECK: #[[MAP:.+]] = affine_map<() -> ()>
-// CHECK: func @custom_op_empty_affine_map(
+// CHECK: #[[$MAP:.+]] = affine_map<() -> ()>
+// CHECK-LABEL: func @custom_op_empty_affine_map(
// CHECK: iree_linalg_ext.custom_op
-// CHECK-SAME: indexing_maps = [#{{.+}}, #[[MAP]], #{{.+}}]
+// CHECK-SAME: indexing_maps = [#{{.+}}, #[[$MAP]], #{{.+}}]
// -----
@@ -1282,8 +1312,7 @@
} -> tensor<10xf32>
return %0 : tensor<10xf32>
}
-// CHECK: #[[MAP:.+]] = affine_map<(d0) -> (d0)>
-// CHECK: func @custom_op_static_args(
+// CHECK-LABEL: func @custom_op_static_args(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<10xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<10xf32>
// CHECK: iree_linalg_ext.custom_op
@@ -1317,7 +1346,7 @@
} -> tensor<?xf32>
return %0 : tensor<?xf32>
}
-// CHECK: func @custom_op_reduction(
+// CHECK-LABEL: func @custom_op_reduction(
// CHECK: iree_linalg_ext.custom_op
// CHECK-SAME: iterator_types = [#iree_linalg_ext.iterator_type<parallel>, #iree_linalg_ext.iterator_type<reduction>]
// CHECK-NEXT: ^bb0
@@ -1339,8 +1368,7 @@
} -> tensor<?xf32>, tensor<?xf32>
return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
}
-// CHECK: #[[MAP:.+]] = affine_map<(d0) -> (d0)>
-// CHECK: func @custom_op_multiple_results(
+// CHECK-LABEL: func @custom_op_multiple_results(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?xf32>
// CHECK: %[[RESULT:.+]]:2 = iree_linalg_ext.custom_op
@@ -1376,14 +1404,14 @@
} -> tensor<1000000x?xf32>, tensor<1000000x?xf32>
return %0#0, %0#1 : tensor<1000000x?xf32>, tensor<1000000x?xf32>
}
-// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0, s0)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (s0, s1)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (s1, d1)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0, s1)>
-// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
-// CHECK: func @custom_op_symbolic_dims
-// CHECK: iree_linalg_ext.custom_op
-// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP4]]]
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0, s0)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (s0, s1)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (s1, d1)>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0, s1)>
+// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
+// CHECK-LABEL: func @custom_op_symbolic_dims(
+// CHECK: iree_linalg_ext.custom_op
+// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]]]
// -----
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
index c9895be..5f886c7 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp
@@ -6,6 +6,7 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/CommandLine.h"
@@ -17,6 +18,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
@@ -269,15 +271,6 @@
affineOp.getMap().getResult(0).isMultipleOf(constTileSize.value());
}
-// Helper method to add 2 OpFoldResult inputs with affine.apply.
-static OpFoldResult addOfrs(OpBuilder &builder, Location loc, OpFoldResult a,
- OpFoldResult b) {
- AffineExpr d0, d1;
- bindDims(builder.getContext(), d0, d1);
- auto addMap = AffineMap::get(2, 0, {d0 + d1});
- return affine::makeComposedFoldedAffineApply(builder, loc, addMap, {a, b});
-}
-
//===----------------------------------------------------------------------===//
// OnlineAttentionOp
//===----------------------------------------------------------------------===//
@@ -448,7 +441,7 @@
/// ```
/// %im2col = iree_linalg_ext.im2col
/// strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-/// m_offset = [%m_off] k_offset = [%k_off]
+/// m_offset = [%m_off] * [1] k_offset = [%k_off] * [1]
/// batch_pos = [0] m_pos = [1, 2] k_pos = [3]
/// ins(%in : tensor<2x34x34x640xf32>)
/// outs(%out : tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
@@ -470,13 +463,30 @@
FailureOr<SmallVector<Value>> Im2colOp::decomposeOperation(OpBuilder &b) {
Location loc = getLoc();
Value inputSlice = getInput();
- // Unroll all but the K loop
- SmallVector<OpFoldResult> kOffset = getMixedKOffset();
- SmallVector<OpFoldResult> mOffset = getMixedMOffset();
- // Only support single K and M output dimension for now.
- if (kOffset.size() != 1 || mOffset.size() != 1) {
- return failure();
- }
+
+ // This is part of the im2col verifier, but check here in case this changes.
+ assert(getConstantIntValue(getMixedMStrides().back()).value() == 1 &&
+ getConstantIntValue(getMixedKStrides().back()).value() == 1 &&
+ "Expected inner m_offset and k_offset to be 1");
+
+ // Get the linearized mOffset and kOffset.
+ auto linearizeIndex = [&](ArrayRef<OpFoldResult> inds,
+ ArrayRef<OpFoldResult> basis) {
+ MLIRContext *ctx = b.getContext();
+ SmallVector<AffineExpr> dims(inds.size()), symbols(basis.size());
+ bindDimsList<AffineExpr>(ctx, dims);
+ bindSymbolsList<AffineExpr>(ctx, symbols);
+ AffineExpr linearExpr = mlir::linearize(ctx, dims, symbols);
+ SmallVector<OpFoldResult> mapOperands(inds);
+ mapOperands.append(basis.begin(), basis.end());
+ auto linearMap = AffineMap::get(
+ /*dimCount=*/inds.size(), /*symbolCount=*/basis.size(), linearExpr);
+ OpFoldResult linearIdx =
+ affine::makeComposedFoldedAffineApply(b, loc, linearMap, mapOperands);
+ return linearIdx;
+ };
+ OpFoldResult mOffset = linearizeIndex(getMixedMOffset(), getMixedMStrides());
+ OpFoldResult kOffset = linearizeIndex(getMixedKOffset(), getMixedKStrides());
// Step 1: Tile the im2col op to loops with contiguous slices in the
// innermost loop.
@@ -499,14 +509,26 @@
SetVector<int64_t> batchPosSet(getBatchPos().begin(), getBatchPos().end());
OpFoldResult innerSliceSize;
for (int idx = inputSizes.size() - 1; idx >= 0; --idx) {
- if (!batchPosSet.contains(idx)) {
- innerSliceSize = inputSizes[idx];
- break;
+ if (batchPosSet.contains(idx)) {
+ continue;
}
+ innerSliceSize = inputSizes[idx];
+ // If the innermost non-batch dimension is an m_pos dimension, then use the
+ // corresponding kernel_size instead of the input tensor size. This is
+ // because the slice will be of size `kernel_size` at some offset
+ // `i * kernel_size` in this case.
+ for (auto [mPos, kernelSize] :
+ llvm::zip_equal(getMPos(), getMixedKernelSize())) {
+ if (mPos == idx) {
+ innerSliceSize = kernelSize;
+ }
+ }
+ break;
}
- bool tileK =
- !willBeContiguousSlice(innerSliceSize, kTileSize, kOffset.front());
- if (!tileK) {
+ bool vectorizeInnerKLoop =
+ getKPos().back() == getInputRank() - 1 &&
+ willBeContiguousSlice(innerSliceSize, kTileSize, kOffset);
+ if (vectorizeInnerKLoop) {
iterationDomain.pop_back();
} else {
kTileSize = b.getIndexAttr(1);
@@ -571,9 +593,14 @@
}
kBasis.push_back(size);
}
- OpFoldResult kIndex = kOffset.front();
- if (tileK) {
- kIndex = addOfrs(b, nestedLoc, kOffset.front(), ivs.back());
+ OpFoldResult kIndex = kOffset;
+ for (auto [i, ivIdx, stride] :
+ llvm::enumerate(getKOutputDims(), getMixedKStrides())) {
+ if (vectorizeInnerKLoop && i == getMixedKOffset().size() - 1) {
+ break;
+ }
+ OpFoldResult ivOffset = mulOfrs(b, nestedLoc, stride, ivs[ivIdx]);
+ kIndex = addOfrs(b, nestedLoc, kIndex, ivOffset);
}
FailureOr<SmallVector<Value>> maybeDelinKOffset = affine::delinearizeIndex(
b, nestedLoc, getValueOrCreateConstantIndexOp(b, loc, kIndex),
@@ -596,11 +623,22 @@
inputKOffset.push_back(delinKOffset[delinKIdx++]);
}
- // Compute offsets for extract. Start by delinearizing the combined offset
- // of m_offset and the offset from the tiled loop, using the mBasis. This
- // will give an index into the delinearized output space of the convolution.
- Value mArg = tileK ? ivs[ivs.size() - 2] : ivs.back();
- OpFoldResult linearMOffset = addOfrs(b, nestedLoc, mArg, mOffset.front());
+ // Compute offsets for extract. The linearized im2col result M offset is
+ // computed as the m_offset * m_strides inner product plus the linearized
+ // offset from the tiled m loops. The M offsets into the im2col input are then
+ // computed as the delinearized im2col result M offset (in the convolution
+ // result iteration space), plus the convolutional window offsets computed
+ // above.
+ SmallVector<int64_t> mOutDims = getMOutputDims();
+ SmallVector<OpFoldResult> mIvs, mOutStrides(getMixedMStrides());
+ for (auto [idx, dim] : llvm::enumerate(getMOutputDims())) {
+ mIvs.push_back(ivs[dim]);
+ }
+ OpFoldResult linearMIv = linearizeIndex(mIvs, mOutStrides);
+ OpFoldResult linearMOffset = addOfrs(b, nestedLoc, linearMIv, mOffset);
+ // Delinearize the m_offset * m_strides into the convolution output space.
+ // `mBasis` contains the basis for the iteration space of result of the
+ // convolution op (i.e., basis for result H and W dims).
FailureOr<SmallVector<Value>> maybeDelinMOffset = affine::delinearizeIndex(
b, nestedLoc,
getValueOrCreateConstantIndexOp(b, nestedLoc, linearMOffset), mBasis);
@@ -648,7 +686,8 @@
// Extract a slice from the input tensor.
ShapedType outputType = getOutputType();
- SmallVector<OpFoldResult> kTileSizes(outputType.getRank(), b.getIndexAttr(1));
+ SmallVector<OpFoldResult> kTileSizes(
+ std::min<int64_t>(getOutputRank(), getInputRank()), b.getIndexAttr(1));
kTileSizes.back() = kTileSize;
SmallVector<int64_t> kTileSizeStatic;
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp
index 31daf1f..b4d4cca 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp
@@ -137,17 +137,19 @@
SmallVector<int64_t> dilations(convOp.getDilations().getValues<int64_t>());
SmallVector<OpFoldResult> kernelSize = {rewriter.getIndexAttr(fh),
rewriter.getIndexAttr(fw)};
- SmallVector<OpFoldResult> kOffset = {rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> mOffset = {rewriter.getIndexAttr(0)};
+ SmallVector<OpFoldResult> mBasis = {rewriter.getIndexAttr(1)};
+ SmallVector<OpFoldResult> kOffset = {rewriter.getIndexAttr(0)};
+ SmallVector<OpFoldResult> kBasis = {rewriter.getIndexAttr(1)};
SmallVector<int64_t> batchPos = {0};
SmallVector<int64_t> mPos = {1, 2};
SmallVector<int64_t> kPos = {3};
- Value img2ColTensor =
- rewriter
- .create<IREE::LinalgExt::Im2colOp>(
- loc, input, /*output=*/colTensor, strides, dilations,
- kernelSize, mOffset, kOffset, batchPos, mPos, kPos)
- .getResult(0);
+ Value img2ColTensor = rewriter
+ .create<IREE::LinalgExt::Im2colOp>(
+ loc, input, /*output=*/colTensor, strides,
+ dilations, kernelSize, mOffset, mBasis,
+ kOffset, kBasis, batchPos, mPos, kPos)
+ .getResult(0);
SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
auto reshapedFilterType =
@@ -260,17 +262,19 @@
SmallVector<int64_t> dilations(convOp.getDilations().getValues<int64_t>());
SmallVector<OpFoldResult> kernelSize = {rewriter.getIndexAttr(fh),
rewriter.getIndexAttr(fw)};
- SmallVector<OpFoldResult> kOffset = {rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> mOffset = {rewriter.getIndexAttr(0)};
+ SmallVector<OpFoldResult> mBasis = {rewriter.getIndexAttr(1)};
+ SmallVector<OpFoldResult> kOffset = {rewriter.getIndexAttr(0)};
+ SmallVector<OpFoldResult> kBasis = {rewriter.getIndexAttr(1)};
SmallVector<int64_t> batchPos = {0};
SmallVector<int64_t> mPos = {2, 3};
SmallVector<int64_t> kPos = {1};
- Value img2ColTensor =
- rewriter
- .create<IREE::LinalgExt::Im2colOp>(
- loc, input, /*output=*/colTensor, strides, dilations,
- kernelSize, mOffset, kOffset, batchPos, mPos, kPos)
- .getResult(0);
+ Value img2ColTensor = rewriter
+ .create<IREE::LinalgExt::Im2colOp>(
+ loc, input, /*output=*/colTensor, strides,
+ dilations, kernelSize, mOffset, mBasis,
+ kOffset, kBasis, batchPos, mPos, kPos)
+ .getResult(0);
SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
auto reshapedFilterType =
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_im2col.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_im2col.mlir
index e827a7f..fd59547 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_im2col.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_im2col.mlir
@@ -17,7 +17,7 @@
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x196x36xf32>
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col
// CHECK-SAME: strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-// CHECK-SAME: m_offset = [0] k_offset = [0]
+// CHECK-SAME: m_offset = [0] * [1] k_offset = [0] * [1]
// CHECK-SAME: batch_pos = [0] m_pos = [1, 2] k_pos = [3]
// CHECK-SAME: ins(%[[ARG0]] : tensor<1x16x16x4xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x196x36xf32>) -> tensor<1x196x36xf32>
@@ -53,7 +53,7 @@
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x196x36xf32>
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col
// CHECK-SAME: strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-// CHECK-SAME: m_offset = [0] k_offset = [0]
+// CHECK-SAME: m_offset = [0] * [1] k_offset = [0] * [1]
// CHECK-SAME: batch_pos = [0] m_pos = [2, 3] k_pos = [1]
// CHECK-SAME: ins(%[[ARG0]] : tensor<1x4x16x16xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x196x36xf32>) -> tensor<1x196x36xf32>
@@ -89,7 +89,7 @@
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x196x36xf16>
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col
// CHECK-SAME: strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-// CHECK-SAME: m_offset = [0] k_offset = [0]
+// CHECK-SAME: m_offset = [0] * [1] k_offset = [0] * [1]
// CHECK-SAME: batch_pos = [0] m_pos = [1, 2] k_pos = [3]
// CHECK-SAME: ins(%[[ARG0]] : tensor<1x16x16x4xf16>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x196x36xf16>) -> tensor<1x196x36xf16>
@@ -127,7 +127,7 @@
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x49x36xf16>
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col
// CHECK-SAME: strides = [2, 2] dilations = [1, 1] kernel_size = [3, 3]
-// CHECK-SAME: m_offset = [0] k_offset = [0]
+// CHECK-SAME: m_offset = [0] * [1] k_offset = [0] * [1]
// CHECK-SAME: batch_pos = [0] m_pos = [1, 2] k_pos = [3]
// CHECK-SAME: ins(%[[ARG0]] : tensor<1x16x16x4xf16>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x49x36xf16>) -> tensor<1x49x36xf16>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_im2col.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_im2col.mlir
index 455381a..0316e8f 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_im2col.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_im2col.mlir
@@ -6,33 +6,39 @@
func.func @im2col_untile_k(%arg0: tensor<2x34x34x640xf32>, %m_size: index, %m_off: index, %k: index) -> tensor<2x?x4xf32> {
%0 = tensor.empty(%m_size) : tensor<2x?x4xf32>
%k_off = affine.apply #map(%k)
- %7 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] m_offset = [%m_off] k_offset = [%k_off] batch_pos = [0] m_pos = [1, 2] k_pos = [3] ins(%arg0 : tensor<2x34x34x640xf32>) outs(%0 : tensor<2x?x4xf32>) -> tensor<2x?x4xf32>
+ %7 = iree_linalg_ext.im2col
+ strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+ m_offset = [%m_off] * [1] k_offset = [%k_off] * [1]
+ batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+ ins(%arg0 : tensor<2x34x34x640xf32>)
+ outs(%0 : tensor<2x?x4xf32>) -> tensor<2x?x4xf32>
return %7 : tensor<2x?x4xf32>
}
}
-// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 160) * 640)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) floordiv 32 + s1 floordiv 480)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) mod 32 + s1 floordiv 160 - (s1 floordiv 480) * 3)>
-// CHECK: func.func @im2col_untile_k(%[[ARG0:.+]]: tensor<2x34x34x640xf32>
-// CHECK-SAME: %[[mSIZE:.+]]: index, %[[mOFF:.+]]: index, %[[K:.+]]: index)
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
-// CHECK: %[[OUT_TILE:.+]] = tensor.empty(%[[mSIZE]]) : tensor<2x?x4xf32>
-// CHECK: %[[bLOOP:.+]] = scf.for %[[b:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[OUT0:.+]] = %[[OUT_TILE]]) -> (tensor<2x?x4xf32>) {
-// CHECK: %[[mLOOP:.+]] = scf.for %[[m:.+]] = %[[C0]] to %[[mSIZE]] step %[[C1]] iter_args(%[[OUT1:.+]] = %[[OUT0]]) -> (tensor<2x?x4xf32>) {
-// CHECK-DAG: %[[kIDX:.+]] = affine.apply #[[MAP]]()[%[[K]]]
-// CHECK-DAG: %[[hIDX:.+]] = affine.apply #[[MAP1]](%[[m]])[%[[mOFF]], %[[K]]]
-// CHECK-DAG: %[[wIDX:.+]] = affine.apply #[[MAP2]](%[[m]])[%[[mOFF]], %[[K]]]
-// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[b]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x4xf32>
-// CHECK: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[OUT1]][%[[b]], %[[m]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x?x4xf32> to tensor<1x1x4xf32>
-// CHECK: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
-// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[OUT1]][%[[b]], %[[m]], 0] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<2x?x4xf32>
-// CHECK: scf.yield %[[INSERT]] : tensor<2x?x4xf32>
-// CHECK: }
-// CHECK: scf.yield %[[mLOOP]] : tensor<2x?x4xf32>
-// CHECK: }
-// CHECK: return %[[bLOOP]] : tensor<2x?x4xf32>
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 160) * 640)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) floordiv 32 + s1 floordiv 480)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) mod 32 + s1 floordiv 160 - (s1 floordiv 480) * 3)>
+// CHECK-LABEL: func.func @im2col_untile_k
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[mSIZE:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[mOFF:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[K:[a-zA-Z0-9_]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[OUT_TILE:.+]] = tensor.empty(%[[mSIZE]]) : tensor<2x?x4xf32>
+// CHECK: %[[bLOOP:.+]] = scf.for %[[b:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[OUT0:.+]] = %[[OUT_TILE]]) -> (tensor<2x?x4xf32>)
+// CHECK: %[[mLOOP:.+]] = scf.for %[[m:.+]] = %[[C0]] to %[[mSIZE]] step %[[C1]] iter_args(%[[OUT1:.+]] = %[[OUT0]]) -> (tensor<2x?x4xf32>)
+// CHECK-DAG: %[[kIDX:.+]] = affine.apply #[[$MAP]]()[%[[K]]]
+// CHECK-DAG: %[[hIDX:.+]] = affine.apply #[[$MAP1]](%[[m]])[%[[mOFF]], %[[K]]]
+// CHECK-DAG: %[[wIDX:.+]] = affine.apply #[[$MAP2]](%[[m]])[%[[mOFF]], %[[K]]]
+// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[b]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x4xf32>
+// CHECK: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[OUT1]][%[[b]], %[[m]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x?x4xf32> to tensor<1x1x4xf32>
+// CHECK: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[OUT1]][%[[b]], %[[m]], 0] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<2x?x4xf32>
+// CHECK: scf.yield %[[INSERT]] : tensor<2x?x4xf32>
+// CHECK: scf.yield %[[mLOOP]] : tensor<2x?x4xf32>
+// CHECK: return %[[bLOOP]] : tensor<2x?x4xf32>
// -----
@@ -42,36 +48,106 @@
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = tensor.empty(%m_size, %k_size) : tensor<2x?x?xf32>
- %8 = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2] m_offset = [%m_off] k_offset = [%k_off] batch_pos = [1] m_pos = [3, 2] k_pos = [0] ins(%arg0 : tensor<640x2x101x172xf32>) outs(%0 : tensor<2x?x?xf32>) -> tensor<2x?x?xf32>
+ %8 = iree_linalg_ext.im2col
+ strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2]
+ m_offset = [%m_off] * [1] k_offset = [%k_off] * [1]
+ batch_pos = [1] m_pos = [3, 2] k_pos = [0]
+ ins(%arg0 : tensor<640x2x101x172xf32>)
+ outs(%0 : tensor<2x?x?xf32>) -> tensor<2x?x?xf32>
return %8 : tensor<2x?x?xf32>
}
}
-// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0] -> ((d0 + s0) floordiv 10)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (((d0 + s0) floordiv 32) * 5 + (((d1 + s1) mod 10) floordiv 5) * 4)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * 3 + d1 * 7 + s0 * 3 + s1 * 7 - ((d0 + s0) floordiv 32) * 96 - ((d1 + s1) floordiv 5) * 35)>
-// CHECK: func.func @im2col_transposed_m_pos(%[[ARG0:.+]]: tensor<640x2x101x172xf32>
-// CHECK-SAME: %[[mSIZE:.+]]: index, %[[kSIZE:.+]]: index, %[[mOFF:.+]]: index, %[[kOFF:.+]]: index)
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
-// CHECK: %[[OUT_TILE:.+]] = tensor.empty(%[[mSIZE]], %[[kSIZE]]) : tensor<2x?x?xf32>
-// CHECK: %[[bLOOP:.+]] = scf.for %[[b:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[OUT0:.+]] = %[[OUT_TILE]]) -> (tensor<2x?x?xf32>) {
-// CHECK: %[[mLOOP:.+]] = scf.for %[[m:.+]] = %[[C0]] to %[[mSIZE]] step %[[C1]] iter_args(%[[OUT1:.+]] = %[[OUT0]]) -> (tensor<2x?x?xf32>) {
-// CHECK: %[[kLOOP:.+]] = scf.for %[[k:.+]] = %[[C0]] to %[[kSIZE]] step %[[C1]] iter_args(%[[OUT2:.+]] = %[[OUT1]]) -> (tensor<2x?x?xf32>) {
-// CHECK-DAG: %[[kIDX:.+]] = affine.apply #[[MAP]](%[[k]])[%[[kOFF]]]
-// CHECK-DAG: %[[hIDX:.+]] = affine.apply #[[MAP1]](%[[m]], %[[k]])[%[[mOFF]], %[[kOFF]]]
-// CHECK-DAG: %[[wIDX:.+]] = affine.apply #[[MAP2]](%[[m]], %[[k]])[%[[mOFF]], %[[kOFF]]]
-// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[kIDX]], %[[b]], %[[wIDX]], %[[hIDX]]] [1, 1, 1, 1] [1, 1, 1, 1] : tensor<640x2x101x172xf32> to tensor<1x1x1xf32>
-// CHECK: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[OUT2]][%[[b]], %[[m]], %[[k]]] [1, 1, 1] [1, 1, 1] : tensor<2x?x?xf32> to tensor<1x1x1xf32>
-// CHECK: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x1xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
-// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[OUT2]][%[[b]], %[[m]], %[[k]]] [1, 1, 1] [1, 1, 1] : tensor<1x1x1xf32> into tensor<2x?x?xf32>
-// CHECK: scf.yield %[[INSERT]] : tensor<2x?x?xf32>
-// CHECK: }
-// CHECK: scf.yield %[[kLOOP]] : tensor<2x?x?xf32>
-// CHECK: }
-// CHECK: scf.yield %[[mLOOP]] : tensor<2x?x?xf32>
-// CHECK: }
-// CHECK: return %[[bLOOP]] : tensor<2x?x?xf32>
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0)[s0] -> ((d0 + s0) floordiv 10)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (((d0 + s0) floordiv 32) * 5 + (((d1 + s1) mod 10) floordiv 5) * 4)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * 3 + d1 * 7 + s0 * 3 + s1 * 7 - ((d0 + s0) floordiv 32) * 96 - ((d1 + s1) floordiv 5) * 35)>
+// CHECK-LABEL: func.func @im2col_transposed_m_pos
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[mSIZE:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[kSIZE:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[mOFF:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[kOFF:[a-zA-Z0-9_]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[OUT_TILE:.+]] = tensor.empty(%[[mSIZE]], %[[kSIZE]]) : tensor<2x?x?xf32>
+// CHECK: %[[bLOOP:.+]] = scf.for %[[b:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[OUT0:.+]] = %[[OUT_TILE]]) -> (tensor<2x?x?xf32>)
+// CHECK: %[[mLOOP:.+]] = scf.for %[[m:.+]] = %[[C0]] to %[[mSIZE]] step %[[C1]] iter_args(%[[OUT1:.+]] = %[[OUT0]]) -> (tensor<2x?x?xf32>)
+// CHECK: %[[kLOOP:.+]] = scf.for %[[k:.+]] = %[[C0]] to %[[kSIZE]] step %[[C1]] iter_args(%[[OUT2:.+]] = %[[OUT1]]) -> (tensor<2x?x?xf32>)
+// CHECK-DAG: %[[kIDX:.+]] = affine.apply #[[$MAP]](%[[k]])[%[[kOFF]]]
+// CHECK-DAG: %[[hIDX:.+]] = affine.apply #[[$MAP1]](%[[m]], %[[k]])[%[[mOFF]], %[[kOFF]]]
+// CHECK-DAG: %[[wIDX:.+]] = affine.apply #[[$MAP2]](%[[m]], %[[k]])[%[[mOFF]], %[[kOFF]]]
+// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[kIDX]], %[[b]], %[[wIDX]], %[[hIDX]]] [1, 1, 1, 1] [1, 1, 1, 1] : tensor<640x2x101x172xf32> to tensor<1x1x1xf32>
+// CHECK: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[OUT2]][%[[b]], %[[m]], %[[k]]] [1, 1, 1] [1, 1, 1] : tensor<2x?x?xf32> to tensor<1x1x1xf32>
+// CHECK: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x1xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[OUT2]][%[[b]], %[[m]], %[[k]]] [1, 1, 1] [1, 1, 1] : tensor<1x1x1xf32> into tensor<2x?x?xf32>
+// CHECK: scf.yield %[[INSERT]] : tensor<2x?x?xf32>
+// CHECK: scf.yield %[[kLOOP]] : tensor<2x?x?xf32>
+// CHECK: scf.yield %[[mLOOP]] : tensor<2x?x?xf32>
+// CHECK: return %[[bLOOP]] : tensor<2x?x?xf32>
+
+// -----
+
+module {
+ func.func @im2col_expanded(%arg0: tensor<2x34x34x640xf32>, %m_size0: index, %m_size1: index, %m0: index, %m1: index, %k: index, %m_stride: index) -> tensor<2x?x?x2x4xf32> {
+ %0 = tensor.empty(%m_size0, %m_size1) : tensor<2x?x?x2x4xf32>
+ %7 = iree_linalg_ext.im2col
+ strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+ m_offset = [%m0, %m1] * [%m_stride, 1] k_offset = [%k, 0] * [4, 1]
+ batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+ ins(%arg0 : tensor<2x34x34x640xf32>)
+ outs(%0 : tensor<2x?x?x2x4xf32>) -> tensor<2x?x?x2x4xf32>
+ return %7 : tensor<2x?x?x2x4xf32>
+ }
+}
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0)[s0] -> (d0 * 4 + s0 * 4 - ((d0 + s0) floordiv 160) * 640)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> ((d1 + s2 + d0 * s0 + s1 * s0) floordiv 32 + (d2 + s3) floordiv 480)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s0 + d1 + s1 * s0 + s2 - ((d1 + s2 + d0 * s0 + s1 * s0) floordiv 32) * 32 + (d2 + s3) floordiv 160 - ((d2 + s3) floordiv 480) * 3)>
+// CHECK-LABEL: func.func @im2col_expanded
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[mSIZE0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[mSIZE1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[mOFF0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[mOFF1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[kOFF:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[mSTRIDE:[a-zA-Z0-9_]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[OUT_TILE:.+]] = tensor.empty(%[[mSIZE0]], %[[mSIZE1]]) : tensor<2x?x?x2x4xf32>
+// CHECK: %[[bLOOP:.+]] = scf.for %[[b:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[OUT0:.+]] = %[[OUT_TILE]]) -> (tensor<2x?x?x2x4xf32>)
+// CHECK: %[[mLOOP0:.+]] = scf.for %[[m0:.+]] = %[[C0]] to %[[mSIZE0]] step %[[C1]] iter_args(%[[OUT1:.+]] = %[[OUT0]]) -> (tensor<2x?x?x2x4xf32>)
+// CHECK: %[[mLOOP1:.+]] = scf.for %[[m1:.+]] = %[[C0]] to %[[mSIZE1]] step %[[C1]] iter_args(%[[OUT2:.+]] = %[[OUT1]]) -> (tensor<2x?x?x2x4xf32>)
+// CHECK: %[[kLOOP:.+]] = scf.for %[[k:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[OUT3:.+]] = %[[OUT2]]) -> (tensor<2x?x?x2x4xf32>)
+// CHECK-DAG: %[[kIDX:.+]] = affine.apply #[[$MAP]](%[[k]])[%[[kOFF]]]
+// CHECK-DAG: %[[hIDX:.+]] = affine.apply #[[$MAP1]](%[[m0]], %[[m1]], %[[k]])[%[[mSTRIDE]], %[[mOFF0]], %[[mOFF1]], %[[kOFF]]]
+// CHECK-DAG: %[[wIDX:.+]] = affine.apply #[[$MAP2]](%[[m0]], %[[m1]], %[[k]])[%[[mSTRIDE]], %[[mOFF0]], %[[mOFF1]], %[[kOFF]]]
+// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[b]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x1x4xf32>
+// CHECK: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[OUT3]][%[[b]], %[[m0]], %[[m1]], %[[k]], 0] [1, 1, 1, 1, 4] [1, 1, 1, 1, 1] : tensor<2x?x?x2x4xf32> to tensor<1x1x1x4xf32>
+// CHECK: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x1x4xf32>) -> tensor<1x1x1x4xf32>
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[OUT3]][%[[b]], %[[m0]], %[[m1]], %[[k]], 0] [1, 1, 1, 1, 4] [1, 1, 1, 1, 1] : tensor<1x1x1x4xf32> into tensor<2x?x?x2x4xf32>
+// CHECK: scf.yield %[[INSERT]] : tensor<2x?x?x2x4xf32>
+// CHECK: scf.yield %[[kLOOP]] : tensor<2x?x?x2x4xf32>
+// CHECK: scf.yield %[[mLOOP1]] : tensor<2x?x?x2x4xf32>
+// CHECK: scf.yield %[[mLOOP0]] : tensor<2x?x?x2x4xf32>
+// CHECK: return %[[bLOOP]] : tensor<2x?x?x2x4xf32>
+
+// -----
+
+module {
+ func.func @im2col_expanded_nchw(%arg0: tensor<2x640x34x34xf32>, %m0: index, %m1: index, %k: index) -> tensor<2x1x1x2x4xf32> {
+ %0 = tensor.empty() : tensor<2x1x1x2x4xf32>
+ %7 = iree_linalg_ext.im2col
+ strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+ m_offset = [%m0, %m1] * [32, 1] k_offset = [%k, 0] * [4, 1]
+ batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+ ins(%arg0 : tensor<2x640x34x34xf32>)
+ outs(%0 : tensor<2x1x1x2x4xf32>) -> tensor<2x1x1x2x4xf32>
+ return %7 : tensor<2x1x1x2x4xf32>
+ }
+}
+// Verify that the NCHW layout does not vectorize.
+// CHECK-LABEL: func.func @im2col_expanded_nchw
+// CHECK: linalg.copy ins({{.*}} : tensor<1x1x1x1xf32>) outs({{.*}} : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
// -----
@@ -80,57 +156,64 @@
func.func @im2col_unrolled(%arg0: tensor<2x34x34x640xf32>, %m_off: index, %k: index) -> tensor<2x2x4xf32> {
%0 = tensor.empty() : tensor<2x2x4xf32>
%k_off = affine.apply #map(%k)
- %7 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] m_offset = [%m_off] k_offset = [%k_off] batch_pos = [0] m_pos = [1, 2] k_pos = [3] ins(%arg0 : tensor<2x34x34x640xf32>) outs(%0 : tensor<2x2x4xf32>) -> tensor<2x2x4xf32>
+ %7 = iree_linalg_ext.im2col
+ strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+ m_offset = [%m_off] * [1] k_offset = [%k_off] * [1]
+ batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+ ins(%arg0 : tensor<2x34x34x640xf32>)
+ outs(%0 : tensor<2x2x4xf32>) -> tensor<2x2x4xf32>
return %7 : tensor<2x2x4xf32>
}
}
-// CHECK-UNROLL-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 160) * 640)>
-// CHECK-UNROLL-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) floordiv 32 + s1 floordiv 480)>
-// CHECK-UNROLL-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) mod 32 + s1 floordiv 160 - (s1 floordiv 480) * 3)>
-// CHECK-UNROLL: func.func @im2col_unrolled(%[[ARG0:.+]]: tensor<2x34x34x640xf32>
-// CHECK-UNROLL-SAME: %[[mOFF:.+]]: index, %[[K:.+]]: index)
-// CHECK-UNROLL-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-UNROLL-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-UNROLL: %[[OUT_TILE:.+]] = tensor.empty() : tensor<2x2x4xf32>
+// CHECK-UNROLL-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 160) * 640)>
+// CHECK-UNROLL-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) floordiv 32 + s1 floordiv 480)>
+// CHECK-UNROLL-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0, s1] -> ((d0 + s0) mod 32 + s1 floordiv 160 - (s1 floordiv 480) * 3)>
+// CHECK-UNROLL-LABEL: func.func @im2col_unrolled
+// CHECK-UNROLL-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-UNROLL-SAME: %[[mOFF:[a-zA-Z0-9_]+]]
+// CHECK-UNROLL-SAME: %[[K:[a-zA-Z0-9_]+]]
+// CHECK-UNROLL-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-UNROLL-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-UNROLL: %[[OUT_TILE:.+]] = tensor.empty() : tensor<2x2x4xf32>
// First iteration
//
-// CHECK-UNROLL-DAG: %[[kIDX:.+]] = affine.apply #[[MAP]]()[%[[K]]]
-// CHECK-UNROLL-DAG: %[[hIDX:.+]] = affine.apply #[[MAP1]](%[[C0]])[%[[mOFF]], %[[K]]]
-// CHECK-UNROLL-DAG: %[[wIDX:.+]] = affine.apply #[[MAP2]](%[[C0]])[%[[mOFF]], %[[K]]]
-// CHECK-UNROLL: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[C0]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x4xf32>
-// CHECK-UNROLL: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[OUT_TILE]][%[[C0]], %[[C0]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x2x4xf32> to tensor<1x1x4xf32>
-// CHECK-UNROLL: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
-// CHECK-UNROLL: %[[INSERT0:.+]] = tensor.insert_slice %[[COPY]] into %[[OUT_TILE]][%[[C0]], %[[C0]], 0] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<2x2x4xf32>
+// CHECK-UNROLL-DAG: %[[kIDX:.+]] = affine.apply #[[$MAP]]()[%[[K]]]
+// CHECK-UNROLL-DAG: %[[hIDX:.+]] = affine.apply #[[$MAP1]](%[[C0]])[%[[mOFF]], %[[K]]]
+// CHECK-UNROLL-DAG: %[[wIDX:.+]] = affine.apply #[[$MAP2]](%[[C0]])[%[[mOFF]], %[[K]]]
+// CHECK-UNROLL: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[C0]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x4xf32>
+// CHECK-UNROLL: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[OUT_TILE]][%[[C0]], %[[C0]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x2x4xf32> to tensor<1x1x4xf32>
+// CHECK-UNROLL: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
+// CHECK-UNROLL: %[[INSERT0:.+]] = tensor.insert_slice %[[COPY]] into %[[OUT_TILE]][%[[C0]], %[[C0]], 0] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<2x2x4xf32>
// Second iteration
//
-// CHECK-UNROLL-DAG: %[[kIDX:.+]] = affine.apply #[[MAP]]()[%[[K]]]
-// CHECK-UNROLL-DAG: %[[hIDX:.+]] = affine.apply #[[MAP1]](%[[C1]])[%[[mOFF]], %[[K]]]
-// CHECK-UNROLL-DAG: %[[wIDX:.+]] = affine.apply #[[MAP2]](%[[C1]])[%[[mOFF]], %[[K]]]
-// CHECK-UNROLL: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[C0]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x4xf32>
-// CHECK-UNROLL: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[INSERT0]][%[[C0]], %[[C1]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x2x4xf32> to tensor<1x1x4xf32>
-// CHECK-UNROLL: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
-// CHECK-UNROLL: %[[INSERT1:.+]] = tensor.insert_slice %[[COPY]] into %[[INSERT0]][%[[C0]], %[[C1]], 0] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<2x2x4xf32>
+// CHECK-UNROLL-DAG: %[[kIDX:.+]] = affine.apply #[[$MAP]]()[%[[K]]]
+// CHECK-UNROLL-DAG: %[[hIDX:.+]] = affine.apply #[[$MAP1]](%[[C1]])[%[[mOFF]], %[[K]]]
+// CHECK-UNROLL-DAG: %[[wIDX:.+]] = affine.apply #[[$MAP2]](%[[C1]])[%[[mOFF]], %[[K]]]
+// CHECK-UNROLL: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[C0]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x4xf32>
+// CHECK-UNROLL: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[INSERT0]][%[[C0]], %[[C1]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x2x4xf32> to tensor<1x1x4xf32>
+// CHECK-UNROLL: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
+// CHECK-UNROLL: %[[INSERT1:.+]] = tensor.insert_slice %[[COPY]] into %[[INSERT0]][%[[C0]], %[[C1]], 0] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<2x2x4xf32>
// Third iteration
//
-// CHECK-UNROLL-DAG: %[[kIDX:.+]] = affine.apply #[[MAP]]()[%[[K]]]
-// CHECK-UNROLL-DAG: %[[hIDX:.+]] = affine.apply #[[MAP1]](%[[C0]])[%[[mOFF]], %[[K]]]
-// CHECK-UNROLL-DAG: %[[wIDX:.+]] = affine.apply #[[MAP2]](%[[C0]])[%[[mOFF]], %[[K]]]
-// CHECK-UNROLL: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[C1]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x4xf32>
-// CHECK-UNROLL: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[INSERT1]][%[[C1]], %[[C0]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x2x4xf32> to tensor<1x1x4xf32>
-// CHECK-UNROLL: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
-// CHECK-UNROLL: %[[INSERT2:.+]] = tensor.insert_slice %[[COPY]] into %[[INSERT1]][%[[C1]], %[[C0]], 0] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<2x2x4xf32>
+// CHECK-UNROLL-DAG: %[[kIDX:.+]] = affine.apply #[[$MAP]]()[%[[K]]]
+// CHECK-UNROLL-DAG: %[[hIDX:.+]] = affine.apply #[[$MAP1]](%[[C0]])[%[[mOFF]], %[[K]]]
+// CHECK-UNROLL-DAG: %[[wIDX:.+]] = affine.apply #[[$MAP2]](%[[C0]])[%[[mOFF]], %[[K]]]
+// CHECK-UNROLL: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[C1]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x4xf32>
+// CHECK-UNROLL: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[INSERT1]][%[[C1]], %[[C0]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x2x4xf32> to tensor<1x1x4xf32>
+// CHECK-UNROLL: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
+// CHECK-UNROLL: %[[INSERT2:.+]] = tensor.insert_slice %[[COPY]] into %[[INSERT1]][%[[C1]], %[[C0]], 0] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<2x2x4xf32>
// Fourth iteration
//
-// CHECK-UNROLL-DAG: %[[kIDX:.+]] = affine.apply #[[MAP]]()[%[[K]]]
-// CHECK-UNROLL-DAG: %[[hIDX:.+]] = affine.apply #[[MAP1]](%[[C1]])[%[[mOFF]], %[[K]]]
-// CHECK-UNROLL-DAG: %[[wIDX:.+]] = affine.apply #[[MAP2]](%[[C1]])[%[[mOFF]], %[[K]]]
-// CHECK-UNROLL: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[C1]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x4xf32>
-// CHECK-UNROLL: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[INSERT2]][%[[C1]], %[[C1]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x2x4xf32> to tensor<1x1x4xf32>
-// CHECK-UNROLL: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
-// CHECK-UNROLL: %[[INSERT3:.+]] = tensor.insert_slice %[[COPY]] into %[[INSERT2]][%[[C1]], %[[C1]], 0] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<2x2x4xf32>
+// CHECK-UNROLL-DAG: %[[kIDX:.+]] = affine.apply #[[$MAP]]()[%[[K]]]
+// CHECK-UNROLL-DAG: %[[hIDX:.+]] = affine.apply #[[$MAP1]](%[[C1]])[%[[mOFF]], %[[K]]]
+// CHECK-UNROLL-DAG: %[[wIDX:.+]] = affine.apply #[[$MAP2]](%[[C1]])[%[[mOFF]], %[[K]]]
+// CHECK-UNROLL: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[C1]], %[[hIDX]], %[[wIDX]], %[[kIDX]]] [1, 1, 1, 4] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x1x4xf32>
+// CHECK-UNROLL: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[INSERT2]][%[[C1]], %[[C1]], 0] [1, 1, 4] [1, 1, 1] : tensor<2x2x4xf32> to tensor<1x1x4xf32>
+// CHECK-UNROLL: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<1x1x4xf32>) outs(%[[OUT_SLICE]] : tensor<1x1x4xf32>) -> tensor<1x1x4xf32>
+// CHECK-UNROLL: %[[INSERT3:.+]] = tensor.insert_slice %[[COPY]] into %[[INSERT2]][%[[C1]], %[[C1]], 0] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<2x2x4xf32>
-// CHECK-UNROLL: return %[[INSERT3]] : tensor<2x2x4xf32>
+// CHECK-UNROLL: return %[[INSERT3]] : tensor<2x2x4xf32>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
index 983ec33..f42a2fa 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
@@ -745,7 +745,8 @@
func.func @im2col(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1024x5760xf32> {
%0 = tensor.empty() : tensor<2x1024x5760xf32>
%1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
- m_offset = [34] k_offset = [1000] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+ m_offset = [34] * [1] k_offset = [1000] * [1]
+ batch_pos = [0] m_pos = [1, 2] k_pos = [3]
ins(%arg0 : tensor<2x34x34x640xf32>)
outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
return %1 : tensor<2x1024x5760xf32>
@@ -770,11 +771,11 @@
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
// CHECK: %[[RES0:.+]] = scf.for %[[ARG1:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK-SAME: iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<2x1024x5760xf32>) {
+// CHECK-SAME: iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<2x1024x5760xf32>)
// CHECK: %[[RES1:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C5]]
-// CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<2x1024x5760xf32>) {
+// CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<2x1024x5760xf32>)
// CHECK: %[[RES2:.+]] = scf.for %[[ARG5:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C5760]] step %[[C4]]
-// CHECK-SAME: iter_args(%[[ARG6:[a-zA-Z0-9_]+]] = %[[ARG4]]) -> (tensor<2x1024x5760xf32>) {
+// CHECK-SAME: iter_args(%[[ARG6:[a-zA-Z0-9_]+]] = %[[ARG4]]) -> (tensor<2x1024x5760xf32>)
// CHECK-DAG: %[[MSIZE:.+]] = affine.min #[[MAP]](%[[ARG3]])
// CHECK-DAG: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG1]], 0, 0, 0]
// CHECK-SAME: [1, 34, 34, 640] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x34x34x640xf32>
@@ -783,18 +784,16 @@
// CHECK-DAG: %[[KOFFSET:.+]] = affine.apply #[[MAP1]](%[[ARG5]])
// CHECK-DAG: %[[MOFFSET:.+]] = affine.apply #[[MAP2]](%[[ARG3]])
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-// CHECK-SAME: m_offset = [%[[MOFFSET]]] k_offset = [%[[KOFFSET]]] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+// CHECK-SAME: m_offset = [%[[MOFFSET]]] * [1] k_offset = [%[[KOFFSET]]] * [1]
+// CHECK-SAME: batch_pos = [0] m_pos = [1, 2] k_pos = [3]
// CHECK-SAME: ins(%[[EXTRACTED_SLICE]] : tensor<1x34x34x640xf32>)
// CHECK-SAME: outs(%[[EXTRACTED_SLICE_0]] : tensor<1x?x4xf32>) -> tensor<1x?x4xf32>
// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[IM2COL]] into %[[ARG6]]
// CHECK-SAME: [%[[ARG1]], %[[ARG3]], %[[ARG5]]] [1, %[[MSIZE]], 4] [1, 1, 1]
// CHECK-SAME: tensor<1x?x4xf32> into tensor<2x1024x5760xf32>
// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<2x1024x5760xf32>
-// CHECK: }
// CHECK: scf.yield %[[RES2]] : tensor<2x1024x5760xf32>
-// CHECK: }
// CHECK: scf.yield %[[RES1]] : tensor<2x1024x5760xf32>
-// CHECK: }
// CHECK: return %[[RES0]] : tensor<2x1024x5760xf32>
// -----
@@ -802,7 +801,8 @@
func.func @im2col_transposed_m_pos(%arg0: tensor<640x2x101x172xf32>) -> tensor<2x1024x5760xf32> {
%0 = tensor.empty() : tensor<2x1024x5760xf32>
%1 = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2]
- m_offset = [42] k_offset = [7] batch_pos = [1] m_pos = [3, 2] k_pos = [0]
+ m_offset = [42] * [1] k_offset = [7] * [1]
+ batch_pos = [1] m_pos = [3, 2] k_pos = [0]
ins(%arg0 : tensor<640x2x101x172xf32>)
outs(%0 : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
return %1 : tensor<2x1024x5760xf32>
@@ -828,11 +828,11 @@
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x1024x5760xf32>
// CHECK: %[[RES0:.+]] = scf.for %[[ARG1:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK-SAME: iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<2x1024x5760xf32>) {
+// CHECK-SAME: iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<2x1024x5760xf32>)
// CHECK: %[[RES1:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C9]]
-// CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<2x1024x5760xf32>) {
+// CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<2x1024x5760xf32>)
// CHECK: %[[RES2:.+]] = scf.for %[[ARG5:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C5760]] step %[[C7]]
-// CHECK-SAME: iter_args(%[[ARG6:[a-zA-Z0-9_]+]] = %[[ARG4]]) -> (tensor<2x1024x5760xf32>) {
+// CHECK-SAME: iter_args(%[[ARG6:[a-zA-Z0-9_]+]] = %[[ARG4]]) -> (tensor<2x1024x5760xf32>)
// CHECK-DAG: %[[MSIZE:.+]] = affine.min #[[MAP]](%[[ARG3]])
// CHECK-DAG: %[[KSIZE:.+]] = affine.min #[[MAP1]](%[[ARG5]])
// CHECK-DAG: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %[[ARG1]], 0, 0]
@@ -842,18 +842,16 @@
// CHECK-DAG: %[[KOFFSET:.+]] = affine.apply #[[MAP2]](%[[ARG5]])
// CHECK-DAG: %[[MOFFSET:.+]] = affine.apply #[[MAP3]](%[[ARG3]])
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col strides = [5, 3] dilations = [4, 7] kernel_size = [5, 2]
-// CHECK-SAME: m_offset = [%[[MOFFSET]]] k_offset = [%[[KOFFSET]]] batch_pos = [1] m_pos = [3, 2] k_pos = [0]
+// CHECK-SAME: m_offset = [%[[MOFFSET]]] * [1] k_offset = [%[[KOFFSET]]] * [1]
+// CHECK-SAME: batch_pos = [1] m_pos = [3, 2] k_pos = [0]
// CHECK-SAME: ins(%[[EXTRACTED_SLICE]] : tensor<640x1x101x172xf32>)
// CHECK-SAME: outs(%[[EXTRACTED_SLICE_0]] : tensor<1x?x?xf32>) -> tensor<1x?x?xf32>
// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[IM2COL]] into %[[ARG6]]
// CHECK-SAME: [%[[ARG1]], %[[ARG3]], %[[ARG5]]] [1, %[[MSIZE]], %[[KSIZE]]] [1, 1, 1]
// CHECK-SAME: tensor<1x?x?xf32> into tensor<2x1024x5760xf32>
// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<2x1024x5760xf32>
-// CHECK: }
// CHECK: scf.yield %[[RES2]] : tensor<2x1024x5760xf32>
-// CHECK: }
// CHECK: scf.yield %[[RES1]] : tensor<2x1024x5760xf32>
-// CHECK: }
// CHECK: return %[[RES0]] : tensor<2x1024x5760xf32>
// -----
@@ -862,7 +860,8 @@
%mOffset: index, %kOffset: index) -> tensor<?x?x?xf32> {
%0 = tensor.empty(%s0, %s1, %s2) : tensor<?x?x?xf32>
%1 = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
- m_offset = [%mOffset] k_offset = [%kOffset] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+ m_offset = [%mOffset] * [1] k_offset = [%kOffset] * [1]
+ batch_pos = [0] m_pos = [1, 2] k_pos = [3]
ins(%arg0 : tensor<?x?x?x?xf32>)
outs(%0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
@@ -888,11 +887,11 @@
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[D0:.+]] = tensor.empty(%[[S0]], %[[S1]], %[[S2]]) : tensor<?x?x?xf32>
// CHECK: %[[RES0:.+]] = scf.for %[[ARG1:[a-zA-Z0-9_]+]] = %[[C0]] to %[[S0]] step %[[C2]]
-// CHECK-SAME: iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<?x?x?xf32>) {
+// CHECK-SAME: iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<?x?x?xf32>)
// CHECK: %[[RES1:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[S1]] step %[[C7]]
-// CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<?x?x?xf32>) {
+// CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<?x?x?xf32>)
// CHECK: %[[RES2:.+]] = scf.for %[[ARG5:[a-zA-Z0-9_]+]] = %[[C0]] to %[[S2]] step %[[C5]]
-// CHECK-SAME: iter_args(%[[ARG6:[a-zA-Z0-9_]+]] = %[[ARG4]]) -> (tensor<?x?x?xf32>) {
+// CHECK-SAME: iter_args(%[[ARG6:[a-zA-Z0-9_]+]] = %[[ARG4]]) -> (tensor<?x?x?xf32>)
// CHECK-DAG: %[[BSIZE:.+]] = affine.min #[[MAP]](%[[ARG1]])
// CHECK-DAG: %[[MSIZE:.+]] = affine.min #[[MAP1]](%[[ARG3]])
// CHECK-DAG: %[[KSIZE:.+]] = affine.min #[[MAP2]](%[[ARG5]])
@@ -906,22 +905,89 @@
// CHECK-DAG: %[[KOFFSET:.+]] = affine.apply #[[MAP3]](%[[ARG5]])[%[[KOFF]]]
// CHECK-DAG: %[[MOFFSET:.+]] = affine.apply #[[MAP3]](%[[ARG3]])[%[[MOFF]]]
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
-// CHECK-SAME: m_offset = [%[[MOFFSET]]] k_offset = [%[[KOFFSET]]] batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+// CHECK-SAME: m_offset = [%[[MOFFSET]]] * [1] k_offset = [%[[KOFFSET]]] * [1]
+// CHECK-SAME: batch_pos = [0] m_pos = [1, 2] k_pos = [3]
// CHECK-SAME: ins(%[[EXTRACTED_SLICE]] : tensor<?x?x?x?xf32>)
// CHECK-SAME: outs(%[[EXTRACTED_SLICE_0]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[IM2COL]] into %[[ARG6]]
// CHECK-SAME: [%[[ARG1]], %[[ARG3]], %[[ARG5]]] [%[[BSIZE]], %[[MSIZE]], %[[KSIZE]]] [1, 1, 1]
// CHECK-SAME: tensor<?x?x?xf32> into tensor<?x?x?xf32>
// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<?x?x?xf32>
-// CHECK: }
// CHECK: scf.yield %[[RES2]] : tensor<?x?x?xf32>
-// CHECK: }
// CHECK: scf.yield %[[RES1]] : tensor<?x?x?xf32>
-// CHECK: }
// CHECK: return %[[RES0]] : tensor<?x?x?xf32>
// -----
+module {
+ func.func @im2col_expanded(%arg0: tensor<2x34x34x640xf32>, %m_stride: index) -> tensor<2x32x32x1440x4xf32> {
+ %0 = tensor.empty() : tensor<2x32x32x1440x4xf32>
+ %7 = iree_linalg_ext.im2col
+ strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+ m_offset = [0, 0] * [%m_stride, 1] k_offset = [0, 0] * [4, 1]
+ batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+ ins(%arg0 : tensor<2x34x34x640xf32>)
+ outs(%0 : tensor<2x32x32x1440x4xf32>) -> tensor<2x32x32x1440x4xf32>
+ return %7 : tensor<2x32x32x1440x4xf32>
+ }
+}
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["iree_linalg_ext.im2col"]} in %module_op : (!transform.any_op) -> !transform.any_op
+ %1, %loops:5 = transform.structured.tile_using_for %0 tile_sizes [1, 7, 5, 11, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (-d0 + 32, 7)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (-d0 + 32, 5)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (-d0 + 1440, 11)>
+// CHECK: func.func @im2col_expanded
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x34x34x640xf32>
+// CHECK-SAME: %[[M_STRIDE:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[C11:.+]] = arith.constant 11 : index
+// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
+// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
+// CHECK-DAG: %[[C1440:.+]] = arith.constant 1440 : index
+// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x32x32x1440x4xf32>
+// CHECK: %[[RES0:.+]] = scf.for %[[ARG1:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK-SAME: iter_args(%[[ARG2:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<2x32x32x1440x4xf32>)
+// CHECK: %[[RES1:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C32]] step %[[C7]]
+// CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<2x32x32x1440x4xf32>)
+// CHECK: %[[RES2:.+]] = scf.for %[[ARG5:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C32]] step %[[C5]]
+// CHECK-SAME: iter_args(%[[ARG6:[a-zA-Z0-9_]+]] = %[[ARG4]]) -> (tensor<2x32x32x1440x4xf32>)
+// CHECK: %[[RES3:.+]] = scf.for %[[ARG7:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1440]] step %[[C11]]
+// CHECK-SAME: iter_args(%[[ARG8:[a-zA-Z0-9_]+]] = %[[ARG6]]) -> (tensor<2x32x32x1440x4xf32>)
+// CHECK: %[[RES4:.+]] = scf.for %[[ARG9:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C4]] step %[[C2]]
+// CHECK-SAME: iter_args(%[[ARG10:[a-zA-Z0-9_]+]] = %[[ARG8]]) -> (tensor<2x32x32x1440x4xf32>)
+// CHECK-DAG: %[[M0SIZE:.+]] = affine.min #[[MAP]](%[[ARG3]])
+// CHECK-DAG: %[[M1SIZE:.+]] = affine.min #[[MAP1]](%[[ARG5]])
+// CHECK-DAG: %[[K0SIZE:.+]] = affine.min #[[MAP2]](%[[ARG7]])
+// CHECK-DAG: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG1]], 0, 0, 0]
+// CHECK-SAME: [1, 34, 34, 640] [1, 1, 1, 1] : tensor<2x34x34x640xf32> to tensor<1x34x34x640xf32>
+// CHECK-DAG: %[[EXTRACTED_SLICE_0:.+]] = tensor.extract_slice %[[ARG10]][%[[ARG1]], %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]]
+// CHECK-SAME: [1, %[[M0SIZE]], %[[M1SIZE]], %[[K0SIZE]], 2] [1, 1, 1, 1, 1] : tensor<2x32x32x1440x4xf32> to tensor<1x?x?x?x2xf32>
+// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+// CHECK-SAME: m_offset = [%[[ARG3]], %[[ARG5]]] * [%[[M_STRIDE]], 1] k_offset = [%[[ARG7]], %[[ARG9]]] * [4, 1]
+// CHECK-SAME: batch_pos = [0] m_pos = [1, 2] k_pos = [3]
+// CHECK-SAME: ins(%[[EXTRACTED_SLICE]] : tensor<1x34x34x640xf32>)
+// CHECK-SAME: outs(%[[EXTRACTED_SLICE_0]] : tensor<1x?x?x?x2xf32>) -> tensor<1x?x?x?x2xf32>
+// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[IM2COL]] into %[[ARG10]]
+// CHECK-SAME: [%[[ARG1]], %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]] [1, %[[M0SIZE]], %[[M1SIZE]], %[[K0SIZE]], 2] [1, 1, 1, 1, 1]
+// CHECK-SAME: tensor<1x?x?x?x2xf32> into tensor<2x32x32x1440x4xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<2x32x32x1440x4xf32>
+// CHECK: scf.yield %[[RES4]] : tensor<2x32x32x1440x4xf32>
+// CHECK: scf.yield %[[RES3]] : tensor<2x32x32x1440x4xf32>
+// CHECK: scf.yield %[[RES2]] : tensor<2x32x32x1440x4xf32>
+// CHECK: scf.yield %[[RES1]] : tensor<2x32x32x1440x4xf32>
+// CHECK: return %[[RES0]] : tensor<2x32x32x1440x4xf32>
+
+// -----
+
func.func @winograd_filter_transform(%arg0: tensor<3x3x64x128xf32>) -> tensor<8x8x64x128xf32> {
%0 = tensor.empty() : tensor<8x8x64x128xf32>
%1 = iree_linalg_ext.winograd.filter_transform
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel
index 0da302b..0212d6e 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel
@@ -25,6 +25,7 @@
],
deps = [
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:IR",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt
index b4c519c..564ba5b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt
@@ -22,6 +22,7 @@
"Utils.cpp"
DEPS
LLVMSupport
+ MLIRAffineDialect
MLIRArithDialect
MLIRIR
MLIRLinalgDialect
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
index ed9524d..1aa2efa 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
@@ -8,6 +8,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -16,6 +17,22 @@
namespace mlir::iree_compiler::IREE::LinalgExt {
+OpFoldResult addOfrs(OpBuilder &builder, Location loc, OpFoldResult a,
+ OpFoldResult b) {
+ AffineExpr d0, d1;
+ bindDims(builder.getContext(), d0, d1);
+ auto addMap = AffineMap::get(2, 0, {d0 + d1});
+ return affine::makeComposedFoldedAffineApply(builder, loc, addMap, {a, b});
+}
+
+OpFoldResult mulOfrs(OpBuilder &builder, Location loc, OpFoldResult a,
+ OpFoldResult b) {
+ AffineExpr d0, d1;
+ bindDims(builder.getContext(), d0, d1);
+ auto addMap = AffineMap::get(2, 0, {d0 * d1});
+ return affine::makeComposedFoldedAffineApply(builder, loc, addMap, {a, b});
+}
+
Value getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim) {
ShapedType type = cast<ShapedType>(v.getType());
if (!type.isDynamicDim(dim)) {
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
index 4d1b986..a89cb67 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
@@ -20,6 +20,14 @@
namespace mlir::iree_compiler::IREE::LinalgExt {
+// Helper method to add 2 OpFoldResult inputs with affine.apply.
+OpFoldResult addOfrs(OpBuilder &builder, Location loc, OpFoldResult a,
+ OpFoldResult b);
+
+// Helper method to multiply 2 OpFoldResult inputs with affine.apply.
+OpFoldResult mulOfrs(OpBuilder &builder, Location loc, OpFoldResult a,
+ OpFoldResult b);
+
/// Returns a `memref.dim` or `tensor.dim` operation to get the shape of `v` at
/// `dim`.
Value getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim);