Add support for VMVX backend

Add support for VMVX backend for both bytecode and emitc.

Buid VMVX targets in simple_vec_mul for illustration.

Note: IREE currently doesn't well support VMVX backend (LLVM is ok though) for TOSA imported TFLite models: TOSA IR can't be compiled to VMFB. (https://github.com/iree-org/iree/issues/10253)

Change-Id: I69dc84ce848f75e031d68ec21fc03cc3484efd58
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d44e160..7f6b663 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -82,6 +82,7 @@
 include(springbok_ops)
 
 include(springbok_static_module)
+include(springbok_vmvx_module)
 include(springbok_modules)
 include(iree_model_input)
 include(springbok_test)
diff --git a/cmake/springbok_modules.cmake b/cmake/springbok_modules.cmake
index 1bb68ec..5f511c7 100644
--- a/cmake/springbok_modules.cmake
+++ b/cmake/springbok_modules.cmake
@@ -1,5 +1,3 @@
-include(CMakeParseArguments)
-
 # springbok_modules()
 #
 # A wrapper for the springbok_bytecode_module and springbok_c_module to apply common iree-compile flags
@@ -10,6 +8,7 @@
 # C_IDENTIFIER: Identifier to use for generate c embed code.
 #     If omitted then no C embed code will be generated.
 # RVV_OFF: Indicate RVV is OFF (default: ON)
+# VMVX: Compile VMVX backend
 #
 # Examples:
 # springbok_modules(
@@ -40,7 +39,7 @@
 function(springbok_modules)
   cmake_parse_arguments(
     _RULE
-    "PUBLIC;RVV_OFF"
+    "PUBLIC;RVV_OFF;VMVX"
     "NAME;SRC;C_IDENTIFIER"
     "FLAGS"
     ${ARGN}
@@ -73,4 +72,27 @@
     EMITC
   )
 
+  if (${_RULE_VMVX})
+    springbok_vmvx_module(
+      NAME
+        "${_RULE_NAME}_bytecode_module_vmvx"
+      SRC
+        "${_RULE_SRC}"
+      C_IDENTIFIER
+        "${_RULE_C_IDENTIFIER}_bytecode_module_vmvx"
+      FLAGS
+        "${_RULE_FLAGS}"
+    )
+
+    springbok_vmvx_module(
+      NAME
+        "${_RULE_NAME}_c_module_vmvx"
+      SRC
+        "${_RULE_SRC}"
+      FLAGS
+        "${_RULE_FLAGS}"
+      EMITC
+    )
+  endif()
+
 endfunction()
diff --git a/cmake/springbok_static_module.cmake b/cmake/springbok_static_module.cmake
index 52ca1d6..0853335 100644
--- a/cmake/springbok_static_module.cmake
+++ b/cmake/springbok_static_module.cmake
@@ -1,5 +1,3 @@
-include(CMakeParseArguments)
-
 # springbok_static_module()
 #
 # A modified version of iree_static_linker_test to apply common iree-compile flags
diff --git a/cmake/springbok_vmvx_module.cmake b/cmake/springbok_vmvx_module.cmake
new file mode 100644
index 0000000..72e8bde
--- /dev/null
+++ b/cmake/springbok_vmvx_module.cmake
@@ -0,0 +1,105 @@
+# springbok_vmvx_module()
+#
+# A modified version of iree_vmvx_linker_test to apply common iree-compile flags
+# Parameters:
+# NAME: Name of target.
+# SRC: Source file to compile into a bytecode module. Support relative path.
+# FLAGS: Flags to pass to the translation tool (list of strings).
+# EMITC: Uses EmitC to output C code instead of VM bytecode.
+#
+# Examples:
+# springbok_vmvx_module(
+#   NAME
+#     daredevel_bytecode_module_vmvx
+#   SRC
+#     "daredevil_quant.tflite"
+#   C_IDENTIFIER
+#     "daredevil_bytecode_module_vmvx"
+#   FLAGS
+#     "-iree-input-type=tosa"
+# )
+#
+# springbok_vmvx_module(
+#   NAME
+#     simple_float_mul_c_module_vmvx
+#   SRC
+#     "simple_float_mul.mlir"
+#   C_IDENTIFIER
+#     "simple_float_mul"
+#   FLAGS
+#     "-iree-input-type=mhlo"
+#   EMITC
+# )
+#
+function(springbok_vmvx_module)
+  cmake_parse_arguments(
+    _RULE
+    "EMITC"
+    "NAME;SRC;C_IDENTIFIER"
+    "FLAGS"
+    ${ARGN}
+  )
+
+  set(_MLIR_SRC "${_RULE_SRC}")
+  string(FIND "${_RULE_SRC}" ".tflite" _IS_TFLITE REVERSE)
+  if(${_IS_TFLITE} GREATER 0)
+    iree_get_executable_path(IREE_IMPORT_TFLITE_TOOL "iree-import-tflite")
+    set(_MLIR_SRC "${CMAKE_CURRENT_BINARY_DIR}/${_RULE_NAME}.mlir")
+    get_filename_component(_SRC_PATH "${_RULE_SRC}" REALPATH)
+    set(_ARGS "${_SRC_PATH}")
+    list(APPEND _ARGS "-o")
+    list(APPEND _ARGS "${_RULE_NAME}.mlir")
+    # Only add the custom_command here. The output is passed to
+    # iree_bytecode_module as the source.
+    add_custom_command(
+      OUTPUT
+        "${_RULE_NAME}.mlir"
+      COMMAND
+        ${IREE_IMPORT_TFLITE_TOOL}
+        ${_ARGS}
+      DEPENDS
+        ${IREE_IMPORT_TFLITE_TOOL}
+    )
+  endif()
+
+  iree_get_executable_path(_COMPILER_TOOL "iree-compile")
+  iree_package_name(_PACKAGE_NAME)
+  iree_package_ns(_PACKAGE_NS)
+
+  # Set common iree-compile flags
+  set(_COMPILER_ARGS ${_RULE_FLAGS})
+  list(APPEND _COMPILER_ARGS "--iree-hal-target-backends=vmvx")
+
+  if(_RULE_EMITC)
+    set(_MODULE_NAME "${_RULE_NAME}_emitc")
+    set(_H_FILE_NAME "${_RULE_NAME}_emitc.h")
+    iree_c_module(
+      NAME
+        ${_MODULE_NAME}
+      SRC
+        "${_MLIR_SRC}"
+      FLAGS
+        ${_COMPILER_ARGS}
+      H_FILE_OUTPUT
+        "${_H_FILE_NAME}"
+      NO_RUNTIME
+    )
+  else()  # bytecode module path
+    # Generate the embed data with the bytecode module.
+    set(_MODULE_NAME "${_RULE_NAME}")
+    if(NOT _RULE_C_IDENTIFIER)
+      set(_RULE_C_IDENTIFIER "${_PACKAGE_NAME}_${_RULE_NAME}")
+    endif()
+    iree_bytecode_module(
+      NAME
+        ${_MODULE_NAME}
+      SRC
+        "${_MLIR_SRC}"
+      FLAGS
+        ${_COMPILER_ARGS}
+      C_IDENTIFIER
+        "${_RULE_C_IDENTIFIER}"
+      PUBLIC
+    )
+  endif(_RULE_EMITC)
+endfunction()
diff --git a/device/CMakeLists.txt b/device/CMakeLists.txt
index c3961b9..d7e527a 100644
--- a/device/CMakeLists.txt
+++ b/device/CMakeLists.txt
@@ -12,3 +12,18 @@
     iree::hal::local
     iree::hal::local::loaders::static_library_loader
 )
