| // Copyright 2020 The IREE Authors | 
 | // | 
 | // Licensed under the Apache License v2.0 with LLVM Exceptions. | 
 | // See https://llvm.org/LICENSE.txt for license information. | 
 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
 |  | 
 | #include "compiler/plugins/target/MetalSPIRV/SPIRVToMSL.h" | 
 |  | 
 | #include "llvm/ADT/STLExtras.h" | 
 | #include "llvm/Support/Debug.h" | 
 | #include "llvm/Support/raw_ostream.h" | 
 |  | 
 | // Disable exception handling in favor of assertions. | 
 | #define SPIRV_CROSS_EXCEPTIONS_TO_ASSERTIONS | 
 | #include "third_party/spirv_cross/spirv_msl.hpp" | 
 |  | 
 | #define DEBUG_TYPE "spirv-to-msl" | 
 |  | 
 | /// The [[buffer(N)]] index for push constants. | 
 | /// Note that this MUST be kept consistent with the Metal HAL driver. | 
 | #define IREE_HAL_METAL_PUSH_CONSTANT_BUFFER_INDEX 3 | 
 |  | 
 | namespace mlir::iree_compiler { | 
 |  | 
 | namespace { | 
 | class SPIRVToMSLCompiler : public SPIRV_CROSS_NAMESPACE::CompilerMSL { | 
 | public: | 
 |   using CompilerMSL::CompilerMSL; | 
 |  | 
 |   MetalShader::ThreadGroupSize | 
 |   getWorkgroupSizeForEntryPoint(StringRef entryName) { | 
 |     const auto &entryPoint = get_entry_point( | 
 |         entryName.str(), spv::ExecutionModel::ExecutionModelGLCompute); | 
 |     const auto &workgroupSize = entryPoint.workgroup_size; | 
 |     // TODO(antiagainst): support specialization constant. | 
 |     if (workgroupSize.constant != 0) | 
 |       return {0, 0, 0}; | 
 |     return {workgroupSize.x, workgroupSize.y, workgroupSize.z}; | 
 |   } | 
 |  | 
 |   // A struct containing a resource descriptor's information. | 
 |   struct Descriptor { | 
 |     uint32_t set; | 
 |     uint32_t binding; | 
 |  | 
 |     Descriptor(uint32_t s, uint32_t b) : set(s), binding(b) {} | 
 |  | 
 |     friend bool operator<(const Descriptor &l, const Descriptor &r) { | 
 |       return std::tie(l.set, l.binding) < std::tie(r.set, r.binding); | 
 |     } | 
 |   }; | 
 |  | 
 |   // Updates `descriptors` with resource set and binding number pairs in | 
 |   // increasing order, and `hasPushConstant` if with push constants. | 
 |   // Returns true if no unsupported cases are encountered. | 
 |   bool getResources(SmallVectorImpl<Descriptor> *descriptors, | 
 |                     bool *hasPushConstant) { | 
 |     descriptors->clear(); | 
 |     *hasPushConstant = false; | 
 |  | 
 |     // Iterate over all variables in the SPIR-V blob. | 
 |     bool hasUnknownCase = false; | 
 |     ir.for_each_typed_id<SPIRV_CROSS_NAMESPACE::SPIRVariable>( | 
 |         [&](uint32_t id, SPIRV_CROSS_NAMESPACE::SPIRVariable &var) { | 
 |           auto storage = var.storage; | 
 |           switch (storage) { | 
 |             // Non-interface variables. We don't care. | 
 |           case spv::StorageClassFunction: | 
 |           case spv::StorageClassPrivate: | 
 |           case spv::StorageClassWorkgroup: | 
 |             // Builtin variables. We don't care either. | 
 |           case spv::StorageClassInput: | 
 |             return; | 
 |           case spv::StorageClassPushConstant: | 
 |             *hasPushConstant = true; | 
 |             return; | 
 |           case spv::StorageClassUniform: | 
 |           case spv::StorageClassStorageBuffer: { | 
 |             uint32_t setNo = get_decoration(id, spv::DecorationDescriptorSet); | 
 |             uint32_t bindingNo = get_decoration(id, spv::DecorationBinding); | 
 |             descriptors->emplace_back(setNo, bindingNo); | 
 |             return; | 
 |           } | 
 |           default: | 
 |             break; | 
 |           } | 
 |           hasUnknownCase = true; | 
 |         }); | 
 |  | 
 |     llvm::sort(*descriptors); | 
 |     return !hasUnknownCase; | 
 |   } | 
 |  | 
 |   Options getCompilationOptions(IREE::HAL::MetalTargetPlatform platform) { | 
 |     // TODO(antiagainst): fill out the following according to the Metal GPU | 
 |     // family. | 
 |     SPIRVToMSLCompiler::Options spvCrossOptions; | 
 |     switch (platform) { | 
 |     case IREE::HAL::MetalTargetPlatform::macOS: | 
 |       spvCrossOptions.platform = SPIRVToMSLCompiler::Options::Platform::macOS; | 
 |       break; | 
 |     case IREE::HAL::MetalTargetPlatform::iOS: | 
 |     case IREE::HAL::MetalTargetPlatform::iOSSimulator: | 
 |       spvCrossOptions.platform = SPIRVToMSLCompiler::Options::Platform::iOS; | 
 |       break; | 
 |     } | 
 |     spvCrossOptions.msl_version = | 
 |         SPIRVToMSLCompiler::Options::make_msl_version(3, 0); | 
 |     // Eanble using Metal argument buffers. It is more akin to Vulkan descriptor | 
 |     // sets, which is how IREE HAL models resource bindings and mappings. | 
 |     spvCrossOptions.argument_buffers = true; | 
 |     return spvCrossOptions; | 
 |   } | 
 | }; | 
 | } // namespace | 
 |  | 
 | std::optional<std::pair<MetalShader, std::string>> | 
 | crossCompileSPIRVToMSL(IREE::HAL::MetalTargetPlatform targetPlatform, | 
 |                        llvm::ArrayRef<uint32_t> spvBinary, | 
 |                        StringRef entryPoint) { | 
 |   SPIRVToMSLCompiler spvCrossCompiler(spvBinary.data(), spvBinary.size()); | 
 |  | 
 |   // All spirv-cross operations work on the current entry point. It should be | 
 |   // set right after the cross compiler construction. | 
 |   spvCrossCompiler.set_entry_point( | 
 |       entryPoint.str(), spv::ExecutionModel::ExecutionModelGLCompute); | 
 |  | 
 |   SmallVector<SPIRVToMSLCompiler::Descriptor> descriptors; | 
 |   bool hasPushConstant = false; | 
 |   if (!spvCrossCompiler.getResources(&descriptors, &hasPushConstant)) | 
 |     return std::nullopt; | 
 |  | 
 |   // Explicitly set the argument buffer [[id(N)]] location for each SPIR-V | 
 |   // resource variable. | 
 |   for (const auto &descriptor : descriptors) { | 
 |     SPIRV_CROSS_NAMESPACE::MSLResourceBinding binding = {}; | 
 |     binding.stage = spv::ExecutionModelGLCompute; | 
 |     binding.desc_set = descriptor.set; | 
 |     binding.binding = descriptor.binding; | 
 |     // We only interact with buffers in IREE. | 
 |     binding.msl_buffer = descriptor.binding; | 
 |  | 
 |     spvCrossCompiler.add_msl_resource_binding(binding); | 
 |   } | 
 |   // If push constants are used, explicitly set its [[buffer(N)]] location too. | 
 |   if (hasPushConstant) { | 
 |     SPIRV_CROSS_NAMESPACE::MSLResourceBinding binding = {}; | 
 |     binding.stage = spv::ExecutionModelGLCompute; | 
 |     binding.desc_set = | 
 |         SPIRV_CROSS_NAMESPACE::ResourceBindingPushConstantDescriptorSet; | 
 |     binding.binding = SPIRV_CROSS_NAMESPACE::ResourceBindingPushConstantBinding; | 
 |     binding.msl_buffer = IREE_HAL_METAL_PUSH_CONSTANT_BUFFER_INDEX; | 
 |  | 
 |     spvCrossCompiler.add_msl_resource_binding(binding); | 
 |   } | 
 |  | 
 |   auto spvCrossOptions = spvCrossCompiler.getCompilationOptions(targetPlatform); | 
 |   spvCrossCompiler.set_msl_options(spvCrossOptions); | 
 |  | 
 |   std::string mslSource = spvCrossCompiler.compile(); | 
 |   // Get the revised entry point name. Cross compiling to MSL generates source | 
 |   // code, where we may run into the case that we are using reserved keyword for | 
 |   // the entry point name, e.g., `abs`. Under such circumstances, it will be | 
 |   // revised to avoid collision. | 
 |   const auto &spirvEntryPoint = spvCrossCompiler.get_entry_point( | 
 |       entryPoint.str(), spv::ExecutionModel::ExecutionModelGLCompute); | 
 |   LLVM_DEBUG({ | 
 |     llvm::dbgs() << "Original entry point name: '" << spirvEntryPoint.orig_name | 
 |                  << "'\n"; | 
 |     llvm::dbgs() << "Revised entry point name: '" << spirvEntryPoint.name | 
 |                  << "'\n"; | 
 |     llvm::dbgs() << "Generated MSL:\n-----\n" << mslSource << "\n-----\n"; | 
 |   }); | 
 |  | 
 |   auto workgroupSize = | 
 |       spvCrossCompiler.getWorkgroupSizeForEntryPoint(entryPoint); | 
 |   if (!workgroupSize.x || !workgroupSize.y || !workgroupSize.z) { | 
 |     return std::nullopt; | 
 |   } | 
 |   return std::make_pair(MetalShader{std::move(mslSource), workgroupSize}, | 
 |                         spirvEntryPoint.name); | 
 | } | 
 |  | 
 | } // namespace mlir::iree_compiler |