Add conversions for 1x1 conv_2d to matmul (#18736)

Convert1X1FilterConv2DToMatmul: handle dynamic cases and conversion to `linalg.generic` representing a broadcasted batch matmul. This pass is kept because it is used in plugins. 

GeneralizeLinalgNamedOps: generalize conv ops to `linalg.generic` ops when possible.

Converting more ops to linalg.generic ops allows for better reshape propagation and fusion opportunities. Also, removed Convert1X1FilterConv2DToMatmulPass from global optimization because generalize named ops would have already generalized any convolutions that were possible to convert.

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml
index 0748ec5..fb94905 100644
--- a/.github/workflows/pkgci_regression_test.yml
+++ b/.github/workflows/pkgci_regression_test.yml
@@ -220,9 +220,9 @@
             --goldentime-rocm-unet-ms 419.0 \
             --goldentime-rocm-clip-ms 18.5 \
             --goldentime-rocm-vae-ms 337.0 \
-            --goldendispatch-rocm-unet 1545 \
+            --goldendispatch-rocm-unet 1527 \
             --goldendispatch-rocm-clip 1139 \
-            --goldendispatch-rocm-vae 248 \
+            --goldendispatch-rocm-vae 247 \
             --goldensize-rocm-unet-bytes 2280000  \
             --goldensize-rocm-clip-bytes 860000 \
             --goldensize-rocm-vae-bytes 840000 \
@@ -241,9 +241,9 @@
             --goldentime-rocm-unet-ms 95.0 \
             --goldentime-rocm-clip-ms 15.5 \
             --goldentime-rocm-vae-ms 80.0 \
-            --goldendispatch-rocm-unet 1545 \
+            --goldendispatch-rocm-unet 1527 \
             --goldendispatch-rocm-clip 1139 \
-            --goldendispatch-rocm-vae 248 \
+            --goldendispatch-rocm-vae 247 \
             --goldensize-rocm-unet-bytes 2270000 \
             --goldensize-rocm-clip-bytes 860000  \
             --goldensize-rocm-vae-bytes 840000 \
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp b/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp
index a8b4bec..7128dbd 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp
@@ -6,7 +6,8 @@
 
 #include "iree/compiler/GlobalOptimization/Passes.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -26,134 +27,51 @@
 
   LogicalResult matchAndRewrite(Conv2DOpType convOp,
                                 PatternRewriter &rewriter) const override {
-    auto inputShapeType = llvm::dyn_cast<RankedTensorType>(
-        convOp.getDpsInputOperand(0)->get().getType());
     auto filterShapeType = llvm::dyn_cast<RankedTensorType>(
         convOp.getDpsInputOperand(1)->get().getType());
-    auto outputShapeType = llvm::dyn_cast<RankedTensorType>(
-        convOp.getDpsInitOperand(0)->get().getType());
-
-    const bool isNCHW = isa<linalg::Conv2DNchwFchwOp>(convOp);
-    const bool isNHWC = isa<linalg::Conv2DNhwcHwcfOp>(convOp);
-    if (!isNCHW & !isNHWC)
+    if (!filterShapeType)
       return failure();
 
-    if (!inputShapeType || !filterShapeType || !outputShapeType)
-      return failure();
+    constexpr bool isNCHW =
+        std::is_same_v<linalg::Conv2DNchwFchwOp, Conv2DOpType>;
+    constexpr bool isNHWC =
+        std::is_same_v<linalg::Conv2DNhwcHwcfOp, Conv2DOpType>;
+    static_assert(isNCHW || isNHWC);
 
-    auto inputShape = inputShapeType.getShape();
     auto filterShape = filterShapeType.getShape();
-    auto outputShape = outputShapeType.getShape();
+
+    constexpr int64_t numLoops = 7;
 
     // Adjusting dimension indices based on Conv2DOpType.
-    const int nIndex = 0;
-    const int kcIndex = isNHWC ? 2 : 1;
-    const int kfIndex = isNHWC ? 3 : 0;
-    const int khIndex = isNHWC ? 0 : 2;
-    const int kwIndex = isNHWC ? 1 : 3;
-    const int ohIndex = isNHWC ? 1 : 2;
-    const int owIndex = isNHWC ? 2 : 3;
-    const int ocIndex = isNHWC ? 3 : 1;
-
-    bool isInputHWDynamic = ShapedType::isDynamic(inputShape[ohIndex]) &&
-                            ShapedType::isDynamic(inputShape[owIndex]);
-
-    // We cannot merge the width and height if they are both dynamic as we
-    // cannot expand them back to their dynamic values.
-    if (isInputHWDynamic)
-      return failure();
+    constexpr int khIndex = isNHWC ? 0 : 2;
+    constexpr int kwIndex = isNHWC ? 1 : 3;
+    constexpr int khLoopIndex = isNHWC ? 4 : 5;
+    constexpr int kwLoopIndex = isNHWC ? 5 : 6;
 
     if (filterShape[khIndex] != 1 || filterShape[kwIndex] != 1)
       return failure();
 
-    // TODO(ataei): Support conversion to linalg.batch_matmul.
-    if (inputShape[0] != 1)
-      return failure();
-
-    if (!llvm::all_of(convOp.getStrides(), [](APInt element) {
-          return element.getSExtValue() == 1;
-        }))
-      return failure();
-    if (!llvm::all_of(convOp.getDilations(), [](APInt element) {
-          return element.getSExtValue() == 1;
-        }))
-      return failure();
-
-    auto combineDims = [](int64_t a, int64_t b) {
-      if (ShapedType::isDynamic(a) || ShapedType::isDynamic(b))
-        return ShapedType::kDynamic;
-      return a * b;
-    };
-
-    SmallVector<ReassociationIndices> reassociationInputOutputIndices;
-    SmallVector<ReassociationIndices> reassociationFilterIndices;
-    SmallVector<int64_t> reshapedInputShape(2, 0);
-    SmallVector<int64_t> reshapedFilterShape(2, 0);
-    SmallVector<int64_t> reshapedOutputShape(2, 0);
-    if (isNHWC) {
-      // Generate reassociation indices.
-      reassociationInputOutputIndices = {{nIndex, ohIndex, owIndex}, {ocIndex}};
-      reassociationFilterIndices = {{khIndex, kwIndex, kcIndex}, {kfIndex}};
-
-      // Generate matmul shapes from 1x1 conv.
-      reshapedInputShape = {
-          combineDims(inputShape[ohIndex], inputShape[owIndex]),
-          inputShape[ocIndex]};
-      reshapedFilterShape = {filterShape[kcIndex], filterShape[kfIndex]};
-      reshapedOutputShape = {
-          combineDims(outputShape[ohIndex], outputShape[owIndex]),
-          outputShape[ocIndex]};
-    } else if (isNCHW) {
-      // Generate reassociation indices.
-      reassociationInputOutputIndices = {{nIndex, ocIndex}, {ohIndex, owIndex}};
-      reassociationFilterIndices = {{kfIndex}, {kcIndex, khIndex, kwIndex}};
-
-      // Generate matmul shapes from 1x1 conv.
-      reshapedInputShape = {
-          inputShape[ocIndex],
-          combineDims(inputShape[ohIndex], inputShape[owIndex])};
-      reshapedFilterShape = {filterShape[kfIndex], filterShape[kcIndex]};
-      reshapedOutputShape = {
-          outputShape[ocIndex],
-          combineDims(outputShape[ohIndex], outputShape[owIndex])};
+    SmallVector<AffineExpr> dimReplacements;
+    for (int i = 0; i < numLoops; i++) {
+      if (llvm::is_contained({khLoopIndex, kwLoopIndex}, i)) {
+        dimReplacements.push_back(
+            getAffineConstantExpr(0, rewriter.getContext()));
+      } else {
+        dimReplacements.push_back(getAffineDimExpr(i, rewriter.getContext()));
+      }
     }
 
-    auto reshapedInputType = RankedTensorType::get(
-        reshapedInputShape, inputShapeType.getElementType());
+    SmallVector<AffineMap> newMaps = convOp.getIndexingMapsArray();
+    AffineMap inputMap = newMaps[0];
+    SmallVector<AffineExpr> newExprs =
+        llvm::map_to_vector(inputMap.getResults(), [&](AffineExpr resultExpr) {
+          return resultExpr.replaceDims(dimReplacements);
+        });
+    newMaps[0] = AffineMap::get(inputMap.getNumDims(), inputMap.getNumSymbols(),
+                                newExprs, rewriter.getContext());
 
-    auto reshapedFilterType = RankedTensorType::get(
-        reshapedFilterShape, filterShapeType.getElementType());
-
-    auto reshapedOutputType = RankedTensorType::get(
-        reshapedOutputShape, outputShapeType.getElementType());
-
-    Value input = convOp.getDpsInputOperand(0)->get();
-    Value filter = convOp.getDpsInputOperand(1)->get();
-    Value output = convOp.getDpsInitOperand(0)->get();
-    auto loc = convOp.getLoc();
-
-    Value reshapedInput = rewriter.create<tensor::CollapseShapeOp>(
-        loc, reshapedInputType, input, reassociationInputOutputIndices);
-    Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
-        loc, reshapedFilterType, filter, reassociationFilterIndices);
-    Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
-        loc, reshapedOutputType, output, reassociationInputOutputIndices);
-
-    SmallVector<Value, 2> matmulInput;
-    if (isNHWC) {
-      matmulInput = {reshapedInput, reshapedFilter};
-    } else if (isNCHW) {
-      matmulInput = {reshapedFilter, reshapedInput};
-    }
-    auto matmulResult = rewriter.create<linalg::MatmulOp>(
-        loc, reshapedOutputType, matmulInput, ArrayRef<Value>{reshapedOutput});
-
-    auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
-        loc, outputShapeType, matmulResult.getResults()[0],
-        reassociationInputOutputIndices);
-
-    rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
-
+    auto genericOp = linalg::generalizeNamedOp(rewriter, convOp).value();
+    genericOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(newMaps));
     return success();
   }
 };
