[iree][global] Control the demotion of ops (#17515)
Introduces `demote-only` flag in `demote-contraction-inputs-to-bf16` to
control the demotion of ops. For e.g., if `demote-only=conv` only conv
ops will be demoted.
diff --git a/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp b/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp
index 793c713..9df9578 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp
@@ -24,87 +24,102 @@
// For narrowable inputs, selects
struct DemoteContractionInputsToBF16Pattern
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
- using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+ using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
+ explicit DemoteContractionInputsToBF16Pattern(MLIRContext *ctx,
+ DemotionOption &option)
+ : OpInterfaceRewritePattern<linalg::LinalgOp>(ctx), demoteOption(option) {
+ }
+
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
+ if (demoteOption == DemotionOption::None) {
+ return failure();
+ }
if (!isa<linalg::ContractionOpInterface, linalg::ConvolutionOpInterface>(
linalgOp.getOperation())) {
return failure();
}
- for (auto operand : linalgOp->getOperands()) {
- auto operandType = dyn_cast<RankedTensorType>(operand.getType());
- if (!operandType ||
- operandType.getElementType() != rewriter.getF32Type()) {
- return failure();
- }
- }
- Location loc = linalgOp.getLoc();
- SmallVector<Value> demotedInputs;
- for (auto inputOperand : linalgOp.getDpsInputOperands()) {
- auto input = inputOperand->get();
- auto inputType = cast<RankedTensorType>(input.getType());
- auto demotedInputType =
- RankedTensorType::get(inputType.getShape(), rewriter.getBF16Type(),
- inputType.getEncoding());
- SmallVector<AffineMap> maps(
- 2, rewriter.getMultiDimIdentityMap(inputType.getRank()));
- SmallVector<utils::IteratorType> iteratorTypes(
- inputType.getRank(), utils::IteratorType::parallel);
- SmallVector<OpFoldResult> mixedSizes =
- tensor::getMixedSizes(rewriter, loc, input);
- Value empty = rewriter.create<tensor::EmptyOp>(loc, mixedSizes,
- rewriter.getBF16Type());
- demotedInputs.push_back(
- rewriter
- .create<linalg::GenericOp>(
- loc, TypeRange{demotedInputType}, ValueRange{input},
- ValueRange{empty}, maps, iteratorTypes,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- Value result = b.create<arith::TruncFOp>(
- loc, rewriter.getBF16Type(), args[0]);
- b.create<linalg::YieldOp>(loc, result);
- })
- ->getResults()[0]);
+
+ if (!llvm::all_of(linalgOp->getOperands(), [&](auto operand) {
+ auto operandType = dyn_cast<RankedTensorType>(operand.getType());
+ return operandType &&
+ operandType.getElementType() == rewriter.getF32Type();
+ })) {
+ return failure();
}
auto replaceOpInputs = [&](auto *typePtr) {
+ Location loc = linalgOp.getLoc();
+ SmallVector<Value> demotedInputs;
+ for (auto inputOperand : linalgOp.getDpsInputOperands()) {
+ auto input = inputOperand->get();
+ auto inputType = cast<RankedTensorType>(input.getType());
+ auto demotedInputType =
+ RankedTensorType::get(inputType.getShape(), rewriter.getBF16Type(),
+ inputType.getEncoding());
+ SmallVector<AffineMap> maps(
+ 2, rewriter.getMultiDimIdentityMap(inputType.getRank()));
+ SmallVector<utils::IteratorType> iteratorTypes(
+ inputType.getRank(), utils::IteratorType::parallel);
+ SmallVector<OpFoldResult> mixedSizes =
+ tensor::getMixedSizes(rewriter, loc, input);
+ Value empty = rewriter.create<tensor::EmptyOp>(loc, mixedSizes,
+ rewriter.getBF16Type());
+ demotedInputs.push_back(
+ rewriter
+ .create<linalg::GenericOp>(
+ loc, TypeRange{demotedInputType}, ValueRange{input},
+ ValueRange{empty}, maps, iteratorTypes,
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ Value result = b.create<arith::TruncFOp>(
+ loc, rewriter.getBF16Type(), args[0]);
+ b.create<linalg::YieldOp>(loc, result);
+ })
+ ->getResults()[0]);
+ }
auto namedOp = cast<std::remove_pointer_t<decltype(typePtr)>>(linalgOp);
rewriter.replaceOpWithNewOp<std::remove_pointer_t<decltype(typePtr)>>(
linalgOp, demotedInputs, linalgOp.getDpsInits(),
linalg::getPrunedAttributeList(namedOp));
};
- if (isa<linalg::MatmulOp>(linalgOp)) {
+ bool demoteMatmul = (demoteOption == DemotionOption::All) ||
+ (demoteOption == DemotionOption::Matmul);
+
+ bool demoteConv = (demoteOption == DemotionOption::All) ||
+ (demoteOption == DemotionOption::Conv);
+
+ if (demoteMatmul && isa<linalg::MatmulOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::MatmulOp *>(nullptr));
- } else if (isa<linalg::MatvecOp>(linalgOp)) {
+ } else if (demoteMatmul && isa<linalg::MatvecOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::MatvecOp *>(nullptr));
- } else if (isa<linalg::VecmatOp>(linalgOp)) {
+ } else if (demoteMatmul && isa<linalg::VecmatOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::VecmatOp *>(nullptr));
- } else if (isa<linalg::BatchMatmulOp>(linalgOp)) {
+ } else if (demoteMatmul && isa<linalg::BatchMatmulOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::BatchMatmulOp *>(nullptr));
- } else if (isa<linalg::BatchMatvecOp>(linalgOp)) {
+ } else if (demoteMatmul && isa<linalg::BatchMatvecOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::BatchMatvecOp *>(nullptr));
- } else if (isa<linalg::BatchVecmatOp>(linalgOp)) {
+ } else if (demoteMatmul && isa<linalg::BatchVecmatOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::BatchVecmatOp *>(nullptr));
- } else if (isa<linalg::MatmulTransposeAOp>(linalgOp)) {
+ } else if (demoteMatmul && isa<linalg::MatmulTransposeAOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::MatmulTransposeAOp *>(nullptr));
- } else if (isa<linalg::MatmulTransposeBOp>(linalgOp)) {
+ } else if (demoteMatmul && isa<linalg::MatmulTransposeBOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::MatmulTransposeBOp *>(nullptr));
- } else if (isa<linalg::BatchMatmulTransposeAOp>(linalgOp)) {
+ } else if (demoteMatmul && isa<linalg::BatchMatmulTransposeAOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::BatchMatmulTransposeAOp *>(nullptr));
- } else if (isa<linalg::BatchMatmulTransposeBOp>(linalgOp)) {
+ } else if (demoteMatmul && isa<linalg::BatchMatmulTransposeBOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::BatchMatmulTransposeBOp *>(nullptr));
- } else if (isa<linalg::Conv2DOp>(linalgOp)) {
+ } else if (demoteConv && isa<linalg::Conv2DOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::Conv2DOp *>(nullptr));
- } else if (isa<linalg::Conv2DNchwFchwOp>(linalgOp)) {
+ } else if (demoteConv && isa<linalg::Conv2DNchwFchwOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::Conv2DNchwFchwOp *>(nullptr));
- } else if (isa<linalg::Conv2DNhwcHwcfOp>(linalgOp)) {
+ } else if (demoteConv && isa<linalg::Conv2DNhwcHwcfOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::Conv2DNhwcHwcfOp *>(nullptr));
- } else if (isa<linalg::Conv2DNhwcFhwcOp>(linalgOp)) {
+ } else if (demoteConv && isa<linalg::Conv2DNhwcFhwcOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::Conv2DNhwcFhwcOp *>(nullptr));
- } else if (isa<linalg::Conv2DNgchwFgchwOp>(linalgOp)) {
+ } else if (demoteConv && isa<linalg::Conv2DNgchwFgchwOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::Conv2DNgchwFgchwOp *>(nullptr));
- } else if (isa<linalg::Conv2DNgchwGfchwOp>(linalgOp)) {
+ } else if (demoteConv && isa<linalg::Conv2DNgchwGfchwOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::Conv2DNgchwGfchwOp *>(nullptr));
} else {
return failure();
@@ -112,15 +127,24 @@
return success();
}
+
+private:
+ DemotionOption demoteOption;
};
class DemoteContractionInputsToBF16Pass
: public DemoteContractionInputsToBF16Base<
DemoteContractionInputsToBF16Pass> {
+
+public:
+ explicit DemoteContractionInputsToBF16Pass(const DemotionOption &option) {
+ this->demoteOnly.setValue(option);
+ }
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
- patterns.insert<DemoteContractionInputsToBF16Pattern>(context);
+ patterns.insert<DemoteContractionInputsToBF16Pattern>(
+ context, demoteOnly.getValue());
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
@@ -130,8 +154,9 @@
} // namespace
-std::unique_ptr<Pass> createDemoteContractionInputsToBF16Pass() {
- return std::make_unique<DemoteContractionInputsToBF16Pass>();
+std::unique_ptr<Pass>
+createDemoteContractionInputsToBF16Pass(DemotionOption option) {
+ return std::make_unique<DemoteContractionInputsToBF16Pass>(option);
}
} // namespace mlir::iree_compiler::GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/PassDetail.h b/compiler/src/iree/compiler/GlobalOptimization/PassDetail.h
index 49ca67b..b0b79e7 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/PassDetail.h
+++ b/compiler/src/iree/compiler/GlobalOptimization/PassDetail.h
@@ -7,6 +7,7 @@
#ifndef IREE_COMPILER_GLOBALOPTIMIZATION_PASSDETAIL_H_
#define IREE_COMPILER_GLOBALOPTIMIZATION_PASSDETAIL_H_
+#include "iree/compiler/GlobalOptimization/Passes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
index 12b2939..ddabd1c 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
@@ -51,11 +51,17 @@
"Enables horizontal fusion of contractions with one common operand"),
llvm::cl::init(false));
-static llvm::cl::opt<bool> clEnableDemoteContractionInputsToBF16(
+static llvm::cl::opt<DemotionOption> clDemoteContractionInputsToBF16Strategy(
"iree-global-opt-enable-demote-contraction-inputs-to-bf16",
- llvm::cl::desc(
- "Demote inputs (LHS, RHS) of linalg matmul-like ops from f32 to bf16."),
- llvm::cl::init(false));
+ llvm::cl::desc("Demotes inputs (LHS, RHS) of contraction ops to BF16. "
+ "Selects types of contraction ops to demote."),
+ llvm::cl::values(
+ clEnumValN(DemotionOption::All, "all", "Demote all contraction ops."),
+ clEnumValN(DemotionOption::Conv, "conv",
+ "Only demote convolution ops."),
+ clEnumValN(DemotionOption::Matmul, "matmul", "Only demote matmul ops."),
+ clEnumValN(DemotionOption::None, "none", "Demote no contraction ops.")),
+ llvm::cl::init(DemotionOption::None));
void buildGlobalOptExprHoistingPassPipeline(
OpPassManager &passManager, const TransformOptions &transformOptions) {
@@ -120,8 +126,10 @@
.addPass(IREE::Flow::createFoldUnitExtentDimsPass)
.addPredicatedPass(clEnableFuseSiluHorizontalMatmul,
createFuseSiluHorizontalMatmulPass)
- .addPredicatedPass(clEnableDemoteContractionInputsToBF16,
- createDemoteContractionInputsToBF16Pass)
+ .addPass([&]() {
+ return createDemoteContractionInputsToBF16Pass(
+ clDemoteContractionInputsToBF16Strategy);
+ })
.addPass([&]() {
return createFuseDequantizationMatmulPass(
clEnableQuantizedMatmulReassociation);
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.h b/compiler/src/iree/compiler/GlobalOptimization/Passes.h
index e2f79ba..6ddbeab 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.h
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.h
@@ -51,8 +51,13 @@
std::unique_ptr<Pass>
createDecomposeConcatPass(bool enableConcatTransposition = false);
+// Used by the demoteContractionInputsToBF16 pass to determine which op inputs
+// to demote.
+enum class DemotionOption { All, Conv, Matmul, None };
+
/// Demotes inputs (LHS, RHS) of linalg matmul-like ops from f32 to bf16.
-std::unique_ptr<Pass> createDemoteContractionInputsToBF16Pass();
+std::unique_ptr<Pass> createDemoteContractionInputsToBF16Pass(
+ DemotionOption option = DemotionOption::None);
/// Detaches elementwise ops from named Linalg ops.
std::unique_ptr<Pass> createDetachElementwiseFromNamedOpsPass();
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.td b/compiler/src/iree/compiler/GlobalOptimization/Passes.td
index 29e1b3a..0f3bcd3 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.td
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.td
@@ -32,9 +32,33 @@
];
}
-def DemoteContractionInputsToBF16 : Pass<"iree-global-opt-demote-contraction-inputs-to-bf16", ""> {
- let summary = "Demotes inputs (LHS, RHS) of linalg matmul-like ops from f32 to bf16.";
- let constructor = "mlir::iree_compiler::GlobalOptimization::createDemoteContractionInputsToBF16Pass()";
+def DemoteContractionInputsToBF16
+ : Pass<"iree-global-opt-demote-contraction-inputs-to-bf16", ""> {
+ let summary =
+ "Demotes inputs (LHS, RHS) of linalg matmul-like ops from f32 to bf16.";
+ let constructor = "mlir::iree_compiler::GlobalOptimization::"
+ "createDemoteContractionInputsToBF16Pass()";
+ let options =
+ [Option<"demoteOnly", "demote-only",
+ "mlir::iree_compiler::GlobalOptimization::DemotionOption",
+ /*default=*/
+ "mlir::iree_compiler::GlobalOptimization::DemotionOption::All",
+ "Select the type of contraction ops to demote.",
+ [{::llvm::cl::values(
+ clEnumValN(mlir::iree_compiler::GlobalOptimization::DemotionOption::All,
+ "all",
+ "demote all contraction ops."),
+ clEnumValN(mlir::iree_compiler::GlobalOptimization::DemotionOption::Conv,
+ "conv",
+ "Only demote convolution ops."),
+ clEnumValN(mlir::iree_compiler::GlobalOptimization::DemotionOption::Matmul,
+ "matmul",
+ "Only demote matmul ops."),
+ clEnumValN(mlir::iree_compiler::GlobalOptimization::DemotionOption::None,
+ "none",
+ "demote no contraction ops.")
+ )}]>,
+ ];
}
def DetachElementwiseFromNamedOps :
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/demote_contraction_inputs_to_bf16.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/demote_contraction_inputs_to_bf16.mlir
index 709161f..387c8a0 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/demote_contraction_inputs_to_bf16.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/demote_contraction_inputs_to_bf16.mlir
@@ -1,4 +1,5 @@
-// RUN: iree-opt --split-input-file -iree-global-opt-demote-contraction-inputs-to-bf16 %s | FileCheck %s
+// RUN: iree-opt --split-input-file -iree-global-opt-demote-contraction-inputs-to-bf16="demote-only=matmul" %s | FileCheck %s --check-prefix=MATMUL
+// RUN: iree-opt --split-input-file -iree-global-opt-demote-contraction-inputs-to-bf16="demote-only=conv" %s | FileCheck %s --check-prefix=CONV
util.func public @matmul_f32f32f32(%arg0 : tensor<100x250xf32>, %arg1 : tensor<250x500xf32>,
%arg2 : tensor<100x500xf32>) -> tensor<100x500xf32> {
@@ -7,19 +8,19 @@
util.return %0 : tensor<100x500xf32>
}
-// CHECK: @matmul_f32f32f32
-// CHECK-SAME: %[[ARG0:.+]]: tensor<100x250xf32>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<250x500xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<100x500xf32>
-// CHECK: %[[DEMOTED0:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<100x250xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: %[[DEMOTED1:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG1]] : tensor<250x500xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: linalg.matmul
-// CHECK-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<100x250xbf16>, tensor<250x500xbf16>)
-// CHECK-SAME: outs(%[[ARG2]] : tensor<100x500xf32>)
+// MATMUL: @matmul_f32f32f32
+// MATMUL-SAME: %[[ARG0:.+]]: tensor<100x250xf32>
+// MATMUL-SAME: %[[ARG1:.+]]: tensor<250x500xf32>
+// MATMUL-SAME: %[[ARG2:.+]]: tensor<100x500xf32>
+// MATMUL: %[[DEMOTED0:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG0]] : tensor<100x250xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG1]] : tensor<250x500xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: linalg.matmul
+// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<100x250xbf16>, tensor<250x500xbf16>)
+// MATMUL-SAME: outs(%[[ARG2]] : tensor<100x500xf32>)
// -----
@@ -30,17 +31,17 @@
util.return %0 : tensor<?x?xf32>
}
-// CHECK: @dynamic_matmul_f32f32f32
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>
-// CHECK: %[[DEMOTED0:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<?x?xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: %[[DEMOTED1:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG1]] : tensor<?x?xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: linalg.matmul
-// CHECK-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<?x?xbf16>, tensor<?x?xbf16>)
-// CHECK-SAME: outs(%[[ARG2]] : tensor<?x?xf32>)
+// MATMUL: @dynamic_matmul_f32f32f32
+// MATMUL-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>
+// MATMUL: %[[DEMOTED0:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG0]] : tensor<?x?xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG1]] : tensor<?x?xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: linalg.matmul
+// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<?x?xbf16>, tensor<?x?xbf16>)
+// MATMUL-SAME: outs(%[[ARG2]] : tensor<?x?xf32>)
// -----
@@ -51,19 +52,19 @@
util.return %0 : tensor<4x100x500xf32>
}
-// CHECK: @batch_matmul_f32f32f32
-// CHECK-SAME: %[[ARG0:.+]]: tensor<4x100x250xf32>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<4x250x500xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<4x100x500xf32>
-// CHECK: %[[DEMOTED0:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<4x100x250xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: %[[DEMOTED1:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG1]] : tensor<4x250x500xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: linalg.batch_matmul
-// CHECK-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<4x100x250xbf16>, tensor<4x250x500xbf16>)
-// CHECK-SAME: outs(%[[ARG2]] : tensor<4x100x500xf32>)
+// MATMUL: @batch_matmul_f32f32f32
+// MATMUL-SAME: %[[ARG0:.+]]: tensor<4x100x250xf32>
+// MATMUL-SAME: %[[ARG1:.+]]: tensor<4x250x500xf32>
+// MATMUL-SAME: %[[ARG2:.+]]: tensor<4x100x500xf32>
+// MATMUL: %[[DEMOTED0:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG0]] : tensor<4x100x250xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG1]] : tensor<4x250x500xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: linalg.batch_matmul
+// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<4x100x250xbf16>, tensor<4x250x500xbf16>)
+// MATMUL-SAME: outs(%[[ARG2]] : tensor<4x100x500xf32>)
// -----
@@ -74,19 +75,19 @@
util.return %0 : tensor<100xf32>
}
-// CHECK: @matvec_f32f32f32
-// CHECK-SAME: %[[ARG0:.+]]: tensor<100x250xf32>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<250xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<100xf32>
-// CHECK: %[[DEMOTED0:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<100x250xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: %[[DEMOTED1:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG1]] : tensor<250xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: linalg.matvec
-// CHECK-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<100x250xbf16>, tensor<250xbf16>)
-// CHECK-SAME: outs(%[[ARG2]] : tensor<100xf32>)
+// MATMUL: @matvec_f32f32f32
+// MATMUL-SAME: %[[ARG0:.+]]: tensor<100x250xf32>
+// MATMUL-SAME: %[[ARG1:.+]]: tensor<250xf32>
+// MATMUL-SAME: %[[ARG2:.+]]: tensor<100xf32>
+// MATMUL: %[[DEMOTED0:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG0]] : tensor<100x250xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG1]] : tensor<250xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: linalg.matvec
+// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<100x250xbf16>, tensor<250xbf16>)
+// MATMUL-SAME: outs(%[[ARG2]] : tensor<100xf32>)
// -----
@@ -97,19 +98,19 @@
util.return %0 : tensor<4x500xf32>
}
-// CHECK: @batch_vecmat_f32f32f32
-// CHECK-SAME: %[[ARG0:.+]]: tensor<4x250xf32>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<4x250x500xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<4x500xf32>
-// CHECK: %[[DEMOTED0:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<4x250xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: %[[DEMOTED1:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG1]] : tensor<4x250x500xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: linalg.batch_vecmat
-// CHECK-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<4x250xbf16>, tensor<4x250x500xbf16>)
-// CHECK-SAME: outs(%[[ARG2]] : tensor<4x500xf32>)
+// MATMUL: @batch_vecmat_f32f32f32
+// MATMUL-SAME: %[[ARG0:.+]]: tensor<4x250xf32>
+// MATMUL-SAME: %[[ARG1:.+]]: tensor<4x250x500xf32>
+// MATMUL-SAME: %[[ARG2:.+]]: tensor<4x500xf32>
+// MATMUL: %[[DEMOTED0:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG0]] : tensor<4x250xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG1]] : tensor<4x250x500xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: linalg.batch_vecmat
+// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<4x250xbf16>, tensor<4x250x500xbf16>)
+// MATMUL-SAME: outs(%[[ARG2]] : tensor<4x500xf32>)
// -----
@@ -120,13 +121,13 @@
util.return %0 : tensor<100x500xf64>
}
-// CHECK: @nonmatch_matmul_f32f32f64
-// CHECK-SAME: %[[ARG0:.+]]: tensor<100x250xf32>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<250x500xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<100x500xf64>
-// CHECK: linalg.matmul
-// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<100x250xf32>, tensor<250x500xf32>)
-// CHECK-SAME: outs(%[[ARG2]] : tensor<100x500xf64>)
+// MATMUL: @nonmatch_matmul_f32f32f64
+// MATMUL-SAME: %[[ARG0:.+]]: tensor<100x250xf32>
+// MATMUL-SAME: %[[ARG1:.+]]: tensor<250x500xf32>
+// MATMUL-SAME: %[[ARG2:.+]]: tensor<100x500xf64>
+// MATMUL: linalg.matmul
+// MATMUL-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<100x250xf32>, tensor<250x500xf32>)
+// MATMUL-SAME: outs(%[[ARG2]] : tensor<100x500xf64>)
// -----
@@ -137,19 +138,19 @@
util.return %0 : tensor<4x100x500xf32>
}
-// CHECK: @batch_matmul_transpose_a_f32f32f32
-// CHECK-SAME: %[[ARG0:.+]]: tensor<4x250x100xf32>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<4x250x500xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<4x100x500xf32>
-// CHECK: %[[DEMOTED0:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<4x250x100xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: %[[DEMOTED1:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG1]] : tensor<4x250x500xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: linalg.batch_matmul_transpose_a
-// CHECK-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<4x250x100xbf16>, tensor<4x250x500xbf16>)
-// CHECK-SAME: outs(%[[ARG2]] : tensor<4x100x500xf32>)
+// MATMUL: @batch_matmul_transpose_a_f32f32f32
+// MATMUL-SAME: %[[ARG0:.+]]: tensor<4x250x100xf32>
+// MATMUL-SAME: %[[ARG1:.+]]: tensor<4x250x500xf32>
+// MATMUL-SAME: %[[ARG2:.+]]: tensor<4x100x500xf32>
+// MATMUL: %[[DEMOTED0:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG0]] : tensor<4x250x100xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG1]] : tensor<4x250x500xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: linalg.batch_matmul_transpose_a
+// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<4x250x100xbf16>, tensor<4x250x500xbf16>)
+// MATMUL-SAME: outs(%[[ARG2]] : tensor<4x100x500xf32>)
// -----
@@ -160,19 +161,19 @@
util.return %0 : tensor<4x100x500xf32>
}
-// CHECK: @batch_matmul_transpose_b_f32f32f32
-// CHECK-SAME: %[[ARG0:.+]]: tensor<4x100x250xf32>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<4x500x250xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<4x100x500xf32>
-// CHECK: %[[DEMOTED0:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<4x100x250xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: %[[DEMOTED1:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG1]] : tensor<4x500x250xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: linalg.batch_matmul_transpose_b
-// CHECK-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<4x100x250xbf16>, tensor<4x500x250xbf16>)
-// CHECK-SAME: outs(%[[ARG2]] : tensor<4x100x500xf32>)
+// MATMUL: @batch_matmul_transpose_b_f32f32f32
+// MATMUL-SAME: %[[ARG0:.+]]: tensor<4x100x250xf32>
+// MATMUL-SAME: %[[ARG1:.+]]: tensor<4x500x250xf32>
+// MATMUL-SAME: %[[ARG2:.+]]: tensor<4x100x500xf32>
+// MATMUL: %[[DEMOTED0:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG0]] : tensor<4x100x250xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG1]] : tensor<4x500x250xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: linalg.batch_matmul_transpose_b
+// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<4x100x250xbf16>, tensor<4x500x250xbf16>)
+// MATMUL-SAME: outs(%[[ARG2]] : tensor<4x100x500xf32>)
// -----
@@ -183,19 +184,19 @@
util.return %0 : tensor<100x500xf32>
}
-// CHECK: @matmul_transpose_a_f32f32f32
-// CHECK-SAME: %[[ARG0:.+]]: tensor<250x100xf32>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<250x500xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<100x500xf32>
-// CHECK: %[[DEMOTED0:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<250x100xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: %[[DEMOTED1:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG1]] : tensor<250x500xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: linalg.matmul_transpose_a
-// CHECK-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<250x100xbf16>, tensor<250x500xbf16>)
-// CHECK-SAME: outs(%[[ARG2]] : tensor<100x500xf32>)
+// MATMUL: @matmul_transpose_a_f32f32f32
+// MATMUL-SAME: %[[ARG0:.+]]: tensor<250x100xf32>
+// MATMUL-SAME: %[[ARG1:.+]]: tensor<250x500xf32>
+// MATMUL-SAME: %[[ARG2:.+]]: tensor<100x500xf32>
+// MATMUL: %[[DEMOTED0:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG0]] : tensor<250x100xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG1]] : tensor<250x500xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: linalg.matmul_transpose_a
+// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<250x100xbf16>, tensor<250x500xbf16>)
+// MATMUL-SAME: outs(%[[ARG2]] : tensor<100x500xf32>)
// -----
@@ -206,19 +207,19 @@
util.return %0 : tensor<100x500xf32>
}
-// CHECK: @matmul_transpose_b_f32f32f32
-// CHECK-SAME: %[[ARG0:.+]]: tensor<100x250xf32>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<500x250xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<100x500xf32>
-// CHECK: %[[DEMOTED0:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<100x250xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: %[[DEMOTED1:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG1]] : tensor<500x250xf32>)
-// CHECK: arith.truncf {{.*}} : f32 to bf16
-// CHECK: linalg.matmul_transpose_b
-// CHECK-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<100x250xbf16>, tensor<500x250xbf16>)
-// CHECK-SAME: outs(%[[ARG2]] : tensor<100x500xf32>)
+// MATMUL: @matmul_transpose_b_f32f32f32
+// MATMUL-SAME: %[[ARG0:.+]]: tensor<100x250xf32>
+// MATMUL-SAME: %[[ARG1:.+]]: tensor<500x250xf32>
+// MATMUL-SAME: %[[ARG2:.+]]: tensor<100x500xf32>
+// MATMUL: %[[DEMOTED0:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG0]] : tensor<100x250xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
+// MATMUL-SAME: ins(%[[ARG1]] : tensor<500x250xf32>)
+// MATMUL: arith.truncf {{.*}} : f32 to bf16
+// MATMUL: linalg.matmul_transpose_b
+// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<100x250xbf16>, tensor<500x250xbf16>)
+// MATMUL-SAME: outs(%[[ARG2]] : tensor<100x500xf32>)
// -----
@@ -229,26 +230,26 @@
outs(%arg2 : tensor<1x512x128x128xf32>) -> tensor<1x512x128x128xf32>
util.return %0 : tensor<1x512x128x128xf32>
}
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: util.func public @conv_2d_nchw_fchw_f32f32f32(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x16x130x130xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<512x16x3x3xf32>,
-// CHECK-SAME: %[[VAL_2:.*]]: tensor<1x512x128x128xf32>) -> tensor<1x512x128x128xf32> {
-// CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<1x16x130x130xbf16>
-// CHECK: %[[DEMOT1:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]]],
-// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-// CHECK-SAME: ins(%[[VAL_0]] : tensor<1x16x130x130xf32>) outs(%[[VAL_3]] : tensor<1x16x130x130xbf16>) {
-// CHECK: ^bb0(%[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: bf16):
-// CHECK: %[[VAL_7:.*]] = arith.truncf %[[VAL_5]] : f32 to bf16
-// CHECK: linalg.yield %[[VAL_7]] : bf16
-// CHECK: } -> tensor<1x16x130x130xbf16>
-// CHECK: %[[VAL_8:.*]] = tensor.empty() : tensor<512x16x3x3xbf16>
-// CHECK: %[[DEMOT2:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]]],
-// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-// CHECK-SAME: ins(%[[VAL_1]] : tensor<512x16x3x3xf32>) outs(%[[VAL_8]] : tensor<512x16x3x3xbf16>) {
-// CHECK: ^bb0(%[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: bf16):
-// CHECK: %[[VAL_12:.*]] = arith.truncf %[[VAL_10]] : f32 to bf16
-// CHECK: linalg.yield %[[VAL_12]] : bf16
-// CHECK: } -> tensor<512x16x3x3xbf16>
-// CHECK: %[[VAL_13:.*]] = linalg.conv_2d_nchw_fchw ins(%[[DEMOT1]], %[[DEMOT2]] : tensor<1x16x130x130xbf16>, tensor<512x16x3x3xbf16>)
-// CHECK-SAME: outs(%[[VAL_2]] : tensor<1x512x128x128xf32>) -> tensor<1x512x128x128xf32>
+// CONV: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CONV-LABEL: util.func public @conv_2d_nchw_fchw_f32f32f32(
+// CONV-SAME: %[[VAL_0:.*]]: tensor<1x16x130x130xf32>,
+// CONV-SAME: %[[VAL_1:.*]]: tensor<512x16x3x3xf32>,
+// CONV-SAME: %[[VAL_2:.*]]: tensor<1x512x128x128xf32>) -> tensor<1x512x128x128xf32> {
+// CONV: %[[VAL_3:.*]] = tensor.empty() : tensor<1x16x130x130xbf16>
+// CONV: %[[DEMOT1:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]]],
+// CONV-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+// CONV-SAME: ins(%[[VAL_0]] : tensor<1x16x130x130xf32>) outs(%[[VAL_3]] : tensor<1x16x130x130xbf16>) {
+// CONV: ^bb0(%[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: bf16):
+// CONV: %[[VAL_7:.*]] = arith.truncf %[[VAL_5]] : f32 to bf16
+// CONV: linalg.yield %[[VAL_7]] : bf16
+// CONV: } -> tensor<1x16x130x130xbf16>
+// CONV: %[[VAL_8:.*]] = tensor.empty() : tensor<512x16x3x3xbf16>
+// CONV: %[[DEMOT2:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]]],
+// CONV-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+// CONV-SAME: ins(%[[VAL_1]] : tensor<512x16x3x3xf32>) outs(%[[VAL_8]] : tensor<512x16x3x3xbf16>) {
+// CONV: ^bb0(%[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: bf16):
+// CONV: %[[VAL_12:.*]] = arith.truncf %[[VAL_10]] : f32 to bf16
+// CONV: linalg.yield %[[VAL_12]] : bf16
+// CONV: } -> tensor<512x16x3x3xbf16>
+// CONV: %[[VAL_13:.*]] = linalg.conv_2d_nchw_fchw ins(%[[DEMOT1]], %[[DEMOT2]] : tensor<1x16x130x130xbf16>, tensor<512x16x3x3xbf16>)
+// CONV-SAME: outs(%[[VAL_2]] : tensor<1x512x128x128xf32>) -> tensor<1x512x128x128xf32>