[flow] Generalize 1x1 Conv2D to matmul for NCHW (#10616)

Most models from PyTorch uses NCHW Conv2D by default. To accelerate the
1x1 filter Conv2D workload with NCHW layout, we generalize the 1x1 filter
Conv2D to matmul pass to be able to handle NCHW as well.

Perf boost on ResNet50 (PyTorch):
- CPU (Threadripper PRO 3995WX): 100ms -> 80ms
- Vulkan (Threadripper PRO 3995WX): 48ms -> 31.7ms
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD
index c6a886e..8d1064b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -34,7 +34,7 @@
         "CaptureDispatchDynamicDims.cpp",
         "CleanupNumericNarrowing.cpp",
         "CleanupTensorShapes.cpp",
-        "ConvertConv2D1x1ToMatmul.cpp",
+        "Convert1X1FilterConv2DToMatmul.cpp",
         "ConvertConv2DToImg2Col.cpp",
         "ConvertLinalgMatmulToMmt4D.cpp",
         "ConvertRegionToWorkgroups.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 82c3bbb..ac62bf8 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -32,7 +32,7 @@
     "CaptureDispatchDynamicDims.cpp"
     "CleanupNumericNarrowing.cpp"
     "CleanupTensorShapes.cpp"
-    "ConvertConv2D1x1ToMatmul.cpp"
+    "Convert1X1FilterConv2DToMatmul.cpp"
     "ConvertConv2DToImg2Col.cpp"
     "ConvertLinalgMatmulToMmt4D.cpp"
     "ConvertRegionToWorkgroups.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Convert1X1FilterConv2DToMatmul.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Convert1X1FilterConv2DToMatmul.cpp
new file mode 100644
index 0000000..6773942
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Convert1X1FilterConv2DToMatmul.cpp
@@ -0,0 +1,193 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Flow {
+
+namespace {
+
+// Converts linalg.conv_2d_input_nhwc_filter_nhwc op to linalg.matmul
+template <typename Conv2DOpType>
+class Convert1x1FilterConvToMatmul : public OpRewritePattern<Conv2DOpType> {
+ public:
+  using OpRewritePattern<Conv2DOpType>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(Conv2DOpType convOp,
+                                PatternRewriter &rewriter) const override {
+    auto inputShapeType = convOp.getInputOperand(0)
+                              ->get()
+                              .getType()
+                              .template dyn_cast<RankedTensorType>();
+    auto filterShapeType = convOp.getInputOperand(1)
+                               ->get()
+                               .getType()
+                               .template dyn_cast<RankedTensorType>();
+    auto outputShapeType = convOp.getOutputOperand(0)
+                               ->get()
+                               .getType()
+                               .template dyn_cast<RankedTensorType>();
+
+    const bool isNCHW = isa<linalg::Conv2DNchwFchwOp>(convOp);
+    const bool isNHWC = isa<linalg::Conv2DNhwcHwcfOp>(convOp);
+    if (!isNCHW & !isNHWC) return failure();
+
+    if (!inputShapeType || !filterShapeType || !outputShapeType)
+      return failure();
+
+    auto inputShape = inputShapeType.getShape();
+    auto filterShape = filterShapeType.getShape();
+    auto outputShape = outputShapeType.getShape();
+
+    // Adjusting dimension indices based on Conv2DOpType.
+    const int nIndex = 0;
+    const int kcIndex = isNHWC ? 2 : 1;
+    const int kfIndex = isNHWC ? 3 : 0;
+    const int khIndex = isNHWC ? 0 : 2;
+    const int kwIndex = isNHWC ? 1 : 3;
+    const int ohIndex = isNHWC ? 1 : 2;
+    const int owIndex = isNHWC ? 2 : 3;
+    const int ocIndex = isNHWC ? 3 : 1;
+
+    bool isInputHWDynamic = inputShape[ohIndex] == ShapedType::kDynamicSize &&
+                            inputShape[owIndex] == ShapedType::kDynamicSize;
+
+    // We cannot merge the width and height if they are both dynamic as we
+    // cannot expand them back to their dynamic values.
+    if (isInputHWDynamic) return failure();
+
+    if (filterShape[khIndex] != 1 || filterShape[kwIndex] != 1)
+      return failure();
+
+    // TODO(ataei): Support conversion to linalg.batch_matmul.
+    if (inputShape[0] != 1) return failure();
+
+    if (!llvm::all_of(convOp.getStrides(), [](APInt element) {
+          return element.getSExtValue() == 1;
+        }))
+      return failure();
+    if (!llvm::all_of(convOp.getDilations(), [](APInt element) {
+          return element.getSExtValue() == 1;
+        }))
+      return failure();
+
+    auto combineDims = [](int64_t a, int64_t b) {
+      if (a == ShapedType::kDynamicSize || b == ShapedType::kDynamicSize)
+        return ShapedType::kDynamicSize;
+      return a * b;
+    };
+
+    SmallVector<ReassociationIndices, 4> reassociationInputOutputIndices;
+    SmallVector<ReassociationIndices, 4> reassociationFilterIndices;
+    SmallVector<int64_t> reshapedInputShape(2, 0);
+    SmallVector<int64_t> reshapedFilterShape(2, 0);
+    SmallVector<int64_t> reshapedOutputShape(2, 0);
+    if (isNHWC) {
+      // Generate reassociation indices.
+      reassociationInputOutputIndices = {{nIndex, ohIndex, owIndex}, {ocIndex}};
+      reassociationFilterIndices = {{khIndex, kwIndex, kcIndex}, {kfIndex}};
+
+      // Generate matmul shapes from 1x1 conv.
+      reshapedInputShape = {
+          combineDims(inputShape[ohIndex], inputShape[owIndex]),
+          inputShape[ocIndex]};
+      reshapedFilterShape = {filterShape[kcIndex], filterShape[kfIndex]};
+      reshapedOutputShape = {
+          combineDims(outputShape[ohIndex], outputShape[owIndex]),
+          outputShape[ocIndex]};
+    } else if (isNCHW) {
+      // Generate reassociation indices.
+      reassociationInputOutputIndices = {{nIndex, ocIndex}, {ohIndex, owIndex}};
+      reassociationFilterIndices = {{kfIndex}, {kcIndex, khIndex, kwIndex}};
+
+      // Generate matmul shapes from 1x1 conv.
+      reshapedInputShape = {
+          inputShape[ocIndex],
+          combineDims(inputShape[ohIndex], inputShape[owIndex])};
+      reshapedFilterShape = {filterShape[kfIndex], filterShape[kcIndex]};
+      reshapedOutputShape = {
+          outputShape[ocIndex],
+          combineDims(outputShape[ohIndex], outputShape[owIndex])};
+    }
+
+    auto reshapedInputType = RankedTensorType::get(
+        reshapedInputShape, inputShapeType.getElementType());
+
+    auto reshapedFilterType = RankedTensorType::get(
+        reshapedFilterShape, filterShapeType.getElementType());
+
+    auto reshapedOutputType = RankedTensorType::get(
+        reshapedOutputShape, outputShapeType.getElementType());
+
+    Value input = convOp.getInputOperand(0)->get();
+    Value filter = convOp.getInputOperand(1)->get();
+    Value output = convOp.getOutputOperand(0)->get();
+    auto loc = convOp.getLoc();
+
+    Value reshapedInput = rewriter.create<tensor::CollapseShapeOp>(
+        loc, reshapedInputType, input, reassociationInputOutputIndices);
+    Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
+        loc, reshapedFilterType, filter, reassociationFilterIndices);
+    Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
+        loc, reshapedOutputType, output, reassociationInputOutputIndices);
+
+    SmallVector<Value, 2> matmulInput;
+    if (isNHWC) {
+      matmulInput = {reshapedInput, reshapedFilter};
+    } else if (isNCHW) {
+      matmulInput = {reshapedFilter, reshapedInput};
+    }
+    auto matmulResult = rewriter.create<linalg::MatmulOp>(
+        loc, reshapedOutputType, matmulInput, ArrayRef<Value>{reshapedOutput});
+
+    auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
+        loc, outputShapeType, matmulResult.getResults()[0],
+        reassociationInputOutputIndices);
+
+    rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
+
+    return success();
+  }
+};
+
+struct Convert1X1FilterConv2DToMatmulPass
+    : public Convert1X1FilterConv2DToMatmulBase<
+          Convert1X1FilterConv2DToMatmulPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<linalg::LinalgDialect>();
+  }
+
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    RewritePatternSet patterns(&getContext());
+    patterns.insert<Convert1x1FilterConvToMatmul<linalg::Conv2DNhwcHwcfOp>,
+                    Convert1x1FilterConvToMatmul<linalg::Conv2DNchwFchwOp>>(
+        context);
+    if (failed(applyPatternsAndFoldGreedily(getOperation(),
+                                            std::move(patterns)))) {
+      return signalPassFailure();
+    }
+  }
+};
+}  // namespace
+
+std::unique_ptr<Pass> createConvert1X1FilterConv2DToMatmulPass() {
+  return std::make_unique<Convert1X1FilterConv2DToMatmulPass>();
+}
+
+}  // namespace Flow
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertConv2D1x1ToMatmul.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertConv2D1x1ToMatmul.cpp
deleted file mode 100644
index ef45208..0000000
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertConv2D1x1ToMatmul.cpp
+++ /dev/null
@@ -1,139 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace Flow {
-
-namespace {
-
-// Converts linalg.conv_2d_input_nhwc_filter_nhwc op to linalg.matmul
-class Convert1x1ConvolutionMatmulOp
-    : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
- public:
-  using OpRewritePattern<linalg::Conv2DNhwcHwcfOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
-                                PatternRewriter &rewriter) const override {
-    RankedTensorType inputShapeType =
-        convOp.getInputOperand(0)->get().getType().dyn_cast<RankedTensorType>();
-    RankedTensorType filterShapeType =
-        convOp.getInputOperand(1)->get().getType().dyn_cast<RankedTensorType>();
-    RankedTensorType outputShapeType = convOp.getOutputOperand(0)
-                                           ->get()
-                                           .getType()
-                                           .dyn_cast<RankedTensorType>();
-
-    if (!inputShapeType || !filterShapeType || !outputShapeType)
-      return failure();
-
-    auto inputShape = inputShapeType.getShape();
-    auto filterShape = filterShapeType.getShape();
-    auto outputShape = outputShapeType.getShape();
-
-    bool inputDynWidthHeight = inputShape[1] == ShapedType::kDynamicSize &&
-                               inputShape[2] == ShapedType::kDynamicSize;
-
-    // We cannot merge the width and height if they are both dynamic as we
-    // cannot expand them back to their dynamic values.
-    if (inputDynWidthHeight) return failure();
-
-    if (filterShape[0] != 1 || filterShape[1] != 1) return failure();
-
-    // TODO(ataei): Support conversion to linalg.batch_matmul.
-    if (inputShape[0] != 1) return failure();
-
-    if (!llvm::all_of(convOp.getStrides(), [](APInt element) {
-          return element.getSExtValue() == 1;
-        }))
-      return failure();
-    if (!llvm::all_of(convOp.getDilations(), [](APInt element) {
-          return element.getSExtValue() == 1;
-        }))
-      return failure();
-
-    SmallVector<ReassociationIndices, 4> reassociationIndices = {{0, 1, 2},
-                                                                 {3}};
-
-    auto combineDims = [](int64_t a, int64_t b) {
-      if (a == ShapedType::kDynamicSize || b == ShapedType::kDynamicSize)
-        return ShapedType::kDynamicSize;
-      return a * b;
-    };
-
-    auto reshapedInputType = RankedTensorType::get(
-        {combineDims(inputShape[1], inputShape[2]), inputShape[3]},
-        inputShapeType.getElementType());
-
-    auto reshapedFilterType = RankedTensorType::get(
-        {filterShape[2], filterShape[3]}, filterShapeType.getElementType());
-
-    auto reshapedOutputType = RankedTensorType::get(
-        {combineDims(outputShape[1], outputShape[2]), outputShape[3]},
-        outputShapeType.getElementType());
-
-    Value input = convOp.getInputOperand(0)->get();
-    Value filter = convOp.getInputOperand(1)->get();
-    Value output = convOp.getOutputOperand(0)->get();
-    auto loc = convOp.getLoc();
-
-    Value reshapedInput = rewriter.create<tensor::CollapseShapeOp>(
-        loc, reshapedInputType, input, reassociationIndices);
-    Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
-        loc, reshapedFilterType, filter, reassociationIndices);
-    Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
-        loc, reshapedOutputType, output, reassociationIndices);
-
-    auto matmulResult = rewriter.create<linalg::MatmulOp>(
-        loc, reshapedOutputType, ArrayRef<Value>{reshapedInput, reshapedFilter},
-        ArrayRef<Value>{reshapedOutput});
-
-    auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
-        loc, outputShapeType, matmulResult.getResults()[0],
-        reassociationIndices);
-
-    rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
-
-    return success();
-  }
-};
-
-struct ConvertConv2D1x1ConvToMatmulPass
-    : public ConvertConv2D1x1ConvToMatmulBase<
-          ConvertConv2D1x1ConvToMatmulPass> {
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<linalg::LinalgDialect>();
-  }
-
-  void runOnOperation() override {
-    MLIRContext *context = &getContext();
-    RewritePatternSet patterns(&getContext());
-    patterns.insert<Convert1x1ConvolutionMatmulOp>(context);
-    if (failed(applyPatternsAndFoldGreedily(getOperation(),
-                                            std::move(patterns)))) {
-      return signalPassFailure();
-    }
-  }
-};
-}  // namespace
-
-std::unique_ptr<Pass> createConvertConv2D1x1ToMatmulPass() {
-  return std::make_unique<ConvertConv2D1x1ConvToMatmulPass>();
-}
-
-}  // namespace Flow
-}  // namespace IREE
-}  // namespace iree_compiler
-}  // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index 11ffd5e..84540d4 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -202,7 +202,7 @@
 
   // Special case peephole optimizations.
   FunctionLikeNest(passManager)
