[iree][global] Control the demotion of ops (#17515)

Introduces `demote-only` flag in `demote-contraction-inputs-to-bf16` to
control the demotion of ops. For e.g., if `demote-only=conv` only conv
ops will be demoted.
diff --git a/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp b/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp
index 793c713..9df9578 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp
@@ -24,87 +24,102 @@
 // For narrowable inputs, selects
 struct DemoteContractionInputsToBF16Pattern
     : public OpInterfaceRewritePattern<linalg::LinalgOp> {
-  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+  using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
+  explicit DemoteContractionInputsToBF16Pattern(MLIRContext *ctx,
+                                                DemotionOption &option)
+      : OpInterfaceRewritePattern<linalg::LinalgOp>(ctx), demoteOption(option) {
+  }
+
   LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
+    if (demoteOption == DemotionOption::None) {
+      return failure();
+    }
     if (!isa<linalg::ContractionOpInterface, linalg::ConvolutionOpInterface>(
             linalgOp.getOperation())) {
       return failure();
     }
-    for (auto operand : linalgOp->getOperands()) {
-      auto operandType = dyn_cast<RankedTensorType>(operand.getType());
-      if (!operandType ||
-          operandType.getElementType() != rewriter.getF32Type()) {
-        return failure();
-      }
-    }
-    Location loc = linalgOp.getLoc();
-    SmallVector<Value> demotedInputs;
-    for (auto inputOperand : linalgOp.getDpsInputOperands()) {
-      auto input = inputOperand->get();
-      auto inputType = cast<RankedTensorType>(input.getType());
-      auto demotedInputType =
-          RankedTensorType::get(inputType.getShape(), rewriter.getBF16Type(),
-                                inputType.getEncoding());
-      SmallVector<AffineMap> maps(
-          2, rewriter.getMultiDimIdentityMap(inputType.getRank()));
-      SmallVector<utils::IteratorType> iteratorTypes(
-          inputType.getRank(), utils::IteratorType::parallel);
-      SmallVector<OpFoldResult> mixedSizes =
-          tensor::getMixedSizes(rewriter, loc, input);
-      Value empty = rewriter.create<tensor::EmptyOp>(loc, mixedSizes,
-                                                     rewriter.getBF16Type());
-      demotedInputs.push_back(
-          rewriter
-              .create<linalg::GenericOp>(
-                  loc, TypeRange{demotedInputType}, ValueRange{input},
-                  ValueRange{empty}, maps, iteratorTypes,
-                  [&](OpBuilder &b, Location loc, ValueRange args) {
-                    Value result = b.create<arith::TruncFOp>(
-                        loc, rewriter.getBF16Type(), args[0]);
-                    b.create<linalg::YieldOp>(loc, result);
-                  })
-              ->getResults()[0]);
+
+    if (!llvm::all_of(linalgOp->getOperands(), [&](auto operand) {
+          auto operandType = dyn_cast<RankedTensorType>(operand.getType());
+          return operandType &&
+                 operandType.getElementType() == rewriter.getF32Type();
+        })) {
+      return failure();
     }
 
     auto replaceOpInputs = [&](auto *typePtr) {
+      Location loc = linalgOp.getLoc();
+      SmallVector<Value> demotedInputs;
+      for (auto inputOperand : linalgOp.getDpsInputOperands()) {
+        auto input = inputOperand->get();
+        auto inputType = cast<RankedTensorType>(input.getType());
+        auto demotedInputType =
+            RankedTensorType::get(inputType.getShape(), rewriter.getBF16Type(),
+                                  inputType.getEncoding());
+        SmallVector<AffineMap> maps(
+            2, rewriter.getMultiDimIdentityMap(inputType.getRank()));
+        SmallVector<utils::IteratorType> iteratorTypes(
+            inputType.getRank(), utils::IteratorType::parallel);
+        SmallVector<OpFoldResult> mixedSizes =
+            tensor::getMixedSizes(rewriter, loc, input);
+        Value empty = rewriter.create<tensor::EmptyOp>(loc, mixedSizes,
+                                                       rewriter.getBF16Type());
+        demotedInputs.push_back(
+            rewriter
+                .create<linalg::GenericOp>(
+                    loc, TypeRange{demotedInputType}, ValueRange{input},
+                    ValueRange{empty}, maps, iteratorTypes,
+                    [&](OpBuilder &b, Location loc, ValueRange args) {
+                      Value result = b.create<arith::TruncFOp>(
+                          loc, rewriter.getBF16Type(), args[0]);
+                      b.create<linalg::YieldOp>(loc, result);
+                    })
+                ->getResults()[0]);
+      }
       auto namedOp = cast<std::remove_pointer_t<decltype(typePtr)>>(linalgOp);
       rewriter.replaceOpWithNewOp<std::remove_pointer_t<decltype(typePtr)>>(
           linalgOp, demotedInputs, linalgOp.getDpsInits(),
           linalg::getPrunedAttributeList(namedOp));
     };
 
-    if (isa<linalg::MatmulOp>(linalgOp)) {
+    bool demoteMatmul = (demoteOption == DemotionOption::All) ||
+                        (demoteOption == DemotionOption::Matmul);
+
+    bool demoteConv = (demoteOption == DemotionOption::All) ||
+                      (demoteOption == DemotionOption::Conv);
+
+    if (demoteMatmul && isa<linalg::MatmulOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::MatmulOp *>(nullptr));
-    } else if (isa<linalg::MatvecOp>(linalgOp)) {
+    } else if (demoteMatmul && isa<linalg::MatvecOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::MatvecOp *>(nullptr));
-    } else if (isa<linalg::VecmatOp>(linalgOp)) {
+    } else if (demoteMatmul && isa<linalg::VecmatOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::VecmatOp *>(nullptr));
-    } else if (isa<linalg::BatchMatmulOp>(linalgOp)) {
+    } else if (demoteMatmul && isa<linalg::BatchMatmulOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::BatchMatmulOp *>(nullptr));
-    } else if (isa<linalg::BatchMatvecOp>(linalgOp)) {
+    } else if (demoteMatmul && isa<linalg::BatchMatvecOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::BatchMatvecOp *>(nullptr));
-    } else if (isa<linalg::BatchVecmatOp>(linalgOp)) {
+    } else if (demoteMatmul && isa<linalg::BatchVecmatOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::BatchVecmatOp *>(nullptr));
-    } else if (isa<linalg::MatmulTransposeAOp>(linalgOp)) {
+    } else if (demoteMatmul && isa<linalg::MatmulTransposeAOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::MatmulTransposeAOp *>(nullptr));
-    } else if (isa<linalg::MatmulTransposeBOp>(linalgOp)) {
+    } else if (demoteMatmul && isa<linalg::MatmulTransposeBOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::MatmulTransposeBOp *>(nullptr));
-    } else if (isa<linalg::BatchMatmulTransposeAOp>(linalgOp)) {
+    } else if (demoteMatmul && isa<linalg::BatchMatmulTransposeAOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::BatchMatmulTransposeAOp *>(nullptr));
-    } else if (isa<linalg::BatchMatmulTransposeBOp>(linalgOp)) {
+    } else if (demoteMatmul && isa<linalg::BatchMatmulTransposeBOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::BatchMatmulTransposeBOp *>(nullptr));
-    } else if (isa<linalg::Conv2DOp>(linalgOp)) {
+    } else if (demoteConv && isa<linalg::Conv2DOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::Conv2DOp *>(nullptr));
-    } else if (isa<linalg::Conv2DNchwFchwOp>(linalgOp)) {
+    } else if (demoteConv && isa<linalg::Conv2DNchwFchwOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::Conv2DNchwFchwOp *>(nullptr));
-    } else if (isa<linalg::Conv2DNhwcHwcfOp>(linalgOp)) {
+    } else if (demoteConv && isa<linalg::Conv2DNhwcHwcfOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::Conv2DNhwcHwcfOp *>(nullptr));
-    } else if (isa<linalg::Conv2DNhwcFhwcOp>(linalgOp)) {
+    } else if (demoteConv && isa<linalg::Conv2DNhwcFhwcOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::Conv2DNhwcFhwcOp *>(nullptr));
-    } else if (isa<linalg::Conv2DNgchwFgchwOp>(linalgOp)) {
+    } else if (demoteConv && isa<linalg::Conv2DNgchwFgchwOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::Conv2DNgchwFgchwOp *>(nullptr));
-    } else if (isa<linalg::Conv2DNgchwGfchwOp>(linalgOp)) {
+    } else if (demoteConv && isa<linalg::Conv2DNgchwGfchwOp>(linalgOp)) {
       replaceOpInputs(static_cast<linalg::Conv2DNgchwGfchwOp *>(nullptr));
     } else {
       return failure();
@@ -112,15 +127,24 @@
 
     return success();
   }
