blob: 51aa04ed2ee2f92f4a24c93518fe191a4ac85be4 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree-compiler-c/Compiler.h"
#include "iree/compiler/ConstEval/Passes.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h"
#include "iree/compiler/InputConversion/MHLO/Passes.h"
#include "iree/compiler/InputConversion/TOSA/Passes.h"
#include "iree/compiler/Pipelines/Pipelines.h"
#include "iree/compiler/Utils/OptionUtils.h"
#include "iree/tools/init_dialects.h"
#include "iree/tools/init_llvmir_translations.h"
#include "iree/tools/init_passes.h"
#include "iree/tools/init_targets.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Pass.h"
#include "mlir/CAPI/Support.h"
#include "mlir/CAPI/Utils.h"
#include "mlir/CAPI/Wrap.h"
#include "mlir/IR/BuiltinOps.h"
using namespace mlir;
using namespace mlir::iree_compiler;
// TODO: There is a loose ::IREE namespace somewhere which means that we
// have to fully qualify from the unnamed namespace.
using HALTargetOptions = mlir::iree_compiler::IREE::HAL::TargetOptions;
using VMTargetOptions = mlir::iree_compiler::IREE::VM::TargetOptions;
using VMBytecodeTargetOptions =
mlir::iree_compiler::IREE::VM::BytecodeTargetOptions;
namespace {
// We have one composite options struct for everything. Not all components
// are applicable to every translation.
struct CompilerOptions {
BindingOptions bindingOptions;
InputDialectOptions inputDialectOptions;
HighLevelOptimizationOptions highLevelOptimizationOptions;
SchedulingOptions schedulingOptions;
HALTargetOptions halTargetOptions;
VMTargetOptions vmTargetOptions;
VMBytecodeTargetOptions vmBytecodeTargetOptions;
OptionsBinder binder;
CompilerOptions() : binder(OptionsBinder::local()) {
bindingOptions.bindOptions(binder);
inputDialectOptions.bindOptions(binder);
highLevelOptimizationOptions.bindOptions(binder);
schedulingOptions.bindOptions(binder);
halTargetOptions.bindOptions(binder);
vmTargetOptions.bindOptions(binder);
vmBytecodeTargetOptions.bindOptions(binder);
}
};
} // namespace
DEFINE_C_API_PTR_METHODS(IreeCompilerOptions, CompilerOptions)
void ireeCompilerRegisterAllDialects(MlirContext context) {
DialectRegistry registry;
mlir::iree_compiler::registerAllDialects(registry);
mlir::iree_compiler::registerLLVMIRTranslations(registry);
unwrap(context)->appendDialectRegistry(registry);
}
void ireeCompilerRegisterAllPasses() { registerAllPasses(); }
void ireeCompilerRegisterTargetBackends() { registerHALTargetBackends(); }
IreeCompilerOptions ireeCompilerOptionsCreate() {
auto options = new CompilerOptions;
// TODO: Make configurable.
options->vmTargetOptions.f32Extension = true;
return wrap(options);
}
MlirLogicalResult ireeCompilerOptionsSetFlags(
IreeCompilerOptions options, int argc, const char *const *argv,
void (*onError)(MlirStringRef, void *), void *userData) {
CompilerOptions *optionsCpp = unwrap(options);
auto callback = [&](llvm::StringRef message) {
if (onError) {
onError(wrap(message), userData);
}
};
if (failed(optionsCpp->binder.parseArguments(argc, argv, callback))) {
return mlirLogicalResultFailure();
}
return mlirLogicalResultSuccess();
}
void ireeCompilerOptionsGetFlags(IreeCompilerOptions options,
bool nonDefaultOnly,
void (*onFlag)(MlirStringRef, void *),
void *userData) {
auto flagVector = unwrap(options)->binder.printArguments(nonDefaultOnly);
for (std::string &value : flagVector) {
onFlag(wrap(llvm::StringRef(value)), userData);
}
}
void ireeCompilerOptionsDestroy(IreeCompilerOptions options) {
delete unwrap(options);
}
void ireeCompilerOptionsAddTargetBackend(IreeCompilerOptions options,
const char *targetBackend) {
unwrap(options)->halTargetOptions.targets.push_back(
std::string(targetBackend));
}
void ireeCompilerOptionsSetInputDialectMHLO(IreeCompilerOptions options) {
unwrap(options)->inputDialectOptions.type = InputDialectOptions::Type::mhlo;
}
void ireeCompilerOptionsSetInputDialectTOSA(IreeCompilerOptions options) {
unwrap(options)->inputDialectOptions.type = InputDialectOptions::Type::tosa;
}
void ireeCompilerOptionsSetInputDialectXLA(IreeCompilerOptions options) {
unwrap(options)->inputDialectOptions.type = InputDialectOptions::Type::xla;
}
void ireeCompilerBuildXLACleanupPassPipeline(MlirOpPassManager passManager) {
auto *passManagerCpp = unwrap(passManager);
MHLO::buildXLACleanupPassPipeline(*passManagerCpp);
}
void ireeCompilerBuildMHLOImportPassPipeline(MlirOpPassManager passManager) {
auto *passManagerCpp = unwrap(passManager);
MHLO::buildMHLOInputConversionPassPipeline(*passManagerCpp);
}
void ireeCompilerBuildTOSAImportPassPipeline(MlirOpPassManager passManager) {
auto *passManagerCpp = unwrap(passManager);
buildTOSAInputConversionPassPipeline(*passManagerCpp);
}
void ireeCompilerBuildIREEVMPassPipeline(IreeCompilerOptions options,
MlirOpPassManager passManager) {
auto *optionsCpp = unwrap(options);
auto *passManagerCpp = unwrap(passManager);
IREEVMPipelineHooks hooks = {
// buildConstEvalPassPipelineCallback =
[](OpPassManager &pm) { pm.addPass(ConstEval::createJitGlobalsPass()); }};
buildIREEVMTransformPassPipeline(
optionsCpp->bindingOptions, optionsCpp->inputDialectOptions,
optionsCpp->highLevelOptimizationOptions, optionsCpp->schedulingOptions,
optionsCpp->halTargetOptions, optionsCpp->vmTargetOptions, hooks,
*passManagerCpp);
}
// Translates a module op derived from the ireeCompilerBuildIREEVMPassPipeline
// to serialized bytecode. The module op may either be an outer builtin ModuleOp
// wrapping a VM::ModuleOp or a VM::ModuleOp.
MlirLogicalResult ireeCompilerTranslateModuletoVMBytecode(
IreeCompilerOptions options, MlirOperation moduleOp,
MlirStringCallback dataCallback, void *dataUserObject) {
auto *optionsCpp = unwrap(options);
Operation *moduleOpCpp = unwrap(moduleOp);
LogicalResult result = failure();
mlir::detail::CallbackOstream output(dataCallback, dataUserObject);
if (auto op = llvm::dyn_cast<mlir::ModuleOp>(moduleOpCpp)) {
result = iree_compiler::IREE::VM::translateModuleToBytecode(
op, optionsCpp->vmBytecodeTargetOptions, output);
} else if (auto op = llvm::dyn_cast<iree_compiler::IREE::VM::ModuleOp>(
moduleOpCpp)) {
result = iree_compiler::IREE::VM::translateModuleToBytecode(
op, optionsCpp->vmBytecodeTargetOptions, output);
} else {
emitError(moduleOpCpp->getLoc()) << "expected a supported module operation";
result = failure();
}
return wrap(result);
}