blob: 23eda8833c4146c3eb3dc8ee42bf29657d1baf96 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "compiler/Serialization/VMFunctionBuilder.h"
#include "flatbuffers/flatbuffers.h"
#include "compiler/IR/Dialect.h"
#include "compiler/IR/Types.h"
#include "compiler/Serialization/BytecodeTables.h"
#include "compiler/Utils/Macros.h"
#include "schemas/bytecode/bytecode_v0.h"
#include "schemas/type_def_generated.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Module.h"
namespace mlir {
namespace iree_compiler {
namespace {
LogicalResult WriteGenericIreeOp(Block *block, Operation *op,
BytecodeWriter *writer) {
// Strip the dialect name from the op name and lookup the opcode.
// TODO(benvanik): adjust for supporting sequencer opcodes.
auto opName = op->getName().getStringRef();
auto dialect = op->getDialect();
if (!dialect) {
return op->emitOpError() << "Op does not belong to a registered dialect";
}
auto dialectNamespace = dialect->getNamespace();
std::unique_ptr<OpcodeInfo> operandInfo;
auto strippedOpName = opName.substr(opName.find('.') + 1).str();
if (dialectNamespace == "iree_ll_seq") {
auto opcode = GetSequencerOpcodeByName(strippedOpName);
if (!opcode.hasValue()) {
return op->emitOpError()
<< "No sequencer opcode found for op; is it a pseudo op?";
}
RETURN_IF_FAILURE(writer->WriteOpcode(opcode.getValue()));
operandInfo =
std::make_unique<OpcodeInfo>(GetSequencerOpcodeInfo(opcode.getValue()));
} else if (dialectNamespace == "iree_ll_interp" ||
// TODO(gcmn) remove special case for IREE dialect?
dialectNamespace == IREEDialect::getDialectNamespace()) {
auto opcode = GetInterpreterOpcodeByName(strippedOpName);
if (!opcode.hasValue()) {
return op->emitOpError()
<< "No interpreter opcode found for op; is it a pseudo op?";
}
RETURN_IF_FAILURE(writer->WriteOpcode(opcode.getValue()));
operandInfo = std::make_unique<OpcodeInfo>(
GetInterpreterOpcodeInfo(opcode.getValue()));
} else {
return op->emitOpError()
<< "Op belongs to unknown dialect " << dialectNamespace.str();
}
// Write inputs and outputs based on the bytecode encoding.
int operandIndex = 0;
int resultIndex = 0;
for (int i = 0; i < llvm::array_lengthof(operandInfo->operands); ++i) {
auto op_encoding = operandInfo->operands[i];
if (op_encoding == iree::OperandEncoding::kNone) break;
switch (op_encoding) {
case iree::OperandEncoding::kInputSlot:
case iree::OperandEncoding::kOutputSlot: {
auto *value = op->getOperand(operandIndex++);
RETURN_IF_FAILURE(writer->WriteLocal(value));
break;
}
case iree::OperandEncoding::kVariadicInputSlots:
case iree::OperandEncoding::kVariadicOutputSlots: {
int count = op->getNumOperands() - operandIndex;
RETURN_IF_FAILURE(writer->WriteCount(count));
for (; count; --count) {
auto *value = op->getOperand(operandIndex++);
RETURN_IF_FAILURE(writer->WriteLocal(value));
}
break;
}
case iree::OperandEncoding::kResultSlot: {
auto *value = op->getResult(resultIndex++);
RETURN_IF_FAILURE(writer->WriteLocal(value));
break;
}
case iree::OperandEncoding::kVariadicResultSlots: {
int count = op->getNumResults() - resultIndex;
RETURN_IF_FAILURE(writer->WriteCount(count));
for (; count; --count) {
auto *value = op->getResult(resultIndex++);
RETURN_IF_FAILURE(writer->WriteLocal(value));
}
break;
}
case iree::OperandEncoding::kConstant:
case iree::OperandEncoding::kFunctionOrdinal:
case iree::OperandEncoding::kBlockOffset:
case iree::OperandEncoding::kTypeIndex:
case iree::OperandEncoding::kIndex:
case iree::OperandEncoding::kIndexList:
case iree::OperandEncoding::kCmpIPredicate:
case iree::OperandEncoding::kCmpFPredicate:
return op->emitOpError()
<< "Operand encoding " << static_cast<char>(op_encoding)
<< " not supported by generic writer for " << opName.str();
return failure();
default:
return op->emitOpError()
<< "Operand encoding " << static_cast<char>(op_encoding) << " ("
<< static_cast<int>(op_encoding) << ") not recognized (typo?)";
}
}
return success();
}
} // namespace
VMFunctionBuilder::VMFunctionBuilder(FuncOp function,
VMFunctionTableBuilder *functionTable,
::flatbuffers::FlatBufferBuilder *fbb)
: context_(function.getContext()),
function_(function),
functionTable_(functionTable),
fbb_(fbb) {}
void VMFunctionBuilder::RegisterCustomWriter(StringRef operationName,
CustomWriterFn writerFn) {
customWriters_.insert({operationName, writerFn});
}
LogicalResult VMFunctionBuilder::ConvertBytecode() {
BytecodeWriter writer;
sourceMap_ = {};
RETURN_IF_FAILURE(BeginFunction(function_, &writer));
for (auto &block : function_.getBlocks()) {
RETURN_IF_FAILURE(BeginBlock(&block, &writer));
for (auto &op : block.getOperations()) {
if (failed(WriteOperation(&block, &op, &writer))) {
op.emitError() << "Unable to serialize operation";
return failure();
}
}
RETURN_IF_FAILURE(EndBlock(&block, block.getTerminator(), &writer));
}
RETURN_IF_FAILURE(EndFunction(function_, &writer));
int localCount = writer.local_count();
auto bodyBytes = writer.Finish();
auto bodyOffset = fbb_->CreateVector(
reinterpret_cast<const int8_t *>(bodyBytes.data()), bodyBytes.size());
iree::BytecodeDefBuilder bdb(*fbb_);
bdb.add_local_count(localCount);
bdb.add_contents(bodyOffset);
bytecodeDef_ = bdb.Finish();
return success();
}
::flatbuffers::Offset<iree::FunctionDef> VMFunctionBuilder::Finish() {
using TypeDefVector =
::flatbuffers::Vector<::flatbuffers::Offset<iree::TypeDef>>;
const auto &functionType = function_.getType();
std::vector<::flatbuffers::Offset<iree::TypeDef>> inputs;
for (const auto &type : functionType.getInputs()) {
auto typeOffset = SerializeType(type, fbb_);
if (typeOffset.IsNull()) return {};
inputs.push_back(typeOffset);
}
::flatbuffers::Offset<TypeDefVector> inputsOffset;
if (!inputs.empty()) {
inputsOffset = fbb_->CreateVector(inputs);
}
std::vector<::flatbuffers::Offset<iree::TypeDef>> results;
for (const auto &type : functionType.getResults()) {
auto typeOffset = SerializeType(type, fbb_);
if (typeOffset.IsNull()) return {};
results.push_back(typeOffset);
}
::flatbuffers::Offset<TypeDefVector> resultsOffset;
if (!results.empty()) {
resultsOffset = fbb_->CreateVector(results);
}
iree::FunctionTypeDefBuilder ftb(*fbb_);
ftb.add_inputs(inputsOffset);
ftb.add_results(resultsOffset);
auto functionTypeOffset = ftb.Finish();
// TODO(benvanik): strip names of internal functions.
auto nameOffset = fbb_->CreateString(function_.getName().str());
iree::FunctionDefBuilder fdb(*fbb_);
fdb.add_name(nameOffset);
fdb.add_type(functionTypeOffset);
fdb.add_bytecode(bytecodeDef_);
return fdb.Finish();
}
LogicalResult VMFunctionBuilder::BeginFunction(FuncOp function,
BytecodeWriter *writer) {
// Assign value slots for all arguments and results.
// Keeping them at the front will make it easier to find during debugging
// and makes spans easier to compute at runtime.
for (auto argument : function.getArguments()) {
RETURN_IF_FAILURE(writer->PrepareLocal(argument));
}
return success();
}
LogicalResult VMFunctionBuilder::EndFunction(FuncOp function,
BytecodeWriter *writer) {
RETURN_IF_FAILURE(writer->FixupOffsets());
return success();
}
LogicalResult VMFunctionBuilder::BeginBlock(Block *block,
BytecodeWriter *writer) {
RETURN_IF_FAILURE(writer->MarkBlockOffset(block));
return success();
}
LogicalResult VMFunctionBuilder::EndBlock(Block *block, Operation *op,
BytecodeWriter *writer) {
return success();
}
LogicalResult VMFunctionBuilder::WriteOperation(Block *block, Operation *baseOp,
BytecodeWriter *writer) {
if (!baseOp->getLoc().isa<UnknownLoc>()) {
sourceMap_.locations.push_back({writer->offset(), baseOp->getLoc()});
}
// Check registered writers first to allow overrides.
auto writerIt = customWriters_.find(baseOp->getName().getStringRef());
if (writerIt != customWriters_.end()) {
return writerIt->second(baseOp, writer);
}
// Fallback to using the generic writer.
if (baseOp->getAbstractOperation()->dialect.getNamespace().startswith(
"iree")) {
RETURN_IF_FAILURE(WriteGenericIreeOp(block, baseOp, writer));
} else {
return baseOp->emitError()
<< "Unsupported op " << baseOp->getName().getStringRef().str()
<< "; incorrectly outlined or not yet implemented";
}
return success();
}
::flatbuffers::Offset<iree::TypeDef> VMFunctionBuilder::SerializeType(
Type type, ::flatbuffers::FlatBufferBuilder *fbb) {
::flatbuffers::Offset<void> typeDefUnion;
iree::TypeDefUnion typeUnionType;
if (auto memRefType = type.dyn_cast<MemRefType>()) {
auto memRefTypeOffset = SerializeMemRefType(memRefType, fbb_);
if (memRefTypeOffset.IsNull()) return {};
typeDefUnion = memRefTypeOffset.Union();
typeUnionType = iree::TypeDefUnion::MemRefTypeDef;
} else if (auto deviceType = type.dyn_cast<DeviceType>()) {
typeDefUnion = iree::CreateDeviceTypeDef(*fbb).Union();
typeUnionType = iree::TypeDefUnion::DeviceTypeDef;
} else if (auto commandBufferType = type.dyn_cast<CommandBufferType>()) {
typeDefUnion = iree::CreateCommandBufferTypeDef(*fbb).Union();
typeUnionType = iree::TypeDefUnion::CommandBufferTypeDef;
} else if (auto eventType = type.dyn_cast<EventType>()) {
typeDefUnion = iree::CreateEventTypeDef(*fbb).Union();
typeUnionType = iree::TypeDefUnion::EventTypeDef;
} else if (auto semaphoreType = type.dyn_cast<SemaphoreType>()) {
typeDefUnion = iree::CreateSemaphoreTypeDef(*fbb).Union();
typeUnionType = iree::TypeDefUnion::SemaphoreTypeDef;
} else if (auto fenceType = type.dyn_cast<FenceType>()) {
typeDefUnion = iree::CreateFenceTypeDef(*fbb).Union();
typeUnionType = iree::TypeDefUnion::FenceTypeDef;
} else {
function_.emitError() << "Function " << function_.getName().str()
<< " has unsupported I/O with type " << type;
return {};
}
iree::TypeDefBuilder tdb(*fbb);
tdb.add_type_union_type(typeUnionType);
tdb.add_type_union(typeDefUnion);
return tdb.Finish();
}
::flatbuffers::Offset<iree::MemRefTypeDef>
VMFunctionBuilder::SerializeMemRefType(const MemRefType &type,
::flatbuffers::FlatBufferBuilder *fbb) {
auto elementTypeOffset = SerializeElementType(type.getElementType(), fbb);
if (elementTypeOffset.IsNull()) return {};
std::vector<int> shape;
for (int dim : type.getShape()) {
shape.push_back(dim);
}
auto shapeOffset = fbb->CreateVector(shape);
iree::MemRefTypeDefBuilder tb(*fbb);
tb.add_element_type(elementTypeOffset);
tb.add_shape(shapeOffset);
tb.add_memory_space(type.getMemorySpace());
return tb.Finish();
}
::flatbuffers::Offset<iree::ElementTypeDef>
VMFunctionBuilder::SerializeElementType(const Type &genericType,
::flatbuffers::FlatBufferBuilder *fbb) {
::flatbuffers::Offset<void> typeDefUnion;
iree::ElementTypeDefUnion typeUnionType;
if (auto type = genericType.dyn_cast<FloatType>()) {
iree::FloatTypeDefBuilder tb(*fbb);
tb.add_width(type.getWidth());
typeDefUnion = tb.Finish().Union();
typeUnionType = iree::ElementTypeDefUnion::FloatTypeDef;
} else if (auto type = genericType.dyn_cast<IntegerType>()) {
iree::IntegerTypeDefBuilder tb(*fbb);
tb.add_width(type.getWidth());
typeDefUnion = tb.Finish().Union();
typeUnionType = iree::ElementTypeDefUnion::IntegerTypeDef;
} else if (auto type = genericType.dyn_cast<OpaqueType>()) {
auto dialectOffset = fbb->CreateString(type.getDialectNamespace().c_str());
auto typeDataOffset = fbb->CreateString(type.getTypeData().data());
iree::UnknownTypeDefBuilder tb(*fbb);
tb.add_dialect(dialectOffset);
tb.add_type_data(typeDataOffset);
typeDefUnion = tb.Finish().Union();
typeUnionType = iree::ElementTypeDefUnion::UnknownTypeDef;
} else {
function_.emitError()
<< "Unimplemented type encoding: " << genericType
<< "; ensure IREE lowering passes are converting types to the IREE "
"set";
return {};
}
iree::ElementTypeDefBuilder tdb(*fbb);
tdb.add_type_union_type(typeUnionType);
tdb.add_type_union(typeDefUnion);
return tdb.Finish();
}
} // namespace iree_compiler
} // namespace mlir