+
+private:
+  DemotionOption demoteOption;
 };
 
 class DemoteContractionInputsToBF16Pass
     : public DemoteContractionInputsToBF16Base<
           DemoteContractionInputsToBF16Pass> {
+
+public:
+  explicit DemoteContractionInputsToBF16Pass(const DemotionOption &option) {
+    this->demoteOnly.setValue(option);
+  }
   void runOnOperation() override {
     MLIRContext *context = &getContext();
     RewritePatternSet patterns(context);
-    patterns.insert<DemoteContractionInputsToBF16Pattern>(context);
+    patterns.insert<DemoteContractionInputsToBF16Pattern>(
+        context, demoteOnly.getValue());
     if (failed(applyPatternsAndFoldGreedily(getOperation(),
                                             std::move(patterns)))) {
       return signalPassFailure();
@@ -130,8 +154,9 @@
 
 } // namespace
 
-std::unique_ptr<Pass> createDemoteContractionInputsToBF16Pass() {
-  return std::make_unique<DemoteContractionInputsToBF16Pass>();
+std::unique_ptr<Pass>
+createDemoteContractionInputsToBF16Pass(DemotionOption option) {
+  return std::make_unique<DemoteContractionInputsToBF16Pass>(option);
 }
 
 } // namespace mlir::iree_compiler::GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/PassDetail.h b/compiler/src/iree/compiler/GlobalOptimization/PassDetail.h
index 49ca67b..b0b79e7 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/PassDetail.h
+++ b/compiler/src/iree/compiler/GlobalOptimization/PassDetail.h
@@ -7,6 +7,7 @@
 #ifndef IREE_COMPILER_GLOBALOPTIMIZATION_PASSDETAIL_H_
 #define IREE_COMPILER_GLOBALOPTIMIZATION_PASSDETAIL_H_
 
+#include "iree/compiler/GlobalOptimization/Passes.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Pass/Pass.h"
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
index 12b2939..ddabd1c 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
@@ -51,11 +51,17 @@
         "Enables horizontal fusion of contractions with one common operand"),
     llvm::cl::init(false));
 
