blob: beb1e1e0e1b38671b5255133bf7191ae2483bef2 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "compiler/Serialization/BytecodeWriter.h"
#include <algorithm>
#include "compiler/Utils/Macros.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
namespace iree_compiler {
LogicalResult BytecodeWriter::WriteCount(int count) {
if (count > UINT8_MAX) {
// TODO(benvanik): varints?
llvm::errs() << "Too many items: " << count
<< "; only 0-UINT8_MAX are supported";
return failure();
}
return WriteUint8(static_cast<uint8_t>(count));
}
LogicalResult BytecodeWriter::WriteTypeIndex(Type type) {
iree::BuiltinType type_index;
if (type.isInteger(8)) {
type_index = iree::BuiltinType::kI8;
} else if (type.isInteger(16)) {
type_index = iree::BuiltinType::kI16;
} else if (type.isInteger(32)) {
type_index = iree::BuiltinType::kI32;
} else if (type.isInteger(64)) {
type_index = iree::BuiltinType::kI64;
} else if (type.isF16()) {
type_index = iree::BuiltinType::kF16;
} else if (type.isF32()) {
type_index = iree::BuiltinType::kF32;
} else if (type.isF64()) {
type_index = iree::BuiltinType::kF64;
} else {
// TODO(benvanik): support unknown types as BuiltinType::kOpaque?
return emitError(UnknownLoc::get(type.getContext()))
<< "Type " << type << " cannot be represented by a builtin type";
}
return WriteUint8(static_cast<uint8_t>(type_index));
}
LogicalResult BytecodeWriter::WriteFunctionOrdinal(FuncOp function) {
auto functionOrdinal = function.getAttrOfType<IntegerAttr>("iree.ordinal");
if (!functionOrdinal) {
return function.emitError() << "Ordinal not assigned to function";
}
RETURN_IF_FAILURE(WriteUint32(functionOrdinal.getInt()));
return success();
}
LogicalResult BytecodeWriter::WriteImportOrdinal(FuncOp function) {
// For now this is the same as internal function ordinals, though we could
// probably shrink it.
return WriteFunctionOrdinal(function);
}
LogicalResult BytecodeWriter::WriteConstant(MemRefType memRefType,
Attribute baseAttr) {
// All types are memrefs, so we only need the element type.
RETURN_IF_FAILURE(WriteTypeIndex(memRefType.getElementType()));
// Write shape (we could optimize this for cases of scalars and such).
RETURN_IF_FAILURE(WriteCount(memRefType.getRank()));
for (int i = 0; i < memRefType.getRank(); ++i) {
RETURN_IF_FAILURE(WriteInt32(memRefType.getDimSize(i)));
}
if (auto attr = baseAttr.dyn_cast<SplatElementsAttr>()) {
RETURN_IF_FAILURE(
WriteUint8(static_cast<uint8_t>(iree::ConstantEncoding::kSplat)));
return WriteAttributeData(attr.getSplatValue());
}
RETURN_IF_FAILURE(
WriteUint8(static_cast<uint8_t>(iree::ConstantEncoding::kDense)));
return WriteAttributeData(baseAttr);
}
LogicalResult BytecodeWriter::WriteAttributeData(Attribute baseAttr) {
if (auto attr = baseAttr.dyn_cast<BoolAttr>()) {
return WriteUint8(attr.getValue() ? 1 : 0);
} else if (auto attr = baseAttr.dyn_cast<IntegerAttr>()) {
if (attr.getType().isIndex()) {
int32_t value = static_cast<int32_t>(attr.getInt());
return WriteBytes(&value, 4);
} else {
int bitWidth = attr.getValue().getBitWidth();
switch (bitWidth) {
case 8:
case 16:
case 32:
case 64:
return WriteBytes(attr.getValue().getRawData(), bitWidth / 8);
default:
return emitError(UnknownLoc::get(baseAttr.getContext()))
<< "Bit width for integers must be one of 8,16,32,64; others "
"not implemented: "
<< bitWidth;
}
}
} else if (auto attr = baseAttr.dyn_cast<FloatAttr>()) {
int bitWidth = attr.getType().getIntOrFloatBitWidth();
auto bitcastValue = attr.getValue().bitcastToAPInt();
switch (bitWidth) {
case 16:
case 32:
case 64:
return WriteBytes(bitcastValue.getRawData(), bitWidth / 8);
default:
return emitError(UnknownLoc::get(baseAttr.getContext()))
<< "Bit width for floats must be one of 16,32,64; others "
"not implemented: "
<< bitWidth;
}
} else if (auto attr = baseAttr.dyn_cast<StringAttr>()) {
// TODO(benvanik): other attribute encodings.
} else if (auto attr = baseAttr.dyn_cast<ArrayAttr>()) {
// TODO(benvanik): other attribute encodings.
} else if (auto attr = baseAttr.dyn_cast<AffineMapAttr>()) {
// TODO(benvanik): other attribute encodings.
} else if (auto attr = baseAttr.dyn_cast<IntegerSetAttr>()) {
// TODO(benvanik): other attribute encodings.
} else if (auto attr = baseAttr.dyn_cast<TypeAttr>()) {
// TODO(benvanik): other attribute encodings.
} else if (auto attr = baseAttr.dyn_cast<SymbolRefAttr>()) {
// TODO(benvanik): other attribute encodings.
} else if (auto attr = baseAttr.dyn_cast<SplatElementsAttr>()) {
return WriteAttributeData(attr.getSplatValue());
} else if (auto attr = baseAttr.dyn_cast<DenseIntElementsAttr>()) {
int elementCount = attr.getType().getNumElements();
if (elementCount == 0) {
return success();
}
int bitWidth = attr.getType().getElementTypeBitWidth();
int byteWidth = bitWidth / 8;
auto dst = ReserveBytes(elementCount * byteWidth);
if (dst.empty()) return failure();
uint8_t *dstPtr = dst.data();
for (auto element : attr) {
assert(element.getBitWidth() == bitWidth);
std::memcpy(dstPtr, element.getRawData(), byteWidth);
dstPtr += byteWidth;
}
return success();
} else if (auto attr = baseAttr.dyn_cast<DenseFPElementsAttr>()) {
int elementCount = attr.getType().getNumElements();
if (elementCount == 0) {
return success();
}
int bitWidth = attr.getType().getElementTypeBitWidth();
auto dst = ReserveBytes(elementCount * bitWidth / 8);
if (dst.empty()) return failure();
uint8_t *dstPtr = dst.data();
for (auto element : attr) {
auto bitcastValue = element.bitcastToAPInt();
std::memcpy(dstPtr, bitcastValue.getRawData(),
bitcastValue.getBitWidth() / 8);
dstPtr += bitWidth / 8;
}
return success();
} else if (auto attr = baseAttr.dyn_cast<DenseElementsAttr>()) {
// TODO(benvanik): other attribute encodings.
} else if (auto attr = baseAttr.dyn_cast<OpaqueElementsAttr>()) {
// TODO(benvanik): other attribute encodings.
} else if (auto attr = baseAttr.dyn_cast<SparseElementsAttr>()) {
// TODO(benvanik): other attribute encodings.
}
return emitError(UnknownLoc::get(baseAttr.getContext()))
<< "Serializer for attribute kind "
<< static_cast<int>(baseAttr.getKind()) << " not implemented";
}
Optional<int> BytecodeWriter::LookupLocalOrdinal(Value *value) {
int ordinal;
auto it = localMap_.find(value);
if (it != localMap_.end()) {
ordinal = it->second;
} else {
ordinal = localMap_.size();
localMap_.insert({value, ordinal});
}
if (ordinal > UINT16_MAX) {
// TODO(benvanik): varints?
emitError(UnknownLoc::get(value->getContext()))
<< "Too many ordinals: " << ordinal
<< "; only 0-UINT16_MAX are supported";
return llvm::None;
}
return ordinal;
}
LogicalResult BytecodeWriter::PrepareLocal(Value *value) {
if (!LookupLocalOrdinal(value).hasValue()) return failure();
return success();
}
LogicalResult BytecodeWriter::WriteLocal(Value *value) {
auto ordinal = LookupLocalOrdinal(value);
if (!ordinal.hasValue()) {
return failure();
}
if (ordinal.getValue() > UINT16_MAX) {
// TODO(benvanik): varints?
return emitError(UnknownLoc::get(value->getContext()))
<< "Too many locals: " << ordinal.getValue()
<< "; only 0-UINT16_MAX are supported";
}
return WriteUint16(static_cast<uint16_t>(ordinal.getValue()));
}
LogicalResult BytecodeWriter::WriteLocals(
llvm::iterator_range<Operation::operand_iterator> values) {
int count = std::distance(values.begin(), values.end());
RETURN_IF_FAILURE(WriteCount(count));
for (auto *value : values) {
RETURN_IF_FAILURE(WriteLocal(value));
}
return success();
}
LogicalResult BytecodeWriter::WriteLocals(
llvm::iterator_range<Operation::result_iterator> values) {
int count = std::distance(values.begin(), values.end());
RETURN_IF_FAILURE(WriteCount(count));
for (auto *value : values) {
RETURN_IF_FAILURE(WriteLocal(value));
}
return success();
}
MutableArrayRef<uint8_t> BytecodeWriter::ReserveBytes(size_t dataLength) {
int offset = bytecode_.size();
bytecode_.resize(offset + dataLength);
return MutableArrayRef<uint8_t>(
reinterpret_cast<uint8_t *>(bytecode_.data()) + offset, dataLength);
}
LogicalResult BytecodeWriter::WriteBytes(const void *data, size_t dataLength) {
auto dst = ReserveBytes(dataLength);
if (dataLength != dst.size()) {
return failure();
}
std::memcpy(dst.data(), data, dst.size());
return success();
}
LogicalResult BytecodeWriter::WriteUint8(uint8_t value) {
return WriteBytes(&value, sizeof(value));
}
LogicalResult BytecodeWriter::WriteUint16(uint16_t value) {
return WriteBytes(&value, sizeof(value));
}
LogicalResult BytecodeWriter::WriteInt32(int32_t value) {
return WriteBytes(&value, sizeof(value));
}
LogicalResult BytecodeWriter::WriteUint32(uint32_t value) {
return WriteBytes(&value, sizeof(value));
}
LogicalResult BytecodeWriter::WriteElementsAttrInt32(ElementsAttr attr) {
int elementCount = attr.getType().getNumElements();
RETURN_IF_FAILURE(WriteCount(elementCount));
for (auto value : attr.getValues<int32_t>()) {
RETURN_IF_FAILURE(WriteInt32(value));
}
return success();
}
LogicalResult BytecodeWriter::WriteShapePieces(const ShapedType &type) {
RETURN_IF_FAILURE(WriteCount(type.getRank()));
for (int64_t dim : type.getShape()) {
RETURN_IF_FAILURE(WriteInt32(dim));
}
return success();
}
LogicalResult BytecodeWriter::WriteShapePieces(ElementsAttr pieces) {
return WriteElementsAttrInt32(pieces);
}
LogicalResult BytecodeWriter::MarkBlockOffset(Block *block) {
blockOffsets_[block] = bytecode_.size();
return success();
}
LogicalResult BytecodeWriter::WriteBlockOffset(Block *targetBlock) {
// Reserve space for the offset and stash for later fixup.
blockOffsetFixups_.push_back({targetBlock, bytecode_.size()});
bytecode_.resize(bytecode_.size() + sizeof(int32_t));
return success();
}
LogicalResult BytecodeWriter::FixupOffsets() {
for (const auto &fixup : blockOffsetFixups_) {
auto it = blockOffsets_.find(fixup.first);
if (it == blockOffsets_.end()) {
llvm::errs() << "Block offset not found: " << fixup.first;
return failure();
}
std::memcpy(bytecode_.data() + fixup.second, &it->second, sizeof(int32_t));
}
blockOffsetFixups_.clear();
return success();
}
std::vector<uint8_t> BytecodeWriter::Finish() {
localMap_.clear();
return std::move(bytecode_);
}
} // namespace iree_compiler
} // namespace mlir