Adapt mhlo.conv to Linalg on buffers patterns to be on tensors. (#4748)

Lowers a conv op to a static rank conv op. The general conv op only
supports on memrefs. This limits the support for conv op to specific
data layout and only work for few ranks, ie 1d, 2d, and 3d. This
separates the lowering to normal conv and depthwise conv. The former one
is lowering to Linalg on tensors, and the latter one is leaving with the
same behavior, ie, lowering to Linalg on buffers.

Add support of TC conv ops to later GPU passes.

Depends on https://reviews.llvm.org/D96038
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index 1086851..d972351 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -325,13 +325,6 @@
     ArrayRef<Value> resultBuffers, ConversionPatternRewriter &rewriter) const {
   if (!hasCanonicalDimensionNumbers(op.dimension_numbers())) return failure();
 
-  llvm::SmallVector<Attribute, 4> strides;
-  if (auto windowStrides = op.window_strides()) {
-    auto range = windowStrides->getAttributeValues();
-    strides.append(range.begin(), range.end());
-  }
-  auto stridesArg = ArrayAttr::get(op.getContext(), strides);
-
   // TODO(ataei): Only support dilated convolution for now. We need to consider
   // LHS dilation for deconvolution cases.
   llvm::SmallVector<Attribute, 4> dilation;
@@ -339,7 +332,6 @@
     auto range = rhsDilation->getAttributeValues();
     dilation.append(range.begin(), range.end());
   }
-  auto dilationArg = ArrayAttr::get(op.getContext(), dilation);
 
   // Set padding only if it is non-zero.
   DenseIntElementsAttr padding = op.paddingAttr();
@@ -361,92 +353,88 @@
       shape[op.dimension_numbers().kernel_input_feature_dimension().getInt()];
   auto groupSize =
       shape[op.dimension_numbers().kernel_output_feature_dimension().getInt()];
-  // Depthwise conv path...
-  if (op.feature_group_count() > 1u && op.feature_group_count() == numGroups) {
-    // Lowering depthwise convolution to linalg.generic op. The idea is to use
-    // the group convolution formulation to perform the separable depthwise
-    // convolution as the following, given an n-dimensional input x and filter w
-    // the direct convolution operation can be written as:
-    //  y[n, d1, d2, ....dn, ci * groupSize + co] = sum(k1, k2, ....kn,
-    // x[n, d1 * stride1 + k1, d1 * stride2 + k2, ...dn * striden + kn]
-    // * w[k1, k2, ...kn, ci, co])
-
-    // TODO(ataei): Support dilation.
-    if (llvm::any_of(dilation, [](Attribute attr) {
-          return (attr.dyn_cast<IntegerAttr>().getInt() != 1);
-        })) {
-      return failure();
-    }
-
-    SmallVector<AffineExpr, 4> inputExprs;
-    SmallVector<AffineExpr, 4> filterExprs;
-    SmallVector<AffineExpr, 4> outputExprs;
-
-    const auto spatialDims =
-        llvm::size(op.dimension_numbers().input_spatial_dimensions());
-    const int d1Index = 1;
-    const int coIndex = d1Index + spatialDims;
-    const int ciIndex = coIndex + 1;
-    const int k1Index = ciIndex + 1;
-    // n, d1 * stride1 + k1, d1 * stride2 + k2, ...dn * striden + kn
-    inputExprs.push_back(rewriter.getAffineDimExpr(0));
-    for (int i = 0; i < spatialDims; ++i) {
-      if (op.window_stridesAttr()) {
-        auto stride = op.window_stridesAttr().getValue<APInt>(i);
-        inputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i) *
-                                 stride.getZExtValue() +
-                             rewriter.getAffineDimExpr(k1Index + i));
-      } else {
-        inputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i) +
-                             rewriter.getAffineDimExpr(k1Index + i));
-      }
-    }
-    inputExprs.push_back(rewriter.getAffineDimExpr(ciIndex));
-
-    // k1, k2, ...kn, ci, co
-    for (int i = 0; i < spatialDims; ++i) {
-      filterExprs.push_back(rewriter.getAffineDimExpr(k1Index + i));
-    }
-    filterExprs.push_back(rewriter.getAffineDimExpr(ciIndex));
-    filterExprs.push_back(rewriter.getAffineDimExpr(coIndex));
-
-    // n, d1, d2, ....dn, ci * groupSize + co
-    outputExprs.push_back(rewriter.getAffineDimExpr(0));
-    for (int i = 0; i < spatialDims; ++i) {
-      outputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i));
-    }
-    outputExprs.push_back(rewriter.getAffineDimExpr(ciIndex) * groupSize +
-                          rewriter.getAffineDimExpr(coIndex));
-
-    // nloops = |d| + |k| + |{n, ci, co}|
-    int nloops = spatialDims * 2 + 3;
-    SmallVector<AffineMap, 4> indexingMaps;
-    indexingMaps.emplace_back(AffineMap::get(
-        nloops, /*symbolCount=*/0, inputExprs, rewriter.getContext()));
-    indexingMaps.emplace_back(AffineMap::get(
-        nloops, /*symbolCount=*/0, filterExprs, rewriter.getContext()));
-    indexingMaps.emplace_back(AffineMap::get(
-        nloops, /*symbolCount=*/0, outputExprs, rewriter.getContext()));
-
-    Location loc = op.getLoc();
-
-    SmallVector<StringRef, 3> loopAttributeTypes(spatialDims + 3, "parallel");
-    loopAttributeTypes.append(spatialDims, "reduction");
-    rewriter.create<linalg::GenericOp>(
-        loc,
-        /*resultTensorTypes=*/ArrayRef<Type>{},
-        /*inputs=*/inputBuffers,
-        /*outputs=*/resultBuffers, indexingMaps, loopAttributeTypes,
-        [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
-          Value mul = nestedBuilder.create<MulFOp>(nestedLoc, args[0], args[1]);
-          Value add = nestedBuilder.create<AddFOp>(nestedLoc, mul, args[2]);
-          nestedBuilder.create<linalg::YieldOp>(loc, add);
-        });
-  } else {
-    rewriter.create<linalg::ConvOp>(op.getLoc(), inputBuffers[1],
-                                    inputBuffers[0], resultBuffers[0],
-                                    stridesArg, dilationArg, padding);
+  if (op.feature_group_count() <= 1u || op.feature_group_count() != numGroups) {
+    return failure();
   }
+  // Lowering depthwise convolution to linalg.generic op. The idea is to use
+  // the group convolution formulation to perform the separable depthwise
+  // convolution as the following, given an n-dimensional input x and filter w
+  // the direct convolution operation can be written as:
+  //  y[n, d1, d2, ....dn, ci * groupSize + co] = sum(k1, k2, ....kn,
+  // x[n, d1 * stride1 + k1, d1 * stride2 + k2, ...dn * striden + kn]
+  // * w[k1, k2, ...kn, ci, co])
+
+  // TODO(ataei): Support dilation.
+  if (llvm::any_of(dilation, [](Attribute attr) {
+        return (attr.dyn_cast<IntegerAttr>().getInt() != 1);
+      })) {
+    return failure();
+  }
+
+  SmallVector<AffineExpr, 4> inputExprs;
+  SmallVector<AffineExpr, 4> filterExprs;
+  SmallVector<AffineExpr, 4> outputExprs;
+
+  const auto spatialDims =
+      llvm::size(op.dimension_numbers().input_spatial_dimensions());
+  const int d1Index = 1;
+  const int coIndex = d1Index + spatialDims;
+  const int ciIndex = coIndex + 1;
+  const int k1Index = ciIndex + 1;
+  // n, d1 * stride1 + k1, d1 * stride2 + k2, ...dn * striden + kn
+  inputExprs.push_back(rewriter.getAffineDimExpr(0));
+  for (int i = 0; i < spatialDims; ++i) {
+    if (op.window_stridesAttr()) {
+      auto stride = op.window_stridesAttr().getValue<APInt>(i);
+      inputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i) *
+                               stride.getZExtValue() +
+                           rewriter.getAffineDimExpr(k1Index + i));
+    } else {
+      inputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i) +
+                           rewriter.getAffineDimExpr(k1Index + i));
+    }
+  }
+  inputExprs.push_back(rewriter.getAffineDimExpr(ciIndex));
+
+  // k1, k2, ...kn, ci, co
+  for (int i = 0; i < spatialDims; ++i) {
+    filterExprs.push_back(rewriter.getAffineDimExpr(k1Index + i));
+  }
+  filterExprs.push_back(rewriter.getAffineDimExpr(ciIndex));
+  filterExprs.push_back(rewriter.getAffineDimExpr(coIndex));
+
+  // n, d1, d2, ....dn, ci * groupSize + co
+  outputExprs.push_back(rewriter.getAffineDimExpr(0));
+  for (int i = 0; i < spatialDims; ++i) {
+    outputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i));
+  }
+  outputExprs.push_back(rewriter.getAffineDimExpr(ciIndex) * groupSize +
+                        rewriter.getAffineDimExpr(coIndex));
+
+  // nloops = |d| + |k| + |{n, ci, co}|
+  int nloops = spatialDims * 2 + 3;
+  SmallVector<AffineMap, 4> indexingMaps;
+  indexingMaps.emplace_back(AffineMap::get(nloops, /*symbolCount=*/0,
+                                           inputExprs, rewriter.getContext()));
+  indexingMaps.emplace_back(AffineMap::get(nloops, /*symbolCount=*/0,
+                                           filterExprs, rewriter.getContext()));
+  indexingMaps.emplace_back(AffineMap::get(nloops, /*symbolCount=*/0,
+                                           outputExprs, rewriter.getContext()));
+
+  Location loc = op.getLoc();
+
+  SmallVector<StringRef, 3> loopAttributeTypes(spatialDims + 3, "parallel");
+  loopAttributeTypes.append(spatialDims, "reduction");
+  rewriter.create<linalg::GenericOp>(
+      loc,
+      /*resultTensorTypes=*/ArrayRef<Type>{},
+      /*inputs=*/inputBuffers,
+      /*outputs=*/resultBuffers, indexingMaps, loopAttributeTypes,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        Value mul = nestedBuilder.create<MulFOp>(nestedLoc, args[0], args[1]);
+        Value add = nestedBuilder.create<AddFOp>(nestedLoc, mul, args[2]);
+        nestedBuilder.create<linalg::YieldOp>(loc, add);
+      });
   return success();
 }
 