-static llvm::cl::opt<bool> clEnableDemoteContractionInputsToBF16(
+static llvm::cl::opt<DemotionOption> clDemoteContractionInputsToBF16Strategy(
     "iree-global-opt-enable-demote-contraction-inputs-to-bf16",
-    llvm::cl::desc(
-        "Demote inputs (LHS, RHS) of linalg matmul-like ops from f32 to bf16."),
-    llvm::cl::init(false));
+    llvm::cl::desc("Demotes inputs (LHS, RHS) of contraction ops to BF16. "
+                   "Selects types of contraction ops to demote."),
+    llvm::cl::values(
+        clEnumValN(DemotionOption::All, "all", "Demote all contraction ops."),
+        clEnumValN(DemotionOption::Conv, "conv",
+                   "Only demote convolution ops."),
+        clEnumValN(DemotionOption::Matmul, "matmul", "Only demote matmul ops."),
+        clEnumValN(DemotionOption::None, "none", "Demote no contraction ops.")),
+    llvm::cl::init(DemotionOption::None));
 
 void buildGlobalOptExprHoistingPassPipeline(
     OpPassManager &passManager, const TransformOptions &transformOptions) {
@@ -120,8 +126,10 @@
       .addPass(IREE::Flow::createFoldUnitExtentDimsPass)
       .addPredicatedPass(clEnableFuseSiluHorizontalMatmul,
                          createFuseSiluHorizontalMatmulPass)
-      .addPredicatedPass(clEnableDemoteContractionInputsToBF16,
-                         createDemoteContractionInputsToBF16Pass)
+      .addPass([&]() {
+        return createDemoteContractionInputsToBF16Pass(
+            clDemoteContractionInputsToBF16Strategy);
+      })
       .addPass([&]() {
         return createFuseDequantizationMatmulPass(
             clEnableQuantizedMatmulReassociation);
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.h b/compiler/src/iree/compiler/GlobalOptimization/Passes.h
index e2f79ba..6ddbeab 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.h
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.h
@@ -51,8 +51,13 @@
 std::unique_ptr<Pass>
 createDecomposeConcatPass(bool enableConcatTransposition = false);
 
+// Used by the demoteContractionInputsToBF16 pass to determine which op inputs
+// to demote.
+enum class DemotionOption { All, Conv, Matmul, None };
+
 /// Demotes inputs (LHS, RHS) of linalg matmul-like ops from f32 to bf16.
-std::unique_ptr<Pass> createDemoteContractionInputsToBF16Pass();
+std::unique_ptr<Pass> createDemoteContractionInputsToBF16Pass(
+    DemotionOption option = DemotionOption::None);
 
 /// Detaches elementwise ops from named Linalg ops.
 std::unique_ptr<Pass> createDetachElementwiseFromNamedOpsPass();
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.td b/compiler/src/iree/compiler/GlobalOptimization/Passes.td
index 29e1b3a..0f3bcd3 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.td
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.td
@@ -32,9 +32,33 @@
   ];
 }
 
-def DemoteContractionInputsToBF16 : Pass<"iree-global-opt-demote-contraction-inputs-to-bf16", ""> {
-  let summary = "Demotes inputs (LHS, RHS) of linalg matmul-like ops from f32 to bf16.";
-  let constructor = "mlir::iree_compiler::GlobalOptimization::createDemoteContractionInputsToBF16Pass()";
+def DemoteContractionInputsToBF16
+    : Pass<"iree-global-opt-demote-contraction-inputs-to-bf16", ""> {
+  let summary =
+      "Demotes inputs (LHS, RHS) of linalg matmul-like ops from f32 to bf16.";
+  let constructor = "mlir::iree_compiler::GlobalOptimization::"
+                    "createDemoteContractionInputsToBF16Pass()";
+  let options =
+      [Option<"demoteOnly", "demote-only",
+              "mlir::iree_compiler::GlobalOptimization::DemotionOption",
+              /*default=*/
+              "mlir::iree_compiler::GlobalOptimization::DemotionOption::All",
+              "Select the type of contraction ops to demote.",
+              [{::llvm::cl::values(
+            clEnumValN(mlir::iree_compiler::GlobalOptimization::DemotionOption::All,
+                       "all",
+                      "demote all contraction ops."),
+            clEnumValN(mlir::iree_compiler::GlobalOptimization::DemotionOption::Conv,
+                       "conv",
+                       "Only demote convolution ops."),
+            clEnumValN(mlir::iree_compiler::GlobalOptimization::DemotionOption::Matmul,
+                       "matmul",
+                       "Only demote matmul ops."),
+            clEnumValN(mlir::iree_compiler::GlobalOptimization::DemotionOption::None,
+                       "none",
+                      "demote no contraction ops.")
+        )}]>,
+  ];
 }
 
 def DetachElementwiseFromNamedOps :
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/demote_contraction_inputs_to_bf16.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/demote_contraction_inputs_to_bf16.mlir
index 709161f..387c8a0 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/demote_contraction_inputs_to_bf16.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/demote_contraction_inputs_to_bf16.mlir
@@ -1,4 +1,5 @@
-// RUN: iree-opt --split-input-file -iree-global-opt-demote-contraction-inputs-to-bf16 %s | FileCheck %s
+// RUN: iree-opt --split-input-file -iree-global-opt-demote-contraction-inputs-to-bf16="demote-only=matmul" %s | FileCheck %s --check-prefix=MATMUL
+// RUN: iree-opt --split-input-file -iree-global-opt-demote-contraction-inputs-to-bf16="demote-only=conv" %s | FileCheck %s --check-prefix=CONV
 
 util.func public @matmul_f32f32f32(%arg0 : tensor<100x250xf32>, %arg1 : tensor<250x500xf32>,
     %arg2 : tensor<100x500xf32>) -> tensor<100x500xf32> {
@@ -7,19 +8,19 @@
   util.return %0 : tensor<100x500xf32>
 }
 
-// CHECK: @matmul_f32f32f32
-// CHECK-SAME: %[[ARG0:.+]]: tensor<100x250xf32>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<250x500xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<100x500xf32>
-// CHECK: %[[DEMOTED0:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<100x250xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: %[[DEMOTED1:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG1]] : tensor<250x500xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: linalg.matmul
-// CHECK-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<100x250xbf16>, tensor<250x500xbf16>)
-// CHECK-SAME: outs(%[[ARG2]] : tensor<100x500xf32>)
+// MATMUL: @matmul_f32f32f32
+// MATMUL-SAME: %[[ARG0:.+]]: tensor<100x250xf32>
+// MATMUL-SAME: %[[ARG1:.+]]: tensor<250x500xf32>
+// MATMUL-SAME: %[[ARG2:.+]]: tensor<100x500xf32>
+// MATMUL: %[[DEMOTED0:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG0]] : tensor<100x250xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG1]] : tensor<250x500xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: linalg.matmul
+// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<100x250xbf16>, tensor<250x500xbf16>)
+// MATMUL-SAME: outs(%[[ARG2]] : tensor<100x500xf32>)
 
 // -----
 
@@ -30,17 +31,17 @@
   util.return %0 : tensor<?x?xf32>
 }
 
-// CHECK: @dynamic_matmul_f32f32f32
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>
-// CHECK: %[[DEMOTED0:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<?x?xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: %[[DEMOTED1:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG1]] : tensor<?x?xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: linalg.matmul
-// CHECK-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<?x?xbf16>, tensor<?x?xbf16>)
-// CHECK-SAME: outs(%[[ARG2]] : tensor<?x?xf32>)
+// MATMUL: @dynamic_matmul_f32f32f32
+// MATMUL-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>
+// MATMUL: %[[DEMOTED0:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG0]] : tensor<?x?xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG1]] : tensor<?x?xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: linalg.matmul
+// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<?x?xbf16>, tensor<?x?xbf16>)
+// MATMUL-SAME: outs(%[[ARG2]] : tensor<?x?xf32>)
 
 // -----
 
@@ -51,19 +52,19 @@
   util.return %0 : tensor<4x100x500xf32>
 }
 
