[iree][global] Add conv2d op to demote to bf16 pass (#17410)

Adds conv2d or convolutionlikeinterface ops to the
demotecontractioninputstobf16 pass.
diff --git a/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp b/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp
index ceec1bc..793c713 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp
@@ -23,11 +23,14 @@
 
 // For narrowable inputs, selects
 struct DemoteContractionInputsToBF16Pattern
-    : public OpInterfaceRewritePattern<linalg::ContractionOpInterface> {
+    : public OpInterfaceRewritePattern<linalg::LinalgOp> {
   using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
-  LogicalResult matchAndRewrite(linalg::ContractionOpInterface op,
+  LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
-    linalg::LinalgOp linalgOp = cast<linalg::LinalgOp>(op.getOperation());
+    if (!isa<linalg::ContractionOpInterface, linalg::ConvolutionOpInterface>(
+            linalgOp.getOperation())) {
+      return failure();
+    }
     for (auto operand : linalgOp->getOperands()) {
       auto operandType = dyn_cast<RankedTensorType>(operand.getType());
       if (!operandType ||
@@ -65,32 +68,44 @@
     }
 
     auto replaceOpInputs = [&](auto *typePtr) {
-      auto namedOp = cast<std::remove_pointer_t<decltype(typePtr)>>(op);
+      auto namedOp = cast<std::remove_pointer_t<decltype(typePtr)>>(linalgOp);
       rewriter.replaceOpWithNewOp<std::remove_pointer_t<decltype(typePtr)>>(
           linalgOp, demotedInputs, linalgOp.getDpsInits(),
           linalg::getPrunedAttributeList(namedOp));
     };
 
-    if (isa<linalg::MatmulOp>(op)) {
+    if (isa<linalg::MatmulOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::MatmulOp *>(nullptr));
-    } else if (isa<linalg::MatvecOp>(op)) {
+    } else if (isa<linalg::MatvecOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::MatvecOp *>(nullptr));
-    } else if (isa<linalg::VecmatOp>(op)) {
+    } else if (isa<linalg::VecmatOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::VecmatOp *>(nullptr));
-    } else if (isa<linalg::BatchMatmulOp>(op)) {
+    } else if (isa<linalg::BatchMatmulOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::BatchMatmulOp *>(nullptr));
-    } else if (isa<linalg::BatchMatvecOp>(op)) {
+    } else if (isa<linalg::BatchMatvecOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::BatchMatvecOp *>(nullptr));
-    } else if (isa<linalg::BatchVecmatOp>(op)) {
+    } else if (isa<linalg::BatchVecmatOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::BatchVecmatOp *>(nullptr));
-    } else if (isa<linalg::MatmulTransposeAOp>(op)) {
+    } else if (isa<linalg::MatmulTransposeAOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::MatmulTransposeAOp *>(nullptr));
-    } else if (isa<linalg::MatmulTransposeBOp>(op)) {
+    } else if (isa<linalg::MatmulTransposeBOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::MatmulTransposeBOp *>(nullptr));
-    } else if (isa<linalg::BatchMatmulTransposeAOp>(op)) {
+    } else if (isa<linalg::BatchMatmulTransposeAOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::BatchMatmulTransposeAOp *>(nullptr));
-    } else if (isa<linalg::BatchMatmulTransposeBOp>(op)) {
+    } else if (isa<linalg::BatchMatmulTransposeBOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::BatchMatmulTransposeBOp *>(nullptr));
+    } else if (isa<linalg::Conv2DOp>(linalgOp)) {
+      replaceOpInputs(static_cast<linalg::Conv2DOp *>(nullptr));
+    } else if (isa<linalg::Conv2DNchwFchwOp>(linalgOp)) {
+      replaceOpInputs(static_cast<linalg::Conv2DNchwFchwOp *>(nullptr));
+    } else if (isa<linalg::Conv2DNhwcHwcfOp>(linalgOp)) {
+      replaceOpInputs(static_cast<linalg::Conv2DNhwcHwcfOp *>(nullptr));
+    } else if (isa<linalg::Conv2DNhwcFhwcOp>(linalgOp)) {
+      replaceOpInputs(static_cast<linalg::Conv2DNhwcFhwcOp *>(nullptr));
+    } else if (isa<linalg::Conv2DNgchwFgchwOp>(linalgOp)) {
+      replaceOpInputs(static_cast<linalg::Conv2DNgchwFgchwOp *>(nullptr));
+    } else if (isa<linalg::Conv2DNgchwGfchwOp>(linalgOp)) {
+      replaceOpInputs(static_cast<linalg::Conv2DNgchwGfchwOp *>(nullptr));
     } else {
       return failure();
     }
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/demote_contraction_inputs_to_bf16.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/demote_contraction_inputs_to_bf16.mlir
index ff717b8..ea8b168 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/demote_contraction_inputs_to_bf16.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/demote_contraction_inputs_to_bf16.mlir
@@ -219,3 +219,36 @@
 // CHECK: linalg.matmul_transpose_b
 // CHECK-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<100x250xbf16>, tensor<500x250xbf16>)
 // CHECK-SAME: outs(%[[ARG2]] : tensor<100x500xf32>)
