[vm] Add support for SI64 to F32 casts (#19455)

Adds support to the VM for casting from `si64` type to `f32` type.

Enables the lowering of `arith.sitofp %arg0 : i64 to f32` after
demotion.
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp
index 8e5f96f..6d902e5 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp
@@ -526,30 +526,93 @@
   }
 };
 
-template <typename OpTy, typename ExtOpTy, typename CastOpTy>
-struct IntToFPOpConversion : public OpConversionPattern<OpTy> {
-  using OpConversionPattern<OpTy>::OpConversionPattern;
+struct SIToFPOpConversion : public OpConversionPattern<arith::SIToFPOp> {
+  using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(OpTy srcOp, typename OpTy::Adaptor adaptor,
+  matchAndRewrite(arith::SIToFPOp srcOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto srcType = srcOp.getIn().getType();
+    auto input = srcOp.getIn();
+    auto srcType = input.getType();
     auto dstType = srcOp.getResult().getType();
-    if (!dstType.isF32() ||
-        !(srcType.isSignedInteger() || srcType.isSignlessInteger())) {
+    auto resultType = getTypeConverter()->convertType(dstType);
+
+    if (!(dstType.isF32() || dstType.isF64())) {
       return rewriter.notifyMatchFailure(srcOp, "unsupported type");
     }
-    Value input = srcOp.getIn();
-    if (!(srcType.isSignlessInteger(32) || srcType.isSignedInteger(32))) {
-      if (srcType.getIntOrFloatBitWidth() < 32) {
-        input = rewriter.create<ExtOpTy>(
-            srcOp.getLoc(), IntegerType::get(this->getContext(), 32), input);
-      } else {
+
+    if (srcType.isSignedInteger(32) || srcType.isSignlessInteger(32)) {
+      if (dstType.isF32()) {
+        rewriter.replaceOpWithNewOp<IREE::VM::CastSI32F32Op>(srcOp, resultType,
+                                                             input);
+        return success();
+      }
+      if (dstType.isF64()) {
         return rewriter.notifyMatchFailure(srcOp, "unsupported type");
       }
     }
+    if (srcType.isSignedInteger(64) || srcType.isSignlessInteger(64)) {
+      if (dstType.isF32()) {
+        rewriter.replaceOpWithNewOp<IREE::VM::CastSI64F32Op>(srcOp, resultType,
+                                                             input);
+      } else {
+        rewriter.replaceOpWithNewOp<IREE::VM::CastSI64F64Op>(srcOp, resultType,
+                                                             input);
+      }
+      return success();
+    }
 
-    auto resultType = this->getTypeConverter()->convertType(dstType);
-    rewriter.replaceOpWithNewOp<CastOpTy>(srcOp, resultType, input);
+    if (srcType.getIntOrFloatBitWidth() < 32) {
+      input = rewriter.create<arith::ExtSIOp>(
+          srcOp.getLoc(), IntegerType::get(this->getContext(), 32), input);
+    }
+
+    rewriter.replaceOpWithNewOp<IREE::VM::CastSI32F32Op>(srcOp, resultType,
+                                                         input);
+    return success();
+  }
+};
+
+struct UIToFPOpConversion : public OpConversionPattern<arith::UIToFPOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(arith::UIToFPOp srcOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto input = srcOp.getIn();
+    auto srcType = input.getType();
+    auto dstType = srcOp.getResult().getType();
+
+    if (!(dstType.isF32() || dstType.isF64())) {
+      return rewriter.notifyMatchFailure(srcOp, "unsupported type");
+    }
+
+    auto resultType = getTypeConverter()->convertType(dstType);
+    if (srcType.isUnsignedInteger(32) || srcType.isSignlessInteger(32)) {
+      if (dstType.isF32()) {
+        rewriter.replaceOpWithNewOp<IREE::VM::CastUI32F32Op>(srcOp, resultType,
+                                                             input);
+        return success();
+      }
+      if (dstType.isF64()) {
+        return rewriter.notifyMatchFailure(srcOp, "unsupported type");
+      }
+    }
+    if (srcType.isUnsignedInteger(64) || srcType.isSignlessInteger(64)) {
+      if (dstType.isF32()) {
+        return rewriter.notifyMatchFailure(srcOp, "unsupported type");
+      }
+
+      rewriter.replaceOpWithNewOp<IREE::VM::CastUI64F64Op>(srcOp, resultType,
+                                                           input);
+      return success();
+    }
+
+    if (srcType.getIntOrFloatBitWidth() < 32) {
+      input = rewriter.create<arith::ExtUIOp>(
+          srcOp.getLoc(), IntegerType::get(this->getContext(), 32), input);
+    }
+
+    rewriter.replaceOpWithNewOp<IREE::VM::CastUI32F32Op>(srcOp, resultType,
+                                                         input);
     return success();
   }
 };
@@ -742,12 +805,9 @@
                                    IREE::VM::MaxF64Op>>(typeConverter, context);
 
   // Floating-point conversion ops.
-  patterns.insert<IntToFPOpConversion<arith::SIToFPOp, arith::ExtSIOp,
-                                      IREE::VM::CastSI32F32Op>,
-                  IntToFPOpConversion<arith::UIToFPOp, arith::ExtUIOp,
-                                      IREE::VM::CastUI32F32Op>,
-                  FPToSIOpConversion, FPToUIOpConversion, BitcastOpConversion>(
-      typeConverter, context);
+  patterns.insert<SIToFPOpConversion, UIToFPOpConversion, FPToSIOpConversion,
+                  FPToUIOpConversion, BitcastOpConversion>(typeConverter,
+                                                           context);
 
   // Shift ops.
   patterns
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/conversion_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/conversion_ops.mlir
index be4ec1f..5b8da0b 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/conversion_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/conversion_ops.mlir
@@ -275,6 +275,18 @@
 
 // -----
 
+// CHECK-LABEL: @sitofp_i64_f32
+module @sitofp_i64_f32 {
+  // CHECK: vm.func private @fn(%[[ARG0:.+]]: i64)
+  func.func @fn(%arg0: i64) -> f32 {
+    // CHECK: vm.cast.si64.f32 %[[ARG0]] : i64 -> f32
+    %0 = arith.sitofp %arg0 : i64 to f32
+    return %0 : f32
+  }
+}
+
+// -----
+
 // CHECK-LABEL: @uitofp_i8_f32
 module @uitofp_i8_f32 {
   // CHECK: vm.func private @fn(%[[ARG0:.+]]: i32)
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
index 71eaea6..845c3e3 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
@@ -4521,6 +4521,7 @@
   ADD_GENERIC_PATTERN(IREE::VM::CastF32UI32Op, "vm_cast_f32ui32");
   ADD_GENERIC_PATTERN(IREE::VM::CastF32UI64Op, "vm_cast_f32ui64");
   ADD_GENERIC_PATTERN(IREE::VM::CastSI32F32Op, "vm_cast_si32f32");
+  ADD_GENERIC_PATTERN(IREE::VM::CastSI64F32Op, "vm_cast_si64f32");
   ADD_GENERIC_PATTERN(IREE::VM::CastUI32F32Op, "vm_cast_ui32f32");
   ADD_GENERIC_PATTERN(IREE::VM::CeilF32Op, "vm_ceil_f32");
   ADD_GENERIC_PATTERN(IREE::VM::CmpEQF32OOp, "vm_cmp_eq_f32o");
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index bec93d4..f227292 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -1683,6 +1683,16 @@
       });
 }
 
+OpFoldResult CastSI64F32Op::fold(FoldAdaptor operands) {
+  return constFoldCastOp<IntegerAttr, FloatAttr>(
+      Float32Type::get(getContext()), operands.getOperand(),
+      [&](const APInt &a) {
+        APFloat b = APFloat(0.0f);
+        b.convertFromAPInt(a, /*IsSigned=*/true, APFloat::rmNearestTiesToAway);
+        return b;
+      });
+}
+
 OpFoldResult CastUI32F32Op::fold(FoldAdaptor operands) {
   return constFoldCastOp<IntegerAttr, FloatAttr>(
       Float32Type::get(getContext()), operands.getOperand(),
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td
index af9295f..1ce37ae 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td
@@ -45,6 +45,7 @@
 def VM_OPC_MaxF32                : VM_OPC<0x38, "MaxF32">;
 
 def VM_OPC_CastSI32F32           : VM_OPC<0x14, "CastSI32F32">;
+def VM_OPC_CastSI64F32           : VM_OPC<0x3C, "CastSI64F32">;
 def VM_OPC_CastUI32F32           : VM_OPC<0x15, "CastUI32F32">;
 def VM_OPC_CastF32SI32           : VM_OPC<0x16, "CastF32SI32">;
 def VM_OPC_CastF32SI64           : VM_OPC<0x3A, "CastF32SI64">;
@@ -116,10 +117,12 @@
     VM_OPC_CeilF32,
     VM_OPC_FloorF32,
     VM_OPC_RoundF32,
+    VM_OPC_RoundF32Even,
     VM_OPC_MinF32,
     VM_OPC_MaxF32,
 
     VM_OPC_CastSI32F32,
+    VM_OPC_CastSI64F32,
     VM_OPC_CastUI32F32,
     VM_OPC_CastF32SI32,
     VM_OPC_CastF32SI64,
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
index c23e687..f7e5944 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -3167,6 +3167,13 @@
   let hasFolder = 1;
 }
 
+def VM_CastSI64F32Op :
+    VM_ConversionOp<I64, F32, "cast.si64.f32", VM_OPC_CastSI64F32,
+                    [VM_ExtF32]> {
+  let summary = [{cast from a signed integer to a float-point value}];
+  let hasFolder = 1;
+}
+
 def VM_CastUI64F64Op :
     VM_ConversionOp<I64, F64, "cast.ui64.f64", VM_OPC_CastUI64F64,
                     [VM_ExtF64]> {
diff --git a/runtime/src/iree/vm/bytecode/disassembler.c b/runtime/src/iree/vm/bytecode/disassembler.c
index 02d93e9..4b24843 100644
--- a/runtime/src/iree/vm/bytecode/disassembler.c
+++ b/runtime/src/iree/vm/bytecode/disassembler.c
@@ -2111,6 +2111,16 @@
       EMIT_OPTIONAL_VALUE_I32(regs->i32[operand_reg]);
       break;
     }
+    DISASM_OP(EXT_F32, CastSI64F32) {
+      uint16_t operand_reg = VM_ParseOperandRegI64("operand");
+      uint16_t result_reg = VM_ParseResultRegF32("result");
+      EMIT_F32_REG_NAME(result_reg);
+      IREE_RETURN_IF_ERROR(
+          iree_string_builder_append_cstring(b, " = vm.cast.si64.f32 "));
+      EMIT_I64_REG_NAME(operand_reg);
+      EMIT_OPTIONAL_VALUE_I64(regs->i32[operand_reg]);
+      break;
+    }
     DISASM_OP(EXT_F32, CastUI32F32) {
       uint16_t operand_reg = VM_ParseOperandRegI32("operand");
       uint16_t result_reg = VM_ParseResultRegF32("result");
diff --git a/runtime/src/iree/vm/bytecode/dispatch.c b/runtime/src/iree/vm/bytecode/dispatch.c
index 40ae195..ba48f32 100644
--- a/runtime/src/iree/vm/bytecode/dispatch.c
+++ b/runtime/src/iree/vm/bytecode/dispatch.c
@@ -2046,6 +2046,11 @@
         float* result = VM_DecResultRegF32("result");
         *result = vm_cast_si32f32(operand);
       });
+      DISPATCH_OP(EXT_F32, CastSI64F32, {
+        int64_t operand = (int64_t)VM_DecOperandRegI64("operand");
+        float* result = VM_DecResultRegF32("result");
+        *result = vm_cast_si64f32(operand);
+      });
       DISPATCH_OP(EXT_F32, CastUI32F32, {
         int32_t operand = (int32_t)VM_DecOperandRegI32("operand");
         float* result = VM_DecResultRegF32("result");
diff --git a/runtime/src/iree/vm/bytecode/utils/generated/op_table.h b/runtime/src/iree/vm/bytecode/utils/generated/op_table.h
index 2a5a76c..2c76073 100644
--- a/runtime/src/iree/vm/bytecode/utils/generated/op_table.h
+++ b/runtime/src/iree/vm/bytecode/utils/generated/op_table.h
@@ -388,10 +388,10 @@
     OPC(0x77, AbsI32) \
     OPC(0x78, AbsI64) \
     OPC(0x79, Block) \
-    OPC(0x7A, MinI64S) \
-    OPC(0x7B, MinI64U) \
-    OPC(0x7C, MaxI64S) \
-    OPC(0x7D, MaxI64U) \
+    OPC(0x7A, MinI32S) \
+    OPC(0x7B, MinI32U) \
+    OPC(0x7C, MaxI32S) \
+    OPC(0x7D, MaxI32U) \
     OPC(0x7E, MinI64S) \
     OPC(0x7F, MinI64U) \
     OPC(0x80, MaxI64S) \
@@ -584,7 +584,7 @@
   IREE_VM_OP_EXT_F32_RoundF32Even = 0x39,
   IREE_VM_OP_EXT_F32_CastF32SI64 = 0x3A,
   IREE_VM_OP_EXT_F32_CastF32UI64 = 0x3B,
-  IREE_VM_OP_EXT_F32_RSV_0x3C,
+  IREE_VM_OP_EXT_F32_CastSI64F32 = 0x3C,
   IREE_VM_OP_EXT_F32_RSV_0x3D,
   IREE_VM_OP_EXT_F32_RSV_0x3E,
   IREE_VM_OP_EXT_F32_RSV_0x3F,
@@ -843,7 +843,7 @@
     OPC(0x39, RoundF32Even) \
     OPC(0x3A, CastF32SI64) \
     OPC(0x3B, CastF32UI64) \
-    RSV(0x3C) \
+    OPC(0x3C, CastSI64F32) \
     RSV(0x3D) \
     RSV(0x3E) \
     RSV(0x3F) \
diff --git a/runtime/src/iree/vm/bytecode/verifier.c b/runtime/src/iree/vm/bytecode/verifier.c
index c5b9d63..5c726db 100644
--- a/runtime/src/iree/vm/bytecode/verifier.c
+++ b/runtime/src/iree/vm/bytecode/verifier.c
@@ -1823,6 +1823,10 @@
       VM_VerifyOperandRegI32(operand);
       VM_VerifyResultRegF32(result);
     });
+    VERIFY_OP(EXT_F32, CastSI64F32, {
+      VM_VerifyOperandRegI64(operand);
+      VM_VerifyResultRegF32(result);
+    });
     VERIFY_OP(EXT_F32, CastUI32F32, {
       VM_VerifyOperandRegI32(operand);
       VM_VerifyResultRegF32(result);
diff --git a/runtime/src/iree/vm/ops.h b/runtime/src/iree/vm/ops.h
index b9ffd70..68c939e 100644
--- a/runtime/src/iree/vm/ops.h
+++ b/runtime/src/iree/vm/ops.h
@@ -599,6 +599,7 @@
 //===------------------------------------------------------------------===//
 
 static inline float vm_cast_si32f32(int32_t operand) { return (float)operand; }
+static inline float vm_cast_si64f32(int64_t operand) { return (float)operand; }
 static inline float vm_cast_ui32f32(int32_t operand) {
   return (float)(uint32_t)operand;
 }
diff --git a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json
index 721c34c..eb8e94e 100644
--- a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json
+++ b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json
@@ -373,9 +373,6 @@
     "onnx/node/generated/test_slice_start_out_of_bounds",
     "onnx/node/generated/test_stft",
     "onnx/node/generated/test_stft_with_window",
-    "onnx/node/generated/test_tfidfvectorizer_tf_batch_onlybigrams_skip0",
-    "onnx/node/generated/test_tfidfvectorizer_tf_batch_onlybigrams_skip5",
-    "onnx/node/generated/test_tfidfvectorizer_tf_batch_uniandbigrams_skip5",
     "onnx/node/generated/test_tfidfvectorizer_tf_only_bigrams_skip0",
     "onnx/node/generated/test_tfidfvectorizer_tf_onlybigrams_levelempty",
     "onnx/node/generated/test_tfidfvectorizer_tf_onlybigrams_skip5",