-// CHECK: @batch_matmul_f32f32f32
-// CHECK-SAME: %[[ARG0:.+]]: tensor<4x100x250xf32>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<4x250x500xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<4x100x500xf32>
-// CHECK: %[[DEMOTED0:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<4x100x250xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: %[[DEMOTED1:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG1]] : tensor<4x250x500xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: linalg.batch_matmul
-// CHECK-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<4x100x250xbf16>, tensor<4x250x500xbf16>)
-// CHECK-SAME: outs(%[[ARG2]] : tensor<4x100x500xf32>)
+// MATMUL: @batch_matmul_f32f32f32
+// MATMUL-SAME: %[[ARG0:.+]]: tensor<4x100x250xf32>
+// MATMUL-SAME: %[[ARG1:.+]]: tensor<4x250x500xf32>
+// MATMUL-SAME: %[[ARG2:.+]]: tensor<4x100x500xf32>
+// MATMUL: %[[DEMOTED0:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG0]] : tensor<4x100x250xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG1]] : tensor<4x250x500xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: linalg.batch_matmul
+// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<4x100x250xbf16>, tensor<4x250x500xbf16>)
+// MATMUL-SAME: outs(%[[ARG2]] : tensor<4x100x500xf32>)
 
 // -----
 
@@ -74,19 +75,19 @@
   util.return %0 : tensor<100xf32>
 }
 