diff --git a/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp
index 92293bc..99f6268 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp
@@ -30,6 +30,34 @@
 };
 } // namespace
 
+/// Returns true of `linalgOp` is a Conv2DNchwFchwOp or Conv2DNhwcHwcfOp with
+/// all strides equal to 1 and with a kernel height and width of 1
+static bool isConvFoldableToContraction(linalg::LinalgOp linalgOp) {
+  auto NCHWOp = dyn_cast<linalg::Conv2DNchwFchwOp>(linalgOp.getOperation());
+  auto NHWCOp = dyn_cast<linalg::Conv2DNhwcHwcfOp>(linalgOp.getOperation());
+
+  if (!NCHWOp && !NHWCOp)
+    return false;
+
+  DenseIntElementsAttr strides =
+      NCHWOp ? NCHWOp.getStrides() : NHWCOp.getStrides();
+  if (!llvm::all_of(
+          strides, [](APInt element) { return element.getSExtValue() == 1; })) {
+    return false;
+  }
+
+  auto filterShapeType = llvm::dyn_cast<RankedTensorType>(
+      linalgOp.getDpsInputOperand(1)->get().getType());
+  if (!filterShapeType)
+    return false;
+
+  // Adjusting dimension indices based on Conv2DOpType.
+  const int khIndex = NHWCOp ? 0 : 2;
+  const int kwIndex = NHWCOp ? 1 : 3;
+  auto filterShape = filterShapeType.getShape();
+  return filterShape[khIndex] == 1 && filterShape[kwIndex] == 1;
+}
+
 void GeneralizeLinalgNamedOpsPass::runOnOperation() {
   auto funcOp = getOperation();
   SmallVector<linalg::LinalgOp> namedOpCandidates;
@@ -44,7 +72,8 @@
                         linalg::LogOp, linalg::MapOp, linalg::MaxOp,
                         linalg::MulOp, linalg::NegFOp, linalg::ReduceOp,
                         linalg::SubOp, linalg::TransposeOp>(
-            linalgOp.getOperation())) {
+            linalgOp.getOperation()) ||
+        isConvFoldableToContraction(linalgOp)) {
       namedOpCandidates.push_back(linalgOp);
     }
   });
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
index 4f9a33e..bd61d4b 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
@@ -101,8 +101,7 @@
       .addPass(IREE::Flow::createCanonicalizerPass)
       .addPass(createRemoveZeroExtentTensorsPass)
       .addPass(createDetachElementwiseFromNamedOpsPass)
