Add support for VMVX backend Add support for VMVX backend for both bytecode and emitc. Buid VMVX targets in simple_vec_mul for illustration. Note: IREE currently doesn't well support VMVX backend (LLVM is ok though) for TOSA imported TFLite models: TOSA IR can't be compiled to VMFB. (https://github.com/iree-org/iree/issues/10253) Change-Id: I69dc84ce848f75e031d68ec21fc03cc3484efd58
diff --git a/CMakeLists.txt b/CMakeLists.txt index d44e160..7f6b663 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt
@@ -82,6 +82,7 @@ include(springbok_ops) include(springbok_static_module) +include(springbok_vmvx_module) include(springbok_modules) include(iree_model_input) include(springbok_test)
diff --git a/cmake/springbok_modules.cmake b/cmake/springbok_modules.cmake index 1bb68ec..5f511c7 100644 --- a/cmake/springbok_modules.cmake +++ b/cmake/springbok_modules.cmake
@@ -1,5 +1,3 @@ -include(CMakeParseArguments) - # springbok_modules() # # A wrapper for the springbok_bytecode_module and springbok_c_module to apply common iree-compile flags @@ -10,6 +8,7 @@ # C_IDENTIFIER: Identifier to use for generate c embed code. # If omitted then no C embed code will be generated. # RVV_OFF: Indicate RVV is OFF (default: ON) +# VMVX: Compile VMVX backend # # Examples: # springbok_modules( @@ -40,7 +39,7 @@ function(springbok_modules) cmake_parse_arguments( _RULE - "PUBLIC;RVV_OFF" + "PUBLIC;RVV_OFF;VMVX" "NAME;SRC;C_IDENTIFIER" "FLAGS" ${ARGN} @@ -73,4 +72,27 @@ EMITC ) + if (${_RULE_VMVX}) + springbok_vmvx_module( + NAME + "${_RULE_NAME}_bytecode_module_vmvx" + SRC + "${_RULE_SRC}" + C_IDENTIFIER + "${_RULE_C_IDENTIFIER}_bytecode_module_vmvx" + FLAGS + "${_RULE_FLAGS}" + ) + + springbok_vmvx_module( + NAME + "${_RULE_NAME}_c_module_vmvx" + SRC + "${_RULE_SRC}" + FLAGS + "${_RULE_FLAGS}" + EMITC + ) + endif() + endfunction()
diff --git a/cmake/springbok_static_module.cmake b/cmake/springbok_static_module.cmake index 52ca1d6..0853335 100644 --- a/cmake/springbok_static_module.cmake +++ b/cmake/springbok_static_module.cmake
@@ -1,5 +1,3 @@ -include(CMakeParseArguments) - # springbok_static_module() # # A modified version of iree_static_linker_test to apply common iree-compile flags
diff --git a/cmake/springbok_vmvx_module.cmake b/cmake/springbok_vmvx_module.cmake new file mode 100644 index 0000000..72e8bde --- /dev/null +++ b/cmake/springbok_vmvx_module.cmake
@@ -0,0 +1,105 @@ +# springbok_vmvx_module() +# +# A modified version of iree_vmvx_linker_test to apply common iree-compile flags +# Parameters: +# NAME: Name of target. +# SRC: Source file to compile into a bytecode module. Support relative path. +# FLAGS: Flags to pass to the translation tool (list of strings). +# EMITC: Uses EmitC to output C code instead of VM bytecode. +# +# Examples: +# springbok_vmvx_module( +# NAME +# daredevel_bytecode_module_vmvx +# SRC +# "daredevil_quant.tflite" +# C_IDENTIFIER +# "daredevil_bytecode_module_vmvx" +# FLAGS +# "-iree-input-type=tosa" +# ) +# +# springbok_vmvx_module( +# NAME +# simple_float_mul_c_module_vmvx +# SRC +# "simple_float_mul.mlir" +# C_IDENTIFIER +# "simple_float_mul" +# FLAGS +# "-iree-input-type=mhlo" +# EMITC +# ) +# +function(springbok_vmvx_module) + cmake_parse_arguments( + _RULE + "EMITC" + "NAME;SRC;C_IDENTIFIER" + "FLAGS" + ${ARGN} + ) + + set(_MLIR_SRC "${_RULE_SRC}") + string(FIND "${_RULE_SRC}" ".tflite" _IS_TFLITE REVERSE) + if(${_IS_TFLITE} GREATER 0) + iree_get_executable_path(IREE_IMPORT_TFLITE_TOOL "iree-import-tflite") + set(_MLIR_SRC "${CMAKE_CURRENT_BINARY_DIR}/${_RULE_NAME}.mlir") + get_filename_component(_SRC_PATH "${_RULE_SRC}" REALPATH) + set(_ARGS "${_SRC_PATH}") + list(APPEND _ARGS "-o") + list(APPEND _ARGS "${_RULE_NAME}.mlir") + # Only add the custom_command here. The output is passed to + # iree_bytecode_module as the source. + add_custom_command( + OUTPUT + "${_RULE_NAME}.mlir" + COMMAND + ${IREE_IMPORT_TFLITE_TOOL} + ${_ARGS} + DEPENDS + ${IREE_IMPORT_TFLITE_TOOL} + ) + endif() + + iree_get_executable_path(_COMPILER_TOOL "iree-compile") + iree_package_name(_PACKAGE_NAME) + iree_package_ns(_PACKAGE_NS) + + # Set common iree-compile flags + set(_COMPILER_ARGS ${_RULE_FLAGS}) + list(APPEND _COMPILER_ARGS "--iree-hal-target-backends=vmvx") + + if(_RULE_EMITC) + set(_MODULE_NAME "${_RULE_NAME}_emitc") + set(_H_FILE_NAME "${_RULE_NAME}_emitc.h") + iree_c_module( + NAME + ${_MODULE_NAME} + SRC + "${_MLIR_SRC}" + FLAGS + ${_COMPILER_ARGS} + H_FILE_OUTPUT + "${_H_FILE_NAME}" + NO_RUNTIME + ) + else() # bytecode module path + # Generate the embed data with the bytecode module. + set(_MODULE_NAME "${_RULE_NAME}") + if(NOT _RULE_C_IDENTIFIER) + set(_RULE_C_IDENTIFIER "${_PACKAGE_NAME}_${_RULE_NAME}") + endif() + iree_bytecode_module( + NAME + ${_MODULE_NAME} + SRC + "${_MLIR_SRC}" + FLAGS + ${_COMPILER_ARGS} + C_IDENTIFIER + "${_RULE_C_IDENTIFIER}" + PUBLIC + ) + endif(_RULE_EMITC) +endfunction()
diff --git a/device/CMakeLists.txt b/device/CMakeLists.txt index c3961b9..d7e527a 100644 --- a/device/CMakeLists.txt +++ b/device/CMakeLists.txt
@@ -12,3 +12,18 @@ iree::hal::local iree::hal::local::loaders::static_library_loader ) + +iree_cc_library( + NAME + device_vmvx_loader + HDRS + "device.h" + SRCS + "device_vmvx_loader.c" + DEPS + iree::base + iree::hal + iree::hal::drivers::local_sync::sync_driver + iree::hal::local + iree::hal::local::loaders::vmvx_module_loader +)
diff --git a/device/device_vmvx_loader.c b/device/device_vmvx_loader.c new file mode 100644 index 0000000..e1d89a8 --- /dev/null +++ b/device/device_vmvx_loader.c
@@ -0,0 +1,63 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// VMVX module loading in IREE. + +#include "device/device.h" +#include "iree/hal/drivers/local_sync/sync_device.h" +#include "iree/hal/local/loaders/vmvx_module_loader.h" +#include "iree/modules/hal/module.h" +#include "model_util/model_api.h" + +// A function to create the HAL device from the different backend targets. +// The HAL device is returned based on the implementation, and it must be +// released by the caller. +iree_status_t create_sample_device(iree_allocator_t host_allocator, + iree_hal_device_t** out_device) { + // Set parameters for the device created in the next step. + iree_hal_sync_device_params_t params; + iree_hal_sync_device_params_initialize(¶ms); + + iree_vm_instance_t* instance = NULL; + iree_status_t status = iree_vm_instance_create(host_allocator, &instance); + + iree_hal_executable_loader_t* loader = NULL; + if (iree_status_is_ok(status)) { + status = iree_hal_vmvx_module_loader_create( + instance, /*user_module_count=*/0, /*user_modules=*/NULL, + host_allocator, &loader); + } + iree_vm_instance_release(instance); + + // Use the default host allocator for buffer allocations. + iree_string_view_t identifier = iree_make_cstring_view("vmvx"); + iree_hal_allocator_t* device_allocator = NULL; + if (iree_status_is_ok(status)) { + status = iree_hal_allocator_create_heap(identifier, host_allocator, + host_allocator, &device_allocator); + } + + if (iree_status_is_ok(status)) { + // Create the synchronous device. + status = iree_hal_sync_device_create( + identifier, ¶ms, /*loader_count=*/1, &loader, device_allocator, + host_allocator, out_device); + } + + iree_hal_allocator_release(device_allocator); + iree_hal_executable_loader_release(loader); + return status; +}
diff --git a/model_util/CMakeLists.txt b/model_util/CMakeLists.txt index ce3cae0..c14745f 100644 --- a/model_util/CMakeLists.txt +++ b/model_util/CMakeLists.txt
@@ -22,7 +22,14 @@ DEPS ::util_base device::device_static_loader - iree::hal::local::loaders::static_library_loader +) + +iree_cc_library( + NAME + util_vmvx + DEPS + ::util_base + device::device_vmvx_loader ) iree_cc_library(
diff --git a/samples/simple_vec_mul/CMakeLists.txt b/samples/simple_vec_mul/CMakeLists.txt index 6f46eca..9ea573e 100644 --- a/samples/simple_vec_mul/CMakeLists.txt +++ b/samples/simple_vec_mul/CMakeLists.txt
@@ -15,6 +15,7 @@ "-iree-input-type=mhlo" "-riscv-v-vector-bits-min=512" "-riscv-v-fixed-length-vector-lmul-max=8" + VMVX PUBLIC ) @@ -29,6 +30,7 @@ "-iree-input-type=mhlo" "-riscv-v-vector-bits-min=512" "-riscv-v-fixed-length-vector-lmul-max=8" + VMVX PUBLIC ) @@ -46,6 +48,36 @@ iree_cc_binary( NAME + simple_float_vec_mul_bytecode_vmvx + SRCS + "float_vec.c" + DEPS + ::simple_float_mul_bytecode_module_vmvx_c + iree::vm::bytecode_module + model_util::util_vmvx + LINKOPTS + "LINKER:--defsym=__stack_size__=20k" + COPTS + "-DBUILD_VMVX" +) + +iree_cc_binary( + NAME + simple_float_vec_mul_emitc_vmvx + SRCS + "float_vec.c" + DEPS + ::simple_float_mul_c_module_vmvx_emitc + model_util::util_vmvx + LINKOPTS + "LINKER:--defsym=__stack_size__=20k" + COPTS + "-DBUILD_EMITC" + "-DBUILD_VMVX" +) + +iree_cc_binary( + NAME simple_float_vec_mul_bytecode_static SRCS "float_vec.c" @@ -75,6 +107,36 @@ iree_cc_binary( NAME + simple_int_vec_mul_bytecode_vmvx + SRCS + "int_vec.c" + DEPS + ::simple_int_mul_bytecode_module_vmvx_c + iree::vm::bytecode_module + model_util::util_vmvx + LINKOPTS + "LINKER:--defsym=__stack_size__=20k" + COPTS + "-DBUILD_VMVX" +) + +iree_cc_binary( + NAME + simple_int_vec_mul_emitc_vmvx + SRCS + "int_vec.c" + DEPS + ::simple_int_mul_c_module_vmvx_emitc + model_util::util_vmvx + LINKOPTS + "LINKER:--defsym=__stack_size__=20k" + COPTS + "-DBUILD_EMITC" + "-DBUILD_VMVX" +) + +iree_cc_binary( + NAME simple_int_vec_mul_bytecode_static SRCS "int_vec.c"
diff --git a/samples/simple_vec_mul/float_vec.c b/samples/simple_vec_mul/float_vec.c index 90ed92e..42ce5b8 100644 --- a/samples/simple_vec_mul/float_vec.c +++ b/samples/simple_vec_mul/float_vec.c
@@ -21,13 +21,21 @@ #include "model_util/util.h" // Compiled module embedded here to avoid file IO: +#if defined(BUILD_VMVX) +#if !defined(BUILD_EMITC) +#include "samples/simple_vec_mul/simple_float_mul_bytecode_module_vmvx_c.h" +#else +#include "samples/simple_vec_mul/simple_float_mul_c_module_vmvx_emitc.h" +#endif // !defined(BUILD_EMITC) +#else #if !defined(BUILD_EMITC) #include "samples/simple_vec_mul/simple_float_mul_bytecode_module_static.h" #include "samples/simple_vec_mul/simple_float_mul_bytecode_module_static_c.h" #else #include "samples/simple_vec_mul/simple_float_mul_c_module_static_c.h" #include "samples/simple_vec_mul/simple_float_mul_c_module_static_emitc.h" -#endif +#endif // #if !defined(BUILD_EMITC) +#endif // #if defined(BUILD_VMVX) const MlModel kModel = { .num_input = 2, @@ -46,20 +54,27 @@ iree_status_t create_module(iree_vm_instance_t *instance, iree_vm_module_t **module) { #if !defined(BUILD_EMITC) +#if defined(BUILD_VMVX) + const struct iree_file_toc_t *module_file_toc = + samples_simple_vec_mul_simple_float_mul_bytecode_module_vmvx_create(); +#else const struct iree_file_toc_t *module_file_toc = samples_simple_vec_mul_simple_float_mul_bytecode_module_static_create(); +#endif // #if defined(BUILD_VMVX) return iree_vm_bytecode_module_create( instance, iree_make_const_byte_span(module_file_toc->data, module_file_toc->size), iree_allocator_null(), iree_allocator_system(), module); #else return module_create(instance, iree_allocator_system(), module); -#endif +#endif // #if !defined(BUILD_EMITC) } +#if !defined(BUILD_VMVX) iree_hal_executable_library_query_fn_t library_query(void) { return &simple_mul_dispatch_0_library_query; } +#endif iree_status_t load_input_data(const MlModel *model, void **buffer, iree_const_byte_span_t **byte_span) {
diff --git a/samples/simple_vec_mul/int_vec.c b/samples/simple_vec_mul/int_vec.c index 63788a8..3cf57b3 100644 --- a/samples/simple_vec_mul/int_vec.c +++ b/samples/simple_vec_mul/int_vec.c
@@ -21,14 +21,21 @@ #include "model_util/util.h" // Compiled module embedded here to avoid file IO: +#if defined(BUILD_VMVX) #if !defined(BUILD_EMITC) -#include "iree/vm/bytecode_module.h" +#include "samples/simple_vec_mul/simple_int_mul_bytecode_module_vmvx_c.h" +#else +#include "samples/simple_vec_mul/simple_int_mul_c_module_vmvx_emitc.h" +#endif // !defined(BUILD_EMITC) +#else +#if !defined(BUILD_EMITC) #include "samples/simple_vec_mul/simple_int_mul_bytecode_module_static.h" #include "samples/simple_vec_mul/simple_int_mul_bytecode_module_static_c.h" #else #include "samples/simple_vec_mul/simple_int_mul_c_module_static_c.h" #include "samples/simple_vec_mul/simple_int_mul_c_module_static_emitc.h" -#endif +#endif // !defined(BUILD_EMITC) +#endif // defined(BUILD_VMVX) const MlModel kModel = { .num_input = 2, @@ -47,20 +54,27 @@ iree_status_t create_module(iree_vm_instance_t *instance, iree_vm_module_t **module) { #if !defined(BUILD_EMITC) +#if defined(BUILD_VMVX) + const struct iree_file_toc_t *module_file_toc = + samples_simple_vec_mul_simple_int_mul_bytecode_module_vmvx_create(); +#else const struct iree_file_toc_t *module_file_toc = samples_simple_vec_mul_simple_int_mul_bytecode_module_static_create(); +#endif // #if defined(BUILD_VMVX) return iree_vm_bytecode_module_create( instance, iree_make_const_byte_span(module_file_toc->data, module_file_toc->size), iree_allocator_null(), iree_allocator_system(), module); #else return module_create(instance, iree_allocator_system(), module); -#endif +#endif // #if !defined(BUILD_EMITC) } +#if !defined(BUILD_VMVX) iree_hal_executable_library_query_fn_t library_query(void) { return &simple_mul_dispatch_0_library_query; } +#endif // #if !defined(BUILD_VMVX) iree_status_t load_input_data(const MlModel *model, void **buffer, iree_const_byte_span_t **byte_span) {
diff --git a/samples/simple_vec_mul/simple_test.run b/samples/simple_vec_mul/simple_test.run index fbb18bf..07c6d94 100644 --- a/samples/simple_vec_mul/simple_test.run +++ b/samples/simple_vec_mul/simple_test.run
@@ -1,4 +1,8 @@ // RUN: ${TEST_RUNNER_CMD} %S/simple_int_vec_mul_bytecode_static // RUN: ${TEST_RUNNER_CMD} %S/simple_int_vec_mul_emitc_static +// RUN: ${TEST_RUNNER_CMD} %S/simple_int_vec_mul_bytecode_vmvx +// RUN: ${TEST_RUNNER_CMD} %S/simple_int_vec_mul_emitc_vmvx // RUN: ${TEST_RUNNER_CMD} %S/simple_float_vec_mul_bytecode_static // RUN: ${TEST_RUNNER_CMD} %S/simple_float_vec_mul_emitc_static +// RUN: ${TEST_RUNNER_CMD} %S/simple_float_vec_mul_bytecode_vmvx +// RUN: ${TEST_RUNNER_CMD} %S/simple_float_vec_mul_emitc_vmvx