Add a pass for FastExp conversion (#3839)
Convert `llvm.intr.exp` into a sequence of ops that computes an approximation to exp(x).
This is using the fact that:
```
exp(x) = exp(x - floor(x \ ln(2) * ln(2)) 2^(floor(x \ ln(2))
= exp(x - k * ln(2)) * 2^k
```
exp(x - k * ln(2)) range is [0, ln(2)] which is approximated with 4degree polynomial.
The real number 2^k is computed with integer bitwise arithmetic.
MobileBert benchmarks:
Before:
```
------------------------------------------------------------------------------------
Benchmark Time CPU Iterations
------------------------------------------------------------------------------------
BM_serving_default/process_time/real_time 907 ms 905 ms 1
```
After:
```
------------------------------------------------------------------------------------
Benchmark Time CPU Iterations
------------------------------------------------------------------------------------
BM_serving_default/process_time/real_time 819 ms 815 ms 1
```
diff --git a/iree/compiler/Conversion/LLVMToLLVM/BUILD b/iree/compiler/Conversion/LLVMToLLVM/BUILD
new file mode 100644
index 0000000..1656647
--- /dev/null
+++ b/iree/compiler/Conversion/LLVMToLLVM/BUILD
@@ -0,0 +1,36 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "LLVMToLLVM",
+ srcs = [
+ "FastExpConversion.cpp",
+ ],
+ hdrs = [
+ "Passes.h",
+ ],
+ deps = [
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LLVMDialect",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
diff --git a/iree/compiler/Conversion/LLVMToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LLVMToLLVM/CMakeLists.txt
new file mode 100644
index 0000000..4b02edc
--- /dev/null
+++ b/iree/compiler/Conversion/LLVMToLLVM/CMakeLists.txt
@@ -0,0 +1,31 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ LLVMToLLVM
+ HDRS
+ "Passes.h"
+ SRCS
+ "FastExpConversion.cpp"
+ DEPS
+ LLVMSupport
+ MLIRIR
+ MLIRLLVMIR
+ MLIRPass
+ MLIRTransforms
+ PUBLIC
+)
diff --git a/iree/compiler/Conversion/LLVMToLLVM/FastExpConversion.cpp b/iree/compiler/Conversion/LLVMToLLVM/FastExpConversion.cpp
new file mode 100644
index 0000000..749ada4
--- /dev/null
+++ b/iree/compiler/Conversion/LLVMToLLVM/FastExpConversion.cpp
@@ -0,0 +1,125 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace {
+
+// Fast polynomial approximation of exp(x) using its reduced range exp(y)
+// where y is in the range [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2)
+// = x - k * ln(2), exp(x) = exp(y) * 2^k. exp(y) is computed with 4th degree
+// polyomial: exp(y) = c0 + c1 * y + c2 * y^2 + c3 * y^3 + c4 * y^4
+struct FastExpConversionPattern : public OpRewritePattern<LLVM::ExpOp> {
+ using OpRewritePattern<LLVM::ExpOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(LLVM::ExpOp op,
+ PatternRewriter &rewriter) const override {
+ constexpr float ln2Const = 0.693147181;
+ constexpr float ln2InvConst = 1.44269504;
+
+ // Least squares polynomial fit computed :
+ // cValues = np.polyfit(np.linspace(0, math.log(2), 10000), np.exp(x), 4)
+ constexpr float cValues[5] = {0.05924867, 0.15514645, 0.50308552,
+ 0.99968939, 1.00000721531};
+ auto loc = op.getLoc();
+ Value x = op.getOperand();
+
+ auto floatType = LLVM::LLVMType::getFloatTy(rewriter.getContext());
+ auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
+
+ Value ln2 = rewriter.create<LLVM::ConstantOp>(
+ loc, floatType, rewriter.getF32FloatAttr(ln2Const));
+ Value ln2Inv = rewriter.create<LLVM::ConstantOp>(
+ loc, floatType, rewriter.getF32FloatAttr(ln2InvConst));
+
+ // Compute reduced range input y = x - floor(x / ln(2)) * ln(2)
+ Value xL2Inv = rewriter.create<LLVM::FMulOp>(loc, floatType, x, ln2Inv);
+ Value kF32 = rewriter.create<LLVM::FFloorOp>(loc, floatType, xL2Inv);
+ Value kLn2 = rewriter.create<LLVM::FMulOp>(loc, floatType, kF32, ln2);
+ Value y = rewriter.create<LLVM::FSubOp>(loc, floatType, x, kLn2);
+
+ SmallVector<Value, 4> PConst(5);
+ for (int i = 0; i < 5; ++i) {
+ PConst[i] = rewriter.create<LLVM::ConstantOp>(
+ loc, floatType, rewriter.getF32FloatAttr(cValues[i]));
+ }
+ // Evaluate exp(y) = sum(c[i] * y**i, i)
+ Value expY = rewriter.create<LLVM::FMulOp>(loc, floatType, y, PConst[0]);
+ expY = rewriter.create<LLVM::FAddOp>(loc, floatType, expY, PConst[1]);
+ expY = rewriter.create<LLVM::FMulOp>(loc, floatType, expY, y);
+ expY = rewriter.create<LLVM::FAddOp>(loc, floatType, expY, PConst[2]);
+ expY = rewriter.create<LLVM::FMulOp>(loc, floatType, expY, y);
+ expY = rewriter.create<LLVM::FAddOp>(loc, floatType, expY, PConst[3]);
+ expY = rewriter.create<LLVM::FMulOp>(loc, floatType, expY, y);
+ expY = rewriter.create<LLVM::FAddOp>(loc, floatType, expY, PConst[4]);
+
+ // Compute exp2(k) with integer bitshift:
+ // exp2(k) = f32_bitcast((127 + k) << 23)
+ Value fPBias = rewriter.create<LLVM::ConstantOp>(
+ loc, i32Type, rewriter.getI32IntegerAttr(127));
+ Value k = rewriter.create<LLVM::FPToSIOp>(loc, i32Type, kF32);
+ Value kPlusfPBias = rewriter.create<LLVM::AddOp>(loc, i32Type, k, fPBias);
+ Value shiftConst = rewriter.create<LLVM::ConstantOp>(
+ loc, i32Type, rewriter.getI32IntegerAttr(23));
+ Value twoPowkI =
+ rewriter.create<LLVM::ShlOp>(loc, i32Type, kPlusfPBias, shiftConst);
+ Value twoPowk = rewriter.create<LLVM::BitcastOp>(loc, floatType, twoPowkI);
+ expY = rewriter.create<LLVM::FMulOp>(loc, floatType, expY, twoPowk);
+ rewriter.replaceOp(op, {expY});
+ // TODO(ataei): Handle overflow and underflow cases (e.g |k| > 128).
+ return success();
+ }
+};
+
+struct FastExpConversionPass
+ : public PassWrapper<FastExpConversionPass, OperationPass<ModuleOp>> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<LLVM::LLVMDialect>();
+ }
+ void runOnOperation() override;
+};
+
+} // namespace
+
+void populateFastExpConversionPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *context) {
+ patterns.insert<FastExpConversionPattern>(context);
+}
+
+void FastExpConversionPass::runOnOperation() {
+ auto moduleOp = getOperation();
+ auto context = moduleOp.getContext();
+ OwningRewritePatternList patterns;
+ populateFastExpConversionPatterns(patterns, context);
+ applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
+}
+
+std::unique_ptr<OperationPass<ModuleOp>>
+createFastExpApproximationConversionPass() {
+ return std::make_unique<FastExpConversionPass>();
+}
+
+static PassRegistration<OperationPass<ModuleOp>> pass(
+ "iree-codegen-linalg-to-llvm-fast-exp-conversion-pass",
+ "Convert llvm.intr.exp into its fast polynomial approximation version",
+ [] { return std::make_unique<FastExpConversionPass>(); });
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LLVMToLLVM/Passes.h b/iree/compiler/Conversion/LLVMToLLVM/Passes.h
new file mode 100644
index 0000000..e40fe90
--- /dev/null
+++ b/iree/compiler/Conversion/LLVMToLLVM/Passes.h
@@ -0,0 +1,30 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_CONVERSION_LLVMTOLLVM_PASSES_H_
+#define IREE_COMPILER_CONVERSION_LLVMTOLLVM_PASSES_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Creates a pass to rewrite llvm.intr.exp using its reduced range polynomial
+// approximation.
+std::unique_ptr<OperationPass<ModuleOp>>
+createFastExpApproximationConversionPass();
+
+} // namespace iree_compiler
+} // namespace mlir
+#endif // IREE_COMPILER_CONVERSION_LLVMTOLLVM_PASSES_H_
diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD
index c4396c5..5dbbadd 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD
@@ -37,6 +37,7 @@
"//iree/compiler/Conversion/Common",
"//iree/compiler/Conversion/HLOToHLO",
"//iree/compiler/Conversion/HLOToLinalg",
+ "//iree/compiler/Conversion/LLVMToLLVM",
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/HAL/IR:HALDialect",
"//iree/compiler/Dialect/IREE/IR",
diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
index 9564f1e..4018a29 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -49,6 +49,7 @@
iree::compiler::Conversion::Common
iree::compiler::Conversion::HLOToHLO
iree::compiler::Conversion::HLOToLinalg
+ iree::compiler::Conversion::LLVMToLLVM
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::IREE::IR
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
index 6f2f30b..348b425 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
@@ -17,6 +17,7 @@
#include "iree/compiler/Conversion/Common/Attributes.h"
#include "iree/compiler/Conversion/Common/Passes.h"
#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
+#include "iree/compiler/Conversion/LLVMToLLVM/Passes.h"
#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
@@ -34,6 +35,12 @@
"linag.matmul"),
llvm::cl::init(false));
+static llvm::cl::opt<bool> fastExpConversion(
+ "iree-codegen-linalg-to-llvm-fast-exp",
+ llvm::cl::desc("If true convert llvm.intr.exp into its range reduced "
+ "polynomial approximation."),
+ llvm::cl::init(false));
+
void addLinalgToLLVMPasses(OpPassManager &passManager) {
// Distribute linalg op among a 3d grid of parallel threads. Tile each
// workgroup thread memory then vectorize the linalg op.
@@ -66,6 +73,11 @@
passManager.addPass(createConvertToLLVMPass());
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createCSEPass());
+
+ // Approximate llvm.intr.exp with a 4-th order ploynmial in range[0, ln2].
+ if (fastExpConversion) {
+ passManager.addPass(createFastExpApproximationConversionPass());
+ }
}
void buildLLVMTransformPassPipeline(OpPassManager &passManager) {
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.h b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
index 8dcc089..33c74d4 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.h
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
@@ -30,6 +30,9 @@
/// Vectorize linalg ops executed in the same iree.workgroup.
std::unique_ptr<FunctionPass> createLinalgTileAndVectorizeWorkgroupsPass();
+std::unique_ptr<OperationPass<ModuleOp>>
+createFastExpApproximationConversionPass();
+
/// Populates patterns to rewrite linalg::ConvOp into packed img2col operation
/// followed by linalg::MatmulOp.
void populateConvImg2ColMatmulConversionPatterns(
diff --git a/iree/test/e2e/llvmir_specific/BUILD b/iree/test/e2e/llvmir_specific/BUILD
index 8770dc3..f0bf5de 100644
--- a/iree/test/e2e/llvmir_specific/BUILD
+++ b/iree/test/e2e/llvmir_specific/BUILD
@@ -34,3 +34,15 @@
driver = "llvm",
target_backend = "llvm-ir",
)
+
+iree_check_single_backend_test_suite(
+ name = "check_llvm-ir-exponential_fast",
+ srcs = [
+ "exponential.mlir",
+ ],
+ compiler_flags = [
+ "-iree-codegen-linalg-to-llvm-fast-exp=true",
+ ],
+ driver = "llvm",
+ target_backend = "llvm-ir",
+)
diff --git a/iree/test/e2e/llvmir_specific/CMakeLists.txt b/iree/test/e2e/llvmir_specific/CMakeLists.txt
index 0d00ec1..496841a 100644
--- a/iree/test/e2e/llvmir_specific/CMakeLists.txt
+++ b/iree/test/e2e/llvmir_specific/CMakeLists.txt
@@ -26,3 +26,16 @@
COMPILER_FLAGS
"-iree-codegen-linalg-to-llvm-conv-img2col-conversion=true"
)
+
+iree_check_single_backend_test_suite(
+ NAME
+ check_llvm-ir-exponential_fast
+ SRCS
+ "exponential.mlir"
+ TARGET_BACKEND
+ llvm-ir
+ DRIVER
+ llvm
+ COMPILER_FLAGS
+ "-iree-codegen-linalg-to-llvm-fast-exp=true"
+)
diff --git a/iree/test/e2e/llvmir_specific/exponential.mlir b/iree/test/e2e/llvmir_specific/exponential.mlir
new file mode 100644
index 0000000..6e91326
--- /dev/null
+++ b/iree/test/e2e/llvmir_specific/exponential.mlir
@@ -0,0 +1,27 @@
+func @tensor() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<[0.0, 1.0, 2.0, 4.0]> : tensor<4xf32>
+ %result = "mhlo.exponential"(%input) : (tensor<4xf32>) -> tensor<4xf32>
+ check.expect_almost_eq_const(%result, dense<[1.0, 2.7183, 7.3891, 54.5981]> : tensor<4xf32>) : tensor<4xf32>
+ return
+}
+
+func @scalar() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<1.0> : tensor<f32>
+ %result = "mhlo.exponential"(%input) : (tensor<f32>) -> tensor<f32>
+ check.expect_almost_eq_const(%result, dense<2.7183> : tensor<f32>) : tensor<f32>
+ return
+}
+
+func @double() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<1.0> : tensor<f64>
+ %result = "mhlo.exponential"(%input) : (tensor<f64>) -> tensor<f64>
+ check.expect_almost_eq_const(%result, dense<2.7183> : tensor<f64>) : tensor<f64>
+ return
+}
+
+func @negative() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<-1.0> : tensor<f32>
+ %result = "mhlo.exponential"(%input) : (tensor<f32>) -> tensor<f32>
+ check.expect_almost_eq_const(%result, dense<0.367879> : tensor<f32>) : tensor<f32>
+ return
+}