Rework consteval to use SerializableAttrInterface. * Fixes the need to do fill-based splat expansion, which also sidesteps needing to support 8byte fill. * Adds an unsupported i1 conversion case. * Reworks serialization to take a Location for better error reporting.
diff --git a/compiler/src/iree/compiler/ConstEval/BUILD.bazel b/compiler/src/iree/compiler/ConstEval/BUILD.bazel index 1771f7c..92eaf5e 100644 --- a/compiler/src/iree/compiler/ConstEval/BUILD.bazel +++ b/compiler/src/iree/compiler/ConstEval/BUILD.bazel
@@ -73,6 +73,7 @@ "Runtime.h", ], deps = [ + "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Dialect/VM/Target/Bytecode", "//runtime/src/iree/hal", "//runtime/src/iree/hal/drivers/local_task/registration",
diff --git a/compiler/src/iree/compiler/ConstEval/CMakeLists.txt b/compiler/src/iree/compiler/ConstEval/CMakeLists.txt index e13efe8..b91bfe7 100644 --- a/compiler/src/iree/compiler/ConstEval/CMakeLists.txt +++ b/compiler/src/iree/compiler/ConstEval/CMakeLists.txt
@@ -65,6 +65,7 @@ DEPS LLVMSupport MLIRIR + iree::compiler::Dialect::Util::IR iree::compiler::Dialect::VM::Target::Bytecode iree::hal iree::hal::drivers::local_task::registration
diff --git a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp index 003fd64..52a1f40 100644 --- a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp +++ b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
@@ -326,6 +326,7 @@ options->executableOptions.targets.push_back(requestedTargetBackend); options->targetOptions.f32Extension = true; options->targetOptions.f64Extension = true; + options->targetOptions.truncateUnsupportedFloats = false; if (requestedTargetBackend == "vmvx" || !hasRequestedTargetBackend) { targetBackend = targetRegistry.getTargetBackend("vmvx"); } else {
diff --git a/compiler/src/iree/compiler/ConstEval/Runtime.cpp b/compiler/src/iree/compiler/ConstEval/Runtime.cpp index 885697e..c781e60 100644 --- a/compiler/src/iree/compiler/ConstEval/Runtime.cpp +++ b/compiler/src/iree/compiler/ConstEval/Runtime.cpp
@@ -6,6 +6,7 @@ #include "iree/compiler/ConstEval/Runtime.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h" #include "iree/hal/drivers/local_task/registration/driver_module.h" #include "mlir/IR/BuiltinOps.h" @@ -159,154 +160,80 @@ &outputs)); } -LogicalResult FunctionCall::importBufferForRead(Location loc, - const uint8_t *rawData, - iree_host_size_t length, - iree_hal_buffer_t **buffer) { - // TODO: Allow import when we have resources in the input where alignment - // can be guaranteed. - bool tryImport = false; - if (tryImport) { - iree_hal_buffer_params_t params; - std::memset(¶ms, 0, sizeof(params)); - params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE; - iree_hal_external_buffer_t external_buffer; - std::memset(&external_buffer, 0, sizeof(external_buffer)); - external_buffer.type = IREE_HAL_EXTERNAL_BUFFER_TYPE_HOST_ALLOCATION; - external_buffer.size = length; - external_buffer.handle.host_allocation.ptr = - const_cast<void *>(static_cast<const void *>(rawData)); - auto status = iree_hal_allocator_import_buffer( - binary.getAllocator(), params, &external_buffer, - /*release_callback=*/{nullptr, nullptr}, buffer); - if (iree_status_is_ok(status)) - return success(); - else if (!(iree_status_is_out_of_range(status) || - iree_status_is_unavailable(status))) - return handleRuntimeError(loc, status); +LogicalResult FunctionCall::addArgumentElementsAttr(Location loc, + ElementsAttr elementsAttr) { + auto ser = + llvm::dyn_cast<IREE::Util::SerializableAttrInterface>(elementsAttr); + if (!ser) { + return emitError(loc) << "internal error: ElementsAttr does not implement " + "SerializableAttrInterface"; + } + // Meta-data. + ShapedType st = llvm::cast<ShapedType>(elementsAttr.getType()); + auto stShape = st.getShape(); + auto rank = static_cast<size_t>(st.getRank()); + iree_hal_dim_t *shape = + static_cast<iree_hal_dim_t *>(alloca(rank * sizeof(iree_hal_dim_t))); + for (size_t i = 0; i < rank; ++i) { + shape[i] = stShape[i]; + } + Type mlirElementType = st.getElementType(); + iree_hal_element_type_t elementType = IREE_HAL_ELEMENT_TYPE_NONE; + if (failed(convertToElementType(loc, mlirElementType, &elementType))) + return failure(); + + // Allocate buffer. + int64_t storageSize = ser.getStorageSize(); + if (storageSize < 0) { + return emitError(loc) << "unsupported serializable attribute: " + << elementsAttr; } - // Buffer is not compatible with import. Snapshot. - { - iree_hal_buffer_params_t params; - std::memset(¶ms, 0, sizeof(params)); - params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE; - LLVM_DEBUG( - dbgs() - << "Cannot import consteval buffer. Falling back to snapshot.\n"); - return handleRuntimeError(loc, iree_hal_allocator_allocate_buffer( - binary.getAllocator(), params, length, - iree_const_byte_span_t{rawData, length}, - buffer)); - } -} - -LogicalResult FunctionCall::importBitwiseBoolI8BufferForRead( - Location loc, const uint8_t *rawDataBits, - iree_host_size_t rawDataLengthBytes, iree_hal_buffer_t **buffer) { + iree::vm::ref<iree_hal_buffer_t> buffer; iree_hal_buffer_params_t params; std::memset(¶ms, 0, sizeof(params)); - iree_host_size_t bufferLength = rawDataLengthBytes * 8; params.type = IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE; if (failed(handleRuntimeError( loc, iree_hal_allocator_allocate_buffer( - binary.getAllocator(), params, bufferLength, - iree_const_byte_span_t{nullptr, 0}, buffer)))) + binary.getAllocator(), params, storageSize, + iree_const_byte_span_t{nullptr, 0}, &buffer)))) return failure(); iree_hal_buffer_mapping_t mapping; if (failed(handleRuntimeError( loc, iree_hal_buffer_map_range( - *buffer, IREE_HAL_MAPPING_MODE_SCOPED, + buffer.get(), IREE_HAL_MAPPING_MODE_SCOPED, IREE_HAL_MEMORY_ACCESS_WRITE, /*byte_offset=*/0, - /*byte_length=*/bufferLength, &mapping)))) + /*byte_length=*/storageSize, &mapping)))) return failure(); // Copy. - for (iree_host_size_t i = 0; i < rawDataLengthBytes; ++i) { - uint8_t bits = rawDataBits[i]; - mapping.contents.data[i * 8 + 0] = bits & 0x1; - mapping.contents.data[i * 8 + 1] = (bits & 0x2) >> 1; - mapping.contents.data[i * 8 + 2] = (bits & 0x4) >> 2; - mapping.contents.data[i * 8 + 3] = (bits & 0x8) >> 3; - mapping.contents.data[i * 8 + 4] = (bits & 0x10) >> 4; - mapping.contents.data[i * 8 + 5] = (bits & 0x20) >> 5; - mapping.contents.data[i * 8 + 6] = (bits & 0x40) >> 6; - mapping.contents.data[i * 8 + 7] = (bits & 0x80) >> 7; + LogicalResult copyResult = ser.serializeToBuffer( + loc, llvm::support::endian::system_endianness(), + ArrayRef<char>(reinterpret_cast<char *>(mapping.contents.data), + storageSize)); + + if (failed(handleRuntimeError(loc, iree_hal_buffer_unmap_range(&mapping))) || + failed(copyResult)) { + return failure(); } - return handleRuntimeError(loc, iree_hal_buffer_unmap_range(&mapping)); + // Construct buffer view. + iree::vm::ref<iree_hal_buffer_view_t> bv; + if (failed(handleRuntimeError(loc, iree_hal_buffer_view_create( + buffer.get(), rank, shape, elementType, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + iree_allocator_system(), &bv)))) + return failure(); + + return handleRuntimeError( + loc, iree_vm_list_push_ref_move(inputs.get(), std::move(bv))); } LogicalResult FunctionCall::addArgument(Location loc, Attribute attr) { - if (auto elementsAttr = llvm::dyn_cast<DenseElementsAttr>(attr)) { - // Meta-data. - ArrayRef<char> data = elementsAttr.getRawData(); - ShapedType st = elementsAttr.getType(); - auto stShape = st.getShape(); - auto rank = static_cast<size_t>(st.getRank()); - iree_hal_dim_t *shape = - static_cast<iree_hal_dim_t *>(alloca(rank * sizeof(iree_hal_dim_t))); - for (size_t i = 0; i < rank; ++i) { - shape[i] = stShape[i]; - } - Type mlirElementType = st.getElementType(); - bool isI1 = mlirElementType == IntegerType::get(loc.getContext(), 1); - iree_hal_element_type_t elementType = IREE_HAL_ELEMENT_TYPE_NONE; - if (failed(convertToElementType(loc, mlirElementType, &elementType))) - return failure(); - - iree::vm::ref<iree_hal_buffer_t> buffer; - if (elementsAttr.isSplat()) { - // Handle splat. In this case, the data size is one element. - iree_device_size_t bufferSize = data.size() * st.getNumElements(); - iree_hal_buffer_params_t params; - std::memset(¶ms, 0, sizeof(params)); - params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE; - if (failed(handleRuntimeError( - loc, iree_hal_allocator_allocate_buffer( - binary.getAllocator(), params, bufferSize, - iree_const_byte_span_t{nullptr, 0}, &buffer)))) - return failure(); - - if (failed(handleRuntimeError( - loc, iree_hal_buffer_map_fill( - buffer.get(), 0, bufferSize, - static_cast<const void *>(data.data()), data.size())))) - return failure(); - } else if (isI1) { - // Dense, non-splat i1. - // MLIR DenseElementsAttr made the interesting optimization choice to - // densely pack i1 as a bit-vector. It doesn't do this for any other - // sub-byte type, and it is aligned linearly (not row-wise), so is - // a complete special case. - // Since we map this to an 8bit bool on the IREE runtime side, we - // just do the best we can when allocating. - if (failed(importBitwiseBoolI8BufferForRead( - loc, reinterpret_cast<const uint8_t *>(data.data()), data.size(), - &buffer))) { - return failure(); - } - } else { - // Dense, non-splat. - if (failed(importBufferForRead( - loc, reinterpret_cast<const uint8_t *>(data.data()), data.size(), - &buffer))) - return failure(); - } - - // Construct buffer view. - iree::vm::ref<iree_hal_buffer_view_t> bv; - if (failed(handleRuntimeError( - loc, - iree_hal_buffer_view_create(buffer.get(), rank, shape, elementType, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, - iree_allocator_system(), &bv)))) - return failure(); - - return handleRuntimeError( - loc, iree_vm_list_push_ref_move(inputs.get(), std::move(bv))); + if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(attr)) { + return addArgumentElementsAttr(loc, elementsAttr); } else if (auto integerAttr = llvm::dyn_cast<IntegerAttr>(attr)) { iree_vm_value_t value; APInt apValue = integerAttr.getValue(); @@ -337,8 +264,9 @@ } else if (auto floatAttr = llvm::dyn_cast<FloatAttr>(attr)) { iree_vm_value_t value; APFloat apValue = floatAttr.getValue(); - // Note that there are many floating point semantics that LLVM knows about, - // but we restrict to only those that the VM natively supports here. + // Note that there are many floating point semantics that LLVM knows + // about, but we restrict to only those that the VM natively supports + // here. switch (APFloat::SemanticsToEnum(apValue.getSemantics())) { case APFloat::S_IEEEsingle: value = iree_vm_value_make_f32(apValue.convertToFloat());
diff --git a/compiler/src/iree/compiler/ConstEval/Runtime.h b/compiler/src/iree/compiler/ConstEval/Runtime.h index 5637e86..bd1d05b 100644 --- a/compiler/src/iree/compiler/ConstEval/Runtime.h +++ b/compiler/src/iree/compiler/ConstEval/Runtime.h
@@ -62,18 +62,9 @@ TypedAttr &outAttr); private: - // Imports or snapshots a raw host buffer, depending on whether import is - // possible. This should only be used when the MLIR and IREE layout - // agree. - LogicalResult importBufferForRead(Location loc, const uint8_t *rawData, - iree_host_size_t length, - iree_hal_buffer_t **buffer); - // Imports a bit vector of rawData into a byte buffer, expanding 1->8bit - // during import. - LogicalResult - importBitwiseBoolI8BufferForRead(Location loc, const uint8_t *rawDataBits, - iree_host_size_t rawDataLengthBytes, - iree_hal_buffer_t **buffer); + LogicalResult addArgumentElementsAttr(Location loc, + ElementsAttr elementsAttr); + CompiledBinary binary; iree::vm::ref<iree_vm_list_t> inputs; iree::vm::ref<iree_vm_list_t> outputs;
diff --git a/compiler/src/iree/compiler/ConstEval/test/jit_globals.mlir b/compiler/src/iree/compiler/ConstEval/test/jit_globals.mlir index e0e0360..c345764 100644 --- a/compiler/src/iree/compiler/ConstEval/test/jit_globals.mlir +++ b/compiler/src/iree/compiler/ConstEval/test/jit_globals.mlir
@@ -220,3 +220,20 @@ util.initializer.return } } + +// ----- +// Splat of an 8byte value ensures that large fills are possible. +// CHECK-LABEL: @eval_i64_tensor_splat +// CHECK: util.global private @{{.*}} = dense<2> : tensor<2xi64> +module @eval_i64_tensor_splat { + util.global private @hoisted : tensor<2xi64> + func.func @main() -> tensor<2xi64> { + %hoisted = util.global.load @hoisted : tensor<2xi64> + return %hoisted : tensor<2xi64> + } + util.initializer attributes {iree.compiler.consteval} { + %cst = arith.constant dense<2> : tensor<2xi64> + util.global.store %cst, @hoisted : tensor<2xi64> + util.initializer.return + } +}
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp index 582298d..f545926 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp
@@ -100,7 +100,8 @@ // Appends the raw bytes of |value| in the given endianness to |buffer|. // Non-byte-aligned types are rounded up to the next power of two byte-aligned // bit width (i1 -> i8, i4 -> i8, i17 -> i32, etc). -static LogicalResult serializeAPIntRawData(APInt value, uint64_t bitWidth, +static LogicalResult serializeAPIntRawData(Location loc, APInt value, + uint64_t bitWidth, llvm::support::endianness endian, SmallVectorImpl<char> &buffer) { // Round up to 8-bit aligned bytes. @@ -136,12 +137,14 @@ return success(); } default: - return failure(); + return emitError(loc) << "unhandled byte width in serializeAPIntRawData: " + << byteWidth; } } // Appends the raw bytes of |value| in the given endianness to |buffer|. -static LogicalResult serializeAPFloatRawData(APFloat value, size_t bitWidth, +static LogicalResult serializeAPFloatRawData(Location loc, APFloat value, + size_t bitWidth, llvm::support::endianness endian, SmallVectorImpl<char> &buffer) { buffer.resize(bitWidth / 8); @@ -171,30 +174,32 @@ return success(); } default: - return failure(); + return emitError(loc) << "unhandled bitWidth in serializeAPFloatRawData: " + << bitWidth; } } // Serializes |count| copies of |splatAttr| to |os|. // Significantly faster than the generic ElementsAttr path that needs to perform // conversion of the same splat value |count| times. -static LogicalResult serializeSplatValue(Attribute splatAttr, int64_t count, +static LogicalResult serializeSplatValue(Location loc, Attribute splatAttr, + int64_t count, llvm::support::endianness endian, llvm::raw_ostream &os) { // Get the encoded byte contents of the splat element. SmallVector<char> elementBuffer; if (auto attr = llvm::dyn_cast<SerializableAttrInterface>(splatAttr)) { - if (failed(attr.serializeToVector(endian, elementBuffer))) { + if (failed(attr.serializeToVector(loc, endian, elementBuffer))) { return failure(); } } else if (auto attr = llvm::dyn_cast<IntegerAttr>(splatAttr)) { - if (failed(serializeAPIntRawData(attr.getValue(), + if (failed(serializeAPIntRawData(loc, attr.getValue(), attr.getType().getIntOrFloatBitWidth(), endian, elementBuffer))) { return failure(); } } else if (auto attr = llvm::dyn_cast<FloatAttr>(splatAttr)) { - if (failed(serializeAPFloatRawData(attr.getValue(), + if (failed(serializeAPFloatRawData(loc, attr.getValue(), attr.getType().getIntOrFloatBitWidth(), endian, elementBuffer))) { return failure(); @@ -214,7 +219,8 @@ // Serializes the raw data of the given |elementsAttr| to |os|. // Assumes that the caller knows what they are doing; the raw data must be in // the expected endianness and be densely packed. -static LogicalResult serializeRawData(DenseElementsAttr elementsAttr, +static LogicalResult serializeRawData(Location loc, + DenseElementsAttr elementsAttr, llvm::raw_ostream &os) { auto rawData = elementsAttr.getRawData(); os.write(rawData.data(), rawData.size()); @@ -265,7 +271,7 @@ }; static LogicalResult -serializeSubByteIntegerElements(DenseIntElementsAttr attr, +serializeSubByteIntegerElements(Location loc, DenseIntElementsAttr attr, llvm::support::endianness endian, llvm::raw_ostream &os) { const unsigned logicalBitWidth = @@ -308,9 +314,8 @@ return success(); } default: - return emitError(UnknownLoc::get(attr.getContext())) - << "unhandled packed integer physical bit width " << physicalBitWidth - << " for type " << attr.getType(); + return emitError(loc) << "unhandled packed integer physical bit width " + << physicalBitWidth << " for type " << attr.getType(); } } @@ -340,10 +345,31 @@ return success(); } +// Expands 8-values per byte raw data from DenseIntElementsAttr to 0/1 byte +// values in the output. +static LogicalResult serializeBitIntegerValuesAsBytes(DenseIntElementsAttr attr, + llvm::raw_ostream &os) { + auto rawData = attr.getRawData(); + char bytes[8]; + for (size_t i = 0; i < rawData.size(); ++i) { + int32_t bits = rawData[i]; + bytes[i * 8 + 0] = bits & 0x1; + bytes[i * 8 + 1] = (bits & 0x2) >> 1; + bytes[i * 8 + 2] = (bits & 0x4) >> 2; + bytes[i * 8 + 3] = (bits & 0x8) >> 3; + bytes[i * 8 + 4] = (bits & 0x10) >> 4; + bytes[i * 8 + 5] = (bits & 0x20) >> 5; + bytes[i * 8 + 6] = (bits & 0x40) >> 6; + bytes[i * 8 + 7] = (bits & 0x80) >> 7; + } + os.write(bytes, sizeof(bytes)); + return success(); +} + // Performs slow generic serialization of all of the elements in |elementsAttr|. // Respects the target |endian| setting, performing byte swaps if required. static LogicalResult -serializeGenericElementData(DenseElementsAttr elementsAttr, +serializeGenericElementData(Location loc, DenseElementsAttr elementsAttr, llvm::support::endianness endian, llvm::raw_ostream &os) { if (auto attr = llvm::dyn_cast<DenseIntElementsAttr>(elementsAttr)) { @@ -351,8 +377,17 @@ // element type is not integer or floating-point. unsigned bitWidth = attr.getType().getElementTypeBitWidth(); switch (bitWidth) { + case 1: { + // NOTE: i1 is treated as i8 in a lot of places in MLIR/IREE and will need + // a larger cleanup to serialize as a sub-byte value like the others. + // In this one case, we know that DenseIntElementsAttr has been + // prematurely optimized to densely pack bit values ala std::vector<bool>. + // Further, it packs them linearly, regardless of shape, so we have to + // do a simple expansion. + return serializeBitIntegerValuesAsBytes(attr, os); + } case 8: - return serializeRawData(attr, os); + return serializeRawData(loc, attr, os); case 16: return serializeGenericIntegerElements<uint16_t>(attr, endian, os); case 32: @@ -360,15 +395,13 @@ case 64: return serializeGenericIntegerElements<uint64_t>(attr, endian, os); default: - // NOTE: i1 is treated as i8 in a lot of places in MLIR/IREE and will need - // a larger cleanup to serialize as a sub-byte value. if (bitWidth != 1 && bitWidth < 64) { // Special case for bit-packing of sub-byte aligned types. // This could be extended to handle larger widths (i33, etc) but they // are rare today. - return serializeSubByteIntegerElements(attr, endian, os); + return serializeSubByteIntegerElements(loc, attr, endian, os); } - return emitError(UnknownLoc::get(elementsAttr.getContext())) + return emitError(loc) << "unhandled integer element bit width " << bitWidth << " for type " << elementsAttr.getType(); } @@ -387,13 +420,11 @@ case 64: return serializeGenericFloatElements<uint64_t>(attr, endian, os); default: - return emitError(UnknownLoc::get(elementsAttr.getContext())) - << "unhandled float element bit width " << bitWidth << " for type " - << elementsAttr.getType(); + return emitError(loc) << "unhandled float element bit width " << bitWidth + << " for type " << elementsAttr.getType(); } } - return emitError(UnknownLoc::get(elementsAttr.getContext())) - << "unhandled constant type " << elementsAttr.getType(); + return emitError(loc) << "unhandled constant type " << elementsAttr.getType(); } //===----------------------------------------------------------------------===// @@ -558,23 +589,25 @@ int64_t CompositeAttr::getStorageSize() const { return getTotalLength(); } -LogicalResult CompositeAttr::serializeToBuffer(llvm::support::endianness endian, +LogicalResult CompositeAttr::serializeToBuffer(Location loc, + llvm::support::endianness endian, ArrayRef<char> buffer) const { raw_inplace_ostream os(buffer); - return serializeToStream(endian, os); + return serializeToStream(loc, endian, os); } -LogicalResult CompositeAttr::serializeToStream(llvm::support::endianness endian, +LogicalResult CompositeAttr::serializeToStream(Location loc, + llvm::support::endianness endian, llvm::raw_ostream &os) const { for (auto valueAttr : getValues()) { auto serializableAttr = llvm::dyn_cast<SerializableAttrInterface>(valueAttr); if (!serializableAttr) { - llvm::errs() << "unable to serialize a non-serializable attribute: " - << valueAttr << "\n"; - return failure(); + return emitError(loc) + << "unable to serialize a non-serializable attribute: " + << valueAttr; } - if (failed(serializableAttr.serializeToStream(endian, os))) { + if (failed(serializableAttr.serializeToStream(loc, endian, os))) { return failure(); } } @@ -593,21 +626,21 @@ cast<ShapedType>(attr.getType()).getElementType()); } - LogicalResult serializeToVector(Attribute baseAttr, + LogicalResult serializeToVector(Attribute baseAttr, Location loc, llvm::support::endianness endian, SmallVectorImpl<char> &buffer) const { buffer.resize(getStorageSize(baseAttr)); - return serializeToBuffer(baseAttr, endian, buffer); + return serializeToBuffer(baseAttr, loc, endian, buffer); } - LogicalResult serializeToBuffer(Attribute baseAttr, + LogicalResult serializeToBuffer(Attribute baseAttr, Location loc, llvm::support::endianness endian, ArrayRef<char> buffer) const { raw_inplace_ostream os(buffer); - return serializeToStream(baseAttr, endian, os); + return serializeToStream(baseAttr, loc, endian, os); } - LogicalResult serializeToStream(Attribute baseAttr, + LogicalResult serializeToStream(Attribute baseAttr, Location loc, llvm::support::endianness endian, llvm::raw_ostream &os) const { // NOTE: not all ostream implementations handle this but for buffering ones @@ -617,7 +650,7 @@ auto elementsAttr = llvm::cast<DenseElementsAttr>(baseAttr); if (elementsAttr.isSplat()) { // Fast-path for splat (no need to convert the value a bunch). - return serializeSplatValue(elementsAttr.getSplatValue<Attribute>(), + return serializeSplatValue(loc, elementsAttr.getSplatValue<Attribute>(), elementsAttr.getNumElements(), endian, os); } @@ -625,10 +658,10 @@ // Fast-path for bulk data copies that don't require endianness handling. // This relies on DenseElementsAttr storing 8-bit values as 8-bit values; // other sized types are stored in an opaque format. - return serializeRawData(elementsAttr, os); + return serializeRawData(loc, elementsAttr, os); } else { // Slow-path that performs expensive conversion. - return serializeGenericElementData(elementsAttr, endian, os); + return serializeGenericElementData(loc, elementsAttr, endian, os); } } }; @@ -645,21 +678,21 @@ attr.getNumElements(), attr.getType().getElementType()); } - LogicalResult serializeToVector(Attribute baseAttr, + LogicalResult serializeToVector(Attribute baseAttr, Location loc, llvm::support::endianness endian, SmallVectorImpl<char> &buffer) const { buffer.resize(getStorageSize(baseAttr)); - return serializeToBuffer(baseAttr, endian, buffer); + return serializeToBuffer(baseAttr, loc, endian, buffer); } - LogicalResult serializeToBuffer(Attribute baseAttr, + LogicalResult serializeToBuffer(Attribute baseAttr, Location loc, llvm::support::endianness endian, ArrayRef<char> buffer) const { raw_inplace_ostream os(buffer); - return serializeToStream(baseAttr, endian, os); + return serializeToStream(baseAttr, loc, endian, os); } - LogicalResult serializeToStream(Attribute baseAttr, + LogicalResult serializeToStream(Attribute baseAttr, Location loc, llvm::support::endianness endian, llvm::raw_ostream &os) const { auto attr = llvm::cast<DenseResourceElementsAttr>(baseAttr); @@ -670,7 +703,7 @@ // results if executed but it can be useful when building reproducers. if (handle.getKey() == "__elided__") { if (!clZeroFillElidedAttrs) { - return mlir::emitError(UnknownLoc::get(baseAttr.getContext())) + return mlir::emitError(loc) << "elided attributes cannot be serialized; provide non-elided " "values or pass --iree-util-zero-fill-elided-attrs for " "testing and expect invalid execution results"; @@ -679,7 +712,7 @@ return success(); } - return mlir::emitError(UnknownLoc::get(baseAttr.getContext())) + return mlir::emitError(loc) << "DenseResourceElementsAttr not yet supported for serialization"; } }; @@ -694,21 +727,21 @@ return attr.getValue().size(); } - LogicalResult serializeToVector(Attribute baseAttr, + LogicalResult serializeToVector(Attribute baseAttr, Location loc, llvm::support::endianness endian, SmallVectorImpl<char> &buffer) const { buffer.resize(getStorageSize(baseAttr)); - return serializeToBuffer(baseAttr, endian, buffer); + return serializeToBuffer(baseAttr, loc, endian, buffer); } - LogicalResult serializeToBuffer(Attribute baseAttr, + LogicalResult serializeToBuffer(Attribute baseAttr, Location loc, llvm::support::endianness endian, ArrayRef<char> buffer) const { raw_inplace_ostream os(buffer); - return serializeToStream(baseAttr, endian, os); + return serializeToStream(baseAttr, loc, endian, os); } - LogicalResult serializeToStream(Attribute baseAttr, + LogicalResult serializeToStream(Attribute baseAttr, Location loc, llvm::support::endianness endian, llvm::raw_ostream &os) const { // NOTE: not all ostream implementations handle this but for buffering ones
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td index 2e67a46..be9bb48 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
@@ -1116,13 +1116,14 @@ /*retTy=*/"LogicalResult", /*methodName=*/"serializeToVector", /*args=*/(ins + "Location":$loc, "llvm::support::endianness":$endian, "SmallVectorImpl<char> &":$buffer ), /*methodBody=*/[{}], /*defaultImplementation=*/[{ buffer.resize($_attr.getStorageSize()); - return $_attr.serializeToBuffer(endian, buffer); + return $_attr.serializeToBuffer(loc, endian, buffer); }] >, InterfaceMethod< @@ -1133,6 +1134,7 @@ /*retTy=*/"LogicalResult", /*methodName=*/"serializeToBuffer", /*args=*/(ins + "Location":$loc, "llvm::support::endianness":$endian, "ArrayRef<char>":$buffer ) @@ -1145,13 +1147,14 @@ /*retTy=*/"LogicalResult", /*methodName=*/"serializeToStream", /*args=*/(ins + "Location":$loc, "llvm::support::endianness":$endian, "llvm::raw_ostream &":$os ), /*methodBody=*/[{}], /*defaultImplementation=*/[{ SmallVector<char> buffer; - if (failed($_attr.serializeToVector(endian, buffer))) { + if (failed($_attr.serializeToVector(loc, endian, buffer))) { return failure(); } os.write(reinterpret_cast<const char *>(buffer.data()), buffer.size());
diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp index 544bf11..b17729d 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
@@ -116,7 +116,7 @@ // Serialize the constant into the reserved memory. if (failed(value.serializeToBuffer( - llvm::support::endianness::little, + loc, llvm::support::endianness::little, ArrayRef<char>(reinterpret_cast<char *>(bytePtr), static_cast<size_t>(totalSize))))) { mlir::emitError(loc) << "constant attribute failed to serialize: " @@ -652,6 +652,7 @@ actualSize >= kMaxEmbeddedDataSize); RodataRef rodataRef; + Location rodataLoc = rodataOp.getLoc(); rodataRef.rodataOp = rodataOp; rodataRef.alignment = rodataOp.getAlignment().value_or(kDefaultRodataAlignment); @@ -665,7 +666,7 @@ fileName, rodataRef.alignment, rodataRef.totalSize, [=](llvm::raw_ostream &os) { return rodataValue.serializeToStream( - llvm::support::endianness::little, os); + rodataLoc, llvm::support::endianness::little, os); }); } rodataRefs[rodataOp.getOrdinal()->getLimitedValue()] = rodataRef;
diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp index 2645d33..66bec5b 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
@@ -74,7 +74,8 @@ rodataOp.getValue()); assert(value && "expected a serializable rodata value"); SmallVector<char> byteBuffer; - if (failed(value.serializeToVector(llvm::support::endianness::little, + if (failed(value.serializeToVector(rodataOp.getLoc(), + llvm::support::endianness::little, byteBuffer))) { return rodataOp.emitError() << "error during serialization"; }