-      .addPass(mlir::createLinalgNamedOpConversionPass)
-      .addPass(createConvert1X1FilterConv2DToMatmulPass);
+      .addPass(mlir::createLinalgNamedOpConversionPass);
   mainPassManager.addPass(createEraseUnusedLinalgOperandsPass());
 
   // Expand tensor shapes into SSA values and optimize the whole program.
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/conv1x1_to_matmul.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/conv1x1_to_matmul.mlir
index 980db93..607f137 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/conv1x1_to_matmul.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/conv1x1_to_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file -iree-global-opt-convert-1x1-filter-conv2d-to-matmul %s | FileCheck %s
+// RUN: iree-opt --split-input-file --mlir-print-local-scope -iree-global-opt-convert-1x1-filter-conv2d-to-matmul %s | FileCheck %s
 
 util.func public @nhwc_conv_2d(%input: tensor<1x4x5x2xf32>, %filter: tensor<1x1x2x7xf32>) -> tensor<1x4x5x7xf32> {
     %0 = tensor.empty() : tensor<1x4x5x7xf32>
@@ -9,20 +9,15 @@
     util.return %1 : tensor<1x4x5x7xf32>
 }
 
-// CHECK: @nhwc_conv_2d
-// CHECK: %[[INPUT:.+]]: tensor<1x4x5x2xf32>
-// CHECK: %[[FILTER:.+]]: tensor<1x1x2x7xf32>
-// CHECK: %[[OUTPUT:.+]] = tensor.empty() : tensor<1x4x5x7xf32>
-// CHECK: %[[RESHAPED_INPUT:.+]] = tensor.collapse_shape %[[INPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x5x2xf32> into tensor<20x2xf32>
-// CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x2x7xf32> into tensor<2x7xf32>
-// CHECK: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x5x7xf32> into tensor<20x7xf32>
-// CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INPUT]], %[[RESHAPED_FILTER]] : tensor<20x2xf32>, tensor<2x7xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<20x7xf32>)
-// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] output_shape [1, 4, 5, 7] : tensor<20x7xf32> into tensor<1x4x5x7xf32>
-// CHECK: util.return %[[RESULT]]
+// CHECK-LABEL: @nhwc_conv_2d
+//       CHECK:   %[[RESULT:.*]] = linalg.generic
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)>
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+//       CHECK:   util.return %[[RESULT]]
 
 // -----
 
