Fix some bugs in consteval related to different type combinations. (#14534)

With these fixes, the test suite builds successfully (with constant
hoisting enabled).
diff --git a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
index 39f26ae..003fd64 100644
--- a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
+++ b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
@@ -324,13 +324,12 @@
     hasRequestedTargetBackend =
         targetRegistry.getTargetBackend(requestedTargetBackend) != nullptr;
     options->executableOptions.targets.push_back(requestedTargetBackend);
+    options->targetOptions.f32Extension = true;
+    options->targetOptions.f64Extension = true;
     if (requestedTargetBackend == "vmvx" || !hasRequestedTargetBackend) {
       targetBackend = targetRegistry.getTargetBackend("vmvx");
-      options->targetOptions.f32Extension = true;
-      options->targetOptions.f64Extension = false; // not yet implemented
     } else {
       targetBackend = targetRegistry.getTargetBackend(requestedTargetBackend);
-      // options->executableOptions.targets.push_back(requestedTargetBackend);
     }
 
     // Disable constant evaluation for our Jit compilation pipeline.
@@ -443,8 +442,9 @@
         switch (resultBinding.getType()) {
         case ResultBinding::Type::GlobalOp: {
           TypedAttr attr;
-          if (failed(call.getResultAsAttr(resultBinding.getGlobalOp().getLoc(),
-                                          it.index(), attr)))
+          if (failed(call.getResultAsAttr(
+                  resultBinding.getGlobalOp().getLoc(), it.index(),
+                  resultBinding.getGlobalOp().getType(), attr)))
             return failure();
           resultBinding.getGlobalOp().setInitialValueAttr(attr);
           break;
diff --git a/compiler/src/iree/compiler/ConstEval/Runtime.cpp b/compiler/src/iree/compiler/ConstEval/Runtime.cpp
index 967e503..885697e 100644
--- a/compiler/src/iree/compiler/ConstEval/Runtime.cpp
+++ b/compiler/src/iree/compiler/ConstEval/Runtime.cpp
@@ -56,6 +56,9 @@
     case 4:
       *outElementType = IREE_HAL_ELEMENT_TYPE_INT_4;
       return success();
+    case 1:
+      *outElementType = IREE_HAL_ELEMENT_TYPE_BOOL_8;
+      return success();
     }
   } else if (baseType == builder.getF32Type()) {
     *outElementType = IREE_HAL_ELEMENT_TYPE_FLOAT_32;
@@ -118,20 +121,13 @@
 
   // For i1, IREE (currently) returns these as 8bit integer values and MLIR
   // has a loader that accepts bool arrays (the raw buffer loader also
-  // supports them but bit-packed, which is not convenient for us). So, if
-  // sizeof(bool) == 1, we just bit-cast. Otherwise, we go through a temporary.
+  // supports them but bit-packed, which is not convenient for us).
   if (elementType.isInteger(1)) {
-    if (sizeof(bool) == 1) {
-      ArrayRef<bool> boolArray(reinterpret_cast<bool *>(rawBuffer.data()),
-                               rawBuffer.size());
-      return DenseElementsAttr::get(tensorType, boolArray);
-    } else {
-      // Note: cannot use std::vector because it specializes bool in a way
-      // that is not compatible with ArrayRef.
-      llvm::SmallVector<bool> boolVector(rawBuffer.begin(), rawBuffer.end());
-      ArrayRef<bool> boolArray(boolVector.data(), boolVector.size());
-      return DenseElementsAttr::get(tensorType, boolArray);
-    }
+    // Note: cannot use std::vector because it specializes bool in a way
+    // that is not compatible with ArrayRef.
+    llvm::SmallVector<bool> boolVector(rawBuffer.begin(), rawBuffer.end());
+    ArrayRef<bool> boolArray(boolVector.data(), boolVector.size());
+    return DenseElementsAttr::get(tensorType, boolArray);
   }
 
   emitError(loc) << "unhandled case when converting raw buffer of "
@@ -163,8 +159,6 @@
                                     &outputs));
 }
 