+
+// -----
+
+util.func public @conv_2d_nchw_fchw_f32f32f32(%arg0 : tensor<1x16x130x130xf32>, %arg1 : tensor<512x16x3x3xf32>,
+    %arg2 : tensor<1x512x128x128xf32>) -> tensor<1x512x128x128xf32> {
+    %0 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} 
+         ins(%arg0, %arg1 : tensor<1x16x130x130xf32>, tensor<512x16x3x3xf32>) 
+         outs(%arg2 : tensor<1x512x128x128xf32>) -> tensor<1x512x128x128xf32>
+  util.return %0 : tensor<1x512x128x128xf32>
+}
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL:   util.func public @conv_2d_nchw_fchw_f32f32f32(
+// CHECK-SAME:                                                  %[[VAL_0:.*]]: tensor<1x16x130x130xf32>,
+// CHECK-SAME:                                                  %[[VAL_1:.*]]: tensor<512x16x3x3xf32>,
+// CHECK-SAME:                                                  %[[VAL_2:.*]]: tensor<1x512x128x128xf32>) -> tensor<1x512x128x128xf32> {
+// CHECK:           %[[VAL_3:.*]] = tensor.empty() : tensor<1x16x130x130xbf16>
+// CHECK:           %[[DEMOT1:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]]], 
+// CHECK-SAME:          iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 
+// CHECK-SAME:          ins(%[[VAL_0]] : tensor<1x16x130x130xf32>) outs(%[[VAL_3]] : tensor<1x16x130x130xbf16>) {
+// CHECK:           ^bb0(%[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: bf16):
+// CHECK:             %[[VAL_7:.*]] = arith.truncf %[[VAL_5]] : f32 to bf16
+// CHECK:             linalg.yield %[[VAL_7]] : bf16
+// CHECK:           } -> tensor<1x16x130x130xbf16>
+// CHECK:           %[[VAL_8:.*]] = tensor.empty() : tensor<512x16x3x3xbf16>
+// CHECK:           %[[DEMOT2:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]]], 
+// CHECK-SAME:          iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 
+// CHECK-SAME:          ins(%[[VAL_1]] : tensor<512x16x3x3xf32>) outs(%[[VAL_8]] : tensor<512x16x3x3xbf16>) {
+// CHECK:           ^bb0(%[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: bf16):
+// CHECK:             %[[VAL_12:.*]] = arith.truncf %[[VAL_10]] : f32 to bf16
+// CHECK:             linalg.yield %[[VAL_12]] : bf16
+// CHECK:           } -> tensor<512x16x3x3xbf16>
+// CHECK:           %[[VAL_13:.*]] = linalg.conv_2d_nchw_fchw ins(%[[DEMOT1]], %[[DEMOT2]] : tensor<1x16x130x130xbf16>, tensor<512x16x3x3xbf16>) 
+// CHECK-SAME:      outs(%[[VAL_2]] : tensor<1x512x128x128xf32>) -> tensor<1x512x128x128xf32>