@@ -838,20 +826,26 @@
   }
 };
 
-/// Converts linalg.matmul-ish on tensors to linalg.matmul-ish on buffers.
+/// Converts a linalg named op on tensors to linalg named op on buffers.
 template <typename LinalgOpTy>
-struct MatmulOnTensorConversion
-    : public ConvertToLinalgBufferOp<MatmulOnTensorConversion<LinalgOpTy>,
+struct NamedOpConversion
+    : public ConvertToLinalgBufferOp<NamedOpConversion<LinalgOpTy>,
                                      LinalgOpTy> {
-  using ConvertToLinalgBufferOp<MatmulOnTensorConversion<LinalgOpTy>,
+  using ConvertToLinalgBufferOp<NamedOpConversion<LinalgOpTy>,
                                 LinalgOpTy>::ConvertToLinalgBufferOp;
   LogicalResult apply(LinalgOpTy op, ArrayRef<Value> inputBuffers,
                       ArrayRef<Value> resultBuffers,
                       ConversionPatternRewriter &rewriter) const {
     if (!op.hasTensorSemantics()) return failure();
-    // The last one is a init tensor.
-    rewriter.create<LinalgOpTy>(
-        op.getLoc(), inputBuffers.drop_back(op.getNumResults()), resultBuffers);
+    auto linalgOp = cast<linalg::LinalgOp>(op.getOperation());
+    SmallVector<Value, 8> newOperands;
+    newOperands.append(inputBuffers.begin(),
+                       inputBuffers.end() - op.getNumResults());
+    newOperands.append(resultBuffers.begin(), resultBuffers.end());
+    auto otherOperands = linalgOp.getAssumedNonShapedOperands();
+    newOperands.append(otherOperands.begin(), otherOperands.end());
+    Location loc = op.getLoc();
+    linalgOp.clone(rewriter, loc, /*resultTypes=*/TypeRange{}, newOperands);
     return success();
   }
 };
@@ -1244,12 +1238,16 @@
 void populateHLOToLinalgOnBuffersConversionPatterns(
     MLIRContext *context, OwningRewritePatternList &patterns,
     TensorToBufferMap const &resultTensorToBufferMap) {
-  patterns.insert<ConvOpConversion, ConcatenateOpConversion,
-                  FillOpOnTensorConversion, InitTensorOpConversion,
+  patterns.insert<ConvOpConversion, DepthwiseConvOpConversion,
+                  ConcatenateOpConversion, FillOpOnTensorConversion,
+                  InitTensorOpConversion,
                   LinalgOpOnTensorConversion<linalg::GenericOp>,
                   LinalgOpOnTensorConversion<linalg::IndexedGenericOp>,
-                  MatmulOnTensorConversion<linalg::MatmulOp>,
-                  MatmulOnTensorConversion<linalg::BatchMatmulOp>,
+                  NamedOpConversion<linalg::ConvInputNWCFilterWCFOp>,
+                  NamedOpConversion<linalg::ConvInputNHWCFilterHWCFOp>,
+                  NamedOpConversion<linalg::ConvInputNDHWCFilterDHWCFOp>,
+                  NamedOpConversion<linalg::MatmulOp>,
+                  NamedOpConversion<linalg::BatchMatmulOp>,
                   PadTensorOpConversion, ReduceWindowOpConversion,
                   SubTensorOpConversion, TensorReshapeOpConversion>(
       context, resultTensorToBufferMap);
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
index d088a14..c145e90 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
@@ -251,6 +251,156 @@
 };
 }  // namespace
 
