blob: 564f03d6713f47ac983af4d76b8a2f6d0c916f01 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.h"
#include "flatbuffers/flatbuffers.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/VM/Conversion/ConversionTarget.h"
#include "iree/compiler/Dialect/VM/IR/VMDialect.h"
#include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h"
#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
#include "iree/compiler/Dialect/VMLA/Transforms/Passes.h"
#include "iree/schemas/vmla_executable_def_generated.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace HAL {
namespace {
// Destructively merges |sourceModuleOp| into |targetModuleOp|.
// |targetSymbolTable| is updated with the new symbols.
void mergeModuleInto(IREE::VM::ModuleOp sourceModuleOp,
IREE::VM::ModuleOp targetModuleOp,
DenseMap<StringRef, Operation *> &targetSymbolMap) {
auto allOps = llvm::to_vector<8>(llvm::map_range(
sourceModuleOp.getBlock(), [&](Operation &op) { return &op; }));
for (auto &op : allOps) {
if (op->isKnownTerminator()) continue;
if (auto symbolInterface = dyn_cast<SymbolOpInterface>(op)) {
if (targetSymbolMap.count(symbolInterface.getName())) {
// TODO(scotttodd): compare ops to ensure we aren't copying different
// things with the same name.
continue;
}
targetSymbolMap[symbolInterface.getName()] = op;
}
op->moveBefore(&targetModuleOp.getBlock().back());
}
// Now that we're done cloning its ops, delete the original target op.
sourceModuleOp.erase();
}
// Replaces each usage of an entry point with its original symbol name with a
// new symbol name.
void replaceEntryPointUses(mlir::ModuleOp moduleOp,
const DenseMap<Attribute, Attribute> &replacements) {
for (auto funcOp : moduleOp.getOps<mlir::FuncOp>()) {
funcOp.walk([&](IREE::HAL::CommandBufferDispatchSymbolOp dispatchOp) {
auto it = replacements.find(dispatchOp.entry_point());
if (it != replacements.end()) {
dispatchOp.entry_pointAttr(it->second.cast<SymbolRefAttr>());
}
});
}
}
} // namespace
VMLATargetOptions getVMLATargetOptionsFromFlags() {
VMLATargetOptions targetOptions;
// TODO(benvanik): flags.
return targetOptions;
}
class VMLATargetBackend final : public TargetBackend {
public:
VMLATargetBackend(VMLATargetOptions options) : options_(std::move(options)) {}
std::string name() const override { return "vmla"; }
std::string filter_pattern() const override { return "vmla"; }
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<VM::VMDialect, VMLA::VMLADialect>();
}
void buildTranslationPassPipeline(OpPassManager &passManager) override {
IREE::VMLA::buildVMLATransformPassPipeline(passManager);
// TODO(#614): remove this when the std->vm conversion isn't looking for
// iree.module.export.
passManager.addPass(IREE::VM::createMarkPublicSymbolsExportedPass());
IREE::VM::buildVMTransformPassPipeline(
passManager, IREE::VM::getTargetOptionsFromFlags());
}
LogicalResult linkExecutables(mlir::ModuleOp moduleOp) override {
OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody());
auto executableOps =
llvm::to_vector<8>(moduleOp.getOps<IREE::HAL::ExecutableOp>());
// Create our new "linked" hal.executable.
auto linkedExecutableOp = builder.create<IREE::HAL::ExecutableOp>(
moduleOp.getLoc(), "linked_vmla");
SymbolTable::setSymbolVisibility(linkedExecutableOp,
SymbolTable::Visibility::Private);
// Add our VMLA hal.executable.target with an empty module.
builder.setInsertionPointToStart(linkedExecutableOp.getBody());
auto linkedTargetOp = builder.create<IREE::HAL::ExecutableTargetOp>(
moduleOp.getLoc(), name(), filter_pattern());
builder.setInsertionPoint(&linkedTargetOp.getBlock().back());
auto linkedModuleOp = builder.create<ModuleOp>(moduleOp.getLoc());
// Add an empty vm.module to that module.
builder.setInsertionPointToStart(linkedModuleOp.getBody());
auto linkedVmModuleOp =
builder.create<IREE::VM::ModuleOp>(moduleOp.getLoc(), "linked_module");
llvm::SmallVector<IREE::HAL::InterfaceOp, 4> interfaceOps;
int nextEntryPointOrdinal = 0;
DenseMap<StringRef, Operation *> symbolMap;
DenseMap<Attribute, Attribute> entryPointRefReplacements;
auto linkedExecutableBuilder =
OpBuilder::atBlockBegin(linkedExecutableOp.getBody());
auto linkedTargetBuilder =
OpBuilder::atBlockBegin(linkedTargetOp.getBody());
for (auto executableOp : executableOps) {
auto targetOps = llvm::to_vector<4>(
executableOp.getOps<IREE::HAL::ExecutableTargetOp>());
for (auto targetOp : targetOps) {
// Only process targets matching our pattern.
if (!matchPattern(targetOp.target_backend_filter(), filter_pattern())) {
continue;
}
IREE::HAL::InterfaceOp interfaceOpForExecutable;
for (auto interfaceOp : interfaceOps) {
if (interfaceOp.isEquivalentTo(executableOp.getFirstInterfaceOp())) {
interfaceOpForExecutable = interfaceOp;
break;
}
}
if (!interfaceOpForExecutable) {
interfaceOpForExecutable =
dyn_cast<IREE::HAL::InterfaceOp>(linkedExecutableBuilder.clone(
*executableOp.getFirstInterfaceOp()));
interfaceOpForExecutable.setName(
llvm::formatv("legacy_io_{0}", interfaceOps.size()).str());
interfaceOps.push_back(interfaceOpForExecutable);
}
// Clone entry point ops and queue remapping ordinals and updating
// symbol refs.
for (auto entryPointOp :
targetOp.getOps<IREE::HAL::ExecutableEntryPointOp>()) {
auto newEntryPointOp =
linkedTargetBuilder.create<IREE::HAL::ExecutableEntryPointOp>(
entryPointOp.getLoc(), entryPointOp.sym_nameAttr(),
builder.getI32IntegerAttr(nextEntryPointOrdinal++),
builder.getSymbolRefAttr(interfaceOpForExecutable.getName()),
entryPointOp.signatureAttr());
// Add to replacement table for fixing up dispatch calls referencing
// this entry point.
auto oldSymbolRefAttr = builder.getSymbolRefAttr(
executableOp.getName(), {builder.getSymbolRefAttr(targetOp),
builder.getSymbolRefAttr(entryPointOp)});
auto newSymbolRefAttr = builder.getSymbolRefAttr(
linkedExecutableOp.getName(),
{builder.getSymbolRefAttr(linkedTargetOp),
builder.getSymbolRefAttr(newEntryPointOp)});
entryPointRefReplacements[oldSymbolRefAttr] = newSymbolRefAttr;
}
// Merge the existing vm.module op into the new linked vm.module op.
auto vmModuleOps =
targetOp.getInnerModule().getOps<IREE::VM::ModuleOp>();
if (vmModuleOps.empty()) {
return targetOp.getInnerModule().emitError()
<< "target's outer module does not contain a vm.module op";
}
mergeModuleInto(*vmModuleOps.begin(), linkedVmModuleOp, symbolMap);
targetOp.erase();
}
if (executableOp.getOps<IREE::HAL::ExecutableTargetOp>().empty()) {
executableOp.erase();
}
}
// Update references to @executable::@target::@entry symbols.
replaceEntryPointUses(moduleOp, entryPointRefReplacements);
// Remove if we didn't add anything.
if (linkedTargetOp.getOps<IREE::HAL::ExecutableEntryPointOp>().empty()) {
linkedTargetOp.erase();
linkedExecutableOp.erase();
}
return success();
}
LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp,
OpBuilder &executableBuilder) override {
// Serialize the VM module to bytes.
std::string byteStreamValue;
llvm::raw_string_ostream byte_stream(byteStreamValue);
IREE::VM::BytecodeTargetOptions bytecodeOptions;
if (failed(translateModuleToBytecode(targetOp.getInnerModule(),
bytecodeOptions, byte_stream))) {
return targetOp.emitError() << "failed to serialize converted VM module";
}
// Pack the executable definition and get the bytes with the proper header.
// The header is used to verify the contents at runtime.
::flatbuffers::FlatBufferBuilder fbb;
iree::VMLAExecutableDefT vmlaExecutableDef;
vmlaExecutableDef.bytecode_module.resize(byteStreamValue.size());
std::memcpy(vmlaExecutableDef.bytecode_module.data(),
byteStreamValue.data(), byteStreamValue.size());
auto executableOffset =
iree::VMLAExecutableDef::Pack(fbb, &vmlaExecutableDef);
iree::FinishVMLAExecutableDefBuffer(fbb, executableOffset);
std::vector<uint8_t> bytes;
bytes.resize(fbb.GetSize());
std::memcpy(bytes.data(), fbb.GetBufferPointer(), bytes.size());
// Add the binary data to the target executable.
executableBuilder.create<IREE::HAL::ExecutableBinaryOp>(
targetOp.getLoc(),
static_cast<uint32_t>(IREE::HAL::ExecutableFormat::VMLA),
std::move(bytes));
return success();
}
std::array<Value, 3> calculateDispatchWorkgroupCount(
Location loc, IREE::HAL::ExecutableOp executableOp,
IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
OpBuilder &builder) override {
// For now we are not tiling and just dispatch everything as 1,1,1.
auto constantOne = builder.createOrFold<mlir::ConstantIndexOp>(loc, 1);
return {constantOne, constantOne, constantOne};
}
private:
VMLATargetOptions options_;
};
void registerVMLATargetBackends(
std::function<VMLATargetOptions()> queryOptions) {
getVMLATargetOptionsFromFlags();
static TargetBackendRegistration registration("vmla", [=]() {
return std::make_unique<VMLATargetBackend>(queryOptions());
});
}
} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir