Implement all feasible types in JitGlobals. (#7956)
Supported: i8/16/32/64, f32/f64, i1
Unsupported: Non byte-aligned integer, f16, bf16
There are still some VM bugs with i1 which I will need to address separately. Also sub-byte aligned still seems to need a fair bit of work in the system so just punting on that for now.
diff --git a/iree/compiler/ConstEval/JitGlobals.cpp b/iree/compiler/ConstEval/JitGlobals.cpp
index 7613a40..08c8998 100644
--- a/iree/compiler/ConstEval/JitGlobals.cpp
+++ b/iree/compiler/ConstEval/JitGlobals.cpp
@@ -178,16 +178,25 @@
for (Operation &childOp : *innerModule.getBody()) {
auto globalOp = llvm::dyn_cast<IREE::Util::GlobalOp>(childOp);
if (!globalOp) continue;
- if (!globalOp.getInitialValueAttr()) {
- StringAttr funcSymbol = extractor.createAccessor(globalOp);
- uninitializedGlobals.emplace_back(funcSymbol, globalOp.sym_nameAttr());
+ if (globalOp.getInitialValueAttr()) continue;
+
+ // Only generate an accessor for types our runtime bridge knows how to
+ // handle.
+ Type type = globalOp.type();
+ if (!CompiledBinary::isSupportedResultType(type)) {
+ LLVM_DEBUG(dbgs() << "JitGlobals: unsupported global type " << type);
+ continue;
}
+
+ StringAttr funcSymbol = extractor.createAccessor(globalOp);
+ uninitializedGlobals.emplace_back(funcSymbol, globalOp.sym_nameAttr());
}
// Early exit without compiling if no entry-points (this is not just an
// optimization: the low level compiler will fail on an empty module).
if (uninitializedGlobals.empty()) {
LLVM_DEBUG(dbgs() << "Not JIT'ing globals: no undefined globals found\n");
+ innerModule.erase();
return;
}
@@ -216,7 +225,7 @@
Location loc = targetGlobal->getLoc();
Attribute value =
- binary.invokeNullaryAsElements(loc, funcSymbol.strref());
+ binary.invokeNullaryAsAttribute(loc, funcSymbol.strref());
if (!value) {
return signalPassFailure();
}
diff --git a/iree/compiler/ConstEval/Runtime.cpp b/iree/compiler/ConstEval/Runtime.cpp
index 783ad9d..38d12d6 100644
--- a/iree/compiler/ConstEval/Runtime.cpp
+++ b/iree/compiler/ConstEval/Runtime.cpp
@@ -12,6 +12,7 @@
#include "iree/modules/hal/module.h"
#include "iree/vm/ref_cc.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
namespace mlir {
namespace iree_compiler {
@@ -38,6 +39,50 @@
return {};
}
+static Attribute createAttributeFromRawData(Location loc,
+ RankedTensorType tensorType,
+ MutableArrayRef<char> rawBuffer) {
+ Type elementType = tensorType.getElementType();
+ // For numeric types that are byte-width aligned, we just use the raw buffer
+ // loading support of DenseElementsAttr.
+ if (elementType.isIntOrFloat() &&
+ elementType.getIntOrFloatBitWidth() % 8 == 0) {
+ bool detectedSplat = false;
+ if (DenseElementsAttr::isValidRawBuffer(tensorType, rawBuffer,
+ detectedSplat)) {
+ return DenseElementsAttr::getFromRawBuffer(tensorType, rawBuffer,
+ detectedSplat);
+ } else {
+ emitError(loc) << "mapped memory region was not valid for constructing "
+ "tensor of type "
+ << tensorType << " (length=" << rawBuffer.size() << ")";
+ return {};
+ }
+ }
+
+ // For i1, IREE (currently) returns these as 8bit integer values and MLIR
+ // has a loader that accepts bool arrays (the raw buffer loader also
+ // supports them but bit-packed, which is not convenient for us). So, if
+ // sizeof(bool) == 1, we just bit-cast. Otherwise, we go through a temporary.
+ if (elementType.isInteger(1)) {
+ if (sizeof(bool) == 1) {
+ ArrayRef<bool> boolArray(reinterpret_cast<bool*>(rawBuffer.data()),
+ rawBuffer.size());
+ return DenseElementsAttr::get(tensorType, boolArray);
+ } else {
+ // Note: cannot use std::vector because it specializes bool in a way
+ // that is not compatible with ArrayRef.
+ llvm::SmallVector<bool> boolVector(rawBuffer.begin(), rawBuffer.end());
+ ArrayRef<bool> boolArray(boolVector.data(), boolVector.size());
+ return DenseElementsAttr::get(tensorType, boolArray);
+ }
+ }
+
+ emitError(loc) << "unhandled case when converting raw buffer of "
+ << tensorType << " to Attribute";
+ return {};
+}
+
} // namespace
CompiledBinary::CompiledBinary() {}
@@ -92,8 +137,8 @@
return success();
}
-Attribute CompiledBinary::invokeNullaryAsElements(Location loc,
- StringRef name) {
+Attribute CompiledBinary::invokeNullaryAsAttribute(Location loc,
+ StringRef name) {
Attribute result;
if (failed(invokeNullary(
loc, name, [&](iree_vm_list_t* outputs) -> LogicalResult {
@@ -112,6 +157,30 @@
return result;
}
+bool CompiledBinary::isSupportedResultType(Type type) {
+ // TODO: Not currently supported.
+ if (type.isa<Float16Type>() || type.isa<BFloat16Type>()) {
+ return false;
+ }
+
+ // Support scalar int and float type of byte aligned widths.
+ if (type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0) {
+ return true;
+ }
+
+ // Special support for i1.
+ if (type.isa<IntegerType>() && type.getIntOrFloatBitWidth() == 1) {
+ return true;
+ }
+
+ // Support tensors.
+ if (auto tt = type.dyn_cast<RankedTensorType>()) {
+ return isSupportedResultType(tt.getElementType());
+ }
+
+ return false;
+}
+
Attribute CompiledBinary::convertVariantToAttribute(
Location loc, iree_vm_variant_t& variant) {
auto context = loc.getContext();
@@ -161,28 +230,18 @@
iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bufferView);
// Map the memory and construct.
- DenseElementsAttr elementsAttr;
+ Attribute convertedAttr;
iree_hal_buffer_mapping_t mapping;
IREE_CHECK_OK(
iree_hal_buffer_map_range(buffer, IREE_HAL_MEMORY_ACCESS_READ,
/*byte_offset=*/0, length, &mapping));
- ArrayRef<char> rawBufferArray(
+ MutableArrayRef<char> rawBufferArray(
reinterpret_cast<char*>(mapping.contents.data),
mapping.contents.data_length);
- bool detectedSplat = false;
- if (DenseElementsAttr::isValidRawBuffer(tensorType, rawBufferArray,
- detectedSplat)) {
- elementsAttr = DenseElementsAttr::getFromRawBuffer(
- tensorType, rawBufferArray, detectedSplat);
- }
+ convertedAttr =
+ createAttributeFromRawData(loc, tensorType, rawBufferArray);
iree_hal_buffer_unmap_range(&mapping);
-
- if (!elementsAttr) {
- emitError(loc) << "mapped memory region was not valid for constructing "
- "tensor of type "
- << tensorType << " (length=" << length << ")";
- }
- return elementsAttr;
+ return convertedAttr;
} else {
iree_string_view_t typeName =
iree_vm_ref_type_name(variant.type.ref_type);
diff --git a/iree/compiler/ConstEval/Runtime.h b/iree/compiler/ConstEval/Runtime.h
index c3fc748..fe26fbe 100644
--- a/iree/compiler/ConstEval/Runtime.h
+++ b/iree/compiler/ConstEval/Runtime.h
@@ -34,7 +34,10 @@
// Invokes a nullary function and returns its (presumed single) single result
// as an Attribute.
- Attribute invokeNullaryAsElements(Location loc, StringRef name);
+ Attribute invokeNullaryAsAttribute(Location loc, StringRef name);
+
+ // Whether the given type is supported in *AsAttribute methods.
+ static bool isSupportedResultType(Type type);
protected:
CompiledBinary();
diff --git a/iree/compiler/ConstEval/test/jit_globals.mlir b/iree/compiler/ConstEval/test/jit_globals.mlir
index a711a46..4962543 100644
--- a/iree/compiler/ConstEval/test/jit_globals.mlir
+++ b/iree/compiler/ConstEval/test/jit_globals.mlir
@@ -40,62 +40,184 @@
}
}
-// TODO: Crashes compiler.
-// COM-CHECK-LABEL: @eval_f16_tensor
-// module @eval_f16_tensor {
-// util.global private @hoisted : tensor<5x6xf16>
-// func @main() -> tensor<5x6xf16> {
-// %hoisted = util.global.load @hoisted : tensor<5x6xf16>
-// return %hoisted : tensor<5x6xf16>
-// }
-// util.initializer {
-// %cst = arith.constant dense<2.0e+2> : tensor<5x6xf16>
-// util.global.store %cst, @hoisted : tensor<5x6xf16>
-// util.initializer.return
-// }
-// }
+// -----
+// CHECK-LABEL: @eval_splat_detection
+// CHECK: util.global private @{{.*}} = dense<2> : tensor<2xi32>
+module @eval_splat_detection {
+ util.global private @hoisted : tensor<2xi32>
+ func @main() -> tensor<2xi32> {
+ %hoisted = util.global.load @hoisted : tensor<2xi32>
+ return %hoisted : tensor<2xi32>
+ }
+ util.initializer {
+ %cst = arith.constant dense<[2, 2]> : tensor<2xi32>
+ util.global.store %cst, @hoisted : tensor<2xi32>
+ util.initializer.return
+ }
+}
-// TODO: Error on 'hal.command_buffer.fill_buffer'
-// COM-CHECK-LABEL: @eval_f16_tensor
-// module @eval_bf16_tensor {
-// util.global private @hoisted : tensor<5x6xbf16>
-// func @main() -> tensor<5x6xbf16> {
-// %hoisted = util.global.load @hoisted : tensor<5x6xbf16>
-// return %hoisted : tensor<5x6xbf16>
-// }
-// util.initializer {
-// %cst = arith.constant dense<2.0e+2> : tensor<5x6xbf16>
-// util.global.store %cst, @hoisted : tensor<5x6xbf16>
-// util.initializer.return
-// }
-// }
-// TODO: Error on 'hal.command_buffer.fill_buffer'
-// COM-CHECK-LABEL: @eval_i4_tensor
-// module @eval_i4_tensor {
-// util.global private @hoisted : tensor<5x6xi4>
-// func @main() -> tensor<5x6xi4> {
-// %hoisted = util.global.load @hoisted : tensor<5x6xi4>
-// return %hoisted : tensor<5x6xi4>
-// }
-// util.initializer {
-// %cst = arith.constant dense<3> : tensor<5x6xi4>
-// util.global.store %cst, @hoisted : tensor<5x6xi4>
-// util.initializer.return
-// }
-// }
+// -----
+// CHECK-LABEL: @eval_f16_tensor
+// Not currently supported (initializer should remain)
+// CHECK: util.initializer
+module @eval_f16_tensor {
+ util.global private @hoisted : tensor<5x6xf16>
+ func @main() -> tensor<5x6xf16> {
+ %hoisted = util.global.load @hoisted : tensor<5x6xf16>
+ return %hoisted : tensor<5x6xf16>
+ }
+ util.initializer {
+ %cst = arith.constant dense<2.0e+2> : tensor<5x6xf16>
+ util.global.store %cst, @hoisted : tensor<5x6xf16>
+ util.initializer.return
+ }
+}
-// TODO: Error: mapped memory region was not valid for constructing tensor of type 'tensor<5x6xi1>' (length=30)
-// COM-CHECK-LABEL: @eval_i1_tensor
-// module @eval_i1_tensor {
-// util.global private @hoisted : tensor<5x6xi1>
-// func @main() -> tensor<5x6xi1> {
-// %hoisted = util.global.load @hoisted : tensor<5x6xi1>
-// return %hoisted : tensor<5x6xi1>
-// }
-// util.initializer {
-// %cst = arith.constant dense<1> : tensor<5x6xi1>
-// util.global.store %cst, @hoisted : tensor<5x6xi1>
-// util.initializer.return
-// }
-// }
+// -----
+// CHECK-LABEL: @eval_bf16_tensor
+// Not currently supported (initializer should remain)
+// CHECK: util.initializer
+module @eval_bf16_tensor {
+ util.global private @hoisted : tensor<5x6xbf16>
+ func @main() -> tensor<5x6xbf16> {
+ %hoisted = util.global.load @hoisted : tensor<5x6xbf16>
+ return %hoisted : tensor<5x6xbf16>
+ }
+ util.initializer {
+ %cst = arith.constant dense<2.0e+2> : tensor<5x6xbf16>
+ util.global.store %cst, @hoisted : tensor<5x6xbf16>
+ util.initializer.return
+ }
+}
+
+// -----
+// CHECK-LABEL: @eval_f32_tensor
+// CHECK: util.global private @{{.*}} = dense<[2.000000e+02, 3.200000e+03]> : tensor<2xf32>
+module @eval_f32_tensor {
+ util.global private @hoisted : tensor<2xf32>
+ func @main() -> tensor<2xf32> {
+ %hoisted = util.global.load @hoisted : tensor<2xf32>
+ return %hoisted : tensor<2xf32>
+ }
+ util.initializer {
+ %cst = arith.constant dense<[2.0e+2, 3.2e+3]> : tensor<2xf32>
+ util.global.store %cst, @hoisted : tensor<2xf32>
+ util.initializer.return
+ }
+}
+
+// -----
+// CHECK-LABEL: @eval_f64_tensor
+// CHECK: util.global private @{{.*}} = dense<[2.000000e+02, 3.200000e+03]> : tensor<2xf64>
+module @eval_f64_tensor {
+ util.global private @hoisted : tensor<2xf64>
+ func @main() -> tensor<2xf64> {
+ %hoisted = util.global.load @hoisted : tensor<2xf64>
+ return %hoisted : tensor<2xf64>
+ }
+ util.initializer {
+ %cst = arith.constant dense<[2.0e+2, 3.2e+3]> : tensor<2xf64>
+ util.global.store %cst, @hoisted : tensor<2xf64>
+ util.initializer.return
+ }
+}
+
+// -----
+// CHECK-LABEL: @eval_i1_tensor
+// CHECK: util.global private @{{.*}} = dense<[false, true, false, true, true, false]> : tensor<6xi1>
+module @eval_i1_tensor {
+ util.global private @hoisted : tensor<6xi1>
+ func @main() -> tensor<6xi1> {
+ %hoisted = util.global.load @hoisted : tensor<6xi1>
+ return %hoisted : tensor<6xi1>
+ }
+ util.initializer {
+ // Note that the level we are testing at is a bit odd in the way i1 vs
+ // i8 are handled.
+ %cst = arith.constant dense<[0, 1, 0, 1, 1, 0]> : tensor<6xi8>
+ %casted = arith.trunci %cst : tensor<6xi8> to tensor<6xi1>
+ util.global.store %casted, @hoisted : tensor<6xi1>
+ util.initializer.return
+ }
+}
+
+// -----
+// CHECK-LABEL: @eval_i4_tensor
+// CHECK: util.initializer
+module @eval_i4_tensor {
+ util.global private @hoisted : tensor<5x6xi4>
+ func @main() -> tensor<5x6xi4> {
+ %hoisted = util.global.load @hoisted : tensor<5x6xi4>
+ return %hoisted : tensor<5x6xi4>
+ }
+ util.initializer {
+ %cst = arith.constant dense<3> : tensor<5x6xi4>
+ util.global.store %cst, @hoisted : tensor<5x6xi4>
+ util.initializer.return
+ }
+}
+
+// -----
+// CHECK-LABEL: @eval_i8_tensor
+// CHECK: util.global private @{{.*}} = dense<[2, 3]> : tensor<2xi8>
+module @eval_i8_tensor {
+ util.global private @hoisted : tensor<2xi8>
+ func @main() -> tensor<2xi8> {
+ %hoisted = util.global.load @hoisted : tensor<2xi8>
+ return %hoisted : tensor<2xi8>
+ }
+ util.initializer {
+ %cst = arith.constant dense<[2, 3]> : tensor<2xi8>
+ util.global.store %cst, @hoisted : tensor<2xi8>
+ util.initializer.return
+ }
+}
+
+// -----
+// CHECK-LABEL: @eval_i16_tensor
+// CHECK: util.global private @{{.*}} = dense<[2, 3]> : tensor<2xi16>
+module @eval_i16_tensor {
+ util.global private @hoisted : tensor<2xi16>
+ func @main() -> tensor<2xi16> {
+ %hoisted = util.global.load @hoisted : tensor<2xi16>
+ return %hoisted : tensor<2xi16>
+ }
+ util.initializer {
+ %cst = arith.constant dense<[2, 3]> : tensor<2xi16>
+ util.global.store %cst, @hoisted : tensor<2xi16>
+ util.initializer.return
+ }
+}
+
+// -----
+// CHECK-LABEL: @eval_i32_tensor
+// CHECK: util.global private @{{.*}} = dense<[2, 3]> : tensor<2xi32>
+module @eval_i32_tensor {
+ util.global private @hoisted : tensor<2xi32>
+ func @main() -> tensor<2xi32> {
+ %hoisted = util.global.load @hoisted : tensor<2xi32>
+ return %hoisted : tensor<2xi32>
+ }
+ util.initializer {
+ %cst = arith.constant dense<[2, 3]> : tensor<2xi32>
+ util.global.store %cst, @hoisted : tensor<2xi32>
+ util.initializer.return
+ }
+}
+
+// -----
+// CHECK-LABEL: @eval_i64_tensor
+// CHECK: util.global private @{{.*}} = dense<[2, 3]> : tensor<2xi64>
+module @eval_i64_tensor {
+ util.global private @hoisted : tensor<2xi64>
+ func @main() -> tensor<2xi64> {
+ %hoisted = util.global.load @hoisted : tensor<2xi64>
+ return %hoisted : tensor<2xi64>
+ }
+ util.initializer {
+ %cst = arith.constant dense<[2, 3]> : tensor<2xi64>
+ util.global.store %cst, @hoisted : tensor<2xi64>
+ util.initializer.return
+ }
+}
diff --git a/iree/compiler/Translation/IREEVM.cpp b/iree/compiler/Translation/IREEVM.cpp
index 8fe5f6a..7e22013 100644
--- a/iree/compiler/Translation/IREEVM.cpp
+++ b/iree/compiler/Translation/IREEVM.cpp
@@ -245,6 +245,7 @@
void registerIREEVMTranslationFlags() {
getBindingOptionsFromFlags();
getInputDialectOptionsFromFlags();
+ getHighLevelOptimizationOptionsFromFlags();
}
void registerIREEVMTranslation() {