Add support for VMVX backend
Add support for VMVX backend for both bytecode and emitc.
Buid VMVX targets in simple_vec_mul for illustration.
Note: IREE currently doesn't well support VMVX backend (LLVM is ok though) for TOSA imported TFLite models: TOSA IR can't be compiled to VMFB. (https://github.com/iree-org/iree/issues/10253)
Change-Id: I69dc84ce848f75e031d68ec21fc03cc3484efd58
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d44e160..7f6b663 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -82,6 +82,7 @@
include(springbok_ops)
include(springbok_static_module)
+include(springbok_vmvx_module)
include(springbok_modules)
include(iree_model_input)
include(springbok_test)
diff --git a/cmake/springbok_modules.cmake b/cmake/springbok_modules.cmake
index 1bb68ec..5f511c7 100644
--- a/cmake/springbok_modules.cmake
+++ b/cmake/springbok_modules.cmake
@@ -1,5 +1,3 @@
-include(CMakeParseArguments)
-
# springbok_modules()
#
# A wrapper for the springbok_bytecode_module and springbok_c_module to apply common iree-compile flags
@@ -10,6 +8,7 @@
# C_IDENTIFIER: Identifier to use for generate c embed code.
# If omitted then no C embed code will be generated.
# RVV_OFF: Indicate RVV is OFF (default: ON)
+# VMVX: Compile VMVX backend
#
# Examples:
# springbok_modules(
@@ -40,7 +39,7 @@
function(springbok_modules)
cmake_parse_arguments(
_RULE
- "PUBLIC;RVV_OFF"
+ "PUBLIC;RVV_OFF;VMVX"
"NAME;SRC;C_IDENTIFIER"
"FLAGS"
${ARGN}
@@ -73,4 +72,27 @@
EMITC
)
+ if (${_RULE_VMVX})
+ springbok_vmvx_module(
+ NAME
+ "${_RULE_NAME}_bytecode_module_vmvx"
+ SRC
+ "${_RULE_SRC}"
+ C_IDENTIFIER
+ "${_RULE_C_IDENTIFIER}_bytecode_module_vmvx"
+ FLAGS
+ "${_RULE_FLAGS}"
+ )
+
+ springbok_vmvx_module(
+ NAME
+ "${_RULE_NAME}_c_module_vmvx"
+ SRC
+ "${_RULE_SRC}"
+ FLAGS
+ "${_RULE_FLAGS}"
+ EMITC
+ )
+ endif()
+
endfunction()
diff --git a/cmake/springbok_static_module.cmake b/cmake/springbok_static_module.cmake
index 52ca1d6..0853335 100644
--- a/cmake/springbok_static_module.cmake
+++ b/cmake/springbok_static_module.cmake
@@ -1,5 +1,3 @@
-include(CMakeParseArguments)
-
# springbok_static_module()
#
# A modified version of iree_static_linker_test to apply common iree-compile flags
diff --git a/cmake/springbok_vmvx_module.cmake b/cmake/springbok_vmvx_module.cmake
new file mode 100644
index 0000000..72e8bde
--- /dev/null
+++ b/cmake/springbok_vmvx_module.cmake
@@ -0,0 +1,105 @@
+# springbok_vmvx_module()
+#
+# A modified version of iree_vmvx_linker_test to apply common iree-compile flags
+# Parameters:
+# NAME: Name of target.
+# SRC: Source file to compile into a bytecode module. Support relative path.
+# FLAGS: Flags to pass to the translation tool (list of strings).
+# EMITC: Uses EmitC to output C code instead of VM bytecode.
+#
+# Examples:
+# springbok_vmvx_module(
+# NAME
+# daredevel_bytecode_module_vmvx
+# SRC
+# "daredevil_quant.tflite"
+# C_IDENTIFIER
+# "daredevil_bytecode_module_vmvx"
+# FLAGS
+# "-iree-input-type=tosa"
+# )
+#
+# springbok_vmvx_module(
+# NAME
+# simple_float_mul_c_module_vmvx
+# SRC
+# "simple_float_mul.mlir"
+# C_IDENTIFIER
+# "simple_float_mul"
+# FLAGS
+# "-iree-input-type=mhlo"
+# EMITC
+# )
+#
+function(springbok_vmvx_module)
+ cmake_parse_arguments(
+ _RULE
+ "EMITC"
+ "NAME;SRC;C_IDENTIFIER"
+ "FLAGS"
+ ${ARGN}
+ )
+
+ set(_MLIR_SRC "${_RULE_SRC}")
+ string(FIND "${_RULE_SRC}" ".tflite" _IS_TFLITE REVERSE)
+ if(${_IS_TFLITE} GREATER 0)
+ iree_get_executable_path(IREE_IMPORT_TFLITE_TOOL "iree-import-tflite")
+ set(_MLIR_SRC "${CMAKE_CURRENT_BINARY_DIR}/${_RULE_NAME}.mlir")
+ get_filename_component(_SRC_PATH "${_RULE_SRC}" REALPATH)
+ set(_ARGS "${_SRC_PATH}")
+ list(APPEND _ARGS "-o")
+ list(APPEND _ARGS "${_RULE_NAME}.mlir")
+ # Only add the custom_command here. The output is passed to
+ # iree_bytecode_module as the source.
+ add_custom_command(
+ OUTPUT
+ "${_RULE_NAME}.mlir"
+ COMMAND
+ ${IREE_IMPORT_TFLITE_TOOL}
+ ${_ARGS}
+ DEPENDS
+ ${IREE_IMPORT_TFLITE_TOOL}
+ )
+ endif()
+
+ iree_get_executable_path(_COMPILER_TOOL "iree-compile")
+ iree_package_name(_PACKAGE_NAME)
+ iree_package_ns(_PACKAGE_NS)
+
+ # Set common iree-compile flags
+ set(_COMPILER_ARGS ${_RULE_FLAGS})
+ list(APPEND _COMPILER_ARGS "--iree-hal-target-backends=vmvx")
+
+ if(_RULE_EMITC)
+ set(_MODULE_NAME "${_RULE_NAME}_emitc")
+ set(_H_FILE_NAME "${_RULE_NAME}_emitc.h")
+ iree_c_module(
+ NAME
+ ${_MODULE_NAME}
+ SRC
+ "${_MLIR_SRC}"
+ FLAGS
+ ${_COMPILER_ARGS}
+ H_FILE_OUTPUT
+ "${_H_FILE_NAME}"
+ NO_RUNTIME
+ )
+ else() # bytecode module path
+ # Generate the embed data with the bytecode module.
+ set(_MODULE_NAME "${_RULE_NAME}")
+ if(NOT _RULE_C_IDENTIFIER)
+ set(_RULE_C_IDENTIFIER "${_PACKAGE_NAME}_${_RULE_NAME}")
+ endif()
+ iree_bytecode_module(
+ NAME
+ ${_MODULE_NAME}
+ SRC
+ "${_MLIR_SRC}"
+ FLAGS
+ ${_COMPILER_ARGS}
+ C_IDENTIFIER
+ "${_RULE_C_IDENTIFIER}"
+ PUBLIC
+ )
+ endif(_RULE_EMITC)
+endfunction()
diff --git a/device/CMakeLists.txt b/device/CMakeLists.txt
index c3961b9..d7e527a 100644
--- a/device/CMakeLists.txt
+++ b/device/CMakeLists.txt
@@ -12,3 +12,18 @@
iree::hal::local
iree::hal::local::loaders::static_library_loader
)
+
+iree_cc_library(
+ NAME
+ device_vmvx_loader
+ HDRS
+ "device.h"
+ SRCS
+ "device_vmvx_loader.c"
+ DEPS
+ iree::base
+ iree::hal
+ iree::hal::drivers::local_sync::sync_driver
+ iree::hal::local
+ iree::hal::local::loaders::vmvx_module_loader
+)
diff --git a/device/device_vmvx_loader.c b/device/device_vmvx_loader.c
new file mode 100644
index 0000000..e1d89a8
--- /dev/null
+++ b/device/device_vmvx_loader.c
@@ -0,0 +1,63 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// VMVX module loading in IREE.
+
+#include "device/device.h"
+#include "iree/hal/drivers/local_sync/sync_device.h"
+#include "iree/hal/local/loaders/vmvx_module_loader.h"
+#include "iree/modules/hal/module.h"
+#include "model_util/model_api.h"
+
+// A function to create the HAL device from the different backend targets.
+// The HAL device is returned based on the implementation, and it must be
+// released by the caller.
+iree_status_t create_sample_device(iree_allocator_t host_allocator,
+ iree_hal_device_t** out_device) {
+ // Set parameters for the device created in the next step.
+ iree_hal_sync_device_params_t params;
+ iree_hal_sync_device_params_initialize(¶ms);
+
+ iree_vm_instance_t* instance = NULL;
+ iree_status_t status = iree_vm_instance_create(host_allocator, &instance);
+
+ iree_hal_executable_loader_t* loader = NULL;
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_vmvx_module_loader_create(
+ instance, /*user_module_count=*/0, /*user_modules=*/NULL,
+ host_allocator, &loader);
+ }
+ iree_vm_instance_release(instance);
+
+ // Use the default host allocator for buffer allocations.
+ iree_string_view_t identifier = iree_make_cstring_view("vmvx");
+ iree_hal_allocator_t* device_allocator = NULL;
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_allocator_create_heap(identifier, host_allocator,
+ host_allocator, &device_allocator);
+ }
+
+ if (iree_status_is_ok(status)) {
+ // Create the synchronous device.
+ status = iree_hal_sync_device_create(
+ identifier, ¶ms, /*loader_count=*/1, &loader, device_allocator,
+ host_allocator, out_device);
+ }
+
+ iree_hal_allocator_release(device_allocator);
+ iree_hal_executable_loader_release(loader);
+ return status;
+}
diff --git a/model_util/CMakeLists.txt b/model_util/CMakeLists.txt
index ce3cae0..c14745f 100644
--- a/model_util/CMakeLists.txt
+++ b/model_util/CMakeLists.txt
@@ -22,7 +22,14 @@
DEPS
::util_base
device::device_static_loader
- iree::hal::local::loaders::static_library_loader
+)
+
+iree_cc_library(
+ NAME
+ util_vmvx
+ DEPS
+ ::util_base
+ device::device_vmvx_loader
)
iree_cc_library(
diff --git a/samples/simple_vec_mul/CMakeLists.txt b/samples/simple_vec_mul/CMakeLists.txt
index 6f46eca..9ea573e 100644
--- a/samples/simple_vec_mul/CMakeLists.txt
+++ b/samples/simple_vec_mul/CMakeLists.txt
@@ -15,6 +15,7 @@
"-iree-input-type=mhlo"
"-riscv-v-vector-bits-min=512"
"-riscv-v-fixed-length-vector-lmul-max=8"
+ VMVX
PUBLIC
)
@@ -29,6 +30,7 @@
"-iree-input-type=mhlo"
"-riscv-v-vector-bits-min=512"
"-riscv-v-fixed-length-vector-lmul-max=8"
+ VMVX
PUBLIC
)
@@ -46,6 +48,36 @@
iree_cc_binary(
NAME
+ simple_float_vec_mul_bytecode_vmvx
+ SRCS
+ "float_vec.c"
+ DEPS
+ ::simple_float_mul_bytecode_module_vmvx_c
+ iree::vm::bytecode_module
+ model_util::util_vmvx
+ LINKOPTS
+ "LINKER:--defsym=__stack_size__=20k"
+ COPTS
+ "-DBUILD_VMVX"
+)
+
+iree_cc_binary(
+ NAME
+ simple_float_vec_mul_emitc_vmvx
+ SRCS
+ "float_vec.c"
+ DEPS
+ ::simple_float_mul_c_module_vmvx_emitc
+ model_util::util_vmvx
+ LINKOPTS
+ "LINKER:--defsym=__stack_size__=20k"
+ COPTS
+ "-DBUILD_EMITC"
+ "-DBUILD_VMVX"
+)
+
+iree_cc_binary(
+ NAME
simple_float_vec_mul_bytecode_static
SRCS
"float_vec.c"
@@ -75,6 +107,36 @@
iree_cc_binary(
NAME
+ simple_int_vec_mul_bytecode_vmvx
+ SRCS
+ "int_vec.c"
+ DEPS
+ ::simple_int_mul_bytecode_module_vmvx_c
+ iree::vm::bytecode_module
+ model_util::util_vmvx
+ LINKOPTS
+ "LINKER:--defsym=__stack_size__=20k"
+ COPTS
+ "-DBUILD_VMVX"
+)
+
+iree_cc_binary(
+ NAME
+ simple_int_vec_mul_emitc_vmvx
+ SRCS
+ "int_vec.c"
+ DEPS
+ ::simple_int_mul_c_module_vmvx_emitc
+ model_util::util_vmvx
+ LINKOPTS
+ "LINKER:--defsym=__stack_size__=20k"
+ COPTS
+ "-DBUILD_EMITC"
+ "-DBUILD_VMVX"
+)
+
+iree_cc_binary(
+ NAME
simple_int_vec_mul_bytecode_static
SRCS
"int_vec.c"
diff --git a/samples/simple_vec_mul/float_vec.c b/samples/simple_vec_mul/float_vec.c
index 90ed92e..42ce5b8 100644
--- a/samples/simple_vec_mul/float_vec.c
+++ b/samples/simple_vec_mul/float_vec.c
@@ -21,13 +21,21 @@
#include "model_util/util.h"
// Compiled module embedded here to avoid file IO:
+#if defined(BUILD_VMVX)
+#if !defined(BUILD_EMITC)
+#include "samples/simple_vec_mul/simple_float_mul_bytecode_module_vmvx_c.h"
+#else
+#include "samples/simple_vec_mul/simple_float_mul_c_module_vmvx_emitc.h"
+#endif // !defined(BUILD_EMITC)
+#else
#if !defined(BUILD_EMITC)
#include "samples/simple_vec_mul/simple_float_mul_bytecode_module_static.h"
#include "samples/simple_vec_mul/simple_float_mul_bytecode_module_static_c.h"
#else
#include "samples/simple_vec_mul/simple_float_mul_c_module_static_c.h"
#include "samples/simple_vec_mul/simple_float_mul_c_module_static_emitc.h"
-#endif
+#endif // #if !defined(BUILD_EMITC)
+#endif // #if defined(BUILD_VMVX)
const MlModel kModel = {
.num_input = 2,
@@ -46,20 +54,27 @@
iree_status_t create_module(iree_vm_instance_t *instance,
iree_vm_module_t **module) {
#if !defined(BUILD_EMITC)
+#if defined(BUILD_VMVX)
+ const struct iree_file_toc_t *module_file_toc =
+ samples_simple_vec_mul_simple_float_mul_bytecode_module_vmvx_create();
+#else
const struct iree_file_toc_t *module_file_toc =
samples_simple_vec_mul_simple_float_mul_bytecode_module_static_create();
+#endif // #if defined(BUILD_VMVX)
return iree_vm_bytecode_module_create(
instance,
iree_make_const_byte_span(module_file_toc->data, module_file_toc->size),
iree_allocator_null(), iree_allocator_system(), module);
#else
return module_create(instance, iree_allocator_system(), module);
-#endif
+#endif // #if !defined(BUILD_EMITC)
}
+#if !defined(BUILD_VMVX)
iree_hal_executable_library_query_fn_t library_query(void) {
return &simple_mul_dispatch_0_library_query;
}
+#endif
iree_status_t load_input_data(const MlModel *model, void **buffer,
iree_const_byte_span_t **byte_span) {
diff --git a/samples/simple_vec_mul/int_vec.c b/samples/simple_vec_mul/int_vec.c
index 63788a8..3cf57b3 100644
--- a/samples/simple_vec_mul/int_vec.c
+++ b/samples/simple_vec_mul/int_vec.c
@@ -21,14 +21,21 @@
#include "model_util/util.h"
// Compiled module embedded here to avoid file IO:
+#if defined(BUILD_VMVX)
#if !defined(BUILD_EMITC)
-#include "iree/vm/bytecode_module.h"
+#include "samples/simple_vec_mul/simple_int_mul_bytecode_module_vmvx_c.h"
+#else
+#include "samples/simple_vec_mul/simple_int_mul_c_module_vmvx_emitc.h"
+#endif // !defined(BUILD_EMITC)
+#else
+#if !defined(BUILD_EMITC)
#include "samples/simple_vec_mul/simple_int_mul_bytecode_module_static.h"
#include "samples/simple_vec_mul/simple_int_mul_bytecode_module_static_c.h"
#else
#include "samples/simple_vec_mul/simple_int_mul_c_module_static_c.h"
#include "samples/simple_vec_mul/simple_int_mul_c_module_static_emitc.h"
-#endif
+#endif // !defined(BUILD_EMITC)
+#endif // defined(BUILD_VMVX)
const MlModel kModel = {
.num_input = 2,
@@ -47,20 +54,27 @@
iree_status_t create_module(iree_vm_instance_t *instance,
iree_vm_module_t **module) {
#if !defined(BUILD_EMITC)
+#if defined(BUILD_VMVX)
+ const struct iree_file_toc_t *module_file_toc =
+ samples_simple_vec_mul_simple_int_mul_bytecode_module_vmvx_create();
+#else
const struct iree_file_toc_t *module_file_toc =
samples_simple_vec_mul_simple_int_mul_bytecode_module_static_create();
+#endif // #if defined(BUILD_VMVX)
return iree_vm_bytecode_module_create(
instance,
iree_make_const_byte_span(module_file_toc->data, module_file_toc->size),
iree_allocator_null(), iree_allocator_system(), module);
#else
return module_create(instance, iree_allocator_system(), module);
-#endif
+#endif // #if !defined(BUILD_EMITC)
}
+#if !defined(BUILD_VMVX)
iree_hal_executable_library_query_fn_t library_query(void) {
return &simple_mul_dispatch_0_library_query;
}
+#endif // #if !defined(BUILD_VMVX)
iree_status_t load_input_data(const MlModel *model, void **buffer,
iree_const_byte_span_t **byte_span) {
diff --git a/samples/simple_vec_mul/simple_test.run b/samples/simple_vec_mul/simple_test.run
index fbb18bf..07c6d94 100644
--- a/samples/simple_vec_mul/simple_test.run
+++ b/samples/simple_vec_mul/simple_test.run
@@ -1,4 +1,8 @@
// RUN: ${TEST_RUNNER_CMD} %S/simple_int_vec_mul_bytecode_static
// RUN: ${TEST_RUNNER_CMD} %S/simple_int_vec_mul_emitc_static
+// RUN: ${TEST_RUNNER_CMD} %S/simple_int_vec_mul_bytecode_vmvx
+// RUN: ${TEST_RUNNER_CMD} %S/simple_int_vec_mul_emitc_vmvx
// RUN: ${TEST_RUNNER_CMD} %S/simple_float_vec_mul_bytecode_static
// RUN: ${TEST_RUNNER_CMD} %S/simple_float_vec_mul_emitc_static
+// RUN: ${TEST_RUNNER_CMD} %S/simple_float_vec_mul_bytecode_vmvx
+// RUN: ${TEST_RUNNER_CMD} %S/simple_float_vec_mul_emitc_vmvx