[Winograd] Generate winograd.filter_transform op in ConvertConv2DToWinograd (#17106)

This PR enables the `ConvertConv2DToWinograd` pass to generate the
`winograd.filter_transform` op instead of relying on constant folding
within the pass. This will allow faster and more reliable constant
folding (or const-expr-hoisting) for the filter transform.
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToWinograd.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToWinograd.cpp
index 6cabb8b..6ce35e1 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToWinograd.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToWinograd.cpp
@@ -44,65 +44,6 @@
 // for more tile sizes
 static constexpr int64_t outputTileSize = 6;
 
-/// This function computes the Winograd filter transform when
-/// the filter is known to be a constant. Specifically, this
-/// function computes matmul(G, matmul(F, transpose(G))) where
-/// F is a tile of the convolution filter of size m x m
-/// (single input channel, single output channel) and G has
-/// shape m x (m + r - 1) where r is the output tile size and
-/// (m + r - 1) is the input tile size.
-/// The time complexity of this function is O(ic * oc)
-/// where ic is the number of input channels and oc is the
-/// number of output channels since input tile size and kernel size
-/// are constants. So for large ic and oc, this function is
-/// time intensive.
-/// TODO: Codegen this as a kernel and run once at initialization
-static DenseElementsAttr
-foldFilterTransform(ArrayRef<int64_t> shape, int64_t inputTileSize,
-                    int64_t kernelSize, ShapedType outputType, const float *G,
-                    bool isSplat, float splatValue,
-                    DenseElementsAttr::iterator_range<APFloat> &input,
-                    FloatType floatType, bool isNchw) {
-  const int &kh = isNchw ? shape[2] : shape[0];
-  const int &kw = isNchw ? shape[3] : shape[1];
-  const int &ic = isNchw ? shape[1] : shape[2];
-  const int &oc = isNchw ? shape[0] : shape[3];
-  const int64_t numElements = inputTileSize * inputTileSize * ic * oc;
-  SmallVector<APFloat> output(numElements, APFloat(0.0f));
-  for (int d0 = 0; d0 < inputTileSize; d0++) {
-    for (int d1 = 0; d1 < inputTileSize; d1++) {
-      for (int d2 = 0; d2 < ic; d2++) {
-        for (int d3 = 0; d3 < oc; d3++) {
-          APFloat accum(0.0f);
-          for (int d4 = 0; d4 < kernelSize; d4++) {
-            for (int d5 = 0; d5 < kernelSize; d5++) {
-              APFloat ival(splatValue);
-              if (!isSplat) {
-                if (!isNchw) {
-                  ival = input[index(d4, d5, d2, d3, kh, kw, ic, oc)];
-                } else {
-                  ival = input[index(d3, d2, d4, d5, oc, ic, kh, kw)];
-                }
-              }
-              int idx0 = index(d0, d4, inputTileSize, kernelSize);
-              int idx1 = index(d1, d5, inputTileSize, kernelSize);
-              accum = accum + APFloat(G[idx0]) * ival * APFloat(G[idx1]);
-            }
-          }
-          int odx = index(d0, d1, d2, d3, inputTileSize, inputTileSize, ic, oc);
-          output[odx] = accum;
-          if (floatType.isF16()) {
-            bool losesInfo;
-            output[odx].convert(APFloat::IEEEhalf(),
-                                APFloat::rmNearestTiesToEven, &losesInfo);
-          }
-        }
-      }
-    }
-  }
-  return DenseElementsAttr::get(outputType, output);
-}
-
 template <typename T>
 static bool hasValidStridesAndDilations(Operation *op) {
   auto convOp = dyn_cast<T>(op);
@@ -128,73 +69,6 @@
                  : hasValidStridesAndDilations<linalg::Conv2DNhwcHwcfOp>(op));
 }
 