+//===----------------------------------------------------------------------===//
+// mhlo.conv conversion patterns.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+static bool isDepthwiseConv(mhlo::ConvOp op) {
+  auto shape = op.rhs().getType().cast<ShapedType>().getShape();
+  auto numGroups =
+      shape[op.dimension_numbers().kernel_input_feature_dimension().getInt()];
+  return op.feature_group_count() > 1u && op.feature_group_count() == numGroups;
+}
+
+/// Returns true if the given `dimensionNumbers` from a mhlo.convolution op
+/// follows a canonical form:
+///
+/// * Input dimensions have order: (batch_count, spatial_dims,
+///   input_channel_count).
+/// * Filter dimensions have order: (spatial_dims, input_channel_count,
+///   output_channel_count).
+/// * Output dimensions have order: (batch_count, spatial_dims,
+///   output_channel_count).
+static bool hasCanonicalDimensionNumbers(
+    const mhlo::ConvDimensionNumbers &dimensionNumbers) {
+  const int inputSpatialRank =
+      llvm::size(dimensionNumbers.input_spatial_dimensions());
+  // The dimensions for input should follow the order of
+  // batch_count, spatial_dims..., input_feature_count.
+  if (dimensionNumbers.input_batch_dimension().getInt() != 0 ||
+      dimensionNumbers.input_feature_dimension().getInt() !=
+          (inputSpatialRank + 1)) {
+    return false;
+  }
+
+  const int kernelSpatialRank =
+      llvm::size(dimensionNumbers.kernel_spatial_dimensions());
+  // The dimensions for filter should follow the order of
+  // spatial_dims..., input_feature_count, num_output_feature_count.
+  if (dimensionNumbers.kernel_input_feature_dimension().getInt() !=
+          kernelSpatialRank ||
+      dimensionNumbers.kernel_output_feature_dimension().getInt() !=
+          (kernelSpatialRank + 1)) {
+    return false;
+  }
+
+  const int outputSpatialRank =
+      llvm::size(dimensionNumbers.output_spatial_dimensions());
+  // The dimensions for output should follow the order of
+  // batch_count, spatial_dims.., output_feature_count.
+  if (dimensionNumbers.output_batch_dimension().getInt() != 0 ||
+      dimensionNumbers.output_feature_dimension().getInt() !=
+          (outputSpatialRank + 1)) {
+    return false;
+  }
+
+  if (inputSpatialRank != outputSpatialRank ||
+      inputSpatialRank != kernelSpatialRank) {
+    return false;
+  }
+
+  auto inputSpatialDim = dimensionNumbers.input_spatial_dimensions().begin();
+  auto kernelSpatialDim = dimensionNumbers.kernel_spatial_dimensions().begin();
+  auto outputSpatialDim = dimensionNumbers.output_spatial_dimensions().begin();
+  // Check spatial dims are ordred correctly.
+  for (int i = 0; i < inputSpatialRank; ++i) {
+    const int dim = i + 1;
+    if ((*inputSpatialDim++).getZExtValue() != dim ||
+        (*outputSpatialDim++).getZExtValue() != dim ||
+        (*kernelSpatialDim++).getZExtValue() != i) {
+      return false;
+    }
+  }
+
+  return true;
+}
+
+/// Converts mhlo.conv operation to linalg named op. This only covers normal
+/// convolution cases. The op must have canonical dimension numbers. Depthwise
+/// convolution and pointwise convolution are not handled in the conversion.
+struct NormalConvOpConversion : public OpConversionPattern<mhlo::ConvOp> {
+  using OpConversionPattern<mhlo::ConvOp>::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      mhlo::ConvOp op, ArrayRef<Value> args,
+      ConversionPatternRewriter &rewriter) const override {
+    if (!hasCanonicalDimensionNumbers(op.dimension_numbers())) return failure();
+    if (isDepthwiseConv(op)) return failure();
+
+    mhlo::ConvOp::Adaptor adaptor(args);
+    Location loc = op.getLoc();
+    Value input = adaptor.lhs();
+    Value filter = adaptor.rhs();
+    auto resultType = op.getResult().getType().cast<ShapedType>();
+    int rank = resultType.getRank();
+
+    // Check if padding is zero.
+    DenseIntElementsAttr padding = op.paddingAttr();
+    if (padding &&
+        (!padding.isSplat() || padding.getSplatValue<int64_t>() != 0)) {
+      return rewriter.notifyMatchFailure(op, "expected no padding");
+    }
+
+    // The output shape is N spatial_dims F.
+    SmallVector<Value, 8> dynSizes;
+    for (int i = 0, e = rank - 1; i < e; ++i) {
+      if (!resultType.isDynamicDim(i)) continue;
+      dynSizes.push_back(rewriter.create<DimOp>(loc, input, i));
+    }
+    if (resultType.isDynamicDim(rank - 1)) {
+      dynSizes.push_back(rewriter.create<DimOp>(loc, filter, rank - 1));
+    }
+    Value initTensor = rewriter.create<linalg::InitTensorOp>(
+        loc, dynSizes, resultType.getShape(), resultType.getElementType());
+    auto zeroAttr = rewriter.getZeroAttr(resultType.getElementType());
+    Value zero = rewriter.create<ConstantOp>(loc, zeroAttr);
+    Value zeroTensor =
+        rewriter.create<linalg::FillOp>(loc, initTensor, zero).getResult(0);
+    linalg::LinalgOp res;
+    Attribute strides = op.window_stridesAttr();
+    // TODO(ataei): Only support dilated kernel right now. We need to consider
+    // input dilation for deconvolution cases.
+    Attribute dilations = op.rhs_dilationAttr();
+    switch (rank) {
+      case 3: {
+        res = rewriter.create<linalg::ConvInputNWCFilterWCFOp>(
+            loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
+            dilations, strides);
+        break;
+      }
+      case 4: {
+        res = rewriter.create<linalg::ConvInputNHWCFilterHWCFOp>(
+            loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
+            dilations, strides);
+        break;
+      }
+      case 5: {
+        res = rewriter.create<linalg::ConvInputNDHWCFilterDHWCFOp>(
+            loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
+            dilations, strides);
+        break;
+      }
+      default:
+        return rewriter.notifyMatchFailure(op, "expected 1/2/3D conv op");
+    }
+    rewriter.replaceOp(op, res.getOperation()->getResults());
+    return success();
+  }
+};
+}  // namespace
+
 struct ConvertHLOToLinalgOnTensorsPass
     : public PassWrapper<ConvertHLOToLinalgOnTensorsPass, FunctionPass> {
   void getDependentDialects(DialectRegistry &registry) const override {
@@ -295,7 +445,8 @@
     MLIRContext *context, OwningRewritePatternList &patterns) {
   mhlo::populateHLOToLinalgConversionPattern(context, &patterns);
   patterns.insert<TorchIndexSelectOpConversion, SliceOpConversion,
-                  ConstOpConversion, PadOpConversion>(context);
+                  ConstOpConversion, PadOpConversion, NormalConvOpConversion>(
+      context);
 }
 
 std::unique_ptr<OperationPass<FuncOp>> createHLOToLinalgOnTensorsPass() {
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/conv.mlir b/iree/compiler/Conversion/HLOToLinalg/test/conv.mlir
index 4006f02..008b380 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/conv.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/conv.mlir
@@ -1,123 +1,161 @@
-// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-on-buffers %s | IreeFileCheck %s
+// RUN: iree-opt -iree-codegen-hlo-to-linalg-on-tensors -canonicalize %s | IreeFileCheck %s
 
-module {
-  // CHECK-LABEL: func @conv
-  func @conv() {
-    %c0 = constant 0 : index
-    %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<3x5x5x3xf32>
-    %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<2x2x3x4xf32>
-    //      CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}}) {
-    // CHECK-SAME: dilations = [1, 2]
-    // CHECK-SAME: padding = dense<[
-    // CHECK-SAME:                  [0, 1], [0, 1]]> : tensor<2x2xi64>
-    // CHECK-SAME: strides = [2, 1]}
-    %2 = "mhlo.convolution"(%1, %0) {
-      batch_group_count = 1 : i64,
-      dimension_numbers = {
-        input_batch_dimension = 0 : i64,
-        input_feature_dimension = 3 : i64,
-        input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
-        kernel_input_feature_dimension = 2 : i64,
-        kernel_output_feature_dimension = 3 : i64,
-        kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
-        output_batch_dimension = 0 : i64,
-        output_feature_dimension = 3 : i64,
-        output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
-      },
-      feature_group_count = 1 : i64,
-      padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
-      rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
-      window_strides = dense<[2, 1]> : tensor<2xi64>} : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32>
-    hal.interface.store.tensor %2, @legacy_io::@ret0, offset = %c0 : tensor<3x5x5x4xf32>
-    return
-  }
-  hal.interface @legacy_io attributes {sym_visibility = "private"} {
-    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
-  }
+// CHECK-LABEL: func @linalg.conv_1d_input_nwc_filter_wcf
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]]
+// CHECK-DAG:     %[[C0:.+]] = constant 0 : index
+// CHECK-DAG:     %[[C1:.+]] = constant 1 : index
+// CHECK-DAG:     %[[C2:.+]] = constant 2 : index
+// CHECK-DAG:     %[[ZERO:.+]] = constant 0.000000e+00 : f32
+// CHECK:         %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
+// CHECK:         %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
+// CHECK:         %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
+// CHECK:         %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]]]
+// CHECK:         %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
+// CHECK:         linalg.conv_1d_input_nwc_filter_wcf
+// CHECK-SAME:      {dilations = dense<1> : tensor<1xi64>
+// CHECK-SAME:       strides = dense<1> : tensor<1xi64>}
+// CHECK-SAME:     ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+// CHECK-SAME:    outs(%[[FILL]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>)
+  -> tensor<?x?x?xf32> {
+  %0 = "mhlo.convolution"(%arg0, %arg1) {
+    batch_group_count = 1 : i64,
+    dimension_numbers = {
+      input_batch_dimension = 0 : i64,
+      input_feature_dimension = 2 : i64,
+      input_spatial_dimensions = dense<[1]> : tensor<1xi64>,
+      kernel_input_feature_dimension = 1 : i64,
+      kernel_output_feature_dimension = 2 : i64,
+      kernel_spatial_dimensions = dense<[0]> : tensor<1xi64>,
+      output_batch_dimension = 0 : i64,
+      output_feature_dimension = 2 : i64,
+      output_spatial_dimensions = dense<[1]> : tensor<1xi64>
+    },
+    feature_group_count = 1 : i64,
+    padding = dense<[[0], [0]]> : tensor<2x1xi64>,
+    rhs_dilation = dense<1> : tensor<1xi64>,
+    window_strides = dense<1> : tensor<1xi64>
+  } : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
 }
 
-// -----
-
-module {
-  func @depthwise_conv() {
-    %c0 = constant 0 : index
-    %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<2x4x5x2xf32>
-    %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<2x2x2x3xf32>
-    %2 = "mhlo.convolution"(%0, %1) {
-      batch_group_count = 1 : i64,
-      dimension_numbers = {
-        input_batch_dimension = 0 : i64,
-        input_feature_dimension = 3 : i64,
-        input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
-        kernel_input_feature_dimension = 2 : i64,
-        kernel_output_feature_dimension = 3 : i64,
-        kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
-        output_batch_dimension = 0 : i64,
-        output_feature_dimension = 3 : i64,
-        output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
-      },
-      feature_group_count = 2 : i64,
-      padding = dense<0> : tensor<2x2xi64>,
-      rhs_dilation = dense<1> : tensor<2xi64>,
-      window_strides = dense<1> : tensor<2xi64>} : (tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<2x3x4x6xf32>
-    hal.interface.store.tensor %2, @legacy_io::@ret0, offset = %c0 : tensor<2x3x4x6xf32>
-    return
-  }
-  hal.interface @legacy_io attributes {sym_visibility = "private"} {
-    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
-  }
+// CHECK-LABEL: func @conv_2d_input_nhwc_filter_hwcf
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]]
+// CHECK-DAG:     %[[C0:.+]] = constant 0 : index
+// CHECK-DAG:     %[[C1:.+]] = constant 1 : index
+// CHECK-DAG:     %[[C2:.+]] = constant 2 : index
+// CHECK-DAG:     %[[C3:.+]] = constant 3 : index
+// CHECK-DAG:     %[[ZERO:.+]] = constant 0.000000e+00 : f32
+// CHECK:         %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?xf32>
+// CHECK:         %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
+// CHECK:         %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
+// CHECK:         %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<?x?x?x?xf32>
+// CHECK:         %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]]]
+// CHECK:         %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
+// CHECK:         linalg.conv_2d_input_nhwc_filter_hwcf
+// CHECK-SAME:      {dilations = dense<1> : tensor<2xi64>
+// CHECK-SAME:       strides = dense<1> : tensor<2xi64>}
+// CHECK-SAME:     ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+// CHECK-SAME:    outs(%[[FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>)
+  -> tensor<?x?x?x?xf32> {
+  %0 = "mhlo.convolution"(%arg0, %arg1) {
+    batch_group_count = 1 : i64,
+    dimension_numbers = {
+      input_batch_dimension = 0 : i64,
+      input_feature_dimension = 3 : i64,
+      input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+      kernel_input_feature_dimension = 2 : i64,
+      kernel_output_feature_dimension = 3 : i64,
+      kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+      output_batch_dimension = 0 : i64,
+      output_feature_dimension = 3 : i64,
+      output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
+    },
+    feature_group_count = 1 : i64,
+    padding = dense<[[0, 0], [0, 0]]> : tensor<2x2xi64>,
+    rhs_dilation = dense<1> : tensor<2xi64>,
+    window_strides = dense<1> : tensor<2xi64>
+  } : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
 }
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d4)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d4, d3)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4 * 3 + d3)>
-// CHECK: func @depthwise_conv()
-// CHECK: linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]
-// CHECK-SAME:   ins(%{{[a-z0-9]*}}, %{{[a-z0-9]*}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
-// CHECK-SAME:   outs(%{{[a-z0-9]*}} : memref<2x3x4x6xf32>)
-// CHECK: mulf
-// CHECK: addf
 