-// CHECK: @matvec_f32f32f32
-// CHECK-SAME: %[[ARG0:.+]]: tensor<100x250xf32>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<250xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<100xf32>
-// CHECK: %[[DEMOTED0:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<100x250xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: %[[DEMOTED1:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG1]] : tensor<250xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: linalg.matvec
-// CHECK-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<100x250xbf16>, tensor<250xbf16>)
-// CHECK-SAME: outs(%[[ARG2]] : tensor<100xf32>)
+// MATMUL: @matvec_f32f32f32
+// MATMUL-SAME: %[[ARG0:.+]]: tensor<100x250xf32>
+// MATMUL-SAME: %[[ARG1:.+]]: tensor<250xf32>
+// MATMUL-SAME: %[[ARG2:.+]]: tensor<100xf32>
+// MATMUL: %[[DEMOTED0:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG0]] : tensor<100x250xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG1]] : tensor<250xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: linalg.matvec
+// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<100x250xbf16>, tensor<250xbf16>)
+// MATMUL-SAME: outs(%[[ARG2]] : tensor<100xf32>)
 
 // -----
 
@@ -97,19 +98,19 @@
   util.return %0 : tensor<4x500xf32>
 }
 
-// CHECK: @batch_vecmat_f32f32f32
-// CHECK-SAME: %[[ARG0:.+]]: tensor<4x250xf32>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<4x250x500xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<4x500xf32>
-// CHECK: %[[DEMOTED0:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<4x250xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: %[[DEMOTED1:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG1]] : tensor<4x250x500xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: linalg.batch_vecmat
-// CHECK-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<4x250xbf16>, tensor<4x250x500xbf16>)
-// CHECK-SAME: outs(%[[ARG2]] : tensor<4x500xf32>)
+// MATMUL: @batch_vecmat_f32f32f32
+// MATMUL-SAME: %[[ARG0:.+]]: tensor<4x250xf32>
+// MATMUL-SAME: %[[ARG1:.+]]: tensor<4x250x500xf32>
+// MATMUL-SAME: %[[ARG2:.+]]: tensor<4x500xf32>
+// MATMUL: %[[DEMOTED0:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG0]] : tensor<4x250xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG1]] : tensor<4x250x500xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: linalg.batch_vecmat
+// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<4x250xbf16>, tensor<4x250x500xbf16>)
+// MATMUL-SAME: outs(%[[ARG2]] : tensor<4x500xf32>)
 
 // -----
 