-// CHECK: @dynamic_nhwc_conv_2d
 util.func public @dynamic_nhwc_conv_2d(%input: tensor<1x4x?x2xf32>, %filter: tensor<1x1x2x7xf32>) -> tensor<1x4x?x7xf32> {
     %c2 = arith.constant 2 : index
     %d2 = tensor.dim %input, %c2 : tensor<1x4x?x2xf32>
@@ -34,34 +29,12 @@
     util.return %1 : tensor<1x4x?x7xf32>
 }
 
-// CHECK: %[[INPUT:.+]]: tensor<1x4x?x2xf32>
-// CHECK: %[[FILTER:.+]]: tensor<1x1x2x7xf32>
-// CHECK: %[[C2:.+]] = arith.constant 2 : index
-// CHECK: %[[D2:.+]] = tensor.dim %[[INPUT]], %[[C2]]
-// CHECK: %[[OUTPUT:.+]] = tensor.empty(%[[D2]]) : tensor<1x4x?x7xf32>
-// CHECK: %[[RESHAPED_INPUT:.+]] = tensor.collapse_shape %[[INPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x?x2xf32> into tensor<?x2xf32>
-// CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x2x7xf32> into tensor<2x7xf32>
-// CHECK: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x?x7xf32> into tensor<?x7xf32>
-// CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INPUT]], %[[RESHAPED_FILTER]] : tensor<?x2xf32>, tensor<2x7xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<?x7xf32>)
-// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]]
-
-// -----
-
-util.func public @fail_dynamic_nhwc_conv_2d(%input: tensor<1x?x?x2xf32>, %filter: tensor<1x1x2x7xf32>) -> tensor<1x?x?x7xf32> {
-    %c1 = arith.constant 1 : index
-    %c2 = arith.constant 2 : index
-    %d1 = tensor.dim %input, %c1 : tensor<1x?x?x2xf32>
-    %d2 = tensor.dim %input, %c2 : tensor<1x?x?x2xf32>
-    %0 = tensor.empty(%d1, %d2) : tensor<1x?x?x7xf32>
-    %1 = linalg.conv_2d_nhwc_hwcf {
-        dilations = dense<1> : tensor<2xi64>,
-        strides = dense<1> : tensor<2xi64>
-    } ins(%input, %filter : tensor<1x?x?x2xf32>, tensor<1x1x2x7xf32>) outs(%0 : tensor<1x?x?x7xf32>) -> tensor<1x?x?x7xf32>
-    util.return %1 : tensor<1x?x?x7xf32>
-}
-
-// CHECK: @fail_dynamic_nhwc_conv_2d
-// CHECK: linalg.conv_2d_nhwc_hwcf
+// CHECK-LABEL: @dynamic_nhwc_conv_2d
+//       CHECK:   %[[RESULT:.*]] = linalg.generic
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)>
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+//       CHECK:   util.return %[[RESULT]]
 
 // -----
 
