| // Copyright 2019 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| // Vulkan Graphics + IREE API Integration Sample. |
| |
| // Vulkan GUI utility functions |
| // Other matters here: we need to pull in this first to make sure Vulkan API |
| // prototypes are defined so that we can statically link against them. |
| #include "iree/testing/vulkan/vulkan_gui_util.h" |
| |
| // IREE's C API: |
| #include "iree/base/api.h" |
| #include "iree/hal/api.h" |
| #include "iree/hal/vulkan/api.h" |
| #include "iree/hal/vulkan/registration/driver_module.h" |
| #include "iree/modules/hal/module.h" |
| #include "iree/vm/api.h" |
| #include "iree/vm/bytecode_module.h" |
| #include "iree/vm/ref_cc.h" |
| |
| // Other dependencies (helpers, etc.) |
| #include "iree/base/internal/main.h" |
| #include "iree/base/logging.h" |
| |
| // Compiled module embedded here to avoid file IO: |
| #include "iree/samples/vulkan/simple_mul_bytecode_module_c.h" |
| |
| static VkAllocationCallbacks* g_Allocator = NULL; |
| static VkInstance g_Instance = VK_NULL_HANDLE; |
| static VkPhysicalDevice g_PhysicalDevice = VK_NULL_HANDLE; |
| static VkDevice g_Device = VK_NULL_HANDLE; |
| static uint32_t g_QueueFamily = (uint32_t)-1; |
| static VkQueue g_Queue = VK_NULL_HANDLE; |
| static VkPipelineCache g_PipelineCache = VK_NULL_HANDLE; |
| static VkDescriptorPool g_DescriptorPool = VK_NULL_HANDLE; |
| |
| static ImGui_ImplVulkanH_Window g_MainWindowData; |
| static uint32_t g_MinImageCount = 2; |
| static bool g_SwapChainRebuild = false; |
| static int g_SwapChainResizeWidth = 0; |
| static int g_SwapChainResizeHeight = 0; |
| |
| static void check_vk_result(VkResult err) { |
| if (err == 0) return; |
| IREE_LOG(FATAL) << "VkResult: " << err; |
| } |
| |
| static void CleanupVulkan() { |
| vkDestroyDescriptorPool(g_Device, g_DescriptorPool, g_Allocator); |
| |
| vkDestroyDevice(g_Device, g_Allocator); |
| vkDestroyInstance(g_Instance, g_Allocator); |
| } |
| |
| static void CleanupVulkanWindow() { |
| ImGui_ImplVulkanH_DestroyWindow(g_Instance, g_Device, &g_MainWindowData, |
| g_Allocator); |
| } |
| |
| namespace iree { |
| |
| extern "C" int iree_main(int argc, char** argv) { |
| // -------------------------------------------------------------------------- |
| // Create a window. |
| if (SDL_Init(SDL_INIT_VIDEO | SDL_INIT_TIMER) != 0) { |
| IREE_LOG(FATAL) << "Failed to initialize SDL"; |
| return 1; |
| } |
| |
| // Setup window |
| // clang-format off |
| SDL_WindowFlags window_flags = (SDL_WindowFlags)( |
| SDL_WINDOW_VULKAN | SDL_WINDOW_RESIZABLE | SDL_WINDOW_ALLOW_HIGHDPI); |
| // clang-format on |
| SDL_Window* window = SDL_CreateWindow( |
| "IREE Samples - Vulkan Inference GUI", SDL_WINDOWPOS_CENTERED, |
| SDL_WINDOWPOS_CENTERED, 1280, 720, window_flags); |
| |
| // Setup Vulkan |
| iree_hal_vulkan_features_t iree_vulkan_features = |
| static_cast<iree_hal_vulkan_features_t>( |
| IREE_HAL_VULKAN_FEATURE_ENABLE_VALIDATION_LAYERS | |
| IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS); |
| std::vector<const char*> layers = GetInstanceLayers(iree_vulkan_features); |
| std::vector<const char*> extensions = |
| GetInstanceExtensions(window, iree_vulkan_features); |
| SetupVulkan(iree_vulkan_features, layers.data(), |
| static_cast<uint32_t>(layers.size()), extensions.data(), |
| static_cast<uint32_t>(extensions.size()), g_Allocator, |
| &g_Instance, &g_QueueFamily, &g_PhysicalDevice, &g_Queue, |
| &g_Device, &g_DescriptorPool); |
| |
| // Create Window Surface |
| VkSurfaceKHR surface; |
| VkResult err; |
| if (SDL_Vulkan_CreateSurface(window, g_Instance, &surface) == 0) { |
| printf("Failed to create Vulkan surface.\n"); |
| return 1; |
| } |
| |
| // Create Framebuffers |
| int w, h; |
| SDL_GetWindowSize(window, &w, &h); |
| ImGui_ImplVulkanH_Window* wd = &g_MainWindowData; |
| SetupVulkanWindow(wd, g_Allocator, g_Instance, g_QueueFamily, |
| g_PhysicalDevice, g_Device, surface, w, h, g_MinImageCount); |
| |
| // Setup Dear ImGui context |
| IMGUI_CHECKVERSION(); |
| ImGui::CreateContext(); |
| ImGuiIO& io = ImGui::GetIO(); |
| (void)io; |
| |
| ImGui::StyleColorsDark(); |
| |
| // Setup Platform/Renderer bindings |
| ImGui_ImplSDL2_InitForVulkan(window); |
| ImGui_ImplVulkan_InitInfo init_info = {}; |
| init_info.Instance = g_Instance; |
| init_info.PhysicalDevice = g_PhysicalDevice; |
| init_info.Device = g_Device; |
| init_info.QueueFamily = g_QueueFamily; |
| init_info.Queue = g_Queue; |
| init_info.PipelineCache = g_PipelineCache; |
| init_info.DescriptorPool = g_DescriptorPool; |
| init_info.Allocator = g_Allocator; |
| init_info.MinImageCount = g_MinImageCount; |
| init_info.ImageCount = wd->ImageCount; |
| init_info.CheckVkResultFn = check_vk_result; |
| ImGui_ImplVulkan_Init(&init_info, wd->RenderPass); |
| |
| // Upload Fonts |
| { |
| // Use any command queue |
| VkCommandPool command_pool = wd->Frames[wd->FrameIndex].CommandPool; |
| VkCommandBuffer command_buffer = wd->Frames[wd->FrameIndex].CommandBuffer; |
| |
| err = vkResetCommandPool(g_Device, command_pool, 0); |
| check_vk_result(err); |
| VkCommandBufferBeginInfo begin_info = {}; |
| begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; |
| begin_info.flags |= VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; |
| err = vkBeginCommandBuffer(command_buffer, &begin_info); |
| check_vk_result(err); |
| |
| ImGui_ImplVulkan_CreateFontsTexture(command_buffer); |
| |
| VkSubmitInfo end_info = {}; |
| end_info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; |
| end_info.commandBufferCount = 1; |
| end_info.pCommandBuffers = &command_buffer; |
| err = vkEndCommandBuffer(command_buffer); |
| check_vk_result(err); |
| err = vkQueueSubmit(g_Queue, 1, &end_info, VK_NULL_HANDLE); |
| check_vk_result(err); |
| |
| err = vkDeviceWaitIdle(g_Device); |
| check_vk_result(err); |
| ImGui_ImplVulkan_DestroyFontUploadObjects(); |
| } |
| |
| // Demo state. |
| bool show_demo_window = true; |
| bool show_iree_window = true; |
| // -------------------------------------------------------------------------- |
| |
| // -------------------------------------------------------------------------- |
| // Setup IREE. |
| |
| // Check API version. |
| iree_api_version_t actual_version; |
| iree_status_t status = |
| iree_api_version_check(IREE_API_VERSION_LATEST, &actual_version); |
| if (iree_status_is_ok(status)) { |
| IREE_LOG(INFO) << "IREE runtime API version " << actual_version; |
| } else { |
| IREE_LOG(FATAL) << "Unsupported runtime API version " << actual_version; |
| } |
| |
| // Register HAL drivers and VM module types. |
| IREE_CHECK_OK(iree_hal_vulkan_driver_module_register( |
| iree_hal_driver_registry_default())); |
| IREE_CHECK_OK(iree_hal_module_register_types()); |
| |
| // Create a runtime Instance. |
| iree_vm_instance_t* iree_instance = nullptr; |
| IREE_CHECK_OK( |
| iree_vm_instance_create(iree_allocator_system(), &iree_instance)); |
| |
| // Create IREE Vulkan Driver and Device, sharing our VkInstance/VkDevice. |
| IREE_LOG(INFO) << "Creating Vulkan driver/device"; |
| // Load symbols from our static `vkGetInstanceProcAddr` for IREE to use. |
| iree_hal_vulkan_syms_t* iree_vk_syms = nullptr; |
| IREE_CHECK_OK(iree_hal_vulkan_syms_create( |
| reinterpret_cast<void*>(&vkGetInstanceProcAddr), iree_allocator_system(), |
| &iree_vk_syms)); |
| // Create the driver sharing our VkInstance. |
| iree_hal_driver_t* iree_vk_driver = nullptr; |
| iree_string_view_t driver_identifier = iree_make_cstring_view("vulkan"); |
| iree_hal_vulkan_driver_options_t driver_options; |
| driver_options.api_version = VK_API_VERSION_1_0; |
| driver_options.requested_features = static_cast<iree_hal_vulkan_features_t>( |
| IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS); |
| IREE_CHECK_OK(iree_hal_vulkan_driver_create_using_instance( |
| driver_identifier, &driver_options, iree_vk_syms, g_Instance, |
| iree_allocator_system(), &iree_vk_driver)); |
| // Create a device sharing our VkDevice and queue. |
| // We could also create a separate (possibly low priority) compute queue for |
| // IREE, and/or provide a dedicated transfer queue. |
| iree_string_view_t device_identifier = iree_make_cstring_view("vulkan"); |
| iree_hal_vulkan_queue_set_t compute_queue_set; |
| compute_queue_set.queue_family_index = g_QueueFamily; |
| compute_queue_set.queue_indices = 1 << 0; |
| iree_hal_vulkan_queue_set_t transfer_queue_set; |
| transfer_queue_set.queue_indices = 0; |
| iree_hal_device_t* iree_vk_device = nullptr; |
| IREE_CHECK_OK(iree_hal_vulkan_wrap_device( |
| device_identifier, &driver_options.device_options, iree_vk_syms, |
| g_Instance, g_PhysicalDevice, g_Device, &compute_queue_set, |
| &transfer_queue_set, iree_allocator_system(), &iree_vk_device)); |
| // Create a HAL module using the HAL device. |
| iree_vm_module_t* hal_module = nullptr; |
| IREE_CHECK_OK(iree_hal_module_create(iree_vk_device, iree_allocator_system(), |
| &hal_module)); |
| |
| // Load bytecode module from embedded data. |
| IREE_LOG(INFO) << "Loading simple_mul.mlir..."; |
| const struct iree_file_toc_t* module_file_toc = |
| iree_samples_vulkan_simple_mul_bytecode_module_create(); |
| iree_vm_module_t* bytecode_module = nullptr; |
| IREE_CHECK_OK(iree_vm_bytecode_module_create( |
| iree_const_byte_span_t{ |
| reinterpret_cast<const uint8_t*>(module_file_toc->data), |
| module_file_toc->size}, |
| iree_allocator_null(), iree_allocator_system(), &bytecode_module)); |
| // Query for details about what is in the loaded module. |
| iree_vm_module_signature_t bytecode_module_signature = |
| iree_vm_module_signature(bytecode_module); |
| IREE_LOG(INFO) << "Module loaded, have <" |
| << bytecode_module_signature.export_function_count |
| << "> exported functions:"; |
| for (int i = 0; i < bytecode_module_signature.export_function_count; ++i) { |
| iree_vm_function_t function; |
| IREE_CHECK_OK(iree_vm_module_lookup_function_by_ordinal( |
| bytecode_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function)); |
| auto function_name = iree_vm_function_name(&function); |
| auto function_signature = iree_vm_function_signature(&function); |
| IREE_LOG(INFO) << " " << i << ": '" |
| << std::string(function_name.data, function_name.size) |
| << "' with calling convention '" |
| << std::string(function_signature.calling_convention.data, |
| function_signature.calling_convention.size) |
| << "'"; |
| } |
| |
| // Allocate a context that will hold the module state across invocations. |
| iree_vm_context_t* iree_context = nullptr; |
| std::vector<iree_vm_module_t*> modules = {hal_module, bytecode_module}; |
| IREE_CHECK_OK(iree_vm_context_create_with_modules( |
| iree_instance, IREE_VM_CONTEXT_FLAG_NONE, modules.data(), modules.size(), |
| iree_allocator_system(), &iree_context)); |
| IREE_LOG(INFO) << "Context with modules is ready for use"; |
| |
| // Lookup the entry point function. |
| iree_vm_function_t main_function; |
| const char kMainFunctionName[] = "module.simple_mul"; |
| IREE_CHECK_OK(iree_vm_context_resolve_function( |
| iree_context, |
| iree_string_view_t{kMainFunctionName, sizeof(kMainFunctionName) - 1}, |
| &main_function)); |
| iree_string_view_t main_function_name = iree_vm_function_name(&main_function); |
| IREE_LOG(INFO) << "Resolved main function named '" |
| << std::string(main_function_name.data, |
| main_function_name.size) |
| << "'"; |
| // -------------------------------------------------------------------------- |
| |
| // -------------------------------------------------------------------------- |
| // Main loop. |
| bool done = false; |
| while (!done) { |
| SDL_Event event; |
| |
| while (SDL_PollEvent(&event)) { |
| if (event.type == SDL_QUIT) { |
| done = true; |
| } |
| |
| ImGui_ImplSDL2_ProcessEvent(&event); |
| if (event.type == SDL_QUIT) done = true; |
| if (event.type == SDL_WINDOWEVENT && |
| event.window.event == SDL_WINDOWEVENT_RESIZED && |
| event.window.windowID == SDL_GetWindowID(window)) { |
| g_SwapChainResizeWidth = (int)event.window.data1; |
| g_SwapChainResizeHeight = (int)event.window.data2; |
| g_SwapChainRebuild = true; |
| } |
| } |
| |
| if (g_SwapChainRebuild) { |
| g_SwapChainRebuild = false; |
| ImGui_ImplVulkan_SetMinImageCount(g_MinImageCount); |
| ImGui_ImplVulkanH_CreateOrResizeWindow( |
| g_Instance, g_PhysicalDevice, g_Device, &g_MainWindowData, |
| g_QueueFamily, g_Allocator, g_SwapChainResizeWidth, |
| g_SwapChainResizeHeight, g_MinImageCount); |
| g_MainWindowData.FrameIndex = 0; |
| } |
| |
| // Start the Dear ImGui frame |
| ImGui_ImplVulkan_NewFrame(); |
| ImGui_ImplSDL2_NewFrame(window); |
| ImGui::NewFrame(); |
| |
| // Demo window. |
| if (show_demo_window) ImGui::ShowDemoWindow(&show_demo_window); |
| |
| // Custom window. |
| { |
| ImGui::Begin("IREE Vulkan Integration Demo", &show_iree_window, |
| ImGuiWindowFlags_AlwaysAutoResize); |
| |
| ImGui::Checkbox("Show ImGui Demo Window", &show_demo_window); |
| ImGui::Separator(); |
| |
| // ImGui Inputs for two input tensors. |
| // Run computation whenever any of the values changes. |
| static bool dirty = true; |
| static float input_x[] = {4.0f, 4.0f, 4.0f, 4.0f}; |
| static float input_y[] = {2.0f, 2.0f, 2.0f, 2.0f}; |
| static float latest_output[] = {0.0f, 0.0f, 0.0f, 0.0f}; |
| ImGui::Text("Multiply numbers using IREE"); |
| ImGui::PushItemWidth(60); |
| // clang-format off |
| if (ImGui::DragFloat("= x[0]", &input_x[0], 0.5f, 0.f, 0.f, "%.1f")) { dirty = true; } ImGui::SameLine(); // NOLINT |
| if (ImGui::DragFloat("= x[1]", &input_x[1], 0.5f, 0.f, 0.f, "%.1f")) { dirty = true; } ImGui::SameLine(); // NOLINT |
| if (ImGui::DragFloat("= x[2]", &input_x[2], 0.5f, 0.f, 0.f, "%.1f")) { dirty = true; } ImGui::SameLine(); // NOLINT |
| if (ImGui::DragFloat("= x[3]", &input_x[3], 0.5f, 0.f, 0.f, "%.1f")) { dirty = true; } // NOLINT |
| if (ImGui::DragFloat("= y[0]", &input_y[0], 0.5f, 0.f, 0.f, "%.1f")) { dirty = true; } ImGui::SameLine(); // NOLINT |
| if (ImGui::DragFloat("= y[1]", &input_y[1], 0.5f, 0.f, 0.f, "%.1f")) { dirty = true; } ImGui::SameLine(); // NOLINT |
| if (ImGui::DragFloat("= y[2]", &input_y[2], 0.5f, 0.f, 0.f, "%.1f")) { dirty = true; } ImGui::SameLine(); // NOLINT |
| if (ImGui::DragFloat("= y[3]", &input_y[3], 0.5f, 0.f, 0.f, "%.1f")) { dirty = true; } // NOLINT |
| // clang-format on |
| ImGui::PopItemWidth(); |
| |
| if (dirty) { |
| // Some input values changed, run the computation. |
| // This is synchronous and doesn't reuse buffers for now. |
| |
| // Write inputs into mappable buffers. |
| IREE_DLOG(INFO) << "Creating I/O buffers..."; |
| constexpr int32_t kElementCount = 4; |
| iree_hal_allocator_t* allocator = |
| iree_hal_device_allocator(iree_vk_device); |
| iree_hal_memory_type_t input_memory_type = |
| static_cast<iree_hal_memory_type_t>( |
| IREE_HAL_MEMORY_TYPE_HOST_LOCAL | |
| IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE); |
| iree_hal_buffer_usage_t input_buffer_usage = |
| static_cast<iree_hal_buffer_usage_t>( |
| IREE_HAL_BUFFER_USAGE_DISPATCH | |
| IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING | |
| IREE_HAL_BUFFER_USAGE_CONSTANT); |
| iree_hal_buffer_params_t buffer_params; |
| buffer_params.type = input_memory_type; |
| buffer_params.usage = input_buffer_usage; |
| // Wrap input buffers in buffer views. |
| iree_hal_buffer_view_t* input0_buffer_view = nullptr; |
| iree_hal_buffer_view_t* input1_buffer_view = nullptr; |
| IREE_CHECK_OK(iree_hal_buffer_view_allocate_buffer( |
| allocator, |
| /*shape=*/&kElementCount, /*shape_rank=*/1, |
| IREE_HAL_ELEMENT_TYPE_FLOAT_32, |
| IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params, |
| iree_make_const_byte_span(&input_x, sizeof(input_x)), |
| &input0_buffer_view)); |
| IREE_CHECK_OK(iree_hal_buffer_view_allocate_buffer( |
| allocator, |
| /*shape=*/&kElementCount, /*shape_rank=*/1, |
| IREE_HAL_ELEMENT_TYPE_FLOAT_32, |
| IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params, |
| iree_make_const_byte_span(&input_y, sizeof(input_y)), |
| &input1_buffer_view)); |
| // Marshal inputs through a VM variant list. |
| // [arg0|arg1] |
| vm::ref<iree_vm_list_t> inputs; |
| IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, 6, |
| iree_allocator_system(), &inputs)); |
| auto input0_buffer_view_ref = |
| iree_hal_buffer_view_move_ref(input0_buffer_view); |
| auto input1_buffer_view_ref = |
| iree_hal_buffer_view_move_ref(input1_buffer_view); |
| IREE_CHECK_OK( |
| iree_vm_list_push_ref_move(inputs.get(), &input0_buffer_view_ref)); |
| IREE_CHECK_OK( |
| iree_vm_list_push_ref_move(inputs.get(), &input1_buffer_view_ref)); |
| |
| // Prepare outputs list to accept results from the invocation. |
| vm::ref<iree_vm_list_t> outputs; |
| IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, |
| kElementCount * sizeof(float), |
| iree_allocator_system(), &outputs)); |
| |
| // Synchronously invoke the function. |
| IREE_CHECK_OK(iree_vm_invoke(iree_context, main_function, |
| IREE_VM_INVOCATION_FLAG_NONE, |
| /*policy=*/nullptr, inputs.get(), |
| outputs.get(), iree_allocator_system())); |
| |
| // Read back the results. |
| IREE_DLOG(INFO) << "Reading back results..."; |
| auto* output_buffer_view = reinterpret_cast<iree_hal_buffer_view_t*>( |
| iree_vm_list_get_ref_deref(outputs.get(), 0, |
| iree_hal_buffer_view_get_descriptor())); |
| IREE_CHECK_OK(iree_hal_device_transfer_d2h( |
| iree_vk_device, iree_hal_buffer_view_buffer(output_buffer_view), 0, |
| latest_output, sizeof(latest_output), |
| IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout())); |
| |
| dirty = false; |
| } |
| |
| // Display the latest computation output. |
| ImGui::Text("X * Y = [%f, %f, %f, %f]", |
| latest_output[0], // |
| latest_output[1], // |
| latest_output[2], // |
| latest_output[3]); |
| ImGui::Separator(); |
| |
| // Framerate counter. |
| ImGui::Text("Application average %.3f ms/frame (%.1f FPS)", |
| 1000.0f / ImGui::GetIO().Framerate, ImGui::GetIO().Framerate); |
| |
| ImGui::End(); |
| } |
| |
| // Rendering |
| ImGui::Render(); |
| RenderFrame(wd, g_Device, g_Queue); |
| |
| PresentFrame(wd, g_Queue); |
| } |
| // -------------------------------------------------------------------------- |
| |
| // -------------------------------------------------------------------------- |
| // Cleanup |
| iree_vm_module_release(hal_module); |
| iree_vm_module_release(bytecode_module); |
| iree_vm_context_release(iree_context); |
| iree_hal_device_release(iree_vk_device); |
| iree_hal_driver_release(iree_vk_driver); |
| iree_hal_vulkan_syms_release(iree_vk_syms); |
| iree_vm_instance_release(iree_instance); |
| |
| err = vkDeviceWaitIdle(g_Device); |
| check_vk_result(err); |
| ImGui_ImplVulkan_Shutdown(); |
| ImGui_ImplSDL2_Shutdown(); |
| ImGui::DestroyContext(); |
| |
| CleanupVulkanWindow(); |
| CleanupVulkan(); |
| |
| SDL_DestroyWindow(window); |
| SDL_Quit(); |
| // -------------------------------------------------------------------------- |
| |
| return 0; |
| } |
| |
| } // namespace iree |