[iree][global] Add conv2d op to demote to bf16 pass (#17410)
Adds conv2d or convolutionlikeinterface ops to the
demotecontractioninputstobf16 pass.
diff --git a/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp b/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp
index ceec1bc..793c713 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp
@@ -23,11 +23,14 @@
// For narrowable inputs, selects
struct DemoteContractionInputsToBF16Pattern
- : public OpInterfaceRewritePattern<linalg::ContractionOpInterface> {
+ : public OpInterfaceRewritePattern<linalg::LinalgOp> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
- LogicalResult matchAndRewrite(linalg::ContractionOpInterface op,
+ LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
- linalg::LinalgOp linalgOp = cast<linalg::LinalgOp>(op.getOperation());
+ if (!isa<linalg::ContractionOpInterface, linalg::ConvolutionOpInterface>(
+ linalgOp.getOperation())) {
+ return failure();
+ }
for (auto operand : linalgOp->getOperands()) {
auto operandType = dyn_cast<RankedTensorType>(operand.getType());
if (!operandType ||
@@ -65,32 +68,44 @@
}
auto replaceOpInputs = [&](auto *typePtr) {
- auto namedOp = cast<std::remove_pointer_t<decltype(typePtr)>>(op);
+ auto namedOp = cast<std::remove_pointer_t<decltype(typePtr)>>(linalgOp);
rewriter.replaceOpWithNewOp<std::remove_pointer_t<decltype(typePtr)>>(
linalgOp, demotedInputs, linalgOp.getDpsInits(),
linalg::getPrunedAttributeList(namedOp));
};
- if (isa<linalg::MatmulOp>(op)) {
+ if (isa<linalg::MatmulOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::MatmulOp *>(nullptr));
- } else if (isa<linalg::MatvecOp>(op)) {
+ } else if (isa<linalg::MatvecOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::MatvecOp *>(nullptr));
- } else if (isa<linalg::VecmatOp>(op)) {
+ } else if (isa<linalg::VecmatOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::VecmatOp *>(nullptr));
- } else if (isa<linalg::BatchMatmulOp>(op)) {
+ } else if (isa<linalg::BatchMatmulOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::BatchMatmulOp *>(nullptr));
- } else if (isa<linalg::BatchMatvecOp>(op)) {
+ } else if (isa<linalg::BatchMatvecOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::BatchMatvecOp *>(nullptr));
- } else if (isa<linalg::BatchVecmatOp>(op)) {
+ } else if (isa<linalg::BatchVecmatOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::BatchVecmatOp *>(nullptr));
- } else if (isa<linalg::MatmulTransposeAOp>(op)) {
+ } else if (isa<linalg::MatmulTransposeAOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::MatmulTransposeAOp *>(nullptr));
- } else if (isa<linalg::MatmulTransposeBOp>(op)) {
+ } else if (isa<linalg::MatmulTransposeBOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::MatmulTransposeBOp *>(nullptr));
- } else if (isa<linalg::BatchMatmulTransposeAOp>(op)) {
+ } else if (isa<linalg::BatchMatmulTransposeAOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::BatchMatmulTransposeAOp *>(nullptr));
- } else if (isa<linalg::BatchMatmulTransposeBOp>(op)) {
+ } else if (isa<linalg::BatchMatmulTransposeBOp>(linalgOp)) {
replaceOpInputs(static_cast<linalg::BatchMatmulTransposeBOp *>(nullptr));
+ } else if (isa<linalg::Conv2DOp>(linalgOp)) {
+ replaceOpInputs(static_cast<linalg::Conv2DOp *>(nullptr));
+ } else if (isa<linalg::Conv2DNchwFchwOp>(linalgOp)) {
+ replaceOpInputs(static_cast<linalg::Conv2DNchwFchwOp *>(nullptr));
+ } else if (isa<linalg::Conv2DNhwcHwcfOp>(linalgOp)) {
+ replaceOpInputs(static_cast<linalg::Conv2DNhwcHwcfOp *>(nullptr));
+ } else if (isa<linalg::Conv2DNhwcFhwcOp>(linalgOp)) {
+ replaceOpInputs(static_cast<linalg::Conv2DNhwcFhwcOp *>(nullptr));
+ } else if (isa<linalg::Conv2DNgchwFgchwOp>(linalgOp)) {
+ replaceOpInputs(static_cast<linalg::Conv2DNgchwFgchwOp *>(nullptr));
+ } else if (isa<linalg::Conv2DNgchwGfchwOp>(linalgOp)) {
+ replaceOpInputs(static_cast<linalg::Conv2DNgchwGfchwOp *>(nullptr));
} else {
return failure();
}
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/demote_contraction_inputs_to_bf16.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/demote_contraction_inputs_to_bf16.mlir
index ff717b8..ea8b168 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/demote_contraction_inputs_to_bf16.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/demote_contraction_inputs_to_bf16.mlir
@@ -219,3 +219,36 @@
// CHECK: linalg.matmul_transpose_b
// CHECK-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<100x250xbf16>, tensor<500x250xbf16>)
// CHECK-SAME: outs(%[[ARG2]] : tensor<100x500xf32>)
+
+// -----
+
+util.func public @conv_2d_nchw_fchw_f32f32f32(%arg0 : tensor<1x16x130x130xf32>, %arg1 : tensor<512x16x3x3xf32>,
+ %arg2 : tensor<1x512x128x128xf32>) -> tensor<1x512x128x128xf32> {
+ %0 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
+ ins(%arg0, %arg1 : tensor<1x16x130x130xf32>, tensor<512x16x3x3xf32>)
+ outs(%arg2 : tensor<1x512x128x128xf32>) -> tensor<1x512x128x128xf32>
+ util.return %0 : tensor<1x512x128x128xf32>
+}
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: util.func public @conv_2d_nchw_fchw_f32f32f32(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x16x130x130xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<512x16x3x3xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<1x512x128x128xf32>) -> tensor<1x512x128x128xf32> {
+// CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<1x16x130x130xbf16>
+// CHECK: %[[DEMOT1:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[VAL_0]] : tensor<1x16x130x130xf32>) outs(%[[VAL_3]] : tensor<1x16x130x130xbf16>) {
+// CHECK: ^bb0(%[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: bf16):
+// CHECK: %[[VAL_7:.*]] = arith.truncf %[[VAL_5]] : f32 to bf16
+// CHECK: linalg.yield %[[VAL_7]] : bf16
+// CHECK: } -> tensor<1x16x130x130xbf16>
+// CHECK: %[[VAL_8:.*]] = tensor.empty() : tensor<512x16x3x3xbf16>
+// CHECK: %[[DEMOT2:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[VAL_1]] : tensor<512x16x3x3xf32>) outs(%[[VAL_8]] : tensor<512x16x3x3xbf16>) {
+// CHECK: ^bb0(%[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: bf16):
+// CHECK: %[[VAL_12:.*]] = arith.truncf %[[VAL_10]] : f32 to bf16
+// CHECK: linalg.yield %[[VAL_12]] : bf16
+// CHECK: } -> tensor<512x16x3x3xbf16>
+// CHECK: %[[VAL_13:.*]] = linalg.conv_2d_nchw_fchw ins(%[[DEMOT1]], %[[DEMOT2]] : tensor<1x16x130x130xbf16>, tensor<512x16x3x3xbf16>)
+// CHECK-SAME: outs(%[[VAL_2]] : tensor<1x512x128x128xf32>) -> tensor<1x512x128x128xf32>