-// Imports or snapshots a raw host buffer, depending on whether import is
-// possible.
 LogicalResult FunctionCall::importBufferForRead(Location loc,
                                                 const uint8_t *rawData,
                                                 iree_host_size_t length,
@@ -207,6 +201,44 @@
   }
 }
 
+LogicalResult FunctionCall::importBitwiseBoolI8BufferForRead(
+    Location loc, const uint8_t *rawDataBits,
+    iree_host_size_t rawDataLengthBytes, iree_hal_buffer_t **buffer) {
+  iree_hal_buffer_params_t params;
+  std::memset(&params, 0, sizeof(params));
+  iree_host_size_t bufferLength = rawDataLengthBytes * 8;
+  params.type =
+      IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE;
+  if (failed(handleRuntimeError(
+          loc, iree_hal_allocator_allocate_buffer(
+                   binary.getAllocator(), params, bufferLength,
+                   iree_const_byte_span_t{nullptr, 0}, buffer))))
+    return failure();
+
+  iree_hal_buffer_mapping_t mapping;
+  if (failed(handleRuntimeError(
+          loc, iree_hal_buffer_map_range(
+                   *buffer, IREE_HAL_MAPPING_MODE_SCOPED,
+                   IREE_HAL_MEMORY_ACCESS_WRITE, /*byte_offset=*/0,
+                   /*byte_length=*/bufferLength, &mapping))))
+    return failure();
+
+  // Copy.
+  for (iree_host_size_t i = 0; i < rawDataLengthBytes; ++i) {
+    uint8_t bits = rawDataBits[i];
+    mapping.contents.data[i * 8 + 0] = bits & 0x1;
+    mapping.contents.data[i * 8 + 1] = (bits & 0x2) >> 1;
+    mapping.contents.data[i * 8 + 2] = (bits & 0x4) >> 2;
+    mapping.contents.data[i * 8 + 3] = (bits & 0x8) >> 3;
+    mapping.contents.data[i * 8 + 4] = (bits & 0x10) >> 4;
+    mapping.contents.data[i * 8 + 5] = (bits & 0x20) >> 5;
+    mapping.contents.data[i * 8 + 6] = (bits & 0x40) >> 6;
+    mapping.contents.data[i * 8 + 7] = (bits & 0x80) >> 7;
+  }
+
+  return handleRuntimeError(loc, iree_hal_buffer_unmap_range(&mapping));
+}
+
 LogicalResult FunctionCall::addArgument(Location loc, Attribute attr) {
   if (auto elementsAttr = llvm::dyn_cast<DenseElementsAttr>(attr)) {
     // Meta-data.
@@ -219,8 +251,10 @@
     for (size_t i = 0; i < rank; ++i) {
       shape[i] = stShape[i];
     }
+    Type mlirElementType = st.getElementType();
+    bool isI1 = mlirElementType == IntegerType::get(loc.getContext(), 1);
     iree_hal_element_type_t elementType = IREE_HAL_ELEMENT_TYPE_NONE;
-    if (failed(convertToElementType(loc, st.getElementType(), &elementType)))
+    if (failed(convertToElementType(loc, mlirElementType, &elementType)))
       return failure();
 
     iree::vm::ref<iree_hal_buffer_t> buffer;
@@ -241,6 +275,19 @@
                        buffer.get(), 0, bufferSize,
                        static_cast<const void *>(data.data()), data.size()))))
         return failure();
