[vm] Add support for SI64 to F32 casts (#19455)
Adds support to the VM for casting from `si64` type to `f32` type.
Enables the lowering of `arith.sitofp %arg0 : i64 to f32` after
demotion.
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp
index 8e5f96f..6d902e5 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp
@@ -526,30 +526,93 @@
}
};
-template <typename OpTy, typename ExtOpTy, typename CastOpTy>
-struct IntToFPOpConversion : public OpConversionPattern<OpTy> {
- using OpConversionPattern<OpTy>::OpConversionPattern;
+struct SIToFPOpConversion : public OpConversionPattern<arith::SIToFPOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(OpTy srcOp, typename OpTy::Adaptor adaptor,
+ matchAndRewrite(arith::SIToFPOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto srcType = srcOp.getIn().getType();
+ auto input = srcOp.getIn();
+ auto srcType = input.getType();
auto dstType = srcOp.getResult().getType();
- if (!dstType.isF32() ||
- !(srcType.isSignedInteger() || srcType.isSignlessInteger())) {
+ auto resultType = getTypeConverter()->convertType(dstType);
+
+ if (!(dstType.isF32() || dstType.isF64())) {
return rewriter.notifyMatchFailure(srcOp, "unsupported type");
}
- Value input = srcOp.getIn();
- if (!(srcType.isSignlessInteger(32) || srcType.isSignedInteger(32))) {
- if (srcType.getIntOrFloatBitWidth() < 32) {
- input = rewriter.create<ExtOpTy>(
- srcOp.getLoc(), IntegerType::get(this->getContext(), 32), input);
- } else {
+
+ if (srcType.isSignedInteger(32) || srcType.isSignlessInteger(32)) {
+ if (dstType.isF32()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::CastSI32F32Op>(srcOp, resultType,
+ input);
+ return success();
+ }
+ if (dstType.isF64()) {
return rewriter.notifyMatchFailure(srcOp, "unsupported type");
}
}
+ if (srcType.isSignedInteger(64) || srcType.isSignlessInteger(64)) {
+ if (dstType.isF32()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::CastSI64F32Op>(srcOp, resultType,
+ input);
+ } else {
+ rewriter.replaceOpWithNewOp<IREE::VM::CastSI64F64Op>(srcOp, resultType,
+ input);
+ }
+ return success();
+ }
- auto resultType = this->getTypeConverter()->convertType(dstType);
- rewriter.replaceOpWithNewOp<CastOpTy>(srcOp, resultType, input);
+ if (srcType.getIntOrFloatBitWidth() < 32) {
+ input = rewriter.create<arith::ExtSIOp>(
+ srcOp.getLoc(), IntegerType::get(this->getContext(), 32), input);
+ }
+
+ rewriter.replaceOpWithNewOp<IREE::VM::CastSI32F32Op>(srcOp, resultType,
+ input);
+ return success();
+ }
+};
+
+struct UIToFPOpConversion : public OpConversionPattern<arith::UIToFPOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(arith::UIToFPOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto input = srcOp.getIn();
+ auto srcType = input.getType();
+ auto dstType = srcOp.getResult().getType();
+
+ if (!(dstType.isF32() || dstType.isF64())) {
+ return rewriter.notifyMatchFailure(srcOp, "unsupported type");
+ }
+
+ auto resultType = getTypeConverter()->convertType(dstType);
+ if (srcType.isUnsignedInteger(32) || srcType.isSignlessInteger(32)) {
+ if (dstType.isF32()) {
+ rewriter.replaceOpWithNewOp<IREE::VM::CastUI32F32Op>(srcOp, resultType,
+ input);
+ return success();
+ }
+ if (dstType.isF64()) {
+ return rewriter.notifyMatchFailure(srcOp, "unsupported type");
+ }
+ }
+ if (srcType.isUnsignedInteger(64) || srcType.isSignlessInteger(64)) {
+ if (dstType.isF32()) {
+ return rewriter.notifyMatchFailure(srcOp, "unsupported type");
+ }
+
+ rewriter.replaceOpWithNewOp<IREE::VM::CastUI64F64Op>(srcOp, resultType,
+ input);
+ return success();
+ }
+
+ if (srcType.getIntOrFloatBitWidth() < 32) {
+ input = rewriter.create<arith::ExtUIOp>(
+ srcOp.getLoc(), IntegerType::get(this->getContext(), 32), input);
+ }
+
+ rewriter.replaceOpWithNewOp<IREE::VM::CastUI32F32Op>(srcOp, resultType,
+ input);
return success();
}
};
@@ -742,12 +805,9 @@
IREE::VM::MaxF64Op>>(typeConverter, context);
// Floating-point conversion ops.
- patterns.insert<IntToFPOpConversion<arith::SIToFPOp, arith::ExtSIOp,
- IREE::VM::CastSI32F32Op>,
- IntToFPOpConversion<arith::UIToFPOp, arith::ExtUIOp,
- IREE::VM::CastUI32F32Op>,
- FPToSIOpConversion, FPToUIOpConversion, BitcastOpConversion>(
- typeConverter, context);
+ patterns.insert<SIToFPOpConversion, UIToFPOpConversion, FPToSIOpConversion,
+ FPToUIOpConversion, BitcastOpConversion>(typeConverter,
+ context);
// Shift ops.
patterns
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/conversion_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/conversion_ops.mlir
index be4ec1f..5b8da0b 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/conversion_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/conversion_ops.mlir
@@ -275,6 +275,18 @@
// -----
+// CHECK-LABEL: @sitofp_i64_f32
+module @sitofp_i64_f32 {
+ // CHECK: vm.func private @fn(%[[ARG0:.+]]: i64)
+ func.func @fn(%arg0: i64) -> f32 {
+ // CHECK: vm.cast.si64.f32 %[[ARG0]] : i64 -> f32
+ %0 = arith.sitofp %arg0 : i64 to f32
+ return %0 : f32
+ }
+}
+
+// -----
+
// CHECK-LABEL: @uitofp_i8_f32
module @uitofp_i8_f32 {
// CHECK: vm.func private @fn(%[[ARG0:.+]]: i32)
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
index 71eaea6..845c3e3 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
@@ -4521,6 +4521,7 @@
ADD_GENERIC_PATTERN(IREE::VM::CastF32UI32Op, "vm_cast_f32ui32");
ADD_GENERIC_PATTERN(IREE::VM::CastF32UI64Op, "vm_cast_f32ui64");
ADD_GENERIC_PATTERN(IREE::VM::CastSI32F32Op, "vm_cast_si32f32");
+ ADD_GENERIC_PATTERN(IREE::VM::CastSI64F32Op, "vm_cast_si64f32");
ADD_GENERIC_PATTERN(IREE::VM::CastUI32F32Op, "vm_cast_ui32f32");
ADD_GENERIC_PATTERN(IREE::VM::CeilF32Op, "vm_ceil_f32");
ADD_GENERIC_PATTERN(IREE::VM::CmpEQF32OOp, "vm_cmp_eq_f32o");
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index bec93d4..f227292 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -1683,6 +1683,16 @@
});
}
+OpFoldResult CastSI64F32Op::fold(FoldAdaptor operands) {
+ return constFoldCastOp<IntegerAttr, FloatAttr>(
+ Float32Type::get(getContext()), operands.getOperand(),
+ [&](const APInt &a) {
+ APFloat b = APFloat(0.0f);
+ b.convertFromAPInt(a, /*IsSigned=*/true, APFloat::rmNearestTiesToAway);
+ return b;
+ });
+}
+
OpFoldResult CastUI32F32Op::fold(FoldAdaptor operands) {
return constFoldCastOp<IntegerAttr, FloatAttr>(
Float32Type::get(getContext()), operands.getOperand(),
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td
index af9295f..1ce37ae 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td
@@ -45,6 +45,7 @@
def VM_OPC_MaxF32 : VM_OPC<0x38, "MaxF32">;
def VM_OPC_CastSI32F32 : VM_OPC<0x14, "CastSI32F32">;
+def VM_OPC_CastSI64F32 : VM_OPC<0x3C, "CastSI64F32">;
def VM_OPC_CastUI32F32 : VM_OPC<0x15, "CastUI32F32">;
def VM_OPC_CastF32SI32 : VM_OPC<0x16, "CastF32SI32">;
def VM_OPC_CastF32SI64 : VM_OPC<0x3A, "CastF32SI64">;
@@ -116,10 +117,12 @@
VM_OPC_CeilF32,
VM_OPC_FloorF32,
VM_OPC_RoundF32,
+ VM_OPC_RoundF32Even,
VM_OPC_MinF32,
VM_OPC_MaxF32,
VM_OPC_CastSI32F32,
+ VM_OPC_CastSI64F32,
VM_OPC_CastUI32F32,
VM_OPC_CastF32SI32,
VM_OPC_CastF32SI64,
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
index c23e687..f7e5944 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -3167,6 +3167,13 @@
let hasFolder = 1;
}
+def VM_CastSI64F32Op :
+ VM_ConversionOp<I64, F32, "cast.si64.f32", VM_OPC_CastSI64F32,
+ [VM_ExtF32]> {
+ let summary = [{cast from a signed integer to a float-point value}];
+ let hasFolder = 1;
+}
+
def VM_CastUI64F64Op :
VM_ConversionOp<I64, F64, "cast.ui64.f64", VM_OPC_CastUI64F64,
[VM_ExtF64]> {
diff --git a/runtime/src/iree/vm/bytecode/disassembler.c b/runtime/src/iree/vm/bytecode/disassembler.c
index 02d93e9..4b24843 100644
--- a/runtime/src/iree/vm/bytecode/disassembler.c
+++ b/runtime/src/iree/vm/bytecode/disassembler.c
@@ -2111,6 +2111,16 @@
EMIT_OPTIONAL_VALUE_I32(regs->i32[operand_reg]);
break;
}
+ DISASM_OP(EXT_F32, CastSI64F32) {
+ uint16_t operand_reg = VM_ParseOperandRegI64("operand");
+ uint16_t result_reg = VM_ParseResultRegF32("result");
+ EMIT_F32_REG_NAME(result_reg);
+ IREE_RETURN_IF_ERROR(
+ iree_string_builder_append_cstring(b, " = vm.cast.si64.f32 "));
+ EMIT_I64_REG_NAME(operand_reg);
+ EMIT_OPTIONAL_VALUE_I64(regs->i32[operand_reg]);
+ break;
+ }
DISASM_OP(EXT_F32, CastUI32F32) {
uint16_t operand_reg = VM_ParseOperandRegI32("operand");
uint16_t result_reg = VM_ParseResultRegF32("result");
diff --git a/runtime/src/iree/vm/bytecode/dispatch.c b/runtime/src/iree/vm/bytecode/dispatch.c
index 40ae195..ba48f32 100644
--- a/runtime/src/iree/vm/bytecode/dispatch.c
+++ b/runtime/src/iree/vm/bytecode/dispatch.c
@@ -2046,6 +2046,11 @@
float* result = VM_DecResultRegF32("result");
*result = vm_cast_si32f32(operand);
});
+ DISPATCH_OP(EXT_F32, CastSI64F32, {
+ int64_t operand = (int64_t)VM_DecOperandRegI64("operand");
+ float* result = VM_DecResultRegF32("result");
+ *result = vm_cast_si64f32(operand);
+ });
DISPATCH_OP(EXT_F32, CastUI32F32, {
int32_t operand = (int32_t)VM_DecOperandRegI32("operand");
float* result = VM_DecResultRegF32("result");
diff --git a/runtime/src/iree/vm/bytecode/utils/generated/op_table.h b/runtime/src/iree/vm/bytecode/utils/generated/op_table.h
index 2a5a76c..2c76073 100644
--- a/runtime/src/iree/vm/bytecode/utils/generated/op_table.h
+++ b/runtime/src/iree/vm/bytecode/utils/generated/op_table.h
@@ -388,10 +388,10 @@
OPC(0x77, AbsI32) \
OPC(0x78, AbsI64) \
OPC(0x79, Block) \
- OPC(0x7A, MinI64S) \
- OPC(0x7B, MinI64U) \
- OPC(0x7C, MaxI64S) \
- OPC(0x7D, MaxI64U) \
+ OPC(0x7A, MinI32S) \
+ OPC(0x7B, MinI32U) \
+ OPC(0x7C, MaxI32S) \
+ OPC(0x7D, MaxI32U) \
OPC(0x7E, MinI64S) \
OPC(0x7F, MinI64U) \
OPC(0x80, MaxI64S) \
@@ -584,7 +584,7 @@
IREE_VM_OP_EXT_F32_RoundF32Even = 0x39,
IREE_VM_OP_EXT_F32_CastF32SI64 = 0x3A,
IREE_VM_OP_EXT_F32_CastF32UI64 = 0x3B,
- IREE_VM_OP_EXT_F32_RSV_0x3C,
+ IREE_VM_OP_EXT_F32_CastSI64F32 = 0x3C,
IREE_VM_OP_EXT_F32_RSV_0x3D,
IREE_VM_OP_EXT_F32_RSV_0x3E,
IREE_VM_OP_EXT_F32_RSV_0x3F,
@@ -843,7 +843,7 @@
OPC(0x39, RoundF32Even) \
OPC(0x3A, CastF32SI64) \
OPC(0x3B, CastF32UI64) \
- RSV(0x3C) \
+ OPC(0x3C, CastSI64F32) \
RSV(0x3D) \
RSV(0x3E) \
RSV(0x3F) \
diff --git a/runtime/src/iree/vm/bytecode/verifier.c b/runtime/src/iree/vm/bytecode/verifier.c
index c5b9d63..5c726db 100644
--- a/runtime/src/iree/vm/bytecode/verifier.c
+++ b/runtime/src/iree/vm/bytecode/verifier.c
@@ -1823,6 +1823,10 @@
VM_VerifyOperandRegI32(operand);
VM_VerifyResultRegF32(result);
});
+ VERIFY_OP(EXT_F32, CastSI64F32, {
+ VM_VerifyOperandRegI64(operand);
+ VM_VerifyResultRegF32(result);
+ });
VERIFY_OP(EXT_F32, CastUI32F32, {
VM_VerifyOperandRegI32(operand);
VM_VerifyResultRegF32(result);
diff --git a/runtime/src/iree/vm/ops.h b/runtime/src/iree/vm/ops.h
index b9ffd70..68c939e 100644
--- a/runtime/src/iree/vm/ops.h
+++ b/runtime/src/iree/vm/ops.h
@@ -599,6 +599,7 @@
//===------------------------------------------------------------------===//
static inline float vm_cast_si32f32(int32_t operand) { return (float)operand; }
+static inline float vm_cast_si64f32(int64_t operand) { return (float)operand; }
static inline float vm_cast_ui32f32(int32_t operand) {
return (float)(uint32_t)operand;
}
diff --git a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json
index 721c34c..eb8e94e 100644
--- a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json
+++ b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json
@@ -373,9 +373,6 @@
"onnx/node/generated/test_slice_start_out_of_bounds",
"onnx/node/generated/test_stft",
"onnx/node/generated/test_stft_with_window",
- "onnx/node/generated/test_tfidfvectorizer_tf_batch_onlybigrams_skip0",
- "onnx/node/generated/test_tfidfvectorizer_tf_batch_onlybigrams_skip5",
- "onnx/node/generated/test_tfidfvectorizer_tf_batch_uniandbigrams_skip5",
"onnx/node/generated/test_tfidfvectorizer_tf_only_bigrams_skip0",
"onnx/node/generated/test_tfidfvectorizer_tf_onlybigrams_levelempty",
"onnx/node/generated/test_tfidfvectorizer_tf_onlybigrams_skip5",