| // Copyright 2019 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h" |
| |
| #include "iree/compiler/Dialect/IREE/IR/IREETypes.h" |
| #include "iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h" |
| #include "iree/compiler/Dialect/VM/IR/VMDialect.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/Diagnostics.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| namespace IREE { |
| namespace VM { |
| |
| namespace { |
| |
| // v0 bytecode spec. This is in extreme flux and not guaranteed to be a stable |
| // representation. Always generate this from source in tooling and never check |
| // in any emitted files! |
| class V0BytecodeEncoder : public BytecodeEncoder { |
| public: |
| V0BytecodeEncoder(llvm::DenseMap<Type, int> *typeTable, |
| RegisterAllocation *registerAllocation) |
| : typeTable_(typeTable), registerAllocation_(registerAllocation) {} |
| ~V0BytecodeEncoder() = default; |
| |
| LogicalResult beginBlock(Block *block) override { |
| blockOffsets_[block] = bytecode_.size(); |
| return success(); |
| } |
| |
| LogicalResult endBlock(Block *block) override { return success(); } |
| |
| LogicalResult beginOp(Operation *op) override { |
| currentOp_ = op; |
| // TODO(benvanik): encode source information (start). |
| return success(); |
| } |
| |
| LogicalResult endOp(Operation *op) override { |
| // TODO(benvanik): encode source information (end). |
| currentOp_ = nullptr; |
| return success(); |
| } |
| |
| LogicalResult encodeI8(int value) override { return writeUint8(value); } |
| |
| LogicalResult encodeOpcode(StringRef name, int opcode) override { |
| return writeUint8(opcode); |
| } |
| |
| LogicalResult encodeSymbolOrdinal(SymbolTable &syms, |
| StringRef name) override { |
| auto *symbolOp = syms.lookup(name); |
| if (!symbolOp) { |
| return currentOp_->emitOpError() << "target symbol not found: " << name; |
| } |
| auto ordinalAttr = symbolOp->getAttrOfType<IntegerAttr>("ordinal"); |
| if (!ordinalAttr) { |
| return symbolOp->emitOpError() << "missing ordinal"; |
| } |
| int32_t ordinal = ordinalAttr.getInt(); |
| if (isa<IREE::VM::ImportOp>(symbolOp)) { |
| // Imported functions have their MSB set. |
| ordinal |= 0x80000000u; |
| } |
| return writeInt32(ordinal); |
| } |
| |
| LogicalResult encodeType(Value value) override { |
| auto refPtrType = value.getType().dyn_cast<IREE::RefPtrType>(); |
| if (!refPtrType) { |
| return currentOp_->emitOpError() |
| << "type " << value.getType() |
| << " is not supported as a serialized type kind"; |
| } |
| int typeOrdinal = typeTable_->lookup(refPtrType.getObjectType()); |
| return writeUint32(typeOrdinal); |
| } |
| |
| LogicalResult encodeIntAttr(IntegerAttr value) override { |
| auto attr = value.cast<IntegerAttr>(); |
| int bitWidth = attr.getType().getIntOrFloatBitWidth(); |
| uint64_t limitedValue = attr.getValue().extractBitsAsZExtValue(bitWidth, 0); |
| switch (bitWidth) { |
| case 8: |
| return writeUint8(static_cast<uint8_t>(limitedValue)); |
| case 16: |
| return writeUint16(static_cast<uint16_t>(limitedValue)); |
| case 32: |
| return writeUint32(static_cast<uint32_t>(limitedValue)); |
| default: |
| return currentOp_->emitOpError() |
| << "attribute of bitwidth " << bitWidth << " not supported"; |
| } |
| } |
| |
| LogicalResult encodeIntArrayAttr(DenseIntElementsAttr value) override { |
| if (value.getNumElements() > UINT8_MAX || |
| failed(writeUint8(value.getNumElements()))) { |
| return currentOp_->emitOpError() << "integer array size out of bounds"; |
| } |
| for (auto el : value.getAttributeValues()) { |
| if (failed(encodeIntAttr(el.cast<IntegerAttr>()))) { |
| return currentOp_->emitOpError() << "failed to encode element " << el; |
| } |
| } |
| return success(); |
| } |
| |
| LogicalResult encodeStrAttr(StringAttr value) override { |
| auto stringValue = value.getValue(); |
| if (stringValue.size() > UINT16_MAX) { |
| return currentOp_->emitOpError() |
| << "string attribute too large for 16-bit p-string (needs " |
| << stringValue.size() << " bytes)"; |
| } |
| return failure(failed(writeUint16(stringValue.size())) || |
| failed(writeBytes(stringValue.data(), stringValue.size()))); |
| } |
| |
| LogicalResult encodeBranch(Block *targetBlock, |
| Operation::operand_range operands, |
| int successorIndex) override { |
| // Reserve space for the block offset. It will get fixed up when we are all |
| // done and know all of the block offsets. |
| blockOffsetFixups_.push_back({targetBlock, bytecode_.size()}); |
| bytecode_.resize(bytecode_.size() + sizeof(int32_t)); |
| |
| // Compute required remappings - we only need to emit them when the source |
| // and dest registers differ. Hopefully the allocator did a good job and |
| // this list is small :) |
| auto srcDstRegs = registerAllocation_->remapSuccessorRegisters( |
| currentOp_, successorIndex); |
| writeUint8(srcDstRegs.size()); |
| for (auto srcDstReg : srcDstRegs) { |
| if (failed(writeUint8(srcDstReg.first)) || |
| failed(writeUint8(srcDstReg.second))) { |
| return failure(); |
| } |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult encodeOperand(Value value, int ordinal) override { |
| uint8_t reg = |
| registerAllocation_->mapUseToRegister(value, currentOp_, ordinal); |
| return writeUint8(reg); |
| } |
| |
| LogicalResult encodeOperands(Operation::operand_range values) override { |
| writeUint8(std::distance(values.begin(), values.end())); |
| for (auto it : llvm::enumerate(values)) { |
| uint8_t reg = registerAllocation_->mapUseToRegister( |
| it.value(), currentOp_, it.index()); |
| if (failed(writeUint8(reg))) { |
| return failure(); |
| } |
| } |
| return success(); |
| } |
| |
| LogicalResult encodeResult(Value value) override { |
| uint8_t reg = registerAllocation_->mapUseToRegister(value, currentOp_, 0); |
| return writeUint8(reg); |
| } |
| |
| LogicalResult encodeResults(Operation::result_range values) override { |
| writeUint8(std::distance(values.begin(), values.end())); |
| for (auto value : values) { |
| uint8_t reg = registerAllocation_->mapToRegister(value); |
| if (failed(writeUint8(reg))) { |
| return failure(); |
| } |
| } |
| return success(); |
| } |
| |
| Optional<std::vector<uint8_t>> finish() { |
| if (failed(fixupOffsets())) { |
| return llvm::None; |
| } |
| return std::move(bytecode_); |
| } |
| |
| private: |
| // TODO(benvanik): replace this with something not using an ever-expanding |
| // vector. I'm sure LLVM has something. |
| |
| MutableArrayRef<uint8_t> reserveBytes(size_t dataLength) { |
| int offset = bytecode_.size(); |
| bytecode_.resize(offset + dataLength); |
| return MutableArrayRef<uint8_t>( |
| reinterpret_cast<uint8_t *>(bytecode_.data()) + offset, dataLength); |
| } |
| |
| LogicalResult writeBytes(const void *data, size_t dataLength) { |
| auto dst = reserveBytes(dataLength); |
| if (dataLength != dst.size()) { |
| return failure(); |
| } |
| std::memcpy(dst.data(), data, dst.size()); |
| return success(); |
| } |
| |
| LogicalResult writeUint8(uint8_t value) { |
| return writeBytes(&value, sizeof(value)); |
| } |
| |
| LogicalResult writeUint16(uint16_t value) { |
| return writeBytes(&value, sizeof(value)); |
| } |
| |
| LogicalResult writeInt32(int32_t value) { |
| return writeBytes(&value, sizeof(value)); |
| } |
| |
| LogicalResult writeUint32(uint32_t value) { |
| return writeBytes(&value, sizeof(value)); |
| } |
| |
| LogicalResult fixupOffsets() { |
| for (const auto &fixup : blockOffsetFixups_) { |
| auto blockOffset = blockOffsets_.find(fixup.first); |
| if (blockOffset == blockOffsets_.end()) { |
| emitError(fixup.first->front().getLoc()) << "Block offset not found"; |
| return failure(); |
| } |
| std::memcpy(bytecode_.data() + fixup.second, &blockOffset->second, |
| sizeof(int32_t)); |
| } |
| blockOffsetFixups_.clear(); |
| return success(); |
| } |
| |
| llvm::DenseMap<Type, int> *typeTable_; |
| RegisterAllocation *registerAllocation_; |
| |
| Operation *currentOp_ = nullptr; |
| |
| std::vector<uint8_t> bytecode_; |
| llvm::DenseMap<Block *, size_t> blockOffsets_; |
| std::vector<std::pair<Block *, size_t>> blockOffsetFixups_; |
| }; |
| |
| } // namespace |
| |
| // static |
| Optional<EncodedBytecodeFunction> BytecodeEncoder::encodeFunction( |
| IREE::VM::FuncOp funcOp, llvm::DenseMap<Type, int> &typeTable, |
| SymbolTable &symbolTable) { |
| EncodedBytecodeFunction result; |
| |
| // Perform register allocation first so that we can quickly lookup values as |
| // we encode the operands/results. |
| RegisterAllocation registerAllocation; |
| if (failed(registerAllocation.recalculate(funcOp))) { |
| funcOp.emitError() << "register allocation failed"; |
| return llvm::None; |
| } |
| result.i32RegisterCount = registerAllocation.getMaxI32RegisterOrdinal() + 1; |
| result.refRegisterCount = registerAllocation.getMaxRefRegisterOrdinal() + 1; |
| |
| V0BytecodeEncoder encoder(&typeTable, ®isterAllocation); |
| for (auto &block : funcOp.getBlocks()) { |
| if (failed(encoder.beginBlock(&block))) { |
| funcOp.emitError() << "failed to begin block"; |
| return llvm::None; |
| } |
| |
| for (auto &op : block.getOperations()) { |
| auto *serializableOp = |
| op.getAbstractOperation()->getInterface<IREE::VM::VMSerializableOp>(); |
| if (!serializableOp) { |
| op.emitOpError() << "is not serializable"; |
| return llvm::None; |
| } |
| if (failed(encoder.beginOp(&op)) || |
| failed(serializableOp->encode(&op, symbolTable, encoder)) || |
| failed(encoder.endOp(&op))) { |
| op.emitOpError() << "failed to encode"; |
| return llvm::None; |
| } |
| } |
| |
| if (failed(encoder.endBlock(&block))) { |
| funcOp.emitError() << "failed to end block"; |
| return llvm::None; |
| } |
| } |
| |
| auto bytecodeData = encoder.finish(); |
| if (!bytecodeData.hasValue()) { |
| funcOp.emitError() << "failed to fixup and finish encoding"; |
| return llvm::None; |
| } |
| result.bytecodeData = bytecodeData.getValue(); |
| return result; |
| } |
| |
| } // namespace VM |
| } // namespace IREE |
| } // namespace iree_compiler |
| } // namespace mlir |