[metal] Populate proper capability and limits for SPIR-V CodeGen (#12391)
These capabilities and limits are for Apple GPUs supporting Metal3.
Progress towards https://github.com/openxla/iree/issues/4370
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
index fa40189..71ba118 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -26,15 +27,57 @@
namespace IREE {
namespace HAL {
-// TODO(antiagainst): provide a proper target environment for Metal.
static spirv::TargetEnvAttr getMetalTargetEnv(MLIRContext *context) {
- auto triple = spirv::VerCapExtAttr::get(
- spirv::Version::V_1_0, {spirv::Capability::Shader},
- {spirv::Extension::SPV_KHR_storage_buffer_storage_class}, context);
+ using spirv::Capability;
+ using spirv::Extension;
+
+ // Capabilities and limits according to Metal 3 devices.
+ const std::array<Extension, 4> extensions = {
+ Extension::SPV_KHR_16bit_storage,
+ Extension::SPV_KHR_8bit_storage,
+ Extension::SPV_KHR_storage_buffer_storage_class,
+ Extension::SPV_KHR_variable_pointers,
+ };
+ const std::array<Capability, 21> capabilities = {
+ Capability::Shader,
+ Capability::Int8,
+ Capability::Int16,
+ Capability::Int64,
+ Capability::Float16,
+ Capability::UniformAndStorageBuffer8BitAccess,
+ Capability::StorageBuffer8BitAccess,
+ Capability::StoragePushConstant8,
+ Capability::StorageUniform16,
+ Capability::StorageBuffer16BitAccess,
+ Capability::StoragePushConstant16,
+ Capability::GroupNonUniform,
+ Capability::GroupNonUniformVote,
+ Capability::GroupNonUniformArithmetic,
+ Capability::GroupNonUniformBallot,
+ Capability::GroupNonUniformShuffle,
+ Capability::GroupNonUniformShuffleRelative,
+ Capability::GroupNonUniformQuad,
+ Capability::StoragePushConstant16,
+ Capability::VariablePointers,
+ Capability::VariablePointersStorageBuffer,
+ };
+ auto limits = spirv::ResourceLimitsAttr::get(
+ context,
+ /*max_compute_shared_memory_size=*/32768,
+ /*max_compute_workgroup_invocations=*/1024,
+ /*max_compute_workgroup_size=*/
+ Builder(context).getI32ArrayAttr({1024, 1024, 1024}),
+ /*subgroup_size=*/32,
+ /*min_subgroup_size=*/std::nullopt,
+ /*max_subgroup_size=*/std::nullopt,
+ /*cooperative_matrix_properties_nv=*/ArrayAttr());
+
+ auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0, capabilities,
+ extensions, context);
+ // Further assuming Apple GPUs.
return spirv::TargetEnvAttr::get(
- triple, spirv::getDefaultResourceLimits(context), spirv::ClientAPI::Metal,
- spirv::Vendor::Unknown, spirv::DeviceType::Unknown,
- spirv::TargetEnvAttr::kUnknownDeviceID);
+ triple, limits, spirv::ClientAPI::Metal, spirv::Vendor::Apple,
+ spirv::DeviceType::IntegratedGPU, spirv::TargetEnvAttr::kUnknownDeviceID);
}
class MetalSPIRVTargetBackend : public TargetBackend {