[CPU] Add support for converting math.powf from fp16 to fp32. (#15927)
There is a bug in polynomial approximation. It generates `NAN` and `INF`
for fp16 types. This is a workaround to get it functional. See
https://github.com/openxla/iree/issues/15661 for more details.
Also rework on the maximumf test. The generic op is not a common input
because it uses `outs` while there are no reduction loops.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/ExpandF16OpToF32Pass.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/ExpandF16OpToF32Pass.cpp
index 3c7d9b8..1c8bbf4 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/ExpandF16OpToF32Pass.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/ExpandF16OpToF32Pass.cpp
@@ -7,8 +7,7 @@
#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
-#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -16,7 +15,7 @@
namespace {
-/// A pattern that expands floating-point arithmetic operations with f16
+/// A pattern that expands floating-point arithmetic/math operations with f16
/// operands to f32 operands. It performs the expansion by extending the
/// f16 operands to f32, performing the arithmetic operation on the extended
/// operands, and then truncating the result back to f16.
@@ -27,29 +26,32 @@
LogicalResult matchAndRewrite(Op op,
PatternRewriter &rewriter) const override {
- Type resultType = op.getLhs().getType();
- if (getElementTypeOrSelf(resultType).getIntOrFloatBitWidth() != 16) {
+ auto isElemF16Type = [](Type t) { return getElementTypeOrSelf(t).isF16(); };
+ Type resultType = op.getResult().getType();
+ if (!isElemF16Type(resultType)) {
return failure();
}
Location loc = op.getLoc();
-
- Type wideType = rewriter.getF32Type();
- if (auto vecTy = resultType.dyn_cast<VectorType>()) {
- wideType = VectorType::get(vecTy.getShape(), wideType);
+ Type f32Type = rewriter.getF32Type();
+ SmallVector<Value> operands;
+ for (auto operand : op.getOperands()) {
+ if (!isElemF16Type(operand.getType())) {
+ operands.push_back(operand);
+ continue;
+ }
+ Value ext = rewriter.create<arith::ExtFOp>(loc, f32Type, operand);
+ operands.push_back(ext);
}
+ Value newOp = rewriter.create<Op>(loc, f32Type, operands);
- Value lhsExt = rewriter.create<arith::ExtFOp>(loc, wideType, op.getLhs());
- Value rhsExt = rewriter.create<arith::ExtFOp>(loc, wideType, op.getRhs());
- Value maxExt = rewriter.create<Op>(loc, wideType, lhsExt, rhsExt);
-
- rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, resultType, maxExt);
+ rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, resultType, newOp);
return success();
}
};
struct ExpandF16OpToF32Pass
- : public ExpandArithF16ToF32Base<ExpandF16OpToF32Pass> {
+ : public ExpandF16OpToF32Base<ExpandF16OpToF32Pass> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<linalg::LinalgDialect>();
}
@@ -58,6 +60,9 @@
MLIRContext *context = &getContext();
RewritePatternSet patterns(&getContext());
patterns.insert<ExpandF16OpToF32Pattern<arith::MaximumFOp>>(context);
+ // TODO(#15661): Remove the expansion for math.powf op after fixing
+ // approximation issue.
+ patterns.insert<ExpandF16OpToF32Pattern<math::PowFOp>>(context);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
index 5fa4f38..1491a3b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
@@ -25,7 +25,7 @@
];
}
-def ExpandArithF16ToF32 :
+def ExpandF16OpToF32 :
Pass<"iree-llvmcpu-expand-f16-op-to-f32", ""> {
let summary =
"Preform f16 opertaions by expanding them to f32.";
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/expand_f16_op_to_f32.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/expand_f16_op_to_f32.mlir
index 98f42e3..519f0e5 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/expand_f16_op_to_f32.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/expand_f16_op_to_f32.mlir
@@ -1,21 +1,51 @@
-// RUN: iree-opt --iree-llvmcpu-expand-f16-op-to-f32 %s | FileCheck %s
+// RUN: iree-opt --iree-llvmcpu-expand-f16-op-to-f32 --split-input-file %s | FileCheck %s
-func.func @test_expand_f16_maxf(%arg0: tensor<4xf16>, %arg1: tensor<4xf16>) -> tensor<4xf16>{
- %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
- iterator_types = ["parallel"]} ins(%arg0: tensor<4xf16>) outs(%arg1: tensor<4xf16>) {
- ^bb0(%in: f16, %out: f16):
- %2 = arith.maximumf %in, %out : f16
- linalg.yield %2: f16
- } -> tensor<4xf16>
-
- return %1 : tensor<4xf16>
+func.func @maximumf(%arg0: tensor<4xf16>, %arg1: tensor<4xf16>, %arg2: tensor<4xf16>) -> tensor<4xf16>{
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%arg0, %arg1: tensor<4xf16>, tensor<4xf16>)
+ outs(%arg2: tensor<4xf16>) {
+ ^bb0(%in: f16, %in_0: f16, %out: f16):
+ %2 = arith.maximumf %in, %in_0 : f16
+ linalg.yield %2: f16
+ } -> tensor<4xf16>
+ return %1 : tensor<4xf16>
}
-// CHECK-LABEL: func.func @test_expand_f16_maxf
-// CHECK: %[[GEN:.*]] = linalg.generic
-// CHECK: %[[RHSEXT:.*]] = arith.extf %in : f16 to f32
-// CHECK: %[[LHSEXT:.*]] = arith.extf %out : f16 to f32
-// CHECK: %[[MAX:.*]] = arith.maximumf %[[RHSEXT]], %[[LHSEXT]] : f32
-// CHECK: %[[TRUNC:.*]] = arith.truncf %[[MAX]] : f32 to f16
-// CHECK: linalg.yield %[[TRUNC:.*]] : f16
-// CHECK: return %[[GEN:.*]] : tensor<4xf16>
+// CHECK-LABEL: func.func @maximumf
+// CHECK: %[[GEN:.*]] = linalg.generic
+// CHECK: %[[LHS:.*]] = arith.extf %{{.+}} : f16 to f32
+// CHECK: %[[RHS:.*]] = arith.extf %{{.+}} : f16 to f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[LHS]], %[[RHS]] : f32
+// CHECK: %[[TRUNC:.*]] = arith.truncf %[[MAX]] : f32 to f16
+// CHECK: linalg.yield %[[TRUNC:.*]] : f16
+// CHECK: return %[[GEN:.*]] : tensor<4xf16>
+
+// -----
+
+func.func @powf(%arg0: tensor<4xf16>, %arg1: tensor<4xf16>, %arg2: tensor<4xf16>) -> tensor<4xf16>{
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%arg0, %arg1: tensor<4xf16>, tensor<4xf16>)
+ outs(%arg2: tensor<4xf16>) {
+ ^bb0(%in: f16, %in_0: f16, %out: f16):
+ %2 = math.powf %in, %in_0 : f16
+ linalg.yield %2: f16
+ } -> tensor<4xf16>
+ return %1 : tensor<4xf16>
+}
+// CHECK-LABEL: func.func @powf
+// CHECK: %[[GEN:.*]] = linalg.generic
+// CHECK: %[[LHS:.*]] = arith.extf %{{.+}} : f16 to f32
+// CHECK: %[[RHS:.*]] = arith.extf %{{.+}} : f16 to f32
+// CHECK: %[[POWF:.*]] = math.powf %[[LHS]], %[[RHS]] : f32
+// CHECK: %[[TRUNC:.*]] = arith.truncf %[[POWF]] : f32 to f16
+// CHECK: linalg.yield %[[TRUNC:.*]] : f16
+// CHECK: return %[[GEN:.*]] : tensor<4xf16>
+