-namespace {
-
-template <typename ConvOp>
-class FoldWinogradFilterTransform final : public OpRewritePattern<ConvOp> {
-public:
-  using OpRewritePattern<ConvOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ConvOp convOp,
-                                PatternRewriter &rewriter) const override {
-
-    bool isNchw;
-    if (!isValidConv2d(convOp, isNchw)) {
-      return failure();
-    }
-
-    // Check that kernel size = 3x3
-    Value kernel = convOp.getInputs()[1];
-    auto kernelType = cast<ShapedType>(kernel.getType());
-    if (!kernelType) {
-      return failure();
-    }
-    ArrayRef<int64_t> kernelShape = kernelType.getShape();
-    if (kernelShape.size() != 4) {
-      return failure();
-    }
-    const int64_t kh = isNchw ? kernelShape[2] : kernelShape[0];
-    const int64_t kw = isNchw ? kernelShape[3] : kernelShape[1];
-    if ((kh != 3) || (kw != 3)) {
-      return failure();
-    }
-    const int64_t kernelSize = kh;
-    const int64_t inputTileSize = outputTileSize + kernelSize - 1;
-
-    DenseIntOrFPElementsAttr kernelAttr;
-    if (!matchPattern(kernel, m_Constant(&kernelAttr))) {
-      return failure();
-    }
-
-    Operation *constOp = kernel.getDefiningOp();
-    ShapedType type = cast<ShapedType>(constOp->getResult(0).getType());
-    auto elemType = cast<FloatType>(type.getElementType());
-    ArrayRef<int64_t> shape = type.getShape();
-    DenseElementsAttr::iterator_range<APFloat> nonSplatValues =
-        kernelAttr.getValues<APFloat>();
-    bool isSplat = kernelAttr.isSplat();
-    float splatValue{0.0};
-    if (isSplat) {
-      splatValue = kernelAttr.getSplatValue<APFloat>().convertToFloat();
-    }
-    SmallVector<int64_t> resultShape{inputTileSize * inputTileSize, shape[2],
-                                     shape[3]};
-    if (isNchw) {
-      resultShape[1] = shape[1];
-      resultShape[2] = shape[0];
-    }
-    auto resultType = RankedTensorType::get(resultShape, elemType);
-    auto foldedKernelAttr =
-        foldFilterTransform(shape, inputTileSize, kernelSize, resultType,
-                            IREE::LinalgExt::Winograd::G_6x6_3x3, isSplat,
-                            splatValue, nonSplatValues, elemType, isNchw);
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, foldedKernelAttr);
-    return success();
-  }
-};
-
-} // namespace
-
 static Value
 createCollapse(Value tensor, Location loc, PatternRewriter &rewriter,
                SmallVectorImpl<int64_t> &outputShape,
@@ -291,28 +165,61 @@
   LogicalResult matchAndRewrite(ConvOp convOp,
                                 PatternRewriter &rewriter) const override {
 
-    bool isNchw;
-    if (!isValidConv2d(convOp, isNchw)) {
+    bool isNchwFchw;
+    if (!isValidConv2d(convOp, isNchwFchw)) {
       return failure();
     }
 
-    // Check that kernel has been constant folded (by validating rank = 3)
+    // Create winograd filter transform op.
     Value kernel = convOp.getInputs()[1];
     auto kernelType = cast<ShapedType>(kernel.getType());
     if (!kernelType) {
       return failure();
     }
-    Type elementType = kernelType.getElementType();
-    ArrayRef<int64_t> kernelShape = kernelType.getShape();
-    if (kernelShape.size() != 3) {
-      return failure();
+    if (!kernelType.hasStaticShape()) {
+      return rewriter.notifyMatchFailure(convOp, "Kernel shape is not static");
     }
+    SmallVector<int64_t> kernelShape(kernelType.getShape());
+    const int64_t kh = isNchwFchw ? kernelShape[2] : kernelShape[0];
+    const int64_t kw = isNchwFchw ? kernelShape[3] : kernelShape[1];
+    if (kh != 3 || kw != 3) {
+      return rewriter.notifyMatchFailure(convOp,
+                                         "Winograd only supports 3x3 filters");
+    }
+    assert(kernelShape.size() == 4);
+    Type elementType = kernelType.getElementType();
 
     const int64_t kernelSize = 3;
     const int64_t inputTileSize = outputTileSize + kernelSize - 1;
 
-    // Create winograd input transform op
     Location loc = convOp.getLoc();
+    const std::array<int64_t, 2> hwcfKernelDims = {0, 1};
+    const std::array<int64_t, 2> fchwKernelDims = {2, 3};
+    SmallVector<int64_t> filterResultShape(4, inputTileSize);
+    filterResultShape[2] = isNchwFchw ? kernelShape[1] : kernelShape[2];
+    filterResultShape[3] = isNchwFchw ? kernelShape[0] : kernelShape[3];
+    Value kernelInit =
+        rewriter.create<tensor::EmptyOp>(loc, filterResultShape, elementType);
+    const std::array<int64_t, 2> kernelDims =
+        isNchwFchw ? fchwKernelDims : hwcfKernelDims;
+    Value winogradFilter =
+        rewriter
+            .create<IREE::LinalgExt::WinogradFilterTransformOp>(
+                loc, kernelInit.getType(), ValueRange{kernel},
+                ValueRange{kernelInit}, outputTileSize, kernelSize, kernelDims)
+            .getResults()[0];
+
+    // Add collapse shape
+    SmallVector<int64_t> collapsedFilterShape;
+    collapsedFilterShape.push_back(filterResultShape[0] * filterResultShape[1]);
+    collapsedFilterShape.push_back(filterResultShape[2]);
+    collapsedFilterShape.push_back(filterResultShape[3]);
+    SmallVector<ReassociationIndices> filterReassociations = {{0, 1}, {2}, {3}};
+    Value collapsedWinogradFilter =
+        createCollapse(winogradFilter, loc, rewriter, collapsedFilterShape,
+                       filterReassociations);
+
+    // Create winograd input transform op.
     Value zero = rewriter.create<arith::ConstantOp>(
         loc, rewriter.getZeroAttr(elementType));
     Value input = convOp.getInputs()[0];
@@ -322,23 +229,23 @@
     }
     SmallVector<int64_t> inputShape(inputType.getShape());
     if (llvm::any_of(inputShape, ShapedType::isDynamic)) {
-      return failure();
+      return rewriter.notifyMatchFailure(convOp, "Input shape is not static");
     }
     assert(inputShape.size() == 4);
-    if (isNchw) {
+    if (isNchwFchw) {
       permute<IREE::LinalgExt::Permutation::NCHW_TO_NHWC>(inputShape);
     }
 
-    const std::array<int64_t, 2> nhwcImageDimensions{1, 2};
-    const std::array<int64_t, 2> nchwImageDimensions{2, 3};
-    const size_t numImageDims = nhwcImageDimensions.size();
+    const std::array<int64_t, 2> nhwcImageDims = {1, 2};
+    const std::array<int64_t, 2> nchwImageDims = {2, 3};
+    const size_t numImageDims = nhwcImageDims.size();
     SmallVector<int64_t> resultShape(6, inputTileSize);
-    llvm::SmallSetVector<int64_t, 2> imageDimensionsSet(
-        nhwcImageDimensions.begin(), nhwcImageDimensions.end());
+    llvm::SmallSetVector<int64_t, 2> imageDimsSet(nhwcImageDims.begin(),
+                                                  nhwcImageDims.end());
     int outputIndex;
     for (int i = 0; i < inputShape.size(); i++) {
       outputIndex = i + numImageDims;
-      if (!imageDimensionsSet.contains(i)) {
+      if (!imageDimsSet.contains(i)) {
         resultShape[outputIndex] = inputShape[i];
       } else {
         resultShape[outputIndex] =
@@ -347,13 +254,14 @@
     }
     Value emptyTensor =
         rewriter.create<tensor::EmptyOp>(loc, resultShape, elementType);
-    auto &imageDimensions = isNchw ? nchwImageDimensions : nhwcImageDimensions;
-    auto winogradInputOp =
-        rewriter.create<IREE::LinalgExt::WinogradInputTransformOp>(
-            loc, emptyTensor.getType(), ValueRange{input},
-            ValueRange{emptyTensor}, outputTileSize, kernelSize,
-            imageDimensions);
-    Value winogradInput = winogradInputOp.getResult()[0];
+    const std::array<int64_t, 2> imageDims =
+        isNchwFchw ? nchwImageDims : nhwcImageDims;
+    Value winogradInput =
+        rewriter
+            .create<IREE::LinalgExt::WinogradInputTransformOp>(
+                loc, emptyTensor.getType(), ValueRange{input},
+                ValueRange{emptyTensor}, outputTileSize, kernelSize, imageDims)
+            .getResults()[0];
 
     // Add collapse shape
     SmallVector<int64_t> collapsedShape = {
@@ -368,7 +276,7 @@
     Value output = convOp.getOutputs()[0];
     auto outputType = cast<RankedTensorType>(output.getType());
     SmallVector<int64_t> outputShape(outputType.getShape());
-    if (isNchw) {
+    if (isNchwFchw) {
       permute<IREE::LinalgExt::Permutation::NCHW_TO_NHWC>(outputShape);
     }
     bmmShape[2] = outputShape[3];
@@ -377,7 +285,8 @@
     auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange{zero},
                                                   ValueRange{emptyTensor});
     auto bmmOp = rewriter.create<linalg::BatchMatmulOp>(
-        loc, bmmOutputType, ValueRange({collapsedWinogradInput, kernel}),
+        loc, bmmOutputType,
+        ValueRange({collapsedWinogradInput, collapsedWinogradFilter}),
         ValueRange({fillOp.result()}));
     Value bmmResult = bmmOp.getResult(0);
 
@@ -392,37 +301,35 @@
     // Convert back into original domain
     SmallVector<int64_t> paddedResultShape(outputShape.size(), 0);
     for (int i = 0; i < outputShape.size(); i++) {
-      if (!imageDimensionsSet.contains(i)) {
+      if (!imageDimsSet.contains(i)) {
         paddedResultShape[i] = outputShape[i];
       } else {
         paddedResultShape[i] = resultShape[i + numImageDims] * outputTileSize;
       }
     }
-    if (isNchw) {
+    if (isNchwFchw) {
       permute<IREE::LinalgExt::Permutation::NHWC_TO_NCHW>(paddedResultShape);
     }
     emptyTensor =
         rewriter.create<tensor::EmptyOp>(loc, paddedResultShape, elementType);
-    auto winogradOutputOp =
-        rewriter.create<IREE::LinalgExt::WinogradOutputTransformOp>(
-            loc, emptyTensor.getType(), ValueRange{expandedBmmResult},
-            ValueRange{emptyTensor}, outputTileSize, kernelSize,
-            imageDimensions);
-    Value paddedOutput = winogradOutputOp.getResult()[0];
+    Value paddedOutput =
+        rewriter
+            .create<IREE::LinalgExt::WinogradOutputTransformOp>(
+                loc, emptyTensor.getType(), ValueRange{expandedBmmResult},
+                ValueRange{emptyTensor}, outputTileSize, kernelSize, imageDims)
+            .getResults()[0];
 
     // Extract slice
     SmallVector<OpFoldResult> offsets(outputShape.size(),
                                       rewriter.getIndexAttr(0));
     SmallVector<OpFoldResult> strides(outputShape.size(),
                                       rewriter.getIndexAttr(1));
-    SmallVector<OpFoldResult> sizes;
-    for (const int64_t shape : outputType.getShape())
-      sizes.push_back(rewriter.getIndexAttr(shape));
+    SmallVector<OpFoldResult> sizes =
+        getAsIndexOpFoldResult(rewriter.getContext(), outputType.getShape());
     auto winogradOutput = rewriter.create<tensor::ExtractSliceOp>(
         loc, outputType, paddedOutput, offsets, sizes, strides);
 
-    Value result = convOp.getResult(0);
-    result.replaceAllUsesWith(winogradOutput);
+    rewriter.replaceOp(convOp, winogradOutput);
     return success();
   }
 };
@@ -436,9 +343,7 @@
   void runOnOperation() override {
     MLIRContext *context = &getContext();
     RewritePatternSet patterns(&getContext());
-    patterns.insert<FoldWinogradFilterTransform<linalg::Conv2DNchwFchwOp>,
-                    FoldWinogradFilterTransform<linalg::Conv2DNhwcHwcfOp>,
-                    ConvertConvToWinograd<linalg::Conv2DNhwcHwcfOp>,
+    patterns.insert<ConvertConvToWinograd<linalg::Conv2DNhwcHwcfOp>,
                     ConvertConvToWinograd<linalg::Conv2DNchwFchwOp>>(context);
     if (failed(applyPatternsAndFoldGreedily(getOperation(),
                                             std::move(patterns)))) {
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_winograd.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_winograd.mlir
index 42b781d..b5be6e5 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_winograd.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_winograd.mlir
@@ -1,146 +1,87 @@
 // RUN: iree-opt --split-input-file -iree-linalg-ext-convert-conv2d-to-winograd -mlir-elide-elementsattrs-if-larger=4 %s | FileCheck %s
 
-func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
-  %c0 = arith.constant dense<0.1> : tensor<3x3x4x16xf32>
+func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
   %0 = linalg.conv_2d_nhwc_hwcf
     {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
-     ins(%arg0, %c0: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
+     ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
     outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
   return %0 : tensor<1x14x14x16xf32>
 }
-// CHECK:      func.func @conv_16433136(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x16x16x4xf32>, %[[ARG1:[a-zA-Z0-9_]+]]:
-// CHECK-SAME:   tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
-// CHECK:        %[[CST:.+]] = arith.constant dense_resource<__elided__> : tensor<64x4x16xf32>
-// CHECK:        %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<8x8x1x3x3x4xf32>
-// CHECK:        %[[D1:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME:     image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<1x16x16x4xf32>) outs(%[[D0]] :
+// CHECK:      func.func @conv_16433136(
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x16x16x4xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<3x3x4x16xf32>
+// CHECK-DAG:    %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:    %[[EMPTY0:.+]] = tensor.empty() : tensor<8x8x4x16xf32>
+// CHECK:        %[[FILTER_TF:.+]] = iree_linalg_ext.winograd.filter_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:     kernel_dimensions([0, 1]) ins(%[[ARG1]] : tensor<3x3x4x16xf32>) outs(%[[EMPTY0]] :
+// CHECK-SAME:     tensor<8x8x4x16xf32>) -> tensor<8x8x4x16xf32>
+// CHECK:        %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER_TF]]
+// CHECK-SAME{LITERAL}:  [[0, 1], [2], [3]]
+// CHECK-SAME:           tensor<8x8x4x16xf32> into tensor<64x4x16xf32>
+// CHECK:        %[[EMPTY1:.+]] = tensor.empty() : tensor<8x8x1x3x3x4xf32>
+// CHECK:        %[[INPUT_TF:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:     image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<1x16x16x4xf32>) outs(%[[EMPTY1]] :
 // CHECK-SAME:     tensor<8x8x1x3x3x4xf32>) -> tensor<8x8x1x3x3x4xf32>
-// CHECK:        %[[COLLAPSED:.+]] = tensor.collapse_shape %[[D1]]
+// CHECK:        %[[COLLAPSED_INPUT:.+]] = tensor.collapse_shape %[[INPUT_TF]]
 // CHECK-SAME{LITERAL}:  [[0, 1], [2, 3, 4], [5]]
 // CHECK-SAME:           tensor<8x8x1x3x3x4xf32> into tensor<64x9x4xf32>
-// CHECK:        %[[D2:.+]] = tensor.empty() : tensor<64x9x16xf32>
-// CHECK:        %[[D3:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D2]] : tensor<64x9x16xf32>) ->
+// CHECK:        %[[EMPTY2:.+]] = tensor.empty() : tensor<64x9x16xf32>
+// CHECK:        %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[EMPTY2]] : tensor<64x9x16xf32>) ->
 // CHECK-SAME:     tensor<64x9x16xf32>
-// CHECK:        %[[D4:.+]] = linalg.batch_matmul ins(%[[COLLAPSED]], %[[CST]] : tensor<64x9x4xf32>,
-// CHECK-SAME:     tensor<64x4x16xf32>) outs(%[[D3]] : tensor<64x9x16xf32>) -> tensor<64x9x16xf32>
-// CHECK:        %[[EXPANDED:.+]] = tensor.expand_shape %[[D4]]
+// CHECK:        %[[BMM:.+]] = linalg.batch_matmul ins(%[[COLLAPSED_INPUT]], %[[COLLAPSED_FILTER]] : tensor<64x9x4xf32>,
+// CHECK-SAME:     tensor<64x4x16xf32>) outs(%[[FILL]] : tensor<64x9x16xf32>) -> tensor<64x9x16xf32>
+// CHECK:        %[[EXPANDED:.+]] = tensor.expand_shape %[[BMM]]
 // CHECK-SAME{LITERAL}: [[0, 1], [2, 3, 4], [5]]
 // CHECK-SAME:          tensor<64x9x16xf32> into tensor<8x8x1x3x3x16xf32>
-// CHECK:        %[[D5:.+]] = tensor.empty() : tensor<1x18x18x16xf32>
-// CHECK:        %[[D6:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME:     image_dimensions([1, 2]) ins(%[[EXPANDED]] : tensor<8x8x1x3x3x16xf32>) outs(%[[D5]] :
+// CHECK:        %[[EMPTY3:.+]] = tensor.empty() : tensor<1x18x18x16xf32>
+// CHECK:        %[[OUTPUT_TF:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:     image_dimensions([1, 2]) ins(%[[EXPANDED]] : tensor<8x8x1x3x3x16xf32>) outs(%[[EMPTY3]] :
 // CHECK-SAME:     tensor<1x18x18x16xf32>) -> tensor<1x18x18x16xf32>
-// CHECK:        %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[D6]][0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] :
+// CHECK:        %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[OUTPUT_TF]][0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] :
 // CHECK-SAME:     tensor<1x18x18x16xf32> to tensor<1x14x14x16xf32>
 // CHECK:        return %[[EXTRACTED_SLICE]] : tensor<1x14x14x16xf32>
 // CHECK:      }
 
 // -----
 
-func.func @conv2d_non_splat_weights(%inputs : tensor<1x4x4x1xf32>, %arg2: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> {
-  %c0 = arith.constant dense<[[ [[1.0]],  [[3.0]],  [[5.0]]  ],
-                              [ [[7.0]],  [[9.0]],  [[11.0]] ],
-                              [ [[13.0]], [[15.0]], [[17.0]] ]]> : tensor<3x3x1x1xf32>
-  %0 = linalg.conv_2d_nhwc_hwcf
-    {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
-     ins(%inputs, %c0: tensor<1x4x4x1xf32>, tensor<3x3x1x1xf32>)
-    outs(%arg2: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
-  return %0 : tensor<1x2x2x1xf32>
-}
-// CHECK:      func.func @conv2d_non_splat_weights(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x4x4x1xf32>,
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> {
-// CHECK:        %[[CST:.+]] = arith.constant dense_resource<__elided__> : tensor<64x1x1xf32>
-// CHECK:        %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<8x8x1x1x1x1xf32>
-// CHECK:        %[[D1:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME:     image_dimensions([1, 2]) ins(%[[ARG0]] : tensor<1x4x4x1xf32>) outs(%[[D0]] : tensor<8x8x1x1x1x1xf32>)
-// CHECK-SAME:     -> tensor<8x8x1x1x1x1xf32>
-// CHECK:        %[[COLLAPSED:.+]] = tensor.collapse_shape %[[D1]]
-// CHECK-SAME{LITERAL}:   [[0, 1], [2, 3, 4], [5]]
-// CHECK-SAME:            tensor<8x8x1x1x1x1xf32> into tensor<64x1x1xf32>
-// CHECK:        %[[D2:.+]] = tensor.empty() : tensor<64x1x1xf32>
-// CHECK:        %[[D3:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D2]] : tensor<64x1x1xf32>) ->
-// CHECK-SAME:     tensor<64x1x1xf32>
-// CHECK:        %[[D4:.+]] = linalg.batch_matmul ins(%[[COLLAPSED]], %[[CST]] : tensor<64x1x1xf32>, tensor<64x1x1xf32>)
-// CHECK-SAME:     outs(%[[D3]] : tensor<64x1x1xf32>) -> tensor<64x1x1xf32>
-// CHECK:        %[[EXPANDED:.+]] = tensor.expand_shape %[[D4]]
-// CHECK-SAME{LITERAL}:   [[0, 1], [2, 3, 4], [5]]
-// CHECK-SAME:            tensor<64x1x1xf32> into tensor<8x8x1x1x1x1xf32>
-// CHECK:        %[[D5:.+]] = tensor.empty() : tensor<1x6x6x1xf32>
-// CHECK:        %[[D6:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME:     image_dimensions([1, 2]) ins(%[[EXPANDED]] : tensor<8x8x1x1x1x1xf32>) outs(%[[D5]] :
-// CHECK-SAME:     tensor<1x6x6x1xf32>) -> tensor<1x6x6x1xf32>
-// CHECK:        %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[D6]][0, 0, 0, 0] [1, 2, 2, 1] [1, 1, 1, 1] :
-// CHECK-SAME:     tensor<1x6x6x1xf32> to tensor<1x2x2x1xf32>
-// CHECK:        return %[[EXTRACTED_SLICE]] : tensor<1x2x2x1xf32>
-// CHECK:      }
-
-// -----
-
-func.func @conv_16433136_nchw_fchw(%arg0: tensor<1x4x16x16xf32>, %arg2: tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32> {
-  %c0 = arith.constant dense<0.1> : tensor<16x4x3x3xf32>
+func.func @conv_16433136_nchw_fchw(%arg0: tensor<1x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32> {
   %0 = linalg.conv_2d_nchw_fchw
     {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
-     ins(%arg0, %c0: tensor<1x4x16x16xf32>, tensor<16x4x3x3xf32>)
+     ins(%arg0, %arg1: tensor<1x4x16x16xf32>, tensor<16x4x3x3xf32>)
     outs(%arg2: tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32>
   return %0 : tensor<1x16x14x14xf32>
 }
-// CHECK:      func.func @conv_16433136_nchw_fchw(%[[ARG0]]: tensor<1x4x16x16xf32>, %[[ARG1]]: tensor<1x16x14x14xf32>)
-// CHECK-SAME:   -> tensor<1x16x14x14xf32> {
-// CHECK:        %[[CST]] = arith.constant dense_resource<__elided__> : tensor<64x4x16xf32>
-// CHECK:        %[[CST_0]] = arith.constant 0.000000e+00 : f32
-// CHECK:        %[[D0]] = tensor.empty() : tensor<8x8x1x3x3x4xf32>
-// CHECK:        %[[D1]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME:     image_dimensions([2, 3]) ins(%[[ARG0]] : tensor<1x4x16x16xf32>) outs(%[[D0]] :
+// CHECK:      func.func @conv_16433136_nchw_fchw(
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x4x16x16xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<16x4x3x3xf32>
+// CHECK-DAG:    %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:    %[[EMPTY0:.+]] = tensor.empty() : tensor<8x8x4x16xf32>
+// CHECK:        %[[FILTER_TF:.+]] = iree_linalg_ext.winograd.filter_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:     kernel_dimensions([2, 3]) ins(%[[ARG1]] : tensor<16x4x3x3xf32>) outs(%[[EMPTY0]] :
+// CHECK-SAME:     tensor<8x8x4x16xf32>) -> tensor<8x8x4x16xf32>
+// CHECK:        %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER_TF]]
+// CHECK-SAME{LITERAL}:  [[0, 1], [2], [3]]
+// CHECK-SAME:           tensor<8x8x4x16xf32> into tensor<64x4x16xf32>
+// CHECK:        %[[EMPTY1:.+]] = tensor.empty() : tensor<8x8x1x3x3x4xf32>
+// CHECK:        %[[INPUT_TF:.+]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:     image_dimensions([2, 3]) ins(%[[ARG0]] : tensor<1x4x16x16xf32>) outs(%[[EMPTY1]] :
 // CHECK-SAME:     tensor<8x8x1x3x3x4xf32>) -> tensor<8x8x1x3x3x4xf32>
-// CHECK:        %[[COLLAPSED]] = tensor.collapse_shape %[[D1]]
-// CHECK:        %[[D2]] = tensor.empty() : tensor<64x9x16xf32>
-// CHECK:        %[[D3]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D2]] : tensor<64x9x16xf32>) -> tensor<64x9x16xf32>
-// CHECK:        %[[D4]] = linalg.batch_matmul ins(%[[COLLAPSED]], %[[CST]] : tensor<64x9x4xf32>, tensor<64x4x16xf32>)
-// CHECK-SAME:     outs(%[[D3]] : tensor<64x9x16xf32>) -> tensor<64x9x16xf32>
-// CHECK:        %[[EXPANDED]] = tensor.expand_shape %[[D4]]
-// CHECK:        %[[D5]] = tensor.empty() : tensor<1x16x18x18xf32>
-// CHECK:        %[[D6]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME:     image_dimensions([2, 3]) ins(%[[EXPANDED]] : tensor<8x8x1x3x3x16xf32>) outs(%[[D5]] :
+// CHECK:        %[[COLLAPSED_INPUT:.+]] = tensor.collapse_shape %[[INPUT_TF]]
+// CHECK-SAME{LITERAL}:  [[0, 1], [2, 3, 4], [5]]
+// CHECK-SAME:           tensor<8x8x1x3x3x4xf32> into tensor<64x9x4xf32>
+// CHECK:        %[[EMPTY2:.+]] = tensor.empty() : tensor<64x9x16xf32>
+// CHECK:        %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[EMPTY2]] : tensor<64x9x16xf32>) ->
+// CHECK-SAME:     tensor<64x9x16xf32>
+// CHECK:        %[[BMM:.+]] = linalg.batch_matmul ins(%[[COLLAPSED_INPUT]], %[[COLLAPSED_FILTER]] : tensor<64x9x4xf32>,
+// CHECK-SAME:     tensor<64x4x16xf32>) outs(%[[FILL]] : tensor<64x9x16xf32>) -> tensor<64x9x16xf32>
+// CHECK:        %[[EXPANDED:.+]] = tensor.expand_shape %[[BMM]]
+// CHECK-SAME{LITERAL}: [[0, 1], [2, 3, 4], [5]]
+// CHECK-SAME:          tensor<64x9x16xf32> into tensor<8x8x1x3x3x16xf32>
+// CHECK:        %[[EMPTY3:.+]] = tensor.empty() : tensor<1x16x18x18xf32>
+// CHECK:        %[[OUTPUT_TF:.+]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
+// CHECK-SAME:     image_dimensions([2, 3]) ins(%[[EXPANDED]] : tensor<8x8x1x3x3x16xf32>) outs(%[[EMPTY3]] :
 // CHECK-SAME:     tensor<1x16x18x18xf32>) -> tensor<1x16x18x18xf32>
-// CHECK:        %[[EXTRACTED_SLICE]] = tensor.extract_slice %[[D6]][0, 0, 0, 0] [1, 16, 14, 14] [1, 1, 1, 1] :
+// CHECK:        %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[OUTPUT_TF]][0, 0, 0, 0] [1, 16, 14, 14] [1, 1, 1, 1] :
 // CHECK-SAME:     tensor<1x16x18x18xf32> to tensor<1x16x14x14xf32>
 // CHECK:        return %[[EXTRACTED_SLICE]] : tensor<1x16x14x14xf32>
 // CHECK:      }
-// CHECK:    }
-
-// -----
-
-func.func @conv2d_nchw_non_splat_weights(%inputs : tensor<1x1x4x4xf32>, %arg2: tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32> {
-  %c0 = arith.constant dense<[[[[ 1.0,  3.0,  5.0  ],
-                                [ 7.0,  9.0,  11.0 ],
-                                [ 13.0, 15.0, 17.0 ]]]]> : tensor<1x1x3x3xf32>
-  %0 = linalg.conv_2d_nchw_fchw
-    {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
-     ins(%inputs, %c0: tensor<1x1x4x4xf32>, tensor<1x1x3x3xf32>)
-    outs(%arg2: tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32>
-  return %0 : tensor<1x1x2x2xf32>
-}
-// CHECK:      func.func @conv2d_nchw_non_splat_weights(%[[ARG0]]: tensor<1x1x4x4xf32>, %[[ARG1]]: tensor<1x1x2x2xf32>)
-// CHECK-SAME:   -> tensor<1x1x2x2xf32> {
-// CHECK:        %[[CST]] = arith.constant dense_resource<__elided__> : tensor<64x1x1xf32>
-// CHECK:        %[[CST_0]] = arith.constant 0.000000e+00 : f32
-// CHECK:        %[[D0]] = tensor.empty() : tensor<8x8x1x1x1x1xf32>
-// CHECK:        %[[D1]] = iree_linalg_ext.winograd.input_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME:     image_dimensions([2, 3]) ins(%[[ARG0]] : tensor<1x1x4x4xf32>) outs(%[[D0]] : tensor<8x8x1x1x1x1xf32>)
-// CHECK-SAME:     -> tensor<8x8x1x1x1x1xf32>
-// CHECK:        %[[COLLAPSED]] = tensor.collapse_shape %[[D1]]
-// CHECK:        %[[D2]] = tensor.empty() : tensor<64x1x1xf32>
-// CHECK:        %[[D3]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D2]] : tensor<64x1x1xf32>) -> tensor<64x1x1xf32>
-// CHECK:        %[[D4]] = linalg.batch_matmul ins(%[[COLLAPSED]], %[[CST]] : tensor<64x1x1xf32>, tensor<64x1x1xf32>)
-// CHECK-SAME:     outs(%[[D3]] : tensor<64x1x1xf32>) -> tensor<64x1x1xf32>
-// CHECK:        %[[EXPANDED]] = tensor.expand_shape %[[D4]]
-// CHECK:        %[[D5]] = tensor.empty() : tensor<1x1x6x6xf32>
-// CHECK:        %[[D6]] = iree_linalg_ext.winograd.output_transform output_tile_size(6) kernel_size(3)
-// CHECK-SAME:     image_dimensions([2, 3]) ins(%[[EXPANDED]] : tensor<8x8x1x1x1x1xf32>) outs(%[[D5]] :
-// CHECK-SAME:     tensor<1x1x6x6xf32>) -> tensor<1x1x6x6xf32>
-// CHECK:        %[[EXTRACTED_SLICE]] = tensor.extract_slice %[[D6]][0, 0, 0, 0] [1, 1, 2, 2] [1, 1, 1, 1] :
-// CHECK-SAME:     tensor<1x1x6x6xf32> to tensor<1x1x2x2xf32>
-// CHECK:        return %[[EXTRACTED_SLICE]] : tensor<1x1x2x2xf32>
-// CHECK:      }