Fix some bugs in consteval related to different type combinations. (#14534)
With these fixes, the test suite builds successfully (with constant
hoisting enabled).
diff --git a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
index 39f26ae..003fd64 100644
--- a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
+++ b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
@@ -324,13 +324,12 @@
hasRequestedTargetBackend =
targetRegistry.getTargetBackend(requestedTargetBackend) != nullptr;
options->executableOptions.targets.push_back(requestedTargetBackend);
+ options->targetOptions.f32Extension = true;
+ options->targetOptions.f64Extension = true;
if (requestedTargetBackend == "vmvx" || !hasRequestedTargetBackend) {
targetBackend = targetRegistry.getTargetBackend("vmvx");
- options->targetOptions.f32Extension = true;
- options->targetOptions.f64Extension = false; // not yet implemented
} else {
targetBackend = targetRegistry.getTargetBackend(requestedTargetBackend);
- // options->executableOptions.targets.push_back(requestedTargetBackend);
}
// Disable constant evaluation for our Jit compilation pipeline.
@@ -443,8 +442,9 @@
switch (resultBinding.getType()) {
case ResultBinding::Type::GlobalOp: {
TypedAttr attr;
- if (failed(call.getResultAsAttr(resultBinding.getGlobalOp().getLoc(),
- it.index(), attr)))
+ if (failed(call.getResultAsAttr(
+ resultBinding.getGlobalOp().getLoc(), it.index(),
+ resultBinding.getGlobalOp().getType(), attr)))
return failure();
resultBinding.getGlobalOp().setInitialValueAttr(attr);
break;
diff --git a/compiler/src/iree/compiler/ConstEval/Runtime.cpp b/compiler/src/iree/compiler/ConstEval/Runtime.cpp
index 967e503..885697e 100644
--- a/compiler/src/iree/compiler/ConstEval/Runtime.cpp
+++ b/compiler/src/iree/compiler/ConstEval/Runtime.cpp
@@ -56,6 +56,9 @@
case 4:
*outElementType = IREE_HAL_ELEMENT_TYPE_INT_4;
return success();
+ case 1:
+ *outElementType = IREE_HAL_ELEMENT_TYPE_BOOL_8;
+ return success();
}
} else if (baseType == builder.getF32Type()) {
*outElementType = IREE_HAL_ELEMENT_TYPE_FLOAT_32;
@@ -118,20 +121,13 @@
// For i1, IREE (currently) returns these as 8bit integer values and MLIR
// has a loader that accepts bool arrays (the raw buffer loader also
- // supports them but bit-packed, which is not convenient for us). So, if
- // sizeof(bool) == 1, we just bit-cast. Otherwise, we go through a temporary.
+ // supports them but bit-packed, which is not convenient for us).
if (elementType.isInteger(1)) {
- if (sizeof(bool) == 1) {
- ArrayRef<bool> boolArray(reinterpret_cast<bool *>(rawBuffer.data()),
- rawBuffer.size());
- return DenseElementsAttr::get(tensorType, boolArray);
- } else {
- // Note: cannot use std::vector because it specializes bool in a way
- // that is not compatible with ArrayRef.
- llvm::SmallVector<bool> boolVector(rawBuffer.begin(), rawBuffer.end());
- ArrayRef<bool> boolArray(boolVector.data(), boolVector.size());
- return DenseElementsAttr::get(tensorType, boolArray);
- }
+ // Note: cannot use std::vector because it specializes bool in a way
+ // that is not compatible with ArrayRef.
+ llvm::SmallVector<bool> boolVector(rawBuffer.begin(), rawBuffer.end());
+ ArrayRef<bool> boolArray(boolVector.data(), boolVector.size());
+ return DenseElementsAttr::get(tensorType, boolArray);
}
emitError(loc) << "unhandled case when converting raw buffer of "
@@ -163,8 +159,6 @@
&outputs));
}
-// Imports or snapshots a raw host buffer, depending on whether import is
-// possible.
LogicalResult FunctionCall::importBufferForRead(Location loc,
const uint8_t *rawData,
iree_host_size_t length,
@@ -207,6 +201,44 @@
}
}
+LogicalResult FunctionCall::importBitwiseBoolI8BufferForRead(
+ Location loc, const uint8_t *rawDataBits,
+ iree_host_size_t rawDataLengthBytes, iree_hal_buffer_t **buffer) {
+ iree_hal_buffer_params_t params;
+ std::memset(¶ms, 0, sizeof(params));
+ iree_host_size_t bufferLength = rawDataLengthBytes * 8;
+ params.type =
+ IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE;
+ if (failed(handleRuntimeError(
+ loc, iree_hal_allocator_allocate_buffer(
+ binary.getAllocator(), params, bufferLength,
+ iree_const_byte_span_t{nullptr, 0}, buffer))))
+ return failure();
+
+ iree_hal_buffer_mapping_t mapping;
+ if (failed(handleRuntimeError(
+ loc, iree_hal_buffer_map_range(
+ *buffer, IREE_HAL_MAPPING_MODE_SCOPED,
+ IREE_HAL_MEMORY_ACCESS_WRITE, /*byte_offset=*/0,
+ /*byte_length=*/bufferLength, &mapping))))
+ return failure();
+
+ // Copy.
+ for (iree_host_size_t i = 0; i < rawDataLengthBytes; ++i) {
+ uint8_t bits = rawDataBits[i];
+ mapping.contents.data[i * 8 + 0] = bits & 0x1;
+ mapping.contents.data[i * 8 + 1] = (bits & 0x2) >> 1;
+ mapping.contents.data[i * 8 + 2] = (bits & 0x4) >> 2;
+ mapping.contents.data[i * 8 + 3] = (bits & 0x8) >> 3;
+ mapping.contents.data[i * 8 + 4] = (bits & 0x10) >> 4;
+ mapping.contents.data[i * 8 + 5] = (bits & 0x20) >> 5;
+ mapping.contents.data[i * 8 + 6] = (bits & 0x40) >> 6;
+ mapping.contents.data[i * 8 + 7] = (bits & 0x80) >> 7;
+ }
+
+ return handleRuntimeError(loc, iree_hal_buffer_unmap_range(&mapping));
+}
+
LogicalResult FunctionCall::addArgument(Location loc, Attribute attr) {
if (auto elementsAttr = llvm::dyn_cast<DenseElementsAttr>(attr)) {
// Meta-data.
@@ -219,8 +251,10 @@
for (size_t i = 0; i < rank; ++i) {
shape[i] = stShape[i];
}
+ Type mlirElementType = st.getElementType();
+ bool isI1 = mlirElementType == IntegerType::get(loc.getContext(), 1);
iree_hal_element_type_t elementType = IREE_HAL_ELEMENT_TYPE_NONE;
- if (failed(convertToElementType(loc, st.getElementType(), &elementType)))
+ if (failed(convertToElementType(loc, mlirElementType, &elementType)))
return failure();
iree::vm::ref<iree_hal_buffer_t> buffer;
@@ -241,6 +275,19 @@
buffer.get(), 0, bufferSize,
static_cast<const void *>(data.data()), data.size()))))
return failure();
+ } else if (isI1) {
+ // Dense, non-splat i1.
+ // MLIR DenseElementsAttr made the interesting optimization choice to
+ // densely pack i1 as a bit-vector. It doesn't do this for any other
+ // sub-byte type, and it is aligned linearly (not row-wise), so is
+ // a complete special case.
+ // Since we map this to an 8bit bool on the IREE runtime side, we
+ // just do the best we can when allocating.
+ if (failed(importBitwiseBoolI8BufferForRead(
+ loc, reinterpret_cast<const uint8_t *>(data.data()), data.size(),
+ &buffer))) {
+ return failure();
+ }
} else {
// Dense, non-splat.
if (failed(importBufferForRead(
@@ -260,6 +307,52 @@
return handleRuntimeError(
loc, iree_vm_list_push_ref_move(inputs.get(), std::move(bv)));
+ } else if (auto integerAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
+ iree_vm_value_t value;
+ APInt apValue = integerAttr.getValue();
+ switch (apValue.getBitWidth()) {
+ case 8:
+ value =
+ iree_vm_value_make_i8(static_cast<uint8_t>(apValue.getZExtValue()));
+ break;
+ case 16:
+ value =
+ iree_vm_value_make_i16(static_cast<uint16_t>(apValue.getZExtValue()));
+ break;
+ case 32:
+ value =
+ iree_vm_value_make_i32(static_cast<uint32_t>(apValue.getZExtValue()));
+ break;
+ case 64:
+ value =
+ iree_vm_value_make_i64(static_cast<uint64_t>(apValue.getZExtValue()));
+ break;
+ default:
+ return emitError(loc) << "internal error: unsupported consteval jit "
+ "function integer input type ("
+ << attr << ")";
+ }
+ return handleRuntimeError(loc,
+ iree_vm_list_push_value(inputs.get(), &value));
+ } else if (auto floatAttr = llvm::dyn_cast<FloatAttr>(attr)) {
+ iree_vm_value_t value;
+ APFloat apValue = floatAttr.getValue();
+ // Note that there are many floating point semantics that LLVM knows about,
+ // but we restrict to only those that the VM natively supports here.
+ switch (APFloat::SemanticsToEnum(apValue.getSemantics())) {
+ case APFloat::S_IEEEsingle:
+ value = iree_vm_value_make_f32(apValue.convertToFloat());
+ break;
+ case APFloat::S_IEEEdouble:
+ value = iree_vm_value_make_f64(apValue.convertToDouble());
+ break;
+ default:
+ return emitError(loc) << "internal error: unsupported consteval jit "
+ "function float input type ("
+ << attr << ")";
+ }
+ return handleRuntimeError(loc,
+ iree_vm_list_push_value(inputs.get(), &value));
}
return emitError(loc)
@@ -288,34 +381,30 @@
}
LogicalResult FunctionCall::getResultAsAttr(Location loc, size_t index,
- TypedAttr &outAttr) {
+ Type mlirType, TypedAttr &outAttr) {
iree_vm_variant_t variant = iree_vm_variant_empty();
if (failed(handleRuntimeError(loc, iree_vm_list_get_variant_assign(
outputs.get(), index, &variant))))
return failure();
- outAttr = binary.convertVariantToAttribute(loc, variant);
+ outAttr = binary.convertVariantToAttribute(loc, variant, mlirType);
if (!outAttr)
return failure();
return success();
}
-TypedAttr
-CompiledBinary::convertVariantToAttribute(Location loc,
- iree_vm_variant_t &variant) {
+TypedAttr CompiledBinary::convertVariantToAttribute(Location loc,
+ iree_vm_variant_t &variant,
+ Type mlirType) {
auto context = loc.getContext();
Builder builder(context);
if (iree_vm_variant_is_value(variant)) {
switch (iree_vm_type_def_as_value(variant.type)) {
- case IREE_VM_VALUE_TYPE_I8:
- return builder.getI8IntegerAttr(variant.i8);
- case IREE_VM_VALUE_TYPE_I16:
- return builder.getI16IntegerAttr(variant.i16);
case IREE_VM_VALUE_TYPE_I32:
- return builder.getI32IntegerAttr(variant.i32);
+ return builder.getIntegerAttr(mlirType, variant.i32);
case IREE_VM_VALUE_TYPE_I64:
- return builder.getI64IntegerAttr(variant.i64);
+ return builder.getIntegerAttr(mlirType, variant.i64);
case IREE_VM_VALUE_TYPE_F32:
return builder.getF32FloatAttr(variant.f32);
case IREE_VM_VALUE_TYPE_F64:
diff --git a/compiler/src/iree/compiler/ConstEval/Runtime.h b/compiler/src/iree/compiler/ConstEval/Runtime.h
index be7259a..5637e86 100644
--- a/compiler/src/iree/compiler/ConstEval/Runtime.h
+++ b/compiler/src/iree/compiler/ConstEval/Runtime.h
@@ -40,7 +40,8 @@
// explicitly by subclasses, ensuring that any backing images remain valid
// through the call to deinitialize().
void deinitialize();
- TypedAttr convertVariantToAttribute(Location loc, iree_vm_variant_t &variant);
+ TypedAttr convertVariantToAttribute(Location loc, iree_vm_variant_t &variant,
+ Type mlirType);
iree::vm::ref<iree_hal_device_t> device;
iree::vm::ref<iree_vm_module_t> hal_module;
@@ -57,12 +58,22 @@
LogicalResult addArgument(Location loc, Attribute attr);
LogicalResult invoke(Location loc, StringRef name);
- LogicalResult getResultAsAttr(Location loc, size_t index, TypedAttr &outAttr);
+ LogicalResult getResultAsAttr(Location loc, size_t index, Type mlirType,
+ TypedAttr &outAttr);
private:
+ // Imports or snapshots a raw host buffer, depending on whether import is
+ // possible. This should only be used when the MLIR and IREE layout
+ // agree.
LogicalResult importBufferForRead(Location loc, const uint8_t *rawData,
iree_host_size_t length,
iree_hal_buffer_t **buffer);
+ // Imports a bit vector of rawData into a byte buffer, expanding 1->8bit
+ // during import.
+ LogicalResult
+ importBitwiseBoolI8BufferForRead(Location loc, const uint8_t *rawDataBits,
+ iree_host_size_t rawDataLengthBytes,
+ iree_hal_buffer_t **buffer);
CompiledBinary binary;
iree::vm::ref<iree_vm_list_t> inputs;
iree::vm::ref<iree_vm_list_t> outputs;
diff --git a/compiler/src/iree/compiler/ConstEval/test/BUILD.bazel b/compiler/src/iree/compiler/ConstEval/test/BUILD.bazel
index 6b56ebe..7888ae2 100644
--- a/compiler/src/iree/compiler/ConstEval/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/ConstEval/test/BUILD.bazel
@@ -16,7 +16,10 @@
name = "lit",
srcs = enforce_glob(
[
+ "compile_regressions.mlir",
+ "failing.mlir",
"jit_globals.mlir",
+ "scalar_values.mlir",
],
include = ["*.mlir"],
),
diff --git a/compiler/src/iree/compiler/ConstEval/test/CMakeLists.txt b/compiler/src/iree/compiler/ConstEval/test/CMakeLists.txt
index 79fe3d9..e14f38b 100644
--- a/compiler/src/iree/compiler/ConstEval/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/ConstEval/test/CMakeLists.txt
@@ -14,7 +14,10 @@
NAME
lit
SRCS
+ "compile_regressions.mlir"
+ "failing.mlir"
"jit_globals.mlir"
+ "scalar_values.mlir"
TOOLS
FileCheck
iree-opt
diff --git a/compiler/src/iree/compiler/ConstEval/test/compile_regressions.mlir b/compiler/src/iree/compiler/ConstEval/test/compile_regressions.mlir
new file mode 100644
index 0000000..903ed6c
--- /dev/null
+++ b/compiler/src/iree/compiler/ConstEval/test/compile_regressions.mlir
@@ -0,0 +1,49 @@
+// RUN: iree-opt --split-input-file --verify-diagnostics --iree-consteval-jit-debug --iree-consteval-jit-globals %s | FileCheck %s
+
+// Test case reduced by running the pass --iree-util-hoist-into-globals on the
+// following (and then chang the check to a return):
+// func.func @i1_inline_constant() {
+// %control = arith.constant dense<[true, false, true, false]> : tensor<4xi1>
+// %a = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+// %b = arith.constant dense<[5, 6, 7, 8]> : tensor<4xi32>
+// %init = tensor.empty() : tensor<4xi32>
+// %c = linalg.generic {
+// indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>,
+// affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+// iterator_types = ["parallel"]}
+// ins(%control, %a, %b : tensor<4xi1>, tensor<4xi32>, tensor<4xi32>)
+// outs(%init : tensor<4xi32>) {
+// ^bb0(%b1 : i1, %b2 : i32, %b3 : i32, %b4 : i32):
+// %0 = arith.select %b1, %b2, %b3 : i32
+// linalg.yield %0 : i32
+// } -> tensor<4xi32>
+// check.expect_eq_const(%c, dense<[1, 6, 3, 8]> : tensor<4xi32>) : tensor<4xi32>
+// return
+// }
+
+// CHECK-LABEL: module @hoisted_tensor_i1_input
+// Verify the original check based on constant folding.
+// CHECK: = dense<[1, 6, 3, 8]>
+#map = affine_map<(d0) -> (d0)>
+module @hoisted_tensor_i1_input {
+ util.global private @hoisted : tensor<4xi32>
+ func.func @i1_inline_constant() -> tensor<4xi32> {
+ %hoisted = util.global.load @hoisted : tensor<4xi32>
+ return %hoisted : tensor<4xi32>
+ }
+ util.initializer attributes {iree.compiler.consteval} {
+ %cst = arith.constant dense<[true, false, true, false]> : tensor<4xi1>
+ %cst_0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+ %cst_1 = arith.constant dense<[5, 6, 7, 8]> : tensor<4xi32>
+ %0 = tensor.empty() : tensor<4xi32>
+ %1 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%cst, %cst_0, %cst_1 : tensor<4xi1>, tensor<4xi32>, tensor<4xi32>) outs(%0 : tensor<4xi32>) {
+ ^bb0(%in: i1, %in_2: i32, %in_3: i32, %out: i32):
+ %2 = arith.select %in, %in_2, %in_3 : i32
+ linalg.yield %2 : i32
+ } -> tensor<4xi32>
+ util.global.store %1, @hoisted : tensor<4xi32>
+ util.initializer.return
+ }
+}
+
+// -----
diff --git a/compiler/src/iree/compiler/ConstEval/test/failing.mlir b/compiler/src/iree/compiler/ConstEval/test/failing.mlir
new file mode 100644
index 0000000..71edad7
--- /dev/null
+++ b/compiler/src/iree/compiler/ConstEval/test/failing.mlir
@@ -0,0 +1,20 @@
+// RUN: iree-opt --split-input-file --iree-consteval-jit-target-backend=vmvx --verify-diagnostics --iree-consteval-jit-debug --iree-consteval-jit-globals %s | FileCheck %s
+// XFAIL: *
+
+// CHECK-LABEL: @eval_f64_scalar
+// CHECK: 4.200000e+01 : f64
+module @eval_i64_scalar {
+ util.global private @offset : f64 = -2.0 : f64
+ util.global private @hoisted : f64
+ func.func @main() -> f64 {
+ %hoisted = util.global.load @hoisted : f64
+ return %hoisted : f64
+ }
+ util.initializer attributes {iree.compiler.consteval} {
+ %cst = arith.constant 44.0 : f64
+ %offset = util.global.load @offset : f64
+ %sum = arith.addf %cst, %offset : f64
+ util.global.store %sum, @hoisted : f64
+ util.initializer.return
+ }
+}
diff --git a/compiler/src/iree/compiler/ConstEval/test/scalar_values.mlir b/compiler/src/iree/compiler/ConstEval/test/scalar_values.mlir
new file mode 100644
index 0000000..513d0e9
--- /dev/null
+++ b/compiler/src/iree/compiler/ConstEval/test/scalar_values.mlir
@@ -0,0 +1,95 @@
+// RUN: iree-opt --split-input-file --iree-consteval-jit-target-backend=vmvx --verify-diagnostics --iree-consteval-jit-debug --iree-consteval-jit-globals %s | FileCheck %s
+
+// CHECK-LABEL: @eval_i8_scalar
+// CHECK: 42 : i8
+module @eval_i8_scalar {
+ util.global private @offset : i8 = -2 : i8
+ util.global private @hoisted : i8
+ func.func @main() -> i8 {
+ %hoisted = util.global.load @hoisted : i8
+ return %hoisted : i8
+ }
+ util.initializer attributes {iree.compiler.consteval} {
+ %cst = arith.constant 44 : i8
+ %offset = util.global.load @offset : i8
+ %sum = arith.addi %cst, %offset : i8
+ util.global.store %sum, @hoisted : i8
+ util.initializer.return
+ }
+}
+
+// -----
+// CHECK-LABEL: @eval_i16_scalar
+// CHECK: 42 : i16
+module @eval_i16_scalar {
+ util.global private @offset : i16 = -2 : i16
+ util.global private @hoisted : i16
+ func.func @main() -> i16 {
+ %hoisted = util.global.load @hoisted : i16
+ return %hoisted : i16
+ }
+ util.initializer attributes {iree.compiler.consteval} {
+ %cst = arith.constant 44 : i16
+ %offset = util.global.load @offset : i16
+ %sum = arith.addi %cst, %offset : i16
+ util.global.store %sum, @hoisted : i16
+ util.initializer.return
+ }
+}
+
+// -----
+// CHECK-LABEL: @eval_i32_scalar
+// CHECK: 42 : i32
+module @eval_i32_scalar {
+ util.global private @offset : i32 = -2 : i32
+ util.global private @hoisted : i32
+ func.func @main() -> i32 {
+ %hoisted = util.global.load @hoisted : i32
+ return %hoisted : i32
+ }
+ util.initializer attributes {iree.compiler.consteval} {
+ %cst = arith.constant 44 : i32
+ %offset = util.global.load @offset : i32
+ %sum = arith.addi %cst, %offset : i32
+ util.global.store %sum, @hoisted : i32
+ util.initializer.return
+ }
+}
+
+// -----
+// CHECK-LABEL: @eval_i64_scalar
+// CHECK: 42 : i64
+module @eval_i64_scalar {
+ util.global private @offset : i64 = -2 : i64
+ util.global private @hoisted : i64
+ func.func @main() -> i64 {
+ %hoisted = util.global.load @hoisted : i64
+ return %hoisted : i64
+ }
+ util.initializer attributes {iree.compiler.consteval} {
+ %cst = arith.constant 44 : i64
+ %offset = util.global.load @offset : i64
+ %sum = arith.addi %cst, %offset : i64
+ util.global.store %sum, @hoisted : i64
+ util.initializer.return
+ }
+}
+
+// -----
+// CHECK-LABEL: @eval_f32_scalar
+// CHECK: 4.200000e+01 : f32
+module @eval_f32_scalar {
+ util.global private @offset : f32 = -2.0 : f32
+ util.global private @hoisted : f32
+ func.func @main() -> f32 {
+ %hoisted = util.global.load @hoisted : f32
+ return %hoisted : f32
+ }
+ util.initializer attributes {iree.compiler.consteval} {
+ %cst = arith.constant 44.0 : f32
+ %offset = util.global.load @offset : f32
+ %sum = arith.addf %cst, %offset : f32
+ util.global.store %sum, @hoisted : f32
+ util.initializer.return
+ }
+}