blob: f92986a59486f98c3c57e7fa7453cc1376db7661 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/HAL/Target/LLVM/LibraryBuilder.h"
#include "llvm/IR/IRBuilder.h"
// =============================================================================
//
// NOTE: these structures model 1:1 those in iree/hal/local/executable_library.h
//
// This file must always track the latest version. Backwards compatibility with
// existing runtimes using older versions of the header is maintained by
// emitting variants of the structs matching those in the older headers and
// selecting between them in the query function based on the requested version.
//
// =============================================================================
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace HAL {
static inline int64_t RoundUpToAlignment(int64_t value, int64_t alignment) {
return (value + (alignment - 1)) & ~(alignment - 1);
}
//===----------------------------------------------------------------------===//
// iree/hal/local/executable_library.h structure types
//===----------------------------------------------------------------------===//
// The IR snippets below were pulled from clang running with `-S -emit-llvm`
// on the executable_library.h header: https://godbolt.org/z/6bMv5jfvf
// %struct.iree_hal_executable_import_table_v0_t = type {
// i32,
// i8**
// }
static llvm::StructType *makeImportTableType(llvm::LLVMContext &context) {
if (auto *existingType = llvm::StructType::getTypeByName(
context, "iree_hal_executable_import_table_v0_t")) {
return existingType;
}
auto *i32Type = llvm::IntegerType::getInt32Ty(context);
auto *i8PtrType = llvm::IntegerType::getInt8PtrTy(context);
auto *type = llvm::StructType::create(context,
{
i32Type,
i8PtrType->getPointerTo(),
},
"iree_hal_executable_import_table_v0_t",
/*isPacked=*/false);
return type;
}
// %struct.iree_hal_executable_environment_v0_t = type {
// ...
// }
static llvm::StructType *makeEnvironmentType(llvm::LLVMContext &context) {
auto *type = llvm::StructType::getTypeByName(
context, "iree_hal_executable_environment_v0_t");
assert(type && "environment type must be defined by ConvertToLLVM");
return type;
}
// %struct.anon = type { i32, i32, i32 }
// %union.iree_hal_vec3_t = type { %struct.anon }
static llvm::StructType *makeVec3Type(llvm::LLVMContext &context) {
if (auto *existingType =
llvm::StructType::getTypeByName(context, "iree_hal_vec3_t")) {
return existingType;
}
auto *i32Type = llvm::IntegerType::getInt32Ty(context);
auto *type = llvm::StructType::create(context,
{
i32Type,
i32Type,
i32Type,
},
"iree_hal_vec3_t",
/*isPacked=*/false);
return type;
}
// %struct.iree_hal_executable_dispatch_state_v0_t = type {
// ...
// }
static llvm::StructType *makeDispatchStateType(llvm::LLVMContext &context) {
auto *type = llvm::StructType::getTypeByName(
context, "iree_hal_executable_dispatch_state_v0_t");
assert(type && "state type must be defined by ConvertToLLVM");
return type;
}
// i32 (%struct.iree_hal_executable_dispatch_state_v0_t*,
// %union.iree_hal_vec3_t*,
// i8*)
static llvm::FunctionType *makeDispatchFunctionType(
llvm::LLVMContext &context) {
auto *dispatchStateType = makeDispatchStateType(context);
auto *i8Type = llvm::IntegerType::getInt8Ty(context);
auto *i32Type = llvm::IntegerType::getInt32Ty(context);
auto *vec3Type = llvm::ArrayType::get(i32Type, 3);
return llvm::FunctionType::get(i32Type,
{
dispatchStateType->getPointerTo(),
vec3Type->getPointerTo(),
i8Type->getPointerTo(),
},
/*isVarArg=*/false);
}
// %struct.iree_hal_executable_dispatch_attrs_v0_t = type {
// i16,
// i16
// }
static llvm::StructType *makeDispatchAttrsType(llvm::LLVMContext &context) {
if (auto *existingType = llvm::StructType::getTypeByName(
context, "iree_hal_executable_dispatch_attrs_v0_t")) {
return existingType;
}
auto *i16Type = llvm::IntegerType::getInt16Ty(context);
auto *type =
llvm::StructType::create(context,
{
i16Type,
i16Type,
},
"iree_hal_executable_dispatch_attrs_v0_t",
/*isPacked=*/false);
return type;
}
// %struct.iree_hal_executable_export_table_v0_t = type {
// i32,
// %struct.iree_hal_executable_dispatch_attrs_v0_t*,
// i32*,
// i8**,
// i8**
// }
static llvm::StructType *makeExportTableType(llvm::LLVMContext &context) {
if (auto *existingType = llvm::StructType::getTypeByName(
context, "iree_hal_executable_export_table_v0_t")) {
return existingType;
}
auto *i32Type = llvm::IntegerType::getInt32Ty(context);
auto *dispatchFunctionType = makeDispatchFunctionType(context);
auto *dispatchAttrsType = makeDispatchAttrsType(context);
auto *i8PtrType = llvm::IntegerType::getInt8PtrTy(context);
auto *type = llvm::StructType::create(
context,
{
i32Type,
dispatchFunctionType->getPointerTo()->getPointerTo(),
dispatchAttrsType->getPointerTo(),
i8PtrType->getPointerTo(),
i8PtrType->getPointerTo(),
},
"iree_hal_executable_export_table_v0_t",
/*isPacked=*/false);
return type;
}
// %struct.iree_hal_executable_library_header_t = type {
// i32,
// i8*,
// i32,
// i32
// }
static llvm::StructType *makeLibraryHeaderType(llvm::LLVMContext &context) {
if (auto *existingType = llvm::StructType::getTypeByName(
context, "iree_hal_executable_library_header_t")) {
return existingType;
}
auto *i32Type = llvm::IntegerType::getInt32Ty(context);
auto *i8PtrType = llvm::IntegerType::getInt8PtrTy(context);
auto *type = llvm::StructType::create(context,
{
i32Type,
i8PtrType,
i32Type,
i32Type,
},
"iree_hal_executable_library_header_t",
/*isPacked=*/false);
return type;
}
// %struct.iree_hal_executable_library_v0_t = type {
// %struct.iree_hal_executable_library_header_t*,
// %struct.iree_hal_executable_import_table_v0_t,
// %struct.iree_hal_executable_export_table_v0_t,
// }
static llvm::StructType *makeLibraryType(llvm::StructType *libraryHeaderType) {
auto &context = libraryHeaderType->getContext();
if (auto *existingType = llvm::StructType::getTypeByName(
context, "iree_hal_executable_library_v0_t")) {
return existingType;
}
auto *importTableType = makeImportTableType(context);
auto *exportTableType = makeExportTableType(context);
auto *type = llvm::StructType::create(context,
{
libraryHeaderType->getPointerTo(),
importTableType,
exportTableType,
},
"iree_hal_executable_library_v0_t",
/*isPacked=*/false);
return type;
}
//===----------------------------------------------------------------------===//
// IR construction utilities
//===----------------------------------------------------------------------===//
// Creates a global NUL-terminated string constant.
//
// Example:
// @.str.2 = private unnamed_addr constant [6 x i8] c"lib_a\00", align 1
static llvm::Constant *getStringConstant(StringRef value,
llvm::Module *module) {
auto i8Type = llvm::IntegerType::getInt8Ty(module->getContext());
auto i32Type = llvm::IntegerType::getInt32Ty(module->getContext());
auto *stringType = llvm::ArrayType::get(i8Type, value.size() + /*NUL*/ 1);
auto *literal =
llvm::ConstantDataArray::getString(module->getContext(), value);
auto *global = new llvm::GlobalVariable(*module, stringType,
/*isConstant=*/true,
llvm::GlobalVariable::PrivateLinkage,
literal, /*Name=*/"");
global->setAlignment(llvm::MaybeAlign(1));
llvm::Constant *zero = llvm::ConstantInt::get(i32Type, 0);
return llvm::ConstantExpr::getInBoundsGetElementPtr(
stringType, global, ArrayRef<llvm::Constant *>{zero, zero});
}
//===----------------------------------------------------------------------===//
// Builder interface
//===----------------------------------------------------------------------===//
llvm::Function *LibraryBuilder::build(StringRef queryFuncName) {
auto &context = module->getContext();
auto *i32Type = llvm::IntegerType::getInt32Ty(context);
auto *environmentType = makeEnvironmentType(context)->getPointerTo();
auto *libraryHeaderType = makeLibraryHeaderType(context);
// %struct.iree_hal_executable_library_header_t**
// @iree_hal_library_query(i32, %struct.iree_hal_executable_environment_v0_t*)
auto *queryFuncType =
llvm::FunctionType::get(libraryHeaderType->getPointerTo(),
{
i32Type,
environmentType,
},
/*isVarArg=*/false);
auto *func =
llvm::Function::Create(queryFuncType, llvm::GlobalValue::InternalLinkage,
queryFuncName, *module);
auto *entryBlock = llvm::BasicBlock::Create(context, "entry", func);
llvm::IRBuilder<> builder(entryBlock);
// Build out the header for each version and select it at runtime.
// NOTE: today there is just one version so this is rather simple:
// return max_version == 0 ? &library : NULL;
auto *v0 = buildLibraryV0((queryFuncName + "_v0").str());
builder.CreateRet(builder.CreateSelect(
builder.CreateICmpEQ(func->getArg(0), llvm::ConstantInt::get(i32Type, 0)),
builder.CreatePointerCast(v0, libraryHeaderType->getPointerTo()),
llvm::ConstantPointerNull::get(libraryHeaderType->getPointerTo())));
return func;
}
llvm::Constant *LibraryBuilder::buildLibraryV0ImportTable(
std::string libraryName) {
auto &context = module->getContext();
auto *importTableType = makeImportTableType(context);
auto *i8Type = llvm::IntegerType::getInt8Ty(context);
auto *i32Type = llvm::IntegerType::getInt32Ty(context);
llvm::Constant *zero = llvm::ConstantInt::get(i32Type, 0);
llvm::Constant *symbolNames =
llvm::Constant::getNullValue(i8Type->getPointerTo());
if (!imports.empty()) {
SmallVector<llvm::Constant *, 4> symbolNameValues;
for (auto &import : imports) {
auto symbolName = import.symbol_name;
if (import.weak) {
symbolName += "?";
}
symbolNameValues.push_back(getStringConstant(symbolName, module));
}
auto *symbolNamesType =
llvm::ArrayType::get(i8Type->getPointerTo(), symbolNameValues.size());
auto *global = new llvm::GlobalVariable(
*module, symbolNamesType, /*isConstant=*/true,
llvm::GlobalVariable::PrivateLinkage,
llvm::ConstantArray::get(symbolNamesType, symbolNameValues),
/*Name=*/libraryName + "_import_names");
symbolNames = llvm::ConstantExpr::getInBoundsGetElementPtr(
symbolNamesType, global, ArrayRef<llvm::Constant *>{zero, zero});
}
return llvm::ConstantStruct::get(
importTableType, {
// count=
llvm::ConstantInt::get(i32Type, imports.size()),
// symbols=
symbolNames,
});
}
llvm::Constant *LibraryBuilder::buildLibraryV0ExportTable(
std::string libraryName) {
auto &context = module->getContext();
auto *exportTableType = makeExportTableType(context);
auto *dispatchFunctionType = makeDispatchFunctionType(context);
auto *dispatchAttrsType = makeDispatchAttrsType(context);
auto *i8Type = llvm::IntegerType::getInt8Ty(context);
auto *i16Type = llvm::IntegerType::getInt16Ty(context);
auto *i32Type = llvm::IntegerType::getInt32Ty(context);
llvm::Constant *zero = llvm::ConstantInt::get(i32Type, 0);
// iree_hal_executable_export_table_v0_t::ptrs
SmallVector<llvm::Constant *, 4> exportPtrValues;
for (auto dispatch : exports) {
exportPtrValues.push_back(dispatch.func);
}
auto *exportPtrsType = llvm::ArrayType::get(
dispatchFunctionType->getPointerTo(), exportPtrValues.size());
llvm::Constant *exportPtrs = new llvm::GlobalVariable(
*module, exportPtrsType, /*isConstant=*/true,
llvm::GlobalVariable::PrivateLinkage,
llvm::ConstantArray::get(exportPtrsType, exportPtrValues),
/*Name=*/libraryName + "_funcs");
// TODO(benvanik): force alignment (16? natural pointer width *2?)
exportPtrs = llvm::ConstantExpr::getInBoundsGetElementPtr(
exportPtrsType, exportPtrs, ArrayRef<llvm::Constant *>{zero, zero});
// iree_hal_executable_export_table_v0_t::attrs
llvm::Constant *exportAttrs =
llvm::Constant::getNullValue(i32Type->getPointerTo());
bool hasNonDefaultAttrs =
llvm::find_if(exports, [](const Dispatch &dispatch) {
return !dispatch.attrs.isDefault();
}) != exports.end();
if (!hasNonDefaultAttrs) {
SmallVector<llvm::Constant *, 4> exportAttrValues;
for (auto dispatch : exports) {
exportAttrValues.push_back(llvm::ConstantStruct::get(
dispatchAttrsType,
{
// local_memory_pages=
llvm::ConstantInt::get(
i16Type, RoundUpToAlignment(dispatch.attrs.localMemorySize,
kWorkgroupLocalMemoryPageSize) /
kWorkgroupLocalMemoryPageSize),
// reserved=
llvm::ConstantInt::get(i16Type, 0),
}));
}
auto *exportAttrsType =
llvm::ArrayType::get(dispatchAttrsType, exportAttrValues.size());
auto *global = new llvm::GlobalVariable(
*module, exportAttrsType, /*isConstant=*/true,
llvm::GlobalVariable::PrivateLinkage,
llvm::ConstantArray::get(exportAttrsType, exportAttrValues),
/*Name=*/libraryName + "_attrs");
// TODO(benvanik): force alignment (16? natural pointer width?)
exportAttrs = llvm::ConstantExpr::getInBoundsGetElementPtr(
exportAttrsType, global, ArrayRef<llvm::Constant *>{zero, zero});
}
// iree_hal_executable_export_table_v0_t::names
llvm::Constant *exportNames =
llvm::Constant::getNullValue(i8Type->getPointerTo()->getPointerTo());
if (mode == Mode::INCLUDE_REFLECTION_ATTRS) {
SmallVector<llvm::Constant *, 4> exportNameValues;
for (auto dispatch : exports) {
exportNameValues.push_back(getStringConstant(dispatch.name, module));
}
auto *exportNamesType =
llvm::ArrayType::get(i8Type->getPointerTo(), exportNameValues.size());
auto *global = new llvm::GlobalVariable(
*module, exportNamesType, /*isConstant=*/true,
llvm::GlobalVariable::PrivateLinkage,
llvm::ConstantArray::get(exportNamesType, exportNameValues),
/*Name=*/libraryName + "_names");
// TODO(benvanik): force alignment (16? natural pointer width *2?)
exportNames = llvm::ConstantExpr::getInBoundsGetElementPtr(
exportNamesType, global, ArrayRef<llvm::Constant *>{zero, zero});
}
// iree_hal_executable_export_table_v0_t::tags
llvm::Constant *exportTags =
llvm::Constant::getNullValue(i8Type->getPointerTo()->getPointerTo());
if (mode == Mode::INCLUDE_REFLECTION_ATTRS) {
SmallVector<llvm::Constant *, 4> exportTagValues;
for (auto dispatch : exports) {
exportTagValues.push_back(getStringConstant(dispatch.tag, module));
}
auto *exportTagsType =
llvm::ArrayType::get(i8Type->getPointerTo(), exportTagValues.size());
auto *global = new llvm::GlobalVariable(
*module, exportTagsType, /*isConstant=*/true,
llvm::GlobalVariable::PrivateLinkage,
llvm::ConstantArray::get(exportTagsType, exportTagValues),
/*Name=*/libraryName + "_tags");
// TODO(benvanik): force alignment (16? natural pointer width *2?)
exportTags = llvm::ConstantExpr::getInBoundsGetElementPtr(
exportTagsType, global, ArrayRef<llvm::Constant *>{zero, zero});
}
return llvm::ConstantStruct::get(
exportTableType, {
// count=
llvm::ConstantInt::get(i32Type, exports.size()),
// ptrs=
exportPtrs,
// attrs=
exportAttrs,
// names=
exportNames,
// tags=
exportTags,
});
}
llvm::Constant *LibraryBuilder::buildLibraryV0(std::string libraryName) {
auto &context = module->getContext();
auto *libraryHeaderType = makeLibraryHeaderType(context);
auto *libraryType = makeLibraryType(libraryHeaderType);
auto *i32Type = llvm::IntegerType::getInt32Ty(context);
// ----- Header -----
auto *libraryHeader = new llvm::GlobalVariable(
*module, libraryHeaderType, /*isConstant=*/true,
llvm::GlobalVariable::PrivateLinkage,
llvm::ConstantStruct::get(
libraryHeaderType,
{
// version=
llvm::ConstantInt::get(i32Type,
static_cast<int64_t>(Version::V_0)),
// name=
getStringConstant(module->getName(), module),
// features=
llvm::ConstantInt::get(i32Type, static_cast<int64_t>(features)),
// sanitizer=
llvm::ConstantInt::get(i32Type,
static_cast<int64_t>(sanitizerKind)),
}),
/*Name=*/libraryName + "_header");
// TODO(benvanik): force alignment (8? natural pointer width?)
// ----- Library -----
auto *library = new llvm::GlobalVariable(
*module, libraryType, /*isConstant=*/true,
llvm::GlobalVariable::PrivateLinkage,
llvm::ConstantStruct::get(libraryType,
{
// header=
libraryHeader,
// imports=
buildLibraryV0ImportTable(libraryName),
// exports=
buildLibraryV0ExportTable(libraryName),
}),
/*Name=*/libraryName);
// TODO(benvanik): force alignment (8? natural pointer width?)
return library;
}
} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir