Defining initial set of VM ExtF32 and ExtF64 ops.
There are definitely some missing ones and missing conversions
(like int<->fp, exp/log/etc, comparisons, etc).
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertVariableOps.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertVariableOps.cpp
index 4315fc5..baa0db2 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertVariableOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertVariableOps.cpp
@@ -57,6 +57,30 @@
op, op.sym_name(), op.is_mutable(), convertedType, op.initializer(),
convertedValue, llvm::to_vector<4>(op->getDialectAttrs()));
return success();
+ } else if (convertedType.isF32()) {
+ auto convertedValue = op.initial_value().hasValue()
+ ? rewriter.getF32FloatAttr(static_cast<float>(
+ op.initial_value()
+ .getValue()
+ .cast<FloatAttr>()
+ .getValueAsDouble()))
+ : Attribute{};
+ rewriter.replaceOpWithNewOp<IREE::VM::GlobalF32Op>(
+ op, op.sym_name(), op.is_mutable(), convertedType, op.initializer(),
+ convertedValue, llvm::to_vector<4>(op->getDialectAttrs()));
+ return success();
+ } else if (convertedType.isF64()) {
+ auto convertedValue =
+ op.initial_value().hasValue()
+ ? rewriter.getF64FloatAttr(op.initial_value()
+ .getValue()
+ .cast<FloatAttr>()
+ .getValueAsDouble())
+ : Attribute{};
+ rewriter.replaceOpWithNewOp<IREE::VM::GlobalF64Op>(
+ op, op.sym_name(), op.is_mutable(), convertedType, op.initializer(),
+ convertedValue, llvm::to_vector<4>(op->getDialectAttrs()));
+ return success();
}
return op.emitOpError("unsupported variable type");
}
@@ -93,12 +117,25 @@
LogicalResult matchAndRewrite(
IREE::HAL::VariableLoadOp op, llvm::ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- if (IREE::VM::RefType::isCompatible(op.getType())) {
- rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadRefOp>(
- op, typeConverter.convertType(op.getType()), op.variable());
+ auto operandType = op.getType();
+ auto convertedType = typeConverter.convertType(operandType);
+ if (IREE::VM::RefType::isCompatible(operandType)) {
+ rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadRefOp>(op, convertedType,
+ op.variable());
+ } else if (convertedType.isInteger(32)) {
+ rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadI32Op>(op, convertedType,
+ op.variable());
+ } else if (convertedType.isInteger(64)) {
+ rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadI64Op>(op, convertedType,
+ op.variable());
+ } else if (convertedType.isF32()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadF32Op>(op, convertedType,
+ op.variable());
+ } else if (convertedType.isF64()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadF64Op>(op, convertedType,
+ op.variable());
} else {
- rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadI32Op>(
- op, typeConverter.convertType(op.getType()), op.variable());
+ return rewriter.notifyMatchFailure(op, "unhandled variable type");
}
return success();
}
@@ -115,15 +152,27 @@
: OpConversionPattern(context), typeConverter(typeConverter) {}
LogicalResult matchAndRewrite(
- IREE::HAL::VariableLoadIndirectOp op, llvm::ArrayRef<Value> newOperands,
+ IREE::HAL::VariableLoadIndirectOp op, llvm::ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- IREE::HAL::VariableLoadIndirectOp::Adaptor operands(newOperands);
- if (IREE::VM::RefType::isCompatible(op.getType())) {
+ auto operandType = op.getType();
+ auto convertedType = typeConverter.convertType(operandType);
+ if (IREE::VM::RefType::isCompatible(operandType)) {
rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadIndirectRefOp>(
- op, typeConverter.convertType(op.getType()), operands.variable());
- } else {
+ op, convertedType, op.variable());
+ } else if (convertedType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadIndirectI32Op>(
- op, typeConverter.convertType(op.getType()), operands.variable());
+ op, convertedType, op.variable());
+ } else if (convertedType.isInteger(64)) {
+ rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadIndirectI64Op>(
+ op, convertedType, op.variable());
+ } else if (convertedType.isF32()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadIndirectF32Op>(
+ op, convertedType, op.variable());
+ } else if (convertedType.isF64()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadIndirectF64Op>(
+ op, convertedType, op.variable());
+ } else {
+ return rewriter.notifyMatchFailure(op, "unhandled variable type");
}
return success();
}
@@ -142,12 +191,24 @@
IREE::HAL::VariableStoreOp op, llvm::ArrayRef<Value> newOperands,
ConversionPatternRewriter &rewriter) const override {
IREE::HAL::VariableStoreOp::Adaptor operands(newOperands);
- if (operands.value().getType().isa<IREE::VM::RefType>()) {
+ auto operandType = operands.value().getType();
+ if (operandType.isa<IREE::VM::RefType>()) {
rewriter.replaceOpWithNewOp<IREE::VM::GlobalStoreRefOp>(
op, operands.value(), op.variable());
- } else {
+ } else if (operandType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::GlobalStoreI32Op>(
op, operands.value(), op.variable());
+ } else if (operandType.isInteger(64)) {
+ rewriter.replaceOpWithNewOp<IREE::VM::GlobalStoreI64Op>(
+ op, operands.value(), op.variable());
+ } else if (operandType.isF32()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::GlobalStoreF32Op>(
+ op, operands.value(), op.variable());
+ } else if (operandType.isF64()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::GlobalStoreF64Op>(
+ op, operands.value(), op.variable());
+ } else {
+ return rewriter.notifyMatchFailure(op, "unhandled variable type");
}
return success();
}
@@ -164,12 +225,24 @@
IREE::HAL::VariableStoreIndirectOp op, llvm::ArrayRef<Value> newOperands,
ConversionPatternRewriter &rewriter) const override {
IREE::HAL::VariableStoreIndirectOp::Adaptor operands(newOperands);
- if (operands.value().getType().isa<IREE::VM::RefType>()) {
+ auto operandType = operands.value().getType();
+ if (operandType.isa<IREE::VM::RefType>()) {
rewriter.replaceOpWithNewOp<IREE::VM::GlobalStoreIndirectRefOp>(
- op, operands.value(), operands.variable());
- } else {
+ op, operands.value(), op.variable());
+ } else if (operandType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::GlobalStoreIndirectI32Op>(
- op, operands.value(), operands.variable());
+ op, operands.value(), op.variable());
+ } else if (operandType.isInteger(64)) {
+ rewriter.replaceOpWithNewOp<IREE::VM::GlobalStoreIndirectI64Op>(
+ op, operands.value(), op.variable());
+ } else if (operandType.isF32()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::GlobalStoreIndirectF32Op>(
+ op, operands.value(), op.variable());
+ } else if (operandType.isF64()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::GlobalStoreIndirectF64Op>(
+ op, operands.value(), op.variable());
+ } else {
+ return rewriter.notifyMatchFailure(op, "unhandled variable type");
}
return success();
}
diff --git a/iree/compiler/Dialect/VM/Conversion/IREEToVM/ConvertIREEToVM.cpp b/iree/compiler/Dialect/VM/Conversion/IREEToVM/ConvertIREEToVM.cpp
index f49e8f2..2f8b606 100644
--- a/iree/compiler/Dialect/VM/Conversion/IREEToVM/ConvertIREEToVM.cpp
+++ b/iree/compiler/Dialect/VM/Conversion/IREEToVM/ConvertIREEToVM.cpp
@@ -75,9 +75,8 @@
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREE::VM::FailOp>(
srcOp,
- rewriter.createOrFold<mlir::ConstantIntOp>(
- srcOp.getLoc(), static_cast<int32_t>(IREE::StatusCode::Unknown),
- 32),
+ rewriter.createOrFold<IREE::VM::ConstI32Op>(
+ srcOp.getLoc(), static_cast<int32_t>(IREE::StatusCode::Unknown)),
srcOp.message());
return success();
}
@@ -138,6 +137,12 @@
} else if (resultType.isInteger(64)) {
rewriter.replaceOpWithNewOp<IREE::VM::ListGetI64Op>(
srcOp, resultType, srcOperands.list(), srcOperands.index());
+ } else if (resultType.isF32()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::ListGetF32Op>(
+ srcOp, resultType, srcOperands.list(), srcOperands.index());
+ } else if (resultType.isF64()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::ListGetF64Op>(
+ srcOp, resultType, srcOperands.list(), srcOperands.index());
} else if (!resultType.isIntOrIndexOrFloat()) {
rewriter.replaceOpWithNewOp<IREE::VM::ListGetRefOp>(
srcOp, resultType, srcOperands.list(), srcOperands.index());
@@ -161,6 +166,12 @@
} else if (valueType.isInteger(64)) {
rewriter.replaceOpWithNewOp<IREE::VM::ListSetI64Op>(
srcOp, srcOperands.list(), srcOperands.index(), srcOperands.value());
+ } else if (valueType.isF32()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::ListSetF32Op>(
+ srcOp, srcOperands.list(), srcOperands.index(), srcOperands.value());
+ } else if (valueType.isF64()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::ListSetF64Op>(
+ srcOp, srcOperands.list(), srcOperands.index(), srcOperands.value());
} else if (!valueType.isIntOrIndexOrFloat()) {
rewriter.replaceOpWithNewOp<IREE::VM::ListSetRefOp>(
srcOp, srcOperands.list(), srcOperands.index(), srcOperands.value());
diff --git a/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp b/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
index d5e6317..eb98d42 100644
--- a/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
+++ b/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
@@ -160,35 +160,61 @@
LogicalResult matchAndRewrite(
ConstantOp srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto integerAttr = srcOp.getValue().dyn_cast<IntegerAttr>();
- if (!integerAttr) {
- return srcOp.emitRemark() << "unsupported const type for dialect";
- }
// TODO(#2878): use getTypeConverter() when we pass it upon creation.
IREE::VM::TypeConverter typeConverter(
IREE::VM::getTargetOptionsFromFlags());
auto targetType = typeConverter.convertType(srcOp.getType());
- switch (targetType.getIntOrFloatBitWidth()) {
- case 1:
- case 32:
- if (integerAttr.getInt()) {
- rewriter.replaceOpWithNewOp<IREE::VM::ConstI32Op>(
- srcOp, integerAttr.getInt());
- } else {
- rewriter.replaceOpWithNewOp<IREE::VM::ConstI32ZeroOp>(srcOp);
- }
- break;
- case 64:
- if (integerAttr.getInt()) {
- rewriter.replaceOpWithNewOp<IREE::VM::ConstI64Op>(
- srcOp, integerAttr.getInt());
- } else {
- rewriter.replaceOpWithNewOp<IREE::VM::ConstI64ZeroOp>(srcOp);
- }
- break;
- default:
- return srcOp.emitRemark()
- << "unsupported const integer bit width for dialect";
+ if (targetType.isa<IntegerType>()) {
+ auto integerAttr = srcOp.getValue().dyn_cast<IntegerAttr>();
+ if (!integerAttr) {
+ return srcOp.emitRemark() << "unsupported const type for dialect";
+ }
+ switch (targetType.getIntOrFloatBitWidth()) {
+ case 1:
+ case 32:
+ if (integerAttr.getInt()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::ConstI32Op>(
+ srcOp, integerAttr.getInt());
+ } else {
+ rewriter.replaceOpWithNewOp<IREE::VM::ConstI32ZeroOp>(srcOp);
+ }
+ break;
+ case 64:
+ if (integerAttr.getInt()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::ConstI64Op>(
+ srcOp, integerAttr.getInt());
+ } else {
+ rewriter.replaceOpWithNewOp<IREE::VM::ConstI64ZeroOp>(srcOp);
+ }
+ break;
+ default:
+ return srcOp.emitRemark()
+ << "unsupported const integer bit width for dialect";
+ }
+ } else if (targetType.isa<FloatType>()) {
+ auto floatAttr = srcOp.getValue().dyn_cast<FloatAttr>();
+ if (!floatAttr) {
+ return srcOp.emitRemark() << "unsupported const type for dialect";
+ }
+ switch (targetType.getIntOrFloatBitWidth()) {
+ case 32:
+ if (floatAttr.getValue().isZero()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::ConstF32ZeroOp>(srcOp);
+ } else {
+ rewriter.replaceOpWithNewOp<IREE::VM::ConstF32Op>(srcOp, floatAttr);
+ }
+ break;
+ case 64:
+ if (floatAttr.getValue().isZero()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::ConstF64ZeroOp>(srcOp);
+ } else {
+ rewriter.replaceOpWithNewOp<IREE::VM::ConstF64Op>(srcOp, floatAttr);
+ }
+ break;
+ default:
+ return srcOp.emitRemark()
+ << "unsupported const floating-point bit width for dialect";
+ }
}
return success();
}
@@ -200,48 +226,48 @@
LogicalResult matchAndRewrite(
CmpIOp srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- CmpIOp::Adaptor srcAdapter(operands);
+ CmpIOp::Adaptor srcAdaptor(operands);
auto returnType = rewriter.getIntegerType(32);
switch (srcOp.getPredicate()) {
case CmpIPredicate::eq:
rewriter.replaceOpWithNewOp<IREE::VM::CmpEQI32Op>(
- srcOp, returnType, srcAdapter.lhs(), srcAdapter.rhs());
+ srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
return success();
case CmpIPredicate::ne:
rewriter.replaceOpWithNewOp<IREE::VM::CmpNEI32Op>(
- srcOp, returnType, srcAdapter.lhs(), srcAdapter.rhs());
+ srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
return success();
case CmpIPredicate::slt:
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTI32SOp>(
- srcOp, returnType, srcAdapter.lhs(), srcAdapter.rhs());
+ srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
return success();
case CmpIPredicate::sle:
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTEI32SOp>(
- srcOp, returnType, srcAdapter.lhs(), srcAdapter.rhs());
+ srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
return success();
case CmpIPredicate::sgt:
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTI32SOp>(
- srcOp, returnType, srcAdapter.lhs(), srcAdapter.rhs());
+ srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
return success();
case CmpIPredicate::sge:
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTEI32SOp>(
- srcOp, returnType, srcAdapter.lhs(), srcAdapter.rhs());
+ srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
return success();
case CmpIPredicate::ult:
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTI32UOp>(
- srcOp, returnType, srcAdapter.lhs(), srcAdapter.rhs());
+ srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
return success();
case CmpIPredicate::ule:
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTEI32UOp>(
- srcOp, returnType, srcAdapter.lhs(), srcAdapter.rhs());
+ srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
return success();
case CmpIPredicate::ugt:
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTI32UOp>(
- srcOp, returnType, srcAdapter.lhs(), srcAdapter.rhs());
+ srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
return success();
case CmpIPredicate::uge:
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTEI32UOp>(
- srcOp, returnType, srcAdapter.lhs(), srcAdapter.rhs());
+ srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
return success();
default:
return failure();
@@ -249,17 +275,28 @@
}
};
-template <typename SrcOpTy, typename DstOpTy>
+template <typename SrcOpTy, typename Dst32OpTy, typename Dst64OpTy>
class BinaryArithmeticOpConversion : public OpConversionPattern<SrcOpTy> {
using OpConversionPattern<SrcOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
SrcOpTy srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- typename SrcOpTy::Adaptor srcAdapter(operands);
-
- rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcAdapter.lhs().getType(),
- srcAdapter.lhs(), srcAdapter.rhs());
+ typename SrcOpTy::Adaptor srcAdaptor(operands);
+ switch (srcAdaptor.lhs().getType().getIntOrFloatBitWidth()) {
+ case 32:
+ rewriter.replaceOpWithNewOp<Dst32OpTy>(
+ srcOp, srcAdaptor.lhs().getType(), srcAdaptor.lhs(),
+ srcAdaptor.rhs());
+ break;
+ case 64:
+ rewriter.replaceOpWithNewOp<Dst64OpTy>(
+ srcOp, srcAdaptor.lhs().getType(), srcAdaptor.lhs(),
+ srcAdaptor.rhs());
+ break;
+ default:
+ llvm_unreachable("invalid target type");
+ }
return success();
}
};
@@ -302,28 +339,37 @@
}
};
-class SelectI32OpConversion : public OpConversionPattern<SelectOp> {
+class SelectOpConversion : public OpConversionPattern<SelectOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
SelectOp srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
SelectOp::Adaptor srcAdaptor(operands);
- IntegerType requiredType = IntegerType::get(srcOp.getContext(), 32);
- // Note: This check can correctly just be a verification that
- // actualType == requiredType, but since the VM type conversion also
- // maps Indextype to this type, widening the check here reduces red-herrings
- // when other conversions fail to properly match/rewrite index related ops.
- // (Otherwise, the dialect converter may report the error as a failure to
- // legalize the select op depending on order of resolution).
- auto actualType = srcAdaptor.true_value().getType();
- if (actualType != requiredType && actualType.isa<IndexType>()) {
- return failure();
+ auto valueType = srcAdaptor.true_value().getType();
+ if (valueType.isInteger(32)) {
+ rewriter.replaceOpWithNewOp<IREE::VM::SelectI32Op>(
+ srcOp, valueType, srcAdaptor.condition(), srcAdaptor.true_value(),
+ srcAdaptor.false_value());
+ return success();
+ } else if (valueType.isInteger(64)) {
+ rewriter.replaceOpWithNewOp<IREE::VM::SelectI64Op>(
+ srcOp, valueType, srcAdaptor.condition(), srcAdaptor.true_value(),
+ srcAdaptor.false_value());
+ return success();
+ } else if (valueType.isF32()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::SelectF32Op>(
+ srcOp, valueType, srcAdaptor.condition(), srcAdaptor.true_value(),
+ srcAdaptor.false_value());
+ return success();
+ } else if (valueType.isF64()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::SelectF64Op>(
+ srcOp, valueType, srcAdaptor.condition(), srcAdaptor.true_value(),
+ srcAdaptor.false_value());
+ return success();
+ } else {
+ return rewriter.notifyMatchFailure(srcOp,
+ "unsupported select element type");
}
-
- rewriter.replaceOpWithNewOp<IREE::VM::SelectI32Op>(
- srcOp, requiredType, srcAdaptor.condition(), srcAdaptor.true_value(),
- srcAdaptor.false_value());
- return success();
}
};
@@ -386,24 +432,45 @@
patterns.insert<BranchOpConversion, CallOpConversion, CmpIOpConversion,
CondBranchOpConversion, ModuleOpConversion, FuncOpConversion,
ReturnOpConversion, CastingOpConversion<IndexCastOp>,
- CastingOpConversion<TruncateIOp>, SelectI32OpConversion>(
+ CastingOpConversion<TruncateIOp>, SelectOpConversion>(
typeConverter, context);
// TODO(#2878): pass typeConverter here.
patterns.insert<ConstantOpConversion>(context);
- // Binary arithmetic ops
- patterns
- .insert<BinaryArithmeticOpConversion<AddIOp, IREE::VM::AddI32Op>,
- BinaryArithmeticOpConversion<SignedDivIOp, IREE::VM::DivI32SOp>,
- BinaryArithmeticOpConversion<UnsignedDivIOp, IREE::VM::DivI32UOp>,
- BinaryArithmeticOpConversion<MulIOp, IREE::VM::MulI32Op>,
- BinaryArithmeticOpConversion<SignedRemIOp, IREE::VM::RemI32SOp>,
- BinaryArithmeticOpConversion<UnsignedRemIOp, IREE::VM::RemI32UOp>,
- BinaryArithmeticOpConversion<SubIOp, IREE::VM::SubI32Op>,
- BinaryArithmeticOpConversion<AndOp, IREE::VM::AndI32Op>,
- BinaryArithmeticOpConversion<OrOp, IREE::VM::OrI32Op>,
- BinaryArithmeticOpConversion<XOrOp, IREE::VM::XorI32Op>>(
- typeConverter, context);
+ // Integer arithmetic ops.
+ patterns.insert<
+ BinaryArithmeticOpConversion<AddIOp, IREE::VM::AddI32Op,
+ IREE::VM::AddI64Op>,
+ BinaryArithmeticOpConversion<SignedDivIOp, IREE::VM::DivI32SOp,
+ IREE::VM::DivI64SOp>,
+ BinaryArithmeticOpConversion<UnsignedDivIOp, IREE::VM::DivI32UOp,
+ IREE::VM::DivI64UOp>,
+ BinaryArithmeticOpConversion<MulIOp, IREE::VM::MulI32Op,
+ IREE::VM::MulI64Op>,
+ BinaryArithmeticOpConversion<SignedRemIOp, IREE::VM::RemI32SOp,
+ IREE::VM::RemI64SOp>,
+ BinaryArithmeticOpConversion<UnsignedRemIOp, IREE::VM::RemI32UOp,
+ IREE::VM::RemI64UOp>,
+ BinaryArithmeticOpConversion<SubIOp, IREE::VM::SubI32Op,
+ IREE::VM::SubI64Op>,
+ BinaryArithmeticOpConversion<AndOp, IREE::VM::AndI32Op,
+ IREE::VM::AndI64Op>,
+ BinaryArithmeticOpConversion<OrOp, IREE::VM::OrI32Op, IREE::VM::OrI64Op>,
+ BinaryArithmeticOpConversion<XOrOp, IREE::VM::XorI32Op,
+ IREE::VM::XorI64Op>>(typeConverter, context);
+
+ // Floating-point arithmetic ops.
+ patterns.insert<BinaryArithmeticOpConversion<AddFOp, IREE::VM::AddF32Op,
+ IREE::VM::AddF64Op>,
+ BinaryArithmeticOpConversion<DivFOp, IREE::VM::DivF32Op,
+ IREE::VM::DivF64Op>,
+ BinaryArithmeticOpConversion<MulFOp, IREE::VM::MulF32Op,
+ IREE::VM::MulF64Op>,
+ BinaryArithmeticOpConversion<RemFOp, IREE::VM::RemF32Op,
+ IREE::VM::RemF64Op>,
+ BinaryArithmeticOpConversion<SubFOp, IREE::VM::SubF32Op,
+ IREE::VM::SubF64Op>>(
+ typeConverter, context);
// Shift ops
// TODO(laurenzo): The standard dialect is missing shr ops. Add once in place.
diff --git a/iree/compiler/Dialect/VM/Conversion/TargetOptions.cpp b/iree/compiler/Dialect/VM/Conversion/TargetOptions.cpp
index 32f7513..7cb1218 100644
--- a/iree/compiler/Dialect/VM/Conversion/TargetOptions.cpp
+++ b/iree/compiler/Dialect/VM/Conversion/TargetOptions.cpp
@@ -37,7 +37,9 @@
llvm::cl::desc("Supported target opcode extensions"),
llvm::cl::cat(vmTargetOptionsCategory),
llvm::cl::values(
- clEnumValN(OpcodeExtension::kI64, "i64", "i64 type support")),
+ clEnumValN(OpcodeExtension::kI64, "i64", "i64 type support"),
+ clEnumValN(OpcodeExtension::kF32, "f32", "f32 type support"),
+ clEnumValN(OpcodeExtension::kF64, "f64", "f64 type support")),
};
static auto *truncateUnsupportedIntegersFlag = new llvm::cl::opt<bool>{
"iree-vm-target-truncate-unsupported-integers",
@@ -45,6 +47,12 @@
llvm::cl::desc("Truncate i64 to i32 when unsupported"),
llvm::cl::cat(vmTargetOptionsCategory),
};
+ static auto *truncateUnsupportedFloatsFlag = new llvm::cl::opt<bool>{
+ "iree-vm-target-truncate-unsupported-floats",
+ llvm::cl::init(true),
+ llvm::cl::desc("Truncate f64 to f32 when unsupported"),
+ llvm::cl::cat(vmTargetOptionsCategory),
+ };
TargetOptions targetOptions;
targetOptions.indexBits = *indexBitsFlag;
@@ -53,9 +61,16 @@
case OpcodeExtension::kI64:
targetOptions.i64Extension = true;
break;
+ case OpcodeExtension::kF32:
+ targetOptions.f32Extension = true;
+ break;
+ case OpcodeExtension::kF64:
+ targetOptions.f64Extension = true;
+ break;
}
}
targetOptions.truncateUnsupportedIntegers = *truncateUnsupportedIntegersFlag;
+ targetOptions.truncateUnsupportedFloats = *truncateUnsupportedFloatsFlag;
return targetOptions;
}
diff --git a/iree/compiler/Dialect/VM/Conversion/TargetOptions.h b/iree/compiler/Dialect/VM/Conversion/TargetOptions.h
index f68cab7..5f1973e 100644
--- a/iree/compiler/Dialect/VM/Conversion/TargetOptions.h
+++ b/iree/compiler/Dialect/VM/Conversion/TargetOptions.h
@@ -26,6 +26,10 @@
enum class OpcodeExtension {
// Adds ops for manipulating i64 types.
kI64,
+ // Adds ops for manipulating f32 types.
+ kF32,
+ // Adds ops for manipulating f64 types.
+ kF64,
};
// Controls VM translation targets.
@@ -35,10 +39,20 @@
// Whether the i64 extension is enabled in the target VM.
bool i64Extension = false;
+ // Whether the f32 extension is enabled in the target VM.
+ bool f32Extension = false;
+ // Whether the f64 extension is enabled in the target VM.
+ bool f64Extension = false;
// Whether to truncate i64 types to i32 when the i64 extension is not
// enabled.
bool truncateUnsupportedIntegers = true;
+ // Whether to truncate f64 types to f32 when the f64 extension is not
+ // enabled.
+ bool truncateUnsupportedFloats = true;
+
+ // Prefer optimizations that reduce VM stack usage over performance.
+ bool optimizeForStackSize = true;
};
// Returns a TargetOptions struct initialized with the
diff --git a/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp b/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp
index 0a27979..4054a3b 100644
--- a/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp
+++ b/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp
@@ -79,6 +79,37 @@
return llvm::None;
});
+ // Convert floating-point types.
+ addConversion([this](FloatType floatType) -> Optional<Type> {
+ if (floatType.getIntOrFloatBitWidth() < 32) {
+ if (targetOptions_.f32Extension) {
+ // Promote f16 -> f32.
+ return FloatType::getF32(floatType.getContext());
+ } else if (targetOptions_.truncateUnsupportedIntegers) {
+ // f32 is not supported; can't compile.
+ return llvm::None;
+ }
+ } else if (floatType.isF32()) {
+ if (targetOptions_.f32Extension) {
+ return floatType;
+ } else if (targetOptions_.truncateUnsupportedIntegers) {
+ // f32 is not supported; can't compile.
+ return llvm::None;
+ }
+ } else if (floatType.isF64()) {
+ if (targetOptions_.f64Extension) {
+ // f64 is supported by the VM, use directly.
+ return floatType;
+ } else if (targetOptions_.f32Extension &&
+ targetOptions_.truncateUnsupportedFloats) {
+ // f64 is not supported and we still want to compile, so truncate to
+ // f32 (unsafe if all bits are actually required!).
+ return FloatType::getF32(floatType.getContext());
+ }
+ }
+ return llvm::None;
+ });
+
// Convert index types to the target bit width.
addConversion([this](IndexType indexType) -> Optional<Type> {
return IntegerType::get(indexType.getContext(), targetOptions_.indexBits);
diff --git a/iree/compiler/Dialect/VM/IR/BUILD b/iree/compiler/Dialect/VM/IR/BUILD
index 05fc5b8..207dcd2 100644
--- a/iree/compiler/Dialect/VM/IR/BUILD
+++ b/iree/compiler/Dialect/VM/IR/BUILD
@@ -30,6 +30,8 @@
[
"VMBase.td",
"VMOpcodesCore.td",
+ "VMOpcodesF32.td",
+ "VMOpcodesF64.td",
"VMOpcodesI64.td",
"VMOps.td",
],
diff --git a/iree/compiler/Dialect/VM/IR/VMBase.td b/iree/compiler/Dialect/VM/IR/VMBase.td
index 6cec68d..74760fd 100644
--- a/iree/compiler/Dialect/VM/IR/VMBase.td
+++ b/iree/compiler/Dialect/VM/IR/VMBase.td
@@ -56,8 +56,14 @@
// Opcode tables
//===----------------------------------------------------------------------===//
+// Mandatory:
include "iree/compiler/Dialect/VM/IR/VMOpcodesCore.td"
+// Optional:
include "iree/compiler/Dialect/VM/IR/VMOpcodesI64.td"
+// Optional:
+include "iree/compiler/Dialect/VM/IR/VMOpcodesF32.td"
+// Optional:
+include "iree/compiler/Dialect/VM/IR/VMOpcodesF64.td"
//===----------------------------------------------------------------------===//
// Declarative encoding framework
@@ -363,6 +369,14 @@
let constBuilderCall = "$0";
}
+class VM_ConstantFloatValueAttr<F type> :
+ FloatAttrBase<type, type.bitwidth # "-bit floating-point value"> {
+ let storageType = "FloatAttr";
+ let returnType = "FloatAttr";
+ let convertFromStorage = "$_self";
+ let constBuilderCall = "$0";
+}
+
//===----------------------------------------------------------------------===//
// VM structs
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/VM/IR/VMDialect.cpp b/iree/compiler/Dialect/VM/IR/VMDialect.cpp
index d2d796d..a5b6709 100644
--- a/iree/compiler/Dialect/VM/IR/VMDialect.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMDialect.cpp
@@ -61,11 +61,18 @@
}
if (auto globalLoadOp = dyn_cast<GlobalLoadI32Op>(op)) {
os << globalLoadOp.global();
+ } else if (auto globalLoadOp = dyn_cast<GlobalLoadI64Op>(op)) {
+ os << globalLoadOp.global();
+ } else if (auto globalLoadOp = dyn_cast<GlobalLoadF32Op>(op)) {
+ os << globalLoadOp.global();
+ } else if (auto globalLoadOp = dyn_cast<GlobalLoadF64Op>(op)) {
+ os << globalLoadOp.global();
} else if (auto globalLoadOp = dyn_cast<GlobalLoadRefOp>(op)) {
os << globalLoadOp.global();
} else if (isa<ConstRefZeroOp>(op)) {
os << "null";
- } else if (isa<ConstI32ZeroOp>(op) || isa<ConstI64ZeroOp>(op)) {
+ } else if (isa<ConstI32ZeroOp>(op) || isa<ConstI64ZeroOp>(op) ||
+ isa<ConstF32ZeroOp>(op) || isa<ConstF64ZeroOp>(op)) {
os << "zero";
} else if (auto constOp = dyn_cast<ConstI32Op>(op)) {
getIntegerName(constOp.value().dyn_cast<IntegerAttr>(), os);
@@ -319,6 +326,18 @@
return builder.create<VM::ConstI64ZeroOp>(loc);
}
return builder.create<VM::ConstI64Op>(loc, convertedValue);
+ } else if (ConstF32Op::isBuildableWith(value, type)) {
+ auto convertedValue = ConstF32Op::convertConstValue(value);
+ if (convertedValue.cast<FloatAttr>().getValue().isZero()) {
+ return builder.create<VM::ConstF32ZeroOp>(loc);
+ }
+ return builder.create<VM::ConstF32Op>(loc, convertedValue);
+ } else if (ConstF64Op::isBuildableWith(value, type)) {
+ auto convertedValue = ConstF64Op::convertConstValue(value);
+ if (convertedValue.cast<FloatAttr>().getValue().isZero()) {
+ return builder.create<VM::ConstF64ZeroOp>(loc);
+ }
+ return builder.create<VM::ConstF64Op>(loc, convertedValue);
} else if (type.isa<IREE::VM::RefType>()) {
// The only constant type we support for refs is null so we can just
// emit that here.
diff --git a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index 246467b..47040d4 100644
--- a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -46,10 +46,11 @@
/// Creates a constant one attribute matching the given type.
Attribute oneOfType(Type type) {
Builder builder(type.getContext());
- if (type.isa<FloatType>()) return builder.getFloatAttr(type, 1.0);
- if (auto integerTy = type.dyn_cast<IntegerType>())
+ if (type.isa<FloatType>()) {
+ return builder.getFloatAttr(type, 1.0);
+ } else if (auto integerTy = type.dyn_cast<IntegerType>()) {
return builder.getIntegerAttr(integerTy, APInt(integerTy.getWidth(), 1));
- if (type.isa<RankedTensorType, VectorType>()) {
+ } else if (type.isa<RankedTensorType, VectorType>()) {
auto vtType = type.cast<ShapedType>();
auto element = oneOfType(vtType.getElementType());
if (!element) return {};
@@ -105,8 +106,12 @@
LogicalResult matchAndRewrite(T op,
PatternRewriter &rewriter) const override {
if (!op.initial_value().hasValue()) return failure();
- auto value = op.initial_valueAttr().template cast<IntegerAttr>();
- if (value.getValue() != 0) return failure();
+ if (auto value = op.initial_valueAttr().template dyn_cast<IntegerAttr>()) {
+ if (value.getValue() != 0) return failure();
+ } else if (auto value =
+ op.initial_valueAttr().template dyn_cast<FloatAttr>()) {
+ if (value.getValue().isNonZero()) return failure();
+ }
rewriter.replaceOpWithNewOp<T>(op, op.sym_name(), op.is_mutable(),
op.type(),
llvm::to_vector<4>(op->getDialectAttrs()));
@@ -128,6 +133,18 @@
DropDefaultConstGlobalOpInitializer<GlobalI64Op>>(context);
}
+void GlobalF32Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<InlineConstGlobalOpInitializer<GlobalF32Op>,
+ DropDefaultConstGlobalOpInitializer<GlobalF32Op>>(context);
+}
+
+void GlobalF64Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<InlineConstGlobalOpInitializer<GlobalF64Op>,
+ DropDefaultConstGlobalOpInitializer<GlobalF64Op>>(context);
+}
+
void GlobalRefOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<InlineConstGlobalOpInitializer<GlobalRefOp>>(context);
@@ -174,6 +191,20 @@
context);
}
+void GlobalLoadF32Op::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<InlineConstGlobalLoadIntegerOp<GlobalLoadF32Op, GlobalF32Op,
+ ConstF32Op, ConstF32ZeroOp>>(
+ context);
+}
+
+void GlobalLoadF64Op::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<InlineConstGlobalLoadIntegerOp<GlobalLoadF64Op, GlobalF64Op,
+ ConstF64Op, ConstF64ZeroOp>>(
+ context);
+}
+
namespace {
/// Inlines immutable global constants into their loads.
@@ -233,6 +264,20 @@
context);
}
+void GlobalLoadIndirectF32Op::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<
+ PropagateGlobalLoadAddress<GlobalLoadIndirectF32Op, GlobalLoadF32Op>>(
+ context);
+}
+
+void GlobalLoadIndirectF64Op::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<
+ PropagateGlobalLoadAddress<GlobalLoadIndirectF64Op, GlobalLoadF64Op>>(
+ context);
+}
+
void GlobalLoadIndirectRefOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<
@@ -273,6 +318,20 @@
context);
}
+void GlobalStoreIndirectF32Op::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<
+ PropagateGlobalStoreAddress<GlobalStoreIndirectF32Op, GlobalStoreF32Op>>(
+ context);
+}
+
+void GlobalStoreIndirectF64Op::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<
+ PropagateGlobalStoreAddress<GlobalStoreIndirectF64Op, GlobalStoreF64Op>>(
+ context);
+}
+
void GlobalStoreIndirectRefOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<
@@ -287,7 +346,7 @@
namespace {
template <typename GeneralOp, typename ZeroOp>
-struct FoldZeroConstInteger final : public OpRewritePattern<GeneralOp> {
+struct FoldZeroConstPrimitive final : public OpRewritePattern<GeneralOp> {
using OpRewritePattern<GeneralOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GeneralOp constOp,
@@ -306,14 +365,28 @@
void ConstI32Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<FoldZeroConstInteger<ConstI32Op, ConstI32ZeroOp>>(context);
+ results.insert<FoldZeroConstPrimitive<ConstI32Op, ConstI32ZeroOp>>(context);
}
OpFoldResult ConstI64Op::fold(ArrayRef<Attribute> operands) { return value(); }
void ConstI64Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<FoldZeroConstInteger<ConstI64Op, ConstI64ZeroOp>>(context);
+ results.insert<FoldZeroConstPrimitive<ConstI64Op, ConstI64ZeroOp>>(context);
+}
+
+OpFoldResult ConstF32Op::fold(ArrayRef<Attribute> operands) { return value(); }
+
+void ConstF32Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<FoldZeroConstPrimitive<ConstF32Op, ConstF32ZeroOp>>(context);
+}
+
+OpFoldResult ConstF64Op::fold(ArrayRef<Attribute> operands) { return value(); }
+
+void ConstF64Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<FoldZeroConstPrimitive<ConstF64Op, ConstF64ZeroOp>>(context);
}
OpFoldResult ConstI32ZeroOp::fold(ArrayRef<Attribute> operands) {
@@ -324,6 +397,14 @@
return IntegerAttr::get(getResult().getType(), 0);
}
+OpFoldResult ConstF32ZeroOp::fold(ArrayRef<Attribute> operands) {
+ return FloatAttr::get(getResult().getType(), 0.0f);
+}
+
+OpFoldResult ConstF64ZeroOp::fold(ArrayRef<Attribute> operands) {
+ return FloatAttr::get(getResult().getType(), 0.0);
+}
+
OpFoldResult ConstRefZeroOp::fold(ArrayRef<Attribute> operands) {
// TODO(b/144027097): relace unit attr with a proper null ref attr.
return UnitAttr::get(getContext());
@@ -360,6 +441,14 @@
return foldSelectOp(*this);
}
+OpFoldResult SelectF32Op::fold(ArrayRef<Attribute> operands) {
+ return foldSelectOp(*this);
+}
+
+OpFoldResult SelectF64Op::fold(ArrayRef<Attribute> operands) {
+ return foldSelectOp(*this);
+}
+
OpFoldResult SelectRefOp::fold(ArrayRef<Attribute> operands) {
return foldSelectOp(*this);
}
@@ -400,19 +489,27 @@
return foldSwitchOp(*this);
}
+OpFoldResult SwitchF32Op::fold(ArrayRef<Attribute> operands) {
+ return foldSwitchOp(*this);
+}
+
+OpFoldResult SwitchF64Op::fold(ArrayRef<Attribute> operands) {
+ return foldSwitchOp(*this);
+}
+
OpFoldResult SwitchRefOp::fold(ArrayRef<Attribute> operands) {
return foldSwitchOp(*this);
}
//===----------------------------------------------------------------------===//
-// Native integer arithmetic
+// Integer arithmetic
//===----------------------------------------------------------------------===//
/// Performs const folding `calculate` with element-wise behavior on the given
/// attribute in `operands` and returns the result if possible.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
- class CalculationT = std::function<ElementValueT(ElementValueT)>>
+ class CalculationT = std::function<APInt(ElementValueT)>>
static Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
const CalculationT &calculate) {
assert(operands.size() == 1 && "unary op takes one operand");
@@ -432,6 +529,29 @@
return {};
}
+/// Performs const folding `calculate` with element-wise behavior on the given
+/// attribute in `operands` and returns the result if possible.
+static Attribute constFoldFloatUnaryOp(
+ ArrayRef<Attribute> operands,
+ const std::function<APFloat(APFloat)> &calculate) {
+ assert(operands.size() == 1 && "unary op takes one operand");
+ if (auto operand = operands[0].dyn_cast_or_null<FloatAttr>()) {
+ return FloatAttr::get(operand.getType(), calculate(operand.getValue()));
+ } else if (auto operand = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
+ auto elementResult =
+ constFoldFloatUnaryOp({operand.getSplatValue()}, calculate);
+ if (!elementResult) return {};
+ return DenseElementsAttr::get(operand.getType(), elementResult);
+ } else if (auto operand = operands[0].dyn_cast_or_null<ElementsAttr>()) {
+ return operand.mapValues(
+ operand.getType().getElementType(),
+ llvm::function_ref<APInt(const APFloat &)>([&](const APFloat &value) {
+ return calculate(value).bitcastToAPInt();
+ }));
+ }
+ return {};
+}
+
/// Performs const folding `calculate` with element-wise behavior on the two
/// attributes in `operands` and returns the result if possible.
template <class AttrElementT,
@@ -472,7 +592,8 @@
return {};
}
-template <typename ADD, typename SUB>
+template <class AttrElementT, typename ADD, typename SUB,
+ class ElementValueT = typename AttrElementT::ValueType>
static OpFoldResult foldAddOp(ADD op, ArrayRef<Attribute> operands) {
if (matchPattern(op.rhs(), m_Zero())) {
// x + 0 = x or 0 + y = y (commutative)
@@ -485,19 +606,21 @@
if (subOp.lhs() == op.lhs()) return subOp.rhs();
if (subOp.rhs() == op.lhs()) return subOp.lhs();
}
- return constFoldBinaryOp<IntegerAttr>(
- operands, [](const APInt &a, const APInt &b) { return a + b; });
+ return constFoldBinaryOp<AttrElementT>(
+ operands,
+ [](const ElementValueT &a, const ElementValueT &b) { return a + b; });
}
OpFoldResult AddI32Op::fold(ArrayRef<Attribute> operands) {
- return foldAddOp<AddI32Op, SubI32Op>(*this, operands);
+ return foldAddOp<IntegerAttr, AddI32Op, SubI32Op>(*this, operands);
}
OpFoldResult AddI64Op::fold(ArrayRef<Attribute> operands) {
- return foldAddOp<AddI64Op, SubI64Op>(*this, operands);
+ return foldAddOp<IntegerAttr, AddI64Op, SubI64Op>(*this, operands);
}
-template <typename SUB, typename ADD>
+template <class AttrElementT, typename SUB, typename ADD,
+ class ElementValueT = typename AttrElementT::ValueType>
static OpFoldResult foldSubOp(SUB op, ArrayRef<Attribute> operands) {
if (matchPattern(op.rhs(), m_Zero())) {
// x - 0 = x
@@ -510,19 +633,21 @@
if (addOp.lhs() == op.lhs()) return addOp.rhs();
if (addOp.rhs() == op.lhs()) return addOp.lhs();
}
- return constFoldBinaryOp<IntegerAttr>(
- operands, [](const APInt &a, const APInt &b) { return a - b; });
+ return constFoldBinaryOp<AttrElementT>(
+ operands,
+ [](const ElementValueT &a, const ElementValueT &b) { return a - b; });
}
OpFoldResult SubI32Op::fold(ArrayRef<Attribute> operands) {
- return foldSubOp<SubI32Op, AddI32Op>(*this, operands);
+ return foldSubOp<IntegerAttr, SubI32Op, AddI32Op>(*this, operands);
}
OpFoldResult SubI64Op::fold(ArrayRef<Attribute> operands) {
- return foldSubOp<SubI64Op, AddI64Op>(*this, operands);
+ return foldSubOp<IntegerAttr, SubI64Op, AddI64Op>(*this, operands);
}
-template <typename T>
+template <class AttrElementT, typename T,
+ class ElementValueT = typename AttrElementT::ValueType>
static OpFoldResult foldMulOp(T op, ArrayRef<Attribute> operands) {
if (matchPattern(op.rhs(), m_Zero())) {
// x * 0 = 0 or 0 * y = 0 (commutative)
@@ -531,25 +656,28 @@
// x * 1 = x or 1 * y = y (commutative)
return op.lhs();
}
- return constFoldBinaryOp<IntegerAttr>(
- operands, [](const APInt &a, const APInt &b) { return a * b; });
+ return constFoldBinaryOp<AttrElementT>(
+ operands,
+ [](const ElementValueT &a, const ElementValueT &b) { return a * b; });
}
-template <typename T, typename CONST_OP>
+template <class AttrElementT, typename T, typename CONST_OP,
+ class ElementValueT = typename AttrElementT::ValueType>
struct FoldConstantMulOperand : public OpRewritePattern<T> {
using OpRewritePattern<T>::OpRewritePattern;
LogicalResult matchAndRewrite(T op,
PatternRewriter &rewriter) const override {
- IntegerAttr c1, c2;
+ AttrElementT c1, c2;
if (!matchPattern(op.rhs(), m_Constant(&c1))) return failure();
if (auto mulOp = dyn_cast_or_null<T>(op.lhs().getDefiningOp())) {
if (matchPattern(mulOp.rhs(), m_Constant(&c2))) {
auto c = rewriter.createOrFold<CONST_OP>(
rewriter.getFusedLoc({mulOp.getLoc(), op.getLoc()}),
- constFoldBinaryOp<IntegerAttr>(
- {c1, c2},
- [](const APInt &a, const APInt &b) { return a * b; }));
+ constFoldBinaryOp<AttrElementT>(
+ {c1, c2}, [](const ElementValueT &a, const ElementValueT &b) {
+ return a * b;
+ }));
rewriter.replaceOpWithNewOp<T>(op, op.getType(), mulOp.lhs(), c);
return success();
}
@@ -559,21 +687,23 @@
};
OpFoldResult MulI32Op::fold(ArrayRef<Attribute> operands) {
- return foldMulOp(*this, operands);
+ return foldMulOp<IntegerAttr>(*this, operands);
}
void MulI32Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<FoldConstantMulOperand<MulI32Op, ConstI32Op>>(context);
+ results.insert<FoldConstantMulOperand<IntegerAttr, MulI32Op, ConstI32Op>>(
+ context);
}
OpFoldResult MulI64Op::fold(ArrayRef<Attribute> operands) {
- return foldMulOp(*this, operands);
+ return foldMulOp<IntegerAttr>(*this, operands);
}
void MulI64Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<FoldConstantMulOperand<MulI64Op, ConstI64Op>>(context);
+ results.insert<FoldConstantMulOperand<IntegerAttr, MulI64Op, ConstI64Op>>(
+ context);
}
template <typename T>
@@ -652,7 +782,12 @@
template <typename T>
static OpFoldResult foldRemUOp(T op, ArrayRef<Attribute> operands) {
- if (matchPattern(op.lhs(), m_Zero()) || matchPattern(op.rhs(), m_One())) {
+ if (matchPattern(op.rhs(), m_Zero())) {
+ // x % 0 = death
+ op.emitOpError() << "is a remainder by constant zero";
+ return {};
+ } else if (matchPattern(op.lhs(), m_Zero()) ||
+ matchPattern(op.rhs(), m_One())) {
// x % 1 = 0
// 0 % y = 0
return zeroOfType(op.getType());
@@ -669,6 +804,171 @@
return foldRemUOp(*this, operands);
}
+//===----------------------------------------------------------------------===//
+// Floating-point arithmetic
+//===----------------------------------------------------------------------===//
+
+OpFoldResult AddF32Op::fold(ArrayRef<Attribute> operands) {
+ return foldAddOp<FloatAttr, AddF32Op, SubF32Op>(*this, operands);
+}
+
+OpFoldResult AddF64Op::fold(ArrayRef<Attribute> operands) {
+ return foldAddOp<FloatAttr, AddF64Op, SubF64Op>(*this, operands);
+}
+
+OpFoldResult SubF32Op::fold(ArrayRef<Attribute> operands) {
+ return foldSubOp<FloatAttr, SubF32Op, AddF32Op>(*this, operands);
+}
+
+OpFoldResult SubF64Op::fold(ArrayRef<Attribute> operands) {
+ return foldSubOp<FloatAttr, SubF64Op, AddF64Op>(*this, operands);
+}
+
+OpFoldResult MulF32Op::fold(ArrayRef<Attribute> operands) {
+ return foldMulOp<FloatAttr>(*this, operands);
+}
+
+void MulF32Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<FoldConstantMulOperand<FloatAttr, MulF32Op, ConstF32Op>>(
+ context);
+}
+
+OpFoldResult MulF64Op::fold(ArrayRef<Attribute> operands) {
+ return foldMulOp<FloatAttr>(*this, operands);
+}
+
+void MulF64Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<FoldConstantMulOperand<FloatAttr, MulF64Op, ConstF64Op>>(
+ context);
+}
+
+template <typename T>
+static OpFoldResult foldDivFOp(T op, ArrayRef<Attribute> operands) {
+ if (matchPattern(op.rhs(), m_Zero())) {
+ // x / 0 = death
+ op.emitOpError() << "is a divide by constant zero";
+ return {};
+ } else if (matchPattern(op.lhs(), m_Zero())) {
+ // 0 / y = 0
+ return zeroOfType(op.getType());
+ } else if (matchPattern(op.rhs(), m_One())) {
+ // x / 1 = x
+ return op.lhs();
+ }
+ return constFoldBinaryOp<FloatAttr>(operands,
+ [](const APFloat &a, const APFloat &b) {
+ APFloat c = a;
+ c.divide(b, APFloat::rmTowardZero);
+ return c;
+ });
+}
+
+OpFoldResult DivF32Op::fold(ArrayRef<Attribute> operands) {
+ return foldDivFOp(*this, operands);
+}
+
+OpFoldResult DivF64Op::fold(ArrayRef<Attribute> operands) {
+ return foldDivFOp(*this, operands);
+}
+
+template <typename T>
+static OpFoldResult foldRemFOp(T op, ArrayRef<Attribute> operands) {
+ if (matchPattern(op.rhs(), m_Zero())) {
+ // x % 0 = death
+ op.emitOpError() << "is a remainder by constant zero";
+ return {};
+ } else if (matchPattern(op.lhs(), m_Zero()) ||
+ matchPattern(op.rhs(), m_One())) {
+ // x % 1 = 0
+ // 0 % y = 0
+ return zeroOfType(op.getType());
+ }
+ return constFoldBinaryOp<FloatAttr>(operands,
+ [](const APFloat &a, const APFloat &b) {
+ APFloat c = a;
+ c.remainder(b);
+ return c;
+ });
+}
+
+OpFoldResult RemF32Op::fold(ArrayRef<Attribute> operands) {
+ return foldRemFOp(*this, operands);
+}
+
+OpFoldResult RemF64Op::fold(ArrayRef<Attribute> operands) {
+ return foldRemFOp(*this, operands);
+}
+
+OpFoldResult AbsF32Op::fold(ArrayRef<Attribute> operands) {
+ return constFoldFloatUnaryOp(operands, [](const APFloat &a) {
+ auto b = a;
+ b.clearSign();
+ return b;
+ });
+}
+
+OpFoldResult AbsF64Op::fold(ArrayRef<Attribute> operands) {
+ return constFoldFloatUnaryOp(operands, [](const APFloat &a) {
+ auto b = a;
+ b.clearSign();
+ return b;
+ });
+}
+
+OpFoldResult NegF32Op::fold(ArrayRef<Attribute> operands) {
+ return constFoldFloatUnaryOp(operands, [](const APFloat &a) {
+ auto b = a;
+ b.changeSign();
+ return b;
+ });
+}
+
+OpFoldResult NegF64Op::fold(ArrayRef<Attribute> operands) {
+ return constFoldFloatUnaryOp(operands, [](const APFloat &a) {
+ auto b = a;
+ b.changeSign();
+ return b;
+ });
+}
+
+OpFoldResult CeilF32Op::fold(ArrayRef<Attribute> operands) {
+ return constFoldFloatUnaryOp(operands, [](const APFloat &a) {
+ auto b = a;
+ b.roundToIntegral(APFloat::rmTowardPositive);
+ return b;
+ });
+}
+
+OpFoldResult CeilF64Op::fold(ArrayRef<Attribute> operands) {
+ return constFoldFloatUnaryOp(operands, [](const APFloat &a) {
+ auto b = a;
+ b.roundToIntegral(APFloat::rmTowardPositive);
+ return b;
+ });
+}
+
+OpFoldResult FloorF32Op::fold(ArrayRef<Attribute> operands) {
+ return constFoldFloatUnaryOp(operands, [](const APFloat &a) {
+ auto b = a;
+ b.roundToIntegral(APFloat::rmTowardNegative);
+ return b;
+ });
+}
+
+OpFoldResult FloorF64Op::fold(ArrayRef<Attribute> operands) {
+ return constFoldFloatUnaryOp(operands, [](const APFloat &a) {
+ auto b = a;
+ b.roundToIntegral(APFloat::rmTowardNegative);
+ return b;
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// Integer bit manipulation
+//===----------------------------------------------------------------------===//
+
template <typename T>
static OpFoldResult foldNotOp(T op, ArrayRef<Attribute> operands) {
return constFoldUnaryOp<IntegerAttr>(operands, [](APInt a) {
@@ -864,6 +1164,12 @@
[&](const APInt &a) { return a.trunc(32); });
}
+OpFoldResult TruncF64F32Op::fold(ArrayRef<Attribute> operands) {
+ return constFoldConversionOp<FloatAttr>(
+ FloatType::getF32(getContext()), operands,
+ [&](const APFloat &a) { return APFloat(a.convertToFloat()); });
+}
+
OpFoldResult ExtI8I32SOp::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(getContext(), 32), operands,
@@ -924,6 +1230,12 @@
[&](const APInt &a) { return a.zext(64); });
}
+OpFoldResult ExtF32F64Op::fold(ArrayRef<Attribute> operands) {
+ return constFoldConversionOp<FloatAttr>(
+ FloatType::getF64(getContext()), operands,
+ [&](const APFloat &a) { return APFloat(a.convertToDouble()); });
+}
+
namespace {
template <typename SRC_OP, typename OP_A, int SZ_T, typename OP_B>
@@ -988,6 +1300,27 @@
namespace {
+/// Performs const folding `calculate` with element-wise behavior on the given
+/// attribute in `operands` and returns the result if possible.
+template <class AttrElementT,
+ class ElementValueT = typename AttrElementT::ValueType,
+ class CalculationT = std::function<APInt(ElementValueT)>>
+static Attribute constFoldCmpOp(ArrayRef<Attribute> operands,
+ const CalculationT &calculate) {
+ assert(operands.size() == 1 && "unary op takes one operand");
+ if (auto operand = operands[0].dyn_cast_or_null<AttrElementT>()) {
+ auto boolType = IntegerType::get(operand.getContext(), 32);
+ return IntegerAttr::get(boolType, calculate(operand.getValue()));
+ } else if (auto operand = operands[0].dyn_cast_or_null<ElementsAttr>()) {
+ auto boolType = IntegerType::get(operand.getContext(), 32);
+ return operand.mapValues(
+ boolType,
+ llvm::function_ref<APInt(const ElementValueT &)>(
+ [&](const ElementValueT &value) { return calculate(value); }));
+ }
+ return {};
+}
+
/// Swaps the cmp op with its inverse if the result is inverted.
template <typename OP, typename INV>
struct SwapInvertedCmpOps : public OpRewritePattern<OP> {
@@ -1047,6 +1380,32 @@
}
template <typename T>
+static OpFoldResult foldCmpEQFOp(T op, ArrayRef<Attribute> operands) {
+ if (op.lhs() == op.rhs()) {
+ // x == x = true
+ return oneOfType(op.getType());
+ }
+ return constFoldBinaryOp<FloatAttr>(
+ operands, [&](const APFloat &a, const APFloat &b) {
+ return a.compare(b) == APFloat::cmpEqual;
+ });
+}
+
+OpFoldResult CmpEQF32Op::fold(ArrayRef<Attribute> operands) {
+ return foldCmpEQFOp(*this, operands);
+}
+
+void CmpEQF32Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {}
+
+OpFoldResult CmpEQF64Op::fold(ArrayRef<Attribute> operands) {
+ return foldCmpEQFOp(*this, operands);
+}
+
+void CmpEQF64Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {}
+
+template <typename T>
static OpFoldResult foldCmpNEOp(T op, ArrayRef<Attribute> operands) {
if (op.lhs() == op.rhs()) {
// x != x = false
@@ -1064,6 +1423,26 @@
return foldCmpNEOp(*this, operands);
}
+template <typename T>
+static OpFoldResult foldCmpNEFOp(T op, ArrayRef<Attribute> operands) {
+ if (op.lhs() == op.rhs()) {
+ // x != x = false
+ return zeroOfType(op.getType());
+ }
+ return constFoldBinaryOp<FloatAttr>(
+ operands, [&](const APFloat &a, const APFloat &b) {
+ return a.compare(b) != APFloat::cmpEqual;
+ });
+}
+
+OpFoldResult CmpNEF32Op::fold(ArrayRef<Attribute> operands) {
+ return foldCmpNEFOp(*this, operands);
+}
+
+OpFoldResult CmpNEF64Op::fold(ArrayRef<Attribute> operands) {
+ return foldCmpNEFOp(*this, operands);
+}
+
namespace {
/// Changes a cmp.ne.i32 check against 0 to a cmp.nz.i32.
@@ -1094,6 +1473,18 @@
CmpNEZeroToCmpNZ<CmpNEI64Op, CmpNZI64Op>>(context);
}
+void CmpNEF32Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<SwapInvertedCmpOps<CmpNEF32Op, CmpEQF32Op>,
+ CmpNEZeroToCmpNZ<CmpNEF32Op, CmpNZF32Op>>(context);
+}
+
+void CmpNEF64Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<SwapInvertedCmpOps<CmpNEF64Op, CmpEQF64Op>,
+ CmpNEZeroToCmpNZ<CmpNEF64Op, CmpNZF64Op>>(context);
+}
+
template <typename T>
static OpFoldResult foldCmpLTSOp(T op, ArrayRef<Attribute> operands) {
if (op.lhs() == op.rhs()) {
@@ -1142,6 +1533,32 @@
void CmpLTI64UOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {}
+template <typename T>
+static OpFoldResult foldCmpLTFOp(T op, ArrayRef<Attribute> operands) {
+ if (op.lhs() == op.rhs()) {
+ // x < x = false
+ return zeroOfType(op.getType());
+ }
+ return constFoldBinaryOp<FloatAttr>(
+ operands, [&](const APFloat &a, const APFloat &b) {
+ return a.compare(b) == APFloat::cmpLessThan;
+ });
+}
+
+OpFoldResult CmpLTF32Op::fold(ArrayRef<Attribute> operands) {
+ return foldCmpLTFOp(*this, operands);
+}
+
+OpFoldResult CmpLTF64Op::fold(ArrayRef<Attribute> operands) {
+ return foldCmpLTFOp(*this, operands);
+}
+
+void CmpLTF32Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {}
+
+void CmpLTF64Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {}
+
namespace {
/// Rewrites a vm.cmp.lte.* pseudo op to a vm.cmp.lt.* op.
@@ -1222,6 +1639,39 @@
results.insert<RewritePseudoCmpLTEToLT<CmpLTEI64UOp, CmpLTI64UOp>>(context);
}
+template <typename T>
+static OpFoldResult foldCmpLTEFOp(T op, ArrayRef<Attribute> operands) {
+ if (op.lhs() == op.rhs()) {
+ // x <= x = true
+ return oneOfType(op.getType());
+ }
+ return constFoldBinaryOp<FloatAttr>(
+ operands, [&](const APFloat &a, const APFloat &b) {
+ return a.compare(b) == APFloat::cmpLessThan ||
+ a.compare(b) == APFloat::cmpEqual;
+ });
+}
+
+OpFoldResult CmpLTEF32Op::fold(ArrayRef<Attribute> operands) {
+ return foldCmpLTEFOp(*this, operands);
+}
+
+OpFoldResult CmpLTEF64Op::fold(ArrayRef<Attribute> operands) {
+ return foldCmpLTEFOp(*this, operands);
+}
+
+void CmpLTEF32Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<SwapInvertedCmpOps<CmpLTEF32Op, CmpGTF32Op>>(context);
+ results.insert<RewritePseudoCmpLTEToLT<CmpLTEF32Op, CmpLTF32Op>>(context);
+}
+
+void CmpLTEF64Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<SwapInvertedCmpOps<CmpLTEF64Op, CmpGTF64Op>>(context);
+ results.insert<RewritePseudoCmpLTEToLT<CmpLTEF64Op, CmpLTF64Op>>(context);
+}
+
namespace {
/// Rewrites a vm.cmp.gt.* pseudo op to a vm.cmp.lt.* op.
@@ -1298,6 +1748,38 @@
results.insert<RewritePseudoCmpGTToLT<CmpGTI64UOp, CmpLTI64UOp>>(context);
}
+template <typename T>
+static OpFoldResult foldCmpGTFOp(T op, ArrayRef<Attribute> operands) {
+ if (op.lhs() == op.rhs()) {
+ // x > x = false
+ return zeroOfType(op.getType());
+ }
+ return constFoldBinaryOp<FloatAttr>(
+ operands, [&](const APFloat &a, const APFloat &b) {
+ return a.compare(b) == APFloat::cmpGreaterThan;
+ });
+}
+
+OpFoldResult CmpGTF32Op::fold(ArrayRef<Attribute> operands) {
+ return foldCmpGTFOp(*this, operands);
+}
+
+OpFoldResult CmpGTF64Op::fold(ArrayRef<Attribute> operands) {
+ return foldCmpGTFOp(*this, operands);
+}
+
+void CmpGTF32Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<SwapInvertedCmpOps<CmpGTF32Op, CmpLTEF32Op>>(context);
+ results.insert<RewritePseudoCmpGTToLT<CmpGTF32Op, CmpLTF32Op>>(context);
+}
+
+void CmpGTF64Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<SwapInvertedCmpOps<CmpGTF64Op, CmpLTEF64Op>>(context);
+ results.insert<RewritePseudoCmpGTToLT<CmpGTF64Op, CmpLTF64Op>>(context);
+}
+
namespace {
/// Rewrites a vm.cmp.gte.* pseudo op to a vm.cmp.lt.* op.
@@ -1349,6 +1831,39 @@
}
template <typename T>
+static OpFoldResult foldCmpGTEFOp(T op, ArrayRef<Attribute> operands) {
+ if (op.lhs() == op.rhs()) {
+ // x >= x = true
+ return oneOfType(op.getType());
+ }
+ return constFoldBinaryOp<FloatAttr>(
+ operands, [&](const APFloat &a, const APFloat &b) {
+ return a.compare(b) == APFloat::cmpGreaterThan ||
+ a.compare(b) == APFloat::cmpEqual;
+ });
+}
+
+OpFoldResult CmpGTEF32Op::fold(ArrayRef<Attribute> operands) {
+ return foldCmpGTEFOp(*this, operands);
+}
+
+OpFoldResult CmpGTEF64Op::fold(ArrayRef<Attribute> operands) {
+ return foldCmpGTESOp(*this, operands);
+}
+
+void CmpGTEF32Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<SwapInvertedCmpOps<CmpGTEF32Op, CmpLTF32Op>>(context);
+ results.insert<RewritePseudoCmpGTEToLT<CmpGTEF32Op, CmpLTF32Op>>(context);
+}
+
+void CmpGTEF64Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<SwapInvertedCmpOps<CmpGTEF64Op, CmpLTF64Op>>(context);
+ results.insert<RewritePseudoCmpGTEToLT<CmpGTEF64Op, CmpLTF64Op>>(context);
+}
+
+template <typename T>
static OpFoldResult foldCmpGTEUOp(T op, ArrayRef<Attribute> operands) {
if (op.lhs() == op.rhs()) {
// x >= x = true
@@ -1388,6 +1903,16 @@
operands, [&](const APInt &a) { return APInt(64, a.getBoolValue()); });
}
+OpFoldResult CmpNZF32Op::fold(ArrayRef<Attribute> operands) {
+ return constFoldCmpOp<FloatAttr>(
+ operands, [&](const APFloat &a) { return APInt(32, a.isNonZero()); });
+}
+
+OpFoldResult CmpNZF64Op::fold(ArrayRef<Attribute> operands) {
+ return constFoldCmpOp<FloatAttr>(
+ operands, [&](const APFloat &a) { return APInt(32, a.isNonZero()); });
+}
+
OpFoldResult CmpEQRefOp::fold(ArrayRef<Attribute> operands) {
if (lhs() == rhs()) {
// x == x = true
@@ -1768,7 +2293,7 @@
/// Rewrites a check op to a cmp and a cond_fail.
template <typename CheckOp, typename CmpI32Op, typename CmpI64Op,
- typename CmpRefOp>
+ typename CmpF32Op, typename CmpF64Op, typename CmpRefOp>
struct RewriteCheckToCondFail : public OpRewritePattern<CheckOp> {
using OpRewritePattern<CheckOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CheckOp op,
@@ -1780,12 +2305,20 @@
condValue = rewriter.template createOrFold<CmpRefOp>(
op.getLoc(), ArrayRef<Type>{condType},
op.getOperation()->getOperands());
+ } else if (operandType.isInteger(32)) {
+ condValue = rewriter.template createOrFold<CmpI32Op>(
+ op.getLoc(), ArrayRef<Type>{condType},
+ op.getOperation()->getOperands());
} else if (operandType.isInteger(64)) {
condValue = rewriter.template createOrFold<CmpI64Op>(
op.getLoc(), ArrayRef<Type>{condType},
op.getOperation()->getOperands());
- } else if (operandType.isInteger(32)) {
- condValue = rewriter.template createOrFold<CmpI32Op>(
+ } else if (operandType.isF32()) {
+ condValue = rewriter.template createOrFold<CmpF32Op>(
+ op.getLoc(), ArrayRef<Type>{condType},
+ op.getOperation()->getOperands());
+ } else if (operandType.isF64()) {
+ condValue = rewriter.template createOrFold<CmpF64Op>(
op.getLoc(), ArrayRef<Type>{condType},
op.getOperation()->getOperands());
} else {
@@ -1806,22 +2339,22 @@
void CheckEQOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<
- RewriteCheckToCondFail<CheckEQOp, CmpEQI32Op, CmpEQI64Op, CmpEQRefOp>>(
+ results.insert<RewriteCheckToCondFail<CheckEQOp, CmpEQI32Op, CmpEQI64Op,
+ CmpEQF32Op, CmpEQF64Op, CmpEQRefOp>>(
context);
}
void CheckNEOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<
- RewriteCheckToCondFail<CheckNEOp, CmpNEI32Op, CmpNEI64Op, CmpNERefOp>>(
+ results.insert<RewriteCheckToCondFail<CheckNEOp, CmpNEI32Op, CmpNEI64Op,
+ CmpNEF32Op, CmpNEF64Op, CmpNERefOp>>(
context);
}
void CheckNZOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<
- RewriteCheckToCondFail<CheckNZOp, CmpNZI32Op, CmpNZI64Op, CmpNZRefOp>>(
+ results.insert<RewriteCheckToCondFail<CheckNZOp, CmpNZI32Op, CmpNZI64Op,
+ CmpNZF32Op, CmpNZF64Op, CmpNZRefOp>>(
context);
}
diff --git a/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td b/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td
new file mode 100644
index 0000000..f58dd2d
--- /dev/null
+++ b/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td
@@ -0,0 +1,82 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_DIALECT_VM_OPCODES_F32
+#define IREE_DIALECT_VM_OPCODES_F32
+
+include "iree/compiler/Dialect/VM/IR/VMBase.td"
+include "iree/compiler/Dialect/VM/IR/VMOpcodesCore.td"
+
+//===----------------------------------------------------------------------===//
+// F32 VM Opcode Extension
+//===----------------------------------------------------------------------===//
+// Ops are encoded as a VM_OPC_ExtF32 + the opcode below.
+
+def VM_OPC_GlobalLoadF32 : VM_OPC<0x00, "GlobalLoadF32">;
+def VM_OPC_GlobalStoreF32 : VM_OPC<0x01, "GlobalStoreF32">;
+def VM_OPC_GlobalLoadIndirectF32 : VM_OPC<0x02, "GlobalLoadIndirectF32">;
+def VM_OPC_GlobalStoreIndirectF32: VM_OPC<0x03, "GlobalStoreIndirectF32">;
+def VM_OPC_ConstF32Zero : VM_OPC<0x08, "ConstF32Zero">;
+def VM_OPC_ConstF32 : VM_OPC<0x09, "ConstF32">;
+def VM_OPC_ListGetF32 : VM_OPC<0x14, "ListGetF32">;
+def VM_OPC_ListSetF32 : VM_OPC<0x15, "ListSetF32">;
+def VM_OPC_SelectF32 : VM_OPC<0x1E, "SelectF32">;
+def VM_OPC_SwitchF32 : VM_OPC<0x20, "SwitchF32">;
+def VM_OPC_AddF32 : VM_OPC<0x22, "AddF32">;
+def VM_OPC_SubF32 : VM_OPC<0x23, "SubF32">;
+def VM_OPC_MulF32 : VM_OPC<0x24, "MulF32">;
+def VM_OPC_DivF32 : VM_OPC<0x25, "DivF32">;
+def VM_OPC_RemF32 : VM_OPC<0x27, "RemF32">;
+def VM_OPC_AbsF32 : VM_OPC<0x2D, "AbsF32">;
+def VM_OPC_NegF32 : VM_OPC<0x2E, "NegF32">;
+def VM_OPC_CeilF32 : VM_OPC<0x2F, "CeilF32">;
+def VM_OPC_FloorF32 : VM_OPC<0x32, "FloorF32">;
+def VM_OPC_CmpEQF32 : VM_OPC<0x40, "CmpEQF32">;
+def VM_OPC_CmpNEF32 : VM_OPC<0x41, "CmpNEF32">;
+def VM_OPC_CmpNZF32 : VM_OPC<0x42, "CmpNZF32">;
+def VM_OPC_CmpLTF32 : VM_OPC<0x43, "CmpLTF32">;
+
+// Runtime enum iree_vm_ext_f32_op_t:
+def VM_ExtF32OpcodeAttr :
+ VM_OPC_EnumAttr<"ExtF32Opcode",
+ "iree_vm_ext_f32_op_t",
+ "EXT_F32", // IREE_VM_OP_EXT_F32_*
+ "valid VM operation encodings in the f32 extension",
+ VM_OPC_PrefixExtF32, [
+ VM_OPC_GlobalLoadF32,
+ VM_OPC_GlobalStoreF32,
+ VM_OPC_GlobalLoadIndirectF32,
+ VM_OPC_GlobalStoreIndirectF32,
+ VM_OPC_ConstF32Zero,
+ VM_OPC_ConstF32,
+ VM_OPC_ListGetF32,
+ VM_OPC_ListSetF32,
+ VM_OPC_SelectF32,
+ VM_OPC_SwitchF32,
+ VM_OPC_AddF32,
+ VM_OPC_SubF32,
+ VM_OPC_MulF32,
+ VM_OPC_DivF32,
+ VM_OPC_RemF32,
+ VM_OPC_AbsF32,
+ VM_OPC_NegF32,
+ VM_OPC_CeilF32,
+ VM_OPC_FloorF32,
+ VM_OPC_CmpEQF32,
+ VM_OPC_CmpNEF32,
+ VM_OPC_CmpNZF32,
+ VM_OPC_CmpLTF32,
+ ]>;
+
+#endif // IREE_DIALECT_VM_OPCODES_F32
diff --git a/iree/compiler/Dialect/VM/IR/VMOpcodesF64.td b/iree/compiler/Dialect/VM/IR/VMOpcodesF64.td
new file mode 100644
index 0000000..7c4501e
--- /dev/null
+++ b/iree/compiler/Dialect/VM/IR/VMOpcodesF64.td
@@ -0,0 +1,86 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_DIALECT_VM_OPCODES_F64
+#define IREE_DIALECT_VM_OPCODES_F64
+
+include "iree/compiler/Dialect/VM/IR/VMBase.td"
+include "iree/compiler/Dialect/VM/IR/VMOpcodesCore.td"
+
+//===----------------------------------------------------------------------===//
+// F64 VM Opcode Extension
+//===----------------------------------------------------------------------===//
+// Ops are encoded as a VM_OPC_ExtF64 + the opcode below.
+
+def VM_OPC_GlobalLoadF64 : VM_OPC<0x00, "GlobalLoadF64">;
+def VM_OPC_GlobalStoreF64 : VM_OPC<0x01, "GlobalStoreF64">;
+def VM_OPC_GlobalLoadIndirectF64 : VM_OPC<0x02, "GlobalLoadIndirectF64">;
+def VM_OPC_GlobalStoreIndirectF64: VM_OPC<0x03, "GlobalStoreIndirectF64">;
+def VM_OPC_ConstF64Zero : VM_OPC<0x08, "ConstF64Zero">;
+def VM_OPC_ConstF64 : VM_OPC<0x09, "ConstF64">;
+def VM_OPC_ListGetF64 : VM_OPC<0x14, "ListGetF64">;
+def VM_OPC_ListSetF64 : VM_OPC<0x15, "ListSetF64">;
+def VM_OPC_SelectF64 : VM_OPC<0x1E, "SelectF64">;
+def VM_OPC_SwitchF64 : VM_OPC<0x20, "SwitchF64">;
+def VM_OPC_AddF64 : VM_OPC<0x22, "AddF64">;
+def VM_OPC_SubF64 : VM_OPC<0x23, "SubF64">;
+def VM_OPC_MulF64 : VM_OPC<0x24, "MulF64">;
+def VM_OPC_DivF64 : VM_OPC<0x25, "DivF64">;
+def VM_OPC_RemF64 : VM_OPC<0x27, "RemF64">;
+def VM_OPC_AbsF64 : VM_OPC<0x2D, "AbsF64">;
+def VM_OPC_NegF64 : VM_OPC<0x2E, "NegF64">;
+def VM_OPC_CeilF64 : VM_OPC<0x2F, "CeilF64">;
+def VM_OPC_FloorF64 : VM_OPC<0x31, "FloorF64">;
+def VM_OPC_TruncF64F32 : VM_OPC<0x32, "TruncF64F32">;
+def VM_OPC_ExtF32F64 : VM_OPC<0x37, "ExtF32F64">;
+def VM_OPC_CmpEQF64 : VM_OPC<0x40, "CmpEQF64">;
+def VM_OPC_CmpNEF64 : VM_OPC<0x41, "CmpNEF64">;
+def VM_OPC_CmpNZF64 : VM_OPC<0x42, "CmpNZF64">;
+def VM_OPC_CmpLTF64 : VM_OPC<0x43, "CmpLTF64">;
+
+// Runtime enum iree_vm_ext_f64_op_t:
+def VM_ExtF64OpcodeAttr :
+ VM_OPC_EnumAttr<"ExtF64Opcode",
+ "iree_vm_ext_f64_op_t",
+ "EXT_F64", // IREE_VM_OP_EXT_F64_*
+ "valid VM operation encodings in the f64 extension",
+ VM_OPC_PrefixExtF64, [
+ VM_OPC_GlobalLoadF64,
+ VM_OPC_GlobalStoreF64,
+ VM_OPC_GlobalLoadIndirectF64,
+ VM_OPC_GlobalStoreIndirectF64,
+ VM_OPC_ConstF64Zero,
+ VM_OPC_ConstF64,
+ VM_OPC_ListGetF64,
+ VM_OPC_ListSetF64,
+ VM_OPC_SelectF64,
+ VM_OPC_SwitchF64,
+ VM_OPC_AddF64,
+ VM_OPC_SubF64,
+ VM_OPC_MulF64,
+ VM_OPC_DivF64,
+ VM_OPC_RemF64,
+ VM_OPC_AbsF64,
+ VM_OPC_NegF64,
+ VM_OPC_CeilF64,
+ VM_OPC_FloorF64,
+ VM_OPC_TruncF64F32,
+ VM_OPC_ExtF32F64,
+ VM_OPC_CmpEQF64,
+ VM_OPC_CmpNEF64,
+ VM_OPC_CmpNZF64,
+ VM_OPC_CmpLTF64,
+ ]>;
+
+#endif // IREE_DIALECT_VM_OPCODES_F64
diff --git a/iree/compiler/Dialect/VM/IR/VMOps.cpp b/iree/compiler/Dialect/VM/IR/VMOps.cpp
index 765330c..896ba5f 100644
--- a/iree/compiler/Dialect/VM/IR/VMOps.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMOps.cpp
@@ -451,6 +451,16 @@
addMemoryEffectsForGlobal<GlobalI64Op>(*this, global(), effects);
}
+void GlobalLoadF32Op::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ addMemoryEffectsForGlobal<GlobalF32Op>(*this, global(), effects);
+}
+
+void GlobalLoadF64Op::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ addMemoryEffectsForGlobal<GlobalF64Op>(*this, global(), effects);
+}
+
void GlobalLoadRefOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
addMemoryEffectsForGlobal<GlobalRefOp>(*this, global(), effects);
@@ -499,8 +509,7 @@
//===----------------------------------------------------------------------===//
template <typename T>
-static ParseResult parseConstIntegerOp(OpAsmParser &parser,
- OperationState *result) {
+static ParseResult parseConstOp(OpAsmParser &parser, OperationState *result) {
Attribute valueAttr;
NamedAttrList dummyAttrs;
if (failed(parser.parseAttribute(valueAttr, "value", dummyAttrs))) {
@@ -521,7 +530,7 @@
}
template <typename T>
-static void printConstIntegerOp(OpAsmPrinter &p, T &op) {
+static void printConstOp(OpAsmPrinter &p, T &op) {
p << op.getOperationName() << ' ';
p.printAttribute(op.value());
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"});
@@ -548,6 +557,26 @@
}
template <int SZ>
+static bool isConstFloatBuildableWith(Attribute value, Type type) {
+ // FlatSymbolRefAttr can only be used with a function type.
+ if (value.isa<FlatSymbolRefAttr>()) {
+ return false;
+ }
+ // Otherwise, the attribute must have the same type as 'type'.
+ if (value.getType() != type) {
+ return false;
+ }
+ Type elementType;
+ if (auto floatAttr = value.dyn_cast<FloatAttr>()) {
+ elementType = floatAttr.getType();
+ } else if (auto elementsAttr = value.dyn_cast<ElementsAttr>()) {
+ elementType = elementsAttr.getType().getElementType();
+ }
+ if (!elementType) return false;
+ return elementType.getIntOrFloatBitWidth() == SZ;
+}
+
+template <int SZ>
static Attribute convertConstIntegerValue(Attribute value) {
assert(isConstIntegerBuildableWith<SZ>(value, value.getType()));
Builder builder(value.getContext());
@@ -574,6 +603,42 @@
return Attribute();
}
+static FloatType getFloatType(int bitwidth, MLIRContext *context) {
+ switch (bitwidth) {
+ case 16:
+ return FloatType::getF16(context);
+ case 32:
+ return FloatType::getF32(context);
+ case 64:
+ return FloatType::getF64(context);
+ default:
+ llvm_unreachable("unhandled floating point type");
+ return {};
+ }
+}
+
+template <int SZ>
+static Attribute convertConstFloatValue(Attribute value) {
+ assert(isConstFloatBuildableWith<SZ>(value, value.getType()));
+ Builder builder(value.getContext());
+ auto floatType = getFloatType(SZ, value.getContext());
+ int32_t dims = 1;
+ if (auto v = value.dyn_cast<FloatAttr>()) {
+ return FloatAttr::get(floatType, v.getValue());
+ } else if (auto v = value.dyn_cast<ElementsAttr>()) {
+ dims = v.getNumElements();
+ ShapedType adjustedType = VectorType::get({dims}, floatType);
+ if (auto elements = v.dyn_cast<SplatElementsAttr>()) {
+ return SplatElementsAttr::get(adjustedType, elements.getSplatValue());
+ } else {
+ return DenseElementsAttr::get(
+ adjustedType, llvm::to_vector<4>(v.getValues<Attribute>()));
+ }
+ }
+ llvm_unreachable("unexpected attribute type");
+ return Attribute();
+}
+
// static
bool ConstI32Op::isBuildableWith(Attribute value, Type type) {
return isConstIntegerBuildableWith<32>(value, type);
@@ -618,6 +683,50 @@
return build(builder, result, builder.getI64IntegerAttr(value));
}
+// static
+bool ConstF32Op::isBuildableWith(Attribute value, Type type) {
+ return isConstFloatBuildableWith<32>(value, type);
+}
+
+// static
+Attribute ConstF32Op::convertConstValue(Attribute value) {
+ return convertConstFloatValue<32>(value);
+}
+
+void ConstF32Op::build(OpBuilder &builder, OperationState &result,
+ Attribute value) {
+ Attribute newValue = convertConstValue(value);
+ result.addAttribute("value", newValue);
+ result.addTypes(newValue.getType());
+}
+
+void ConstF32Op::build(OpBuilder &builder, OperationState &result,
+ float value) {
+ return build(builder, result, builder.getF32FloatAttr(value));
+}
+
+// static
+bool ConstF64Op::isBuildableWith(Attribute value, Type type) {
+ return isConstFloatBuildableWith<64>(value, type);
+}
+
+// static
+Attribute ConstF64Op::convertConstValue(Attribute value) {
+ return convertConstFloatValue<64>(value);
+}
+
+void ConstF64Op::build(OpBuilder &builder, OperationState &result,
+ Attribute value) {
+ Attribute newValue = convertConstValue(value);
+ result.addAttribute("value", newValue);
+ result.addTypes(newValue.getType());
+}
+
+void ConstF64Op::build(OpBuilder &builder, OperationState &result,
+ double value) {
+ return build(builder, result, builder.getF64FloatAttr(value));
+}
+
void ConstI32ZeroOp::build(OpBuilder &builder, OperationState &result) {
result.addTypes(builder.getIntegerType(32));
}
@@ -626,6 +735,14 @@
result.addTypes(builder.getIntegerType(64));
}
+void ConstF32ZeroOp::build(OpBuilder &builder, OperationState &result) {
+ result.addTypes(builder.getF32Type());
+}
+
+void ConstF64ZeroOp::build(OpBuilder &builder, OperationState &result) {
+ result.addTypes(builder.getF64Type());
+}
+
void ConstRefZeroOp::build(OpBuilder &builder, OperationState &result,
Type objectType) {
result.addTypes(objectType);
diff --git a/iree/compiler/Dialect/VM/IR/VMOps.td b/iree/compiler/Dialect/VM/IR/VMOps.td
index 0107f92..04e93bc 100644
--- a/iree/compiler/Dialect/VM/IR/VMOps.td
+++ b/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -275,7 +275,8 @@
$_state.addAttribute("initializer",
$_builder.getSymbolRefAttr(initializer.getValue()));
} else if (initialValue.hasValue() &&
- initialValue.getValue().isa<IntegerAttr>()) {
+ (initialValue.getValue().isa<IntegerAttr>() ||
+ initialValue.getValue().isa<FloatAttr>())) {
$_state.addAttribute("initial_value", initialValue.getValue());
}
$_state.addAttribute("type", TypeAttr::get(type));
@@ -343,6 +344,30 @@
let hasCanonicalizer = 1;
}
+def VM_GlobalF32Op : VM_GlobalOp<"global.f32",
+ VM_ConstantFloatValueAttr<F32>,
+ [VM_ExtF32]> {
+ let summary = [{32-bit floating-point global declaration}];
+ let description = [{
+ Defines a global value that is treated as a scalar literal at runtime.
+ Initialized to zero unless a custom initializer function is specified.
+ }];
+
+ let hasCanonicalizer = 1;
+}
+
+def VM_GlobalF64Op : VM_GlobalOp<"global.f64",
+ VM_ConstantFloatValueAttr<F64>,
+ [VM_ExtF64]> {
+ let summary = [{64-bit floating-point global declaration}];
+ let description = [{
+ Defines a global value that is treated as a scalar literal at runtime.
+ Initialized to zero unless a custom initializer function is specified.
+ }];
+
+ let hasCanonicalizer = 1;
+}
+
def VM_GlobalRefOp : VM_GlobalOp<"global.ref", UnitAttr> {
let summary = [{ref<T> global declaration}];
let description = [{
@@ -505,6 +530,20 @@
let hasCanonicalizer = 1;
}
+def VM_GlobalLoadF32Op :
+ VM_GlobalLoadPrimitiveOp<F32, "global.load.f32", VM_OPC_GlobalLoadF32,
+ [VM_ExtF32]> {
+ let summary = [{global 32-bit floating-point load operation}];
+ let hasCanonicalizer = 1;
+}
+
+def VM_GlobalLoadF64Op :
+ VM_GlobalLoadPrimitiveOp<F64, "global.load.f64", VM_OPC_GlobalLoadF64,
+ [VM_ExtF64]> {
+ let summary = [{global 64-bit floating-point load operation}];
+ let hasCanonicalizer = 1;
+}
+
def VM_GlobalStoreI32Op :
VM_GlobalStorePrimitiveOp<I32, "global.store.i32", VM_OPC_GlobalStoreI32> {
let summary = [{global 32-bit integer store operation}];
@@ -516,6 +555,18 @@
let summary = [{global 64-bit integer store operation}];
}
+def VM_GlobalStoreF32Op :
+ VM_GlobalStorePrimitiveOp<F32, "global.store.f32", VM_OPC_GlobalStoreF32,
+ [VM_ExtF32]> {
+ let summary = [{global 32-bit floating-point store operation}];
+}
+
+def VM_GlobalStoreF64Op :
+ VM_GlobalStorePrimitiveOp<F64, "global.store.f64", VM_OPC_GlobalStoreF64,
+ [VM_ExtF64]> {
+ let summary = [{global 64-bit floating-point store operation}];
+}
+
def VM_GlobalLoadIndirectI32Op :
VM_GlobalLoadIndirectPrimitiveOp<I32, "global.load.indirect.i32",
VM_OPC_GlobalLoadIndirectI32> {
@@ -531,6 +582,22 @@
let hasCanonicalizer = 1;
}
+def VM_GlobalLoadIndirectF32Op :
+ VM_GlobalLoadIndirectPrimitiveOp<F32, "global.load.indirect.f32",
+ VM_OPC_GlobalLoadIndirectF32,
+ [VM_ExtF64]> {
+ let summary = [{global 32-bit floating-point load operation}];
+ let hasCanonicalizer = 1;
+}
+
+def VM_GlobalLoadIndirectF64Op :
+ VM_GlobalLoadIndirectPrimitiveOp<F64, "global.load.indirect.f64",
+ VM_OPC_GlobalLoadIndirectF64,
+ [VM_ExtI64]> {
+ let summary = [{global 64-bit floating-point load operation}];
+ let hasCanonicalizer = 1;
+}
+
def VM_GlobalStoreIndirectI32Op :
VM_GlobalStoreIndirectPrimitiveOp<I32, "global.store.indirect.i32",
VM_OPC_GlobalStoreIndirectI32> {
@@ -546,6 +613,22 @@
let hasCanonicalizer = 1;
}
+def VM_GlobalStoreIndirectF32Op :
+ VM_GlobalStoreIndirectPrimitiveOp<F32, "global.store.indirect.f32",
+ VM_OPC_GlobalStoreIndirectI32,
+ [VM_ExtF32]> {
+ let summary = [{global 32-bit floating-point store operation}];
+ let hasCanonicalizer = 1;
+}
+
+def VM_GlobalStoreIndirectF64Op :
+ VM_GlobalStoreIndirectPrimitiveOp<F64, "global.store.indirect.f64",
+ VM_OPC_GlobalStoreIndirectF64,
+ [VM_ExtF64]> {
+ let summary = [{global 64-bit floating-point store operation}];
+ let hasCanonicalizer = 1;
+}
+
def VM_GlobalLoadRefOp : VM_GlobalLoadOp<VM_AnyRef, "global.load.ref"> {
let summary = [{global ref<T> load operation}];
let description = [{
@@ -653,8 +736,8 @@
VM_EncResult<"result">,
];
- let parser = [{ return parseConstIntegerOp<$cppClass>(parser, &result); }];
- let printer = [{ return printConstIntegerOp<$cppClass>(p, *this); }];
+ let parser = [{ return parseConstOp<$cppClass>(parser, &result); }];
+ let printer = [{ return printConstOp<$cppClass>(p, *this); }];
}
def VM_ConstI32Op :
@@ -679,6 +762,28 @@
let hasCanonicalizer = 1;
}
+def VM_ConstF32Op :
+ VM_ConstantPrimitiveOp<F32, 32, "const.f32", VM_OPC_ConstF32,
+ "float", [VM_ExtF32]> {
+ let summary = [{32-bit floating-point constant operation}];
+ let arguments = (ins
+ VM_ConstantFloatValueAttr<F32>:$value
+ );
+ let hasFolder = 1;
+ let hasCanonicalizer = 1;
+}
+
+def VM_ConstF64Op :
+ VM_ConstantPrimitiveOp<F64, 64, "const.f64", VM_OPC_ConstF64,
+ "double", [VM_ExtF64]> {
+ let summary = [{64-bit floating-point constant operation}];
+ let arguments = (ins
+ VM_ConstantFloatValueAttr<F64>:$value
+ );
+ let hasFolder = 1;
+ let hasCanonicalizer = 1;
+}
+
class VM_ConstantPrimitiveZeroOp<Type type, string mnemonic, VM_OPC opcode,
string ctype, list<OpTrait> traits = []> :
VM_ConstOp<mnemonic, ctype, traits> {
@@ -717,6 +822,20 @@
let hasFolder = 1;
}
+def VM_ConstF32ZeroOp :
+ VM_ConstantPrimitiveZeroOp<F32, "const.f32.zero", VM_OPC_ConstF32Zero,
+ "float", [VM_ExtF32]> {
+ let summary = [{32-bit floating-point constant zero operation}];
+ let hasFolder = 1;
+}
+
+def VM_ConstF64ZeroOp :
+ VM_ConstantPrimitiveZeroOp<F64, "const.f64.zero", VM_OPC_ConstF64Zero,
+ "double", [VM_ExtI64]> {
+ let summary = [{64-bit floating-point constant zero operation}];
+ let hasFolder = 1;
+}
+
def VM_ConstRefZeroOp : VM_PureOp<"const.ref.zero", [
ConstantLike,
DeclareOpInterfaceMethods<VM_SerializableOpInterface>,
@@ -1024,12 +1143,24 @@
def VM_ListGetI64Op :
VM_ListGetPrimitiveOp<I64, "list.get.i64", VM_OPC_ListGetI64, [VM_ExtI64]>;
+def VM_ListGetF32Op :
+ VM_ListGetPrimitiveOp<F32, "list.get.f32", VM_OPC_ListGetF32, [VM_ExtF32]>;
+
+def VM_ListGetF64Op :
+ VM_ListGetPrimitiveOp<F64, "list.get.f64", VM_OPC_ListGetF64, [VM_ExtF64]>;
+
def VM_ListSetI32Op :
VM_ListSetPrimitiveOp<I32, "list.set.i32", VM_OPC_ListSetI32>;
def VM_ListSetI64Op :
VM_ListSetPrimitiveOp<I64, "list.set.i64", VM_OPC_ListSetI64, [VM_ExtI64]>;
+def VM_ListSetF32Op :
+ VM_ListSetPrimitiveOp<F32, "list.set.f32", VM_OPC_ListSetF32, [VM_ExtF32]>;
+
+def VM_ListSetF64Op :
+ VM_ListSetPrimitiveOp<F64, "list.set.f64", VM_OPC_ListSetF64, [VM_ExtF64]>;
+
def VM_ListGetRefOp :
VM_PureOp<"list.get.ref", [
DeclareOpInterfaceMethods<VM_SerializableOpInterface>,
@@ -1144,6 +1275,18 @@
let hasFolder = 1;
}
+def VM_SelectF32Op : VM_SelectPrimitiveOp<F32, "select.f32", VM_OPC_SelectF32,
+ [VM_ExtF32]> {
+ let summary = [{floating-point select operation}];
+ let hasFolder = 1;
+}
+
+def VM_SelectF64Op : VM_SelectPrimitiveOp<F64, "select.f64", VM_OPC_SelectF64,
+ [VM_ExtF64]> {
+ let summary = [{floating-point select operation}];
+ let hasFolder = 1;
+}
+
def VM_SelectRefOp : VM_PureOp<"select.ref", [
DeclareOpInterfaceMethods<VM_SerializableOpInterface>,
AllTypesMatch<["true_value", "false_value", "result"]>,
@@ -1189,8 +1332,8 @@
the index is out of bounds.
```mlir
- // Switch %arg0 to cases of %c100/%c200/%c300 if arg0==0, ==1, ==2.
- // If %arg0 is out of range (<0 or >2) then default to %c5.
+ // Switch %index to cases of %c100/%c200/%c300 if index==0, ==1, ==2.
+ // If %index is out of range (<0 or >2) then default to %c5.
%0 = vm.switch.i32 %index[%c100, %c200, %c300] else %c5 : i32
```
}];
@@ -1267,6 +1410,18 @@
let hasFolder = 1;
}
+def VM_SwitchF32Op : VM_SwitchFloatOp<F32, "switch.f32", VM_OPC_SwitchF32,
+ [VM_ExtF32]> {
+ let summary = [{floating-point switch operation}];
+ let hasFolder = 1;
+}
+
+def VM_SwitchF64Op : VM_SwitchFloatOp<F64, "switch.f64", VM_OPC_SwitchF64,
+ [VM_ExtF64]> {
+ let summary = [{floating-point switch operation}];
+ let hasFolder = 1;
+}
+
def VM_SwitchRefOp : VM_PureOp<"switch.ref", [
DeclareOpInterfaceMethods<VM_SerializableOpInterface>,
AllTypesMatch<["default_value", "result"]>,
@@ -1478,6 +1633,138 @@
}
//===----------------------------------------------------------------------===//
+// Floating-point arithmetic
+//===----------------------------------------------------------------------===//
+
+def VM_AddF32Op :
+ VM_BinaryArithmeticOp<F32, "add.f32", VM_OPC_AddF32,
+ [VM_ExtF32, Commutative]> {
+ let summary = [{floating-point add operation}];
+ let hasFolder = 1;
+}
+
+def VM_AddF64Op :
+ VM_BinaryArithmeticOp<F64, "add.f64", VM_OPC_AddF64,
+ [VM_ExtF64, Commutative]> {
+ let summary = [{floating-point add operation}];
+ let hasFolder = 1;
+}
+
+def VM_SubF32Op :
+ VM_BinaryArithmeticOp<F32, "sub.f32", VM_OPC_SubF32,
+ [VM_ExtF32]> {
+ let summary = [{floating point subtraction operation}];
+ let hasFolder = 1;
+}
+
+def VM_SubF64Op :
+ VM_BinaryArithmeticOp<F64, "sub.f64", VM_OPC_SubF64,
+ [VM_ExtF64]> {
+ let summary = [{floating point subtraction operation}];
+ let hasFolder = 1;
+}
+
+def VM_MulF32Op :
+ VM_BinaryArithmeticOp<F32, "mul.f32", VM_OPC_MulF32,
+ [VM_ExtF32, Commutative]> {
+ let summary = [{floating point multiplication operation}];
+ let hasFolder = 1;
+ let hasCanonicalizer = 1;
+}
+
+def VM_MulF64Op :
+ VM_BinaryArithmeticOp<F64, "mul.f64", VM_OPC_MulF64,
+ [VM_ExtF64, Commutative]> {
+ let summary = [{floating point multiplication operation}];
+ let hasFolder = 1;
+ let hasCanonicalizer = 1;
+}
+
+def VM_DivF32Op :
+ VM_BinaryArithmeticOp<F32, "div.f32", VM_OPC_DivF32,
+ [VM_ExtF32]> {
+ let summary = [{floating point division operation}];
+ let hasFolder = 1;
+}
+
+def VM_DivF64Op :
+ VM_BinaryArithmeticOp<F64, "div.f64", VM_OPC_DivF64,
+ [VM_ExtF64]> {
+ let summary = [{floating point division operation}];
+ let hasFolder = 1;
+}
+
+def VM_RemF32Op :
+ VM_BinaryArithmeticOp<F32, "rem.f32", VM_OPC_RemF32,
+ [VM_ExtF32]> {
+ let summary = [{floating point remainder operation}];
+ let hasFolder = 1;
+}
+
+def VM_RemF64Op :
+ VM_BinaryArithmeticOp<F64, "rem.f64", VM_OPC_RemF64,
+ [VM_ExtF64]> {
+ let summary = [{floating point remainder operation}];
+ let hasFolder = 1;
+}
+
+def VM_AbsF32Op :
+ VM_UnaryArithmeticOp<F32, "abs.f32", VM_OPC_AbsF32,
+ [VM_ExtF32]> {
+ let summary = [{floating point absolute-value operation}];
+ let hasFolder = 1;
+}
+
+def VM_AbsF64Op :
+ VM_UnaryArithmeticOp<F64, "abs.f64", VM_OPC_AbsF64,
+ [VM_ExtF64]> {
+ let summary = [{floating point absolute-value operation}];
+ let hasFolder = 1;
+}
+
+def VM_NegF32Op :
+ VM_UnaryArithmeticOp<F32, "neg.f32", VM_OPC_NegF32,
+ [VM_ExtF32]> {
+ let summary = [{floating point negation operation}];
+ let hasFolder = 1;
+}
+
+def VM_NegF64Op :
+ VM_UnaryArithmeticOp<F64, "neg.f64", VM_OPC_NegF64,
+ [VM_ExtF64]> {
+ let summary = [{floating point negation operation}];
+ let hasFolder = 1;
+}
+
+def VM_CeilF32Op :
+ VM_UnaryArithmeticOp<F32, "ceil.f32", VM_OPC_CeilF32,
+ [VM_ExtF32]> {
+ let summary = [{floating point ceiling operation}];
+ let hasFolder = 1;
+}
+
+def VM_CeilF64Op :
+ VM_UnaryArithmeticOp<F64, "ceil.f64", VM_OPC_CeilF64,
+ [VM_ExtF64]> {
+ let summary = [{floating point ceiling operation}];
+ let hasFolder = 1;
+}
+
+def VM_FloorF32Op :
+ VM_UnaryArithmeticOp<F32, "floor.f32", VM_OPC_FloorF32,
+ [VM_ExtF32]> {
+ let summary = [{floating point floor operation}];
+ let hasFolder = 1;
+}
+
+def VM_FloorF64Op :
+ VM_UnaryArithmeticOp<F64, "floor.f64", VM_OPC_FloorF64,
+ [VM_ExtF64]> {
+ let summary = [{floating point floor operation}];
+ let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
// Integer bit manipulation
//===----------------------------------------------------------------------===//
@@ -1674,6 +1961,13 @@
let hasFolder = 1;
}
+def VM_TruncF64F32Op :
+ VM_ConversionOp<F64, F32, "trunc.f64.f32", VM_OPC_TruncF64F32,
+ [VM_ExtF64]> {
+ let summary = [{floating-point truncate to 32 bits}];
+ let hasFolder = 1;
+}
+
def VM_ExtI8I32SOp :
VM_ConversionOp<I32, I32, "ext.i8.i32.s", VM_OPC_ExtI8I32S> {
let summary = [{integer sign extend 8 bits to 32 bits}];
@@ -1740,6 +2034,13 @@
let hasFolder = 1;
}
+def VM_ExtF32F64Op :
+ VM_ConversionOp<F32, F64, "ext.f32.f64", VM_OPC_ExtF32F64,
+ [VM_ExtF64]> {
+ let summary = [{floating-point zero extend 32 bits to 64 bits}];
+ let hasFolder = 1;
+}
+
//===----------------------------------------------------------------------===//
// Native reduction (horizontal) arithmetic
//===----------------------------------------------------------------------===//
@@ -1840,6 +2141,22 @@
let hasFolder = 1;
}
+def VM_CmpEQF32Op :
+ VM_BinaryComparisonOp<F32, "cmp.eq.f32", VM_OPC_CmpEQF32,
+ [VM_ExtF32, Commutative]> {
+ let summary = [{floating-point equality comparison operation}];
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+def VM_CmpEQF64Op :
+ VM_BinaryComparisonOp<F64, "cmp.eq.f64", VM_OPC_CmpEQF64,
+ [VM_ExtF64, Commutative]> {
+ let summary = [{floating-point equality comparison operation}];
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
def VM_CmpNEI32Op :
VM_BinaryComparisonOp<I32, "cmp.ne.i32", VM_OPC_CmpNEI32,
[Commutative]> {
@@ -1856,6 +2173,22 @@
let hasFolder = 1;
}
+def VM_CmpNEF32Op :
+ VM_BinaryComparisonOp<F32, "cmp.ne.f32", VM_OPC_CmpNEF32,
+ [VM_ExtF32, Commutative]> {
+ let summary = [{floating-point inequality comparison operation}];
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+def VM_CmpNEF64Op :
+ VM_BinaryComparisonOp<F64, "cmp.ne.f64", VM_OPC_CmpNEF64,
+ [VM_ExtF64, Commutative]> {
+ let summary = [{floating-point inequality comparison operation}];
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
def VM_CmpLTI32SOp :
VM_BinaryComparisonOp<I32, "cmp.lt.i32.s", VM_OPC_CmpLTI32S> {
let summary = [{signed integer less-than comparison operation}];
@@ -1871,6 +2204,22 @@
let hasFolder = 1;
}
+def VM_CmpLTF32Op :
+ VM_BinaryComparisonOp<F32, "cmp.lt.f32", VM_OPC_CmpLTF32,
+ [VM_ExtF32]> {
+ let summary = [{floating-point less-than comparison operation}];
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+def VM_CmpLTF64Op :
+ VM_BinaryComparisonOp<F64, "cmp.lt.f64", VM_OPC_CmpLTF64,
+ [VM_ExtF64]> {
+ let summary = [{floating-point less-than comparison operation}];
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
def VM_CmpLTI32UOp :
VM_BinaryComparisonOp<I32, "cmp.lt.i32.u", VM_OPC_CmpLTI32U> {
let summary = [{unsigned integer less-than comparison operation}];
@@ -1916,6 +2265,22 @@
let hasFolder = 1;
}
+def VM_CmpLTEF32Op :
+ VM_BinaryComparisonPseudoOp<F32, "cmp.lte.f32",
+ [VM_ExtF32]> {
+ let summary = [{floating-point less-than-or-equal comparison operation}];
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+def VM_CmpLTEF64Op :
+ VM_BinaryComparisonPseudoOp<F64, "cmp.lte.f64",
+ [VM_ExtF64]> {
+ let summary = [{floating-point less-than-or-equal comparison operation}];
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
def VM_CmpGTI32SOp :
VM_BinaryComparisonPseudoOp<I32, "cmp.gt.i32.s"> {
let summary = [{signed integer greater-than comparison operation}];
@@ -1946,6 +2311,22 @@
let hasFolder = 1;
}
+def VM_CmpGTF32Op :
+ VM_BinaryComparisonPseudoOp<F32, "cmp.gt.f32",
+ [VM_ExtF32]> {
+ let summary = [{floating-point greater-than comparison operation}];
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+def VM_CmpGTF64Op :
+ VM_BinaryComparisonPseudoOp<F64, "cmp.gt.f64",
+ [VM_ExtF64]> {
+ let summary = [{floating-point greater-than comparison operation}];
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
def VM_CmpGTEI32SOp :
VM_BinaryComparisonPseudoOp<I32, "cmp.gte.i32.s"> {
let summary = [{signed integer greater-than-or-equal comparison operation}];
@@ -1976,6 +2357,22 @@
let hasFolder = 1;
}
+def VM_CmpGTEF32Op :
+ VM_BinaryComparisonPseudoOp<F32, "cmp.gte.f32",
+ [VM_ExtF32]> {
+ let summary = [{floating-point greater-than-or-equal comparison operation}];
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
+def VM_CmpGTEF64Op :
+ VM_BinaryComparisonPseudoOp<F64, "cmp.gte.f64",
+ [VM_ExtF64]> {
+ let summary = [{floating-point greater-than-or-equal comparison operation}];
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+}
+
def VM_CmpNZI32Op :
VM_UnaryComparisonOp<I32, "cmp.nz.i32", VM_OPC_CmpNZI32> {
let summary = [{integer non-zero comparison operation}];
@@ -1995,6 +2392,26 @@
let hasFolder = 1;
}
+def VM_CmpNZF32Op :
+ VM_UnaryComparisonOp<F32, "cmp.nz.f32", VM_OPC_CmpNZF32,
+ [VM_ExtF32]> {
+ let summary = [{floating-point non-zero comparison operation}];
+ let description = [{
+ Compares the given floating-point operand for a non-zero value.
+ }];
+ let hasFolder = 1;
+}
+
+def VM_CmpNZF64Op :
+ VM_UnaryComparisonOp<F64, "cmp.nz.f64", VM_OPC_CmpNZF64,
+ [VM_ExtF64]> {
+ let summary = [{floating-point non-zero comparison operation}];
+ let description = [{
+ Compares the given floating-point operand for a non-zero value.
+ }];
+ let hasFolder = 1;
+}
+
def VM_CmpEQRefOp :
VM_BinaryComparisonOp<VM_AnyRef, "cmp.eq.ref", VM_OPC_CmpEQRef,
[Commutative]> {
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp
index 643a1da..ade40f9 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp
@@ -126,6 +126,28 @@
return currentOp_->emitOpError()
<< "attribute of bitwidth " << bitWidth << " not supported";
}
+ } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
+ switch (bitWidth) {
+ case 32: {
+ union {
+ float f32;
+ uint32_t u32;
+ } value;
+ value.f32 = floatAttr.getValue().convertToFloat();
+ return writeUint32(value.u32);
+ }
+ case 64: {
+ union {
+ double f64;
+ uint64_t u64;
+ } value;
+ value.f64 = floatAttr.getValue().convertToDouble();
+ return writeUint64(value.u64);
+ }
+ default:
+ return currentOp_->emitOpError()
+ << "attribute of bitwidth " << bitWidth << " not supported";
+ }
} else {
return currentOp_->emitOpError()
<< "attribute type not supported for primitive serialization: "
diff --git a/iree/compiler/Dialect/VM/Target/CallingConventionUtils.cpp b/iree/compiler/Dialect/VM/Target/CallingConventionUtils.cpp
index 29578ec..92e0125 100644
--- a/iree/compiler/Dialect/VM/Target/CallingConventionUtils.cpp
+++ b/iree/compiler/Dialect/VM/Target/CallingConventionUtils.cpp
@@ -37,8 +37,8 @@
if (auto refPtrType = type.dyn_cast<IREE::VM::RefType>()) {
s.push_back('r');
return success();
- } else if (auto intType = type.dyn_cast<IntegerType>()) {
- switch (intType.getIntOrFloatBitWidth()) {
+ } else if (auto integerType = type.dyn_cast<IntegerType>()) {
+ switch (integerType.getIntOrFloatBitWidth()) {
default:
case 32:
s.push_back('i');
@@ -47,6 +47,16 @@
s.push_back('I');
return success();
}
+ } else if (auto floatType = type.dyn_cast<FloatType>()) {
+ switch (floatType.getIntOrFloatBitWidth()) {
+ default:
+ case 32:
+ s.push_back('f');
+ return success();
+ case 64:
+ s.push_back('F');
+ return success();
+ }
} else if (auto tupleType = type.dyn_cast<TupleType>()) {
// Flatten tuple (so tuple<i32, i64> -> `...iI...`).
SmallVector<Type, 4> flattenedTypes;
diff --git a/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp b/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
index bb87505..6218dbc 100644
--- a/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
+++ b/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
@@ -130,16 +130,31 @@
// Returns {} if the constant is zero.
std::pair<LogicalResult, Value> createConst(Location loc, Attribute value,
OpBuilder &builder) {
- if (auto intValue = value.dyn_cast<IntegerAttr>()) {
- if (intValue.getValue().isNullValue()) {
+ if (auto integerAttr = value.dyn_cast<IntegerAttr>()) {
+ if (integerAttr.getValue().isNullValue()) {
// Globals are zero-initialized by default.
return {success(), {}};
}
- switch (intValue.getValue().getBitWidth()) {
+ switch (integerAttr.getType().getIntOrFloatBitWidth()) {
case 32:
- return {success(), builder.createOrFold<ConstI32Op>(loc, intValue)};
+ return {success(),
+ builder.createOrFold<ConstI32Op>(loc, integerAttr)};
case 64:
- return {success(), builder.createOrFold<ConstI64Op>(loc, intValue)};
+ return {success(),
+ builder.createOrFold<ConstI64Op>(loc, integerAttr)};
+ default:
+ return {failure(), {}};
+ }
+ } else if (auto floatAttr = value.dyn_cast<FloatAttr>()) {
+ if (floatAttr.getValue().isZero()) {
+ // Globals are zero-initialized by default.
+ return {success(), {}};
+ }
+ switch (floatAttr.getType().getIntOrFloatBitWidth()) {
+ case 32:
+ return {success(), builder.createOrFold<ConstF32Op>(loc, floatAttr)};
+ case 64:
+ return {success(), builder.createOrFold<ConstF64Op>(loc, floatAttr)};
default:
return {failure(), {}};
}
@@ -150,8 +165,8 @@
// Stores a value to a global; the global must be mutable.
LogicalResult storePrimitiveGlobal(Location loc, StringRef symName,
Value value, OpBuilder &builder) {
- if (auto intType = value.getType().dyn_cast<IntegerType>()) {
- switch (intType.getIntOrFloatBitWidth()) {
+ if (auto integerType = value.getType().dyn_cast<IntegerType>()) {
+ switch (integerType.getIntOrFloatBitWidth()) {
case 32:
builder.create<GlobalStoreI32Op>(loc, value, symName);
return success();
@@ -161,6 +176,17 @@
default:
return failure();
}
+ } else if (auto floatType = value.getType().dyn_cast<FloatType>()) {
+ switch (floatType.getIntOrFloatBitWidth()) {
+ case 32:
+ builder.create<GlobalStoreF32Op>(loc, value, symName);
+ return success();
+ case 64:
+ builder.create<GlobalStoreF64Op>(loc, value, symName);
+ return success();
+ default:
+ return failure();
+ }
}
return failure();
}
diff --git a/iree/compiler/Dialect/VM/Transforms/OrdinalAllocation.cpp b/iree/compiler/Dialect/VM/Transforms/OrdinalAllocation.cpp
index dca926a..904d487 100644
--- a/iree/compiler/Dialect/VM/Transforms/OrdinalAllocation.cpp
+++ b/iree/compiler/Dialect/VM/Transforms/OrdinalAllocation.cpp
@@ -51,7 +51,8 @@
int nextExportOrdinal = 0;
int nextGlobalRefOrdinal = 0;
int nextRodataOrdinal = 0;
- SmallVector<SmallVector<VMGlobalOp, 4>, 8> primitiveGlobalOps(8);
+ SmallVector<SmallVector<VMGlobalOp, 4>, 8> primitiveGlobalOps(
+ sizeof(int64_t) + 1);
for (auto &op : getOperation().getBlock().getOperations()) {
Optional<int> ordinal = llvm::None;
if (auto funcOp = dyn_cast<FuncOp>(op)) {
diff --git a/iree/compiler/Dialect/VM/Transforms/Passes.cpp b/iree/compiler/Dialect/VM/Transforms/Passes.cpp
index c88a068..f7fdd54 100644
--- a/iree/compiler/Dialect/VM/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/VM/Transforms/Passes.cpp
@@ -35,7 +35,9 @@
passManager.addPass(createInlinerPass());
passManager.addPass(createCSEPass());
passManager.addPass(createSymbolDCEPass());
- passManager.addNestedPass<VM::ModuleOp>(createSinkDefiningOpsPass());
+ if (targetOptions.optimizeForStackSize) {
+ passManager.addNestedPass<VM::ModuleOp>(createSinkDefiningOpsPass());
+ }
}
void registerVMTransformPassPipeline() {