blob: 6f8f95a15a9e19952f5145bea239a02e237c31f2 [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Simple unit test that demonstrates compiling from MHLO using the CAPI.
// There is room for improvement on the high level APIs, and if some of what
// is here is extracted into new APIs, please simplify this test accordingly.
//
// Originally contributed due to the work of edubart who figured out how to
// be the first user of the combined MLIR+IREE CAPI:
// https://github.com/google/iree/pull/8582
#include <iree-compiler-c/Compiler.h>
#include <iree/base/string_builder.h>
#include <stdio.h>
static void bytecode_builder_callback(MlirStringRef str, void* userdata) {
iree_string_builder_t* builder = (iree_string_builder_t*)userdata;
iree_string_builder_append_string(
builder, iree_make_string_view(str.data, str.length));
}
// Compiles MLIR code into VM bytecode for the given target backend.
static bool iree_compile_mlir_to_bytecode(iree_string_view_t mlir_source,
iree_string_view_t target_backend,
iree_string_builder_t* out_builder) {
// TODO: support customizing compiling flags?
// TODO: support enabling different input dialects other than MHLO?
// TODO: return IREE status with error information instead of a boolean?
// TODO: only call registers once to speedup second calls?
// TODO: cache MLIR context, pass manager to speedup second calls?
// Expects string builder to be initialized.
if (out_builder == NULL) {
return false;
}
// The IREE source code states that this function should be called before
// creating any MLIRContext if one expects all the possible target backends
// to be available.
ireeCompilerRegisterTargetBackends();
// Register passes that may be required in the lowering pipeline.
ireeCompilerRegisterAllPasses();
// Create MLIR context.
MlirContext context = mlirContextCreate();
// Register all IREE dialects and dialects it depends on.
ireeCompilerRegisterAllDialects(context);
// Create MLIR module from a chunk of text.
MlirModule module = mlirModuleCreateParse(
context, mlirStringRefCreate(mlir_source.data, mlir_source.size));
if (mlirModuleIsNull(module)) {
return false;
}
// Prepare target backend flag.
char target_buf[128];
iree_string_builder_t target_builder;
iree_string_builder_initialize_with_storage(target_buf, sizeof(target_buf),
&target_builder);
iree_string_builder_append_cstring(&target_builder,
"--iree-hal-target-backends=");
iree_string_builder_append_string(&target_builder, target_backend);
// Create compiler options.
IreeCompilerOptions options = ireeCompilerOptionsCreate();
const char* compiler_flags[] = {iree_string_builder_buffer(&target_builder)};
MlirLogicalResult status =
ireeCompilerOptionsSetFlags(options, 1, compiler_flags, NULL, NULL);
if (mlirLogicalResultIsFailure(status)) {
ireeCompilerOptionsDestroy(options);
mlirModuleDestroy(module);
mlirContextDestroy(context);
return false;
}
// Run MLIR pass pipeline to lower the high level MLIR code down to to IREE VM
// MLIR code.
MlirPassManager pass = mlirPassManagerCreate(context);
MlirOpPassManager op_pass = mlirPassManagerGetAsOpPassManager(pass);
// Enable use of MHLO dialect.
ireeCompilerBuildMHLOImportPassPipeline(op_pass);
ireeCompilerBuildIREEVMPassPipeline(options, op_pass);
status = mlirPassManagerRun(pass, module);
if (mlirLogicalResultIsFailure(status)) {
mlirPassManagerDestroy(pass);
ireeCompilerOptionsDestroy(options);
mlirModuleDestroy(module);
mlirContextDestroy(context);
return false;
}
// Compile MLIR VM code to VM bytecode.
status = ireeCompilerTranslateModuletoVMBytecode(
options, mlirModuleGetOperation(module), bytecode_builder_callback,
out_builder);
// Cleanups.
mlirPassManagerDestroy(pass);
ireeCompilerOptionsDestroy(options);
mlirModuleDestroy(module);
mlirContextDestroy(context);
return mlirLogicalResultIsSuccess(status);
}
int main(int argc, char** argv) {
// MLIR code that we will compile
iree_string_view_t mlir_code = iree_make_cstring_view(
"func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> "
"tensor<4xf32>\n"
" {\n"
" %0 = \"mhlo.multiply\"(%arg0, %arg1) : "
"(tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>\n"
" return %0 : tensor<4xf32>\n"
"}\n");
// Initializes string builder that will contains the output bytecode.
iree_string_builder_t bytecode_builder;
iree_string_builder_initialize(iree_allocator_system(), &bytecode_builder);
// Compiles MLIR to VM bytecode.
bool status = iree_compile_mlir_to_bytecode(
mlir_code, iree_make_cstring_view("dylib"), &bytecode_builder);
if (!status) {
iree_string_builder_deinitialize(&bytecode_builder);
fprintf(stderr, "failed to compiler MLIR code\n");
return -1;
}
// For testing purposes, just print the length vs the full contents.
iree_string_view_t bytecode = iree_string_builder_view(&bytecode_builder);
printf("GENERATED VMFB SIZE: %d\n", (int)bytecode.size);
// Cleanups.
iree_string_builder_deinitialize(&bytecode_builder);
return 0;
}