@@ -120,13 +121,13 @@
   util.return %0 : tensor<100x500xf64>
 }
 
-// CHECK: @nonmatch_matmul_f32f32f64
-// CHECK-SAME: %[[ARG0:.+]]: tensor<100x250xf32>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<250x500xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<100x500xf64>
-// CHECK: linalg.matmul
-// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<100x250xf32>, tensor<250x500xf32>)
-// CHECK-SAME: outs(%[[ARG2]] : tensor<100x500xf64>)
+// MATMUL: @nonmatch_matmul_f32f32f64
+// MATMUL-SAME: %[[ARG0:.+]]: tensor<100x250xf32>
+// MATMUL-SAME: %[[ARG1:.+]]: tensor<250x500xf32>
+// MATMUL-SAME: %[[ARG2:.+]]: tensor<100x500xf64>
+// MATMUL: linalg.matmul
+// MATMUL-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<100x250xf32>, tensor<250x500xf32>)
+// MATMUL-SAME: outs(%[[ARG2]] : tensor<100x500xf64>)
 
 // -----
 
@@ -137,19 +138,19 @@
   util.return %0 : tensor<4x100x500xf32>
 }
 
-// CHECK: @batch_matmul_transpose_a_f32f32f32
-// CHECK-SAME: %[[ARG0:.+]]: tensor<4x250x100xf32>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<4x250x500xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<4x100x500xf32>
-// CHECK: %[[DEMOTED0:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<4x250x100xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: %[[DEMOTED1:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG1]] : tensor<4x250x500xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: linalg.batch_matmul_transpose_a
-// CHECK-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<4x250x100xbf16>, tensor<4x250x500xbf16>)
-// CHECK-SAME: outs(%[[ARG2]] : tensor<4x100x500xf32>)
+// MATMUL: @batch_matmul_transpose_a_f32f32f32
+// MATMUL-SAME: %[[ARG0:.+]]: tensor<4x250x100xf32>
+// MATMUL-SAME: %[[ARG1:.+]]: tensor<4x250x500xf32>
+// MATMUL-SAME: %[[ARG2:.+]]: tensor<4x100x500xf32>
+// MATMUL: %[[DEMOTED0:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG0]] : tensor<4x250x100xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG1]] : tensor<4x250x500xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: linalg.batch_matmul_transpose_a
+// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<4x250x100xbf16>, tensor<4x250x500xbf16>)
+// MATMUL-SAME: outs(%[[ARG2]] : tensor<4x100x500xf32>)
 
 // -----
 
@@ -160,19 +161,19 @@
   util.return %0 : tensor<4x100x500xf32>
 }
 
