Adding some HAL ops and builders needed for command buffer recording.

PiperOrigin-RevId: 283774286
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp
index 813ec7f..c7f19f0 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp
@@ -79,10 +79,8 @@
     auto sourceShape = IREE::HAL::getShapeDims(loadOp.source(), rewriter);
     auto *sourceOffset =
         rewriter.createOrFold<IREE::HAL::BufferViewComputeOffsetOp>(
-            loadOp.getLoc(), rewriter.getIntegerType(32), operands.source(),
-            sourceShape, operands.indices(),
-            APInt(32, IREE::HAL::getRoundedElementByteWidth(
-                          sourceType.getElementType())));
+            loadOp.getLoc(), operands.source(), sourceShape, operands.indices(),
+            IREE::HAL::getRoundedElementByteWidth(sourceType.getElementType()));
     rewriter.replaceOpWithNewOp<IREE::HAL::BufferLoadOp>(
         loadOp, converter.convertType(loadOp.result()->getType()),
         operands.source(), sourceOffset);
@@ -107,10 +105,9 @@
     auto targetShape = IREE::HAL::getShapeDims(storeOp.target(), rewriter);
     auto *targetOffset =
         rewriter.createOrFold<IREE::HAL::BufferViewComputeOffsetOp>(
-            storeOp.getLoc(), rewriter.getIntegerType(32), operands.target(),
-            targetShape, operands.indices(),
-            APInt(32, IREE::HAL::getRoundedElementByteWidth(
-                          targetType.getElementType())));
+            storeOp.getLoc(), operands.target(), targetShape,
+            operands.indices(),
+            IREE::HAL::getRoundedElementByteWidth(targetType.getElementType()));
     rewriter.create<IREE::HAL::BufferStoreOp>(
         storeOp.getLoc(), operands.value(), operands.target(), targetOffset);
     rewriter.replaceOp(storeOp, {operands.value()});
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 71d1b77..f7bdacb 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -133,6 +133,60 @@
 }
 
 //===----------------------------------------------------------------------===//
+// hal.ex.push_binding
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseExPushBindingOp(OpAsmParser &parser,
+                                        OperationState *result) {
+  OpAsmParser::OperandType commandBuffer;
+  OpAsmParser::OperandType buffer;
+  SmallVector<OpAsmParser::OperandType, 4> shape;
+  IntegerAttr ordinalAttr;
+  IntegerAttr elementSizeAttr;
+  if (failed(parser.parseOperand(commandBuffer)) ||
+      failed(parser.parseComma()) ||
+      failed(parser.resolveOperand(
+          commandBuffer,
+          RefPtrType::get(CommandBufferType::get(result->getContext())),
+          result->operands)) ||
+      failed(parser.parseAttribute(ordinalAttr,
+                                   parser.getBuilder().getIntegerType(32),
+                                   "ordinal", result->attributes)) ||
+      failed(parser.parseComma()) || failed(parser.parseOperand(buffer)) ||
+      failed(parser.parseComma()) ||
+      failed(parser.resolveOperand(
+          buffer, RefPtrType::get(BufferType::get(result->getContext())),
+          result->operands)) ||
+      failed(parser.parseKeyword("shape")) || failed(parser.parseEqual()) ||
+      failed(parser.parseOperandList(shape, OpAsmParser::Delimiter::Square)) ||
+      failed(parser.resolveOperands(shape, getDimType(parser),
+                                    result->operands)) ||
+      failed(parser.parseComma()) ||
+      failed(parser.parseKeyword("element_size")) ||
+      failed(parser.parseEqual()) ||
+      failed(parser.parseAttribute(elementSizeAttr,
+                                   parser.getBuilder().getIntegerType(32),
+                                   "element_size", result->attributes)) ||
+      failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
+    return failure();
+  }
+  return success();
+}
+
+static void printExPushBindingOp(OpAsmPrinter &p, ExPushBindingOp op) {
+  p << op.getOperationName() << ' ';
+  p.printOperand(op.command_buffer());
+  p << ", " << op.ordinal() << ", ";
+  p.printOperand(op.buffer());
+  p << ", shape=[";
+  interleaveComma(op.shape(), p, [&](Value *value) { p.printOperand(value); });
+  p << "], element_size=" << op.element_size();
+  p.printOptionalAttrDictWithKeyword(
+      op.getAttrs(),
+      /*elidedAttrs=*/{"ordinal", "element_size"});
+}
+
+//===----------------------------------------------------------------------===//
 // hal.ex.executable_descriptor_set_layout
 //===----------------------------------------------------------------------===//
 
