Changing descriptor ops to use index instead of int32 and SSA values. (#4765)

diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
index 59eb6bc..af4c7c3 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
@@ -305,8 +305,10 @@
     auto byteLength = value->getByteLength();
     if (!byteLength) return failure();
 
-    bindings.push_back(std::make_tuple(bindingOrdinal++, value->getBuffer(),
-                                       zeroOffset, byteLength));
+    bindings.push_back(
+        std::make_tuple(rewriter.createOrFold<mlir::ConstantIndexOp>(
+                            dispatchOp.getLoc(), bindingOrdinal++),
+                        value->getBuffer(), zeroOffset, byteLength));
     return success();
   };
   for (auto it : llvm::enumerate(dispatchOp.operands())) {
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir
index 2ff6847..abc7d9f 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -print-ir-after-all -split-input-file -iree-convert-to-hal -canonicalize %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -iree-convert-to-hal -canonicalize %s | IreeFileCheck %s
 
 hal.executable @ex0 {
   hal.interface @interface {
@@ -26,7 +26,7 @@
   // CHECK-NEXT: hal.command_buffer.begin %[[CMD]]
   %0 = flow.ex.stream.fragment(%arg1 = %cst : index, %arg2 = %arg0 : tensor<128xf32>) -> tensor<128xf32> {
     //  CHECK-DAG: %[[EXE_LAYOUT:.+]] = hal.executable_layout.lookup
-    //      CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %[[EXE_LAYOUT]], set=0, bindings=[0 = (%arg0, %c0, %c512), 1 = (%[[TMP_BUF]], %c0, %c512)]
+    //      CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %[[EXE_LAYOUT]], set = %c0, bindings = [%c0 = (%arg0, %c0, %c512), %c1 = (%[[TMP_BUF]], %c0, %c512)]
     //      CHECK: hal.command_buffer.dispatch.symbol {{.+}}, @ex0::@vmla::@entry0, workgroup_xyz
     //      CHECK: hal.command_buffer.execution_barrier
     %1 = flow.dispatch @ex0::@entry0[%arg1] (%arg2) : (tensor<128xf32>) -> tensor<128xf32>
@@ -137,7 +137,7 @@
       %arg6 = %c1024 : index,
       %arg7 = %c512 : index
     ) -> tensor<4x7x1024xf32> {
-    // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %executable_layout, set=0, bindings=[0 = (%arg0, %c0, %c2688), 1 = (%buffer, %c0, %c114688)]
+    // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %executable_layout, set = %c0, bindings = [%c0 = (%arg0, %c0, %c2688), %c1 = (%buffer, %c0, %c114688)]
     // CHECK: hal.command_buffer.dispatch.symbol {{.+}}, @ex::@tgt::@entry, workgroup_xyz
     %0 = flow.dispatch @ex::@entry[%arg6, %arg7, %arg7] (%arg3) : (tensor<7x4x24xf32>) -> tensor<4x7x1024xf32>
     flow.return %0 : tensor<4x7x1024xf32>
@@ -180,7 +180,7 @@
     %4 = shapex.make_ranked_shape %arg5, %arg4 : (index, index) -> !shapex.ranked_shape<[?,?,1024]>
     %5 = shapex.tie_shape %arg3, %3 : tensor<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]>
     // CHECK: hal.command_buffer.push_constants %[[CMD]], %executable_layout, offset = 0, values = [%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}] : i32
-    // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %executable_layout, set=0, bindings=[0 = (%arg0, %c0, %9), 1 = (%buffer, %c0, %12)]
+    // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %executable_layout, set = %c0, bindings = [%c0 = (%arg0, %c0, %9), %c1 = (%buffer, %c0, %12)]
 
     // CHECK: #hal.device.match.id<"dylib*">(
     // CHECK-SAME: %[[CMD_INNER:.+]] = %cmd : !hal.command_buffer,
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
index e2c44da..eff6109 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
@@ -43,35 +43,20 @@
     SmallVector<Value, 8> callOperands = {
         newOperands.command_buffer(),
         newOperands.executable_layout(),
-        rewriter.create<mlir::ConstantOp>(
-            op.getLoc(), rewriter.getI32IntegerAttr(
-                             static_cast<int32_t>(op.setAttr().getInt()))),
+        newOperands.set(),
     };
     SmallVector<int16_t, 5> segmentSizes = {
         /*command_buffer=*/-1,
         /*executable_layout=*/-1,
         /*set=*/-1,
-        /*bindings_ordinals=*/
-        static_cast<int16_t>(op.bindings().size()),
-        /*bindings_buffers=*/
-        static_cast<int16_t>(op.bindings().size()),
-        /*bindings_offsets=*/
-        static_cast<int16_t>(op.bindings().size()),
-        /*bindings_lengths=*/
-        static_cast<int16_t>(op.bindings().size()),
+        /*bindings=*/
+        static_cast<int16_t>(newOperands.binding_ordinals().size()),
     };
-    for (auto bindingAttr : op.bindings()) {
-      callOperands.push_back(
-          rewriter.create<mlir::ConstantOp>(op.getLoc(), bindingAttr));
-    }
-    for (auto bindingBuffer : newOperands.binding_buffers()) {
-      callOperands.push_back(bindingBuffer);
-    }
-    for (auto bindingOffset : newOperands.binding_offsets()) {
-      callOperands.push_back(bindingOffset);
-    }
-    for (auto bindingLength : newOperands.binding_lengths()) {
-      callOperands.push_back(bindingLength);
+    for (size_t i = 0; i < newOperands.binding_ordinals().size(); ++i) {
+      callOperands.push_back(newOperands.binding_ordinals()[i]);
+      callOperands.push_back(newOperands.binding_buffers()[i]);
+      callOperands.push_back(newOperands.binding_offsets()[i]);
+      callOperands.push_back(newOperands.binding_lengths()[i]);
     }
 
     rewriter.replaceOpWithNewOp<IREE::VM::CallVariadicOp>(
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir
index d0615eb..f46de67 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir
@@ -58,11 +58,12 @@
     %arg0 : !hal.command_buffer,
     %arg1 : !hal.executable_layout,
     %arg2 : !hal.descriptor_set) {
+  %c0 = constant 0 : index
   %c100 = constant 100 : index
   // CHECK: vm.call.variadic @hal.command_buffer.bind_descriptor_set(%arg0, %arg1, %zero, %arg2, []) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable_layout>, i32, !vm.ref<!hal.descriptor_set>, i32 ...)
-  hal.command_buffer.bind_descriptor_set %arg0, %arg1, set=0, %arg2
-  // CHECK: vm.call.variadic @hal.command_buffer.bind_descriptor_set(%arg0, %arg1, %zero_0, %arg2, [%c100]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable_layout>, i32, !vm.ref<!hal.descriptor_set>, i32 ...)
-  hal.command_buffer.bind_descriptor_set %arg0, %arg1, set=0, %arg2, offsets=[%c100]
+  hal.command_buffer.bind_descriptor_set %arg0, %arg1, set = %c0, %arg2
+  // CHECK: vm.call.variadic @hal.command_buffer.bind_descriptor_set(%arg0, %arg1, %zero, %arg2, [%c100]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable_layout>, i32, !vm.ref<!hal.descriptor_set>, i32 ...)
+  hal.command_buffer.bind_descriptor_set %arg0, %arg1, set = %c0, %arg2, offsets = [%c100]
   return
 }
 
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 86f8521..0a1d25d 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -724,11 +724,18 @@
 
 void CommandBufferPushDescriptorSetOp::build(
     OpBuilder &builder, OperationState &state, Value commandBuffer,
-    Value executableLayout, uint32_t set,
+    Value executableLayout, int64_t set,
     ArrayRef<DescriptorSetBindingValue> bindings) {
-  state.addOperands({commandBuffer, executableLayout});
-  state.addAttribute("set", builder.getI32IntegerAttr(set));
-  SmallVector<int32_t, 4> bindingOrdinals;
+  build(builder, state, commandBuffer, executableLayout,
+        builder.createOrFold<ConstantIndexOp>(state.location, set), bindings);
+}
+
+void CommandBufferPushDescriptorSetOp::build(
+    OpBuilder &builder, OperationState &state, Value commandBuffer,
+    Value executableLayout, Value set,
+    ArrayRef<DescriptorSetBindingValue> bindings) {
+  state.addOperands({commandBuffer, executableLayout, set});
+  SmallVector<Value, 4> bindingOrdinals;
   SmallVector<Value, 4> bindingBuffers;
   SmallVector<Value, 4> bindingOffsets;
   SmallVector<Value, 4> bindingLengths;
@@ -738,7 +745,7 @@
     bindingOffsets.push_back(std::get<2>(binding));
     bindingLengths.push_back(std::get<3>(binding));
   }
-  state.addAttribute("bindings", builder.getI32ArrayAttr(bindingOrdinals));
+  state.addOperands(bindingOrdinals);
   state.addOperands(bindingBuffers);
   state.addOperands(bindingOffsets);
   state.addOperands(bindingLengths);
@@ -746,20 +753,19 @@
 
 static ParseResult parseDescriptorSetBindings(OpAsmParser &parser,
                                               OperationState *result) {
-  auto i32Type = parser.getBuilder().getIntegerType(32);
   auto indexType = parser.getBuilder().getIndexType();
-  SmallVector<Attribute, 4> bindingAttrs;
+  SmallVector<Value, 4> bindingOrdinals;
   SmallVector<Value, 4> bindingBuffers;
   SmallVector<Value, 4> bindingOffsets;
   SmallVector<Value, 4> bindingLengths;
   do {
-    IntegerAttr bindingAttr;
     NamedAttrList attrList;
+    OpAsmParser::OperandType ordinal;
     OpAsmParser::OperandType buffer;
     OpAsmParser::OperandType bufferOffset;
     OpAsmParser::OperandType bufferLength;
-    if (failed(
-            parser.parseAttribute(bindingAttr, i32Type, "binding", attrList)) ||
+    if (failed(parser.parseOperand(ordinal)) ||
+        failed(parser.resolveOperand(ordinal, indexType, bindingOrdinals)) ||
         failed(parser.parseEqual()) || failed(parser.parseLParen()) ||
         failed(parser.parseOperand(buffer)) ||
         failed(parser.resolveOperand(
@@ -775,10 +781,8 @@
         failed(parser.parseRParen())) {
       return failure();
     }
-    bindingAttrs.push_back(bindingAttr);
   } while (succeeded(parser.parseOptionalComma()));
-  result->addAttribute("bindings",
-                       parser.getBuilder().getArrayAttr(bindingAttrs));
+  result->addOperands(bindingOrdinals);
   result->addOperands(bindingBuffers);
   result->addOperands(bindingOffsets);
   result->addOperands(bindingLengths);
@@ -789,25 +793,24 @@
     OpAsmParser &parser, OperationState *result) {
   OpAsmParser::OperandType commandBuffer;
   OpAsmParser::OperandType executableLayout;
-  IntegerAttr setAttr;
+  OpAsmParser::OperandType set;
   auto operandsLoc = parser.getCurrentLocation();
   if (failed(parser.parseOperand(commandBuffer)) ||
       failed(parser.parseComma()) ||
       failed(parser.parseOperand(executableLayout)) ||
       failed(parser.parseComma()) || failed(parser.parseKeyword("set")) ||
-      failed(parser.parseEqual()) ||
-      failed(parser.parseAttribute(setAttr,
-                                   parser.getBuilder().getIntegerType(32),
-                                   "set", result->attributes)) ||
+      failed(parser.parseEqual()) || failed(parser.parseOperand(set)) ||
       failed(parser.parseComma()) ||
       failed(parser.resolveOperands(
           ArrayRef<OpAsmParser::OperandType>{
               commandBuffer,
               executableLayout,
+              set,
           },
           ArrayRef<Type>{
               CommandBufferType::get(result->getContext()),
               ExecutableLayoutType::get(result->getContext()),
+              IndexType::get(result->getContext()),
           },
           operandsLoc, result->operands)) ||
       failed(parser.parseKeyword("bindings")) || failed(parser.parseEqual()) ||
@@ -822,8 +825,8 @@
 
 template <typename T>
 static void printDescriptorSetBindings(OpAsmPrinter &p, T op) {
-  for (int i = 0; i < op.bindings().size(); ++i) {
-    p << op.bindings()[i].template cast<IntegerAttr>().getValue();
+  for (int i = 0; i < op.binding_ordinals().size(); ++i) {
+    p.printOperand(op.binding_ordinals()[i]);
     p << " = (";
     p.printOperand(op.binding_buffers()[i]);
     p << ", ";
@@ -831,7 +834,7 @@
     p << ", ";
     p.printOperand(op.binding_lengths()[i]);
     p << ")";
-    if (i < op.bindings().size() - 1) p << ", ";
+    if (i < op.binding_ordinals().size() - 1) p << ", ";
   }
 }
 
@@ -841,8 +844,9 @@
   p.printOperand(op.command_buffer());
   p << ", ";
   p.printOperand(op.executable_layout());
-  p << ", set=" << op.set();
-  p << ", bindings=[";
+  p << ", set = ";
+  p.printOperand(op.set());
+  p << ", bindings = [";
   printDescriptorSetBindings(p, op);
   p << "]";
   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{
@@ -859,11 +863,20 @@
                                              OperationState &state,
                                              Value commandBuffer,
                                              Value executableLayout,
-                                             uint32_t set, Value descriptorSet,
+                                             int64_t set, Value descriptorSet,
                                              ValueRange dynamicOffsets) {
-  state.addOperands({commandBuffer, executableLayout, descriptorSet});
-  state.addAttribute("set",
-                     builder.getIntegerAttr(builder.getIntegerType(32), set));
+  build(builder, state, commandBuffer, executableLayout,
+        builder.createOrFold<ConstantIndexOp>(state.location, set),
+        descriptorSet, dynamicOffsets);
+}
+
+void CommandBufferBindDescriptorSetOp::build(OpBuilder &builder,
+                                             OperationState &state,
+                                             Value commandBuffer,
+                                             Value executableLayout, Value set,
+                                             Value descriptorSet,
+                                             ValueRange dynamicOffsets) {
+  state.addOperands({commandBuffer, executableLayout, set, descriptorSet});
   state.addOperands(dynamicOffsets);
 }
 
@@ -966,7 +979,7 @@
     OpBuilder &builder, OperationState &state, Value device, Value setLayout,
     ArrayRef<DescriptorSetBindingValue> bindings) {
   state.addOperands({device, setLayout});
-  SmallVector<int32_t, 4> bindingOrdinals;
+  SmallVector<Value, 4> bindingOrdinals;
   SmallVector<Value, 4> bindingBuffers;
   SmallVector<Value, 4> bindingOffsets;
   SmallVector<Value, 4> bindingLengths;
@@ -976,7 +989,7 @@
     bindingOffsets.push_back(std::get<2>(binding));
     bindingLengths.push_back(std::get<3>(binding));
   }
-  state.addAttribute("bindings", builder.getI32ArrayAttr(bindingOrdinals));
+  state.addOperands(bindingOrdinals);
   state.addOperands(bindingBuffers);
   state.addOperands(bindingOffsets);
   state.addOperands(bindingLengths);
@@ -1015,7 +1028,7 @@
   p.printOperand(op.device());
   p << ", ";
   p.printOperand(op.set_layout());
-  p << ", bindings=[";
+  p << ", bindings = [";
   printDescriptorSetBindings(p, op);
   p << "]";
   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td
index bea5b3e..68f052c 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -1139,10 +1139,10 @@
     Pushes an inline-defined descriptor set to the command buffer.
 
     ```mlir
-    hal.command_buffer.push_descriptor_set %cmd, %executable_layout, set = 0, bindings = [
-      0 = (%buffer_0, %buffer_offset_0, %buffer_length_0),
-      1 = (%buffer_1, %buffer_offset_1, %buffer_length_1),
-      2 = (%buffer_2, %buffer_offset_2, %buffer_length_2)
+    hal.command_buffer.push_descriptor_set %cmd, %executable_layout, set = %c0, bindings = [
+      %c0 = (%buffer_0, %buffer_offset_0, %buffer_length_0),
+      %c1 = (%buffer_1, %buffer_offset_1, %buffer_length_1),
+      %c2 = (%buffer_2, %buffer_offset_2, %buffer_length_2)
     ]
     ```
   }];
@@ -1150,8 +1150,8 @@
   let arguments = (ins
     HAL_CommandBuffer:$command_buffer,
     HAL_ExecutableLayout:$executable_layout,
-    I32Attr:$set,
-    I32ArrayAttr:$bindings,
+    Index:$set,
+    Variadic<Index>:$binding_ordinals,
     Variadic<HAL_Buffer>:$binding_buffers,
     Variadic<HAL_DeviceSize>:$binding_offsets,
     Variadic<HAL_DeviceSize>:$binding_lengths
@@ -1160,7 +1160,9 @@
   let skipDefaultBuilders = 1;
   let builders = [
     OpBuilderDAG<(ins "Value":$commandBuffer, "Value":$executableLayout,
-      "uint32_t":$set, "ArrayRef<DescriptorSetBindingValue>":$bindings)>,
+      "int64_t":$set, "ArrayRef<DescriptorSetBindingValue>":$bindings)>,
+    OpBuilderDAG<(ins "Value":$commandBuffer, "Value":$executableLayout,
+      "Value":$set, "ArrayRef<DescriptorSetBindingValue>":$bindings)>,
   ];
 
   let hasCanonicalizer = 1;
@@ -1177,7 +1179,7 @@
   let arguments = (ins
     HAL_CommandBuffer:$command_buffer,
     HAL_ExecutableLayout:$executable_layout,
-    I32Attr:$set,
+    Index:$set,
     HAL_DescriptorSet:$descriptor_set,
     Variadic<HAL_DeviceSize>:$dynamic_offsets
   );
@@ -1191,7 +1193,10 @@
   let skipDefaultBuilders = 1;
   let builders = [
     OpBuilderDAG<(ins "Value":$commandBuffer, "Value":$executableLayout,
-      "uint32_t":$set, "Value":$descriptorSet,
+      "int64_t":$set, "Value":$descriptorSet,
+      CArg<"ValueRange", "{}">:$dynamicOffsets)>,
+    OpBuilderDAG<(ins "Value":$commandBuffer, "Value":$executableLayout,
+      "Value":$set, "Value":$descriptorSet,
       CArg<"ValueRange", "{}">:$dynamicOffsets)>,
   ];
 }
@@ -1551,7 +1556,7 @@
   let arguments = (ins
     HAL_Device:$device,
     HAL_DescriptorSetLayout:$set_layout,
-    I32ArrayAttr:$bindings,
+    Variadic<Index>:$binding_ordinals,
     Variadic<HAL_Buffer>:$binding_buffers,
     Variadic<HAL_DeviceSize>:$binding_offsets,
     Variadic<HAL_DeviceSize>:$binding_lengths
diff --git a/iree/compiler/Dialect/HAL/IR/HALTypes.h b/iree/compiler/Dialect/HAL/IR/HALTypes.h
index 752ce04..1e0bed6 100644
--- a/iree/compiler/Dialect/HAL/IR/HALTypes.h
+++ b/iree/compiler/Dialect/HAL/IR/HALTypes.h
@@ -182,7 +182,7 @@
 
 // A tuple containing runtime values for a descriptor set binding:
 // <binding ordinal, hal.buffer, buffer byte offset, buffer byte length>
-using DescriptorSetBindingValue = std::tuple<uint32_t, Value, Value, Value>;
+using DescriptorSetBindingValue = std::tuple<Value, Value, Value, Value>;
 
 }  // namespace HAL
 }  // namespace IREE
diff --git a/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir b/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir
index 6ede3e1..84f8ed0 100644
--- a/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir
+++ b/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir
@@ -23,23 +23,25 @@
     %buffer : !hal.buffer
   ) {
   %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
   %c4 = constant 4 : index
   %c4096 = constant 4096 : index
   %c8000 = constant 8000 : index
   %c262140 = constant 262140 : index
   %c262144 = constant 262144 : index
   %subspan = hal.buffer.subspan %buffer, %c4096, %c262144 : !hal.buffer
-  // CHECK: hal.command_buffer.push_descriptor_set {{.+}}, bindings=[
-  hal.command_buffer.push_descriptor_set %cmd, %layout, set=0, bindings=[
+  // CHECK: hal.command_buffer.push_descriptor_set {{.+}}, set = %c0, bindings = [
+  hal.command_buffer.push_descriptor_set %cmd, %layout, set = %c0, bindings = [
     // 0 + 4096:
-    // CHECK-SAME: 0 = ([[BASE_BUFFER]], %c4096, %c8000)
-    0 = (%subspan, %c0, %c8000),
+    // CHECK-SAME: %c0 = ([[BASE_BUFFER]], %c4096, %c8000)
+    %c0 = (%subspan, %c0, %c8000),
     // 4096 + 4:
-    // CHECK-SAME: 1 = ([[BASE_BUFFER]], %c4100, %c262140)
-    1 = (%subspan, %c4, %c262140),
+    // CHECK-SAME: %c1 = ([[BASE_BUFFER]], %c4100, %c262140)
+    %c1 = (%subspan, %c4, %c262140),
     // No change:
-    // CHECK-SAME: 2 = ([[BASE_BUFFER]], %c4096, %c262144)
-    2 = (%buffer, %c4096, %c262144)
+    // CHECK-SAME: %c2 = ([[BASE_BUFFER]], %c4096, %c262144)
+    %c2 = (%buffer, %c4096, %c262144)
   ]
   return
 }
diff --git a/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir
index 5291bf8..9730bfe 100644
--- a/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir
+++ b/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir
@@ -91,10 +91,11 @@
   %0 = "test_hal.executable_layout"() : () -> !hal.executable_layout
   %1 = "test_hal.descriptor_set"() : () -> !hal.descriptor_set
   %2 = "test_hal.offset"() : () -> index
-  // CHECK: hal.command_buffer.bind_descriptor_set %arg0, %0, set = 0, %1
-  hal.command_buffer.bind_descriptor_set %arg0, %0, set = 0, %1
-  // CHECK-NEXT: hal.command_buffer.bind_descriptor_set %arg0, %0, set = 0, %1, offsets = [%2]
-  hal.command_buffer.bind_descriptor_set %arg0, %0, set = 0, %1, offsets = [%2]
+  %c0 = constant 0 : index
+  // CHECK: hal.command_buffer.bind_descriptor_set %arg0, %0, set = %c0, %1
+  hal.command_buffer.bind_descriptor_set %arg0, %0, set = %c0, %1
+  // CHECK-NEXT: hal.command_buffer.bind_descriptor_set %arg0, %0, set = %c0, %1, offsets = [%2]
+  hal.command_buffer.bind_descriptor_set %arg0, %0, set = %c0, %1, offsets = [%2]
   return
 }
 
diff --git a/iree/compiler/Dialect/HAL/hal.imports.mlir b/iree/compiler/Dialect/HAL/hal.imports.mlir
index 39de1ec..c75e756 100644
--- a/iree/compiler/Dialect/HAL/hal.imports.mlir
+++ b/iree/compiler/Dialect/HAL/hal.imports.mlir
@@ -194,10 +194,8 @@
   %command_buffer : !vm.ref<!hal.command_buffer>,
   %executable_layout : !vm.ref<!hal.executable_layout>,
   %set : i32,
-  %bindings : i32 ...,
-  %binding_buffers : !vm.ref<!hal.buffer>...,
-  %binding_offsets : i32 ...,
-  %binding_lengths : i32 ...
+  // <binding, buffer, offset, length>
+  %bindings : tuple<i32, !vm.ref<!hal.buffer>, i32, i32>...
 )
 
 // Binds a descriptor set to the given set number.
@@ -237,10 +235,8 @@
 vm.import @descriptor_set.create(
   %device : !vm.ref<!hal.device>,
   %set_layout : !vm.ref<!hal.descriptor_set_layout>,
-  %bindings : i32 ...,
-  %binding_buffers : !vm.ref<!hal.buffer>...,
-  %binding_offsets : i32 ...,
-  %binding_lengths : i32 ...
+  // <binding, buffer, offset, length>
+  %bindings : tuple<i32, !vm.ref<!hal.buffer>, i32, i32>...
 ) -> !vm.ref<!hal.descriptor_set>
 
 //===----------------------------------------------------------------------===//
diff --git a/iree/modules/hal/hal_module.cc b/iree/modules/hal/hal_module.cc
index 9994903..9f3c164 100644
--- a/iree/modules/hal/hal_module.cc
+++ b/iree/modules/hal/hal_module.cc
@@ -467,19 +467,19 @@
   Status CommandBufferPushDescriptorSet(
       const vm::ref<iree_hal_command_buffer_t>& command_buffer,
       const vm::ref<iree_hal_executable_layout_t>& executable_layout,
-      uint32_t set, absl::Span<const uint32_t> binding_ordinals,
-      absl::Span<const vm::ref<iree_hal_buffer_t>> binding_buffers,
-      absl::Span<const int32_t> binding_offsets,
-      absl::Span<const int32_t> binding_lengths) {
+      uint32_t set,
+      absl::Span<const std::tuple<uint32_t, vm::ref<iree_hal_buffer_t>, int32_t,
+                                  int32_t>>
+          bindings) {
     ExDeferRelease(executable_layout);
     absl::InlinedVector<iree_hal_descriptor_set_binding_t, 16> binding_structs(
-        binding_ordinals.size());
-    for (int i = 0; i < binding_ordinals.size(); ++i) {
+        bindings.size());
+    for (int i = 0; i < bindings.size(); ++i) {
       binding_structs[i] = {
-          binding_ordinals[i], binding_buffers[i].get(),
-          static_cast<iree_device_size_t>(binding_offsets[i]),
-          static_cast<iree_device_size_t>(binding_lengths[i])};
-      ExDeferRelease(binding_buffers[i]);
+          std::get<0>(bindings[i]), std::get<1>(bindings[i]).get(),
+          static_cast<iree_device_size_t>(std::get<2>(bindings[i])),
+          static_cast<iree_device_size_t>(std::get<3>(bindings[i]))};
+      ExDeferRelease(std::get<1>(bindings[i]));
     }
     return iree_hal_command_buffer_push_descriptor_set(
         command_buffer.get(), executable_layout.get(), set,
@@ -534,18 +534,17 @@
   StatusOr<vm::ref<iree_hal_descriptor_set_t>> DescriptorSetCreate(
       const vm::ref<iree_hal_device_t>& device,
       const vm::ref<iree_hal_descriptor_set_layout_t>& set_layout,
-      absl::Span<const uint32_t> binding_ordinals,
-      absl::Span<const vm::ref<iree_hal_buffer_t>> binding_buffers,
-      absl::Span<const uint32_t> binding_offsets,
-      absl::Span<const uint32_t> binding_lengths) {
+      absl::Span<const std::tuple<uint32_t, vm::ref<iree_hal_buffer_t>, int32_t,
+                                  int32_t>>
+          bindings) {
     absl::InlinedVector<iree_hal_descriptor_set_binding_t, 4> binding_structs(
-        binding_ordinals.size());
-    for (int i = 0; i < binding_ordinals.size(); ++i) {
+        bindings.size());
+    for (int i = 0; i < bindings.size(); ++i) {
       binding_structs[i] = {
-          binding_ordinals[i],                                   // binding
-          binding_buffers[i].get(),                              // buffer
-          static_cast<iree_device_size_t>(binding_offsets[i]),   // offset
-          static_cast<iree_device_size_t>(binding_lengths[i])};  // length
+          /*ordinal=*/std::get<0>(bindings[i]),
+          /*buffer=*/std::get<1>(bindings[i]).get(),
+          /*offset=*/static_cast<iree_device_size_t>(std::get<2>(bindings[i])),
+          /*length=*/static_cast<iree_device_size_t>(std::get<3>(bindings[i]))};
     }
     vm::ref<iree_hal_descriptor_set_t> descriptor_set;
     IREE_RETURN_IF_ERROR(iree_hal_descriptor_set_create(