-// -----
-
-module {
-  func @depthwise_conv_multiplier_1() {
-    %c0 = constant 0 : index
-    %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<1x113x113x96xf32>
-    %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<3x3x1x96xf32>
-    %2 = "mhlo.convolution"(%0, %1) {
-      batch_group_count = 1 : i64,
-      dimension_numbers = {
-        input_batch_dimension = 0 : i64,
-        input_feature_dimension = 3 : i64,
-        input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
-        kernel_input_feature_dimension = 2 : i64,
-        kernel_output_feature_dimension = 3 : i64,
-        kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
-        output_batch_dimension = 0 : i64,
-        output_feature_dimension = 3 : i64,
-        output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
-      },
-      feature_group_count = 96 : i64,
-      padding = dense<0> : tensor<2x2xi64>,
-      rhs_dilation = dense<1> : tensor<2xi64>,
-      window_strides = dense<2> : tensor<2xi64>} : (tensor<1x113x113x96xf32>, tensor<3x3x1x96xf32>) -> tensor<1x56x56x96xf32>
-    hal.interface.store.tensor %2, @legacy_io::@ret0, offset = %c0 : tensor<1x56x56x96xf32>
-    return
-  }
-  hal.interface @legacy_io attributes {sym_visibility = "private"} {
-    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
-  }
+// CHECK-LABEL: func @conv_3d_input_ndhwc_filter_dhwcf
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]]
+// CHECK-DAG:     %[[C0:.+]] = constant 0 : index
+// CHECK-DAG:     %[[C1:.+]] = constant 1 : index
+// CHECK-DAG:     %[[C2:.+]] = constant 2 : index
+// CHECK-DAG:     %[[C3:.+]] = constant 3 : index
+// CHECK-DAG:     %[[C4:.+]] = constant 4 : index
+// CHECK-DAG:     %[[ZERO:.+]] = constant 0.000000e+00 : f32
+// CHECK:         %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?x?xf32>
+// CHECK:         %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?x?xf32>
+// CHECK:         %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?x?xf32>
+// CHECK:         %[[DIM3:.+]] = dim %[[ARG0]], %[[C3]] : tensor<?x?x?x?x?xf32>
+// CHECK:         %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<?x?x?x?x?xf32>
+// CHECK:         %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]], %[[DIM4]]]
+// CHECK:         %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
+// CHECK:         linalg.conv_3d_input_ndhwc_filter_dhwcf
+// CHECK-SAME:      {dilations = dense<1> : tensor<3xi64>
+// CHECK-SAME:       strides = dense<1> : tensor<3xi64>}
+// CHECK-SAME:     ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+// CHECK-SAME:    outs(%[[FILL]] : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<?x?x?x?x?xf32>)
+  -> tensor<?x?x?x?x?xf32> {
+  %0 = "mhlo.convolution"(%arg0, %arg1) {
+    batch_group_count = 1 : i64,
+    dimension_numbers = {
+      input_batch_dimension = 0 : i64,
+      input_feature_dimension = 4 : i64,
+      input_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>,
+      kernel_input_feature_dimension = 3 : i64,
+      kernel_output_feature_dimension = 4 : i64,
+      kernel_spatial_dimensions = dense<[0, 1, 2]> : tensor<3xi64>,
+      output_batch_dimension = 0 : i64,
+      output_feature_dimension = 4 : i64,
+      output_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>
+    },
+    feature_group_count = 1 : i64,
+    padding = dense<[[0, 0, 0], [0, 0, 0]]> : tensor<2x3xi64>,
+    rhs_dilation = dense<1> : tensor<3xi64>,
+    window_strides = dense<1> : tensor<3xi64>
+  } : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
 }
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
-// CHECK: func @depthwise_conv_multiplier_1()
-// CHECK: linalg.fill
-// CHECK: %[[FILTER:.+]] = linalg.reshape %{{.+}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]] : memref<3x3x1x96xf32> into memref<3x3x96xf32>
-// CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : tensor<2xi64>} ins(%{{.+}}, %[[FILTER]] : memref<1x113x113x96xf32>, memref<3x3x96xf32>) outs(%{{.+}} : memref<1x56x56x96xf32>)
+
+// CHECK-LABEL: func @conv2d_1452x2223_dilated_valid
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]]
+// CHECK:         %[[ZERO:.+]] = constant 0.000000e+00 : f32
+// CHECK:         %[[INIT:.+]] = linalg.init_tensor [1, 2, 4, 3] : tensor<1x2x4x3xf32>
+// CHECK:         %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) : tensor<1x2x4x3xf32>, f32 -> tensor<1x2x4x3xf32>
+// CHECK:         linalg.conv_2d_input_nhwc_filter_hwcf
+// CHECK-SAME:      {dilations = dense<[2, 1]> : tensor<2xi64>
+// CHECK-SAME:       strides = dense<1> : tensor<2xi64>}
+// CHECK-SAME:     ins(%[[ARG0]], %[[ARG1]] : tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>)
+// CHECK-SAME:    outs(%[[FILL]] : tensor<1x2x4x3xf32>) -> tensor<1x2x4x3xf32>
+func @conv2d_1452x2223_dilated_valid(%arg0: tensor<1x4x5x2xf32>, %arg1: tensor<2x2x2x3xf32>)
+  -> tensor<1x2x4x3xf32> {
+  %0 = "mhlo.convolution"(%arg0, %arg1) {
+    batch_group_count = 1 : i64,
+    dimension_numbers = {
+      input_batch_dimension = 0 : i64,
+      input_feature_dimension = 3 : i64,
+      input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+      kernel_input_feature_dimension = 2 : i64,
+      kernel_output_feature_dimension = 3 : i64,
+      kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+      output_batch_dimension = 0 : i64,
+      output_feature_dimension = 3 : i64,
+      output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
+    },
+    feature_group_count = 1 : i64,
+    padding = dense<0> : tensor<2x2xi64>,
+    rhs_dilation = dense<[2, 1]> : tensor<2xi64>,
+    window_strides = dense<1> : tensor<2xi64>
+  } : (tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<1x2x4x3xf32>
+  return %0 : tensor<1x2x4x3xf32>
+}
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/depthwise_conv.mlir b/iree/compiler/Conversion/HLOToLinalg/test/depthwise_conv.mlir
new file mode 100644
index 0000000..d68f749
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToLinalg/test/depthwise_conv.mlir
@@ -0,0 +1,83 @@
+// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-on-buffers %s | IreeFileCheck %s
+
+module {
+  func @depthwise_conv() {
+    %c0 = constant 0 : index
+    %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<2x4x5x2xf32>
+    %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<2x2x2x3xf32>
+    %2 = "mhlo.convolution"(%0, %1) {
+      batch_group_count = 1 : i64,
+      dimension_numbers = {
+        input_batch_dimension = 0 : i64,
+        input_feature_dimension = 3 : i64,
+        input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+        kernel_input_feature_dimension = 2 : i64,
+        kernel_output_feature_dimension = 3 : i64,
+        kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+        output_batch_dimension = 0 : i64,
+        output_feature_dimension = 3 : i64,
+        output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
+      },
+      feature_group_count = 2 : i64,
+      padding = dense<0> : tensor<2x2xi64>,
+      rhs_dilation = dense<1> : tensor<2xi64>,
+      window_strides = dense<1> : tensor<2xi64>} : (tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<2x3x4x6xf32>
+    hal.interface.store.tensor %2, @legacy_io::@ret0, offset = %c0 : tensor<2x3x4x6xf32>
+    return
+  }
+  hal.interface @legacy_io attributes {sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+  }
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d4)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d4, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4 * 3 + d3)>
+// CHECK: func @depthwise_conv()
+// CHECK: linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]
+// CHECK-SAME:   ins(%{{[a-z0-9]*}}, %{{[a-z0-9]*}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
+// CHECK-SAME:   outs(%{{[a-z0-9]*}} : memref<2x3x4x6xf32>)
+// CHECK: mulf
+// CHECK: addf
+
+// -----
+
+module {
+  func @depthwise_conv_multiplier_1() {
+    %c0 = constant 0 : index
+    %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<1x113x113x96xf32>
+    %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<3x3x1x96xf32>
+    %2 = "mhlo.convolution"(%0, %1) {
+      batch_group_count = 1 : i64,
+      dimension_numbers = {
+        input_batch_dimension = 0 : i64,
+        input_feature_dimension = 3 : i64,
+        input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+        kernel_input_feature_dimension = 2 : i64,
+        kernel_output_feature_dimension = 3 : i64,
+        kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+        output_batch_dimension = 0 : i64,
+        output_feature_dimension = 3 : i64,
+        output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
+      },
+      feature_group_count = 96 : i64,
+      padding = dense<0> : tensor<2x2xi64>,
+      rhs_dilation = dense<1> : tensor<2xi64>,
+      window_strides = dense<2> : tensor<2xi64>} : (tensor<1x113x113x96xf32>, tensor<3x3x1x96xf32>) -> tensor<1x56x56x96xf32>
+    hal.interface.store.tensor %2, @legacy_io::@ret0, offset = %c0 : tensor<1x56x56x96xf32>
+    return
+  }
+  hal.interface @legacy_io attributes {sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+  }
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+// CHECK: func @depthwise_conv_multiplier_1()
+// CHECK: linalg.fill
+// CHECK: %[[FILTER:.+]] = linalg.reshape %{{.+}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]] : memref<3x3x1x96xf32> into memref<3x3x96xf32>
+// CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : tensor<2xi64>} ins(%{{.+}}, %[[FILTER]] : memref<1x113x113x96xf32>, memref<3x3x96xf32>) outs(%{{.+}} : memref<1x56x56x96xf32>)
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
index bf27373..2ec5408 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
@@ -771,22 +771,25 @@
 
   OwningRewritePatternList patterns;
 
-  patterns.insert<MapLinalgOpToGlobalInvocationId<linalg::CopyOp>,
-                  MapLinalgOpToGlobalInvocationId<linalg::FillOp>,
-                  MapLinalgOpToGlobalInvocationId<linalg::GenericOp>,
-                  MapLinalgOpToGlobalInvocationId<linalg::IndexedGenericOp>,
-                  MapLinalgOpToLocalInvocationId<linalg::ConvOp>,
-                  MapLinalgOpToLocalInvocationId<linalg::CopyOp>,
-                  MapLinalgOpToLocalInvocationId<linalg::FillOp>,
-                  MapLinalgOpToLocalInvocationId<linalg::GenericOp>,
-                  MapLinalgOpToLocalInvocationId<linalg::IndexedGenericOp>,
-                  MapLinalgOpToLocalInvocationId<linalg::MatmulOp>,
-                  MapLinalgOpToLocalInvocationId<linalg::BatchMatmulOp>,
-                  MapLinalgOpToLocalInvocationId<linalg::PoolingMaxOp>,
-                  MapLinalgOpToLocalInvocationId<linalg::PoolingMinOp>,
-                  MapLinalgOpToLocalInvocationId<linalg::PoolingSumOp>,
-                  RemoveLinalgRange, SerializeParallelLoopPattern>(
-      context, options.useLinalgOnTensors);
+  patterns.insert<
+      MapLinalgOpToGlobalInvocationId<linalg::CopyOp>,
+      MapLinalgOpToGlobalInvocationId<linalg::FillOp>,
+      MapLinalgOpToGlobalInvocationId<linalg::GenericOp>,
+      MapLinalgOpToGlobalInvocationId<linalg::IndexedGenericOp>,
+      MapLinalgOpToLocalInvocationId<linalg::ConvOp>,
+      MapLinalgOpToLocalInvocationId<linalg::ConvInputNWCFilterWCFOp>,
+      MapLinalgOpToLocalInvocationId<linalg::ConvInputNHWCFilterHWCFOp>,
+      MapLinalgOpToLocalInvocationId<linalg::ConvInputNDHWCFilterDHWCFOp>,
+      MapLinalgOpToLocalInvocationId<linalg::CopyOp>,
+      MapLinalgOpToLocalInvocationId<linalg::FillOp>,
+      MapLinalgOpToLocalInvocationId<linalg::GenericOp>,
+      MapLinalgOpToLocalInvocationId<linalg::IndexedGenericOp>,
+      MapLinalgOpToLocalInvocationId<linalg::MatmulOp>,
+      MapLinalgOpToLocalInvocationId<linalg::BatchMatmulOp>,
+      MapLinalgOpToLocalInvocationId<linalg::PoolingMaxOp>,
+      MapLinalgOpToLocalInvocationId<linalg::PoolingMinOp>,
+      MapLinalgOpToLocalInvocationId<linalg::PoolingSumOp>, RemoveLinalgRange,
+      SerializeParallelLoopPattern>(context, options.useLinalgOnTensors);
   FrozenRewritePatternList frozenPatterns(std::move(patterns));
 
   for (FuncOp funcOp : getOperation().getOps<FuncOp>()) {
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index 1c4bcd5..885dbe3 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -350,13 +350,19 @@
   std::array<int64_t, 3> workgroupSize;
 };
 
-static LogicalResult getMaliSpecificConfig(linalg::ConvOp op,
+template <typename ConvOpTy>
+static LogicalResult getMaliSpecificConfig(ConvOpTy op,
                                            TileSizesListType &tileSizes,
                                            LaunchConfigInfo &config) {
-  auto inputType = op.getInput(1).getType().cast<MemRefType>();
-  auto outputType = op.getOutputBufferTypes()[0].cast<MemRefType>();
+  auto inputType = op.getInput(1).getType().template cast<MemRefType>();
+  auto outputType = op.getOutputBufferTypes()[0].template cast<MemRefType>();
   if (!inputType.hasStaticShape() || !outputType.hasStaticShape())
     return failure();
+  // Only support NHWC conv.
+  if (!isa<linalg::ConvOp, linalg::ConvInputNHWCFilterHWCFOp>(
+          op.getOperation())) {
+    return failure();
+  }
 
   bool isInputTilable = inputType.getDimSize(3) % 4 == 0;
   if (!isInputTilable) return failure();
@@ -406,12 +412,11 @@
   return failure();
 }
 
-template <>
-LogicalResult getOpLaunchConfig(linalg::ConvOp op,
-                                const spirv::TargetEnv &targetEnv,
-                                const SPIRVCodegenOptions &options,
-                                TileSizesListType &tileSizes,
-                                LaunchConfigInfo &config) {
+template <typename T>
+LogicalResult getConvOpLaunchConfig(T op, const spirv::TargetEnv &targetEnv,
+                                    const SPIRVCodegenOptions &options,
+                                    TileSizesListType &tileSizes,
+                                    LaunchConfigInfo &config) {
   if (targetEnv.getVendorID() == spirv::Vendor::ARM &&
       succeeded(getMaliSpecificConfig(op, tileSizes, config))) {
     return success();
@@ -428,6 +433,22 @@
   return success();
 }
 
+#define GET_CONV_LAUNCH_CONFIG(opType)                                       \
+  template <>                                                                \
+  LogicalResult getOpLaunchConfig(                                           \
+      opType op, const spirv::TargetEnv &targetEnv,                          \
+      const SPIRVCodegenOptions &options, TileSizesListType &tileSizes,      \
+      LaunchConfigInfo &config) {                                            \
+    return getConvOpLaunchConfig(op, targetEnv, options, tileSizes, config); \
+  }
+
+GET_CONV_LAUNCH_CONFIG(linalg::ConvOp)
+GET_CONV_LAUNCH_CONFIG(linalg::ConvInputNWCFilterWCFOp)
+GET_CONV_LAUNCH_CONFIG(linalg::ConvInputNHWCFilterHWCFOp)
+GET_CONV_LAUNCH_CONFIG(linalg::ConvInputNDHWCFilterDHWCFOp)
+
+#undef GET_CONV_LAUNCH_CONFIG
+
 static LogicalResult getMaliSpecificConfig(
     linalg::DepthwiseConvInputNHWCFilterHWCOp op, TileSizesListType &tileSizes,
     LaunchConfigInfo &config) {
@@ -583,6 +604,9 @@
     DISPATCH(linalg::BatchMatmulOp)
     DISPATCH(linalg::ConvOp)
     DISPATCH(linalg::DepthwiseConvInputNHWCFilterHWCOp)
+    DISPATCH(linalg::ConvInputNWCFilterWCFOp)
+    DISPATCH(linalg::ConvInputNHWCFilterHWCFOp)
+    DISPATCH(linalg::ConvInputNDHWCFilterDHWCFOp)
     DISPATCH(linalg::MatmulOp)
     DISPATCH(linalg::PoolingMaxOp)
     DISPATCH(linalg::PoolingMinOp)
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index c893944..04b9219 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -159,13 +159,14 @@
 // 2) Maybe there are better alternatives for handling filter like using
 //    different storage classes, since for inference workloads these are model
 //    constants. This is TBD.
+template <typename ConvOpTy>
 struct PromoteConvSubviewsPattern
-    : public linalg::LinalgPromotionPattern<linalg::ConvOp> {
+    : public linalg::LinalgPromotionPattern<ConvOpTy> {
   PromoteConvSubviewsPattern(MLIRContext *context,
                              linalg::LinalgPromotionOptions options,
                              linalg::LinalgTransformationFilter marker,
                              PatternBenefit benefit = 1)
-      : linalg::LinalgPromotionPattern<linalg::ConvOp>(
+      : linalg::LinalgPromotionPattern<ConvOpTy>(
             context,
             options.setOperandsToPromote({1}).setUseFullTileBuffers(
                 {false, false}),
@@ -175,7 +176,11 @@
 
 static void populatePromotionPatterns(MLIRContext *context,
                                       OwningRewritePatternList &patterns) {
-  patterns.insert<PromoteMatmulSubviewsPattern, PromoteConvSubviewsPattern>(
+  patterns.insert<
+      PromoteMatmulSubviewsPattern, PromoteConvSubviewsPattern<linalg::ConvOp>,
+      PromoteConvSubviewsPattern<linalg::ConvInputNWCFilterWCFOp>,
+      PromoteConvSubviewsPattern<linalg::ConvInputNHWCFilterHWCFOp>,
+      PromoteConvSubviewsPattern<linalg::ConvInputNDHWCFilterDHWCFOp>>(
       context,
       linalg::LinalgPromotionOptions()
           .setAllocationDeallocationFns(allocateWorkgroupMemory,
@@ -313,6 +318,9 @@
 
   patterns.insert<
       linalg::LinalgTilingPattern<linalg::ConvOp>,
+      linalg::LinalgTilingPattern<linalg::ConvInputNWCFilterWCFOp>,
+      linalg::LinalgTilingPattern<linalg::ConvInputNHWCFilterHWCFOp>,
+      linalg::LinalgTilingPattern<linalg::ConvInputNDHWCFilterDHWCFOp>,
       linalg::LinalgTilingPattern<linalg::DepthwiseConvInputNHWCFilterHWCOp>>(
       context, tilingOptions,
       getLinalgMatchAndReplaceMarker(
@@ -431,8 +439,12 @@
                                .setInterchange(loopOrder)
                                .setTileSizeComputationFunction(getTileSizeFn);
 
-  patterns.insert<linalg::LinalgTilingPattern<linalg::ConvOp>>(
-      context, convTilingOptions, marker);
+  patterns
+      .insert<linalg::LinalgTilingPattern<linalg::ConvOp>,
+              linalg::LinalgTilingPattern<linalg::ConvInputNWCFilterWCFOp>,
+              linalg::LinalgTilingPattern<linalg::ConvInputNHWCFilterHWCFOp>,
+              linalg::LinalgTilingPattern<linalg::ConvInputNDHWCFilterDHWCFOp>>(
+          context, convTilingOptions, marker);
 }
 
 //====---------------------------------------------------------------------===//
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
index d314304..9f6c443 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
@@ -85,7 +85,11 @@
                      linalg::LinalgDependenceGraph::DependenceType::RAW)
     ADD_FUSABLE_PAIR(linalg::FillOp, linalg::BatchMatmulOp,
                      linalg::LinalgDependenceGraph::DependenceType::WAW)
-    ADD_FUSABLE_PAIR(linalg::FillOp, linalg::ConvOp,
+    ADD_FUSABLE_PAIR(linalg::FillOp, linalg::ConvInputNWCFilterWCFOp,
+                     linalg::LinalgDependenceGraph::DependenceType::WAW)
+    ADD_FUSABLE_PAIR(linalg::FillOp, linalg::ConvInputNHWCFilterHWCFOp,
+                     linalg::LinalgDependenceGraph::DependenceType::WAW)
+    ADD_FUSABLE_PAIR(linalg::FillOp, linalg::ConvInputNDHWCFilterDHWCFOp,
                      linalg::LinalgDependenceGraph::DependenceType::WAW)
     ADD_FUSABLE_PAIR(linalg::FillOp, linalg::MatmulOp,
                      linalg::LinalgDependenceGraph::DependenceType::WAW)
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
index f937cd6..d201d0e 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
@@ -354,9 +354,12 @@
       %26 = dim %arg2, %c3 : memref<?x?x?x?xf32>
       %27 = subview %arg2[%arg3, %arg4, %arg5, 0] [%23, %24, %25, %26] [1, 1, 1, 1]
               : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map5>
-      linalg.conv(%arg0, %21, %27)
-        {__internal_linalg_transform__ = "workgroup", dilations = [1, 1], strides = [1, 1]}
-        : memref<?x?x?x?xf32>, memref<?x?x?x?xf32, #map5>, memref<?x?x?x?xf32, #map5>
+      linalg.conv_2d_input_nhwc_filter_hwcf {
+        __internal_linalg_transform__ = "workgroup",
+        dilations = dense<1> : tensor<2xi64>,
+        strides = dense<2> : tensor<2xi64>}
+         ins(%21, %arg0 : memref<?x?x?x?xf32, #map5>, memref<?x?x?x?xf32>)
+        outs(%27 : memref<?x?x?x?xf32, #map5>)
       scf.yield
     }
     return
@@ -407,7 +410,7 @@
 //       CHECK:             scf.for
 //       CHECK:               scf.for
 //       CHECK:                 scf.for
-//   CHECK-NOT:                   linalg.conv
+//   CHECK-NOT:                   linalg.conv_2d_input_nhwc_filter_hwcf
 
 // -----
 
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
index 5278cdd..9e0fa44 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
@@ -1,29 +1,71 @@
 // RUN: iree-opt -allow-unregistered-dialect -split-input-file -iree-codegen-split-dispatch-function -verify-diagnostics %s | IreeFileCheck %s
 
 module {
-  //     CHECK: func @kernel_fusable_fill_conv_ops
+  //     CHECK: func @kernel_fusable_fill_conv1d_ops
   //     CHECK:   linalg.fill
   // CHECK-NOT:   return
-  //     CHECK:   linalg.conv
+  //     CHECK:   linalg.conv_1d_input_nwc_filter_wcf
   //     CHECK:   return
 
-  func @kernel_fusable_fill_conv_ops()
+  func @kernel_fusable_fill_conv1d_ops()
   attributes {hal.num_workgroups_fn = @kernel_fusable_fill_conv_ops_num_workgroups__} {
     %cst = constant 0.000000e+00 : f32
     %dim = hal.interface.load.constant offset = 0 : index
-    %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,2,2,512]>
+    %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,3,512]>
+    %shape2 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,1,512]>
+    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x3x512xf32>
+    %ts1 = shapex.tie_shape %0, %shape1 : memref<?x3x512xf32>, !shapex.ranked_shape<[?,3,512]>
+    %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x512x1xf32>
+    %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x512xf32>
+    %ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x512xf32>, !shapex.ranked_shape<[?,1,512]>
+    linalg.fill(%ts2, %cst) : memref<?x1x512xf32>, f32
+    linalg.conv_1d_input_nwc_filter_wcf {
+      dilations = dense<1> : tensor<1xi64>,
+      strides = dense<2> : tensor<1xi64>}
+       ins(%ts1, %1 : memref<?x3x512xf32>, memref<3x512x1xf32>)
+      outs(%ts2 : memref<?x1x512xf32>)
+    return
+  }
+  func private @kernel_fusable_fill_conv_ops_num_workgroups__
+    (!shapex.ranked_shape<[?,3,512]>, !shapex.ranked_shape<[3,512,1]>,
+     !shapex.ranked_shape<[?,1,512]>) -> (index, index, index)
+  hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+  }
+}
+
+// -----
+
+module {
+  //     CHECK: func @kernel_fusable_fill_conv2d_ops
+  //     CHECK:   linalg.fill
+  // CHECK-NOT:   return
+  //     CHECK:   linalg.conv_2d_input_nhwc_filter_hwcf
+  //     CHECK:   return
+
+  func @kernel_fusable_fill_conv2d_ops()
+  attributes {hal.num_workgroups_fn = @kernel_fusable_fill_conv_ops_num_workgroups__} {
+    %cst = constant 0.000000e+00 : f32
+    %dim = hal.interface.load.constant offset = 0 : index
+    %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,3,3,512]>
     %shape2 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,1,1,512]>
-    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
-    %ts1 = shapex.tie_shape %0, %shape1 : memref<?x2x2x512xf32>, !shapex.ranked_shape<[?,2,2,512]>
+    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x3x3x512xf32>
+    %ts1 = shapex.tie_shape %0, %shape1 : memref<?x3x3x512xf32>, !shapex.ranked_shape<[?,3,3,512]>
     %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
     %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
     %ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x1x512xf32>, !shapex.ranked_shape<[?,1,1,512]>
     linalg.fill(%ts2, %cst) : memref<?x1x1x512xf32>, f32
-    linalg.conv(%1, %ts1, %ts2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<?x2x2x512xf32>, memref<?x1x1x512xf32>
+    linalg.conv_2d_input_nhwc_filter_hwcf {
+      dilations = dense<1> : tensor<2xi64>,
+      strides = dense<2> : tensor<2xi64>}
+       ins(%ts1, %1 : memref<?x3x3x512xf32>, memref<3x3x512x1xf32>)
+      outs(%ts2 : memref<?x1x1x512xf32>)
     return
   }
   func private @kernel_fusable_fill_conv_ops_num_workgroups__
-    (!shapex.ranked_shape<[?,2,2,512]>, !shapex.ranked_shape<[3,3,512,1]>,
+    (!shapex.ranked_shape<[?,3,3,512]>, !shapex.ranked_shape<[3,3,512,1]>,
      !shapex.ranked_shape<[?,1,1,512]>) -> (index, index, index)
   hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
     hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
@@ -35,6 +77,44 @@
 // -----
 
 module {
+  //     CHECK: func @kernel_fusable_fill_conv3d_ops
+  //     CHECK:   linalg.fill
+  // CHECK-NOT:   return
+  //     CHECK:   linalg.conv_3d_input_ndhwc_filter_dhwcf
+  //     CHECK:   return
+
+  func @kernel_fusable_fill_conv3d_ops()
+  attributes {hal.num_workgroups_fn = @kernel_fusable_fill_conv_ops_num_workgroups__} {
+    %cst = constant 0.000000e+00 : f32
+    %dim = hal.interface.load.constant offset = 0 : index
+    %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,3,3,3,512]>
+    %shape2 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,1,1,1,512]>
+    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x3x3x3x512xf32>
+    %ts1 = shapex.tie_shape %0, %shape1 : memref<?x3x3x3x512xf32>, !shapex.ranked_shape<[?,3,3,3,512]>
+    %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x3x512x1xf32>
+    %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x1x512xf32>
+    %ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x1x1x512xf32>, !shapex.ranked_shape<[?,1,1,1,512]>
+    linalg.fill(%ts2, %cst) : memref<?x1x1x1x512xf32>, f32
+    linalg.conv_3d_input_ndhwc_filter_dhwcf {
+      dilations = dense<1> : tensor<3xi64>,
+      strides = dense<2> : tensor<3xi64>}
+       ins(%ts1, %1 : memref<?x3x3x3x512xf32>, memref<3x3x3x512x1xf32>)
+      outs(%ts2 : memref<?x1x1x1x512xf32>)
+    return
+  }
+  func private @kernel_fusable_fill_conv_ops_num_workgroups__
+    (!shapex.ranked_shape<[?,3,3,3,512]>, !shapex.ranked_shape<[3,3,3,512,1]>,
+     !shapex.ranked_shape<[?,1,1,1,512]>) -> (index, index, index)
+  hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+  }
+}
+
+// -----
+
+module {
   //     CHECK: func @kernel_fusable_fill_matmul_ops
   //     CHECK:   linalg.fill
   // CHECK-NOT:   return
@@ -118,29 +198,35 @@
   // CHECK:   %[[DIM:.+]] = hal.interface.load.constant
   // CHECK:   %[[SHAPE1:.+]] = shapex.make_ranked_shape %[[DIM]]
   // CHECK:   %[[SHAPE2:.+]] = shapex.make_ranked_shape %[[DIM]]
-  // CHECK:   %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
+  // CHECK:   %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x3x3x512xf32>
   // CHECK:   %[[TS1:.+]] = shapex.tie_shape %[[IN1]], %[[SHAPE1]]
   // CHECK:   %[[IN2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
   // CHECK:   %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
   // CHECK:   %[[TS2:.+]] = shapex.tie_shape %[[OUT]], %[[SHAPE2]]
-  // CHECK:   linalg.conv(%[[IN2]], %[[TS1]], %[[TS2]])
+  // CHECK:   linalg.conv_2d_input_nhwc_filter_hwcf
+  // CHECK-SAME: ins(%[[TS1]], %[[IN2]] : memref<?x3x3x512xf32>, memref<3x3x512x1xf32>)
+  // CHECK-SAME: outs(%[[TS2]] : memref<?x1x1x512xf32>)
   // CHECK:   return
 
   func @kernel() attributes {hal.num_workgroups_fn = @kernel__num_workgroups__} {
     %cst = constant 0.000000e+00 : f32
     %dim = hal.interface.load.constant offset = 0 : index
-    %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,2,2,512]>
+    %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,3,3,512]>
     %shape2 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,1,1,512]>
-    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
-    %ts1 = shapex.tie_shape %0, %shape1 : memref<?x2x2x512xf32>, !shapex.ranked_shape<[?,2,2,512]>
+    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x3x3x512xf32>
+    %ts1 = shapex.tie_shape %0, %shape1 : memref<?x3x3x512xf32>, !shapex.ranked_shape<[?,3,3,512]>
     %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
     %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
     %ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x1x512xf32>, !shapex.ranked_shape<[?,1,1,512]>
-    linalg.conv(%1, %ts1, %ts2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<?x2x2x512xf32>, memref<?x1x1x512xf32>
+    linalg.conv_2d_input_nhwc_filter_hwcf {
+      dilations = dense<1> : tensor<2xi64>,
+      strides = dense<2> : tensor<2xi64>}
+       ins(%ts1, %1 : memref<?x3x3x512xf32>, memref<3x3x512x1xf32>)
+      outs(%ts2 : memref<?x1x1x512xf32>)
     linalg.fill(%ts2, %cst) : memref<?x1x1x512xf32>, f32
     return
   }
-  func private @kernel__num_workgroups__(!shapex.ranked_shape<[?,2,2,512]>,
+  func private @kernel__num_workgroups__(!shapex.ranked_shape<[?,3,3,512]>,
                                  !shapex.ranked_shape<[3,3,512,1]>,
                                  !shapex.ranked_shape<[?,1,1,512]>)
                                  -> (index, index, index)
@@ -160,12 +246,14 @@
 //      CHECK:   %[[DIM:.+]] = hal.interface.load.constant
 //      CHECK:   %[[SHAPE1:.+]] = shapex.make_ranked_shape %[[DIM]]
 //      CHECK:   %[[SHAPE2:.+]] = shapex.make_ranked_shape %[[DIM]]
-//      CHECK:   %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
+//      CHECK:   %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x3x3x512xf32>
 //      CHECK:   %[[TS1:.+]] = shapex.tie_shape %[[IN1]], %[[SHAPE1]]
 //      CHECK:   %[[IN2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
 //      CHECK:   %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
 //      CHECK:   %[[TS2:.+]] = shapex.tie_shape %[[OUT]], %[[SHAPE2]]
-//      CHECK:   linalg.conv(%[[IN2]], %[[TS1]], %[[TS2]])
+//      CHECK:   linalg.conv_2d_input_nhwc_filter_hwcf
+// CHECK-SAME:     ins(%[[TS1]], %[[IN2]] : memref<?x3x3x512xf32>, memref<3x3x512x1xf32>)
+// CHECK-SAME:     outs(%[[TS2]] : memref<?x1x1x512xf32>)
 //      CHECK:   return
 
 //      CHECK: func private @[[NUM_WORKGROUPS_FN2]]
@@ -197,10 +285,10 @@
     %c0 = constant 0 : index
     %c1 = constant 1 : index
     %dim = hal.interface.load.constant offset = 0 : index
-    %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,2,2,512]>
+    %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,3,3,512]>
     %shape2 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,1,1,512]>
-    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
-    %ts1 = shapex.tie_shape %0, %shape1 : memref<?x2x2x512xf32>, !shapex.ranked_shape<[?,2,2,512]>
+    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x3x3x512xf32>
+    %ts1 = shapex.tie_shape %0, %shape1 : memref<?x3x3x512xf32>, !shapex.ranked_shape<[?,3,3,512]>
     %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
     %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
     %ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x1x512xf32>, !shapex.ranked_shape<[?,1,1,512]>
@@ -208,10 +296,14 @@
     scf.parallel (%iv) = (%c0) to (%c1) step (%c1) {
       scf.yield
     }
-    linalg.conv(%1, %ts1, %ts2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<?x2x2x512xf32>, memref<?x1x1x512xf32>
+    linalg.conv_2d_input_nhwc_filter_hwcf {
+      dilations = dense<1> : tensor<2xi64>,
+      strides = dense<2> : tensor<2xi64>}
+       ins(%ts1, %1 : memref<?x3x3x512xf32>, memref<3x3x512x1xf32>)
+      outs(%ts2 : memref<?x1x1x512xf32>)
     return
   }
-  func private @kernel__num_workgroups__(!shapex.ranked_shape<[?,2,2,512]>,
+  func private @kernel__num_workgroups__(!shapex.ranked_shape<[?,3,3,512]>,
                                  !shapex.ranked_shape<[3,3,512,1]>,
                                  !shapex.ranked_shape<[?,1,1,512]>)
                                  -> (index, index, index)
@@ -231,10 +323,14 @@
   // CHECK-LABEL: @kernel()
   func @kernel() attributes {hal.num_workgroups_fn = @kernel__num_workgroups__} {
     %cst = constant 0.000000e+00 : f32
-    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x2x2x512xf32>
+    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x3x3x512xf32>
     %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
     %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x1x1x512xf32>
-    linalg.conv(%1, %0, %2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<1x2x2x512xf32>, memref<1x1x1x512xf32>
+    linalg.conv_2d_input_nhwc_filter_hwcf {
+      dilations = dense<1> : tensor<2xi64>,
+      strides = dense<2> : tensor<2xi64>}
+       ins(%0, %1 : memref<1x3x3x512xf32>, memref<3x3x512x1xf32>)
+      outs(%2 : memref<1x1x1x512xf32>)
     return
   }
   // CHECK-LABEL: @kernel__num_workgroups__
@@ -256,12 +352,16 @@
   // expected-error @+1 {{cannot separate Linalg/Parallel ops into multiple kernels}}
   func @kernel() {
     %cst = constant 0.000000e+00 : f32
-    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x2x2x512xf32>
+    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x3x3x512xf32>
     %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
     %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x1x1x512xf32>
     linalg.fill(%2, %cst) : memref<1x1x1x512xf32>, f32
     "some_op"() : () -> ()
-    linalg.conv(%1, %0, %2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<1x2x2x512xf32>, memref<1x1x1x512xf32>
+    linalg.conv_2d_input_nhwc_filter_hwcf {
+      dilations = dense<1> : tensor<2xi64>,
+      strides = dense<2> : tensor<2xi64>}
+       ins(%0, %1 : memref<1x3x3x512xf32>, memref<3x3x512x1xf32>)
+      outs(%2 : memref<1x1x1x512xf32>)
     return
   }
   hal.interface @legacy_io attributes {sym_visibility = "private"} {
diff --git a/iree/test/e2e/llvm_specific/conv.mlir b/iree/test/e2e/llvm_specific/conv.mlir
index 6ac1719..5e29fa2 100644
--- a/iree/test/e2e/llvm_specific/conv.mlir
+++ b/iree/test/e2e/llvm_specific/conv.mlir
@@ -205,3 +205,53 @@
         [105921.0, 107874.0, 109827.0, 111780.0, 113733.0, 115686.0]]]]> : tensor<2x3x3x6xf32>) : tensor<2x3x3x6xf32>
   return
 }
+
+func @conv_1d() {
+  %inputs = iree.unfoldable_constant dense<2.0> : tensor<3x8x1xf32>
+  %weights = iree.unfoldable_constant dense<2.0> : tensor<3x1x1xf32>
+  %res = "mhlo.convolution"(%inputs, %weights) {
+    batch_group_count = 1 : i64,
+    dimension_numbers = {
+      input_batch_dimension = 0 : i64,
+      input_feature_dimension = 2 : i64,
+      input_spatial_dimensions = dense<[1]> : tensor<1xi64>,
+      kernel_input_feature_dimension = 1 : i64,
+      kernel_output_feature_dimension = 2 : i64,
+      kernel_spatial_dimensions = dense<[0]> : tensor<1xi64>,
+      output_batch_dimension = 0 : i64,
+      output_feature_dimension = 2 : i64,
+      output_spatial_dimensions = dense<[1]> : tensor<1xi64>
+    },
+    feature_group_count = 1 : i64,
+    padding = dense<0> : tensor<1x2xi64>,
+    rhs_dilation = dense<1> : tensor<1xi64>,
+    window_strides = dense<1> : tensor<1xi64>
+  } : (tensor<3x8x1xf32>, tensor<3x1x1xf32>) -> tensor<3x6x1xf32>
+  check.expect_almost_eq_const(%res, dense<12.0> : tensor<3x6x1xf32>) : tensor<3x6x1xf32>
+  return
+}
+
+func @conv_3d() {
+  %inputs = iree.unfoldable_constant dense<1.0> : tensor<2x8x8x8x3xf32>
+  %weights = iree.unfoldable_constant dense<1.0> : tensor<2x2x2x3x2xf32>
+  %res = "mhlo.convolution"(%inputs, %weights) {
+    batch_group_count = 1 : i64,
+    dimension_numbers = {
+      input_batch_dimension = 0 : i64,
+      input_feature_dimension = 4 : i64,
+      input_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>,
+      kernel_input_feature_dimension = 3 : i64,
+      kernel_output_feature_dimension = 4 : i64,
+      kernel_spatial_dimensions = dense<[0, 1, 2]> : tensor<3xi64>,
+      output_batch_dimension = 0 : i64,
+      output_feature_dimension = 4 : i64,
+      output_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>
+    },
+    feature_group_count = 1 : i64,
+    padding = dense<0> : tensor<3x2xi64>,
+    rhs_dilation = dense<1> : tensor<3xi64>,
+    window_strides = dense<1> : tensor<3xi64>
+  } : (tensor<2x8x8x8x3xf32>, tensor<2x2x2x3x2xf32>) -> tensor<2x7x7x7x2xf32>
+  check.expect_almost_eq_const(%res, dense<24.0> : tensor<2x7x7x7x2xf32>) : tensor<2x7x7x7x2xf32>
+  return
+}
diff --git a/iree/test/e2e/vulkan_specific/conv.mlir b/iree/test/e2e/vulkan_specific/conv.mlir
index e2dc59d..a1f273b 100644
--- a/iree/test/e2e/vulkan_specific/conv.mlir
+++ b/iree/test/e2e/vulkan_specific/conv.mlir
@@ -51,14 +51,14 @@
        batch_group_count = 1 : i64,
        dimension_numbers = {
          input_batch_dimension = 0 : i64,
-	 input_feature_dimension = 3 : i64,
-	 input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
-	 kernel_input_feature_dimension = 2 : i64,
-	 kernel_output_feature_dimension = 3 : i64,
-	 kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
-	 output_batch_dimension = 0 : i64,
-	 output_feature_dimension = 3 : i64,
-	 output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
+   input_feature_dimension = 3 : i64,
+   input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+   kernel_input_feature_dimension = 2 : i64,
+   kernel_output_feature_dimension = 3 : i64,
+   kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+   output_batch_dimension = 0 : i64,
+   output_feature_dimension = 3 : i64,
+   output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
        feature_group_count = 1 : i64,
        rhs_dilation = dense<1> : tensor<2xi64>,
        window_strides = dense<1> : tensor<2xi64>}
@@ -81,3 +81,53 @@
         : tensor<1x3x4x3xf32>) : tensor<1x3x4x3xf32>
    return
 }
+
+func @conv_1d() {
+  %inputs = iree.unfoldable_constant dense<2.0> : tensor<3x8x1xf32>
+  %weights = iree.unfoldable_constant dense<2.0> : tensor<3x1x1xf32>
+  %res = "mhlo.convolution"(%inputs, %weights) {
+    batch_group_count = 1 : i64,
+    dimension_numbers = {
+      input_batch_dimension = 0 : i64,
+      input_feature_dimension = 2 : i64,
+      input_spatial_dimensions = dense<[1]> : tensor<1xi64>,
+      kernel_input_feature_dimension = 1 : i64,
+      kernel_output_feature_dimension = 2 : i64,
+      kernel_spatial_dimensions = dense<[0]> : tensor<1xi64>,
+      output_batch_dimension = 0 : i64,
+      output_feature_dimension = 2 : i64,
+      output_spatial_dimensions = dense<[1]> : tensor<1xi64>
+    },
+    feature_group_count = 1 : i64,
+    padding = dense<0> : tensor<1x2xi64>,
+    rhs_dilation = dense<1> : tensor<1xi64>,
+    window_strides = dense<1> : tensor<1xi64>
+  } : (tensor<3x8x1xf32>, tensor<3x1x1xf32>) -> tensor<3x6x1xf32>
+  check.expect_almost_eq_const(%res, dense<12.0> : tensor<3x6x1xf32>) : tensor<3x6x1xf32>
+  return
+}
+
+func @conv_3d() {
+  %inputs = iree.unfoldable_constant dense<1.0> : tensor<2x8x8x8x3xf32>
+  %weights = iree.unfoldable_constant dense<1.0> : tensor<2x2x2x3x2xf32>
+  %res = "mhlo.convolution"(%inputs, %weights) {
+    batch_group_count = 1 : i64,
+    dimension_numbers = {
+      input_batch_dimension = 0 : i64,
+      input_feature_dimension = 4 : i64,
+      input_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>,
+      kernel_input_feature_dimension = 3 : i64,
+      kernel_output_feature_dimension = 4 : i64,
+      kernel_spatial_dimensions = dense<[0, 1, 2]> : tensor<3xi64>,
+      output_batch_dimension = 0 : i64,
+      output_feature_dimension = 4 : i64,
+      output_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>
+    },
+    feature_group_count = 1 : i64,
+    padding = dense<0> : tensor<3x2xi64>,
+    rhs_dilation = dense<1> : tensor<3xi64>,
+    window_strides = dense<1> : tensor<3xi64>
+  } : (tensor<2x8x8x8x3xf32>, tensor<2x2x2x3x2xf32>) -> tensor<2x7x7x7x2xf32>
+  check.expect_almost_eq_const(%res, dense<24.0> : tensor<2x7x7x7x2xf32>) : tensor<2x7x7x7x2xf32>
+  return
+}
diff --git a/iree/test/e2e/xla_ops/convolution.mlir b/iree/test/e2e/xla_ops/convolution.mlir
index bfdea3c..7fb1008 100644
--- a/iree/test/e2e/xla_ops/convolution.mlir
+++ b/iree/test/e2e/xla_ops/convolution.mlir
@@ -341,3 +341,64 @@
         [105921.0, 107874.0, 109827.0, 111780.0, 113733.0, 115686.0]]]]> : tensor<2x3x3x6xf32>) : tensor<2x3x3x6xf32>
   return
 }