@@ -1155,6 +1209,17 @@
 // hal.buffer_view.compute_offset
 //===----------------------------------------------------------------------===//
 
+void BufferViewComputeOffsetOp::build(Builder *builder, OperationState &state,
+                                      Value *buffer, ArrayRef<Value *> shape,
+                                      ArrayRef<Value *> indices,
+                                      int32_t elementSize) {
+  state.addOperands({buffer});
+  state.addOperands(shape);
+  state.addOperands(indices);
+  state.addAttribute("element_size", builder->getI32IntegerAttr(elementSize));
+  state.addTypes({builder->getIntegerType(32)});
+}
+
 void BufferViewComputeOffsetOp::getAsmResultNames(
     function_ref<void(Value *, StringRef)> setNameFn) {
   setNameFn(offset(), "off");
@@ -1208,9 +1273,78 @@
 }
 
 //===----------------------------------------------------------------------===//
+// hal.buffer_view.compute_length
+//===----------------------------------------------------------------------===//
+
+void BufferViewComputeLengthOp::build(Builder *builder, OperationState &state,
+                                      Value *buffer, ArrayRef<Value *> shape,
+                                      int32_t elementSize) {
+  state.addOperands({buffer});
+  state.addOperands(shape);
+  state.addAttribute("element_size", builder->getI32IntegerAttr(elementSize));
+  state.addTypes({builder->getIntegerType(32)});
+}
+
+void BufferViewComputeLengthOp::getAsmResultNames(
+    function_ref<void(Value *, StringRef)> setNameFn) {
+  setNameFn(length(), "len");
+}
+
+static ParseResult parseBufferViewComputeLengthOp(OpAsmParser &parser,
+                                                  OperationState *result) {
+  OpAsmParser::OperandType buffer;
+  SmallVector<OpAsmParser::OperandType, 4> shape;
+  IntegerAttr elementSize;
+  if (failed(parser.parseOperand(buffer)) ||
+      failed(parser.resolveOperand(
+          buffer, RefPtrType::get(BufferType::get(result->getContext())),
+          result->operands)) ||
+      failed(parser.parseComma()) || failed(parser.parseKeyword("shape")) ||
+      failed(parser.parseEqual()) ||
+      failed(parser.parseOperandList(shape, OpAsmParser::Delimiter::Square)) ||
+      failed(parser.resolveOperands(shape, getDimType(parser),
+                                    result->operands)) ||
+      failed(parser.parseComma()) ||
+      failed(parser.parseKeyword("element_size")) ||
+      failed(parser.parseEqual()) ||
+      failed(parser.parseAttribute(elementSize,
+                                   parser.getBuilder().getIntegerType(32),
+                                   "element_size", result->attributes)) ||
+      failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
+    return failure();
+  }
+  result->addTypes(getDeviceSizeType(parser));
+  return success();
+}
+
+static void printBufferViewComputeLengthOp(OpAsmPrinter &p,
+                                           BufferViewComputeLengthOp op) {
+  p << op.getOperationName() << ' ';
+  p.printOperand(op.buffer());
+  p << ", shape=[";
+  p.printOperands(op.shape());
+  p << "], element_size=" << op.element_size();
+  p.printOptionalAttrDictWithKeyword(op.getAttrs(),
+                                     /*elidedAttrs=*/{"element_size"});
+}
+
+//===----------------------------------------------------------------------===//
 // hal.buffer_view.compute_range
 //===----------------------------------------------------------------------===//
 