-// CHECK: @batch_matmul_transpose_b_f32f32f32
-// CHECK-SAME: %[[ARG0:.+]]: tensor<4x100x250xf32>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<4x500x250xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<4x100x500xf32>
-// CHECK: %[[DEMOTED0:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<4x100x250xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: %[[DEMOTED1:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG1]] : tensor<4x500x250xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: linalg.batch_matmul_transpose_b
-// CHECK-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<4x100x250xbf16>, tensor<4x500x250xbf16>)
-// CHECK-SAME: outs(%[[ARG2]] : tensor<4x100x500xf32>)
+// MATMUL: @batch_matmul_transpose_b_f32f32f32
+// MATMUL-SAME: %[[ARG0:.+]]: tensor<4x100x250xf32>
+// MATMUL-SAME: %[[ARG1:.+]]: tensor<4x500x250xf32>
+// MATMUL-SAME: %[[ARG2:.+]]: tensor<4x100x500xf32>
+// MATMUL: %[[DEMOTED0:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG0]] : tensor<4x100x250xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG1]] : tensor<4x500x250xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: linalg.batch_matmul_transpose_b
+// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<4x100x250xbf16>, tensor<4x500x250xbf16>)
+// MATMUL-SAME: outs(%[[ARG2]] : tensor<4x100x500xf32>)
 
 // -----
 
@@ -183,19 +184,19 @@
   util.return %0 : tensor<100x500xf32>
 }
 
-// CHECK: @matmul_transpose_a_f32f32f32
-// CHECK-SAME: %[[ARG0:.+]]: tensor<250x100xf32>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<250x500xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<100x500xf32>
-// CHECK: %[[DEMOTED0:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<250x100xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: %[[DEMOTED1:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG1]] : tensor<250x500xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: linalg.matmul_transpose_a
-// CHECK-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<250x100xbf16>, tensor<250x500xbf16>)
-// CHECK-SAME: outs(%[[ARG2]] : tensor<100x500xf32>)
+// MATMUL: @matmul_transpose_a_f32f32f32
+// MATMUL-SAME: %[[ARG0:.+]]: tensor<250x100xf32>
+// MATMUL-SAME: %[[ARG1:.+]]: tensor<250x500xf32>
+// MATMUL-SAME: %[[ARG2:.+]]: tensor<100x500xf32>
+// MATMUL: %[[DEMOTED0:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG0]] : tensor<250x100xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG1]] : tensor<250x500xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: linalg.matmul_transpose_a
+// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<250x100xbf16>, tensor<250x500xbf16>)
+// MATMUL-SAME: outs(%[[ARG2]] : tensor<100x500xf32>)
 
 // -----
 
@@ -206,19 +207,19 @@
   util.return %0 : tensor<100x500xf32>
 }
 
-// CHECK: @matmul_transpose_b_f32f32f32
-// CHECK-SAME: %[[ARG0:.+]]: tensor<100x250xf32>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<500x250xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<100x500xf32>
-// CHECK: %[[DEMOTED0:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<100x250xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: %[[DEMOTED1:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG1]] : tensor<500x250xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: linalg.matmul_transpose_b
-// CHECK-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<100x250xbf16>, tensor<500x250xbf16>)
-// CHECK-SAME: outs(%[[ARG2]] : tensor<100x500xf32>)
+// MATMUL: @matmul_transpose_b_f32f32f32
+// MATMUL-SAME: %[[ARG0:.+]]: tensor<100x250xf32>
+// MATMUL-SAME: %[[ARG1:.+]]: tensor<500x250xf32>
+// MATMUL-SAME: %[[ARG2:.+]]: tensor<100x500xf32>
+// MATMUL: %[[DEMOTED0:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG0]] : tensor<100x250xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG1]] : tensor<500x250xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: linalg.matmul_transpose_b
+// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<100x250xbf16>, tensor<500x250xbf16>)
+// MATMUL-SAME: outs(%[[ARG2]] : tensor<100x500xf32>)
 
 // -----
 