-      .addPass(IREE::Flow::createConvertConv2D1x1ToMatmulPass)
+      .addPass(IREE::Flow::createConvert1X1FilterConv2DToMatmulPass)
       .addPredicatedPass(clEnableConvToImg2Col,
                          IREE::Flow::createConvertConv2DToImg2ColPass)
       .addPredicatedPass(
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
index 71a1226..b11318b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -73,7 +73,7 @@
 
 // Creates a pass to convert linalg convolution ops with 1x1 kernels into
 // linalg.matmul
-std::unique_ptr<Pass> createConvertConv2D1x1ToMatmulPass();
+std::unique_ptr<Pass> createConvert1X1FilterConv2DToMatmulPass();
 
 // Creates a pass to convert linalg convolution ops into linalg.matmul ops
 // using im2col tranformation.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
index 3c4f6f8..cf008a2 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -26,10 +26,10 @@
   let constructor = "mlir::iree_compiler::IREE::Flow::createCleanupNumericNarrowingPass()";
 }
 
-def ConvertConv2D1x1ConvToMatmul :
-    Pass<"iree-flow-convert-conv2d-1x1-to-matmul", ""> {
+def Convert1X1FilterConv2DToMatmul:
+    Pass<"iree-flow-convert-1x1-filter-conv2d-to-matmul", ""> {
   let summary = "Convert linalg convolution ops with 1x1 kernels into linalg matrix multiplication ops.";
-  let constructor = "mlir::iree_compiler::IREE::Flow::createConvertConv2D1x1ToMatmulPass()";
+  let constructor = "mlir::iree_compiler::IREE::Flow::createConvert1X1FilterConv2DToMatmulPass()";
 }
 
 def ConvertConv2DToImg2Col :
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/conv1x1_to_matmul.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/conv1x1_to_matmul.mlir
index 3956456..ea2b897 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/conv1x1_to_matmul.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/conv1x1_to_matmul.mlir
@@ -1,6 +1,6 @@
-// RUN: iree-opt --split-input-file -iree-flow-convert-conv2d-1x1-to-matmul %s | FileCheck %s
+// RUN: iree-opt --split-input-file -iree-flow-convert-1x1-filter-conv2d-to-matmul %s | FileCheck %s
 
-func.func @conv_2d_1x1(%input: tensor<1x4x5x2xf32>, %filter: tensor<1x1x2x7xf32>) -> tensor<1x4x5x7xf32> {
+func.func @nhwc_conv_2d(%input: tensor<1x4x5x2xf32>, %filter: tensor<1x1x2x7xf32>) -> tensor<1x4x5x7xf32> {
     %0 = linalg.init_tensor [1, 4, 5, 7] : tensor<1x4x5x7xf32>
     %1 = linalg.conv_2d_nhwc_hwcf {
         dilations = dense<1> : tensor<2xi64>,
@@ -8,21 +8,22 @@
     } ins(%input, %filter : tensor<1x4x5x2xf32>, tensor<1x1x2x7xf32>) outs(%0 : tensor<1x4x5x7xf32>) -> tensor<1x4x5x7xf32>
     return %1 : tensor<1x4x5x7xf32>
 }
-// CHECK: @conv_2d_1x1
+
+// CHECK: @nhwc_conv_2d
 // CHECK: %[[INPUT:.+]]: tensor<1x4x5x2xf32>
 // CHECK: %[[FILTER:.+]]: tensor<1x1x2x7xf32>
-// CHECK: %[[OTUPUT:.+]] = linalg.init_tensor [1, 4, 5, 7] : tensor<1x4x5x7xf32>
+// CHECK: %[[OUTPUT:.+]] = linalg.init_tensor [1, 4, 5, 7] : tensor<1x4x5x7xf32>
 // CHECK: %[[RESHAPED_INPUT:.+]] = tensor.collapse_shape %[[INPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x5x2xf32> into tensor<20x2xf32>
 // CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x2x7xf32> into tensor<2x7xf32>
-// CHECK: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OTUPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x5x7xf32> into tensor<20x7xf32>
+// CHECK: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x5x7xf32> into tensor<20x7xf32>
 // CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INPUT]], %[[RESHAPED_FILTER]] : tensor<20x2xf32>, tensor<2x7xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<20x7xf32>)
 // CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<20x7xf32> into tensor<1x4x5x7xf32>
 // CHECK: return %[[RESULT]]
 
 // -----
 
-// CHECK: @conv_2d_1x1_dyn
-func.func @conv_2d_1x1_dyn(%input: tensor<1x4x?x2xf32>, %filter: tensor<1x1x2x7xf32>) -> tensor<1x4x?x7xf32> {
+// CHECK: @dynamic_nhwc_conv_2d
+func.func @dynamic_nhwc_conv_2d(%input: tensor<1x4x?x2xf32>, %filter: tensor<1x1x2x7xf32>) -> tensor<1x4x?x7xf32> {
     %c2 = arith.constant 2 : index
     %d2 = tensor.dim %input, %c2 : tensor<1x4x?x2xf32>
     %0 = linalg.init_tensor [1, 4, %d2, 7] : tensor<1x4x?x7xf32>
@@ -37,16 +38,16 @@
 // CHECK: %[[FILTER:.+]]: tensor<1x1x2x7xf32>
 // CHECK: %[[C2:.+]] = arith.constant 2 : index
 // CHECK: %[[D2:.+]] = tensor.dim %[[INPUT]], %[[C2]]
-// CHECK: %[[OTUPUT:.+]] = linalg.init_tensor [1, 4, %[[D2]], 7] : tensor<1x4x?x7xf32>
+// CHECK: %[[OUTPUT:.+]] = linalg.init_tensor [1, 4, %[[D2]], 7] : tensor<1x4x?x7xf32>
 // CHECK: %[[RESHAPED_INPUT:.+]] = tensor.collapse_shape %[[INPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x?x2xf32> into tensor<?x2xf32>
 // CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x2x7xf32> into tensor<2x7xf32>
-// CHECK: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OTUPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x?x7xf32> into tensor<?x7xf32>
+// CHECK: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x?x7xf32> into tensor<?x7xf32>
 // CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INPUT]], %[[RESHAPED_FILTER]] : tensor<?x2xf32>, tensor<2x7xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<?x7xf32>)
 // CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<?x7xf32> into tensor<1x4x?x7xf32>
 
 // -----
 
-func.func @conv_2d_1x1_fail_dyn(%input: tensor<1x?x?x2xf32>, %filter: tensor<1x1x2x7xf32>) -> tensor<1x?x?x7xf32> {
+func.func @fail_dynamic_nhwc_conv_2d(%input: tensor<1x?x?x2xf32>, %filter: tensor<1x1x2x7xf32>) -> tensor<1x?x?x7xf32> {
     %c1 = arith.constant 1 : index
     %c2 = arith.constant 2 : index
     %d1 = tensor.dim %input, %c1 : tensor<1x?x?x2xf32>
@@ -59,5 +60,70 @@
     return %1 : tensor<1x?x?x7xf32>
 }
 
-// CHECK: @conv_2d_1x1_fail_dyn
+// CHECK: @fail_dynamic_nhwc_conv_2d
 // CHECK: linalg.conv_2d_nhwc_hwcf
+
+// -----
+
+func.func @nchw_conv_2d(%input: tensor<1x2x4x5xf32>, %filter: tensor<7x2x1x1xf32>) -> tensor<1x7x4x5xf32> {
+    %0 = linalg.init_tensor [1, 7, 4, 5] : tensor<1x7x4x5xf32>
+    %1 = linalg.conv_2d_nchw_fchw {
+        dilations = dense<1> : tensor<2xi64>,
+        strides = dense<1> : tensor<2xi64>
+    } ins(%input, %filter : tensor<1x2x4x5xf32>, tensor<7x2x1x1xf32>) outs(%0 : tensor<1x7x4x5xf32>) -> tensor<1x7x4x5xf32>
+    return %1 : tensor<1x7x4x5xf32>
+}
+// CHECK: @nchw_conv_2d
+// CHECK: %[[INPUT:.+]]: tensor<1x2x4x5xf32>
+// CHECK: %[[FILTER:.+]]: tensor<7x2x1x1xf32>
+// CHECK: %[[OUTPUT:.+]] = linalg.init_tensor [1, 7, 4, 5] : tensor<1x7x4x5xf32>
+// CHECK: %[[RESHAPED_INPUT:.+]] = tensor.collapse_shape %[[INPUT]] {{\[}}[0, 1], [2, 3]] : tensor<1x2x4x5xf32> into tensor<2x20xf32>
+// CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<7x2x1x1xf32> into tensor<7x2xf32>
+// CHECK: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0, 1], [2, 3]] : tensor<1x7x4x5xf32> into tensor<7x20xf32>
+// CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_FILTER]], %[[RESHAPED_INPUT]] : tensor<7x2xf32>, tensor<2x20xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<7x20xf32>)
+// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1], [2, 3]] : tensor<7x20xf32> into tensor<1x7x4x5xf32>
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @dynamic_nchw_conv_2d(%input: tensor<1x2x4x?xf32>, %filter: tensor<7x2x1x1xf32>) -> tensor<1x7x4x?xf32> {
+    %c3 = arith.constant 3 : index
+    %d3 = tensor.dim %input, %c3 : tensor<1x2x4x?xf32>
+    %0 = linalg.init_tensor [1, 7, 4, %d3] : tensor<1x7x4x?xf32>
+    %1 = linalg.conv_2d_nchw_fchw {
+        dilations = dense<1> : tensor<2xi64>,
+        strides = dense<1> : tensor<2xi64>
+    } ins(%input, %filter : tensor<1x2x4x?xf32>, tensor<7x2x1x1xf32>) outs(%0 : tensor<1x7x4x?xf32>) -> tensor<1x7x4x?xf32>
+    return %1 : tensor<1x7x4x?xf32>
+}
+
+// CHECK: @dynamic_nchw_conv_2d
+// CHECK: %[[INPUT:.+]]: tensor<1x2x4x?xf32>
+// CHECK: %[[FILTER:.+]]: tensor<7x2x1x1xf32>
+// CHECK: %[[C3:.+]] = arith.constant 3 : index
+// CHECK: %[[D3:.+]] = tensor.dim %[[INPUT]], %[[C3]]
+// CHECK: %[[OUTPUT:.+]] = linalg.init_tensor [1, 7, 4, %[[D3]]] : tensor<1x7x4x?xf32>
+// CHECK: %[[RESHAPED_INPUT:.+]] = tensor.collapse_shape %[[INPUT]] {{\[}}[0, 1], [2, 3]] : tensor<1x2x4x?xf32> into tensor<2x?xf32>
+// CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<7x2x1x1xf32> into tensor<7x2xf32>
+// CHECK: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0, 1], [2, 3]] : tensor<1x7x4x?xf32> into tensor<7x?xf32>
+// CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_FILTER]], %[[RESHAPED_INPUT]] : tensor<7x2xf32>, tensor<2x?xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<7x?xf32>)
+// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1], [2, 3]] : tensor<7x?xf32> into tensor<1x7x4x?xf32>
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @fail_dynamic_nchw_conv_2d(%input: tensor<1x2x?x?xf32>, %filter: tensor<7x2x1x1xf32>) -> tensor<1x7x?x?xf32> {
+    %c2 = arith.constant 2 : index
+    %c3 = arith.constant 3 : index
+    %d2 = tensor.dim %input, %c2 : tensor<1x2x?x?xf32>
+    %d3 = tensor.dim %input, %c3 : tensor<1x2x?x?xf32>
+    %0 = linalg.init_tensor [1, 7, %d2, %d3] : tensor<1x7x?x?xf32>
+    %1 = linalg.conv_2d_nchw_fchw {
+        dilations = dense<1> : tensor<2xi64>,
+        strides = dense<1> : tensor<2xi64>
+    } ins(%input, %filter : tensor<1x2x?x?xf32>, tensor<7x2x1x1xf32>) outs(%0 : tensor<1x7x?x?xf32>) -> tensor<1x7x?x?xf32>
+    return %1 : tensor<1x7x?x?xf32>
+}
+
+// CHECK: @fail_dynamic_nchw_conv_2d
+// CHECK: linalg.conv_2d_nchw_fchw