@@ -73,16 +46,12 @@
     } ins(%input, %filter : tensor<1x2x4x5xf32>, tensor<7x2x1x1xf32>) outs(%0 : tensor<1x7x4x5xf32>) -> tensor<1x7x4x5xf32>
     util.return %1 : tensor<1x7x4x5xf32>
 }
-// CHECK: @nchw_conv_2d
-// CHECK: %[[INPUT:.+]]: tensor<1x2x4x5xf32>
-// CHECK: %[[FILTER:.+]]: tensor<7x2x1x1xf32>
-// CHECK: %[[OUTPUT:.+]] = tensor.empty() : tensor<1x7x4x5xf32>
-// CHECK: %[[RESHAPED_INPUT:.+]] = tensor.collapse_shape %[[INPUT]] {{\[}}[0, 1], [2, 3]] : tensor<1x2x4x5xf32> into tensor<2x20xf32>
-// CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<7x2x1x1xf32> into tensor<7x2xf32>
-// CHECK: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0, 1], [2, 3]] : tensor<1x7x4x5xf32> into tensor<7x20xf32>
-// CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_FILTER]], %[[RESHAPED_INPUT]] : tensor<7x2xf32>, tensor<2x20xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<7x20xf32>)
-// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1], [2, 3]] output_shape [1, 7, 4, 5] : tensor<7x20xf32> into tensor<1x7x4x5xf32>
-// CHECK: util.return %[[RESULT]]
+// CHECK-LABEL: @nchw_conv_2d
+//       CHECK:   %[[RESULT:.*]] = linalg.generic
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2, d3)>
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+//       CHECK:   util.return %[[RESULT]]
 
 // -----
 
@@ -97,33 +66,27 @@
     util.return %1 : tensor<1x7x4x?xf32>
 }
 
-// CHECK: @dynamic_nchw_conv_2d
-// CHECK: %[[INPUT:.+]]: tensor<1x2x4x?xf32>
-// CHECK: %[[FILTER:.+]]: tensor<7x2x1x1xf32>
-// CHECK: %[[C3:.+]] = arith.constant 3 : index
-// CHECK: %[[D3:.+]] = tensor.dim %[[INPUT]], %[[C3]]
-// CHECK: %[[OUTPUT:.+]] = tensor.empty(%[[D3]]) : tensor<1x7x4x?xf32>
-// CHECK: %[[RESHAPED_INPUT:.+]] = tensor.collapse_shape %[[INPUT]] {{\[}}[0, 1], [2, 3]] : tensor<1x2x4x?xf32> into tensor<2x?xf32>
-// CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<7x2x1x1xf32> into tensor<7x2xf32>
-// CHECK: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0, 1], [2, 3]] : tensor<1x7x4x?xf32> into tensor<7x?xf32>
-// CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_FILTER]], %[[RESHAPED_INPUT]] : tensor<7x2xf32>, tensor<2x?xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<7x?xf32>)
-// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1], [2, 3]]
-// CHECK: util.return %[[RESULT]]
+// CHECK-LABEL: @dynamic_nchw_conv_2d
+//       CHECK:   %[[RESULT:.*]] = linalg.generic
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2, d3)>
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+//       CHECK:   util.return %[[RESULT]]
 
 // -----
 