+
+func @conv2d_1452x2223_dilated_valid() {
+  %inputs = iree.unfoldable_constant dense<
+     [[[[0.09762701,  0.43037874],
+       [ 0.20552675,  0.08976637],
+       [-0.1526904,   0.29178822],
+       [-0.12482557,  0.78354603],
+       [ 0.92732555, -0.23311697]],
+      [[ 0.5834501,   0.05778984],
+       [ 0.13608912,  0.85119325],
+       [-0.85792786, -0.8257414 ],
+       [-0.9595632,   0.6652397 ],
+       [ 0.5563135,   0.74002427]],
+      [[ 0.9572367,   0.59831715],
+       [-0.07704128,  0.56105834],
+       [-0.76345116,  0.27984205],
+       [-0.71329343,  0.88933784],
+       [ 0.04369664, -0.17067613]],
+      [[-0.47088876,  0.5484674 ],
+       [-0.08769934,  0.1368679 ],
+       [-0.9624204,   0.23527099],
+       [ 0.22419144,  0.23386799],
+       [ 0.8874962,   0.3636406 ]]]]> : tensor<1x4x5x2xf32>
+  %weights = iree.unfoldable_constant dense<
+    [[[[-0.2809842,  -0.12593609,  0.3952624 ],
+       [-0.8795491,   0.33353344,  0.34127575]],
+      [[-0.5792349,  -0.7421474,  -0.3691433 ],
+       [-0.27257845,  0.14039354, -0.12279698]]],
+     [[[ 0.9767477,  -0.79591036, -0.5822465 ],
+       [-0.677381,    0.30621666, -0.4934168 ]],
+      [[-0.06737845, -0.5111488,  -0.68206084],
+       [-0.7792497,   0.31265917, -0.7236341 ]]]]> : tensor<2x2x2x3xf32>
+  %res = "mhlo.convolution"(%inputs, %weights) {
+    batch_group_count = 1 : i64,
+    dimension_numbers = {
+      input_batch_dimension = 0 : i64,
+      input_feature_dimension = 3 : i64,
+      input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+      kernel_input_feature_dimension = 2 : i64,
+      kernel_output_feature_dimension = 3 : i64,
+      kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+      output_batch_dimension = 0 : i64,
+      output_feature_dimension = 3 : i64,
+      output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
+    },
+    feature_group_count = 1 : i64,
+    padding = dense<0> : tensor<2x2xi64>,
+    rhs_dilation = dense<[2, 1]> : tensor<2xi64>,
+    window_strides = dense<1> : tensor<2xi64>
+  } : (tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<1x2x4x3xf32>
+  check.expect_almost_eq_const(%res, dense<
+    [[[[-0.45181108, -0.37253797, -1.1074474 ],
+       [-0.74972206,  0.8691965,   0.21864426],
+       [-1.9352274,   1.6551838,   0.13848126],
+       [-2.296763,    0.32046723, -0.02542188]],
+      [[-1.4578199,   0.59465677,  0.0599021 ],
+       [-0.3617443,   1.4647548,   1.2320882 ],
+       [ 0.04506956,  1.4347346,  -0.22625303],
+       [-1.122044,   -0.41301775, -1.5628793 ]]]]> : tensor<1x2x4x3xf32>) : tensor<1x2x4x3xf32>
+  return
+}