+
+iree_cc_library(
+  NAME
+    device_vmvx_loader
+  HDRS
+    "device.h"
+  SRCS
+    "device_vmvx_loader.c"
+  DEPS
+    iree::base
+    iree::hal
+    iree::hal::drivers::local_sync::sync_driver
+    iree::hal::local
+    iree::hal::local::loaders::vmvx_module_loader
+)
diff --git a/device/device_vmvx_loader.c b/device/device_vmvx_loader.c
new file mode 100644
index 0000000..e1d89a8
--- /dev/null
+++ b/device/device_vmvx_loader.c
@@ -0,0 +1,63 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// VMVX module loading in IREE.
+
+#include "device/device.h"
+#include "iree/hal/drivers/local_sync/sync_device.h"
+#include "iree/hal/local/loaders/vmvx_module_loader.h"
+#include "iree/modules/hal/module.h"
+#include "model_util/model_api.h"
+
+// A function to create the HAL device from the different backend targets.
+// The HAL device is returned based on the implementation, and it must be
+// released by the caller.
+iree_status_t create_sample_device(iree_allocator_t host_allocator,
+                                   iree_hal_device_t** out_device) {
+  // Set parameters for the device created in the next step.
+  iree_hal_sync_device_params_t params;
+  iree_hal_sync_device_params_initialize(&params);
+
+  iree_vm_instance_t* instance = NULL;
+  iree_status_t status = iree_vm_instance_create(host_allocator, &instance);
+
+  iree_hal_executable_loader_t* loader = NULL;
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_vmvx_module_loader_create(
+        instance, /*user_module_count=*/0, /*user_modules=*/NULL,
+        host_allocator, &loader);
+  }
+  iree_vm_instance_release(instance);
+
+  // Use the default host allocator for buffer allocations.
+  iree_string_view_t identifier = iree_make_cstring_view("vmvx");
+  iree_hal_allocator_t* device_allocator = NULL;
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_allocator_create_heap(identifier, host_allocator,
+                                            host_allocator, &device_allocator);
+  }
+
+  if (iree_status_is_ok(status)) {
+    // Create the synchronous device.
+    status = iree_hal_sync_device_create(
+        identifier, &params, /*loader_count=*/1, &loader, device_allocator,
+        host_allocator, out_device);
+  }
+
+  iree_hal_allocator_release(device_allocator);
+  iree_hal_executable_loader_release(loader);
+  return status;
+}
diff --git a/model_util/CMakeLists.txt b/model_util/CMakeLists.txt
index ce3cae0..c14745f 100644
--- a/model_util/CMakeLists.txt
+++ b/model_util/CMakeLists.txt
@@ -22,7 +22,14 @@
   DEPS
     ::util_base
     device::device_static_loader
