blob: 9d2be025532fe7d10c4c96f1a775fbe65f6d6fc7 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "compiler/IR/Sequencer/OpWriters.h"
#include "compiler/IR/Sequencer/LLOps.h"
#include "compiler/IR/StructureOps.h"
#include "compiler/Serialization/BytecodeWriter.h"
#include "compiler/Utils/Macros.h"
#include "schemas/bytecode/sequencer_bytecode_v0.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
namespace mlir {
namespace iree_compiler {
namespace {
//===----------------------------------------------------------------------===//
// Sequencer ops
//===----------------------------------------------------------------------===//
LogicalResult writeOp(IREESeq::LL::ConstantOp op, BytecodeWriter *writer) {
RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kConstant));
auto memRefType = op.getType().dyn_cast<MemRefType>();
if (!memRefType) {
return op.emitError()
<< "Constant has an unsupported type; must be a memref: "
<< op.getType();
}
RETURN_IF_FAILURE(writer->WriteConstant(memRefType, op.getAttr("value")));
RETURN_IF_FAILURE(writer->WriteLocal(op.getResult()));
return success();
}
LogicalResult writeOp(IREESeq::LL::CallOp op, BytecodeWriter *writer) {
auto module = op.getOperation()->getParentOfType<ModuleOp>();
auto callee = module.lookupSymbol<FuncOp>(op.getCallee());
// TODO(benvanik): switch with kCallTail if attr exists.
RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCall));
RETURN_IF_FAILURE(writer->WriteFunctionOrdinal(callee));
RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands()));
RETURN_IF_FAILURE(writer->WriteLocals(op.getResults()));
return success();
}
LogicalResult writeOp(IREESeq::LL::CallImportOp op, BytecodeWriter *writer) {
auto module = op.getOperation()->getParentOfType<ModuleOp>();
auto callee = module.lookupSymbol<FuncOp>(op.getCallee());
// TODO(benvanik): transforms to convert Call->CallImport.
// TODO(benvanik): switch with kCallTail if attr exists.
if (callee.isExternal()) {
RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCallImport));
} else {
RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCall));
}
RETURN_IF_FAILURE(writer->WriteImportOrdinal(callee));
RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands()));
RETURN_IF_FAILURE(writer->WriteLocals(op.getResults()));
return success();
}
LogicalResult writeOp(IREESeq::LL::CallIndirectOp op, BytecodeWriter *writer) {
RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCallIndirect));
RETURN_IF_FAILURE(writer->WriteTypeIndex(op.getCallee()->getType()));
RETURN_IF_FAILURE(writer->WriteLocal(op.getCallee()));
RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands()));
RETURN_IF_FAILURE(writer->WriteLocals(op.getResults()));
return success();
}
LogicalResult writeOp(IREESeq::LL::BranchOp op, BytecodeWriter *writer) {
RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kBranch));
RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getDest()));
RETURN_IF_FAILURE(writer->WriteCount(op.getNumOperands()));
for (int i = 0; i < op.getNumOperands(); ++i) {
// Copy src->dst.
RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(i)));
RETURN_IF_FAILURE(writer->WriteLocal(op.getDest()->getArgument(i)));
}
return success();
}
LogicalResult writeOp(IREESeq::LL::CondBranchOp op, BytecodeWriter *writer) {
RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCondBranch));
RETURN_IF_FAILURE(writer->WriteLocal(op.getCondition()));
RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getTrueDest()));
RETURN_IF_FAILURE(writer->WriteCount(op.getNumTrueOperands()));
for (int i = 0; i < op.getNumTrueOperands(); ++i) {
// Copy src->dst.
RETURN_IF_FAILURE(writer->WriteLocal(op.getTrueOperand(i)));
RETURN_IF_FAILURE(writer->WriteLocal(op.getTrueDest()->getArgument(i)));
}
RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getFalseDest()));
RETURN_IF_FAILURE(writer->WriteCount(op.getNumFalseOperands()));
for (int i = 0; i < op.getNumFalseOperands(); ++i) {
// Copy src->dst.
RETURN_IF_FAILURE(writer->WriteLocal(op.getFalseOperand(i)));
RETURN_IF_FAILURE(writer->WriteLocal(op.getFalseDest()->getArgument(i)));
}
return success();
}
LogicalResult writeDispatchOpExecutableRef(Operation *op, StringRef executable,
StringRef entryPoint,
BytecodeWriter *writer) {
auto module = op->getParentOfType<ModuleOp>();
auto multiArchExecutableOp =
module.lookupSymbol<IREE::MultiArchExecutableOp>(executable);
if (!multiArchExecutableOp) {
return op->emitError() << "Executable @" << executable.str()
<< " not found in module";
}
auto executableOrdinalAttr = multiArchExecutableOp.getAttr("iree.ordinal")
.dyn_cast_or_null<IntegerAttr>();
if (!executableOrdinalAttr) {
return op->emitError() << "No ordinal assigned to executable";
}
int executableOrdinal = executableOrdinalAttr.getInt();
// TODO(benvanik): move an export table to the MAE to make this cleaner.
auto executableOp =
cast<IREE::ExecutableOp>(multiArchExecutableOp.getBlock().front());
auto entryPointOp =
executableOp.getInnerModule().lookupSymbol<FuncOp>(entryPoint);
if (!entryPointOp) {
return op->emitError() << "Entry point @" << entryPoint.str()
<< " not found in executable @" << executable.str();
}
if (!entryPointOp.getAttr("iree.ordinal")) {
return op->emitError() << "No ordinal assigned to entry point";
}
int entryPointOrdinal =
entryPointOp.getAttr("iree.ordinal").cast<IntegerAttr>().getInt();
RETURN_IF_FAILURE(writer->WriteUint32(executableOrdinal));
RETURN_IF_FAILURE(writer->WriteUint16(entryPointOrdinal));
return success();
}
LogicalResult writeOp(IREESeq::LL::DynamicDispatchOp op,
BytecodeWriter *writer) {
RETURN_IF_FAILURE(
writer->WriteOpcode(iree::SequencerOpcode::kDynamicDispatch));
RETURN_IF_FAILURE(writeDispatchOpExecutableRef(op, op.getExecutable(),
op.getEntryPoint(), writer));
RETURN_IF_FAILURE(writer->WriteLocal(op.getWorkload()));
RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands()));
// TODO(benvanik): support output arg group (or change to tags).
RETURN_IF_FAILURE(writer->WriteCount(/*output_arg_count*/ 0));
RETURN_IF_FAILURE(writer->WriteLocals(op.getResults()));
return success();
}
LogicalResult writeOp(IREESeq::LL::StaticDispatchOp op,
BytecodeWriter *writer) {
RETURN_IF_FAILURE(
writer->WriteOpcode(iree::SequencerOpcode::kStaticDispatch));
RETURN_IF_FAILURE(writeDispatchOpExecutableRef(op, op.getExecutable(),
op.getEntryPoint(), writer));
auto workloadAttr = op.getWorkload();
RETURN_IF_FAILURE(
writer->WriteInt32(workloadAttr.getValue<IntegerAttr>({0}).getInt()));
RETURN_IF_FAILURE(
writer->WriteInt32(workloadAttr.getValue<IntegerAttr>({1}).getInt()));
RETURN_IF_FAILURE(
writer->WriteInt32(workloadAttr.getValue<IntegerAttr>({2}).getInt()));
RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands()));
// TODO(benvanik): support output arg group (or change to tags).
RETURN_IF_FAILURE(writer->WriteCount(/*output_arg_count*/ 0));
RETURN_IF_FAILURE(writer->WriteLocals(op.getResults()));
return success();
}
LogicalResult writeOp(IREESeq::LL::AllocHeapOp op, BytecodeWriter *writer) {
auto memRefType = op.getType().cast<MemRefType>();
RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kAllocHeap));
RETURN_IF_FAILURE(writer->WriteInt32(0));
RETURN_IF_FAILURE(writer->WriteTypeIndex(memRefType.getElementType()));
RETURN_IF_FAILURE(writer->WriteShapePieces(memRefType));
RETURN_IF_FAILURE(writer->WriteLocals(op.getOperands()));
RETURN_IF_FAILURE(writer->WriteLocal(op.getResult()));
return success();
}
LogicalResult writeOp(IREESeq::LL::ComputeRangeOp op, BytecodeWriter *writer) {
RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kComputeRange));
RETURN_IF_FAILURE(writer->WriteLocal(op.shape()));
RETURN_IF_FAILURE(writer->WriteUint8(op.elementSize().getZExtValue()));
RETURN_IF_FAILURE(writer->WriteLocal(op.indices()));
RETURN_IF_FAILURE(writer->WriteLocal(op.lengths()));
RETURN_IF_FAILURE(writer->WriteLocal(op.dstOffset()));
RETURN_IF_FAILURE(writer->WriteLocal(op.dstLength()));
return success();
}
LogicalResult writeOp(IREESeq::LL::StaticSliceOp op, BytecodeWriter *writer) {
RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kStaticSlice));
RETURN_IF_FAILURE(writer->WriteLocal(op.src()));
RETURN_IF_FAILURE(writer->WriteInt32(op.offset().getZExtValue()));
RETURN_IF_FAILURE(writer->WriteInt32(op.length().getZExtValue()));
RETURN_IF_FAILURE(writer->WriteTypeIndex(op.getResult()->getType()));
RETURN_IF_FAILURE(
writer->WriteShapePieces(op.getResult()->getType().cast<ShapedType>()));
RETURN_IF_FAILURE(writer->WriteLocal(op.getResult()));
return success();
}
LogicalResult writeOp(IREESeq::LL::StaticCopyOp op, BytecodeWriter *writer) {
RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kStaticCopy));
RETURN_IF_FAILURE(writer->WriteLocal(op.src()));
RETURN_IF_FAILURE(writer->WriteInt32(op.srcOffset().getZExtValue()));
RETURN_IF_FAILURE(writer->WriteLocal(op.dst()));
RETURN_IF_FAILURE(writer->WriteInt32(op.dstOffset().getZExtValue()));
RETURN_IF_FAILURE(writer->WriteInt32(op.length().getZExtValue()));
return success();
}
LogicalResult writeOp(IREESeq::LL::StaticFillOp op, BytecodeWriter *writer) {
RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kStaticFill));
RETURN_IF_FAILURE(writer->WriteInt32(op.value().getZExtValue()));
RETURN_IF_FAILURE(writer->WriteLocal(op.dst()));
RETURN_IF_FAILURE(writer->WriteInt32(op.dstOffset().getZExtValue()));
RETURN_IF_FAILURE(writer->WriteInt32(op.length().getZExtValue()));
return success();
}
} // namespace
void registerSequencerCustomWriters(VMFunctionBuilder *builder) {
#define REGISTER_CUSTOM_WRITER_IMPL(op_type) \
builder->RegisterCustomWriter( \
op_type::getOperationName(), \
+[](Operation *op, BytecodeWriter *writer) { \
return writeOp(cast<op_type>(op), writer); \
});
REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::ConstantOp);
REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::CallOp);
REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::CallImportOp);
REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::CallIndirectOp);
REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::BranchOp);
REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::CondBranchOp);
REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::DynamicDispatchOp);
REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::StaticDispatchOp);
REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::AllocHeapOp);
REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::ComputeRangeOp);
REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::StaticSliceOp);
REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::StaticCopyOp);
REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::StaticFillOp);
}
} // namespace iree_compiler
} // namespace mlir