blob: e7898b5a133f812b31fcd593bdfc078d1b348129 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/VM/Conversion/MathToVM/ConvertMathToVM.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/Dialect/VM/Conversion/TargetOptions.h"
#include "iree/compiler/Dialect/VM/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace iree_compiler {
namespace {
template <typename SrcOpTy, typename Dst32OpTy, typename Dst64OpTy>
class UnaryArithmeticOpConversion : public OpConversionPattern<SrcOpTy> {
using OpConversionPattern<SrcOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
SrcOpTy srcOp, typename SrcOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): support vectors.
if (srcOp.getResult().getType().template isa<VectorType>())
return failure();
switch (adaptor.getOperand().getType().getIntOrFloatBitWidth()) {
case 32:
rewriter.replaceOpWithNewOp<Dst32OpTy>(
srcOp, adaptor.getOperand().getType(), adaptor.getOperand());
break;
case 64:
rewriter.replaceOpWithNewOp<Dst64OpTy>(
srcOp, adaptor.getOperand().getType(), adaptor.getOperand());
break;
default:
llvm_unreachable("invalid target type");
}
return success();
}
};
template <typename SrcOpTy, typename Dst32OpTy, typename Dst64OpTy>
class BinaryArithmeticOpConversion : public OpConversionPattern<SrcOpTy> {
using OpConversionPattern<SrcOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
SrcOpTy srcOp, typename SrcOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): support vectors.
if (srcOp.getResult().getType().template isa<VectorType>())
return failure();
switch (adaptor.getLhs().getType().getIntOrFloatBitWidth()) {
case 32:
rewriter.replaceOpWithNewOp<Dst32OpTy>(
srcOp, adaptor.getLhs().getType(), adaptor.getLhs(),
adaptor.getRhs());
break;
case 64:
rewriter.replaceOpWithNewOp<Dst64OpTy>(
srcOp, adaptor.getLhs().getType(), adaptor.getLhs(),
adaptor.getRhs());
break;
default:
llvm_unreachable("invalid target type");
}
return success();
}
};
} // namespace
void populateMathToVMPatterns(MLIRContext *context,
TypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns
.insert<UnaryArithmeticOpConversion<math::AtanOp, IREE::VM::AtanF32Op,
IREE::VM::AtanF64Op>,
BinaryArithmeticOpConversion<math::Atan2Op, IREE::VM::Atan2F32Op,
IREE::VM::Atan2F64Op>,
UnaryArithmeticOpConversion<math::CosOp, IREE::VM::CosF32Op,
IREE::VM::CosF64Op>,
UnaryArithmeticOpConversion<math::SinOp, IREE::VM::SinF32Op,
IREE::VM::SinF64Op>,
UnaryArithmeticOpConversion<math::ExpOp, IREE::VM::ExpF32Op,
IREE::VM::ExpF64Op>,
UnaryArithmeticOpConversion<math::Exp2Op, IREE::VM::Exp2F32Op,
IREE::VM::Exp2F64Op>,
UnaryArithmeticOpConversion<math::ExpM1Op, IREE::VM::ExpM1F32Op,
IREE::VM::ExpM1F64Op>,
UnaryArithmeticOpConversion<math::LogOp, IREE::VM::LogF32Op,
IREE::VM::LogF64Op>,
UnaryArithmeticOpConversion<math::Log10Op, IREE::VM::Log10F32Op,
IREE::VM::Log10F64Op>,
UnaryArithmeticOpConversion<math::Log1pOp, IREE::VM::Log1pF32Op,
IREE::VM::Log1pF64Op>,
UnaryArithmeticOpConversion<math::Log2Op, IREE::VM::Log2F32Op,
IREE::VM::Log2F64Op>,
BinaryArithmeticOpConversion<math::PowFOp, IREE::VM::PowF32Op,
IREE::VM::PowF64Op>,
UnaryArithmeticOpConversion<math::RsqrtOp, IREE::VM::RsqrtF32Op,
IREE::VM::RsqrtF64Op>,
UnaryArithmeticOpConversion<math::SqrtOp, IREE::VM::SqrtF32Op,
IREE::VM::SqrtF64Op>,
UnaryArithmeticOpConversion<math::TanhOp, IREE::VM::TanhF32Op,
IREE::VM::TanhF64Op>,
UnaryArithmeticOpConversion<math::ErfOp, IREE::VM::ErfF32Op,
IREE::VM::ErfF64Op>>(typeConverter,
context);
}
} // namespace iree_compiler
} // namespace mlir