blob: 1e3f42b4f76eb9624092efb6ab4cec8b53b3cbd8 [file]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Vulkan GUI utility functions
// Other matters here: we need to pull in this first to make sure Vulkan API
// prototypes are defined so that we can statically link against them.
#include "iree/testing/vulkan/vulkan_gui_util.h"
// Other dependencies (helpers, etc.)
#include "iree/base/internal/file_io.h"
#include "iree/base/internal/flags.h"
#include "iree/base/internal/main.h"
#include "iree/base/status.h"
#include "iree/hal/vulkan/registration/driver_module.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/tools/utils/vm_util.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
IREE_FLAG(string, module_file, "-",
"File containing the module to load that contains the entry "
"function. Defaults to stdin.");
IREE_FLAG(string, entry_function, "",
"Name of a function contained in the module specified by input_file "
"to run.");
static iree_status_t parse_function_input(iree_string_view_t flag_name,
void* storage,
iree_string_view_t value) {
auto* list = (std::vector<std::string>*)storage;
list->push_back(std::string(value.data, value.size));
return iree_ok_status();
}
static void print_function_input(iree_string_view_t flag_name, void* storage,
FILE* file) {
auto* list = (std::vector<std::string>*)storage;
if (list->empty()) {
fprintf(file, "# --%.*s=\n", (int)flag_name.size, flag_name.data);
} else {
for (size_t i = 0; i < list->size(); ++i) {
fprintf(file, "--%.*s=\"%s\"\n", (int)flag_name.size, flag_name.data,
list->at(i).c_str());
}
}
}
static std::vector<std::string> FLAG_function_inputs;
IREE_FLAG_CALLBACK(
parse_function_input, print_function_input, &FLAG_function_inputs,
function_input,
"An input value or buffer of the format:\n"
" [shape]xtype=[value]\n"
" 2x2xi32=1 2 3 4\n"
"Optionally, brackets may be used to separate the element values:\n"
" 2x2xi32=[[1 2][3 4]]\n"
"Each occurrence of the flag indicates an input in the order they were\n"
"specified on the command line.");
static VkAllocationCallbacks* g_Allocator = NULL;
static VkInstance g_Instance = VK_NULL_HANDLE;
static VkPhysicalDevice g_PhysicalDevice = VK_NULL_HANDLE;
static VkDevice g_Device = VK_NULL_HANDLE;
static uint32_t g_QueueFamily = (uint32_t)-1;
static VkQueue g_Queue = VK_NULL_HANDLE;
static VkPipelineCache g_PipelineCache = VK_NULL_HANDLE;
static VkDescriptorPool g_DescriptorPool = VK_NULL_HANDLE;
static ImGui_ImplVulkanH_Window g_MainWindowData;
static uint32_t g_MinImageCount = 2;
static bool g_SwapChainRebuild = false;
static int g_SwapChainResizeWidth = 0;
static int g_SwapChainResizeHeight = 0;
namespace iree {
namespace {
void check_vk_result(VkResult err) {
if (err == 0) return;
IREE_LOG(FATAL) << "VkResult: " << err;
}
void CleanupVulkan() {
vkDestroyDescriptorPool(g_Device, g_DescriptorPool, g_Allocator);
vkDestroyDevice(g_Device, g_Allocator);
vkDestroyInstance(g_Instance, g_Allocator);
}
void CleanupVulkanWindow() {
ImGui_ImplVulkanH_DestroyWindow(g_Instance, g_Device, &g_MainWindowData,
g_Allocator);
}
Status GetModuleContentsFromFlags(std::string* out_contents) {
auto module_file = std::string(FLAG_module_file);
if (module_file == "-") {
*out_contents = std::string{std::istreambuf_iterator<char>(std::cin),
std::istreambuf_iterator<char>()};
} else {
IREE_RETURN_IF_ERROR(
file_io::GetFileContents(module_file.c_str(), out_contents));
}
return OkStatus();
}
// Runs the current IREE bytecode module and renders its result to a window
// using ImGui.
Status RunModuleAndUpdateImGuiWindow(
iree_hal_device_t* device, iree_vm_context_t* context,
iree_vm_function_t function, const std::string& function_name,
const vm::ref<iree_vm_list_t>& function_inputs,
const std::vector<RawSignatureParser::Description>& output_descs,
const std::string& window_title) {
vm::ref<iree_vm_list_t> outputs;
IREE_RETURN_IF_ERROR(iree_vm_list_create(/*element_type=*/nullptr,
output_descs.size(),
iree_allocator_system(), &outputs));
IREE_LOG(INFO) << "EXEC @" << function_name;
IREE_RETURN_IF_ERROR(iree_vm_invoke(context, function, /*policy=*/nullptr,
function_inputs.get(), outputs.get(),
iree_allocator_system()));
std::ostringstream oss;
IREE_RETURN_IF_ERROR(PrintVariantList(output_descs, outputs.get(), &oss));
outputs.reset();
ImGui::Begin(window_title.c_str(), /*p_open=*/nullptr,
ImGuiWindowFlags_AlwaysAutoResize);
ImGui::Text("Entry function:");
ImGui::Text("%s", function_name.c_str());
ImGui::Separator();
ImGui::Text("Invocation result:");
ImGui::Text("%s", oss.str().c_str());
ImGui::Separator();
// Framerate counter.
ImGui::Text("Application average %.3f ms/frame (%.1f FPS)",
1000.0f / ImGui::GetIO().Framerate, ImGui::GetIO().Framerate);
ImGui::End();
return OkStatus();
}
} // namespace
extern "C" int iree_main(int argc, char** argv) {
iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
IREE_CHECK_OK(iree_hal_vulkan_driver_module_register(
iree_hal_driver_registry_default()));
// --------------------------------------------------------------------------
// Create a window.
if (SDL_Init(SDL_INIT_VIDEO | SDL_INIT_TIMER) != 0) {
IREE_LOG(FATAL) << "Failed to initialize SDL";
return 1;
}
// Setup window
SDL_WindowFlags window_flags = (SDL_WindowFlags)(
SDL_WINDOW_VULKAN | SDL_WINDOW_RESIZABLE | SDL_WINDOW_ALLOW_HIGHDPI);
SDL_Window* window = SDL_CreateWindow(
"IREE Samples - Vulkan Inference GUI", SDL_WINDOWPOS_CENTERED,
SDL_WINDOWPOS_CENTERED, 1280, 720, window_flags);
// Setup Vulkan
iree_hal_vulkan_features_t iree_vulkan_features =
static_cast<iree_hal_vulkan_features_t>(
IREE_HAL_VULKAN_FEATURE_ENABLE_VALIDATION_LAYERS |
IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS);
std::vector<const char*> layers = GetInstanceLayers(iree_vulkan_features);
std::vector<const char*> extensions =
GetInstanceExtensions(window, iree_vulkan_features);
SetupVulkan(iree_vulkan_features, layers.data(), layers.size(),
extensions.data(), extensions.size(), g_Allocator, &g_Instance,
&g_QueueFamily, &g_PhysicalDevice, &g_Queue, &g_Device,
&g_DescriptorPool);
// Create Window Surface
VkSurfaceKHR surface;
VkResult err;
if (SDL_Vulkan_CreateSurface(window, g_Instance, &surface) == 0) {
printf("Failed to create Vulkan surface.\n");
return 1;
}
// Create Framebuffers
int w, h;
SDL_GetWindowSize(window, &w, &h);
ImGui_ImplVulkanH_Window* wd = &g_MainWindowData;
SetupVulkanWindow(wd, g_Allocator, g_Instance, g_QueueFamily,
g_PhysicalDevice, g_Device, surface, w, h, g_MinImageCount);
// Setup Dear ImGui context
IMGUI_CHECKVERSION();
ImGui::CreateContext();
ImGuiIO& io = ImGui::GetIO();
(void)io;
ImGui::StyleColorsDark();
// Setup Platform/Renderer bindings
ImGui_ImplSDL2_InitForVulkan(window);
ImGui_ImplVulkan_InitInfo init_info = {};
init_info.Instance = g_Instance;
init_info.PhysicalDevice = g_PhysicalDevice;
init_info.Device = g_Device;
init_info.QueueFamily = g_QueueFamily;
init_info.Queue = g_Queue;
init_info.PipelineCache = g_PipelineCache;
init_info.DescriptorPool = g_DescriptorPool;
init_info.Allocator = g_Allocator;
init_info.MinImageCount = g_MinImageCount;
init_info.ImageCount = wd->ImageCount;
init_info.CheckVkResultFn = check_vk_result;
ImGui_ImplVulkan_Init(&init_info, wd->RenderPass);
// Upload Fonts
{
// Use any command queue
VkCommandPool command_pool = wd->Frames[wd->FrameIndex].CommandPool;
VkCommandBuffer command_buffer = wd->Frames[wd->FrameIndex].CommandBuffer;
err = vkResetCommandPool(g_Device, command_pool, 0);
check_vk_result(err);
VkCommandBufferBeginInfo begin_info = {};
begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
begin_info.flags |= VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
err = vkBeginCommandBuffer(command_buffer, &begin_info);
check_vk_result(err);
ImGui_ImplVulkan_CreateFontsTexture(command_buffer);
VkSubmitInfo end_info = {};
end_info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
end_info.commandBufferCount = 1;
end_info.pCommandBuffers = &command_buffer;
err = vkEndCommandBuffer(command_buffer);
check_vk_result(err);
err = vkQueueSubmit(g_Queue, 1, &end_info, VK_NULL_HANDLE);
check_vk_result(err);
err = vkDeviceWaitIdle(g_Device);
check_vk_result(err);
ImGui_ImplVulkan_DestroyFontUploadObjects();
}
// --------------------------------------------------------------------------
// --------------------------------------------------------------------------
// Setup IREE.
// Check API version.
iree_api_version_t actual_version;
iree_status_t status =
iree_api_version_check(IREE_API_VERSION_LATEST, &actual_version);
if (iree_status_is_ok(status)) {
IREE_LOG(INFO) << "IREE runtime API version " << actual_version;
} else {
IREE_LOG(FATAL) << "Unsupported runtime API version " << actual_version;
}
// Register HAL module types.
IREE_CHECK_OK(iree_hal_module_register_types());
// Create a runtime Instance.
iree_vm_instance_t* iree_instance = nullptr;
IREE_CHECK_OK(
iree_vm_instance_create(iree_allocator_system(), &iree_instance));
// Create IREE Vulkan Driver and Device, sharing our VkInstance/VkDevice.
IREE_LOG(INFO) << "Creating Vulkan driver/device";
// Load symbols from our static `vkGetInstanceProcAddr` for IREE to use.
iree_hal_vulkan_syms_t* iree_vk_syms = nullptr;
IREE_CHECK_OK(iree_hal_vulkan_syms_create(
reinterpret_cast<void*>(&vkGetInstanceProcAddr), iree_allocator_system(),
&iree_vk_syms));
// Create the driver sharing our VkInstance.
iree_hal_driver_t* iree_vk_driver = nullptr;
iree_string_view_t driver_identifier = iree_make_cstring_view("vulkan");
iree_hal_vulkan_driver_options_t driver_options;
driver_options.api_version = VK_API_VERSION_1_0;
driver_options.requested_features = static_cast<iree_hal_vulkan_features_t>(
IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS);
IREE_CHECK_OK(iree_hal_vulkan_driver_create_using_instance(
driver_identifier, &driver_options, iree_vk_syms, g_Instance,
iree_allocator_system(), &iree_vk_driver));
// Create a device sharing our VkDevice and queue. This makes capturing with
// vendor tools easier because we will have sync compute residing in the
// rendered frame.
iree_string_view_t device_identifier = iree_make_cstring_view("vulkan");
iree_hal_vulkan_queue_set_t compute_queue_set;
compute_queue_set.queue_family_index = g_QueueFamily;
compute_queue_set.queue_indices = 1 << 0;
iree_hal_vulkan_queue_set_t transfer_queue_set;
transfer_queue_set.queue_indices = 0;
iree_hal_device_t* iree_vk_device = nullptr;
IREE_CHECK_OK(iree_hal_vulkan_wrap_device(
device_identifier, &driver_options.device_options, iree_vk_syms,
g_Instance, g_PhysicalDevice, g_Device, &compute_queue_set,
&transfer_queue_set, iree_allocator_system(), &iree_vk_device));
// Create a HAL module using the HAL device.
iree_vm_module_t* hal_module = nullptr;
IREE_CHECK_OK(iree_hal_module_create(iree_vk_device, iree_allocator_system(),
&hal_module));
// Load bytecode module from embedded data.
IREE_LOG(INFO) << "Loading IREE byecode module...";
std::string module_file;
IREE_CHECK_OK(iree::GetModuleContentsFromFlags(&module_file));
iree_vm_module_t* bytecode_module = nullptr;
IREE_CHECK_OK(iree_vm_bytecode_module_create(
iree_const_byte_span_t{
reinterpret_cast<const uint8_t*>(module_file.data()),
module_file.size()},
iree_allocator_null(), iree_allocator_system(), &bytecode_module));
// Allocate a context that will hold the module state across invocations.
iree_vm_context_t* iree_context = nullptr;
std::vector<iree_vm_module_t*> modules = {hal_module, bytecode_module};
IREE_CHECK_OK(iree_vm_context_create_with_modules(
iree_instance, modules.data(), modules.size(), iree_allocator_system(),
&iree_context));
IREE_LOG(INFO) << "Context with modules is ready for use";
// Lookup the entry point function.
std::string entry_function = FLAG_entry_function;
iree_vm_function_t main_function;
IREE_CHECK_OK(bytecode_module->lookup_function(
bytecode_module->self, IREE_VM_FUNCTION_LINKAGE_EXPORT,
iree_string_view_t{entry_function.data(), entry_function.size()},
&main_function));
iree_string_view_t main_function_name = iree_vm_function_name(&main_function);
IREE_LOG(INFO) << "Resolved main function named '"
<< std::string(main_function_name.data,
main_function_name.size)
<< "'";
IREE_CHECK_OK(ValidateFunctionAbi(main_function));
std::vector<RawSignatureParser::Description> main_function_input_descs;
IREE_CHECK_OK(ParseInputSignature(main_function, &main_function_input_descs));
vm::ref<iree_vm_list_t> main_function_inputs;
IREE_CHECK_OK(ParseToVariantList(
main_function_input_descs, iree_hal_device_allocator(iree_vk_device),
FLAG_function_inputs, &main_function_inputs));
std::vector<RawSignatureParser::Description> main_function_output_descs;
IREE_CHECK_OK(
ParseOutputSignature(main_function, &main_function_output_descs));
const std::string window_title = std::string(FLAG_module_file);
// --------------------------------------------------------------------------
// --------------------------------------------------------------------------
// Main loop.
bool done = false;
while (!done) {
SDL_Event event;
while (SDL_PollEvent(&event)) {
if (event.type == SDL_QUIT) {
done = true;
}
ImGui_ImplSDL2_ProcessEvent(&event);
if (event.type == SDL_QUIT) done = true;
if (event.type == SDL_WINDOWEVENT &&
event.window.event == SDL_WINDOWEVENT_RESIZED &&
event.window.windowID == SDL_GetWindowID(window)) {
g_SwapChainResizeWidth = (int)event.window.data1;
g_SwapChainResizeHeight = (int)event.window.data2;
g_SwapChainRebuild = true;
}
}
if (g_SwapChainRebuild) {
g_SwapChainRebuild = false;
ImGui_ImplVulkan_SetMinImageCount(g_MinImageCount);
ImGui_ImplVulkanH_CreateOrResizeWindow(
g_Instance, g_PhysicalDevice, g_Device, &g_MainWindowData,
g_QueueFamily, g_Allocator, g_SwapChainResizeWidth,
g_SwapChainResizeHeight, g_MinImageCount);
g_MainWindowData.FrameIndex = 0;
}
// Start the Dear ImGui frame
ImGui_ImplVulkan_NewFrame();
ImGui_ImplSDL2_NewFrame(window);
ImGui::NewFrame();
// Custom window.
auto status = RunModuleAndUpdateImGuiWindow(
iree_vk_device, iree_context, main_function, entry_function,
main_function_inputs, main_function_output_descs, window_title);
if (!status.ok()) {
IREE_LOG(FATAL) << status;
done = true;
continue;
}
// Rendering
ImGui::Render();
RenderFrame(wd, g_Device, g_Queue);
PresentFrame(wd, g_Queue);
}
// --------------------------------------------------------------------------
// --------------------------------------------------------------------------
// Cleanup
iree_vm_ref_release(main_function_inputs);
iree_vm_module_release(hal_module);
iree_vm_module_release(bytecode_module);
iree_vm_context_release(iree_context);
iree_hal_device_release(iree_vk_device);
iree_hal_driver_release(iree_vk_driver);
iree_hal_vulkan_syms_release(iree_vk_syms);
iree_vm_instance_release(iree_instance);
err = vkDeviceWaitIdle(g_Device);
check_vk_result(err);
ImGui_ImplVulkan_Shutdown();
ImGui_ImplSDL2_Shutdown();
ImGui::DestroyContext();
CleanupVulkanWindow();
CleanupVulkan();
SDL_DestroyWindow(window);
SDL_Quit();
// --------------------------------------------------------------------------
return 0;
}
} // namespace iree