Revert "[spirv] Switch to use common target description" (#17698)
Reverts iree-org/iree#17623
This appears to have broken some benchmark builds.
diff --git a/compiler/plugins/target/MetalSPIRV/BUILD.bazel b/compiler/plugins/target/MetalSPIRV/BUILD.bazel
index ede5566..9773eff 100644
--- a/compiler/plugins/target/MetalSPIRV/BUILD.bazel
+++ b/compiler/plugins/target/MetalSPIRV/BUILD.bazel
@@ -26,7 +26,6 @@
":SPIRVToMSL",
"//compiler/src/iree/compiler/Codegen/Common",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
- "//compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils:KnownTargets",
"//compiler/src/iree/compiler/Codegen/SPIRV",
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
diff --git a/compiler/plugins/target/MetalSPIRV/CMakeLists.txt b/compiler/plugins/target/MetalSPIRV/CMakeLists.txt
index 678a37a..4dd1b06 100644
--- a/compiler/plugins/target/MetalSPIRV/CMakeLists.txt
+++ b/compiler/plugins/target/MetalSPIRV/CMakeLists.txt
@@ -36,7 +36,6 @@
MLIRVectorDialect
iree::compiler::Codegen::Common
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
- iree::compiler::Codegen::Dialect::GPU::TargetUtils::KnownTargets
iree::compiler::Codegen::SPIRV
iree::compiler::Codegen::Utils
iree::compiler::Dialect::Flow::IR
diff --git a/compiler/plugins/target/MetalSPIRV/MetalSPIRVTarget.cpp b/compiler/plugins/target/MetalSPIRV/MetalSPIRVTarget.cpp
index 25e8e51..6ea1cbf 100644
--- a/compiler/plugins/target/MetalSPIRV/MetalSPIRVTarget.cpp
+++ b/compiler/plugins/target/MetalSPIRV/MetalSPIRVTarget.cpp
@@ -8,7 +8,6 @@
#include "compiler/plugins/target/MetalSPIRV/MetalTargetPlatform.h"
#include "compiler/plugins/target/MetalSPIRV/SPIRVToMSL.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
-#include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
@@ -20,7 +19,9 @@
#include "llvm/TargetParser/Triple.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Target/SPIRV/Serialization.h"
@@ -51,6 +52,60 @@
};
} // namespace
+static spirv::TargetEnvAttr getMetalTargetEnv(MLIRContext *context) {
+ using spirv::Capability;
+ using spirv::Extension;
+
+ // Capabilities and limits according to Metal 3 devices.
+ const std::array<Extension, 4> extensions = {
+ Extension::SPV_KHR_16bit_storage,
+ Extension::SPV_KHR_8bit_storage,
+ Extension::SPV_KHR_storage_buffer_storage_class,
+ Extension::SPV_KHR_variable_pointers,
+ };
+ const std::array<Capability, 21> capabilities = {
+ Capability::Shader,
+ Capability::Int8,
+ Capability::Int16,
+ Capability::Int64,
+ Capability::Float16,
+ Capability::UniformAndStorageBuffer8BitAccess,
+ Capability::StorageBuffer8BitAccess,
+ Capability::StoragePushConstant8,
+ Capability::StorageUniform16,
+ Capability::StorageBuffer16BitAccess,
+ Capability::StoragePushConstant16,
+ Capability::GroupNonUniform,
+ Capability::GroupNonUniformVote,
+ Capability::GroupNonUniformArithmetic,
+ Capability::GroupNonUniformBallot,
+ Capability::GroupNonUniformShuffle,
+ Capability::GroupNonUniformShuffleRelative,
+ Capability::GroupNonUniformQuad,
+ Capability::StoragePushConstant16,
+ Capability::VariablePointers,
+ Capability::VariablePointersStorageBuffer,
+ };
+ auto limits = spirv::ResourceLimitsAttr::get(
+ context,
+ /*max_compute_shared_memory_size=*/32768,
+ /*max_compute_workgroup_invocations=*/1024,
+ /*max_compute_workgroup_size=*/
+ Builder(context).getI32ArrayAttr({1024, 1024, 1024}),
+ /*subgroup_size=*/32,
+ /*min_subgroup_size=*/std::nullopt,
+ /*max_subgroup_size=*/std::nullopt,
+ /*cooperative_matrix_properties_khr=*/ArrayAttr{},
+ /*cooperative_matrix_properties_nv=*/ArrayAttr{});
+
+ auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_3, capabilities,
+ extensions, context);
+ // Further assuming Apple GPUs.
+ return spirv::TargetEnvAttr::get(
+ triple, limits, spirv::ClientAPI::Metal, spirv::Vendor::Apple,
+ spirv::DeviceType::IntegratedGPU, spirv::TargetEnvAttr::kUnknownDeviceID);
+}
+
// TODO: MetalOptions for choosing the Metal version.
class MetalTargetDevice : public TargetDevice {
public:
@@ -90,20 +145,20 @@
MLIRContext *context, StringRef deviceID, DictionaryAttr deviceConfigAttr,
SmallVectorImpl<IREE::HAL::ExecutableTargetAttr> &executableTargetAttrs)
const override {
- executableTargetAttrs.push_back(getExecutableTarget(context));
+ executableTargetAttrs.push_back(
+ getExecutableTarget(context, getMetalTargetEnv(context)));
}
IREE::HAL::ExecutableTargetAttr
- getExecutableTarget(MLIRContext *context) const {
+ getExecutableTarget(MLIRContext *context,
+ spirv::TargetEnvAttr targetEnv) const {
Builder b(context);
SmallVector<NamedAttribute> configItems;
auto addConfig = [&](StringRef name, Attribute value) {
configItems.emplace_back(b.getStringAttr(name), value);
};
- if (auto target = GPU::getMetalTargetDetails(context)) {
- addConfig("iree.gpu.target", target);
- }
+ addConfig(spirv::getTargetEnvAttrName(), targetEnv);
return b.getAttr<IREE::HAL::ExecutableTargetAttr>(
b.getStringAttr("metal-spirv"), b.getStringAttr("metal-msl-fb"),
diff --git a/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir b/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir
index 720e00b..84dc61e 100644
--- a/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir
+++ b/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir
@@ -4,9 +4,7 @@
hal.device.targets = [
#hal.device.target<"metal", [
#hal.executable.target<"metal-spirv", "metal-msl-fb", {
- iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.3,cap:Shader", wgp = <
- compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32],
- max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
}>
]>
]
diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir
index 0f427c5..9f01246 100644
--- a/compiler/plugins/target/ROCM/test/target_device_features.mlir
+++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir
@@ -15,7 +15,7 @@
// GFX940-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>]
// GFX1100: target = #iree_gpu.target<arch = "gfx1100",
-// GFX1100-SAME: mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>]
+// GFX1100-SAME: mma = [<WMMA_F16_16x16x16_F32>]
// GFX1100-SAME: subgroup_size_choices = [32, 64]
// GFX941: target = #iree_gpu.target<arch = "gfx941",
diff --git a/compiler/plugins/target/VulkanSPIRV/BUILD.bazel b/compiler/plugins/target/VulkanSPIRV/BUILD.bazel
index 984bef9..4e51dee 100644
--- a/compiler/plugins/target/VulkanSPIRV/BUILD.bazel
+++ b/compiler/plugins/target/VulkanSPIRV/BUILD.bazel
@@ -25,10 +25,11 @@
deps = [
"//compiler/src/iree/compiler/Codegen/Common",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
- "//compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils:KnownTargets",
"//compiler/src/iree/compiler/Codegen/SPIRV",
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/HAL/Target",
+ "//compiler/src/iree/compiler/Dialect/Vulkan/IR",
+ "//compiler/src/iree/compiler/Dialect/Vulkan/Utils",
"//compiler/src/iree/compiler/PluginAPI",
"//compiler/src/iree/compiler/Utils",
"//runtime/src/iree/schemas:spirv_executable_def_c_fbs",
diff --git a/compiler/plugins/target/VulkanSPIRV/CMakeLists.txt b/compiler/plugins/target/VulkanSPIRV/CMakeLists.txt
index 958e277..c14b76a 100644
--- a/compiler/plugins/target/VulkanSPIRV/CMakeLists.txt
+++ b/compiler/plugins/target/VulkanSPIRV/CMakeLists.txt
@@ -33,10 +33,11 @@
MLIRSupport
iree::compiler::Codegen::Common
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
- iree::compiler::Codegen::Dialect::GPU::TargetUtils::KnownTargets
iree::compiler::Codegen::SPIRV
iree::compiler::Codegen::Utils
iree::compiler::Dialect::HAL::Target
+ iree::compiler::Dialect::Vulkan::IR
+ iree::compiler::Dialect::Vulkan::Utils
iree::compiler::PluginAPI
iree::compiler::Utils
iree::schemas::spirv_executable_def_c_fbs
diff --git a/compiler/plugins/target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/compiler/plugins/target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index 5fdeb54..49bd44d 100644
--- a/compiler/plugins/target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/compiler/plugins/target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -5,9 +5,11 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
-#include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h"
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanDialect.h"
+#include "iree/compiler/Dialect/Vulkan/Utils/TargetEnvironment.h"
#include "iree/compiler/PluginAPI/Client.h"
#include "iree/compiler/Utils/FlatbufferUtils.h"
#include "iree/compiler/Utils/ModuleUtils.h"
@@ -32,19 +34,20 @@
namespace {
struct VulkanSPIRVTargetOptions {
- // Use vp_android_baseline_2022 profile as the default target--it's a good
- // lowest common denominator to guarantee the generated SPIR-V is widely
- // accepted for now. Eventually we want to use a list for multi-targeting.
- std::string targetTriple = "vp_android_baseline_2022";
+ std::string targetTriple = "";
+ std::string targetEnv = "";
bool indirectBindings = false;
void bindOptions(OptionsBinder &binder) {
static llvm::cl::OptionCategory category("VulkanSPIRV HAL Target");
binder.opt<std::string>(
- // TODO: Rename this as target given it's not a triple anymore.
"iree-vulkan-target-triple", targetTriple,
llvm::cl::desc(
"Vulkan target triple controlling the SPIR-V environment."));
+ binder.opt<std::string>(
+ "iree-vulkan-target-env", targetEnv,
+ llvm::cl::desc(
+ "Vulkan target environment as #vk.target_env attribute assembly."));
binder.opt<bool>(
"iree-vulkan-experimental-indirect-bindings", indirectBindings,
llvm::cl::desc(
@@ -53,6 +56,31 @@
};
} // namespace
+// Returns the Vulkan target environment for conversion.
+static spirv::TargetEnvAttr
+getSPIRVTargetEnv(const std::string &vulkanTargetTripleOrEnv,
+ MLIRContext *context) {
+ if (!vulkanTargetTripleOrEnv.empty()) {
+ if (vulkanTargetTripleOrEnv[0] != '#') {
+ // Parse target triple.
+ return convertTargetEnv(
+ Vulkan::getTargetEnvForTriple(context, vulkanTargetTripleOrEnv));
+ }
+
+ // Parse `#vk.target_env<...` attribute assembly.
+ if (auto attr = parseAttribute(vulkanTargetTripleOrEnv, context)) {
+ if (auto vkTargetEnv = llvm::dyn_cast<Vulkan::TargetEnvAttr>(attr)) {
+ return convertTargetEnv(vkTargetEnv);
+ }
+ }
+ emitError(Builder(context).getUnknownLoc())
+ << "cannot parse vulkan target environment as #vk.target_env "
+ "attribute: '"
+ << vulkanTargetTripleOrEnv << "'";
+ }
+ return {};
+}
+
// TODO: VulkanOptions for choosing the Vulkan version and extensions/features.
class VulkanTargetDevice : public TargetDevice {
public:
@@ -91,32 +119,35 @@
MLIRContext *context, StringRef deviceID, DictionaryAttr deviceConfigAttr,
SmallVectorImpl<IREE::HAL::ExecutableTargetAttr> &executableTargetAttrs)
const override {
- executableTargetAttrs.push_back(
- getExecutableTarget(context, options_.indirectBindings));
+ std::string targetTripleOrEnv;
+ if (!options_.targetEnv.empty()) {
+ // TODO(scotttodd): assert if triple is set too? (mutually exclusive)
+ targetTripleOrEnv = options_.targetEnv;
+ } else if (!options_.targetTriple.empty()) {
+ targetTripleOrEnv = options_.targetTriple;
+ } else {
+ targetTripleOrEnv = "unknown-unknown-unknown";
+ }
+
+ executableTargetAttrs.push_back(getExecutableTarget(
+ context, getSPIRVTargetEnv(targetTripleOrEnv, context),
+ options_.indirectBindings));
}
IREE::HAL::ExecutableTargetAttr
- getExecutableTarget(MLIRContext *context, bool indirectBindings) const {
+ getExecutableTarget(MLIRContext *context, spirv::TargetEnvAttr targetEnv,
+ bool indirectBindings) const {
Builder b(context);
SmallVector<NamedAttribute> configItems;
auto addConfig = [&](StringRef name, Attribute value) {
configItems.emplace_back(b.getStringAttr(name), value);
};
+ addConfig(spirv::getTargetEnvAttrName(), targetEnv);
if (indirectBindings) {
addConfig("hal.bindings.indirect", b.getUnitAttr());
}
- // We only care about the architecture right now.
- StringRef arch = StringRef(options_.targetTriple).split("-").first;
- if (auto target = GPU::getVulkanTargetDetails(arch, context)) {
- addConfig("iree.gpu.target", target);
- } else {
- emitError(b.getUnknownLoc(), "Unknown Vulkan target '")
- << options_.targetTriple << "'";
- return nullptr;
- }
-
return IREE::HAL::ExecutableTargetAttr::get(
context, b.getStringAttr("vulkan-spirv"),
indirectBindings ? b.getStringAttr("vulkan-spirv-fb-ptr")
@@ -125,8 +156,8 @@
}
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<IREE::Codegen::IREECodegenDialect, spirv::SPIRVDialect,
- gpu::GPUDialect>();
+ registry.insert<IREE::Codegen::IREECodegenDialect, Vulkan::VulkanDialect,
+ spirv::SPIRVDialect, gpu::GPUDialect>();
}
void
diff --git a/compiler/plugins/target/VulkanSPIRV/test/smoketest.mlir b/compiler/plugins/target/VulkanSPIRV/test/smoketest.mlir
index f8d8159..68d6542 100644
--- a/compiler/plugins/target/VulkanSPIRV/test/smoketest.mlir
+++ b/compiler/plugins/target/VulkanSPIRV/test/smoketest.mlir
@@ -4,9 +4,7 @@
hal.device.targets = [
#hal.device.target<"vulkan", [
#hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
- iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.3,cap:Shader", wgp = <
- compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32, 32],
- max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
}>
]>
]
diff --git a/compiler/plugins/target/WebGPUSPIRV/CMakeLists.txt b/compiler/plugins/target/WebGPUSPIRV/CMakeLists.txt
index caf4460..d98dcf2 100644
--- a/compiler/plugins/target/WebGPUSPIRV/CMakeLists.txt
+++ b/compiler/plugins/target/WebGPUSPIRV/CMakeLists.txt
@@ -48,7 +48,6 @@
MLIRSPIRVTransforms
SPIRV-Tools
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
- iree::compiler::Codegen::Dialect::GPU::TargetUtils::KnownTargets
iree::compiler::Codegen::SPIRV
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::Target
diff --git a/compiler/plugins/target/WebGPUSPIRV/WebGPUSPIRVTarget.cpp b/compiler/plugins/target/WebGPUSPIRV/WebGPUSPIRVTarget.cpp
index 0a37691..8397eb1 100644
--- a/compiler/plugins/target/WebGPUSPIRV/WebGPUSPIRVTarget.cpp
+++ b/compiler/plugins/target/WebGPUSPIRV/WebGPUSPIRVTarget.cpp
@@ -6,16 +6,18 @@
#include "compiler/plugins/target/WebGPUSPIRV/SPIRVToWGSL.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
-#include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
#include "iree/compiler/Codegen/WGSL/Passes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "iree/compiler/PluginAPI/Client.h"
#include "iree/compiler/Utils/FlatbufferUtils.h"
#include "iree/schemas/wgsl_executable_def_builder.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FileSystem.h"
#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/ToolOutputFile.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
@@ -41,6 +43,18 @@
}
};
+// TODO(scotttodd): provide a proper target environment for WebGPU.
+static spirv::TargetEnvAttr getWebGPUTargetEnv(MLIRContext *context) {
+ // TODO(scotttodd): find list of SPIR-V extensions supported by WebGPU/WGSL
+ auto triple = spirv::VerCapExtAttr::get(
+ spirv::Version::V_1_0, {spirv::Capability::Shader},
+ {spirv::Extension::SPV_KHR_storage_buffer_storage_class}, context);
+ return spirv::TargetEnvAttr::get(
+ triple, spirv::getDefaultResourceLimits(context),
+ spirv::ClientAPI::WebGPU, spirv::Vendor::Unknown,
+ spirv::DeviceType::Unknown, spirv::TargetEnvAttr::kUnknownDeviceID);
+}
+
// TODO: WebGPUOptions for choosing the version/extensions/etc.
class WebGPUTargetDevice : public TargetDevice {
public:
@@ -80,20 +94,20 @@
MLIRContext *context, StringRef deviceID, DictionaryAttr deviceConfigAttr,
SmallVectorImpl<IREE::HAL::ExecutableTargetAttr> &executableTargetAttrs)
const override {
- executableTargetAttrs.push_back(getExecutableTarget(context));
+ executableTargetAttrs.push_back(
+ getExecutableTarget(context, getWebGPUTargetEnv(context)));
}
IREE::HAL::ExecutableTargetAttr
- getExecutableTarget(MLIRContext *context) const {
+ getExecutableTarget(MLIRContext *context,
+ spirv::TargetEnvAttr targetEnv) const {
Builder b(context);
SmallVector<NamedAttribute> configItems;
auto addConfig = [&](StringRef name, Attribute value) {
configItems.emplace_back(b.getStringAttr(name), value);
};
- if (auto target = GPU::getWebGPUTargetDetails(context)) {
- addConfig("iree.gpu.target", target);
- }
+ addConfig(spirv::getTargetEnvAttrName(), targetEnv);
return b.getAttr<IREE::HAL::ExecutableTargetAttr>(
b.getStringAttr("webgpu-spirv"), b.getStringAttr("webgpu-wgsl-fb"),
diff --git a/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir b/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir
index 31f361b..1a17240 100644
--- a/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir
+++ b/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir
@@ -5,9 +5,7 @@
hal.device.targets = [
#hal.device.target<"webgpu", [
#hal.executable.target<"webgpu-spirv", "webgpu-wgsl-fb", {
- iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.0,cap:Shader,ext:SPV_KHR_storage_buffer_storage_class", wgp = <
- compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32],
- max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
}>
]>
]
diff --git a/compiler/src/iree/compiler/API/Internal/IREEReduceToolEntryPoint.cpp b/compiler/src/iree/compiler/API/Internal/IREEReduceToolEntryPoint.cpp
index 4cb61ec..7e0d201 100644
--- a/compiler/src/iree/compiler/API/Internal/IREEReduceToolEntryPoint.cpp
+++ b/compiler/src/iree/compiler/API/Internal/IREEReduceToolEntryPoint.cpp
@@ -8,13 +8,13 @@
#include "iree/compiler/tool_entry_points_api.h"
#include "iree/compiler/Tools/init_dialects.h"
-#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/Process.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/PassManager.h"
@@ -50,11 +50,11 @@
static LogicalResult ireeReduceMainFromCL(int argc, char **argv,
MLIRContext ®istry) {
- cl::OptionCategory ireeReduceCategory("iree-reduce options");
+ llvm::cl::OptionCategory ireeReduceCategory("iree-reduce options");
- cl::opt<std::string> testScript(cl::Positional, cl::Required,
- cl::desc("<test script>"),
- cl::cat(ireeReduceCategory));
+ llvm::cl::opt<std::string> testScript(cl::Positional, cl::Required,
+ cl::desc("<test script>"),
+ cl::cat(ireeReduceCategory));
cl::opt<std::string> inputFilename(cl::Positional, cl::desc("<input file>"),
cl::init("-"),
@@ -74,11 +74,12 @@
"output-bytecode", cl::desc("Output the final output as bytecode."),
cl::init(false), llvm::cl::cat(ireeReduceCategory));
- cl::HideUnrelatedOptions(ireeReduceCategory);
+ llvm::cl::HideUnrelatedOptions(ireeReduceCategory);
InitLLVM y(argc, argv);
- cl::ParseCommandLineOptions(argc, argv, "IREE test case reduction tool.\n");
+ llvm::cl::ParseCommandLineOptions(argc, argv,
+ "IREE test case reduction tool.\n");
// When reading from stdin and the input is a tty, it is often a user mistake
// and the process "appears to be stuck". Print a message to let the user know
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index d88dc84..5df8a36 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -8,6 +8,7 @@
#include <numeric>
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
#include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h"
@@ -16,6 +17,7 @@
#include "llvm/ADT/STLForwardCompat.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
@@ -214,9 +216,6 @@
case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
return OpaqueMmaLayout{16, 16, 16, f16, f16, f32};
}
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
- return OpaqueMmaLayout{16, 16, 16, f16, f16, f16};
- }
}
llvm_unreachable("unhandled mfma layout type");
return OpaqueMmaLayout{};
@@ -279,8 +278,7 @@
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
- case MMAIntrinsic::WMMA_F16_16x16x16_F32:
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 4]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
@@ -371,8 +369,7 @@
auto cType = VectorType::get({16}, getCType());
return std::make_tuple(aType, bType, cType);
}
- case MMAIntrinsic::WMMA_F16_16x16x16_F32:
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
auto aType = VectorType::get({16}, getAType());
auto bType = VectorType::get({16}, getBType());
auto cType = VectorType::get({8}, getCType());
@@ -395,7 +392,6 @@
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
- case MMAIntrinsic::WMMA_F16_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
return 1;
}
@@ -410,8 +406,7 @@
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
return 64;
}
- case MMAIntrinsic::WMMA_F16_16x16x16_F32:
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
return 32;
}
}
@@ -425,8 +420,7 @@
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
break;
}
- case MMAIntrinsic::WMMA_F16_16x16x16_F32:
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
return {2, 1};
}
}
@@ -440,8 +434,7 @@
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
break;
}
- case MMAIntrinsic::WMMA_F16_16x16x16_F32:
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
return {1, 2};
}
}
@@ -462,8 +455,7 @@
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*element=*/{1, 4}};
}
- case MMAIntrinsic::WMMA_F16_16x16x16_F32:
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 1}, /*element=*/{1, 16}};
}
}
@@ -478,8 +470,7 @@
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*element=*/{4, 1}};
}
- case MMAIntrinsic::WMMA_F16_16x16x16_F32:
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{1, 16}, /*element=*/{16, 1}};
}
}
@@ -494,8 +485,7 @@
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*element=*/{4, 1}};
}
- case MMAIntrinsic::WMMA_F16_16x16x16_F32:
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
return {/*outer=*/{8, 1}, /*thread=*/{2, 16}, /*element=*/{1, 1}};
}
}
@@ -506,8 +496,7 @@
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
- case MMAIntrinsic::WMMA_F16_16x16x16_F32:
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
return {/*outer=*/{0, 1}, /*thread=*/{1, 0}, /*element=*/{0, 1}};
}
}
@@ -518,8 +507,7 @@
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
- case MMAIntrinsic::WMMA_F16_16x16x16_F32:
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
return {/*outer=*/{0, 1}, /*thread=*/{0, 1}, /*element=*/{1, 0}};
}
}
@@ -530,8 +518,7 @@
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
- case MMAIntrinsic::WMMA_F16_16x16x16_F32:
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
return {/*outer=*/{0, 1}, /*thread=*/{0, 1}, /*element=*/{1, 0}};
}
}
@@ -562,8 +549,7 @@
rhs, acc)
.getResult();
}
- case MMAIntrinsic::WMMA_F16_16x16x16_F32:
- case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
+ case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
return builder.create<amdgpu::WMMAOp>(loc, resultType, lhs, rhs, acc)
.getResult();
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
index a7abbb6..5c1bead 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
@@ -98,19 +98,15 @@
let genSpecializedAttr = 0;
}
-// Format: <kind>_<input-type>_<M>x<N>x<K>_<output-type>
def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 0>;
-def MFMA_F16_32x32x8_F32 : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 1>;
-// TODO: Create separate WMMA ops for AMD and NVIDIA GPUs
+def MFMA_F16_32x32x8_F32 : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 1>;
def WMMA_F16_16x16x16_F32 : I32EnumAttrCase<"WMMA_F16_16x16x16_F32", 2>;
-def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 3>;
def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
"Descriptor for different MMA intrinsics", [
MFMA_F16_16x16x16_F32,
MFMA_F16_32x32x8_F32,
- WMMA_F16_16x16x16_F32,
- WMMA_F16_16x16x16_F16
+ WMMA_F16_16x16x16_F32
]>;
def MMA_LHS : I32EnumAttrCase<"Lhs", 0>;
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
index bdfcace..e02e759 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp
@@ -8,7 +8,6 @@
#include <optional>
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
-#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -157,7 +156,6 @@
const WgpDetails *getRDNA3WgpDetails() {
static const MMAIntrinsic rdna3MMAOps[] = {
MMAIntrinsic::WMMA_F16_16x16x16_F32,
- MMAIntrinsic::WMMA_F16_16x16x16_F16,
};
static const WgpDetails rdna3Wgp = {
allComputeBits, allStorageBits, allSubgroupOps,
@@ -167,29 +165,11 @@
return &rdna3Wgp;
}
-const WgpDetails *getRDNA2WgpDetails() {
- static const WgpDetails rdna2Wgp = {
- allComputeBits, allStorageBits, allSubgroupOps, allDotProductOps,
- /*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 64}, {1024, 1024, 1024},
- 1024, 64 * 1024};
- return &rdna2Wgp;
-}
-
-const WgpDetails *getRDNA1WgpDetails() {
- static const WgpDetails rdna1Wgp = {
- allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None,
- /*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 64}, {1024, 1024, 1024},
- 1024, 64 * 1024};
- return &rdna1Wgp;
-}
-
std::optional<TargetDetails> getAMDGPUTargetDetails(StringRef target) {
const WgpDetails *cdna3Wgp = getCDNA3WgpDetails();
const WgpDetails *cdna2Wgp = getCDNA2WgpDetails();
const WgpDetails *cdna1Wgp = getCDNA1WgpDetails();
const WgpDetails *rdna3Wgp = getRDNA3WgpDetails();
- const WgpDetails *rdna2Wgp = getRDNA2WgpDetails();
- const WgpDetails *rdna1Wgp = getRDNA1WgpDetails();
// "AMD Instinct MI300 Series Product Offerings" in Page 23 of
// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/white-papers/amd-cdna-3-white-paper.pdf
@@ -235,10 +215,6 @@
.Case("rx7700xt", TargetDetails{rdna3Wgp, &rx7700xtChip})
.Cases("rdna3", "gfx1100", "gfx1101", "gfx1102", "gfx1103", "gfx1150",
"gfx1151", TargetDetails{rdna3Wgp, nullptr})
- .Cases("rdna2", "gfx1030", "gfx1031", "gfx1032", "gfx1033", "gfx1034",
- "gfx1035", "gfx1036", TargetDetails{rdna2Wgp, nullptr})
- .Cases("rdna1", "gfx1010", "gfx1011", "gfx1012", "gfx1013",
- TargetDetails{rdna1Wgp, nullptr})
.Default(std::nullopt);
}
@@ -246,136 +222,41 @@
if (target.starts_with("gfx"))
return target;
- // We cannot accept rdnaN as a target for LLVM AMDGPU backend; so the
- // following is only meant for Vulkan but not HIP.
- if (target.starts_with("rdna"))
- return target;
-
return llvm::StringSwitch<StringRef>(target.lower())
.Case("mi300x", "gfx942")
.Case("mi300a", "gfx940")
.Cases("mi250x", "mi250", "mi210", "cdna2", "gfx90a")
- .Case("cdna1", "gfx908")
.Cases("rx7900xtx", "rx7900xt", "gfx1100")
.Cases("rx7800xt", "rx7700xt", "gfx1101")
.Default(StringRef());
}
//===----------------------------------------------------------------------===//
-// Known Apple target details
-//===----------------------------------------------------------------------===//
-
-std::optional<TargetDetails> getAppleTargetDetails() {
- ComputeBitwidths computeBitwdiths =
- allIntComputeBits | ComputeBitwidths::FP32 | ComputeBitwidths::FP16;
- // clang-format off
- static const WgpDetails wgp = {
- computeBitwdiths, allStorageBits, allSubgroupOps, allDotProductOps,
- /*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 32},
- {1024, 1024, 1024}, 1024, 32 * 1024};
- // clang-format on
-
- return TargetDetails{&wgp, nullptr};
-}
-
-//===----------------------------------------------------------------------===//
-// Known ARM target details
-//===----------------------------------------------------------------------===//
-
-const WgpDetails *getValhallWgpDetails() {
- ComputeBitwidths computeBitwdiths =
- allIntComputeBits | ComputeBitwidths::FP32 | ComputeBitwidths::FP16;
- // clang-format off
- static const WgpDetails valhallWgp = {
- computeBitwdiths, allStorageBits, allSubgroupOps, allDotProductOps,
- /*mmaCount=*/0, /*mmaOps=*/nullptr, {16}, {512, 512, 512},
- 512, 32 * 1024};
- // clang-format on
- return &valhallWgp;
-}
-
-std::optional<TargetDetails> getARMGPUTargetDetails(StringRef target) {
- const WgpDetails *valhallWgp = getValhallWgpDetails();
-
- // Note that the underlying GPU may have certain capabilities but the Android
- // version and driver stack may not expose them. So the following is just and
- // will always be approximate.
-
- return llvm::StringSwitch<std::optional<TargetDetails>>(target.lower())
- // Mali-G715: https://vulkan.gpuinfo.org/displayreport.php?id=29754
- .Cases("mali-g715", "mali-g615", "valhall4",
- TargetDetails{valhallWgp, nullptr})
- // Mali-G710: https://vulkan.gpuinfo.org/displayreport.php?id=30471
- .Cases("mali-g710", "mali-g510", "mali-g310", "valhall3",
- TargetDetails{valhallWgp, nullptr})
- // Mali-G78: https://vulkan.gpuinfo.org/displayreport.php?id=29994
- .Cases("mali-g78", "valhall2", TargetDetails{valhallWgp, nullptr})
- // Mali-G57: https://vulkan.gpuinfo.org/displayreport.php?id=24636
- .Cases("mali-g77", "mali-g57", "valhall1", "valhall",
- TargetDetails{valhallWgp, nullptr})
- .Default(std::nullopt);
-}
-
-StringRef normalizeARMGPUTarget(StringRef target) {
- if (target == "valhall")
- return "valhall1";
- if (target.starts_with("valhall"))
- return target;
-
- return llvm::StringSwitch<StringRef>(target.lower())
- .Cases("mali-g715", "mali-g615", "valhall4")
- .Cases("mali-g710", "mali-g510", "mali-g310", "valhall3")
- .Case("mali-78", "valhall2")
- .Cases("mali-g77", "mali-g57", "valhall1")
- .Default("");
-}
-
-//===----------------------------------------------------------------------===//
// Known NVIDIA target details
//===----------------------------------------------------------------------===//
-// FIXME: In the following query functions, we are using AMD WMMA intrinsics
-// that have different layout from NVIDIA WMMA intrinsics. This is fine given
-// right now we only use this to indicate target features for Vulkan, where all
-// cooperative matrix layouts are opaque. We need to create NVIDIA specific WMMA
-// intrinsics if we need to have explicit layout analysis and register mapping.
-
const WgpDetails *getAmpereWgpDetails() {
- static const MMAIntrinsic mmaOps[] = {
- MMAIntrinsic::WMMA_F16_16x16x16_F32,
- MMAIntrinsic::WMMA_F16_16x16x16_F16,
- };
static const WgpDetails ampereWgp = {
- allComputeBits, allStorageBits, allSubgroupOps,
- allDotProductOps, ARRAY_SIZE(mmaOps), mmaOps,
- {32, 32}, {1024, 1024, 1024}, 1024,
- 163 * 1024};
+ allComputeBits, allStorageBits, allSubgroupOps, allDotProductOps, 0,
+ nullptr, // TODO: Add tensor core operations
+ {32, 32}, {1024, 1024, 1024}, 1024, 163 * 1024};
return &ereWgp;
}
const WgpDetails *getTuringWgpDetails() {
- static const MMAIntrinsic mmaOps[] = {
- MMAIntrinsic::WMMA_F16_16x16x16_F32,
- MMAIntrinsic::WMMA_F16_16x16x16_F16,
- };
static const WgpDetails turingWgp = {
- allComputeBits, allStorageBits, allSubgroupOps,
- allDotProductOps, ARRAY_SIZE(mmaOps), mmaOps,
- {32, 32}, {1024, 1024, 1024}, 1024,
- 64 * 1024};
+ allComputeBits, allStorageBits, allSubgroupOps, allDotProductOps, 0,
+ nullptr, // TODO: Add tensor core operations
+ {32, 32}, {1024, 1024, 1024}, 1024, 64 * 1024};
return &turingWgp;
}
const WgpDetails *getVoltaWgpDetails() {
- static const MMAIntrinsic mmaOps[] = {
- MMAIntrinsic::WMMA_F16_16x16x16_F32,
- MMAIntrinsic::WMMA_F16_16x16x16_F16,
- };
// clang-format off
static const WgpDetails voltaWgp = {
- allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None,
- ARRAY_SIZE(mmaOps), mmaOps, {32, 32}, {1024, 1024, 1024},
- 1024, 96 * 1024};
+ allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None,
+ 0, nullptr, // TODO: Add tensor core operations
+ {32, 32}, {1024, 1024, 1024}, 1024, 96 * 1024};
// clang-format on
return &voltaWgp;
}
@@ -451,126 +332,15 @@
.Default(StringRef());
}
-//===----------------------------------------------------------------------===//
-// Known Qualcomm target details
-//===----------------------------------------------------------------------===//
-
-const WgpDetails *getAdrenoWgpDetails() {
- auto computeBitwdiths = ComputeBitwidths::Int32 | ComputeBitwidths::Int16 |
- ComputeBitwidths::Int8 | ComputeBitwidths::FP32 |
- ComputeBitwidths::FP16;
- auto storageBitwidths =
- StorageBitwidths::B64 | StorageBitwidths::B32 | StorageBitwidths::B16;
- // clang-format off
- static const WgpDetails adrenoWgp = {
- computeBitwdiths, storageBitwidths, allSubgroupOps,
- allDotProductOps, /*mmaCount=*/0, /*mmaOps=*/nullptr,
- {64}, {1024, 1024, 1024}, 1024,
- 32 * 1024};
- // clang-format on
- return &adrenoWgp;
-}
-
-bool verifyQualcommGPUTarget(StringRef target) {
- if (target == "adreno")
- return true;
-
- StringRef t = target;
- if (!t.consume_front("adreno-"))
- return false;
-
- // The can exist an optional L at the end.
- if (t.ends_with("l"))
- t = t.drop_back();
-
- // Check whether we have a product number
- unsigned number = 0;
- // StringRef::consumeInteger() returns true to signify errors.
- if (t.size() != 3 || t.consumeInteger(10, number))
- return false;
-
- return true;
-}
-
-std::optional<TargetDetails> getQualcommGPUTargetDetails(StringRef target) {
- const WgpDetails *adrenoWgp = getAdrenoWgpDetails();
-
- // Note that the underlying GPU may have certain capabilities but the Android
- // version and driver stack may not expose them. So the following is just and
- // will always be approximate.
-
- // Adreno GPUs are quite opaque regarding their generational information.
- // So right now we only have one target description for all cases.
- //
- // Though some example Adreno GPUs:
- // Adreno-750: https://vulkan.gpuinfo.org/displayreport.php?id=27414
- // Adreno-740: https://vulkan.gpuinfo.org/displayreport.php?id=19218
- // Adreno-730: https://vulkan.gpuinfo.org/displayreport.php?id=19382
- if (verifyQualcommGPUTarget(target))
- return TargetDetails{adrenoWgp, nullptr};
-
- return std::nullopt;
-}
-
-//===----------------------------------------------------------------------===//
-// Vulkan profile details
-//===----------------------------------------------------------------------===//
-
-const WgpDetails *getAndroidBaseline2022WgpDetails() {
- // The following details are from
- // https://github.com/KhronosGroup/Vulkan-Profiles/blob/main/profiles/VP_ANDROID_baseline_2022.json
-
- auto computeBitwdiths = ComputeBitwidths::Int32 | ComputeBitwidths::FP32;
- auto storageBitwidths = StorageBitwidths::B32;
- // FIXME: We cannot have a fixed subgroup size to target a profile; need to
- // have different targets for different subgroup sizes, or change CodeGen to
- // use symbolic subgroup size values, which can be hard for reduction.
- // It's kinda fine now given we don't allow any subgroup ops anyway here..
-
- // clang-format off
- static const WgpDetails androidWgp = {
- computeBitwdiths, storageBitwidths, SubgroupOps::None,
- DotProductOps::None, /*mmaCount=*/0, /*mmaOps=*/nullptr,
- {64, 64}, {128, 128, 64}, 128,
- 16 * 1024};
- // clang-format on
- return &androidWgp;
-}
-
-std::optional<TargetDetails> getAndroidProfileDetails(StringRef target) {
- const WgpDetails *baseline2022Wgp = getAndroidBaseline2022WgpDetails();
-
- return llvm::StringSwitch<std::optional<TargetDetails>>(target.lower())
- .Case("vp_android_baseline_2022", TargetDetails{baseline2022Wgp, nullptr})
- .Default(std::nullopt);
-}
-
} // namespace
//===----------------------------------------------------------------------===//
// Query functions
//===----------------------------------------------------------------------===//
-TargetAttr getMetalTargetDetails(MLIRContext *context) {
- return createTargetAttr(*getAppleTargetDetails(), /*arch=*/"",
- /*features=*/"spirv:v1.3,cap:Shader", context);
-}
-
-TargetAttr getCUDATargetDetails(StringRef target, StringRef features,
- MLIRContext *context) {
- if (std::optional<TargetDetails> details = getNVIDIAGPUTargetDetails(target))
- return createTargetAttr(*details, normalizeNVIDIAGPUTarget(target),
- features, context);
- return nullptr;
-}
-
-StringRef normalizeCUDATarget(StringRef target) {
- return normalizeNVIDIAGPUTarget(target);
-}
-
TargetAttr getHIPTargetDetails(StringRef target, StringRef features,
MLIRContext *context) {
- if (std::optional<TargetDetails> details = getAMDGPUTargetDetails(target)) {
+ if (auto details = getAMDGPUTargetDetails(target)) {
return createTargetAttr(*details, normalizeAMDGPUTarget(target), features,
context);
}
@@ -581,62 +351,16 @@
return normalizeAMDGPUTarget(target);
}
-TargetAttr getVulkanTargetDetails(llvm::StringRef target,
- MLIRContext *context) {
- // Go through each vendor's target details. This assumes we won't have
- // duplicated product or microarchitecture names among vendors, which should
- // be the case.
-
- // For mobile GPUs we target Vulkan 1.1, which accepts SPIR-V 1.3 as the
- // maximum. But the VK_KHR_spirv_1_4 extension is commonly available so we use
- // SPIR-V 1.4. For non-mobile GPUs we target Vulkan 1.3, which accepts
- // SPIR-V 1.6 as the maximum.
-
- if (std::optional<TargetDetails> details = getAMDGPUTargetDetails(target)) {
- return createTargetAttr(*details, normalizeAMDGPUTarget(target),
- /*features=*/"spirv:v1.6,cap:Shader", context);
- }
- if (std::optional<TargetDetails> details = getARMGPUTargetDetails(target)) {
- return createTargetAttr(*details, normalizeARMGPUTarget(target),
- /*features=*/"spirv:v1.4,cap:Shader", context);
- }
- if (std::optional<TargetDetails> details =
- getNVIDIAGPUTargetDetails(target)) {
+TargetAttr getCUDATargetDetails(StringRef target, StringRef features,
+ MLIRContext *context) {
+ if (auto details = getNVIDIAGPUTargetDetails(target))
return createTargetAttr(*details, normalizeNVIDIAGPUTarget(target),
- /*features=*/"spirv:v1.6,cap:Shader", context);
- }
- if (std::optional<TargetDetails> details =
- getQualcommGPUTargetDetails(target)) {
- return createTargetAttr(*details, target,
- /*features=*/"spirv:v1.4,cap:Shader", context);
- }
-
- // Go through common profiles if not hit in the above.
-
- if (std::optional<TargetDetails> details = getAndroidProfileDetails(target)) {
- return createTargetAttr(*details, target,
- /*features=*/"spirv:v1.3,cap:Shader", context);
- }
+ features, context);
return nullptr;
}
-TargetAttr getWebGPUTargetDetails(MLIRContext *context) {
- // TODO(scotttodd): find list of SPIR-V capabilities and extensions supported
- // by WebGPU/WGSL.
- auto computeBitwdiths = ComputeBitwidths::Int32 | ComputeBitwidths::FP32;
- auto storageBitwidths = StorageBitwidths::B32;
- // clang-format off
- static const WgpDetails wgp = {
- computeBitwdiths, storageBitwidths, SubgroupOps::None,
- DotProductOps::None, /*mmaCount=*/0, /*mmaOps=*/nullptr,
- {32, 32}, {128, 128, 64}, 128,
- 16 * 1024};
- // clang-format on
-
- return createTargetAttr(
- {&wgp, nullptr}, /*arch=*/"",
- "spirv:v1.0,cap:Shader,ext:SPV_KHR_storage_buffer_storage_class",
- context);
+StringRef normalizeCUDATarget(StringRef target) {
+ return normalizeNVIDIAGPUTarget(target);
}
TargetAttr getFullTarget(StringRef targetAPI, StringRef aliasTarget,
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.h
index d9698cc..ffe9a15 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.h
@@ -12,22 +12,6 @@
namespace mlir::iree_compiler::IREE::GPU {
-// Returns a TargetAttr to target Metal via SPIR-V CodeGen.
-TargetAttr getMetalTargetDetails(MLIRContext *context);
-
-// Returns a TargetAttr to describe the details of the given |target|, which can
-// be a product name like "rtx3090", an microarchitecture name like "ampere", or
-// a compute capability like "sm_80", with a list of comma-separated target
-// |features|. Returns a null TargetAttr if the given |target| is not
-// recognized.
-TargetAttr getCUDATargetDetails(llvm::StringRef target,
- llvm::StringRef features, MLIRContext *context);
-
-// Normalizes the given CUDA |target| to the gfx target commonly used for
-// compiling towards CUDA. For example, "sm_80" for "a100", "sm_89" for "ada".
-// if the given |target| is not recognized.
-StringRef normalizeCUDATarget(StringRef target);
-
// Returns a TargetAttr to describe the details of the given |target|, which can
// be a product name like "rx7900xtx", an microarchitecture name like "rdna3",
// or a compiler target like "gfx1100", with a list of comma-separated
@@ -42,13 +26,16 @@
StringRef normalizeHIPTarget(StringRef target);
// Returns a TargetAttr to describe the details of the given |target|, which can
-// be a product name like "rtx3090"/"mali-g710"/"adreno" or an microarchitecture
-// name like "ampere"/"valhall". Returns a null TargetAttr if the given |target|
-// is not recognized.
-TargetAttr getVulkanTargetDetails(llvm::StringRef target, MLIRContext *context);
+// be a product name like "rtx3090", an microarchitecture name like "ampere", or
+// a compute capability like "sm_80", with a list of comma-separated target
+// |features|. TargetAttr if the given |target| is not recognized.
+TargetAttr getCUDATargetDetails(llvm::StringRef target,
+ llvm::StringRef features, MLIRContext *context);
-// Returns a TargetAttr to target WebGPU via SPIR-V CodeGen.
-TargetAttr getWebGPUTargetDetails(MLIRContext *context);
+// Normalizes the given CUDA |target| to the gfx target commonly used for
+// compiling towards CUDA. For example, "sm_80" for "a100", "sm_89" for "ada".
+// if the given |target| is not recognized.
+StringRef normalizeCUDATarget(StringRef target);
// Returns the full target of the given |aliasTarget| with a list of
// comma-separated target |features|. Returns null target if unknown.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 39678b9..d73740f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -74,6 +74,9 @@
using CodeGenPipeline = IREE::Codegen::DispatchLoweringPassPipeline;
+constexpr StringLiteral kCudaTarget = "cuda";
+constexpr StringLiteral kRocmTarget = "rocm";
+
// Threshold used to determine whether a matmul dimension is 'very skinny'.
constexpr int64_t kVerySkinnyDimThreshold = 4;
@@ -89,10 +92,6 @@
} // namespace
-bool isROCmBackend(IREE::GPU::TargetAttr target) {
- return target.getArch().starts_with("gfx");
-}
-
//====---------------------------------------------------------------------===//
// Matmul Configuration Helpers
//====---------------------------------------------------------------------===//
@@ -577,8 +576,6 @@
setVectorDistributionConfig(IREE::GPU::TargetAttr target,
mlir::FunctionOpInterface entryPoint,
Operation *computeOp) {
- if (!isROCmBackend(target))
- return failure();
if (!clGPUEnableVectorDistribution) {
LDBG("Vector Distribution not enabled, skipping...");
@@ -1191,6 +1188,15 @@
// Warp Reduction Pipeline Configuration
//====---------------------------------------------------------------------===//
+bool isROCmBackend(mlir::FunctionOpInterface entryPoint) {
+ if (auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(entryPoint)) {
+ if (auto backend = targetAttr.getBackend()) {
+ return backend.getValue() == "rocm";
+ }
+ }
+ return false;
+}
+
/// Set the configuration for reductions that can be mapped to warp reductions.
static LogicalResult
setWarpReductionConfig(IREE::GPU::TargetAttr target,
@@ -1361,8 +1367,8 @@
//
// TODO: This is enabled for matvec on ROCm for now. We should
// validate this strategy and extend to more linalg generics and to CUDA.
- if (isROCmBackend(target) && llvm::none_of(bounds, ShapedType::isDynamic) &&
- isMatvecLike(op)) {
+ if (isROCmBackend(entryPoint) &&
+ llvm::none_of(bounds, ShapedType::isDynamic) && isMatvecLike(op)) {
int64_t lastParallelBound = bounds[parallelDims.back()];
int64_t numParallelReductions = 1;
const int64_t maxParallelFactor = groupSize / 4;
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
index e72fdc5..5c80445 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
@@ -58,7 +58,6 @@
"Passes.cpp",
"SPIRVAnnotateWinogradLoops.cpp",
"SPIRVBreakDownLargeVector.cpp",
- "SPIRVConvertGPUTarget.cpp",
"SPIRVEmulateI64.cpp",
"SPIRVEraseStorageBufferStaticShape.cpp",
"SPIRVFinalVectorLowering.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
index 1378bbc..13632df 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
@@ -57,7 +57,6 @@
"Passes.cpp"
"SPIRVAnnotateWinogradLoops.cpp"
"SPIRVBreakDownLargeVector.cpp"
- "SPIRVConvertGPUTarget.cpp"
"SPIRVEmulateI64.cpp"
"SPIRVEraseStorageBufferStaticShape.cpp"
"SPIRVFinalVectorLowering.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index 75eb17d..538e3ba 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -631,8 +631,6 @@
void buildSPIRVCodegenConfigurationPassPipeline(
OpPassManager &variantPassManager) {
- // TODO: move the following pass to be immediately before ConvertToSPIRVPass.
- variantPassManager.addPass(createSPIRVConvertGPUTargetPass());
OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
buildSPIRVCodegenConfigurationPassPipelineImpl(modulePassManager);
}
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
index a0b0d16..e1dc830 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
@@ -86,10 +86,6 @@
std::unique_ptr<InterfacePass<FunctionOpInterface>>
createSPIRVBreakDownLargeVectorPass();
-// Converts #iree_gpu.target into #spirv.target_env.
-std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
-createSPIRVConvertGPUTargetPass();
-
/// Emulates bfloat 16 ops with 32-bit float ops.
std::unique_ptr<InterfacePass<FunctionOpInterface>>
createSPIRVEmulateBf16Pass();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td
index dc94eb2..29b396d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td
@@ -33,13 +33,6 @@
let constructor = "mlir::iree_compiler::createSPIRVBreakDownLargeVectorPass()";
}
-def SPIRVConvertGPUTarget :
- Pass<"iree-spirv-convert-gpu-target",
- "mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> {
- let summary = "Convert #iree_gpu.target into #spirv.target_env";
- let constructor = "mlir::iree_compiler::createSPIRVConvertGPUTargetPass()";
-}
-
def SPIRVEmulateI64 :
InterfacePass<"iree-spirv-emulate-i64", "mlir::FunctionOpInterface"> {
let summary = "Emulate 64-bit integer ops with 32-bit integer ops";
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVConvertGPUTarget.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVConvertGPUTarget.cpp
deleted file mode 100644
index 3b8c3ba..0000000
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVConvertGPUTarget.cpp
+++ /dev/null
@@ -1,288 +0,0 @@
-// Copyright 2024 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Codegen/SPIRV/PassDetail.h"
-#include "iree/compiler/Codegen/SPIRV/Passes.h"
-#include "iree/compiler/Codegen/Utils/GPUUtils.h"
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/StringExtras.h"
-#include "llvm/ADT/StringSwitch.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
-#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/Pass/Pass.h"
-
-namespace mlir::iree_compiler {
-
-namespace {
-
-using IREE::GPU::ComputeBitwidths;
-using IREE::GPU::DotProductOps;
-using IREE::GPU::StorageBitwidths;
-using IREE::GPU::SubgroupOps;
-
-using spirv::Capability;
-using spirv::ClientAPI;
-using spirv::Extension;
-using spirv::Vendor;
-using spirv::Version;
-
-//===----------------------------------------------------------------------===//
-// Freeform features
-//===----------------------------------------------------------------------===//
-
-// Scans the given |features| list and pushes SPIR-V version specification of
-// 'spirv:v1.x' format into |caps|.
-std::optional<Version> deduceVersion(ArrayRef<StringRef> features) {
- for (StringRef feature : features) {
- if (feature.consume_front("spirv:v1.")) {
- return llvm::StringSwitch<std::optional<Version>>(feature)
- .Case("6", Version::V_1_6)
- .Case("5", Version::V_1_5)
- .Case("4", Version::V_1_4)
- .Case("3", Version::V_1_3)
- .Case("2", Version::V_1_2)
- .Case("1", Version::V_1_1)
- .Case("0", Version::V_1_0)
- .Default(std::nullopt);
- }
- }
- return std::nullopt;
-}
-
-// Scans the given |features| list and pushes capability specification with
-// 'cap:' prefix into |caps|.
-std::optional<Version> processCapabilities(ArrayRef<StringRef> features,
- SetVector<Capability> &caps) {
- for (StringRef feature : features) {
- if (feature.consume_front("cap:")) {
- if (std::optional<Capability> cap = spirv::symbolizeCapability(feature))
- caps.insert(*cap);
- }
- }
- return std::nullopt;
-}
-
-// Scans the given |features| list and pushes extension specification with
-// 'ext:' prefix into |exts|.
-std::optional<Version> processExtensions(ArrayRef<StringRef> features,
- SetVector<Extension> &exts) {
- for (StringRef feature : features) {
- if (feature.consume_front("ext:")) {
- if (std::optional<Extension> ext = spirv::symbolizeExtension(feature))
- exts.insert(*ext);
- }
- }
- return std::nullopt;
-}
-
-//===----------------------------------------------------------------------===//
-// Client API and vendor
-//===----------------------------------------------------------------------===//
-
-ClientAPI deduceClientAPI(StringRef backend) {
- return llvm::StringSwitch<ClientAPI>(backend)
- .Case("vulkan", ClientAPI::Vulkan)
- .Case("metal", ClientAPI::Metal)
- .Case("webgpu", ClientAPI::WebGPU)
- .Case("opencl", ClientAPI::OpenCL)
- .Default(ClientAPI::Unknown);
-}
-
-Vendor deduceVendor(StringRef arch) {
- if (arch.starts_with("gfx") || arch.starts_with("rdna"))
- return Vendor::AMD;
- if (arch.starts_with("mali"))
- return Vendor::ARM;
- if (arch.starts_with("sm_"))
- return Vendor::NVIDIA;
- if (arch.starts_with("adreno"))
- return Vendor::Qualcomm;
- return Vendor::Unknown;
-}
-
-//===----------------------------------------------------------------------===//
-// Workgroup-processor features and limits
-//===----------------------------------------------------------------------===//
-
-void addComputeFeatures(ComputeBitwidths compute, SetVector<Capability> &caps,
- SetVector<Extension> &exts) {
- if (bitEnumContainsAny(compute, ComputeBitwidths::FP64))
- caps.insert(Capability::Float64);
- // FP32 does not need special capabilities or extensions.
- if (bitEnumContainsAny(compute, ComputeBitwidths::FP16))
- caps.insert(Capability::Float16);
-
- if (bitEnumContainsAny(compute, ComputeBitwidths::Int64))
- caps.insert(Capability::Int64);
- // Int32 does not need special capabilities or extensions.
- if (bitEnumContainsAny(compute, ComputeBitwidths::Int16))
- caps.insert(Capability::Int16);
- if (bitEnumContainsAny(compute, ComputeBitwidths::Int8))
- caps.insert(Capability::Int8);
-}
-
-void addStorageFeatures(StorageBitwidths storage, SetVector<Capability> &caps,
- SetVector<Extension> &exts) {
- // 64bit does not need special capabilities or extensions.
- // 32bit does not need special capabilities or extensions.
- if (bitEnumContainsAny(storage, StorageBitwidths::B16)) {
- caps.insert(Capability::StorageBuffer16BitAccess);
- caps.insert(Capability::StorageUniform16);
- caps.insert(Capability::StoragePushConstant16);
- exts.insert(Extension::SPV_KHR_16bit_storage);
- }
- if (bitEnumContainsAny(storage, StorageBitwidths::B8)) {
- caps.insert(Capability::StorageBuffer8BitAccess);
- caps.insert(Capability::UniformAndStorageBuffer8BitAccess);
- caps.insert(Capability::StoragePushConstant8);
- exts.insert(Extension::SPV_KHR_8bit_storage);
- }
-}
-
-void addSubgroupFeatures(SubgroupOps subgroup, SetVector<Capability> &caps,
- SetVector<Extension> &exts) {
- if (bitEnumContainsAny(subgroup, SubgroupOps::Shuffle)) {
- caps.insert(Capability::GroupNonUniformShuffle);
- caps.insert(Capability::GroupNonUniformShuffleRelative);
- }
- if (bitEnumContainsAny(subgroup, SubgroupOps::Arithmetic)) {
- caps.insert(Capability::GroupNonUniformArithmetic);
- }
-}
-
-void addDotProductFeatures(ComputeBitwidths compute, DotProductOps dotProduct,
- SetVector<Capability> &caps,
- SetVector<Extension> &exts) {
- if (bitEnumContainsAny(dotProduct, DotProductOps::DP4xI8ToI32)) {
- caps.insert(Capability::DotProduct);
- caps.insert(Capability::DotProductInput4x8BitPacked); // Use i32 input
- caps.insert(Capability::DotProductInputAll); // Use vector<*> input
- if (bitEnumContainsAny(compute, ComputeBitwidths::Int8)) {
- caps.insert(Capability::DotProductInput4x8Bit); // Use vector<4xi8> input
- }
- exts.insert(Extension::SPV_KHR_integer_dot_product);
- }
-}
-
-void addMatrixFeatures(IREE::GPU::MMAOpsArrayAttr mmaOps,
- SetVector<Capability> &caps, SetVector<Extension> &exts,
- SetVector<Attribute> &coopMatAttrs) {
- if (!mmaOps.empty()) {
- caps.insert(Capability::CooperativeMatrixKHR);
- exts.insert(Extension::SPV_KHR_cooperative_matrix);
- }
-}
-
-spirv::ResourceLimitsAttr convertLimits(StringRef arch,
- IREE::GPU::TargetWgpAttr wgp) {
- MLIRContext *context = wgp.getContext();
- Builder b(context);
-
- SmallVector<Attribute, 4> coopMatAttrs;
- for (IREE::GPU::MMAAttr mmaOp : wgp.getMma()) {
- auto [mSize, nSize, kSize] = mmaOp.getMNKShape();
- auto [aType, bType, cType] = mmaOp.getABCElementTypes();
- coopMatAttrs.push_back(spirv::CooperativeMatrixPropertiesKHRAttr::get(
- context, mSize, nSize, kSize, aType, bType, cType, cType,
- false /*saturatingAccumulation*/,
- spirv::ScopeAttr::get(context, spirv::Scope::Subgroup)));
- }
-
- ArrayRef<int> subgroupSizes = wgp.getSubgroupSizeChoices().asArrayRef();
- const int minSubgroupSize = *llvm::min_element(subgroupSizes);
- const int maxSubgroupSize = *llvm::max_element(subgroupSizes);
- // This is mostly to match RDNA behavior on Vulkan--RDNA supports either 32 or
- // 64 as subgroup sizes; the default subgroup size is 64.
- const int preferredSubgroupSize = maxSubgroupSize;
-
- return spirv::ResourceLimitsAttr::get(
- context, wgp.getMaxWorkgroupMemoryBytes(),
- wgp.getMaxThreadCountPerWorkgroup(),
- b.getI32ArrayAttr(wgp.getMaxWorkgroupSizes().asArrayRef()),
- preferredSubgroupSize, minSubgroupSize, maxSubgroupSize,
- ArrayAttr::get(context, coopMatAttrs), ArrayAttr{});
-}
-
-//===----------------------------------------------------------------------===//
-// Target specification conversion
-//===----------------------------------------------------------------------===//
-
-FailureOr<spirv::TargetEnvAttr>
-convertGPUTarget(IREE::HAL::ExecutableVariantOp variant) {
- IREE::HAL::ExecutableTargetAttr target = variant.getTarget();
- IREE::GPU::TargetAttr gpuTarget = getGPUTargetAttr(target);
-
- SmallVector<StringRef> features;
- llvm::SplitString(gpuTarget.getFeatures(), features, ",");
-
- SetVector<Capability> caps;
- SetVector<Extension> exts;
- SetVector<Attribute> coopMatAttrs;
-
- std::optional<Version> version = deduceVersion(features);
- if (!version) {
- return variant.emitError("cannot deduce spirv version from target "
- "features; need to specify 'spirv1.x'");
- }
- processCapabilities(features, caps);
- processExtensions(features, exts);
-
- IREE::GPU::TargetWgpAttr wgp = gpuTarget.getWgp();
- ComputeBitwidths compute = wgp.getCompute().getValue();
- addComputeFeatures(compute, caps, exts);
- addStorageFeatures(wgp.getStorage().getValue(), caps, exts);
- addSubgroupFeatures(wgp.getSubgroup().getValue(), caps, exts);
- addDotProductFeatures(compute, wgp.getDot().getValue(), caps, exts);
- addMatrixFeatures(wgp.getMma(), caps, exts, coopMatAttrs);
-
- auto triple = spirv::VerCapExtAttr::get(
- *version, caps.getArrayRef(), exts.getArrayRef(), variant.getContext());
- return spirv::TargetEnvAttr::get(
- triple, convertLimits(gpuTarget.getArch(), wgp),
- deduceClientAPI(target.getBackend()), deduceVendor(gpuTarget.getArch()),
- spirv::DeviceType::Unknown, spirv::TargetEnvAttr::kUnknownDeviceID);
-}
-
-struct SPIRVConvertGPUTargetPass final
- : SPIRVConvertGPUTargetBase<SPIRVConvertGPUTargetPass> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<spirv::SPIRVDialect>();
- }
-
- void runOnOperation() override {
- IREE::HAL::ExecutableVariantOp variant = getOperation();
- IREE::HAL::ExecutableTargetAttr target = variant.getTarget();
-
- FailureOr<spirv::TargetEnvAttr> spirvTarget = convertGPUTarget(variant);
- if (failed(spirvTarget))
- return signalPassFailure();
-
- Builder b(&getContext());
- auto attrs = llvm::to_vector(target.getConfiguration().getValue());
- attrs.emplace_back(b.getStringAttr(spirv::getTargetEnvAttrName()),
- *spirvTarget);
- auto configAttr = b.getDictionaryAttr(attrs);
-
- auto halTarget = IREE::HAL::ExecutableTargetAttr::get(
- target.getContext(), target.getBackend(), target.getFormat(),
- configAttr);
- variant.setTargetAttr(halTarget);
- }
-};
-
-} // namespace
-
-std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
-createSPIRVConvertGPUTargetPass() {
- return std::make_unique<SPIRVConvertGPUTargetPass>();
-}
-
-} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
index eb59474..ed6a8be 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
@@ -39,7 +39,6 @@
"config_nvidia_matmul_cooperative_ops.mlir",
"config_user.mlir",
"convert_to_spirv.mlir",
- "convert_gpu_target.mlir",
"emulate_i64.mlir",
"erase_storage_buffer_static_shape.mlir",
"illegal_configuration.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
index 273e581..3dc8277 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
@@ -34,7 +34,6 @@
"config_nvidia_matmul.mlir"
"config_nvidia_matmul_cooperative_ops.mlir"
"config_user.mlir"
- "convert_gpu_target.mlir"
"convert_to_spirv.mlir"
"emulate_i64.mlir"
"erase_storage_buffer_static_shape.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir
deleted file mode 100644
index b1f8092..0000000
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir
+++ /dev/null
@@ -1,36 +0,0 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-spirv-convert-gpu-target)))' %s | FileCheck %s
-
-hal.executable @dispatch {
-hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {
- iree.gpu.target = #iree_gpu.target<arch = "rdna3", features = "spirv:v1.6,cap:Shader",
- wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>],
- subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>}>) {
- hal.executable.export public @dispatch ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer>]>]>) {
- ^bb0(%arg0: !hal.device):
- %x, %y, %z = flow.dispatch.workgroup_count_from_slice
- hal.return %x, %y, %z : index, index, index
- }
- builtin.module {
- func.func @dispatch() {
- return
- }
- }
-}
-}
-
-// CHECK: spirv.target_env = #spirv.target_env<#spirv.vce<v1.6,
-// CHECK-SAME: [Shader, Float64, Float16, Int64, Int16, Int8,
-// CHECK-SAME: StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16,
-// CHECK-SMAE: StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8,
-// CHECK-SAME: GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformArithmetic,
-// CHECK-SAME: DotProduct, DotProductInput4x8BitPacked, DotProductInputAll, DotProductInput4x8Bit,
-// CHECK-SAME: CooperativeMatrixKHR],
-// CHECK-SAME: [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_integer_dot_product, SPV_KHR_cooperative_matrix]>,
-// CHECK-SAME: AMD,
-// CHECK-SAME: #spirv.resource_limits<max_compute_shared_memory_size = 65536,
-// CHECK-SAME: max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024 : i32, 1024 : i32, 1024 : i32],
-// CHECK-SAME: subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64,
-// CHECK-SAME: cooperative_matrix_properties_khr = [
-// CHECK-SAME: #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>>,
-// CHECK-SAME: #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>
-// CHECK-SAME: ]>>
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Vulkan/BUILD.bazel
new file mode 100644
index 0000000..236a474
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/BUILD.bazel
@@ -0,0 +1,11 @@
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Vulkan/CMakeLists.txt
new file mode 100644
index 0000000..487e4f1
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/CMakeLists.txt
@@ -0,0 +1,13 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Dialect/Vulkan/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Vulkan/IR/BUILD.bazel
new file mode 100644
index 0000000..da4b65e
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/IR/BUILD.bazel
@@ -0,0 +1,87 @@
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library", "iree_td_library")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_td_library(
+ name = "td_files",
+ srcs = enforce_glob(
+ [
+ "VulkanAttributes.td",
+ "VulkanBase.td",
+ ],
+ include = ["*.td"],
+ ),
+ deps = ["@llvm-project//mlir:OpBaseTdFiles"],
+)
+
+iree_compiler_cc_library(
+ name = "IR",
+ srcs = [
+ "VulkanAttributes.cpp",
+ "VulkanAttributes.cpp.inc",
+ "VulkanDialect.cpp",
+ "VulkanEnums.cpp.inc",
+ "VulkanTypes.cpp",
+ ],
+ hdrs = [
+ "VulkanAttributes.h",
+ "VulkanAttributes.h.inc",
+ "VulkanDialect.h",
+ "VulkanEnums.h.inc",
+ "VulkanTypes.h",
+ ],
+ deps = [
+ ":VulkanAttrsGen",
+ ":VulkanEnumsGen",
+ "//compiler/src/iree/compiler/Dialect/Util/IR",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:SPIRVDialect",
+ "@llvm-project//mlir:Support",
+ ],
+)
+
+iree_gentbl_cc_library(
+ name = "VulkanAttrsGen",
+ tbl_outs = [
+ (
+ ["--gen-attrdef-decls"],
+ "VulkanAttributes.h.inc",
+ ),
+ (
+ ["--gen-attrdef-defs"],
+ "VulkanAttributes.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "VulkanAttributes.td",
+ deps = [":td_files"],
+)
+
+iree_gentbl_cc_library(
+ name = "VulkanEnumsGen",
+ tbl_outs = [
+ (
+ ["--gen-enum-decls"],
+ "VulkanEnums.h.inc",
+ ),
+ (
+ ["--gen-enum-defs"],
+ "VulkanEnums.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "VulkanBase.td",
+ deps = [":td_files"],
+)
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Vulkan/IR/CMakeLists.txt
new file mode 100644
index 0000000..3b03c56
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/IR/CMakeLists.txt
@@ -0,0 +1,59 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Dialect/Vulkan/IR/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ IR
+ HDRS
+ "VulkanAttributes.h"
+ "VulkanAttributes.h.inc"
+ "VulkanDialect.h"
+ "VulkanEnums.h.inc"
+ "VulkanTypes.h"
+ SRCS
+ "VulkanAttributes.cpp"
+ "VulkanAttributes.cpp.inc"
+ "VulkanDialect.cpp"
+ "VulkanEnums.cpp.inc"
+ "VulkanTypes.cpp"
+ DEPS
+ ::VulkanAttrsGen
+ ::VulkanEnumsGen
+ LLVMSupport
+ MLIRIR
+ MLIRSPIRVDialect
+ MLIRSupport
+ iree::compiler::Dialect::Util::IR
+ PUBLIC
+)
+
+iree_tablegen_library(
+ NAME
+ VulkanAttrsGen
+ TD_FILE
+ "VulkanAttributes.td"
+ OUTS
+ --gen-attrdef-decls VulkanAttributes.h.inc
+ --gen-attrdef-defs VulkanAttributes.cpp.inc
+)
+
+iree_tablegen_library(
+ NAME
+ VulkanEnumsGen
+ TD_FILE
+ "VulkanBase.td"
+ OUTS
+ --gen-enum-decls VulkanEnums.h.inc
+ --gen-enum-defs VulkanEnums.cpp.inc
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.cpp b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.cpp
new file mode 100644
index 0000000..dc33c2b
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.cpp
@@ -0,0 +1,359 @@
+// Copyright 2020 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h"
+
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanDialect.h"
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanTypes.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/SMLoc.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/IR/AttributeSupport.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Location.h"
+
+#define GET_ATTRDEF_CLASSES
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.cpp.inc" // IWYU pragma: keep
+
+namespace mlir::iree_compiler::IREE::Vulkan {
+
+//===----------------------------------------------------------------------===//
+// TargetEnv
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+struct TargetEnvAttributeStorage : public AttributeStorage {
+ using KeyTy = std::tuple<Attribute, Attribute, Attribute, spirv::Vendor,
+ spirv::DeviceType, uint32_t, Attribute>;
+
+ TargetEnvAttributeStorage(Attribute version, Attribute revision,
+ Attribute extensions, spirv::Vendor vendorID,
+ spirv::DeviceType deviceType, uint32_t deviceID,
+ Attribute capabilities)
+ : version(version), revision(revision), extensions(extensions),
+ capabilities(capabilities), vendorID(vendorID), deviceType(deviceType),
+ deviceID(deviceID) {}
+
+ bool operator==(const KeyTy &key) const {
+ return key == std::make_tuple(version, revision, extensions, vendorID,
+ deviceType, deviceID, capabilities);
+ }
+
+ static TargetEnvAttributeStorage *
+ construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+ return new (allocator.allocate<TargetEnvAttributeStorage>())
+ TargetEnvAttributeStorage(std::get<0>(key), std::get<1>(key),
+ std::get<2>(key), std::get<3>(key),
+ std::get<4>(key), std::get<5>(key),
+ std::get<6>(key));
+ }
+
+ Attribute version;
+ Attribute revision;
+ Attribute extensions;
+ Attribute capabilities;
+ spirv::Vendor vendorID;
+ spirv::DeviceType deviceType;
+ uint32_t deviceID;
+};
+} // namespace detail
+
+TargetEnvAttr TargetEnvAttr::get(Vulkan::Version version, uint32_t revision,
+ ArrayRef<Extension> extensions,
+ spirv::Vendor vendorID,
+ spirv::DeviceType deviceType,
+ uint32_t deviceID,
+ CapabilitiesAttr capabilities) {
+ mlir::Builder builder(capabilities.getContext());
+ llvm::SmallVector<Attribute, 0> extAttrs;
+ extAttrs.reserve(extensions.size());
+ for (auto ext : extensions) {
+ extAttrs.push_back(ExtensionAttr::get(builder.getContext(), ext));
+ }
+ return get(builder.getI32IntegerAttr(static_cast<uint32_t>(version)),
+ builder.getI32IntegerAttr(revision),
+ builder.getArrayAttr(extAttrs), vendorID, deviceType, deviceID,
+ capabilities);
+}
+
+TargetEnvAttr TargetEnvAttr::get(IntegerAttr version, IntegerAttr revision,
+ ArrayAttr extensions, spirv::Vendor vendorID,
+ spirv::DeviceType deviceType,
+ uint32_t deviceID,
+ CapabilitiesAttr capabilities) {
+ assert(version && revision && extensions && capabilities);
+ MLIRContext *context = version.getContext();
+ return Base::get(context, version, revision, extensions, vendorID, deviceType,
+ deviceID, capabilities);
+}
+
+StringRef TargetEnvAttr::getKindName() { return "target_env"; }
+
+Version TargetEnvAttr::getVersion() {
+ return static_cast<Version>(
+ llvm::cast<IntegerAttr>(getImpl()->version).getValue().getZExtValue());
+}
+
+unsigned TargetEnvAttr::getRevision() {
+ return llvm::cast<IntegerAttr>(getImpl()->revision).getValue().getZExtValue();
+}
+
+TargetEnvAttr::ext_iterator::ext_iterator(ArrayAttr::iterator it)
+ : llvm::mapped_iterator<ArrayAttr::iterator, Extension (*)(Attribute)>(
+ it, [](Attribute attr) {
+ return llvm::cast<ExtensionAttr>(attr).getValue();
+ }) {}
+
+TargetEnvAttr::ext_range TargetEnvAttr::getExtensions() {
+ auto range = getExtensionsAttr().getValue();
+ return {ext_iterator(range.begin()), ext_iterator(range.end())};
+}
+
+ArrayAttr TargetEnvAttr::getExtensionsAttr() {
+ return llvm::cast<ArrayAttr>(getImpl()->extensions);
+}
+
+spirv::Vendor TargetEnvAttr::getVendorID() { return getImpl()->vendorID; }
+
+spirv::DeviceType TargetEnvAttr::getDeviceType() {
+ return getImpl()->deviceType;
+}
+
+uint32_t TargetEnvAttr::getDeviceID() { return getImpl()->deviceID; }
+
+CapabilitiesAttr TargetEnvAttr::getCapabilitiesAttr() {
+ return llvm::cast<CapabilitiesAttr>(getImpl()->capabilities);
+}
+
+LogicalResult
+TargetEnvAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ IntegerAttr version, IntegerAttr revision,
+ ArrayAttr extensions, spirv::Vendor /*vendorID*/,
+ spirv::DeviceType /*deviceType*/, uint32_t /*deviceID*/,
+ CapabilitiesAttr capabilities) {
+ if (!version.getType().isInteger(32))
+ return emitError() << "expected 32-bit integer for version";
+
+ if (!revision.getType().isInteger(32))
+ return emitError() << "expected 32-bit integer for revision";
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Attribute Parsing
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Parses a comma-separated list of keywords, invokes `processKeyword` on each
+/// of the parsed keyword, and returns failure if any error occurs.
+ParseResult parseKeywordList(
+ DialectAsmParser &parser,
+ function_ref<LogicalResult(llvm::SMLoc, StringRef)> processKeyword) {
+ if (parser.parseLSquare())
+ return failure();
+
+ // Special case for empty list.
+ if (succeeded(parser.parseOptionalRSquare()))
+ return success();
+
+ // Keep parsing the keyword and an optional comma following it. If the comma
+ // is successfully parsed, then we have more keywords to parse.
+ do {
+ auto loc = parser.getCurrentLocation();
+ StringRef keyword;
+ if (parser.parseKeyword(&keyword) || failed(processKeyword(loc, keyword)))
+ return failure();
+ } while (succeeded(parser.parseOptionalComma()));
+
+ if (parser.parseRSquare())
+ return failure();
+
+ return success();
+}
+
+/// Parses a TargetEnvAttr.
+Attribute parseTargetAttr(DialectAsmParser &parser) {
+ if (parser.parseLess())
+ return {};
+
+ Builder &builder = parser.getBuilder();
+
+ IntegerAttr versionAttr;
+ {
+ auto loc = parser.getCurrentLocation();
+ StringRef version;
+ if (parser.parseKeyword(&version) || parser.parseComma())
+ return {};
+
+ if (auto versionSymbol = symbolizeVersion(version)) {
+ versionAttr =
+ builder.getI32IntegerAttr(static_cast<uint32_t>(*versionSymbol));
+ } else {
+ parser.emitError(loc, "unknown Vulkan version: ") << version;
+ return {};
+ }
+ }
+
+ IntegerAttr revisionAttr;
+ {
+ unsigned revision = 0;
+ // TODO(antiagainst): it would be nice to parse rN instad of r(N).
+ if (parser.parseKeyword("r") || parser.parseLParen() ||
+ parser.parseInteger(revision) || parser.parseRParen() ||
+ parser.parseComma())
+ return {};
+ revisionAttr = builder.getI32IntegerAttr(revision);
+ }
+
+ ArrayAttr extensionsAttr;
+ {
+ SmallVector<Attribute, 1> extensions;
+ llvm::SMLoc errorloc;
+ StringRef errorKeyword;
+
+ MLIRContext *context = parser.getContext();
+ auto processExtension = [&](llvm::SMLoc loc, StringRef extension) {
+ if (std::optional<Extension> symbol = symbolizeExtension(extension)) {
+ extensions.push_back(ExtensionAttr::get(context, *symbol));
+ return success();
+ }
+ return errorloc = loc, errorKeyword = extension, failure();
+ };
+ if (parseKeywordList(parser, processExtension) || parser.parseComma()) {
+ if (!errorKeyword.empty())
+ parser.emitError(errorloc, "unknown Vulkan extension: ")
+ << errorKeyword;
+ return {};
+ }
+
+ extensionsAttr = builder.getArrayAttr(extensions);
+ }
+
+ // Parse vendor:device-type[:device-id]
+ spirv::Vendor vendorID = spirv::Vendor::Unknown;
+ spirv::DeviceType deviceType = spirv::DeviceType::Unknown;
+ uint32_t deviceID = spirv::TargetEnvAttr::kUnknownDeviceID;
+ {
+ auto loc = parser.getCurrentLocation();
+ StringRef vendorStr;
+ if (parser.parseKeyword(&vendorStr))
+ return {};
+ if (auto vendorSymbol = spirv::symbolizeVendor(vendorStr)) {
+ vendorID = *vendorSymbol;
+ } else {
+ parser.emitError(loc, "unknown vendor: ") << vendorStr;
+ }
+
+ loc = parser.getCurrentLocation();
+ StringRef deviceTypeStr;
+ if (parser.parseColon() || parser.parseKeyword(&deviceTypeStr))
+ return {};
+ if (auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr)) {
+ deviceType = *deviceTypeSymbol;
+ } else {
+ parser.emitError(loc, "unknown device type: ") << deviceTypeStr;
+ }
+
+ loc = parser.getCurrentLocation();
+ if (succeeded(parser.parseOptionalColon())) {
+ if (parser.parseInteger(deviceID))
+ return {};
+ }
+
+ if (parser.parseComma())
+ return {};
+ }
+
+ CapabilitiesAttr capabilities;
+ if (parser.parseAttribute(capabilities))
+ return {};
+
+ if (parser.parseGreater())
+ return {};
+
+ return TargetEnvAttr::get(versionAttr, revisionAttr, extensionsAttr, vendorID,
+ deviceType, deviceID, capabilities);
+}
+} // namespace
+
+Attribute VulkanDialect::parseAttribute(DialectAsmParser &parser,
+ Type type) const {
+ // Vulkan attributes do not have type.
+ if (type) {
+ parser.emitError(parser.getNameLoc(), "unexpected type");
+ return {};
+ }
+
+ // Parse the kind keyword first.
+ StringRef attrKind;
+ Attribute attr;
+ OptionalParseResult result =
+ generatedAttributeParser(parser, &attrKind, type, attr);
+ if (result.has_value()) {
+ if (failed(result.value()))
+ return {};
+ return attr;
+ }
+
+ if (attrKind == TargetEnvAttr::getKindName())
+ return parseTargetAttr(parser);
+
+ parser.emitError(parser.getNameLoc(), "unknown Vulkan attriubte kind: ")
+ << attrKind;
+ return {};
+}
+
+//===----------------------------------------------------------------------===//
+// Attribute Printing
+//===----------------------------------------------------------------------===//
+
+namespace {
+void print(TargetEnvAttr targetEnv, DialectAsmPrinter &printer) {
+ auto &os = printer.getStream();
+ printer << TargetEnvAttr::getKindName() << "<"
+ << stringifyVersion(targetEnv.getVersion()) << ", r("
+ << targetEnv.getRevision() << "), [";
+ interleaveComma(targetEnv.getExtensions(), os,
+ [&](Extension ext) { os << stringifyExtension(ext); });
+ printer << "], " << spirv::stringifyVendor(targetEnv.getVendorID());
+ printer << ":" << spirv::stringifyDeviceType(targetEnv.getDeviceType());
+ auto deviceID = targetEnv.getDeviceID();
+ if (deviceID != spirv::TargetEnvAttr::kUnknownDeviceID) {
+ printer << ":" << targetEnv.getDeviceID();
+ }
+ printer << ", " << targetEnv.getCapabilitiesAttr() << ">";
+}
+} // namespace
+
+void VulkanDialect::printAttribute(Attribute attr,
+ DialectAsmPrinter &printer) const {
+ if (succeeded(generatedAttributePrinter(attr, printer)))
+ return;
+
+ if (auto targetEnv = llvm::dyn_cast<TargetEnvAttr>(attr))
+ return print(targetEnv, printer);
+
+ assert(false && "unhandled Vulkan attribute kind");
+}
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+void VulkanDialect::registerAttributes() {
+ addAttributes<TargetEnvAttr,
+#define GET_ATTRDEF_LIST
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.cpp.inc"
+ >();
+}
+
+} // namespace mlir::iree_compiler::IREE::Vulkan
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h
new file mode 100644
index 0000000..1175db6
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h
@@ -0,0 +1,89 @@
+// Copyright 2020 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_VULKAN_IR_VULKANATTRIBUTES_H_
+#define IREE_COMPILER_DIALECT_VULKAN_IR_VULKANATTRIBUTES_H_
+
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanTypes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
+#include "mlir/IR/BuiltinAttributes.h"
+
+#define GET_ATTRDEF_CLASSES
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h.inc" // IWYU pragma: export
+
+namespace mlir::iree_compiler::IREE::Vulkan {
+
+namespace detail {
+struct TargetEnvAttributeStorage;
+} // namespace detail
+
+/// An attribute that specifies the target version, supported extensions, and
+/// resource limits. These information describles a Vulkan target environment.
+class TargetEnvAttr
+ : public Attribute::AttrBase<TargetEnvAttr, Attribute,
+ detail::TargetEnvAttributeStorage> {
+public:
+ using Base::Base;
+
+ static constexpr StringLiteral name = "vk.target_env";
+
+ /// Gets a TargetEnvAttr instance.
+ // TODO(antiagainst): support other physical device core properties, physical
+ // device core features and per-extension features.
+ static TargetEnvAttr get(Version version, uint32_t revision,
+ ArrayRef<Extension> extensions,
+ spirv::Vendor vendorID, spirv::DeviceType deviceType,
+ uint32_t deviceID, CapabilitiesAttr capabilities);
+ static TargetEnvAttr get(IntegerAttr version, IntegerAttr revision,
+ ArrayAttr extensions, spirv::Vendor vendorID,
+ spirv::DeviceType deviceType, uint32_t deviceID,
+ CapabilitiesAttr capabilities);
+
+ /// Returns the attribute kind's name (without the 'vk.' prefix).
+ static StringRef getKindName();
+
+ /// Returns the target Vulkan version; e.g., for 1.1.120, it should be V_1_1.
+ Version getVersion();
+
+ /// Returns the target Vulkan revision; e.g., for 1.1.120, it should be 120.
+ unsigned getRevision();
+
+ struct ext_iterator final
+ : public llvm::mapped_iterator<ArrayAttr::iterator,
+ Extension (*)(Attribute)> {
+ explicit ext_iterator(ArrayAttr::iterator it);
+ };
+ using ext_range = llvm::iterator_range<ext_iterator>;
+
+ /// Returns the target Vulkan instance and device extensions.
+ ext_range getExtensions();
+ /// Returns the target Vulkan instance and device extensions as an string
+ /// array attribute.
+ ArrayAttr getExtensionsAttr();
+
+ /// Returns the vendor ID.
+ spirv::Vendor getVendorID();
+
+ /// Returns the device type.
+ spirv::DeviceType getDeviceType();
+
+ /// Returns the device ID.
+ uint32_t getDeviceID();
+
+ /// Returns the dictionary attribute containing various Vulkan capabilities
+ /// bits.
+ CapabilitiesAttr getCapabilitiesAttr();
+
+ static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+ IntegerAttr version, IntegerAttr revision,
+ ArrayAttr extensions, spirv::Vendor vendorID,
+ spirv::DeviceType deviceType, uint32_t deviceID,
+ CapabilitiesAttr capabilities);
+};
+
+} // namespace mlir::iree_compiler::IREE::Vulkan
+
+#endif // IREE_COMPILER_DIALECT_VULKAN_IR_VULKANATTRIBUTES_H_
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.td b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.td
new file mode 100644
index 0000000..fcd0ccf
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.td
@@ -0,0 +1,134 @@
+// Copyright 2020 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECT_VULKAN_VULKANATTRIBUTES
+#define IREE_DIALECT_VULKAN_VULKANATTRIBUTES
+
+include "iree/compiler/Dialect/Vulkan/IR/VulkanBase.td"
+
+class VK_Attr<string attrName, string attrMnemonic>
+ : AttrDef<VK_Dialect, attrName> {
+ let mnemonic = attrMnemonic;
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+// Attribute that can be used to specify the configuration of the
+// cooperative matrix multiply instructions supported by the target
+// device. This corresponds to `VkCooperativeMatrixPropertiesKHR` structure:
+// https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkCooperativeMatrixPropertiesKHR.html
+def VK_CooperativeMatrixPropertiesKHRAttr :
+ VK_Attr<"CooperativeMatrixPropertiesKHR", "coop_matrix_props"> {
+ let parameters = (ins
+ "uint32_t":$mSize,
+ "uint32_t":$nSize,
+ "uint32_t":$kSize,
+ "::mlir::Type":$aType,
+ "::mlir::Type":$bType,
+ "::mlir::Type":$cType,
+ "::mlir::Type":$resultType,
+ "bool":$accSat,
+ "::mlir::iree_compiler::IREE::Vulkan::ScopeKHRAttr":$scope
+ );
+}
+
+// TODO(antiagainst): consider auto-generating this file (or part of it) from
+// vk.xml:
+// https://raw.githubusercontent.com/KhronosGroup/Vulkan-Docs/main/xml/vk.xml
+
+// Dictionary attribute containing various Vulkan capability bits. This is
+// aggregated from various Vulkan properties, limits, features from the spec.
+//
+// Note that we are using UnitAttr for booleans to allow omitting to mean false.
+// TODO(antiagainst): support DefaultValuedAttr in StrucctAttr to allow
+// specifying defaults for non-boolean fields.
+def VK_CapabilitiesAttr : VK_Attr<"Capabilities", "caps"> {
+ let parameters = (ins
+ // Core Vulkan 1.0 physical device properties.
+ //
+ // This corresponds to the `VkPhysicalDeviceProperties` structure:
+ // https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkPhysicalDeviceProperties.html
+ "int":$maxComputeSharedMemorySize,
+ "int":$maxComputeWorkGroupInvocations,
+ "::mlir::DenseIntElementsAttr":$maxComputeWorkGroupSize,
+
+ // Core Vulkan 1.0 physical device features.
+ //
+ // This corresponds to the `VkPhysicalDeviceFeatures` structure:
+ // https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkPhysicalDeviceFeatures.html
+ OptionalParameter<"::mlir::UnitAttr">:$shaderFloat64,
+ OptionalParameter<"::mlir::UnitAttr">:$shaderInt16,
+ OptionalParameter<"::mlir::UnitAttr">:$shaderInt64,
+
+ // Core Vulkan 1.1 physical device subgroup properties.
+ //
+ // This corresponds to the `VkPhysicalDeviceSubgroupProperties` structure:
+ // https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkPhysicalDeviceSubgroupProperties.html
+
+ // TODO(antiagainst): StructAttr does not actually support attribute kinds
+ // that are derived from IntegerAttr well. So the nice parsing/printing for
+ // VK_SubgroupFeatureAttr does not really kick in here. We need to enhance
+ // upstream MLIR.
+ "::mlir::iree_compiler::IREE::Vulkan::SubgroupFeatureAttr":$subgroupFeatures,
+ "int":$subgroupSize,
+
+ // VK_EXT_subgroup_size_control features.
+ //
+ // This corresponds to the `VkPhysicalDeviceSubgroupSizeControlProperties` structure:
+ // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkPhysicalDeviceSubgroupSizeControlPropertiesEXT.html
+ OptionalParameter<"::std::optional<int>">:$minSubgroupSize,
+ OptionalParameter<"::std::optional<int>">:$maxSubgroupSize,
+
+ // VK_KHR_16bit_storage features.
+ //
+ // This corresponds to the `VkPhysicalDevice16BitStorageFeatures` structure:
+ // https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkPhysicalDevice16BitStorageFeatures.html
+ OptionalParameter<"::mlir::UnitAttr">:$storageBuffer16BitAccess,
+ OptionalParameter<"::mlir::UnitAttr">:$storagePushConstant16,
+ OptionalParameter<"::mlir::UnitAttr">:$uniformAndStorageBuffer16BitAccess,
+
+ // VK_KHR_8bit_storage features.
+ //
+ // This corresponds to the `VkPhysicalDevice8BitStorageFeatures` structure:
+ // https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkPhysicalDevice8BitStorageFeatures.html
+ OptionalParameter<"::mlir::UnitAttr">:$storageBuffer8BitAccess,
+ OptionalParameter<"::mlir::UnitAttr">:$storagePushConstant8,
+ OptionalParameter<"::mlir::UnitAttr">:$uniformAndStorageBuffer8BitAccess,
+
+ // VK_KHR_device_buffer_address features.
+ // This corresponds to the only capability implied by the extensions:
+ // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_KHR_buffer_device_address.html#_new_spir_v_capabilities
+ OptionalParameter<"::mlir::UnitAttr">:$physicalDeviceBufferAddresses,
+
+ // VK_KHR_shader_float16_int8 features.
+ //
+ // This corresponds to the `VkPhysicalDeviceShaderFloat16Int8Features`
+ // structure:
+ // https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkPhysicalDeviceShaderFloat16Int8Features.html
+ OptionalParameter<"::mlir::UnitAttr">:$shaderFloat16,
+ OptionalParameter<"::mlir::UnitAttr">:$shaderInt8,
+
+ // VK_KHR_shader_integer_dot_product features.
+ //
+ // This corresponds to the `VkPhysicalDeviceShaderIntegerDotProductFeatures`
+ // structure:
+ // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR.html
+ OptionalParameter<"::mlir::UnitAttr">:$shaderIntegerDotProduct,
+
+ // VK_KHR_variable_pointers features.
+ // This corresponds to the `VkPhysicalDeviceVariablePointersFeatures`
+ // structure:
+ // https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkPhysicalDeviceVariablePointersFeatures.html
+ OptionalParameter<"::mlir::UnitAttr">:$variablePointersStorageBuffer,
+ OptionalParameter<"::mlir::UnitAttr">:$variablePointers,
+
+ // VkCooperativeMatrixPropertiesKHR features.
+ // This corresponds to `VkCoooperativeMatrixPropertiesKHR` structure:
+ // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_KHR_cooperative_matrix.html
+ DefaultValuedParameter<"ArrayAttr", "nullptr">:$cooperativeMatrixPropertiesKHR
+ );
+}
+
+#endif // IREE_DIALECT_VULKAN_VULKANATTRIBUTES
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanBase.td b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanBase.td
new file mode 100644
index 0000000..c256111
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanBase.td
@@ -0,0 +1,199 @@
+// Copyright 2020 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECT_VULKAN_BASE
+#define IREE_DIALECT_VULKAN_BASE
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/EnumAttr.td"
+
+//===----------------------------------------------------------------------===//
+// Vulkan dialect definition
+//===----------------------------------------------------------------------===//
+
+def VK_Dialect : Dialect {
+ let name = "vk";
+ let cppNamespace = "::mlir::iree_compiler::IREE::Vulkan";
+
+ let summary = "The Vulkan dialect in IREE";
+ let description = [{
+ Vulkan is a new generation graphics and compute API that provides
+ high-efficiency, cross-platform access to modern GPUs used in a wide
+ variety of devices from PCs and consoles to mobile phones and embedded
+ platforms. See https://www.khronos.org/vulkan for more details regarding
+ Vulkan itself.
+
+ This is not a full-fledged Vulkan dialect that models common Vulkan concepts
+ in intermediate representation to be amenable to compiler analysis and
+ transformation. IREE has the HAL dialect for that purpose. Instead, this
+ dialect contains useful utilities for targeting Vulkan both in CodeGen and
+ runtime.
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Utility definitions
+//===----------------------------------------------------------------------===//
+
+// A predicate that checks whether `$_self` is a known enum case for the
+// enum class with `name`.
+class VK_IsKnownBitEnumCaseFor<string name> :
+ CPred<"::mlir::iree_compiler::IREE::Vulkan::symbolize" # name # "("
+ "cast<IntegerAttr>($_self).getValue().getZExtValue()).hasValue()">;
+class VK_IsKnownIntEnumCaseFor<string name> :
+ CPred<"::mlir::iree_compiler::IREE::Vulkan::symbolize" # name # "("
+ "cast<IntegerAttr>($_self).getValue().getZExtValue()).hasValue()">;
+
+// Wrapper over base I32BitEnumAttr to set common fields.
+class VK_BitEnumAttr<string name, string description,
+ list<I32BitEnumAttrCase> cases> :
+ I32BitEnumAttr<name, description, cases> {
+ let predicate = And<[I32Attr.predicate, VK_IsKnownBitEnumCaseFor<name>]>;
+ let cppNamespace = "::mlir::iree_compiler::IREE::Vulkan";
+}
+
+class VK_I32Enum<string name, string description, list<I32EnumAttrCase> cases> :
+ I32EnumAttr<name, description, cases> {
+ let predicate = And<[I32Attr.predicate, VK_IsKnownIntEnumCaseFor<name>]>;
+ let cppNamespace = "::mlir::iree_compiler::IREE::Vulkan";
+}
+
+class VK_I32EnumAttr<string name, string description, string mnemonic,
+ list<I32EnumAttrCase> cases> :
+ EnumAttr<VK_Dialect, I32EnumAttr<name, description, cases>, mnemonic> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Vulkan";
+ let assemblyFormat = "`<` $value `>`";
+}
+
+//===----------------------------------------------------------------------===//
+// Target environment
+//===----------------------------------------------------------------------===//
+
+def VK_V_1_0 : I32EnumAttrCase<"V_1_0", 0, "v1.0">;
+def VK_V_1_1 : I32EnumAttrCase<"V_1_1", 1, "v1.1">;
+def VK_V_1_2 : I32EnumAttrCase<"V_1_2", 2, "v1.2">;
+def VK_V_1_3 : I32EnumAttrCase<"V_1_3", 3, "v1.3">;
+
+def VK_VersionAttr : VK_I32Enum<"Version", "valid Vulkan version", [
+ VK_V_1_0, VK_V_1_1, VK_V_1_2, VK_V_1_3
+]>;
+
+def VK_KHR_16bit_storage : I32EnumAttrCase<"VK_KHR_16bit_storage", 0>;
+def VK_KHR_8bit_storage : I32EnumAttrCase<"VK_KHR_8bit_storage", 1>;
+def VK_KHR_shader_float16_int8 : I32EnumAttrCase<"VK_KHR_shader_float16_int8", 2>;
+def VK_KHR_shader_integer_dot_product : I32EnumAttrCase<"VK_KHR_shader_integer_dot_product", 3>;
+def VK_KHR_spirv_1_4 : I32EnumAttrCase<"VK_KHR_spirv_1_4", 4>;
+def VK_KHR_storage_buffer_storage_class : I32EnumAttrCase<"VK_KHR_storage_buffer_storage_class", 5>;
+def VK_KHR_variable_pointers: I32EnumAttrCase<"VK_KHR_variable_pointers", 6>;
+def VK_EXT_subgroup_size_control : I32EnumAttrCase<"VK_EXT_subgroup_size_control", 7>;
+def VK_KHR_cooperative_matrix : I32EnumAttrCase<"VK_KHR_cooperative_matrix", 8>;
+def VK_KHR_buffer_device_address : I32EnumAttrCase<"VK_KHR_buffer_device_address", 9>;
+
+def VK_ExtensionAttr :
+ VK_I32EnumAttr<"Extension", "supported Vulkan extension", "extension", [
+ VK_KHR_16bit_storage, VK_KHR_8bit_storage, VK_KHR_shader_float16_int8,
+ VK_KHR_shader_integer_dot_product, VK_KHR_spirv_1_4,
+ VK_KHR_storage_buffer_storage_class, VK_KHR_variable_pointers,
+ VK_EXT_subgroup_size_control, VK_KHR_cooperative_matrix,
+ VK_KHR_buffer_device_address
+ ]>;
+
+//===----------------------------------------------------------------------===//
+// Target triple
+//===----------------------------------------------------------------------===//
+
+def VK_TTA_Unknown : I32EnumAttrCase<"Unknown", 0, "unknown">;
+// Software emulated GPU
+def VK_TTA_CPU : I32EnumAttrCase<"CPU", 1, "cpu">;
+// AMD GPU
+def VK_TTA_RDNAv1 : I32EnumAttrCase<"AMD_RDNAv1", 100, "rdna1">;
+def VK_TTA_RDNAv2 : I32EnumAttrCase<"AMD_RDNAv2", 101, "rdna2">;
+def VK_TTA_RDNAv3 : I32EnumAttrCase<"AMD_RDNAv3", 102, "rdna3">;
+// Apple Silicon GPU
+def VK_TTA_M1 : I32EnumAttrCase<"Apple_M1", 200, "m1">;
+// ARM Mali GPU
+def VK_TTA_Valhall : I32EnumAttrCase<"ARM_Valhall", 300, "valhall">;
+// NVIDIA GPU
+def VK_TTA_Turing : I32EnumAttrCase<"NV_Turing", 400, "turing">;
+def VK_TTA_Ampere : I32EnumAttrCase<"NV_Ampere", 401, "ampere">;
+def VK_TTA_Pascal : I32EnumAttrCase<"NV_Pascal", 402, "pascal">;
+// Qualcomm Adreno GPU
+def VK_TTA_Adreno : I32EnumAttrCase<"QC_Adreno", 500, "adreno">;
+// Intel ARC GPU
+def VK_TTA_Arc : I32EnumAttrCase<"Intel_Arc", 600, "arc">;
+
+def VK_TargetArchAttr : VK_I32Enum<
+ "TargetTripleArch", "recognized target architecture", [
+ VK_TTA_Unknown, VK_TTA_CPU, VK_TTA_RDNAv1, VK_TTA_RDNAv2,
+ VK_TTA_RDNAv3, VK_TTA_M1, VK_TTA_Valhall, VK_TTA_Turing, VK_TTA_Ampere,
+ VK_TTA_Pascal, VK_TTA_Adreno, VK_TTA_Arc,
+ ]>;
+
+def VK_TTP_Unknown : I32EnumAttrCase<"Unknown", 0, "unknown">;
+// Qualcomm Adreno GPU
+def VK_TTP_Adreno640 : I32EnumAttrCase<"Adreno_640", 100, "a640">;
+def VK_TTP_Adreno650 : I32EnumAttrCase<"Adreno_650", 101, "a650">;
+def VK_TTP_Adreno660 : I32EnumAttrCase<"Adreno_660", 102, "a660">;
+// Software emulated GPU
+def VK_TTP_SwiftShader : I32EnumAttrCase<"SwiftShader", 200, "swiftshader">;
+// Translation layers
+def VK_TTP_MoltenVK : I32EnumAttrCase<"MoltenVK", 300, "moltenvk">;
+
+def VK_TargetProductAttr : VK_I32Enum<
+ "TargetTripleProduct", "recognized target product", [
+ VK_TTP_Unknown, VK_TTP_Adreno650, VK_TTP_Adreno660, VK_TTP_SwiftShader,
+ VK_TTP_MoltenVK,
+ ]>;
+
+def VK_TTOS_Unknown : I32EnumAttrCase<"Unknown", 0, "unknown">;
+def VK_TTOS_Linux : I32EnumAttrCase<"Linux", 1, "linux">;
+def VK_TTOS_iOS : I32EnumAttrCase<"iOS", 2, "iOS">;
+def VK_TTOS_macOS : I32EnumAttrCase<"macOS", 3, "macos">;
+def VK_TTOS_Windows : I32EnumAttrCase<"Windows", 4, "windows">;
+// API Level 30 => Android 11
+def VK_TTOS_Android30 : I32EnumAttrCase<"Android30", 5, "android30">;
+// API Level 31 => Android 12
+def VK_TTOS_Android31 : I32EnumAttrCase<"Android31", 6, "android31">;
+
+def VK_TargetOSAttr : VK_I32Enum<
+ "TargetTripleOS", "recognized target operating system", [
+ VK_TTOS_Unknown, VK_TTOS_Linux, VK_TTOS_iOS, VK_TTOS_macOS,
+ VK_TTOS_Windows, VK_TTOS_Android30, VK_TTOS_Android31,
+ ]>;
+
+//===----------------------------------------------------------------------===//
+// Subgroup features
+//===----------------------------------------------------------------------===//
+
+def VK_SF_Basic : I32BitEnumAttrCase<"Basic", 0x001>;
+def VK_SF_Vote : I32BitEnumAttrCase<"Vote", 0x002>;
+def VK_SF_Arithmetic : I32BitEnumAttrCase<"Arithmetic", 0x004>;
+def VK_SF_Ballot : I32BitEnumAttrCase<"Ballot", 0x008>;
+def VK_SF_Shuffle : I32BitEnumAttrCase<"Shuffle", 0x010>;
+def VK_SF_ShuffleRelative : I32BitEnumAttrCase<"ShuffleRelative", 0x020>;
+def VK_SF_Clustered : I32BitEnumAttrCase<"Clustered", 0x040>;
+def VK_SF_Quad : I32BitEnumAttrCase<"Quad", 0x080>;
+def VK_SF_PartitionedNV : I32BitEnumAttrCase<"PartitionedNV", 0x100>;
+
+def VK_SubgroupFeatureAttr : VK_BitEnumAttr<
+ "SubgroupFeature", "supported Vulkan subgroup feature", [
+ VK_SF_Basic, VK_SF_Vote, VK_SF_Arithmetic, VK_SF_Ballot, VK_SF_Shuffle,
+ VK_SF_ShuffleRelative, VK_SF_Clustered, VK_SF_Quad, VK_SF_PartitionedNV
+ ]>;
+
+// Matches VkScopeKHR and VkScopeNV.
+def VK_SKHR_Device : I32EnumAttrCase<"Device", 1>;
+def VK_SKHR_Workgroup : I32EnumAttrCase<"Workgroup", 2>;
+def VK_SKHR_Subgroup : I32EnumAttrCase<"Subgroup", 3>;
+def VK_SKHR_QueueFamily : I32EnumAttrCase<"QueueFamily", 5>;
+
+def VK_ScopeKHR_Attr :
+ VK_I32EnumAttr<"ScopeKHR", "valid VkScopeKHR", "scope", [
+ VK_SKHR_Device, VK_SKHR_Workgroup, VK_SKHR_Subgroup,
+ VK_SKHR_QueueFamily
+ ]>;
+
+#endif // IREE_DIALECT_VULKAN_BASE
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanDialect.cpp b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanDialect.cpp
new file mode 100644
index 0000000..2e78feb
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanDialect.cpp
@@ -0,0 +1,18 @@
+// Copyright 2020 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanDialect.h"
+
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h"
+
+namespace mlir::iree_compiler::IREE::Vulkan {
+
+VulkanDialect::VulkanDialect(MLIRContext *context)
+ : Dialect(getDialectNamespace(), context, TypeID::get<VulkanDialect>()) {
+ registerAttributes();
+}
+
+} // namespace mlir::iree_compiler::IREE::Vulkan
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanDialect.h b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanDialect.h
new file mode 100644
index 0000000..9cb3d01
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanDialect.h
@@ -0,0 +1,37 @@
+// Copyright 2020 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_VULKAN_IR_VULKANDIALECT_H_
+#define IREE_COMPILER_DIALECT_VULKAN_IR_VULKANDIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+
+namespace mlir::iree_compiler::IREE::Vulkan {
+
+class VulkanDialect : public Dialect {
+public:
+ explicit VulkanDialect(MLIRContext *context);
+
+ static StringRef getDialectNamespace() { return "vk"; }
+
+ //===--------------------------------------------------------------------===//
+ // Attribute
+ //===--------------------------------------------------------------------===//
+
+ /// Parses an attribute registered to this dialect.
+ Attribute parseAttribute(DialectAsmParser &parser, Type type) const override;
+
+ /// Prints an attribute registered to this dialect.
+ void printAttribute(Attribute, DialectAsmPrinter &printer) const override;
+
+private:
+ /// Register the attributes of this dialect.
+ void registerAttributes();
+};
+
+} // namespace mlir::iree_compiler::IREE::Vulkan
+
+#endif // IREE_COMPILER_DIALECT_VULKAN_IR_VULKANDIALECT_H_
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanTypes.cpp b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanTypes.cpp
new file mode 100644
index 0000000..fc67767
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanTypes.cpp
@@ -0,0 +1,13 @@
+// Copyright 2020 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanTypes.h"
+
+#include "llvm/ADT/StringExtras.h" // IWYU pragma: keep
+
+// clang-format off: must be included after all LLVM/MLIR headers.
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanEnums.cpp.inc" // IWYU pragma: keep
+// clang-format on
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanTypes.h b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanTypes.h
new file mode 100644
index 0000000..2422a85
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanTypes.h
@@ -0,0 +1,20 @@
+// Copyright 2020 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_VULKAN_IR_VULKANTYPES_H_
+#define IREE_COMPILER_DIALECT_VULKAN_IR_VULKANTYPES_H_
+
+#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+// clang-format off: must be included after all LLVM/MLIR headers.
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanEnums.h.inc" // IWYU pragma: export
+// clang-format on
+
+#endif // IREE_COMPILER_DIALECT_VULKAN_IR_VULKANTYPES_H_
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Vulkan/IR/test/BUILD.bazel
new file mode 100644
index 0000000..bbddf7d
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/IR/test/BUILD.bazel
@@ -0,0 +1,26 @@
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")
+
+package(
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = enforce_glob(
+ ["target_env.mlir"],
+ include = ["*.mlir"],
+ ),
+ cfg = "//compiler:lit.cfg.py",
+ tools = [
+ "//tools:iree-opt",
+ "@llvm-project//llvm:FileCheck",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/IR/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Vulkan/IR/test/CMakeLists.txt
new file mode 100644
index 0000000..cebe847
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/IR/test/CMakeLists.txt
@@ -0,0 +1,23 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Dialect/Vulkan/IR/test/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "target_env.mlir"
+ TOOLS
+ FileCheck
+ iree-opt
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/IR/test/target_env.mlir b/compiler/src/iree/compiler/Dialect/Vulkan/IR/test/target_env.mlir
new file mode 100644
index 0000000..343f1aa
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/IR/test/target_env.mlir
@@ -0,0 +1,150 @@
+// Test parsing and printing Vulkan target environment attribute.
+
+// RUN: iree-opt --allow-unregistered-dialect --split-input-file --verify-diagnostics %s | FileCheck %s
+
+"vk_configure_op"() {
+ // CHECK: #vk.target_env<v1.1, r(120), [VK_KHR_spirv_1_4, VK_KHR_storage_buffer_storage_class], AMD:DiscreteGPU, #vk.caps<
+ // CHECK-SAME: maxComputeSharedMemorySize = 16384,
+ // CHECK-SAME: maxComputeWorkGroupInvocations = 1024,
+ // CHECK-SAME: maxComputeWorkGroupSize = dense<[128, 8, 4]> : vector<3xi32>
+ // CHECK-SAME: subgroupFeatures = 63 : i32,
+ // CHECK-SAME: subgroupSize = 4
+ // CHECK-SAME: >>
+ target_env = #vk.target_env<v1.1, r(120), [VK_KHR_spirv_1_4, VK_KHR_storage_buffer_storage_class], AMD:DiscreteGPU, #vk.caps<
+ maxComputeSharedMemorySize = 16384,
+ maxComputeWorkGroupInvocations = 1024,
+ maxComputeWorkGroupSize = dense<[128, 8, 4]>: vector<3xi32>,
+ subgroupFeatures = 63 : i32,
+ subgroupSize = 4
+ >>
+} : () -> ()
+
+// -----
+
+"vk_configure_op"() {
+ // CHECK: #vk.target_env
+ // CHECK-SAME: VK_KHR_cooperative_matrix
+ // CHECK-SAME: cooperativeMatrixPropertiesKHR =
+ // CHECK-SAME: #vk.coop_matrix_props<mSize = 8, nSize = 8, kSize = 32,
+ // CHECK-SAME: aType = i8, bType = i8, cType = i32, resultType = i32,
+ // CHECK-SAME: accSat = false, scope = <Subgroup>>
+ // CHECK-SAME: #vk.coop_matrix_props<mSize = 8, nSize = 8, kSize = 16,
+ // CHECK-SAME: aType = f16, bType = f16, cType = f16, resultType = f16,
+ // CHECK-SAME: accSat = false, scope = <Subgroup>>
+ target_env =
+ #vk.target_env<v1.2, r(133),
+ [VK_KHR_storage_buffer_storage_class, VK_KHR_cooperative_matrix],
+ NVIDIA:DiscreteGPU,
+ #vk.caps<maxComputeSharedMemorySize = 49152,
+ maxComputeWorkGroupInvocations = 1024,
+ maxComputeWorkGroupSize = dense<[2147483647, 65535, 65535]> : vector<3xi32>,
+ subgroupFeatures = 63: i32, subgroupSize = 32,
+ cooperativeMatrixPropertiesKHR = [
+ #vk.coop_matrix_props<
+ mSize = 8, nSize = 8, kSize = 32,
+ aType = i8, bType = i8, cType = i32, resultType = i32,
+ accSat = false, scope = #vk.scope<Subgroup>>,
+ #vk.coop_matrix_props<
+ mSize = 8, nSize = 8, kSize = 16,
+ aType = f16, bType = f16, cType = f16, resultType = f16,
+ accSat = false, scope = #vk.scope<Subgroup>>
+ ]
+ >>
+} : () -> ()
+
+// -----
+
+"vk_configure_op"() {
+ // CHECK: Qualcomm:IntegratedGPU:100925441
+ // CHECK-SAME: shaderFloat64
+ // CHECK-SAME: shaderInt16
+ target_env = #vk.target_env<v1.1, r(120), [VK_KHR_spirv_1_4, VK_KHR_storage_buffer_storage_class], Qualcomm:IntegratedGPU:0x6040001, #vk.caps<
+ maxComputeSharedMemorySize = 16384,
+ maxComputeWorkGroupInvocations = 1024,
+ maxComputeWorkGroupSize = dense<[128, 8, 4]>: vector<3xi32>,
+ subgroupFeatures = 63: i32,
+ subgroupSize = 4,
+ shaderFloat64 = unit, shaderInt16 = unit
+ >>
+} : () -> ()
+
+// -----
+
+"unknown_vulkan_version"() {
+ // expected-error @+1 {{unknown Vulkan version: v10.8}}
+ target_env = #vk.target_env<v10.8, r(0), [], #vk.caps<
+ maxComputeWorkGroupInvocations = 128,
+ maxComputeWorkGroupSize = dense<[64, 4, 4]>: vector<3xi32>
+ >>
+} : () -> ()
+
+// -----
+
+"unknown_vulkan_extension"() {
+ // expected-error @+1 {{unknown Vulkan extension: VK_KHR_something}}
+ target_env = #vk.target_env<v1.0, r(10), [VK_KHR_something], #vk.caps<
+ maxComputeWorkGroupInvocations = 128,
+ maxComputeWorkGroupSize = dense<[64, 4, 4]>: vector<3xi32>
+ >>
+} : () -> ()
+
+// -----
+
+"wrong_vendor_id"() {
+ // expected-error @+1 {{unknown vendor: AVendor}}
+ target_env = #vk.target_env<v1.0, r(10), [], AVendor:Unknown, #vk.caps<
+ maxComputeSharedMemorySize = 16384,
+ maxComputeWorkGroupInvocations = 1024,
+ maxComputeWorkGroupSize = dense<[128, 8, 4]>: vector<3xi32>,
+ subgroupFeatures = 63: i32,
+ subgroupSize = 4
+ >>
+} : () -> ()
+
+// -----
+
+"wrong_device_type"() {
+ // expected-error @+1 {{unknown device type: ADeviceType}}
+ target_env = #vk.target_env<v1.0, r(10), [], NVIDIA:ADeviceType, #vk.caps<
+ maxComputeSharedMemorySize = 16384,
+ maxComputeWorkGroupInvocations = 1024,
+ maxComputeWorkGroupSize = dense<[128, 8, 4]>: vector<3xi32>,
+ subgroupFeatures = 63: i32,
+ subgroupSize = 4
+ >>
+} : () -> ()
+
+// -----
+
+"missing_core_1_1_properties_field"() {
+ target_env = #vk.target_env<v1.0, r(10), [], Unknown:Unknown, #vk.caps<
+ maxComputeWorkGroupInvocations = 128
+ // expected-error @+1 {{struct is missing required parameter: maxComputeSharedMemorySize}}
+ >>
+} : () -> ()
+
+// -----
+
+"unknown_core_1_1_properties_field"() {
+ target_env = #vk.target_env<v1.0, r(10), [], Unknown:Unknown, #vk.caps<
+ maxComputeSharedMemorySize = 16384,
+ maxComputeWorkGroupInvocations = 128,
+ maxComputeWorkGroupSize = dense<[64, 4, 4]>: vector<3xi32>,
+ // expected-error @+1 {{duplicate or unknown struct parameter name: moreStuff}}
+ moreStuff = 8: i32
+ >>
+} : () -> ()
+
+// -----
+
+"wrong_subgroup_bit"() {
+ target_env = #vk.target_env<v1.0, r(10), [], Unknown:Unknown, #vk.caps<
+ maxComputeSharedMemorySize = 16384,
+ maxComputeWorkGroupInvocations = 1024,
+ maxComputeWorkGroupSize = dense<[128, 8, 4]>: vector<3xi32>,
+ // expected-error @+2 {{invalid kind of attribute specified}}
+ // expected-error @+1 {{failed to parse VK_CapabilitiesAttr parameter 'subgroupFeatures'}}
+ subgroupFeatures = 0xffffffff: i32,
+ subgroupSize = 4
+ >>
+} : () -> ()
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/BUILD.bazel
new file mode 100644
index 0000000..cbbd06f
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/BUILD.bazel
@@ -0,0 +1,32 @@
+# Copyright 2019 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_compiler_cc_library(
+ name = "Utils",
+ srcs = [
+ "TargetEnvironment.cpp",
+ "TargetTriple.cpp",
+ ],
+ hdrs = [
+ "TargetEnvironment.h",
+ "TargetTriple.h",
+ ],
+ deps = [
+ "//compiler/src/iree/compiler/Dialect/Vulkan/IR",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:SPIRVDialect",
+ "@llvm-project//mlir:Support",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/CMakeLists.txt
new file mode 100644
index 0000000..8435767
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/CMakeLists.txt
@@ -0,0 +1,31 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Dialect/Vulkan/Utils/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ Utils
+ HDRS
+ "TargetEnvironment.h"
+ "TargetTriple.h"
+ SRCS
+ "TargetEnvironment.cpp"
+ "TargetTriple.cpp"
+ DEPS
+ LLVMSupport
+ MLIRIR
+ MLIRSPIRVDialect
+ MLIRSupport
+ iree::compiler::Dialect::Vulkan::IR
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetEnvironment.cpp b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetEnvironment.cpp
new file mode 100644
index 0000000..bcf3b55
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetEnvironment.cpp
@@ -0,0 +1,222 @@
+// Copyright 2020 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Vulkan/Utils/TargetEnvironment.h"
+
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanTypes.h"
+#include "iree/compiler/Dialect/Vulkan/Utils/TargetTriple.h"
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+
+namespace mlir::iree_compiler::IREE::Vulkan {
+
+namespace {
+
+/// Gets the corresponding SPIR-V version for the ggiven Vulkan target
+/// environment.
+spirv::Version convertVersion(Vulkan::TargetEnvAttr vkTargetEnv) {
+ // Special extension to enable SPIR-V 1.4.
+ const bool has14Ext = (llvm::is_contained(vkTargetEnv.getExtensions(),
+ Extension::VK_KHR_spirv_1_4));
+
+ switch (vkTargetEnv.getVersion()) {
+ case Version::V_1_0:
+ // Vulkan 1.0 only supports SPIR-V 1.0 by default.
+ return has14Ext ? spirv::Version::V_1_4 : spirv::Version::V_1_0;
+ case Version::V_1_1:
+ // Vulkan 1.1 supports up to SPIR-V 1.3 by default.
+ return has14Ext ? spirv::Version::V_1_4 : spirv::Version::V_1_3;
+ case Version::V_1_2:
+ // Vulkan 1.1 supports up to SPIR-V 1.5 by default.
+ return spirv::Version::V_1_5;
+ case Version::V_1_3:
+ // Vulkan 1.1 supports up to SPIR-V 1.6 by default.
+ return spirv::Version::V_1_6;
+ }
+ return spirv::Version::V_1_0;
+}
+
+/// Gets the corresponding SPIR-V extensions for the given Vulkan target
+/// environment.
+void convertExtensions(Vulkan::TargetEnvAttr vkTargetEnv,
+ SmallVectorImpl<spirv::Extension> &extensions) {
+ extensions.clear();
+
+ for (Extension ext : vkTargetEnv.getExtensions()) {
+ switch (ext) {
+ case Extension::VK_KHR_16bit_storage:
+ extensions.push_back(spirv::Extension::SPV_KHR_16bit_storage);
+ break;
+ case Extension::VK_KHR_8bit_storage:
+ extensions.push_back(spirv::Extension::SPV_KHR_8bit_storage);
+ break;
+ case Extension::VK_KHR_shader_float16_int8:
+ // This extension allows using certain SPIR-V capabilities.
+ break;
+ case Extension::VK_KHR_shader_integer_dot_product:
+ extensions.push_back(spirv::Extension::SPV_KHR_integer_dot_product);
+ break;
+ case Extension::VK_KHR_spirv_1_4:
+ // This extension only affects SPIR-V version.
+ break;
+ case Extension::VK_KHR_storage_buffer_storage_class:
+ extensions.push_back(
+ spirv::Extension::SPV_KHR_storage_buffer_storage_class);
+ break;
+ case Extension::VK_KHR_variable_pointers:
+ extensions.push_back(spirv::Extension::SPV_KHR_variable_pointers);
+ break;
+ case Extension::VK_EXT_subgroup_size_control:
+ // This extension allows specifying min/max subgroup size.
+ break;
+ case Extension::VK_KHR_cooperative_matrix:
+ extensions.push_back(spirv::Extension::SPV_KHR_cooperative_matrix);
+ break;
+ case Extension::VK_KHR_buffer_device_address:
+ extensions.push_back(spirv::Extension::SPV_KHR_physical_storage_buffer);
+ }
+ }
+}
+
+/// Gets the corresponding SPIR-V capabilities for the given Vulkan target
+/// environment.
+void convertCapabilities(Vulkan::TargetEnvAttr vkTargetEnv,
+ SmallVectorImpl<spirv::Capability> &capabilities) {
+ // Add unconditionally supported capabilities.
+ // Note that "Table 54. List of SPIR-V Capabilities and enabling features or
+ // extensions" in the Vulkan spec contains the full list. Right now omit those
+ // implicitly declared or not useful for us.
+ capabilities.assign({spirv::Capability::Shader});
+
+ auto vkCapabilities = vkTargetEnv.getCapabilitiesAttr();
+
+#define MAP_PRIMITIVE_TYPE(type) \
+ if (vkCapabilities.getShader##type()) \
+ capabilities.push_back(spirv::Capability::type)
+
+ MAP_PRIMITIVE_TYPE(Float64);
+ MAP_PRIMITIVE_TYPE(Float16);
+ MAP_PRIMITIVE_TYPE(Int64);
+ MAP_PRIMITIVE_TYPE(Int16);
+ MAP_PRIMITIVE_TYPE(Int8);
+#undef MAP_PRIMITIVE_TYPE
+
+#define MAP_8_16_BIT_STORAGE(vkFeature, spvCap) \
+ if (vkCapabilities.vkFeature()) \
+ capabilities.push_back(spirv::Capability::spvCap)
+
+ MAP_8_16_BIT_STORAGE(getStorageBuffer16BitAccess, StorageBuffer16BitAccess);
+ MAP_8_16_BIT_STORAGE(getUniformAndStorageBuffer16BitAccess, StorageUniform16);
+ MAP_8_16_BIT_STORAGE(getStoragePushConstant16, StoragePushConstant16);
+ MAP_8_16_BIT_STORAGE(getStorageBuffer8BitAccess, StorageBuffer8BitAccess);
+ MAP_8_16_BIT_STORAGE(getUniformAndStorageBuffer8BitAccess,
+ UniformAndStorageBuffer8BitAccess);
+ MAP_8_16_BIT_STORAGE(getStoragePushConstant8, StoragePushConstant8);
+#undef MAP_8_16_BIT_STORAGE
+
+ auto subgroupFeatures = vkCapabilities.getSubgroupFeatures().getValue();
+
+#define MAP_SUBGROUP_FEATURE(featureBit) \
+ if ((subgroupFeatures & SubgroupFeature::featureBit) == \
+ SubgroupFeature::featureBit) \
+ capabilities.push_back(spirv::Capability::GroupNonUniform##featureBit)
+
+ if ((subgroupFeatures & SubgroupFeature::Basic) == SubgroupFeature::Basic) {
+ capabilities.push_back(spirv::Capability::GroupNonUniform);
+ }
+ MAP_SUBGROUP_FEATURE(Vote);
+ MAP_SUBGROUP_FEATURE(Arithmetic);
+ MAP_SUBGROUP_FEATURE(Ballot);
+ MAP_SUBGROUP_FEATURE(Shuffle);
+ MAP_SUBGROUP_FEATURE(ShuffleRelative);
+ MAP_SUBGROUP_FEATURE(Clustered);
+ MAP_SUBGROUP_FEATURE(Quad);
+ MAP_SUBGROUP_FEATURE(PartitionedNV);
+#undef MAP_SUBGROUP_FEATURE
+ if (vkCapabilities.getPhysicalDeviceBufferAddresses()) {
+ capabilities.push_back(spirv::Capability::PhysicalStorageBufferAddresses);
+ }
+ if (vkCapabilities.getVariablePointers()) {
+ capabilities.push_back(spirv::Capability::VariablePointers);
+ }
+ if (vkCapabilities.getVariablePointersStorageBuffer()) {
+ capabilities.push_back(spirv::Capability::VariablePointersStorageBuffer);
+ }
+ if (vkCapabilities.getShaderIntegerDotProduct()) {
+ llvm::append_values(capabilities, spirv::Capability::DotProduct,
+ spirv::Capability::DotProductInputAll,
+ spirv::Capability::DotProductInput4x8BitPacked);
+ if (vkCapabilities.getShaderInt8()) {
+ capabilities.push_back(spirv::Capability::DotProductInput4x8Bit);
+ }
+ }
+ if (ArrayAttr attr = vkCapabilities.getCooperativeMatrixPropertiesKHR()) {
+ if (!attr.empty()) {
+ capabilities.push_back(spirv::Capability::CooperativeMatrixKHR);
+ }
+ }
+}
+
+/// Gets the corresponding SPIR-V resource limits for the given Vulkan target
+/// environment.
+spirv::ResourceLimitsAttr
+convertResourceLimits(Vulkan::TargetEnvAttr vkTargetEnv) {
+ MLIRContext *context = vkTargetEnv.getContext();
+ Builder builder(context);
+ auto vkCapabilities = vkTargetEnv.getCapabilitiesAttr();
+ SmallVector<Attribute, 1> khrCoopAttrs;
+ if (ArrayAttr attr = vkCapabilities.getCooperativeMatrixPropertiesKHR()) {
+ for (auto props :
+ attr.getAsRange<Vulkan::CooperativeMatrixPropertiesKHRAttr>()) {
+ auto scope = static_cast<spirv::Scope>(props.getScope().getValue());
+ khrCoopAttrs.push_back(spirv::CooperativeMatrixPropertiesKHRAttr::get(
+ context, props.getMSize(), props.getNSize(), props.getKSize(),
+ props.getAType(), props.getBType(), props.getCType(),
+ props.getResultType(), props.getAccSat(),
+ spirv::ScopeAttr::get(context, scope)));
+ }
+ }
+ auto sizeValues =
+ vkCapabilities.getMaxComputeWorkGroupSize().getValues<int32_t>();
+ SmallVector<int64_t> sizes;
+ sizes.insert(sizes.end(), sizeValues.begin(), sizeValues.end());
+ return spirv::ResourceLimitsAttr::get(
+ context, vkCapabilities.getMaxComputeSharedMemorySize(),
+ vkCapabilities.getMaxComputeWorkGroupInvocations(),
+ builder.getI64ArrayAttr(sizes), vkCapabilities.getSubgroupSize(),
+ vkCapabilities.getMinSubgroupSize(), vkCapabilities.getMaxSubgroupSize(),
+ ArrayAttr::get(context, khrCoopAttrs), ArrayAttr{});
+}
+
+} // namespace
+
+Vulkan::TargetEnvAttr getTargetEnvForTriple(MLIRContext *context,
+ llvm::StringRef triple) {
+ return TargetTriple::get(triple.data()).getTargetEnv(context);
+}
+
+spirv::TargetEnvAttr convertTargetEnv(Vulkan::TargetEnvAttr vkTargetEnv) {
+ auto spvVersion = convertVersion(vkTargetEnv);
+
+ SmallVector<spirv::Extension> spvExtensions;
+ convertExtensions(vkTargetEnv, spvExtensions);
+
+ SmallVector<spirv::Capability, 8> spvCapabilities;
+ convertCapabilities(vkTargetEnv, spvCapabilities);
+
+ auto spvLimits = convertResourceLimits(vkTargetEnv);
+
+ auto triple = spirv::VerCapExtAttr::get(
+ spvVersion, spvCapabilities, spvExtensions, vkTargetEnv.getContext());
+ return spirv::TargetEnvAttr::get(
+ triple, spvLimits, spirv::ClientAPI::Vulkan, vkTargetEnv.getVendorID(),
+ vkTargetEnv.getDeviceType(), vkTargetEnv.getDeviceID());
+}
+
+} // namespace mlir::iree_compiler::IREE::Vulkan
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetEnvironment.h b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetEnvironment.h
new file mode 100644
index 0000000..cc1d62a
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetEnvironment.h
@@ -0,0 +1,36 @@
+// Copyright 2020 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_VULKAN_UTILS_TARGETENVIRONMENT_H_
+#define IREE_COMPILER_DIALECT_VULKAN_UTILS_TARGETENVIRONMENT_H_
+
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+
+namespace mlir::iree_compiler::IREE::Vulkan {
+
+/// Returns the Vulkan target environment attribute for the given GPU triple.
+Vulkan::TargetEnvAttr getTargetEnvForTriple(MLIRContext *context,
+ llvm::StringRef triple);
+
+/// Converts the given Vulkan target environment into the corresponding SPIR-V
+/// target environment.
+///
+/// Vulkan and SPIR-V are two different domains working closely. A Vulkan target
+/// environment specifies the Vulkan version, extensions, features, and resource
+/// limits queried from a Vulkan implementation. These properties typically have
+/// corresponding SPIR-V bits, directly or indirectly. For example, by default,
+/// Vulkan 1.0 supports SPIR-V 1.0 and Vulkan 1.1 supports up to SPIR-V 1.3.
+/// If the VK_KHR_spirv_1_4 extension is available, then SPIR-V 1.4 can be used.
+/// Similarly, if the VK_KHR_variable_pointers extension is available, then
+/// the VariablePointersStorageBuffer capabilities on SPIR-V side can be
+/// activated. The function handles the mapping relationship between tese two
+/// domains.
+spirv::TargetEnvAttr convertTargetEnv(Vulkan::TargetEnvAttr vkTargetEnv);
+
+} // namespace mlir::iree_compiler::IREE::Vulkan
+
+#endif // IREE_COMPILER_DIALECT_VULKAN_UTILS_TARGETENVIRONMENT_H_
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetTriple.cpp b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetTriple.cpp
new file mode 100644
index 0000000..9564bf7
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetTriple.cpp
@@ -0,0 +1,539 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Vulkan/Utils/TargetTriple.h"
+
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanTypes.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+namespace mlir::iree_compiler::IREE::Vulkan {
+
+namespace {
+
+/// Returns the GPU vendor for the given target `triple`.
+spirv::Vendor getVendor(const TargetTriple &triple) {
+ switch (triple.getArch()) {
+ case TargetTripleArch::Unknown:
+ return spirv::Vendor::Unknown;
+ case TargetTripleArch::AMD_RDNAv1:
+ case TargetTripleArch::AMD_RDNAv2:
+ case TargetTripleArch::AMD_RDNAv3:
+ return spirv::Vendor::AMD;
+ case TargetTripleArch::ARM_Valhall:
+ return spirv::Vendor::ARM;
+ case TargetTripleArch::Apple_M1:
+ return spirv::Vendor::Apple;
+ case TargetTripleArch::Intel_Arc:
+ return spirv::Vendor::Intel;
+ case TargetTripleArch::NV_Turing:
+ case TargetTripleArch::NV_Ampere:
+ case TargetTripleArch::NV_Pascal:
+ return spirv::Vendor::NVIDIA;
+ case TargetTripleArch::QC_Adreno:
+ return spirv::Vendor::Qualcomm;
+ case TargetTripleArch::CPU:
+ switch (triple.getProduct()) {
+ case TargetTripleProduct::SwiftShader:
+ return spirv::Vendor::SwiftShader;
+ default:
+ return spirv::Vendor::Unknown;
+ }
+ default:
+ assert(false && "unhandled vendor");
+ return spirv::Vendor::Unknown;
+ }
+}
+
+/// Returns the GPU device type for the given target `triple`.
+spirv::DeviceType getDeviceType(const TargetTriple &triple) {
+ switch (triple.getArch()) {
+ case TargetTripleArch::Unknown:
+ return spirv::DeviceType::Unknown;
+ case TargetTripleArch::CPU:
+ return spirv::DeviceType::CPU;
+ case TargetTripleArch::AMD_RDNAv1:
+ case TargetTripleArch::AMD_RDNAv2:
+ case TargetTripleArch::AMD_RDNAv3:
+ case TargetTripleArch::NV_Turing:
+ case TargetTripleArch::NV_Ampere:
+ case TargetTripleArch::NV_Pascal:
+ case TargetTripleArch::Intel_Arc:
+ return spirv::DeviceType::DiscreteGPU;
+ case TargetTripleArch::Apple_M1:
+ case TargetTripleArch::ARM_Valhall:
+ case TargetTripleArch::QC_Adreno:
+ return spirv::DeviceType::IntegratedGPU;
+ default:
+ assert(false && "unhandled device type");
+ return spirv::DeviceType::Unknown;
+ }
+}
+
+/// Returns the Vulkan version for the given target `triple`.
+Vulkan::Version getVersion(const TargetTriple &triple) {
+ // Android 11/12 (API level 30/31) stays at Vulkan 1.1.
+ if (triple.getOS() == TargetTripleOS::Android30 ||
+ triple.getOS() == TargetTripleOS::Android31) {
+ return Version::V_1_1;
+ }
+
+ // SwiftShader and MoltenVK stays at Vulkan 1.1.
+ if (triple.getProduct() == TargetTripleProduct::SwiftShader ||
+ triple.getProduct() == TargetTripleProduct::MoltenVK) {
+ return Version::V_1_1;
+ }
+
+ // For unknown architecture, be conservative and use a reasonable lowest
+ // denominator.
+ if (triple.getArch() == TargetTripleArch::Unknown) {
+ return Version::V_1_1;
+ }
+
+ return Version::V_1_3;
+}
+
+/// Writes the Vulkan extensions supported by the given `triple` into
+/// `extensions`.
+///
+/// Note that this is an "approximation": Android compatibility will provide
+/// some minimal guarantee but still different Android devices can have
+/// different set of extensions, depending on the Android and GPU driver
+/// version. The GPU triple is a handy way to specify the target but we cannot
+/// encode all the information in the triple.
+void getExtensions(const TargetTriple &triple,
+ llvm::SmallVectorImpl<Extension> &extensions) {
+ // Mobile GPUs need to take Android version into consideration.
+ switch (triple.getArch()) {
+ case TargetTripleArch::Apple_M1: {
+ // Example: https://vulkan.gpuinfo.org/displayreport.php?id=14673
+ return append_values(extensions, Extension::VK_KHR_16bit_storage,
+ Extension::VK_KHR_8bit_storage,
+ Extension::VK_KHR_shader_float16_int8,
+ Extension::VK_KHR_storage_buffer_storage_class,
+ Extension::VK_KHR_buffer_device_address,
+ Extension::VK_KHR_variable_pointers);
+ }
+ case TargetTripleArch::ARM_Valhall: {
+ // Example: https://vulkan.gpuinfo.org/displayreport.php?id=10312
+ return append_values(extensions, Extension::VK_KHR_16bit_storage,
+ Extension::VK_KHR_8bit_storage,
+ Extension::VK_KHR_shader_float16_int8,
+ Extension::VK_KHR_shader_integer_dot_product,
+ Extension::VK_KHR_spirv_1_4,
+ Extension::VK_KHR_storage_buffer_storage_class,
+ Extension::VK_KHR_variable_pointers);
+ }
+ case TargetTripleArch::QC_Adreno: {
+ // Example: https://vulkan.gpuinfo.org/displayreport.php?id=10983 (11)
+ // Example: https://vulkan.gpuinfo.org/displayreport.php?id=16312 (12)
+ append_values(extensions, Extension::VK_KHR_16bit_storage,
+ Extension::VK_KHR_shader_float16_int8,
+ Extension::VK_KHR_spirv_1_4,
+ Extension::VK_KHR_storage_buffer_storage_class,
+ Extension::VK_KHR_variable_pointers);
+ if (triple.getOS() == TargetTripleOS::Android31) {
+ extensions.push_back(Extension::VK_KHR_8bit_storage);
+ }
+ return;
+ }
+ default:
+ break;
+ }
+
+ // SwiftShader is very limited regarding functionalities.
+ if (getVendor(triple) == spirv::Vendor::SwiftShader) {
+ extensions.push_back(Extension::VK_KHR_storage_buffer_storage_class);
+ return;
+ }
+
+ // For unknown architecture, be conservative and use a reasonable lowest
+ // denominator.
+ if (triple.getArch() == TargetTripleArch::Unknown) {
+ // The following extensions have 90%+ device coverage from
+ // https://vulkan.gpuinfo.org/listextensions.php.
+ const Extension list[] = {
+ Extension::VK_KHR_storage_buffer_storage_class,
+ Extension::VK_KHR_variable_pointers,
+ };
+ return append_range(extensions, list);
+ }
+
+ llvm::append_values(
+ extensions, // Desktop GPUs typically support all extensions we care.
+ Extension::VK_KHR_16bit_storage, Extension::VK_KHR_8bit_storage,
+ Extension::VK_KHR_shader_float16_int8,
+ Extension::VK_KHR_shader_integer_dot_product, Extension::VK_KHR_spirv_1_4,
+ Extension::VK_KHR_storage_buffer_storage_class,
+ Extension::VK_KHR_buffer_device_address,
+ Extension::VK_KHR_variable_pointers,
+ Extension::VK_EXT_subgroup_size_control);
+ if (getVendor(triple) == spirv::Vendor::NVIDIA ||
+ triple.getArch() == TargetTripleArch::AMD_RDNAv3) {
+ extensions.push_back(Extension::VK_KHR_cooperative_matrix);
+ }
+}
+
+/// Returns the Vulkan features/limits/capabilities supported by the given
+/// `triple`.
+///
+/// Note that this is an "approximation": Android compatibility will provide
+/// some minimal guarantee but still different Android devices can have
+/// different set of extensions, depending on the Android and GPU driver
+/// version. The GPU triple is a handy way to specify the target but we cannot
+/// encode all the information in the triple.
+CapabilitiesAttr getCapabilities(const TargetTriple &triple,
+ MLIRContext *context) {
+ // Default to Vulkan required limits.
+ int maxComputeSharedMemorySize = 16384;
+ int maxComputeWorkGroupInvocations = 128;
+ std::array<int, 3> maxComputeWorkGroupSize = {128, 128, 64};
+
+ int subgroupSize = 32;
+ SubgroupFeature subgroupFeatures = SubgroupFeature::Basic;
+ std::optional<int> minSubgroupSize, maxSubgroupSize;
+
+ bool shaderFloat16 = false, shaderFloat64 = false;
+ bool shaderInt8 = false, shaderInt16 = false, shaderInt64 = false;
+
+ bool shaderIntegerDotProduct = false;
+
+ bool storageBuffer16BitAccess = false, storagePushConstant16 = false;
+ bool uniformAndStorageBuffer16BitAccess = false;
+ bool storageBuffer8BitAccess = false, storagePushConstant8 = false;
+ bool uniformAndStorageBuffer8BitAccess = false;
+ bool physicalStorageBufferAddresses = false;
+
+ bool variablePointers = false, variablePointersStorageBuffer = false;
+
+ SmallVector<Attribute> coopmatCases;
+
+ Builder builder(context);
+
+ switch (triple.getArch()) {
+ case TargetTripleArch::AMD_RDNAv3: {
+ auto i8t = builder.getIntegerType(8);
+ auto i32t = builder.getIntegerType(32);
+ auto f16t = builder.getF16Type();
+ auto f32t = builder.getF32Type();
+ auto scope = ScopeKHRAttr::get(context, ScopeKHR::Subgroup);
+
+ // Note: The driver also advertises saturating arithmetic, so we can
+ // declare this when needed.
+ coopmatCases.push_back(CooperativeMatrixPropertiesKHRAttr::get(
+ context,
+ /*mSize=*/16, /*nSize=*/16, /*kSize=*/16, /*aType=*/i8t,
+ /*bType=*/i8t, /*cType=*/i32t, /*resultType=*/i32t, /*accSat=*/false,
+ /*scope=*/scope));
+ coopmatCases.push_back(CooperativeMatrixPropertiesKHRAttr::get(
+ context,
+ /*mSize=*/16, /*nSize=*/16, /*kSize=*/16, /*aType=*/f16t,
+ /*bType=*/f16t, /*cType=*/f16t, /*resultType=*/f16t, /*accSat=*/false,
+ /*scope=*/scope));
+ coopmatCases.push_back(CooperativeMatrixPropertiesKHRAttr::get(
+ context,
+ /*mSize=*/16, /*nSize=*/16, /*kSize=*/16, /*aType=*/f16t,
+ /*bType=*/f16t, /*cType=*/f32t, /*resultType=*/f32t, /*accSat=*/false,
+ /*scope=*/scope));
+ }
+ LLVM_FALLTHROUGH;
+ case TargetTripleArch::AMD_RDNAv1:
+ case TargetTripleArch::AMD_RDNAv2:
+ // Example: https://vulkan.gpuinfo.org/displayreport.php?id=10906
+ maxComputeSharedMemorySize = 65536;
+ maxComputeWorkGroupInvocations = 1024;
+ maxComputeWorkGroupSize = {1024, 1024, 1024};
+
+ subgroupSize = 64, minSubgroupSize = 32, maxSubgroupSize = 64;
+ subgroupFeatures = SubgroupFeature::Basic | SubgroupFeature::Vote |
+ SubgroupFeature::Arithmetic | SubgroupFeature::Ballot |
+ SubgroupFeature::Shuffle |
+ SubgroupFeature::ShuffleRelative |
+ SubgroupFeature::Clustered | SubgroupFeature::Quad;
+
+ shaderFloat16 = shaderFloat64 = true;
+ shaderInt8 = shaderInt16 = shaderInt64 = true;
+
+ shaderIntegerDotProduct = true;
+
+ storageBuffer16BitAccess = storagePushConstant16 = true;
+ uniformAndStorageBuffer16BitAccess = true;
+ storageBuffer8BitAccess = true, storagePushConstant8 = true;
+ uniformAndStorageBuffer8BitAccess = true;
+ physicalStorageBufferAddresses = true;
+
+ variablePointers = variablePointersStorageBuffer = true;
+ break;
+ case TargetTripleArch::Apple_M1:
+ // Example: https://vulkan.gpuinfo.org/displayreport.php?id=14673
+ maxComputeSharedMemorySize = 32768;
+ maxComputeWorkGroupInvocations = 1024;
+ maxComputeWorkGroupSize = {1024, 1024, 1024};
+
+ subgroupSize = 32;
+ subgroupFeatures = SubgroupFeature::Basic | SubgroupFeature::Vote |
+ SubgroupFeature::Arithmetic | SubgroupFeature::Ballot |
+ SubgroupFeature::Shuffle |
+ SubgroupFeature::ShuffleRelative | SubgroupFeature::Quad;
+
+ shaderFloat16 = true;
+ shaderFloat64 = false;
+ shaderInt8 = shaderInt16 = shaderInt64 = true;
+
+ storageBuffer16BitAccess = storagePushConstant16 = true;
+ uniformAndStorageBuffer16BitAccess = true;
+ storageBuffer8BitAccess = true, storagePushConstant8 = true;
+ uniformAndStorageBuffer8BitAccess = true;
+ physicalStorageBufferAddresses = true;
+
+ variablePointers = variablePointersStorageBuffer = true;
+ break;
+ case TargetTripleArch::ARM_Valhall:
+ // Example: https://vulkan.gpuinfo.org/displayreport.php?id=10312 (11)
+ // Example: https://vulkan.gpuinfo.org/displayreport.php?id=15142 (12)
+ maxComputeSharedMemorySize = 32768;
+ maxComputeWorkGroupInvocations = 512;
+ maxComputeWorkGroupSize = {512, 512, 512};
+
+ subgroupSize = 16;
+ subgroupFeatures = SubgroupFeature::Basic | SubgroupFeature::Vote |
+ SubgroupFeature::Arithmetic | SubgroupFeature::Ballot |
+ SubgroupFeature::Clustered | SubgroupFeature::Quad;
+
+ if (triple.getOS() == TargetTripleOS::Android31) {
+ subgroupFeatures = subgroupFeatures | SubgroupFeature::Shuffle |
+ SubgroupFeature::ShuffleRelative;
+ }
+
+ shaderFloat16 = shaderInt8 = shaderInt16 = true;
+
+ shaderIntegerDotProduct = true;
+
+ storageBuffer16BitAccess = storagePushConstant16 = true;
+ uniformAndStorageBuffer16BitAccess = true;
+ storageBuffer8BitAccess = true, storagePushConstant8 = true;
+ uniformAndStorageBuffer8BitAccess = true;
+
+ variablePointers = variablePointersStorageBuffer = true;
+ break;
+ case TargetTripleArch::CPU:
+ if (triple.getProduct() == TargetTripleProduct::SwiftShader) {
+ // Example: https://vulkan.gpuinfo.org/displayreport.php?id=11023
+ maxComputeSharedMemorySize = 16384;
+
+ subgroupSize = 4;
+ subgroupFeatures = SubgroupFeature::Basic | SubgroupFeature::Vote |
+ SubgroupFeature::Arithmetic | SubgroupFeature::Ballot |
+ SubgroupFeature::Shuffle |
+ SubgroupFeature::ShuffleRelative;
+ }
+ break;
+ case TargetTripleArch::NV_Turing:
+ case TargetTripleArch::NV_Ampere: {
+ // Example: https://vulkan.gpuinfo.org/displayreport.php?id=11252
+ maxComputeSharedMemorySize = 49152;
+ maxComputeWorkGroupInvocations = 1024;
+ maxComputeWorkGroupSize = {1024, 1024, 64};
+
+ subgroupSize = 32, minSubgroupSize = 32, maxSubgroupSize = 32;
+ subgroupFeatures = SubgroupFeature::Basic | SubgroupFeature::Vote |
+ SubgroupFeature::Arithmetic | SubgroupFeature::Ballot |
+ SubgroupFeature::Shuffle |
+ SubgroupFeature::ShuffleRelative |
+ SubgroupFeature::Clustered | SubgroupFeature::Quad;
+
+ shaderFloat16 = shaderFloat64 = true;
+ shaderInt8 = shaderInt16 = shaderInt64 = true;
+
+ shaderIntegerDotProduct = true;
+
+ storageBuffer16BitAccess = storagePushConstant16 = true;
+ uniformAndStorageBuffer16BitAccess = true;
+ storageBuffer8BitAccess = true, storagePushConstant8 = true;
+ uniformAndStorageBuffer8BitAccess = true;
+ physicalStorageBufferAddresses = true;
+
+ variablePointers = variablePointersStorageBuffer = true;
+
+ auto i8t = builder.getIntegerType(8);
+ auto i32t = builder.getIntegerType(32);
+ auto f16t = builder.getF16Type();
+ auto f32t = builder.getF32Type();
+ auto scope = ScopeKHRAttr::get(context, ScopeKHR::Subgroup);
+
+ // Note: the driver also advertises other shapes that can enabled when
+ // needed.
+ coopmatCases.push_back(CooperativeMatrixPropertiesKHRAttr::get(
+ context,
+ /*mSize=*/8, /*nSize=*/8, /*kSize=*/32, /*aType=*/i8t,
+ /*bType=*/i8t, /*cType=*/i32t, /*resultType=*/i32t, /*accSat=*/false,
+ /*scope=*/scope));
+ coopmatCases.push_back(CooperativeMatrixPropertiesKHRAttr::get(
+ context,
+ /*mSize=*/16, /*nSize=*/16, /*kSize=*/16, /*aType=*/f16t,
+ /*bType=*/f16t, /*cType=*/f16t, /*resultType=*/f16t, /*accSat=*/false,
+ /*scope=*/scope));
+ coopmatCases.push_back(CooperativeMatrixPropertiesKHRAttr::get(
+ context,
+ /*mSize=*/16, /*nSize=*/16, /*kSize=*/16, /*aType=*/f16t,
+ /*bType=*/f16t, /*cType=*/f32t, /*resultType=*/f32t, /*accSat=*/false,
+ /*scope=*/scope));
+ } break;
+ case TargetTripleArch::NV_Pascal:
+ // Example: https://vulkan.gpuinfo.org/displayreport.php?id=17937
+ maxComputeSharedMemorySize = 49152;
+ maxComputeWorkGroupInvocations = 1536;
+ maxComputeWorkGroupSize = {1536, 1024, 64};
+
+ subgroupSize = 32, minSubgroupSize = 32, maxSubgroupSize = 32;
+ subgroupFeatures = SubgroupFeature::Basic | SubgroupFeature::Vote |
+ SubgroupFeature::Arithmetic | SubgroupFeature::Ballot |
+ SubgroupFeature::Shuffle |
+ SubgroupFeature::ShuffleRelative |
+ SubgroupFeature::Clustered | SubgroupFeature::Quad;
+
+ shaderFloat16 = shaderFloat64 = true;
+ shaderInt8 = shaderInt16 = shaderInt64 = true;
+
+ shaderIntegerDotProduct = true;
+
+ storageBuffer16BitAccess = storagePushConstant16 = true;
+ uniformAndStorageBuffer16BitAccess = true;
+ storageBuffer8BitAccess = true, storagePushConstant8 = true;
+ uniformAndStorageBuffer8BitAccess = true;
+ physicalStorageBufferAddresses = true;
+
+ variablePointers = variablePointersStorageBuffer = true;
+ break;
+ case TargetTripleArch::QC_Adreno:
+ // Example: https://vulkan.gpuinfo.org/displayreport.php?id=10983 (11)
+ // Example: https://vulkan.gpuinfo.org/displayreport.php?id=16312 (12)
+ maxComputeSharedMemorySize = 32768;
+ maxComputeWorkGroupInvocations = 1024;
+ maxComputeWorkGroupSize = {1024, 1024, 64};
+
+ subgroupSize = 64;
+ subgroupFeatures = SubgroupFeature::Basic | SubgroupFeature::Vote |
+ SubgroupFeature::Arithmetic | SubgroupFeature::Ballot |
+ SubgroupFeature::Shuffle |
+ SubgroupFeature::ShuffleRelative | SubgroupFeature::Quad;
+
+ shaderFloat16 = shaderInt8 = shaderInt16 = true;
+
+ storageBuffer16BitAccess = true;
+ if (triple.getOS() == TargetTripleOS::Android31) {
+ storageBuffer8BitAccess = true;
+ }
+
+ variablePointers = variablePointersStorageBuffer = true;
+ break;
+ case TargetTripleArch::Intel_Arc:
+ // Example: https://vulkan.gpuinfo.org/displayreport.php?id=19818
+ maxComputeSharedMemorySize = 32768;
+ maxComputeWorkGroupInvocations = 1024;
+ maxComputeWorkGroupSize = {1024, 1024, 64};
+
+ subgroupSize = 32, minSubgroupSize = 8, maxSubgroupSize = 32;
+ subgroupFeatures = SubgroupFeature::Basic | SubgroupFeature::Vote |
+ SubgroupFeature::Arithmetic | SubgroupFeature::Ballot |
+ SubgroupFeature::Shuffle |
+ SubgroupFeature::ShuffleRelative |
+ SubgroupFeature::Clustered | SubgroupFeature::Quad;
+
+ shaderFloat16 = true;
+ shaderFloat64 = false;
+ shaderInt8 = shaderInt16 = true;
+ shaderInt64 = false;
+
+ shaderIntegerDotProduct = true;
+
+ storageBuffer16BitAccess = storagePushConstant16 = true;
+ uniformAndStorageBuffer16BitAccess = true;
+ storageBuffer8BitAccess = true, storagePushConstant8 = true;
+ uniformAndStorageBuffer8BitAccess = true;
+ physicalStorageBufferAddresses = true;
+
+ variablePointers = variablePointersStorageBuffer = true;
+ break;
+ case TargetTripleArch::Unknown:
+ // Use the largest subgroup size we can find across various vendors.
+ subgroupSize = 64;
+ // The following capabilities have 90%+ device coverage (Vulkan 1.1+)
+ // from https://vulkan.gpuinfo.org/listfeaturesextensions.php.
+ variablePointers = variablePointersStorageBuffer = false;
+ // Use Vulkan default for others.
+ break;
+ }
+
+ auto getBoolAttr = [context](bool value) {
+ return value ? UnitAttr::get(context) : UnitAttr();
+ };
+
+ return CapabilitiesAttr::get(
+ context, maxComputeSharedMemorySize, maxComputeWorkGroupInvocations,
+ builder.getI32VectorAttr(maxComputeWorkGroupSize),
+ getBoolAttr(shaderFloat64), getBoolAttr(shaderInt16),
+ getBoolAttr(shaderInt64),
+ SubgroupFeatureAttr::get(context, subgroupFeatures), subgroupSize,
+ minSubgroupSize, maxSubgroupSize, getBoolAttr(storageBuffer16BitAccess),
+ getBoolAttr(storagePushConstant16),
+ getBoolAttr(uniformAndStorageBuffer16BitAccess),
+ getBoolAttr(storageBuffer8BitAccess), getBoolAttr(storagePushConstant8),
+ getBoolAttr(uniformAndStorageBuffer8BitAccess),
+ getBoolAttr(physicalStorageBufferAddresses), getBoolAttr(shaderFloat16),
+ getBoolAttr(shaderInt8), getBoolAttr(shaderIntegerDotProduct),
+ getBoolAttr(variablePointersStorageBuffer), getBoolAttr(variablePointers),
+ builder.getArrayAttr(coopmatCases));
+}
+} // namespace
+
+TargetTriple TargetTriple::get(const char *triple) {
+ llvm::SmallVector<llvm::StringRef, 3> fragments;
+ llvm::SplitString(triple, fragments, "-");
+ TargetTripleArch arch = TargetTripleArch::Unknown;
+ if (auto symbol = symbolizeTargetTripleArch(fragments[0])) {
+ arch = symbol.value();
+ }
+ TargetTripleProduct product = TargetTripleProduct::Unknown;
+ if (auto symbol = symbolizeTargetTripleProduct(fragments[1])) {
+ product = symbol.value();
+ }
+ TargetTripleOS os = TargetTripleOS::Unknown;
+ if (auto symbol = symbolizeTargetTripleOS(fragments[2])) {
+ os = symbol.value();
+ }
+ return TargetTriple(arch, product, os);
+}
+
+TargetTriple::TargetTriple(TargetTripleArch arch, TargetTripleProduct product,
+ TargetTripleOS os)
+ : arch(arch), product(product), os(os) {}
+
+std::string TargetTriple::getTriple() const {
+ llvm::StringRef archStr = stringifyTargetTripleArch(arch);
+ llvm::StringRef productStr = stringifyTargetTripleProduct(product);
+ llvm::StringRef osStr = stringifyTargetTripleOS(os);
+ return llvm::formatv("{0}-{1}-{2}", archStr, productStr, osStr);
+}
+
+TargetEnvAttr TargetTriple::getTargetEnv(MLIRContext *context) const {
+ SmallVector<Extension> extensions;
+ getExtensions(*this, extensions);
+ return TargetEnvAttr::get(getVersion(*this), /*revision=*/0, extensions,
+ getVendor(*this), getDeviceType(*this),
+ spirv::TargetEnvAttr::kUnknownDeviceID,
+ getCapabilities(*this, context));
+}
+
+} // namespace mlir::iree_compiler::IREE::Vulkan
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetTriple.h b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetTriple.h
new file mode 100644
index 0000000..7ea5e0d
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetTriple.h
@@ -0,0 +1,67 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_VULKAN_UTILS_TARGETTRIPLE_H_
+#define IREE_COMPILER_DIALECT_VULKAN_UTILS_TARGETTRIPLE_H_
+
+#include <string>
+
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h"
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanTypes.h"
+#include "mlir/IR/MLIRContext.h"
+
+namespace mlir::iree_compiler::IREE::Vulkan {
+
+/// GPU triple definitions to describe GPU targets for compilers.
+///
+/// We use "triple" here to match common compiler language: historically one
+/// would describe a CPU compiler target as a string containing exactly three
+/// fields. But here the configuration is for GPU and there can exist a lot of
+/// architectures/vendors/products/systems. What matters differ from CPU
+/// triples. We define it in the form of:
+///
+/// <vendor/arch>-<product>-<os>
+///
+/// For example:
+/// ampere-rtx3080-windows
+/// rdna1-5700xt-linux
+/// adreno-a650-android30
+/// valhall-unknown-android30
+/// cpu-swiftshader-unknown
+///
+/// Vendor and architecture are combined together because:
+/// * Typically each GPU vendor has its own set of architectures. So given the
+/// architecture we know which vendor it is from. This is different from CPU
+/// land where the the same architecture can be implemented by mulitple
+/// vendors.
+/// * There are vendors that we don't have public information regarding its
+/// architectures.
+/// We need a field for product to differentiate the cases where the
+/// architecture is unknown or ambiguous.
+class TargetTriple {
+public:
+ static TargetTriple get(const char *triple);
+
+ TargetTriple(TargetTripleArch, TargetTripleProduct, TargetTripleOS);
+
+ TargetTripleArch getArch() const { return arch; }
+ TargetTripleProduct getProduct() const { return product; }
+ TargetTripleOS getOS() const { return os; }
+
+ /// Returns the triple string.
+ std::string getTriple() const;
+
+ TargetEnvAttr getTargetEnv(MLIRContext *context) const;
+
+private:
+ TargetTripleArch arch;
+ TargetTripleProduct product;
+ TargetTripleOS os;
+};
+
+} // namespace mlir::iree_compiler::IREE::Vulkan
+
+#endif // IREE_COMPILER_DIALECT_VULKAN_UTILS_TARGETTRIPLE_H_
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/BUILD.bazel
new file mode 100644
index 0000000..687fa49
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/BUILD.bazel
@@ -0,0 +1,37 @@
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_cmake_extra_content")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")
+
+package(
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_cmake_extra_content(
+ content = """
+if(NOT IREE_TARGET_BACKEND_VULKAN_SPIRV)
+ return()
+endif()
+""",
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = enforce_glob(
+ [
+ "target_env_conversion.mlir",
+ ],
+ include = ["*.mlir"],
+ ),
+ cfg = "//compiler:lit.cfg.py",
+ tools = [
+ "//tools:iree-opt",
+ "@llvm-project//llvm:FileCheck",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/CMakeLists.txt
new file mode 100644
index 0000000..bb5cbe5
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/CMakeLists.txt
@@ -0,0 +1,27 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+if(NOT IREE_TARGET_BACKEND_VULKAN_SPIRV)
+ return()
+endif()
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "target_env_conversion.mlir"
+ TOOLS
+ FileCheck
+ iree-opt
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir
new file mode 100644
index 0000000..3a23031
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir
@@ -0,0 +1,86 @@
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=vulkan-spirv},iree-hal-transformation-pipeline{serialize-executables=false})' %s | FileCheck %s --check-prefix=DEFAULT
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=vulkan-spirv},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-vulkan-target-triple=adreno-a650-android30 %s | FileCheck %s --check-prefix=ADRENO
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=vulkan-spirv},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-vulkan-target-triple=valhall-unknown-android31 %s | FileCheck %s --check-prefix=VALHALL
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=vulkan-spirv},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-vulkan-target-triple=turing-t4-linux %s | FileCheck %s --check-prefix=TURING
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=vulkan-spirv},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-vulkan-target-triple=rdna1-5700xt-windows %s | FileCheck %s --check-prefix=RDNA1
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=vulkan-spirv},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-vulkan-target-triple=rdna3-6900xtx-windows %s | FileCheck %s --check-prefix=RDNA3
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=vulkan-spirv},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-vulkan-target-triple=m1-moltenvk-macos %s | FileCheck %s --check-prefix=M1
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=vulkan-spirv},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-vulkan-target-triple=arc-770-windows %s | FileCheck %s --check-prefix=ARC
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=vulkan-spirv},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-vulkan-target-triple=pascal-1080-windows %s | FileCheck %s --check-prefix=PASCAL
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=vulkan-spirv},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-vulkan-target-env="#vk.target_env<v1.1, r(120), [VK_KHR_spirv_1_4, VK_KHR_storage_buffer_storage_class], AMD:DiscreteGPU, #vk.caps<maxComputeSharedMemorySize = 16384, maxComputeWorkGroupInvocations = 1024, maxComputeWorkGroupSize = dense<[128, 8, 4]>: vector<3xi32>, subgroupFeatures = 63 : i32, subgroupSize = 4 >>" %s | FileCheck %s --check-prefix=ENV
+
+// TODO(antiagainst): Passing in lenghty strings as command-line options is not
+// optimal. We should consider creating a dedicated test pass to pick up
+// #vk.target_env in input assembly and convert them.
+
+// DEFAULT: #spirv.target_env<#spirv.vce<v1.3,
+// DEFAULT-SAME: [Shader, GroupNonUniform], [SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+// DEFAULT-SAME: api=Vulkan, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64, cooperative_matrix_properties_khr = []>>
+
+// ADRENO: #spirv.target_env<#spirv.vce<v1.4,
+// ADRENO-SAME: [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer],
+// ADRENO-SAME: [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+// ADRENO-SAME: api=Vulkan, Qualcomm:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], subgroup_size = 64, cooperative_matrix_properties_khr = []>>
+
+// VALHALL: #spirv.target_env<#spirv.vce<v1.4,
+// VALHALL-SAME: [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer, DotProduct, DotProductInputAll, DotProductInput4x8BitPacked, DotProductInput4x8Bit],
+// VALHALL-SAME: [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_integer_dot_product, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+// VALHALL-SAME: api=Vulkan, ARM:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16, cooperative_matrix_properties_khr = []>>
+
+// TURING: #spirv.target_env<#spirv.vce<v1.6,
+// TURING-SAME: [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, PhysicalStorageBufferAddresses, VariablePointers, VariablePointersStorageBuffer, DotProduct, DotProductInputAll, DotProductInput4x8BitPacked, DotProductInput4x8Bit, CooperativeMatrixKHR],
+// TURING-SAME: [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_integer_dot_product, SPV_KHR_storage_buffer_storage_class, SPV_KHR_physical_storage_buffer, SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>,
+// TURING-SAME: api=Vulkan, NVIDIA:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], min_subgroup_size = 32, max_subgroup_size = 32, cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 8, n_size = 8, k_size = 32, a_type = i8, b_type = i8, c_type = i32, result_type = i32, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>>]>>
+
+// RDNA1: #spirv.target_env<#spirv.vce<v1.6,
+// RDNA1-SAME: [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, PhysicalStorageBufferAddresses, VariablePointers, VariablePointersStorageBuffer, DotProduct, DotProductInputAll, DotProductInput4x8BitPacked, DotProductInput4x8Bit],
+// RDNA1-SAME: [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_integer_dot_product, SPV_KHR_storage_buffer_storage_class, SPV_KHR_physical_storage_buffer, SPV_KHR_variable_pointers]>,
+// RDNA1-SAME: api=Vulkan, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64, cooperative_matrix_properties_khr = []>>
+
+// RDNA3: #spirv.target_env<#spirv.vce<v1.6,
+// RDNA3-SAME: [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, PhysicalStorageBufferAddresses, VariablePointers, VariablePointersStorageBuffer, DotProduct, DotProductInputAll, DotProductInput4x8BitPacked, DotProductInput4x8Bit, CooperativeMatrixKHR],
+// RDNA3-SAME: [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_integer_dot_product, SPV_KHR_storage_buffer_storage_class, SPV_KHR_physical_storage_buffer, SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>,
+// RDNA3-SAME: api=Vulkan, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64, cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = i8, b_type = i8, c_type = i32, result_type = i32, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>>]>>
+
+// M1: #spirv.target_env<#spirv.vce<v1.3,
+// M1-SAME: [Shader, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformQuad, PhysicalStorageBufferAddresses, VariablePointers, VariablePointersStorageBuffer],
+// M1-SAME: [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_physical_storage_buffer, SPV_KHR_variable_pointers]>,
+// M1-SAME: api=Vulkan, Apple:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], cooperative_matrix_properties_khr = []>>
+
+// ARC: #spirv.target_env<#spirv.vce<v1.6,
+// ARC-SAME: [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, PhysicalStorageBufferAddresses, VariablePointers, VariablePointersStorageBuffer, DotProduct, DotProductInputAll, DotProductInput4x8BitPacked, DotProductInput4x8Bit],
+// ARC-SAME: [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_integer_dot_product, SPV_KHR_storage_buffer_storage_class, SPV_KHR_physical_storage_buffer, SPV_KHR_variable_pointers]>,
+// ARC-SAME: api=Vulkan, Intel:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], min_subgroup_size = 8, max_subgroup_size = 32, cooperative_matrix_properties_khr = []>>}>
+
+// PASCAL: #spirv.target_env<#spirv.vce<v1.6,
+// PASCAL-SAME: [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, PhysicalStorageBufferAddresses, VariablePointers, VariablePointersStorageBuffer, DotProduct, DotProductInputAll, DotProductInput4x8BitPacked, DotProductInput4x8Bit],
+// PASCAL-SAME: [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_integer_dot_product, SPV_KHR_storage_buffer_storage_class, SPV_KHR_physical_storage_buffer, SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>,
+// PASCAL-SAME: api=Vulkan, NVIDIA:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1536, max_compute_workgroup_size = [1536, 1024, 64], min_subgroup_size = 32, max_subgroup_size = 32, cooperative_matrix_properties_khr = []>>}>
+
+// ENV: #spirv.target_env<#spirv.vce<v1.4,
+// ENV-SAME: [Shader, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative],
+// ENV-SAME: [SPV_KHR_storage_buffer_storage_class]>,
+// ENV-SAME: api=Vulkan, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [128, 8, 4], subgroup_size = 4, cooperative_matrix_properties_khr = []>>
+
+stream.executable public @reduce_dispatch {
+ stream.executable.export @reduce_dispatch workgroups(%arg0: index) -> (index, index, index) {
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
+ stream.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @reduce_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding) {
+ %c0 = arith.constant 0 : index
+ %arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
+ %arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<f32>>
+ %0 = tensor.empty() : tensor<f32>
+ %1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
+ %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%1 : tensor<16xf32>) outs(%0 : tensor<f32>) {
+ ^bb0(%arg2: f32, %arg3: f32):
+ %4 = arith.addf %arg2, %arg3 : f32
+ linalg.yield %4 : f32
+ } -> tensor<f32>
+ flow.dispatch.tensor.store %3, %arg1, offsets=[], sizes=[], strides=[] : tensor<f32> -> !flow.dispatch.tensor<writeonly:tensor<f32>>
+ return
+ }
+ }
+}
diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel
index 95c620c..314a09c 100644
--- a/compiler/src/iree/compiler/Tools/BUILD.bazel
+++ b/compiler/src/iree/compiler/Tools/BUILD.bazel
@@ -56,6 +56,7 @@
"//compiler/src/iree/compiler/Dialect/VM/Transforms",
"//compiler/src/iree/compiler/Dialect/VMVX/IR:VMVXDialect",
"//compiler/src/iree/compiler/Dialect/VMVX/Transforms",
+ "//compiler/src/iree/compiler/Dialect/Vulkan/IR",
"//compiler/src/iree/compiler/ExternalInterfaces:ExternalModels",
"//compiler/src/iree/compiler/GlobalOptimization/Interfaces",
"//compiler/src/iree/compiler/InputConversion/Common",
diff --git a/compiler/src/iree/compiler/Tools/CMakeLists.txt b/compiler/src/iree/compiler/Tools/CMakeLists.txt
index ee8d820..a38a4db 100644
--- a/compiler/src/iree/compiler/Tools/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Tools/CMakeLists.txt
@@ -53,6 +53,7 @@
iree::compiler::Dialect::VM::Transforms
iree::compiler::Dialect::VMVX::IR::VMVXDialect
iree::compiler::Dialect::VMVX::Transforms
+ iree::compiler::Dialect::Vulkan::IR
iree::compiler::ExternalInterfaces::ExternalModels
iree::compiler::GlobalOptimization::Interfaces::Interfaces
iree::compiler::InputConversion::Common
diff --git a/compiler/src/iree/compiler/Tools/init_iree_dialects.h b/compiler/src/iree/compiler/Tools/init_iree_dialects.h
index 4d33387..0472723 100644
--- a/compiler/src/iree/compiler/Tools/init_iree_dialects.h
+++ b/compiler/src/iree/compiler/Tools/init_iree_dialects.h
@@ -22,17 +22,20 @@
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/TransformOps/UtilTransformOps.h"
#include "iree/compiler/Dialect/VM/IR/VMDialect.h"
#include "iree/compiler/Dialect/VMVX/IR/VMVXDialect.h"
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanDialect.h"
#include "iree/compiler/ExternalInterfaces/Interfaces.h"
#include "iree/compiler/GlobalOptimization/Interfaces/Interfaces.h"
#include "iree/compiler/Modules/HAL/Inline/IR/HALInlineDialect.h"
#include "iree/compiler/Modules/HAL/Loader/IR/HALLoaderDialect.h"
#include "iree/compiler/Modules/IO/Parameters/IR/IOParametersDialect.h"
#include "iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.h"
+#include "mlir/IR/Dialect.h"
namespace mlir::iree_compiler {
@@ -53,7 +56,8 @@
IREE::Util::UtilDialect,
IREE::VM::VMDialect,
IREE::VMVX::VMVXDialect,
- IREE::VectorExt::IREEVectorExtDialect>();
+ IREE::VectorExt::IREEVectorExtDialect,
+ IREE::Vulkan::VulkanDialect>();
// clang-format on
// External models.
diff --git a/samples/custom_dispatch/vulkan/shaders/example.mlir b/samples/custom_dispatch/vulkan/shaders/example.mlir
index ef10fb7..d655b98 100644
--- a/samples/custom_dispatch/vulkan/shaders/example.mlir
+++ b/samples/custom_dispatch/vulkan/shaders/example.mlir
@@ -14,12 +14,9 @@
// and compilation options (architectures, etc) can be embedded for runtime
// selection.
#spirv_target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
- iree.gpu.target = #iree_gpu.target<
- arch = "", features = "spirv:v1.3,cap:Shader", wgp = <
- compute = fp32|int32, storage = b32, subgroup = none,
- dot = none, mma = [], subgroup_size_choices = [64, 64],
- max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128,
- max_workgroup_memory_bytes = 16384>
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.3, [Shader, GroupNonUniform], [SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+ #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64>
>
}>
diff --git a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir
index 36912bb..5cdbcac 100644
--- a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir
+++ b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir
@@ -14,12 +14,9 @@
// and compilation options (architectures, etc) can be embedded for runtime
// selection.
#spirv_target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
- iree.gpu.target = #iree_gpu.target<
- arch = "", features = "spirv:v1.3,cap:Shader", wgp = <
- compute = fp32|int32, storage = b32, subgroup = none,
- dot = none, mma = [], subgroup_size_choices = [64, 64],
- max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128,
- max_workgroup_memory_bytes = 16384>
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.3, [Shader, GroupNonUniform], [SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+ #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64>
>
}>
diff --git a/samples/custom_dispatch/vulkan/shaders/example_transform.mlir b/samples/custom_dispatch/vulkan/shaders/example_transform.mlir
index b4885a0..3766a30 100644
--- a/samples/custom_dispatch/vulkan/shaders/example_transform.mlir
+++ b/samples/custom_dispatch/vulkan/shaders/example_transform.mlir
@@ -18,12 +18,10 @@
// custom kernel. For things to be truly portable, we need to be able to compare
// executable configurations.
#spirv_target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
- iree.gpu.target = #iree_gpu.target<
- arch = "", features = "spirv:v1.3,cap:Shader", wgp = <
- compute = fp32|int32, storage = b32, subgroup = shuffle|arithmetic,
- dot = none, mma = [], subgroup_size_choices = [64, 64],
- max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128,
- max_workgroup_memory_bytes = 16384>
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformArithmetic, GroupNonUniformBallot],
+ [SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+ #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64>
>
}>
diff --git a/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir b/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir
index 5bcdafe..70ad898 100644
--- a/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir
+++ b/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir
@@ -7,12 +7,10 @@
// The configuration used for executable compilation.
// This specifies the device configurations that support this custom kernel.
#spirv_target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
- iree.gpu.target = #iree_gpu.target<
- arch = "", features = "spirv:v1.3,cap:Shader", wgp = <
- compute = fp32|int32, storage = b32, subgroup = shuffle|arithmetic,
- dot = none, mma = [], subgroup_size_choices = [64, 64],
- max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128,
- max_workgroup_memory_bytes = 16384>
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformArithmetic, GroupNonUniformBallot],
+ [SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+ #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64>
>
}>
diff --git a/samples/transform_dialect/example_module.mlir b/samples/transform_dialect/example_module.mlir
index 585bb25..13128e1 100644
--- a/samples/transform_dialect/example_module.mlir
+++ b/samples/transform_dialect/example_module.mlir
@@ -25,21 +25,19 @@
// }
// }
-#target = #iree_gpu.target<arch = "", features = "spirv:v1.3,cap:Shader", wgp = <
- compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [64, 64],
- max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
+#target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, GroupNonUniform], [SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Vulkan, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64, cooperative_matrix_properties_khr = []>>
module attributes {
hal.device.targets = [
#hal.device.target<"vulkan", [
#hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
- iree.gpu.target = #target
+ spirv.target_env = #target_env
}>
]>
]
} {
hal.executable private @example_module_dispatch_0 {
- hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {iree.gpu.target = #target}>) {
+ hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #target_env}>) {
hal.executable.export public @example_module_dispatch_0_generic_80_f32 ordinal(0) layout(
#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) {
^bb0(%arg0: !hal.device):
@@ -65,7 +63,7 @@
}
}
hal.executable private @example_module_dispatch_1 {
- hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {iree.gpu.target = #target}>) {
+ hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #target_env}>) {
hal.executable.export public @example_module_dispatch_1_matmul_16x16x5_f32 ordinal(0) layout(
#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) {
^bb0(%arg0: !hal.device):
@@ -89,7 +87,7 @@
}
}
hal.executable private @example_module_dispatch_2 {
- hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {iree.gpu.target = #target}>) {
+ hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #target_env}>) {
hal.executable.export public @example_module_dispatch_2_generic_16x16_f32 ordinal(0) layout(
#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) {
^bb0(%arg0: !hal.device):