+    } else if (isI1) {
+      // Dense, non-splat i1.
+      // MLIR DenseElementsAttr made the interesting optimization choice to
+      // densely pack i1 as a bit-vector. It doesn't do this for any other
+      // sub-byte type, and it is aligned linearly (not row-wise), so is
+      // a complete special case.
+      // Since we map this to an 8bit bool on the IREE runtime side, we
+      // just do the best we can when allocating.
+      if (failed(importBitwiseBoolI8BufferForRead(
+              loc, reinterpret_cast<const uint8_t *>(data.data()), data.size(),
+              &buffer))) {
+        return failure();
+      }
     } else {
       // Dense, non-splat.
       if (failed(importBufferForRead(
@@ -260,6 +307,52 @@
 
     return handleRuntimeError(
         loc, iree_vm_list_push_ref_move(inputs.get(), std::move(bv)));
+  } else if (auto integerAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
+    iree_vm_value_t value;
+    APInt apValue = integerAttr.getValue();
+    switch (apValue.getBitWidth()) {
+    case 8:
+      value =
+          iree_vm_value_make_i8(static_cast<uint8_t>(apValue.getZExtValue()));
+      break;
+    case 16:
+      value =
+          iree_vm_value_make_i16(static_cast<uint16_t>(apValue.getZExtValue()));
+      break;
+    case 32:
+      value =
+          iree_vm_value_make_i32(static_cast<uint32_t>(apValue.getZExtValue()));
+      break;
+    case 64:
+      value =
+          iree_vm_value_make_i64(static_cast<uint64_t>(apValue.getZExtValue()));
+      break;
+    default:
+      return emitError(loc) << "internal error: unsupported consteval jit "
+                               "function integer input type ("
+                            << attr << ")";
+    }
+    return handleRuntimeError(loc,
+                              iree_vm_list_push_value(inputs.get(), &value));
+  } else if (auto floatAttr = llvm::dyn_cast<FloatAttr>(attr)) {
+    iree_vm_value_t value;
+    APFloat apValue = floatAttr.getValue();
+    // Note that there are many floating point semantics that LLVM knows about,
+    // but we restrict to only those that the VM natively supports here.
+    switch (APFloat::SemanticsToEnum(apValue.getSemantics())) {
+    case APFloat::S_IEEEsingle:
+      value = iree_vm_value_make_f32(apValue.convertToFloat());
+      break;
+    case APFloat::S_IEEEdouble:
+      value = iree_vm_value_make_f64(apValue.convertToDouble());
+      break;
+    default:
+      return emitError(loc) << "internal error: unsupported consteval jit "
+                               "function float input type ("
+                            << attr << ")";
+    }
+    return handleRuntimeError(loc,
+                              iree_vm_list_push_value(inputs.get(), &value));
   }
 
   return emitError(loc)
@@ -288,34 +381,30 @@
 }
 
 LogicalResult FunctionCall::getResultAsAttr(Location loc, size_t index,
-                                            TypedAttr &outAttr) {
+                                            Type mlirType, TypedAttr &outAttr) {
   iree_vm_variant_t variant = iree_vm_variant_empty();
   if (failed(handleRuntimeError(loc, iree_vm_list_get_variant_assign(
                                          outputs.get(), index, &variant))))
     return failure();
 
-  outAttr = binary.convertVariantToAttribute(loc, variant);
+  outAttr = binary.convertVariantToAttribute(loc, variant, mlirType);
   if (!outAttr)
     return failure();
 
   return success();
 }
 
-TypedAttr
-CompiledBinary::convertVariantToAttribute(Location loc,
-                                          iree_vm_variant_t &variant) {
+TypedAttr CompiledBinary::convertVariantToAttribute(Location loc,
+                                                    iree_vm_variant_t &variant,
+                                                    Type mlirType) {
   auto context = loc.getContext();
   Builder builder(context);
   if (iree_vm_variant_is_value(variant)) {
     switch (iree_vm_type_def_as_value(variant.type)) {
-    case IREE_VM_VALUE_TYPE_I8:
-      return builder.getI8IntegerAttr(variant.i8);
-    case IREE_VM_VALUE_TYPE_I16:
-      return builder.getI16IntegerAttr(variant.i16);
     case IREE_VM_VALUE_TYPE_I32:
-      return builder.getI32IntegerAttr(variant.i32);
+      return builder.getIntegerAttr(mlirType, variant.i32);
     case IREE_VM_VALUE_TYPE_I64:
-      return builder.getI64IntegerAttr(variant.i64);
+      return builder.getIntegerAttr(mlirType, variant.i64);
     case IREE_VM_VALUE_TYPE_F32:
       return builder.getF32FloatAttr(variant.f32);
     case IREE_VM_VALUE_TYPE_F64:
diff --git a/compiler/src/iree/compiler/ConstEval/Runtime.h b/compiler/src/iree/compiler/ConstEval/Runtime.h
index be7259a..5637e86 100644
--- a/compiler/src/iree/compiler/ConstEval/Runtime.h
+++ b/compiler/src/iree/compiler/ConstEval/Runtime.h
@@ -40,7 +40,8 @@
   // explicitly by subclasses, ensuring that any backing images remain valid
   // through the call to deinitialize().
   void deinitialize();
-  TypedAttr convertVariantToAttribute(Location loc, iree_vm_variant_t &variant);
+  TypedAttr convertVariantToAttribute(Location loc, iree_vm_variant_t &variant,
+                                      Type mlirType);
 
   iree::vm::ref<iree_hal_device_t> device;
   iree::vm::ref<iree_vm_module_t> hal_module;
@@ -57,12 +58,22 @@
 
   LogicalResult addArgument(Location loc, Attribute attr);
   LogicalResult invoke(Location loc, StringRef name);
-  LogicalResult getResultAsAttr(Location loc, size_t index, TypedAttr &outAttr);
+  LogicalResult getResultAsAttr(Location loc, size_t index, Type mlirType,
+                                TypedAttr &outAttr);
 
 private:
+  // Imports or snapshots a raw host buffer, depending on whether import is
+  // possible. This should only be used when the MLIR and IREE layout
+  // agree.
   LogicalResult importBufferForRead(Location loc, const uint8_t *rawData,
                                     iree_host_size_t length,
                                     iree_hal_buffer_t **buffer);
+  // Imports a bit vector of rawData into a byte buffer, expanding 1->8bit
+  // during import.
+  LogicalResult
+  importBitwiseBoolI8BufferForRead(Location loc, const uint8_t *rawDataBits,
+                                   iree_host_size_t rawDataLengthBytes,
+                                   iree_hal_buffer_t **buffer);
   CompiledBinary binary;
   iree::vm::ref<iree_vm_list_t> inputs;
   iree::vm::ref<iree_vm_list_t> outputs;
diff --git a/compiler/src/iree/compiler/ConstEval/test/BUILD.bazel b/compiler/src/iree/compiler/ConstEval/test/BUILD.bazel
index 6b56ebe..7888ae2 100644
--- a/compiler/src/iree/compiler/ConstEval/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/ConstEval/test/BUILD.bazel
@@ -16,7 +16,10 @@
     name = "lit",
     srcs = enforce_glob(
         [
+            "compile_regressions.mlir",
+            "failing.mlir",
             "jit_globals.mlir",
+            "scalar_values.mlir",
         ],
         include = ["*.mlir"],
     ),
diff --git a/compiler/src/iree/compiler/ConstEval/test/CMakeLists.txt b/compiler/src/iree/compiler/ConstEval/test/CMakeLists.txt
index 79fe3d9..e14f38b 100644
--- a/compiler/src/iree/compiler/ConstEval/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/ConstEval/test/CMakeLists.txt
@@ -14,7 +14,10 @@
   NAME
     lit
   SRCS
+    "compile_regressions.mlir"
+    "failing.mlir"
     "jit_globals.mlir"
+    "scalar_values.mlir"
   TOOLS
     FileCheck
     iree-opt
diff --git a/compiler/src/iree/compiler/ConstEval/test/compile_regressions.mlir b/compiler/src/iree/compiler/ConstEval/test/compile_regressions.mlir
new file mode 100644
index 0000000..903ed6c
--- /dev/null
+++ b/compiler/src/iree/compiler/ConstEval/test/compile_regressions.mlir
@@ -0,0 +1,49 @@
+// RUN: iree-opt --split-input-file --verify-diagnostics --iree-consteval-jit-debug --iree-consteval-jit-globals  %s | FileCheck %s
+
+// Test case reduced by running the pass --iree-util-hoist-into-globals on the
+// following (and then chang the check to a return):
+// func.func @i1_inline_constant() {
+//   %control = arith.constant dense<[true, false, true, false]> : tensor<4xi1>
+//   %a = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+//   %b = arith.constant dense<[5, 6, 7, 8]> : tensor<4xi32>
+//   %init = tensor.empty() : tensor<4xi32>
+//   %c = linalg.generic {
+//       indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>,
+//                        affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+//       iterator_types = ["parallel"]}
+//       ins(%control, %a, %b : tensor<4xi1>, tensor<4xi32>, tensor<4xi32>)
+//       outs(%init : tensor<4xi32>) {
+//     ^bb0(%b1 : i1, %b2 : i32, %b3 : i32, %b4 : i32):
+//       %0 = arith.select %b1, %b2, %b3 : i32
+//       linalg.yield %0 : i32
+//     } -> tensor<4xi32>
+//   check.expect_eq_const(%c, dense<[1, 6, 3, 8]> : tensor<4xi32>) : tensor<4xi32>
+//   return
+// }
+
+// CHECK-LABEL: module @hoisted_tensor_i1_input
+// Verify the original check based on constant folding.
+// CHECK: = dense<[1, 6, 3, 8]>
+#map = affine_map<(d0) -> (d0)>
+module @hoisted_tensor_i1_input {
+  util.global private @hoisted : tensor<4xi32>
+  func.func @i1_inline_constant() -> tensor<4xi32> {
+    %hoisted = util.global.load @hoisted : tensor<4xi32>
+    return %hoisted : tensor<4xi32>
+  }
+  util.initializer attributes {iree.compiler.consteval} {
+    %cst = arith.constant dense<[true, false, true, false]> : tensor<4xi1>
+    %cst_0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+    %cst_1 = arith.constant dense<[5, 6, 7, 8]> : tensor<4xi32>
+    %0 = tensor.empty() : tensor<4xi32>
+    %1 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%cst, %cst_0, %cst_1 : tensor<4xi1>, tensor<4xi32>, tensor<4xi32>) outs(%0 : tensor<4xi32>) {
+    ^bb0(%in: i1, %in_2: i32, %in_3: i32, %out: i32):
+      %2 = arith.select %in, %in_2, %in_3 : i32
+      linalg.yield %2 : i32
+    } -> tensor<4xi32>
+    util.global.store %1, @hoisted : tensor<4xi32>
+    util.initializer.return
+  }
+}
+
+// -----
diff --git a/compiler/src/iree/compiler/ConstEval/test/failing.mlir b/compiler/src/iree/compiler/ConstEval/test/failing.mlir
new file mode 100644
index 0000000..71edad7
--- /dev/null
+++ b/compiler/src/iree/compiler/ConstEval/test/failing.mlir
@@ -0,0 +1,20 @@
+// RUN: iree-opt --split-input-file --iree-consteval-jit-target-backend=vmvx --verify-diagnostics --iree-consteval-jit-debug --iree-consteval-jit-globals  %s | FileCheck %s
+// XFAIL: *
+
+// CHECK-LABEL: @eval_f64_scalar
+// CHECK: 4.200000e+01 : f64
+module @eval_i64_scalar {
+  util.global private @offset : f64 = -2.0 : f64
+  util.global private @hoisted : f64
+  func.func @main() -> f64 {
+    %hoisted = util.global.load @hoisted : f64
+    return %hoisted : f64
+  }
+  util.initializer attributes {iree.compiler.consteval} {
+    %cst = arith.constant 44.0 : f64
+    %offset = util.global.load @offset : f64
+    %sum = arith.addf %cst, %offset : f64
+    util.global.store %sum, @hoisted : f64
+    util.initializer.return
+  }
+}
diff --git a/compiler/src/iree/compiler/ConstEval/test/scalar_values.mlir b/compiler/src/iree/compiler/ConstEval/test/scalar_values.mlir
new file mode 100644
index 0000000..513d0e9
--- /dev/null
+++ b/compiler/src/iree/compiler/ConstEval/test/scalar_values.mlir
@@ -0,0 +1,95 @@
+// RUN: iree-opt --split-input-file --iree-consteval-jit-target-backend=vmvx --verify-diagnostics --iree-consteval-jit-debug --iree-consteval-jit-globals  %s | FileCheck %s
+
+// CHECK-LABEL: @eval_i8_scalar
+// CHECK: 42 : i8
+module @eval_i8_scalar {
+  util.global private @offset : i8 = -2 : i8
+  util.global private @hoisted : i8
+  func.func @main() -> i8 {
+    %hoisted = util.global.load @hoisted : i8
+    return %hoisted : i8
+  }
+  util.initializer attributes {iree.compiler.consteval} {
+    %cst = arith.constant 44 : i8
+    %offset = util.global.load @offset : i8
+    %sum = arith.addi %cst, %offset : i8
+    util.global.store %sum, @hoisted : i8
+    util.initializer.return
+  }
+}
+
+// -----
+// CHECK-LABEL: @eval_i16_scalar
+// CHECK: 42 : i16
+module @eval_i16_scalar {
+  util.global private @offset : i16 = -2 : i16
+  util.global private @hoisted : i16
+  func.func @main() -> i16 {
+    %hoisted = util.global.load @hoisted : i16
+    return %hoisted : i16
+  }
+  util.initializer attributes {iree.compiler.consteval} {
+    %cst = arith.constant 44 : i16
+    %offset = util.global.load @offset : i16
+    %sum = arith.addi %cst, %offset : i16
+    util.global.store %sum, @hoisted : i16
+    util.initializer.return
+  }
+}
+
+// -----
+// CHECK-LABEL: @eval_i32_scalar
+// CHECK: 42 : i32
+module @eval_i32_scalar {
+  util.global private @offset : i32 = -2 : i32
+  util.global private @hoisted : i32
+  func.func @main() -> i32 {
+    %hoisted = util.global.load @hoisted : i32
+    return %hoisted : i32
+  }
+  util.initializer attributes {iree.compiler.consteval} {
+    %cst = arith.constant 44 : i32
+    %offset = util.global.load @offset : i32
+    %sum = arith.addi %cst, %offset : i32
+    util.global.store %sum, @hoisted : i32
+    util.initializer.return
+  }
+}
+
+// -----
+// CHECK-LABEL: @eval_i64_scalar
+// CHECK: 42 : i64
+module @eval_i64_scalar {
+  util.global private @offset : i64 = -2 : i64
+  util.global private @hoisted : i64
+  func.func @main() -> i64 {
+    %hoisted = util.global.load @hoisted : i64
+    return %hoisted : i64
+  }
+  util.initializer attributes {iree.compiler.consteval} {
+    %cst = arith.constant 44 : i64
+    %offset = util.global.load @offset : i64
+    %sum = arith.addi %cst, %offset : i64
+    util.global.store %sum, @hoisted : i64
+    util.initializer.return
+  }
+}
+
+// -----
+// CHECK-LABEL: @eval_f32_scalar
+// CHECK: 4.200000e+01 : f32
+module @eval_f32_scalar {
+  util.global private @offset : f32 = -2.0 : f32
+  util.global private @hoisted : f32
+  func.func @main() -> f32 {
+    %hoisted = util.global.load @hoisted : f32
+    return %hoisted : f32
+  }
+  util.initializer attributes {iree.compiler.consteval} {
+    %cst = arith.constant 44.0 : f32
+    %offset = util.global.load @offset : f32
+    %sum = arith.addf %cst, %offset : f32
+    util.global.store %sum, @hoisted : f32
+    util.initializer.return
+  }
+}