-util.func public @fail_dynamic_nchw_conv_2d(%input: tensor<1x2x?x?xf32>, %filter: tensor<7x2x1x1xf32>) -> tensor<1x7x?x?xf32> {
-    %c2 = arith.constant 2 : index
-    %c3 = arith.constant 3 : index
-    %d2 = tensor.dim %input, %c2 : tensor<1x2x?x?xf32>
-    %d3 = tensor.dim %input, %c3 : tensor<1x2x?x?xf32>
+util.func public @strided_nchw_conv_2d(%input: tensor<1x2x?x?xf32>, %filter: tensor<7x2x1x1xf32>, %d2 : index, %d3 : index) -> tensor<1x7x?x?xf32> {
     %0 = tensor.empty(%d2, %d3) : tensor<1x7x?x?xf32>
     %1 = linalg.conv_2d_nchw_fchw {
         dilations = dense<1> : tensor<2xi64>,
-        strides = dense<1> : tensor<2xi64>
+        strides = dense<2> : tensor<2xi64>
     } ins(%input, %filter : tensor<1x2x?x?xf32>, tensor<7x2x1x1xf32>) outs(%0 : tensor<1x7x?x?xf32>) -> tensor<1x7x?x?xf32>
     util.return %1 : tensor<1x7x?x?xf32>
 }
 
-// CHECK: @fail_dynamic_nchw_conv_2d
-// CHECK: linalg.conv_2d_nchw_fchw
+// CHECK-LABEL: @strided_nchw_conv_2d
+//       CHECK:   %[[RESULT:.*]] = linalg.generic
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 * 2, d3 * 2)>
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+//  CHECK-SAME:     affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+//       CHECK:   util.return %[[RESULT]]
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir
index 5111152..f3f0f8a 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-generalize-linalg-named-ops))" --split-input-file %s | FileCheck %s
+// RUN: iree-opt --mlir-print-local-scope --pass-pipeline="builtin.module(util.func(iree-global-opt-generalize-linalg-named-ops))" --split-input-file %s | FileCheck %s
 
 util.func public @generalize_op(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
   %c0 = arith.constant 0 : index
@@ -34,3 +34,35 @@
 //       CHECK:     %[[ADD:.+]] = linalg.add
 //       CHECK:     flow.return %[[ADD]]
 //       CHECK:   util.return %[[DISPATCH]]
+
+// -----
+
+util.func public @generalize_1x1_nhwc_conv_2d(%input: tensor<1x4x?x2xf32>, %filter: tensor<1x1x2x7xf32>) -> tensor<1x4x?x7xf32> {
+    %c2 = arith.constant 2 : index
+    %d2 = tensor.dim %input, %c2 : tensor<1x4x?x2xf32>
+    %0 = tensor.empty(%d2) : tensor<1x4x?x7xf32>
+    %1 = linalg.conv_2d_nhwc_hwcf {
+        dilations = dense<1> : tensor<2xi64>,
+        strides = dense<1> : tensor<2xi64>
+    } ins(%input, %filter : tensor<1x4x?x2xf32>, tensor<1x1x2x7xf32>) outs(%0 : tensor<1x4x?x7xf32>) -> tensor<1x4x?x7xf32>
+    util.return %1 : tensor<1x4x?x7xf32>
+}
+
+// CHECK-LABEL: @generalize_1x1_nhwc_conv_2d
+//       CHECK:   %[[RESULT:.*]] = linalg.generic
+//       CHECK:   util.return %[[RESULT]]
+
+// -----
+
+util.func public @generalize_1x1_nchw_conv_2d(%input: tensor<1x2x4x5xf32>, %filter: tensor<7x2x1x1xf32>) -> tensor<1x7x4x5xf32> {
+    %0 = tensor.empty() : tensor<1x7x4x5xf32>
+    %1 = linalg.conv_2d_nchw_fchw {
+        dilations = dense<1> : tensor<2xi64>,
+        strides = dense<1> : tensor<2xi64>
+    } ins(%input, %filter : tensor<1x2x4x5xf32>, tensor<7x2x1x1xf32>) outs(%0 : tensor<1x7x4x5xf32>) -> tensor<1x7x4x5xf32>
+    util.return %1 : tensor<1x7x4x5xf32>
+}
+
+// CHECK-LABEL: @generalize_1x1_nchw_conv_2d
+//       CHECK:   %[[RESULT:.*]] = linalg.generic
+//       CHECK:   util.return %[[RESULT]]