[metal] Properly set resource index when cross compiling (9/9) (#3530)

SPIRV-Cross is not able to guess the Metal argument buffer index
to use for a SPIR-V resource variable for our use case. It may
just assign consecutive index numbers to the variables encountered
in traversal order. So we need to help it by explicitly setting
the expected index.

This right now only handles single descriptor set, where we just
map the binding number to the argument buffer index. This is based
on the assumption that IREE does not leave "holes" in binding
number usage for a descriptor set.

This completes the initial CL train to bring up Metal in IREE,
including both the runtime and CodeGen. \o/
diff --git a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.cpp b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.cpp
index 2f5abb7..5d89aa3 100644
--- a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.cpp
+++ b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.cpp
@@ -14,6 +14,9 @@
 
 #include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.h"
 
+#include <vector>
+
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include "spirv_cross/spirv_msl.hpp"
@@ -37,6 +40,66 @@
     if (workgroupSize.constant != 0) return {0, 0, 0};
     return {workgroupSize.x, workgroupSize.y, workgroupSize.z};
   }
+
+  // A struct containing a resource descriptor's information.
+  struct Descriptor {
+    uint32_t set;
+    uint32_t binding;
+
+    Descriptor(uint32_t s, uint32_t b) : set(s), binding(b) {}
+
+    friend bool operator<(const Descriptor& l, const Descriptor& r) {
+      return std::tie(l.set, l.binding) < std::tie(r.set, r.binding);
+    }
+  };
+
+  // Returns all all resource buffer descriptors' set and binding number pairs
+  // in increasing order.
+  std::vector<Descriptor> getBufferSetBindingPairs() {
+    std::vector<Descriptor> descriptors;
+
+    // Iterate over all variables in the SPIR-V blob.
+    ir.for_each_typed_id<SPIRV_CROSS_NAMESPACE::SPIRVariable>(
+        [&](uint32_t id, SPIRV_CROSS_NAMESPACE::SPIRVariable& var) {
+          auto storage = var.storage;
+          switch (storage) {
+              // Non-interface variables. We don't care.
+            case spv::StorageClassFunction:
+            case spv::StorageClassPrivate:
+            case spv::StorageClassWorkgroup:
+              // Builtin variables. We don't care either.
+            case spv::StorageClassInput:
+              return;
+            default:
+              break;
+          }
+          if (storage == spv::StorageClassUniform ||
+              storage == spv::StorageClassStorageBuffer) {
+            uint32_t setNo = get_decoration(id, spv::DecorationDescriptorSet);
+            uint32_t bindingNo = get_decoration(id, spv::DecorationBinding);
+            descriptors.emplace_back(setNo, bindingNo);
+            return;
+          }
+          // TODO(antiagainst): push constant
+          llvm_unreachable("unspported storage class in SPIRVToMSLCompiler");
+        });
+
+    llvm::sort(descriptors);
+    return descriptors;
+  }
+
+  Options getCompilationOptions() {
+    // TODO(antiagainst): fill out the following according to the Metal GPU
+    // family.
+    SPIRVToMSLCompiler::Options spvCrossOptions;
+    spvCrossOptions.platform = SPIRVToMSLCompiler::Options::Platform::macOS;
+    spvCrossOptions.msl_version =
+        SPIRVToMSLCompiler::Options::make_msl_version(2, 0);
+    // Eanble using Metal argument buffers. It is more akin to Vulkan descriptor
+    // sets, which is how IREE HAL models resource bindings and mappings.
+    spvCrossOptions.argument_buffers = true;
+    return spvCrossOptions;
+  }
 };
 }  // namespace
 
@@ -49,15 +112,25 @@
   spvCrossCompiler.set_entry_point(
       entryPoint, spv::ExecutionModel::ExecutionModelGLCompute);
 
-  // TODO(antiagainst): fill out the following according to the Metal GPU
-  // family.
-  SPIRVToMSLCompiler::Options spvCrossOptions;
-  spvCrossOptions.platform = SPIRVToMSLCompiler::Options::Platform::macOS;
-  spvCrossOptions.msl_version =
-      SPIRVToMSLCompiler::Options::make_msl_version(2, 0);
-  // Eanble using Metal argument buffers. It is more akin to Vulkan descriptor
-  // sets, which is how IREE HAL models resource bindings and mappings.
-  spvCrossOptions.argument_buffers = true;
+  // Explicitly set the argument buffer index for each SPIR-V resource variable.
+  auto descriptors = spvCrossCompiler.getBufferSetBindingPairs();
+  for (const auto& descriptor : descriptors) {
+    if (descriptor.set != 0) {
+      llvm_unreachable(
+          "multiple descriptor set unimplemented in SPIRVToMSLCompiler");
+    }
+
+    SPIRV_CROSS_NAMESPACE::MSLResourceBinding binding = {};
+    binding.stage = spv::ExecutionModelGLCompute;
+    binding.desc_set = descriptor.set;
+    binding.binding = descriptor.binding;
+    // We only interact with buffers in IREE.
+    binding.msl_buffer = descriptor.binding;
+
+    spvCrossCompiler.add_msl_resource_binding(binding);
+  }
+
+  auto spvCrossOptions = spvCrossCompiler.getCompilationOptions();
   spvCrossCompiler.set_msl_options(spvCrossOptions);
 
   std::string mslSource = spvCrossCompiler.compile();
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index f60833f..a3c9060 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -37,24 +37,34 @@
     "abs.mlir"
     "add.mlir"
     "broadcast.mlir"
+    "broadcast_add.mlir"
     "broadcast_in_dim.mlir"
-    "convert.mlir"
+    "clamp.mlir"
+    "compare.mlir"
     "constant.mlir"
+    "convert.mlir"
     "cosine.mlir"
+    "divide.mlir"
     "exponential.mlir"
+    "gather.mlir"
     "log.mlir"
     "log_plus_one.mlir"
     "maximum.mlir"
     "minimum.mlir"
     "multiply.mlir"
     "negate.mlir"
+    "remainder.mlir"
     "reshape.mlir"
     "rsqrt.mlir"
+    "select.mlir"
     "sine.mlir"
     "slice.mlir"
     "sqrt.mlir"
+    "subtract.mlir"
     "tanh.mlir"
+    "torch_index_select.mlir"
     "transpose.mlir"
+    "while.mlir"
   TARGET_BACKEND
     metal-spirv
   DRIVER
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index 7d8a568..1f89648 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -19,24 +19,34 @@
     "abs.mlir"
     "add.mlir"
     "broadcast.mlir"
+    "broadcast_add.mlir"
     "broadcast_in_dim.mlir"
-    "convert.mlir"
+    "clamp.mlir"
+    "compare.mlir"
     "constant.mlir"
+    "convert.mlir"
     "cosine.mlir"
+    "divide.mlir"
     "exponential.mlir"
+    "gather.mlir"
     "log.mlir"
     "log_plus_one.mlir"
     "maximum.mlir"
     "minimum.mlir"
     "multiply.mlir"
     "negate.mlir"
+    "remainder.mlir"
     "reshape.mlir"
     "rsqrt.mlir"
+    "select.mlir"
     "sine.mlir"
     "slice.mlir"
     "sqrt.mlir"
+    "subtract.mlir"
     "tanh.mlir"
+    "torch_index_select.mlir"
     "transpose.mlir"
+    "while.mlir"
   TARGET_BACKEND
     metal-spirv
   DRIVER