Generate header-only style libraries with EmitC (#7929)
This basically implements the proposal in #7605. The library identified
by `NAME` now contains the implementation whereas a the headers without
implementation are provided by the suffixed target `NAME`_hdrs.
diff --git a/build_tools/cmake/iree_c_module.cmake b/build_tools/cmake/iree_c_module.cmake
index b1571ff..af83a8b 100644
--- a/build_tools/cmake/iree_c_module.cmake
+++ b/build_tools/cmake/iree_c_module.cmake
@@ -41,7 +41,7 @@
# Prefix the library with the package name, so we get: iree_package_name.
iree_package_name(_PACKAGE_NAME)
- set(_NAME "${_PACKAGE_NAME}_${_RULE_NAME}")
+ set(_NAME "${_PACKAGE_NAME}_${_RULE_NAME}_hdrs")
# Set defaults for TRANSLATE_TOOL.
if(DEFINED _RULE_TRANSLATE_TOOL)
@@ -65,6 +65,15 @@
DEPENDS ${_TRANSLATE_TOOL_EXECUTABLE} ${_RULE_SRC}
)
+ iree_cc_library(
+ NAME ${_RULE_NAME}
+ HDRS "${_RULE_H_FILE_OUTPUT}"
+ SRCS "${IREE_SOURCE_DIR}/iree/vm/module_impl_emitc.c"
+ INCLUDES "${CMAKE_CURRENT_BINARY_DIR}"
+ COPTS "-DEMITC_IMPLEMENTATION=\"${_RULE_H_FILE_OUTPUT}\""
+ "${_TESTONLY_ARG}"
+ )
+
set(_GEN_TARGET "${_NAME}_gen")
add_custom_target(
${_GEN_TARGET}
@@ -84,9 +93,9 @@
# Alias the iree_package_name library to iree::package::name.
# This lets us more clearly map to Bazel and makes it possible to
# disambiguate the underscores in paths vs. the separators.
- add_library(${_PACKAGE_NS}::${_RULE_NAME} ALIAS ${_NAME})
+ add_library(${_PACKAGE_NS}::${_RULE_NAME}_hdrs ALIAS ${_NAME})
iree_package_dir(_PACKAGE_DIR)
- if(${_RULE_NAME} STREQUAL ${_PACKAGE_DIR})
+ if(${_RULE_NAME}_hdrs STREQUAL ${_PACKAGE_DIR})
# If the library name matches the package then treat it as a default.
# For example, foo/bar/ library 'bar' would end up as 'foo::bar'.
add_library(${_PACKAGE_NS} ALIAS ${_NAME})
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
index 155a2b4..84d435b 100644
--- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
@@ -1195,13 +1195,15 @@
typeConverter.analysisCache.insert(
std::make_pair(funcOp.getOperation(), VMAnalysis()));
- funcOp.getOperation()->setAttr("emitc.static", UnitAttr::get(ctx));
-
// This function needs an iree_vm_native_module_descriptor_t that is emitted
// by the CModuleTarget at the moment. So we add a marker to this function
// and delay the printing of it.
funcOp.getOperation()->setAttr("vm.emit_at_end", UnitAttr::get(ctx));
+ // This functions is the only one users need and it is therefore declared
+ // separatly from all other functions.
+ funcOp.getOperation()->setAttr("vm.module.constructor", UnitAttr::get(ctx));
+
Block *entryBlock = funcOp.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
diff --git a/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp b/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
index 0bd1ef4..52123b0 100644
--- a/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
+++ b/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
@@ -38,6 +38,28 @@
<< std::string(77, '=') << "\n";
}
+static LogicalResult printFunctionDeclaration(
+ mlir::FuncOp funcOp, llvm::raw_ostream &output,
+ mlir::emitc::CppEmitter &emitter) {
+ Operation *op = funcOp.getOperation();
+ if (op->hasAttr("emitc.static")) output << "static ";
+
+ if (failed(emitter.emitTypes(funcOp.getLoc(), funcOp.getType().getResults())))
+ return failure();
+ output << " " << funcOp.getName();
+
+ output << "(";
+
+ bool error = false;
+ llvm::interleaveComma(funcOp.getArguments(), output, [&](BlockArgument arg) {
+ if (failed(emitter.emitType(funcOp.getLoc(), arg.getType()))) error = true;
+ });
+ if (error) return failure();
+ output << ");\n";
+
+ return success();
+}
+
static LogicalResult printRodataBuffers(IREE::VM::ModuleOp &moduleOp,
mlir::emitc::CppEmitter &emitter) {
llvm::raw_ostream &output = emitter.ostream();
@@ -135,7 +157,7 @@
<< "[] = {\n";
if (exportOps.empty()) {
// Empty list placeholder.
- output << " {0},\n";
+ output << " {{0}},\n";
} else {
// sort export ops
llvm::sort(exportOps, [](auto &lhs, auto &rhs) {
@@ -322,11 +344,39 @@
return success();
}
+ std::string includeGuard = moduleOp.getName().upper();
+ output << "#ifndef " << includeGuard << "_H_\n";
+ output << "#define " << includeGuard << "_H_\n";
+
auto printInclude = [&output](std::string include) {
output << "#include \"" << include << "\"\n";
};
printInclude("iree/vm/api.h");
+ output << "\n";
+
+ output << "#ifdef __cplusplus\n";
+ output << "extern \"C\" {\n";
+ output << "#endif // __cplusplus\n";
+ output << "\n";
+
+ mlir::emitc::CppEmitter emitter(output, /*declareVariablesAtTop=*/true);
+ for (auto funcOp : moduleOp.getOps<mlir::FuncOp>()) {
+ Operation *op = funcOp.getOperation();
+ if (!op->hasAttr("vm.module.constructor")) continue;
+ if (failed(printFunctionDeclaration(funcOp, output, emitter)))
+ return failure();
+ }
+
+ output << "\n";
+ output << "#ifdef __cplusplus\n";
+ output << "} // extern \"C\"\n";
+ output << "#endif // __cplusplus\n";
+ output << "\n";
+
+ output << "#endif // " << includeGuard << "_H_\n\n";
+ output << "#if defined(EMITC_IMPLEMENTATION)\n";
+
printInclude("iree/vm/ops.h");
printInclude("iree/vm/ops_emitc.h");
printInclude("iree/vm/shims_emitc.h");
@@ -338,7 +388,6 @@
printModuleComment(moduleOp, output);
output << "\n";
- mlir::emitc::CppEmitter emitter(output, /*declareVariablesAtTop=*/true);
mlir::emitc::CppEmitter::Scope scope(emitter);
if (failed(printRodataBuffers(moduleOp, emitter))) {
@@ -355,23 +404,9 @@
for (auto funcOp : moduleOp.getOps<mlir::FuncOp>()) {
Operation *op = funcOp.getOperation();
- if (op->hasAttr("emitc.static")) output << "static ";
-
- if (failed(
- emitter.emitTypes(funcOp.getLoc(), funcOp.getType().getResults())))
+ if (op->hasAttr("vm.module.constructor")) continue;
+ if (failed(printFunctionDeclaration(funcOp, output, emitter)))
return failure();
- output << " " << funcOp.getName();
-
- output << "(";
-
- bool error = false;
- llvm::interleaveComma(
- funcOp.getArguments(), output, [&](BlockArgument arg) {
- if (failed(emitter.emitType(funcOp.getLoc(), arg.getType())))
- error = true;
- });
- if (error) return failure();
- output << ");\n";
}
output << "// DEFINE FUNCTIONS\n";
@@ -406,6 +441,7 @@
return failure();
}
+ output << "#endif // EMITC_IMPLEMENTATION\n";
return success();
}
diff --git a/iree/samples/emitc_modules/CMakeLists.txt b/iree/samples/emitc_modules/CMakeLists.txt
index be6aaab..9f0569e 100644
--- a/iree/samples/emitc_modules/CMakeLists.txt
+++ b/iree/samples/emitc_modules/CMakeLists.txt
@@ -62,7 +62,7 @@
iree_c_module(
NAME
- import_module_b
+ import_module_b
SRC
"import_module_b.mlir"
H_FILE_OUTPUT
diff --git a/iree/samples/static_library/CMakeLists.txt b/iree/samples/static_library/CMakeLists.txt
index d6cc26b..d7fe192 100644
--- a/iree/samples/static_library/CMakeLists.txt
+++ b/iree/samples/static_library/CMakeLists.txt
@@ -141,19 +141,28 @@
)
# TODO(marbre): Cleanup SRCS and DEPS.
+iree_cc_library(
+ NAME
+ simple_mul_emitc
+ HDRS
+ "simple_mul_emitc.h"
+ DEFINES
+ "EMITC_IMPLEMENTATION"
+)
+
iree_cc_binary(
-NAME
- static_library_demo_c
-SRCS
- "create_c_module.c"
- "static_library_demo.c"
- "simple_mul_emitc.h"
-DEPS
- iree::runtime
- iree::hal::local::loaders::static_library_loader
- iree::hal::local::sync_driver
- iree::vm::shims_emitc
- simple_mul_c_module
+ NAME
+ static_library_demo_c
+ SRCS
+ "create_c_module.c"
+ "static_library_demo.c"
+ DEPS
+ ::simple_mul_emitc
+ iree::runtime
+ iree::hal::local::loaders::static_library_loader
+ iree::hal::local::sync_driver
+ iree::vm::shims_emitc
+ simple_mul_c_module
)
iree_lit_test(
diff --git a/iree/vm/module_impl_emitc.c b/iree/vm/module_impl_emitc.c
new file mode 100644
index 0000000..c03694c
--- /dev/null
+++ b/iree/vm/module_impl_emitc.c
@@ -0,0 +1,7 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include EMITC_IMPLEMENTATION
diff --git a/iree/vm/test/emitc/module_test.cc b/iree/vm/test/emitc/module_test.cc
index 89dfebb..eee8d61 100644
--- a/iree/vm/test/emitc/module_test.cc
+++ b/iree/vm/test/emitc/module_test.cc
@@ -15,6 +15,7 @@
#include "iree/base/status_cc.h"
#include "iree/testing/gtest.h"
#include "iree/vm/api.h"
+#define EMITC_IMPLEMENTATION
#include "iree/vm/test/emitc/arithmetic_ops.h"
#include "iree/vm/test/emitc/arithmetic_ops_f32.h"
#include "iree/vm/test/emitc/arithmetic_ops_i64.h"