[Winograd] Generate winograd.filter_transform op in ConvertConv2DToWinograd (#17106)
This PR enables the `ConvertConv2DToWinograd` pass to generate the
`winograd.filter_transform` op instead of relying on constant folding
within the pass. This will allow faster and more reliable constant
folding (or const-expr-hoisting) for the filter transform.
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToWinograd.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToWinograd.cpp
index 6cabb8b..6ce35e1 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToWinograd.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToWinograd.cpp
@@ -44,65 +44,6 @@
// for more tile sizes
static constexpr int64_t outputTileSize = 6;
-/// This function computes the Winograd filter transform when
-/// the filter is known to be a constant. Specifically, this
-/// function computes matmul(G, matmul(F, transpose(G))) where
-/// F is a tile of the convolution filter of size m x m
-/// (single input channel, single output channel) and G has
-/// shape m x (m + r - 1) where r is the output tile size and
-/// (m + r - 1) is the input tile size.
-/// The time complexity of this function is O(ic * oc)
-/// where ic is the number of input channels and oc is the
-/// number of output channels since input tile size and kernel size
-/// are constants. So for large ic and oc, this function is
-/// time intensive.
-/// TODO: Codegen this as a kernel and run once at initialization
-static DenseElementsAttr
-foldFilterTransform(ArrayRef<int64_t> shape, int64_t inputTileSize,
- int64_t kernelSize, ShapedType outputType, const float *G,
- bool isSplat, float splatValue,
- DenseElementsAttr::iterator_range<APFloat> &input,
- FloatType floatType, bool isNchw) {
- const int &kh = isNchw ? shape[2] : shape[0];
- const int &kw = isNchw ? shape[3] : shape[1];
- const int &ic = isNchw ? shape[1] : shape[2];
- const int &oc = isNchw ? shape[0] : shape[3];
- const int64_t numElements = inputTileSize * inputTileSize * ic * oc;
- SmallVector<APFloat> output(numElements, APFloat(0.0f));
- for (int d0 = 0; d0 < inputTileSize; d0++) {
- for (int d1 = 0; d1 < inputTileSize; d1++) {
- for (int d2 = 0; d2 < ic; d2++) {
- for (int d3 = 0; d3 < oc; d3++) {
- APFloat accum(0.0f);
- for (int d4 = 0; d4 < kernelSize; d4++) {
- for (int d5 = 0; d5 < kernelSize; d5++) {
- APFloat ival(splatValue);
- if (!isSplat) {
- if (!isNchw) {
- ival = input[index(d4, d5, d2, d3, kh, kw, ic, oc)];
- } else {
- ival = input[index(d3, d2, d4, d5, oc, ic, kh, kw)];
- }
- }
- int idx0 = index(d0, d4, inputTileSize, kernelSize);
- int idx1 = index(d1, d5, inputTileSize, kernelSize);
- accum = accum + APFloat(G[idx0]) * ival * APFloat(G[idx1]);
- }
- }
- int odx = index(d0, d1, d2, d3, inputTileSize, inputTileSize, ic, oc);
- output[odx] = accum;
- if (floatType.isF16()) {
- bool losesInfo;
- output[odx].convert(APFloat::IEEEhalf(),
- APFloat::rmNearestTiesToEven, &losesInfo);
- }
- }
- }
- }
- }
- return DenseElementsAttr::get(outputType, output);
-}
-
template <typename T>
static bool hasValidStridesAndDilations(Operation *op) {
auto convOp = dyn_cast<T>(op);
@@ -128,73 +69,6 @@
: hasValidStridesAndDilations<linalg::Conv2DNhwcHwcfOp>(op));
}
-namespace {
-
-template <typename ConvOp>
-class FoldWinogradFilterTransform final : public OpRewritePattern<ConvOp> {
-public:
- using OpRewritePattern<ConvOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ConvOp convOp,
- PatternRewriter &rewriter) const override {
-
- bool isNchw;
- if (!isValidConv2d(convOp, isNchw)) {
- return failure();
- }
-
- // Check that kernel size = 3x3
- Value kernel = convOp.getInputs()[1];
- auto kernelType = cast<ShapedType>(kernel.getType());
- if (!kernelType) {
- return failure();
- }
- ArrayRef<int64_t> kernelShape = kernelType.getShape();
- if (kernelShape.size() != 4) {
- return failure();
- }
- const int64_t kh = isNchw ? kernelShape[2] : kernelShape[0];
- const int64_t kw = isNchw ? kernelShape[3] : kernelShape[1];
- if ((kh != 3) || (kw != 3)) {
- return failure();
- }
- const int64_t kernelSize = kh;
- const int64_t inputTileSize = outputTileSize + kernelSize - 1;
-
- DenseIntOrFPElementsAttr kernelAttr;
- if (!matchPattern(kernel, m_Constant(&kernelAttr))) {
- return failure();
- }
-
- Operation *constOp = kernel.getDefiningOp();
- ShapedType type = cast<ShapedType>(constOp->getResult(0).getType());
- auto elemType = cast<FloatType>(type.getElementType());
- ArrayRef<int64_t> shape = type.getShape();
- DenseElementsAttr::iterator_range<APFloat> nonSplatValues =
- kernelAttr.getValues<APFloat>();
- bool isSplat = kernelAttr.isSplat();
- float splatValue{0.0};
- if (isSplat) {
- splatValue = kernelAttr.getSplatValue<APFloat>().convertToFloat();
- }
- SmallVector<int64_t> resultShape{inputTileSize * inputTileSize, shape[2],
- shape[3]};
- if (isNchw) {
- resultShape[1] = shape[1];
- resultShape[2] = shape[0];
- }
- auto resultType = RankedTensorType::get(resultShape, elemType);
- auto foldedKernelAttr =
- foldFilterTransform(shape, inputTileSize, kernelSize, resultType,
- IREE::LinalgExt::Winograd::G_6x6_3x3, isSplat,
- splatValue, nonSplatValues, elemType, isNchw);
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, foldedKernelAttr);
- return success();
- }
-};
-
-} // namespace
-
static Value
createCollapse(Value tensor, Location loc, PatternRewriter &rewriter,
SmallVectorImpl<int64_t> &outputShape,
@@ -291,28 +165,61 @@
LogicalResult matchAndRewrite(ConvOp convOp,
PatternRewriter &rewriter) const override {
- bool isNchw;
- if (!isValidConv2d(convOp, isNchw)) {
+ bool isNchwFchw;
+ if (!isValidConv2d(convOp, isNchwFchw)) {
return failure();
}
- // Check that kernel has been constant folded (by validating rank = 3)
+ // Create winograd filter transform op.
Value kernel = convOp.getInputs()[1];
auto kernelType = cast<ShapedType>(kernel.getType());
if (!kernelType) {
return failure();
}
- Type elementType = kernelType.getElementType();
- ArrayRef<int64_t> kernelShape = kernelType.getShape();
- if (kernelShape.size() != 3) {
- return failure();
+ if (!kernelType.hasStaticShape()) {
+ return rewriter.notifyMatchFailure(convOp, "Kernel shape is not static");
}
+ SmallVector<int64_t> kernelShape(kernelType.getShape());
+ const int64_t kh = isNchwFchw ? kernelShape[2] : kernelShape[0];
+ const int64_t kw = isNchwFchw ? kernelShape[3] : kernelShape[1];
+ if (kh != 3 || kw != 3) {
+ return rewriter.notifyMatchFailure(convOp,
+ "Winograd only supports 3x3 filters");
+ }
+ assert(kernelShape.size() == 4);
+ Type elementType = kernelType.getElementType();
const int64_t kernelSize = 3;
const int64_t inputTileSize = outputTileSize + kernelSize - 1;
- // Create winograd input transform op
Location loc = convOp.getLoc();
+ const std::array<int64_t, 2> hwcfKernelDims = {0, 1};
+ const std::array<int64_t, 2> fchwKernelDims = {2, 3};
+ SmallVector<int64_t> filterResultShape(4, inputTileSize);
+ filterResultShape[2] = isNchwFchw ? kernelShape[1] : kernelShape[2];
+ filterResultShape[3] = isNchwFchw ? kernelShape[0] : kernelShape[3];
+ Value kernelInit =
+ rewriter.create<tensor::EmptyOp>(loc, filterResultShape, elementType);
+ const std::array<int64_t, 2> kernelDims =
+ isNchwFchw ? fchwKernelDims : hwcfKernelDims;
+ Value winogradFilter =
+ rewriter
+ .create<IREE::LinalgExt::WinogradFilterTransformOp>(
+ loc, kernelInit.getType(), ValueRange{kernel},
+ ValueRange{kernelInit}, outputTileSize, kernelSize, kernelDims)
+ .getResults()[0];
+
+ // Add collapse shape
+ SmallVector<int64_t> collapsedFilterShape;
+ collapsedFilterShape.push_back(filterResultShape[0] * filterResultShape[1]);
+ collapsedFilterShape.push_back(filterResultShape[2]);
+ collapsedFilterShape.push_back(filterResultShape[3]);
+ SmallVector<ReassociationIndices> filterReassociations = {{0, 1}, {2}, {3}};
+ Value collapsedWinogradFilter =
+ createCollapse(winogradFilter, loc, rewriter, collapsedFilterShape,
+ filterReassociations);
+
+ // Create winograd input transform op.
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementType));
Value input = convOp.getInputs()[0];
@@ -322,23 +229,23 @@
}
SmallVector<int64_t> inputShape(inputType.getShape());
if (llvm::any_of(inputShape, ShapedType::isDynamic)) {
- return failure();
+ return rewriter.notifyMatchFailure(convOp, "Input shape is not static");
}
assert(inputShape.size() == 4);
- if (isNchw) {
+ if (isNchwFchw) {
permute<IREE::LinalgExt::Permutation::NCHW_TO_NHWC>(inputShape);
}
- const std::array<int64_t, 2> nhwcImageDimensions{1, 2};
- const std::array<int64_t, 2> nchwImageDimensions{2, 3};
- const size_t numImageDims = nhwcImageDimensions.size();
+ const std::array<int64_t, 2> nhwcImageDims = {1, 2};
+ const std::array<int64_t, 2> nchwImageDims = {2, 3};
+ const size_t numImageDims = nhwcImageDims.size();
SmallVector<int64_t> resultShape(6, inputTileSize);
- llvm::SmallSetVector<int64_t, 2> imageDimensionsSet(
- nhwcImageDimensions.begin(), nhwcImageDimensions.end());
+ llvm::SmallSetVector<int64_t, 2> imageDimsSet(nhwcImageDims.begin(),
+ nhwcImageDims.end());
int outputIndex;
for (int i = 0; i < inputShape.size(); i++) {
outputIndex = i + numImageDims;
- if (!imageDimensionsSet.contains(i)) {
+ if (!imageDimsSet.contains(i)) {
resultShape[outputIndex] = inputShape[i];
} else {
resultShape[outputIndex] =
@@ -347,13 +254,14 @@
}
Value emptyTensor =
rewriter.create<tensor::EmptyOp>(loc, resultShape, elementType);
- auto &imageDimensions = isNchw ? nchwImageDimensions : nhwcImageDimensions;
- auto winogradInputOp =
- rewriter.create<IREE::LinalgExt::WinogradInputTransformOp>(
- loc, emptyTensor.getType(), ValueRange{input},
- ValueRange{emptyTensor}, outputTileSize, kernelSize,
- imageDimensions);
- Value winogradInput = winogradInputOp.getResult()[0];
+ const std::array<int64_t, 2> imageDims =
+ isNchwFchw ? nchwImageDims : nhwcImageDims;
+ Value winogradInput =
+ rewriter
+ .create<IREE::LinalgExt::WinogradInputTransformOp>(
+ loc, emptyTensor.getType(), ValueRange{input},
+ ValueRange{emptyTensor}, outputTileSize, kernelSize, imageDims)
+ .getResults()[0];
// Add collapse shape
SmallVector<int64_t> collapsedShape = {
@@ -368,7 +276,7 @@
Value output = convOp.getOutputs()[0];
auto outputType = cast<RankedTensorType>(output.getType());
SmallVector<int64_t> outputShape(outputType.getShape());
- if (isNchw) {
+ if (isNchwFchw) {
permute<IREE::LinalgExt::Permutation::NCHW_TO_NHWC>(outputShape);
}
bmmShape[2] = outputShape[3];
@@ -377,7 +285,8 @@
auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange{zero},
ValueRange{emptyTensor});
auto bmmOp = rewriter.create<linalg::BatchMatmulOp>(
- loc, bmmOutputType, ValueRange({collapsedWinogradInput, kernel}),
+ loc, bmmOutputType,
+ ValueRange({collapsedWinogradInput, collapsedWinogradFilter}),
ValueRange({fillOp.result()}));
Value bmmResult = bmmOp.getResult(0);
@@ -392,37 +301,35 @@
// Convert back into original domain
SmallVector<int64_t> paddedResultShape(outputShape.size(), 0);
for (int i = 0; i < outputShape.size(); i++) {
- if (!imageDimensionsSet.contains(i)) {
+ if (!imageDimsSet.contains(i)) {
paddedResultShape[i] = outputShape[i];
} else {
paddedResultShape[i] = resultShape[i + numImageDims] * outputTileSize;
}
}
- if (isNchw) {
+ if (isNchwFchw) {
permute<IREE::LinalgExt::Permutation::NHWC_TO_NCHW>(paddedResultShape);
}
emptyTensor =
rewriter.create<tensor::EmptyOp>(loc, paddedResultShape, elementType);
- auto winogradOutputOp =
- rewriter.create<IREE::LinalgExt::WinogradOutputTransformOp>(
- loc, emptyTensor.getType(), ValueRange{expandedBmmResult},
- ValueRange{emptyTensor}, outputTileSize, kernelSize,
- imageDimensions);
- Value paddedOutput = winogradOutputOp.getResult()[0];
+ Value paddedOutput =
+ rewriter
+ .create<IREE::LinalgExt::WinogradOutputTransformOp>(
+ loc, emptyTensor.getType(), ValueRange{expandedBmmResult},
+ ValueRange{emptyTensor}, outputTileSize, kernelSize, imageDims)
+ .getResults()[0];
// Extract slice
SmallVector<OpFoldResult> offsets(outputShape.size(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(outputShape.size(),
rewriter.getIndexAttr(1));
- SmallVector<OpFoldResult> sizes;
- for (const int64_t shape : outputType.getShape())
- sizes.push_back(rewriter.getIndexAttr(shape));
+ SmallVector<OpFoldResult> sizes =
+ getAsIndexOpFoldResult(rewriter.getContext(), outputType.getShape());
auto winogradOutput = rewriter.create<tensor::ExtractSliceOp>(
loc, outputType, paddedOutput, offsets, sizes, strides);
- Value result = convOp.getResult(0);
- result.replaceAllUsesWith(winogradOutput);
+ rewriter.replaceOp(convOp, winogradOutput);
return success();
}
};
@@ -436,9 +343,7 @@
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(&getContext());
- patterns.insert<FoldWinogradFilterTransform<linalg::Conv2DNchwFchwOp>,
- FoldWinogradFilterTransform<linalg::Conv2DNhwcHwcfOp>,
- ConvertConvToWinograd<linalg::Conv2DNhwcHwcfOp>,
+ patterns.insert<ConvertConvToWinograd<linalg::Conv2DNhwcHwcfOp>,
ConvertConvToWinograd<linalg::Conv2DNchwFchwOp>>(context);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_winograd.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_winograd.mlir
index 42b781d..b5be6e5 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_winograd.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_winograd.mlir
@@ -1,146 +1,87 @@
// RUN: iree-opt --split-input-file -iree-linalg-ext-convert-conv2d-to-winograd -mlir-elide-elementsattrs-if-larger=4 %s | FileCheck %s
-func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
- %c0 = arith.constant dense<0.1> : tensor<3x3x4x16xf32>
+func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
%0 = linalg.conv_2d_nhwc_hwcf
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
- ins(%arg0, %c0: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
+ ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
return %0 : tensor<1x14x14x16xf32>
}
-// CHECK: func.func @conv_16433136(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x16x16x4xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
-// CHECK-SAME: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
-// CHECK: %[[CST:.+]] = arith.constant dense_resource<__elided__> : tensor<64x4x16xf32>
-// CHECK: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[D0:.+]] = tensor.empty() : tensor<8x8x1x3x3x4xf32>
-// CHECK: %[[D1:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<1x16x16x4xf32>) outs(%[[D0]] :
+// CHECK: func.func @conv_16433136(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x16x16x4xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<3x3x4x16xf32>
+// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[EMPTY0:.+]] = tensor.empty() : tensor<8x8x4x16xf32>
+// CHECK: %[[FILTER_TF:.+]] = iree_linalg_ext.winograd.filter_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: kernel_dimensions([0, 1]) ins(%[[ARG1]] : tensor<3x3x4x16xf32>) outs(%[[EMPTY0]] :
+// CHECK-SAME: tensor<8x8x4x16xf32>) -> tensor<8x8x4x16xf32>
+// CHECK: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER_TF]]
+// CHECK-SAME{LITERAL}: [[0, 1], [2], [3]]
+// CHECK-SAME: tensor<8x8x4x16xf32> into tensor<64x4x16xf32>
+// CHECK: %[[EMPTY1:.+]] = tensor.empty() : tensor<8x8x1x3x3x4xf32>
+// CHECK: %[[INPUT_TF:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<1x16x16x4xf32>) outs(%[[EMPTY1]] :
// CHECK-SAME: tensor<8x8x1x3x3x4xf32>) -> tensor<8x8x1x3x3x4xf32>
-// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[D1]]
+// CHECK: %[[COLLAPSED_INPUT:.+]] = tensor.collapse_shape %[[INPUT_TF]]
// CHECK-SAME{LITERAL}: [[0, 1], [2, 3, 4], [5]]
// CHECK-SAME: tensor<8x8x1x3x3x4xf32> into tensor<64x9x4xf32>
-// CHECK: %[[D2:.+]] = tensor.empty() : tensor<64x9x16xf32>
-// CHECK: %[[D3:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D2]] : tensor<64x9x16xf32>) ->
+// CHECK: %[[EMPTY2:.+]] = tensor.empty() : tensor<64x9x16xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[EMPTY2]] : tensor<64x9x16xf32>) ->
// CHECK-SAME: tensor<64x9x16xf32>
-// CHECK: %[[D4:.+]] = linalg.batch_matmul ins(%[[COLLAPSED]], %[[CST]] : tensor<64x9x4xf32>,
-// CHECK-SAME: tensor<64x4x16xf32>) outs(%[[D3]] : tensor<64x9x16xf32>) -> tensor<64x9x16xf32>
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[D4]]
+// CHECK: %[[BMM:.+]] = linalg.batch_matmul ins(%[[COLLAPSED_INPUT]], %[[COLLAPSED_FILTER]] : tensor<64x9x4xf32>,
+// CHECK-SAME: tensor<64x4x16xf32>) outs(%[[FILL]] : tensor<64x9x16xf32>) -> tensor<64x9x16xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[BMM]]
// CHECK-SAME{LITERAL}: [[0, 1], [2, 3, 4], [5]]
// CHECK-SAME: tensor<64x9x16xf32> into tensor<8x8x1x3x3x16xf32>
-// CHECK: %[[D5:.+]] = tensor.empty() : tensor<1x18x18x16xf32>
-// CHECK: %[[D6:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME: image_dimensions([1, 2]) ins(%[[EXPANDED]] : tensor<8x8x1x3x3x16xf32>) outs(%[[D5]] :
+// CHECK: %[[EMPTY3:.+]] = tensor.empty() : tensor<1x18x18x16xf32>
+// CHECK: %[[OUTPUT_TF:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: image_dimensions([1, 2]) ins(%[[EXPANDED]] : tensor<8x8x1x3x3x16xf32>) outs(%[[EMPTY3]] :
// CHECK-SAME: tensor<1x18x18x16xf32>) -> tensor<1x18x18x16xf32>
-// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[D6]][0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] :
+// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[OUTPUT_TF]][0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] :
// CHECK-SAME: tensor<1x18x18x16xf32> to tensor<1x14x14x16xf32>
// CHECK: return %[[EXTRACTED_SLICE]] : tensor<1x14x14x16xf32>
// CHECK: }
// -----
-func.func @conv2d_non_splat_weights(%inputs : tensor<1x4x4x1xf32>, %arg2: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> {
- %c0 = arith.constant dense<[[ [[1.0]], [[3.0]], [[5.0]] ],
- [ [[7.0]], [[9.0]], [[11.0]] ],
- [ [[13.0]], [[15.0]], [[17.0]] ]]> : tensor<3x3x1x1xf32>
- %0 = linalg.conv_2d_nhwc_hwcf
- {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
- ins(%inputs, %c0: tensor<1x4x4x1xf32>, tensor<3x3x1x1xf32>)
- outs(%arg2: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
- return %0 : tensor<1x2x2x1xf32>
-}
-// CHECK: func.func @conv2d_non_splat_weights(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x4x4x1xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> {
-// CHECK: %[[CST:.+]] = arith.constant dense_resource<__elided__> : tensor<64x1x1xf32>
-// CHECK: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[D0:.+]] = tensor.empty() : tensor<8x8x1x1x1x1xf32>
-// CHECK: %[[D1:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME: image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<1x4x4x1xf32>) outs(%[[D0]] : tensor<8x8x1x1x1x1xf32>)
-// CHECK-SAME: -> tensor<8x8x1x1x1x1xf32>
-// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[D1]]
-// CHECK-SAME{LITERAL}: [[0, 1], [2, 3, 4], [5]]
-// CHECK-SAME: tensor<8x8x1x1x1x1xf32> into tensor<64x1x1xf32>
-// CHECK: %[[D2:.+]] = tensor.empty() : tensor<64x1x1xf32>
-// CHECK: %[[D3:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D2]] : tensor<64x1x1xf32>) ->
-// CHECK-SAME: tensor<64x1x1xf32>
-// CHECK: %[[D4:.+]] = linalg.batch_matmul ins(%[[COLLAPSED]], %[[CST]] : tensor<64x1x1xf32>, tensor<64x1x1xf32>)
-// CHECK-SAME: outs(%[[D3]] : tensor<64x1x1xf32>) -> tensor<64x1x1xf32>
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[D4]]
-// CHECK-SAME{LITERAL}: [[0, 1], [2, 3, 4], [5]]
-// CHECK-SAME: tensor<64x1x1xf32> into tensor<8x8x1x1x1x1xf32>
-// CHECK: %[[D5:.+]] = tensor.empty() : tensor<1x6x6x1xf32>
-// CHECK: %[[D6:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME: image_dimensions([1, 2]) ins(%[[EXPANDED]] : tensor<8x8x1x1x1x1xf32>) outs(%[[D5]] :
-// CHECK-SAME: tensor<1x6x6x1xf32>) -> tensor<1x6x6x1xf32>
-// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[D6]][0, 0, 0, 0] [1, 2, 2, 1] [1, 1, 1, 1] :
-// CHECK-SAME: tensor<1x6x6x1xf32> to tensor<1x2x2x1xf32>
-// CHECK: return %[[EXTRACTED_SLICE]] : tensor<1x2x2x1xf32>
-// CHECK: }
-
-// -----
-
-func.func @conv_16433136_nchw_fchw(%arg0: tensor<1x4x16x16xf32>, %arg2: tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32> {
- %c0 = arith.constant dense<0.1> : tensor<16x4x3x3xf32>
+func.func @conv_16433136_nchw_fchw(%arg0: tensor<1x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32> {
%0 = linalg.conv_2d_nchw_fchw
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
- ins(%arg0, %c0: tensor<1x4x16x16xf32>, tensor<16x4x3x3xf32>)
+ ins(%arg0, %arg1: tensor<1x4x16x16xf32>, tensor<16x4x3x3xf32>)
outs(%arg2: tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32>
return %0 : tensor<1x16x14x14xf32>
}
-// CHECK: func.func @conv_16433136_nchw_fchw(%[[ARG0]]: tensor<1x4x16x16xf32>, %[[ARG1]]: tensor<1x16x14x14xf32>)
-// CHECK-SAME: -> tensor<1x16x14x14xf32> {
-// CHECK: %[[CST]] = arith.constant dense_resource<__elided__> : tensor<64x4x16xf32>
-// CHECK: %[[CST_0]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[D0]] = tensor.empty() : tensor<8x8x1x3x3x4xf32>
-// CHECK: %[[D1]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME: image_dimensions([2, 3]) ins(%[[ARG0]] : tensor<1x4x16x16xf32>) outs(%[[D0]] :
+// CHECK: func.func @conv_16433136_nchw_fchw(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x4x16x16xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<16x4x3x3xf32>
+// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[EMPTY0:.+]] = tensor.empty() : tensor<8x8x4x16xf32>
+// CHECK: %[[FILTER_TF:.+]] = iree_linalg_ext.winograd.filter_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: kernel_dimensions([2, 3]) ins(%[[ARG1]] : tensor<16x4x3x3xf32>) outs(%[[EMPTY0]] :
+// CHECK-SAME: tensor<8x8x4x16xf32>) -> tensor<8x8x4x16xf32>
+// CHECK: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER_TF]]
+// CHECK-SAME{LITERAL}: [[0, 1], [2], [3]]
+// CHECK-SAME: tensor<8x8x4x16xf32> into tensor<64x4x16xf32>
+// CHECK: %[[EMPTY1:.+]] = tensor.empty() : tensor<8x8x1x3x3x4xf32>
+// CHECK: %[[INPUT_TF:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: image_dimensions([2, 3]) ins(%[[ARG0]] : tensor<1x4x16x16xf32>) outs(%[[EMPTY1]] :
// CHECK-SAME: tensor<8x8x1x3x3x4xf32>) -> tensor<8x8x1x3x3x4xf32>
-// CHECK: %[[COLLAPSED]] = tensor.collapse_shape %[[D1]]
-// CHECK: %[[D2]] = tensor.empty() : tensor<64x9x16xf32>
-// CHECK: %[[D3]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D2]] : tensor<64x9x16xf32>) -> tensor<64x9x16xf32>
-// CHECK: %[[D4]] = linalg.batch_matmul ins(%[[COLLAPSED]], %[[CST]] : tensor<64x9x4xf32>, tensor<64x4x16xf32>)
-// CHECK-SAME: outs(%[[D3]] : tensor<64x9x16xf32>) -> tensor<64x9x16xf32>
-// CHECK: %[[EXPANDED]] = tensor.expand_shape %[[D4]]
-// CHECK: %[[D5]] = tensor.empty() : tensor<1x16x18x18xf32>
-// CHECK: %[[D6]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME: image_dimensions([2, 3]) ins(%[[EXPANDED]] : tensor<8x8x1x3x3x16xf32>) outs(%[[D5]] :
+// CHECK: %[[COLLAPSED_INPUT:.+]] = tensor.collapse_shape %[[INPUT_TF]]
+// CHECK-SAME{LITERAL}: [[0, 1], [2, 3, 4], [5]]
+// CHECK-SAME: tensor<8x8x1x3x3x4xf32> into tensor<64x9x4xf32>
+// CHECK: %[[EMPTY2:.+]] = tensor.empty() : tensor<64x9x16xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[EMPTY2]] : tensor<64x9x16xf32>) ->
+// CHECK-SAME: tensor<64x9x16xf32>
+// CHECK: %[[BMM:.+]] = linalg.batch_matmul ins(%[[COLLAPSED_INPUT]], %[[COLLAPSED_FILTER]] : tensor<64x9x4xf32>,
+// CHECK-SAME: tensor<64x4x16xf32>) outs(%[[FILL]] : tensor<64x9x16xf32>) -> tensor<64x9x16xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[BMM]]
+// CHECK-SAME{LITERAL}: [[0, 1], [2, 3, 4], [5]]
+// CHECK-SAME: tensor<64x9x16xf32> into tensor<8x8x1x3x3x16xf32>
+// CHECK: %[[EMPTY3:.+]] = tensor.empty() : tensor<1x16x18x18xf32>
+// CHECK: %[[OUTPUT_TF:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME: image_dimensions([2, 3]) ins(%[[EXPANDED]] : tensor<8x8x1x3x3x16xf32>) outs(%[[EMPTY3]] :
// CHECK-SAME: tensor<1x16x18x18xf32>) -> tensor<1x16x18x18xf32>
-// CHECK: %[[EXTRACTED_SLICE]] = tensor.extract_slice %[[D6]][0, 0, 0, 0] [1, 16, 14, 14] [1, 1, 1, 1] :
+// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[OUTPUT_TF]][0, 0, 0, 0] [1, 16, 14, 14] [1, 1, 1, 1] :
// CHECK-SAME: tensor<1x16x18x18xf32> to tensor<1x16x14x14xf32>
// CHECK: return %[[EXTRACTED_SLICE]] : tensor<1x16x14x14xf32>
// CHECK: }
-// CHECK: }
-
-// -----
-
-func.func @conv2d_nchw_non_splat_weights(%inputs : tensor<1x1x4x4xf32>, %arg2: tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32> {
- %c0 = arith.constant dense<[[[[ 1.0, 3.0, 5.0 ],
- [ 7.0, 9.0, 11.0 ],
- [ 13.0, 15.0, 17.0 ]]]]> : tensor<1x1x3x3xf32>
- %0 = linalg.conv_2d_nchw_fchw
- {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
- ins(%inputs, %c0: tensor<1x1x4x4xf32>, tensor<1x1x3x3xf32>)
- outs(%arg2: tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32>
- return %0 : tensor<1x1x2x2xf32>
-}
-// CHECK: func.func @conv2d_nchw_non_splat_weights(%[[ARG0]]: tensor<1x1x4x4xf32>, %[[ARG1]]: tensor<1x1x2x2xf32>)
-// CHECK-SAME: -> tensor<1x1x2x2xf32> {
-// CHECK: %[[CST]] = arith.constant dense_resource<__elided__> : tensor<64x1x1xf32>
-// CHECK: %[[CST_0]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[D0]] = tensor.empty() : tensor<8x8x1x1x1x1xf32>
-// CHECK: %[[D1]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME: image_dimensions([2, 3]) ins(%[[ARG0]] : tensor<1x1x4x4xf32>) outs(%[[D0]] : tensor<8x8x1x1x1x1xf32>)
-// CHECK-SAME: -> tensor<8x8x1x1x1x1xf32>
-// CHECK: %[[COLLAPSED]] = tensor.collapse_shape %[[D1]]
-// CHECK: %[[D2]] = tensor.empty() : tensor<64x1x1xf32>
-// CHECK: %[[D3]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D2]] : tensor<64x1x1xf32>) -> tensor<64x1x1xf32>
-// CHECK: %[[D4]] = linalg.batch_matmul ins(%[[COLLAPSED]], %[[CST]] : tensor<64x1x1xf32>, tensor<64x1x1xf32>)
-// CHECK-SAME: outs(%[[D3]] : tensor<64x1x1xf32>) -> tensor<64x1x1xf32>
-// CHECK: %[[EXPANDED]] = tensor.expand_shape %[[D4]]
-// CHECK: %[[D5]] = tensor.empty() : tensor<1x1x6x6xf32>
-// CHECK: %[[D6]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME: image_dimensions([2, 3]) ins(%[[EXPANDED]] : tensor<8x8x1x1x1x1xf32>) outs(%[[D5]] :
-// CHECK-SAME: tensor<1x1x6x6xf32>) -> tensor<1x1x6x6xf32>
-// CHECK: %[[EXTRACTED_SLICE]] = tensor.extract_slice %[[D6]][0, 0, 0, 0] [1, 1, 2, 2] [1, 1, 1, 1] :
-// CHECK-SAME: tensor<1x1x6x6xf32> to tensor<1x1x2x2xf32>
-// CHECK: return %[[EXTRACTED_SLICE]] : tensor<1x1x2x2xf32>
-// CHECK: }