-    iree::hal::local::loaders::static_library_loader
+)
+
+iree_cc_library(
+  NAME
+    util_vmvx
+  DEPS
+    ::util_base
+    device::device_vmvx_loader
 )
 
 iree_cc_library(
diff --git a/samples/simple_vec_mul/CMakeLists.txt b/samples/simple_vec_mul/CMakeLists.txt
index 6f46eca..9ea573e 100644
--- a/samples/simple_vec_mul/CMakeLists.txt
+++ b/samples/simple_vec_mul/CMakeLists.txt
@@ -15,6 +15,7 @@
     "-iree-input-type=mhlo"
     "-riscv-v-vector-bits-min=512"
     "-riscv-v-fixed-length-vector-lmul-max=8"
+  VMVX
   PUBLIC
 )
 
@@ -29,6 +30,7 @@
     "-iree-input-type=mhlo"
     "-riscv-v-vector-bits-min=512"
     "-riscv-v-fixed-length-vector-lmul-max=8"
+  VMVX
   PUBLIC
 )
 
@@ -46,6 +48,36 @@
 
 iree_cc_binary(
   NAME
+    simple_float_vec_mul_bytecode_vmvx
+  SRCS
+    "float_vec.c"
+  DEPS
+    ::simple_float_mul_bytecode_module_vmvx_c
+    iree::vm::bytecode_module
+    model_util::util_vmvx
+  LINKOPTS
+    "LINKER:--defsym=__stack_size__=20k"
+  COPTS
+    "-DBUILD_VMVX"
+)
+
+iree_cc_binary(
+  NAME
+    simple_float_vec_mul_emitc_vmvx
+  SRCS
+    "float_vec.c"
+  DEPS
+    ::simple_float_mul_c_module_vmvx_emitc
+    model_util::util_vmvx
+  LINKOPTS
+    "LINKER:--defsym=__stack_size__=20k"
+  COPTS
+    "-DBUILD_EMITC"
+    "-DBUILD_VMVX"
+)
+
+iree_cc_binary(
+  NAME
     simple_float_vec_mul_bytecode_static
   SRCS
     "float_vec.c"
@@ -75,6 +107,36 @@
 
 iree_cc_binary(
   NAME
+    simple_int_vec_mul_bytecode_vmvx
+  SRCS
+    "int_vec.c"
+  DEPS
+    ::simple_int_mul_bytecode_module_vmvx_c
+    iree::vm::bytecode_module
+    model_util::util_vmvx
+  LINKOPTS
+    "LINKER:--defsym=__stack_size__=20k"
+  COPTS
+    "-DBUILD_VMVX"
+)
+
+iree_cc_binary(
+  NAME
+    simple_int_vec_mul_emitc_vmvx
+  SRCS
+    "int_vec.c"
+  DEPS
+    ::simple_int_mul_c_module_vmvx_emitc
+    model_util::util_vmvx
+  LINKOPTS
+    "LINKER:--defsym=__stack_size__=20k"
+  COPTS
+    "-DBUILD_EMITC"
+    "-DBUILD_VMVX"
+)
+
+iree_cc_binary(
+  NAME
     simple_int_vec_mul_bytecode_static
   SRCS
     "int_vec.c"
diff --git a/samples/simple_vec_mul/float_vec.c b/samples/simple_vec_mul/float_vec.c
index 90ed92e..42ce5b8 100644
--- a/samples/simple_vec_mul/float_vec.c
+++ b/samples/simple_vec_mul/float_vec.c
@@ -21,13 +21,21 @@
 #include "model_util/util.h"
 
 // Compiled module embedded here to avoid file IO:
+#if defined(BUILD_VMVX)
+#if !defined(BUILD_EMITC)
+#include "samples/simple_vec_mul/simple_float_mul_bytecode_module_vmvx_c.h"
+#else
+#include "samples/simple_vec_mul/simple_float_mul_c_module_vmvx_emitc.h"
+#endif  // !defined(BUILD_EMITC)
+#else
 #if !defined(BUILD_EMITC)
 #include "samples/simple_vec_mul/simple_float_mul_bytecode_module_static.h"
 #include "samples/simple_vec_mul/simple_float_mul_bytecode_module_static_c.h"
 #else
 #include "samples/simple_vec_mul/simple_float_mul_c_module_static_c.h"
 #include "samples/simple_vec_mul/simple_float_mul_c_module_static_emitc.h"
-#endif
+#endif  // #if !defined(BUILD_EMITC)
+#endif  // #if defined(BUILD_VMVX)
 
 const MlModel kModel = {
     .num_input = 2,
@@ -46,20 +54,27 @@
 iree_status_t create_module(iree_vm_instance_t *instance,
                             iree_vm_module_t **module) {
 #if !defined(BUILD_EMITC)
+#if defined(BUILD_VMVX)
+  const struct iree_file_toc_t *module_file_toc =
+      samples_simple_vec_mul_simple_float_mul_bytecode_module_vmvx_create();
+#else
   const struct iree_file_toc_t *module_file_toc =
       samples_simple_vec_mul_simple_float_mul_bytecode_module_static_create();
+#endif  // #if defined(BUILD_VMVX)
   return iree_vm_bytecode_module_create(
       instance,
       iree_make_const_byte_span(module_file_toc->data, module_file_toc->size),
       iree_allocator_null(), iree_allocator_system(), module);
 #else
   return module_create(instance, iree_allocator_system(), module);
-#endif
+#endif  // #if !defined(BUILD_EMITC)
 }
 
+#if !defined(BUILD_VMVX)
 iree_hal_executable_library_query_fn_t library_query(void) {
   return &simple_mul_dispatch_0_library_query;
 }
+#endif
 
 iree_status_t load_input_data(const MlModel *model, void **buffer,
                               iree_const_byte_span_t **byte_span) {
diff --git a/samples/simple_vec_mul/int_vec.c b/samples/simple_vec_mul/int_vec.c
index 63788a8..3cf57b3 100644
--- a/samples/simple_vec_mul/int_vec.c
+++ b/samples/simple_vec_mul/int_vec.c
@@ -21,14 +21,21 @@
 #include "model_util/util.h"
 
 // Compiled module embedded here to avoid file IO:
+#if defined(BUILD_VMVX)
 #if !defined(BUILD_EMITC)
-#include "iree/vm/bytecode_module.h"
+#include "samples/simple_vec_mul/simple_int_mul_bytecode_module_vmvx_c.h"
+#else
+#include "samples/simple_vec_mul/simple_int_mul_c_module_vmvx_emitc.h"
+#endif  // !defined(BUILD_EMITC)
+#else
+#if !defined(BUILD_EMITC)
 #include "samples/simple_vec_mul/simple_int_mul_bytecode_module_static.h"
 #include "samples/simple_vec_mul/simple_int_mul_bytecode_module_static_c.h"
 #else
 #include "samples/simple_vec_mul/simple_int_mul_c_module_static_c.h"
 #include "samples/simple_vec_mul/simple_int_mul_c_module_static_emitc.h"
-#endif
+#endif  // !defined(BUILD_EMITC)
+#endif  // defined(BUILD_VMVX)
 
 const MlModel kModel = {
     .num_input = 2,
@@ -47,20 +54,27 @@
 iree_status_t create_module(iree_vm_instance_t *instance,
                             iree_vm_module_t **module) {
 #if !defined(BUILD_EMITC)
+#if defined(BUILD_VMVX)
+  const struct iree_file_toc_t *module_file_toc =
+      samples_simple_vec_mul_simple_int_mul_bytecode_module_vmvx_create();
+#else
   const struct iree_file_toc_t *module_file_toc =
       samples_simple_vec_mul_simple_int_mul_bytecode_module_static_create();
+#endif  // #if defined(BUILD_VMVX)
   return iree_vm_bytecode_module_create(
       instance,
       iree_make_const_byte_span(module_file_toc->data, module_file_toc->size),
       iree_allocator_null(), iree_allocator_system(), module);
 #else
   return module_create(instance, iree_allocator_system(), module);
-#endif
+#endif  // #if !defined(BUILD_EMITC)
 }
 
+#if !defined(BUILD_VMVX)
 iree_hal_executable_library_query_fn_t library_query(void) {
   return &simple_mul_dispatch_0_library_query;
 }
+#endif  // #if !defined(BUILD_VMVX)
 
 iree_status_t load_input_data(const MlModel *model, void **buffer,
                               iree_const_byte_span_t **byte_span) {
diff --git a/samples/simple_vec_mul/simple_test.run b/samples/simple_vec_mul/simple_test.run
index fbb18bf..07c6d94 100644
--- a/samples/simple_vec_mul/simple_test.run
+++ b/samples/simple_vec_mul/simple_test.run
@@ -1,4 +1,8 @@
 // RUN: ${TEST_RUNNER_CMD} %S/simple_int_vec_mul_bytecode_static
 // RUN: ${TEST_RUNNER_CMD} %S/simple_int_vec_mul_emitc_static
+// RUN: ${TEST_RUNNER_CMD} %S/simple_int_vec_mul_bytecode_vmvx
+// RUN: ${TEST_RUNNER_CMD} %S/simple_int_vec_mul_emitc_vmvx
 // RUN: ${TEST_RUNNER_CMD} %S/simple_float_vec_mul_bytecode_static
 // RUN: ${TEST_RUNNER_CMD} %S/simple_float_vec_mul_emitc_static
+// RUN: ${TEST_RUNNER_CMD} %S/simple_float_vec_mul_bytecode_vmvx
+// RUN: ${TEST_RUNNER_CMD} %S/simple_float_vec_mul_emitc_vmvx