@@ -229,26 +230,26 @@
          outs(%arg2 : tensor<1x512x128x128xf32>) -> tensor<1x512x128x128xf32>
   util.return %0 : tensor<1x512x128x128xf32>
 }
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL:   util.func public @conv_2d_nchw_fchw_f32f32f32(
-// CHECK-SAME:                                                  %[[VAL_0:.*]]: tensor<1x16x130x130xf32>,
-// CHECK-SAME:                                                  %[[VAL_1:.*]]: tensor<512x16x3x3xf32>,
-// CHECK-SAME:                                                  %[[VAL_2:.*]]: tensor<1x512x128x128xf32>) -> tensor<1x512x128x128xf32> {
-// CHECK:           %[[VAL_3:.*]] = tensor.empty() : tensor<1x16x130x130xbf16>
-// CHECK:           %[[DEMOT1:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]]],
-// CHECK-SAME:          iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-// CHECK-SAME:          ins(%[[VAL_0]] : tensor<1x16x130x130xf32>) outs(%[[VAL_3]] : tensor<1x16x130x130xbf16>) {
-// CHECK:           ^bb0(%[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: bf16):
-// CHECK:             %[[VAL_7:.*]] = arith.truncf %[[VAL_5]] : f32 to bf16
-// CHECK:             linalg.yield %[[VAL_7]] : bf16
-// CHECK:           } -> tensor<1x16x130x130xbf16>
-// CHECK:           %[[VAL_8:.*]] = tensor.empty() : tensor<512x16x3x3xbf16>
-// CHECK:           %[[DEMOT2:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]]],
-// CHECK-SAME:          iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-// CHECK-SAME:          ins(%[[VAL_1]] : tensor<512x16x3x3xf32>) outs(%[[VAL_8]] : tensor<512x16x3x3xbf16>) {
-// CHECK:           ^bb0(%[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: bf16):
-// CHECK:             %[[VAL_12:.*]] = arith.truncf %[[VAL_10]] : f32 to bf16
-// CHECK:             linalg.yield %[[VAL_12]] : bf16
-// CHECK:           } -> tensor<512x16x3x3xbf16>
-// CHECK:           %[[VAL_13:.*]] = linalg.conv_2d_nchw_fchw ins(%[[DEMOT1]], %[[DEMOT2]] : tensor<1x16x130x130xbf16>, tensor<512x16x3x3xbf16>)
-// CHECK-SAME:      outs(%[[VAL_2]] : tensor<1x512x128x128xf32>) -> tensor<1x512x128x128xf32>
+// CONV: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CONV-LABEL:   util.func public @conv_2d_nchw_fchw_f32f32f32(
+// CONV-SAME:                                                  %[[VAL_0:.*]]: tensor<1x16x130x130xf32>,
+// CONV-SAME:                                                  %[[VAL_1:.*]]: tensor<512x16x3x3xf32>,
+// CONV-SAME:                                                  %[[VAL_2:.*]]: tensor<1x512x128x128xf32>) -> tensor<1x512x128x128xf32> {
+// CONV:           %[[VAL_3:.*]] = tensor.empty() : tensor<1x16x130x130xbf16>
+// CONV:           %[[DEMOT1:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]]],
+// CONV-SAME:          iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+// CONV-SAME:          ins(%[[VAL_0]] : tensor<1x16x130x130xf32>) outs(%[[VAL_3]] : tensor<1x16x130x130xbf16>) {
+// CONV:           ^bb0(%[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: bf16):
+// CONV:             %[[VAL_7:.*]] = arith.truncf %[[VAL_5]] : f32 to bf16
+// CONV:             linalg.yield %[[VAL_7]] : bf16
+// CONV:           } -> tensor<1x16x130x130xbf16>
+// CONV:           %[[VAL_8:.*]] = tensor.empty() : tensor<512x16x3x3xbf16>
+// CONV:           %[[DEMOT2:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]]],
+// CONV-SAME:          iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+// CONV-SAME:          ins(%[[VAL_1]] : tensor<512x16x3x3xf32>) outs(%[[VAL_8]] : tensor<512x16x3x3xbf16>) {
+// CONV:           ^bb0(%[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: bf16):
+// CONV:             %[[VAL_12:.*]] = arith.truncf %[[VAL_10]] : f32 to bf16
+// CONV:             linalg.yield %[[VAL_12]] : bf16
+// CONV:           } -> tensor<512x16x3x3xbf16>
+// CONV:           %[[VAL_13:.*]] = linalg.conv_2d_nchw_fchw ins(%[[DEMOT1]], %[[DEMOT2]] : tensor<1x16x130x130xbf16>, tensor<512x16x3x3xbf16>)
+// CONV-SAME:      outs(%[[VAL_2]] : tensor<1x512x128x128xf32>) -> tensor<1x512x128x128xf32>