blob: 8fe916bfaa38914da0cd593aa69376a41d4770ca [file] [log] [blame]
# springbok_vmvx_module()
#
# A modified version of iree_vmvx_linker_test to apply common iree-compile flags
# Parameters:
# NAME: Name of target.
# SRC: Source file to compile into a bytecode module. Support relative path.
# FLAGS: Flags to pass to the translation tool (list of strings).
# EMITC: Uses EmitC to output C code instead of VM bytecode.
#
# Examples:
# springbok_vmvx_module(
# NAME
# daredevel_bytecode_module_vmvx
# SRC
# "daredevil_quant.tflite"
# C_IDENTIFIER
# "daredevil_bytecode_module_vmvx"
# FLAGS
# "-iree-input-type=tosa"
# )
#
# springbok_vmvx_module(
# NAME
# simple_float_mul_c_module_vmvx
# SRC
# "simple_float_mul.mlir"
# C_IDENTIFIER
# "simple_float_mul"
# FLAGS
# "-iree-input-type=mhlo"
# EMITC
# )
#
function(springbok_vmvx_module)
cmake_parse_arguments(
_RULE
"EMITC"
"NAME;SRC;C_IDENTIFIER"
"FLAGS"
${ARGN}
)
set(_MLIR_SRC "${_RULE_SRC}")
string(FIND "${_RULE_SRC}" ".tflite" _IS_TFLITE REVERSE)
if(${_IS_TFLITE} GREATER 0)
iree_get_executable_path(IREE_IMPORT_TFLITE_TOOL "iree-import-tflite")
set(_MLIR_SRC "${CMAKE_CURRENT_BINARY_DIR}/${_RULE_NAME}.mlir")
get_filename_component(_SRC_PATH "${_RULE_SRC}" REALPATH)
set(_ARGS "${_SRC_PATH}")
list(APPEND _ARGS "--output-format=mlir-ir")
list(APPEND _ARGS "-o")
list(APPEND _ARGS "${_RULE_NAME}.mlir")
# Only add the custom_command here. The output is passed to
# iree_bytecode_module as the source.
add_custom_command(
OUTPUT
"${_RULE_NAME}.mlir"
COMMAND
${IREE_IMPORT_TFLITE_TOOL}
${_ARGS}
DEPENDS
${IREE_IMPORT_TFLITE_TOOL}
)
endif()
iree_get_executable_path(_COMPILER_TOOL "iree-compile")
iree_package_name(_PACKAGE_NAME)
iree_package_ns(_PACKAGE_NS)
# Set common iree-compile flags
set(_COMPILER_ARGS ${_RULE_FLAGS})
list(APPEND _COMPILER_ARGS "--iree-hal-target-backends=vmvx")
if(_RULE_EMITC)
list(APPEND _COMPILER_ARGS "--iree-vm-target-index-bits=32")
set(_MODULE_NAME "${_RULE_NAME}_emitc")
set(_H_FILE_NAME "${_RULE_NAME}_emitc.h")
iree_c_module(
NAME
${_MODULE_NAME}
SRC
"${_MLIR_SRC}"
FLAGS
${_COMPILER_ARGS}
H_FILE_OUTPUT
"${_H_FILE_NAME}"
NO_RUNTIME
)
else() # bytecode module path
# Generate the embed data with the bytecode module.
set(_MODULE_NAME "${_RULE_NAME}")
if(NOT _RULE_C_IDENTIFIER)
set(_RULE_C_IDENTIFIER "${_PACKAGE_NAME}_${_RULE_NAME}")
endif()
iree_bytecode_module(
NAME
${_MODULE_NAME}
SRC
"${_MLIR_SRC}"
FLAGS
${_COMPILER_ARGS}
C_IDENTIFIER
"${_RULE_C_IDENTIFIER}"
PUBLIC
)
endif(_RULE_EMITC)
endfunction()