+void BufferViewComputeRangeOp::build(Builder *builder, OperationState &state,
+                                     Value *buffer, ArrayRef<Value *> shape,
+                                     ArrayRef<Value *> indices,
+                                     ArrayRef<Value *> lengths,
+                                     int32_t elementSize) {
+  state.addOperands({buffer});
+  state.addOperands(shape);
+  state.addOperands(indices);
+  state.addOperands(lengths);
+  state.addAttribute("element_size", builder->getI32IntegerAttr(elementSize));
+  state.addTypes({builder->getIntegerType(32), builder->getIntegerType(32)});
+}
+
 void BufferViewComputeRangeOp::getAsmResultNames(
     function_ref<void(Value *, StringRef)> setNameFn) {
   setNameFn(offset(), "off");
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td
index d0e7784..d24f5a9 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -64,6 +64,17 @@
   ];
 }
 
+// TODO(benvanik): remove and replace with descriptor sets.
+def HAL_ExPushBindingOp : HAL_Op<"ex.push_binding"> {
+  let arguments = (ins
+    RefPtrOf<HAL_CommandBuffer>:$command_buffer,
+    I32Attr:$ordinal,
+    RefPtrOf<HAL_Buffer>:$buffer,
+    HAL_Shape:$shape,
+    I32Attr:$element_size
+  );
+}
+
 def HAL_ExExecutableDescriptorSetLayoutOp :
   HAL_PureOp<"ex.executable_descriptor_set_layout", [
     DeclareOpInterfaceMethods<OpAsmOpInterface>,
@@ -498,6 +509,42 @@
   let results = (outs
     HAL_DeviceSize:$offset
   );
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<[{
+      Builder *builder, OperationState &state, Value *buffer,
+      ArrayRef<Value *> shape, ArrayRef<Value *> indices, int32_t elementSize
+    }]>,
+  ];
+}
+
+def HAL_BufferViewComputeLengthOp : HAL_PureOp<"buffer_view.compute_length", [
+    DeclareOpInterfaceMethods<OpAsmOpInterface>,
+    SameVariadicOperandSize,
+  ]> {
+  let summary = [{buffer view shape to byte size computation operation}];
+  let description = [{
+    Computes a shaped buffer view length in bytes.
+  }];
+
+  let arguments = (ins
+    RefPtrOf<HAL_Buffer>:$buffer,
+    HAL_Shape:$shape,
+    I32Attr:$element_size
+  );
+
+  let results = (outs
+    HAL_DeviceSize:$length
+  );
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<[{
+      Builder *builder, OperationState &state, Value *buffer,
+      ArrayRef<Value *> shape, int32_t elementSize
+    }]>,
+  ];
 }
 
 def HAL_BufferViewComputeRangeOp : HAL_PureOp<"buffer_view.compute_range", [
@@ -521,6 +568,15 @@
     HAL_DeviceSize:$offset,
     HAL_DeviceSize:$length
   );
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<[{
+      Builder *builder, OperationState &state, Value *buffer,
+      ArrayRef<Value *> shape, ArrayRef<Value *> indices,
+      ArrayRef<Value *> lengths, int32_t elementSize
+    }]>,
+  ];
 }
 
 def HAL_BufferViewSliceOp : HAL_PureOp<"buffer_view.slice", [
diff --git a/iree/compiler/Dialect/HAL/IR/test/buffer_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/buffer_ops.mlir
index d9e3ced..7593286 100644
--- a/iree/compiler/Dialect/HAL/IR/test/buffer_ops.mlir
+++ b/iree/compiler/Dialect/HAL/IR/test/buffer_ops.mlir
@@ -110,6 +110,16 @@
 
 // -----
 
+// CHECK-LABEL: @buffer_view_compute_length
+func @buffer_view_compute_length(%arg0 : !ireex.ref<!hal.buffer>) -> i32 {
+  %0:2 = "test_hal.shape"() : () -> (i32, i32)
+  // CHECK: %len = hal.buffer_view.compute_length %arg0, shape=[%0#0, %0#1], element_size=4
+  %len = hal.buffer_view.compute_length %arg0, shape=[%0#0, %0#1], element_size=4
+  return %len : i32
+}
+
+// -----
+
 // CHECK-LABEL: @buffer_view_compute_range
 func @buffer_view_compute_range(%arg0 : !ireex.ref<!hal.buffer>) -> (i32, i32) {
   %0:2 = "test_hal.shape"() : () -> (i32, i32)