Adding some HAL ops and builders needed for command buffer recording. PiperOrigin-RevId: 283774286
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp index 813ec7f..c7f19f0 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp
@@ -79,10 +79,8 @@ auto sourceShape = IREE::HAL::getShapeDims(loadOp.source(), rewriter); auto *sourceOffset = rewriter.createOrFold<IREE::HAL::BufferViewComputeOffsetOp>( - loadOp.getLoc(), rewriter.getIntegerType(32), operands.source(), - sourceShape, operands.indices(), - APInt(32, IREE::HAL::getRoundedElementByteWidth( - sourceType.getElementType()))); + loadOp.getLoc(), operands.source(), sourceShape, operands.indices(), + IREE::HAL::getRoundedElementByteWidth(sourceType.getElementType())); rewriter.replaceOpWithNewOp<IREE::HAL::BufferLoadOp>( loadOp, converter.convertType(loadOp.result()->getType()), operands.source(), sourceOffset); @@ -107,10 +105,9 @@ auto targetShape = IREE::HAL::getShapeDims(storeOp.target(), rewriter); auto *targetOffset = rewriter.createOrFold<IREE::HAL::BufferViewComputeOffsetOp>( - storeOp.getLoc(), rewriter.getIntegerType(32), operands.target(), - targetShape, operands.indices(), - APInt(32, IREE::HAL::getRoundedElementByteWidth( - targetType.getElementType()))); + storeOp.getLoc(), operands.target(), targetShape, + operands.indices(), + IREE::HAL::getRoundedElementByteWidth(targetType.getElementType())); rewriter.create<IREE::HAL::BufferStoreOp>( storeOp.getLoc(), operands.value(), operands.target(), targetOffset); rewriter.replaceOp(storeOp, {operands.value()});
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp index 71d1b77..f7bdacb 100644 --- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -133,6 +133,60 @@ } //===----------------------------------------------------------------------===// +// hal.ex.push_binding +//===----------------------------------------------------------------------===// + +static ParseResult parseExPushBindingOp(OpAsmParser &parser, + OperationState *result) { + OpAsmParser::OperandType commandBuffer; + OpAsmParser::OperandType buffer; + SmallVector<OpAsmParser::OperandType, 4> shape; + IntegerAttr ordinalAttr; + IntegerAttr elementSizeAttr; + if (failed(parser.parseOperand(commandBuffer)) || + failed(parser.parseComma()) || + failed(parser.resolveOperand( + commandBuffer, + RefPtrType::get(CommandBufferType::get(result->getContext())), + result->operands)) || + failed(parser.parseAttribute(ordinalAttr, + parser.getBuilder().getIntegerType(32), + "ordinal", result->attributes)) || + failed(parser.parseComma()) || failed(parser.parseOperand(buffer)) || + failed(parser.parseComma()) || + failed(parser.resolveOperand( + buffer, RefPtrType::get(BufferType::get(result->getContext())), + result->operands)) || + failed(parser.parseKeyword("shape")) || failed(parser.parseEqual()) || + failed(parser.parseOperandList(shape, OpAsmParser::Delimiter::Square)) || + failed(parser.resolveOperands(shape, getDimType(parser), + result->operands)) || + failed(parser.parseComma()) || + failed(parser.parseKeyword("element_size")) || + failed(parser.parseEqual()) || + failed(parser.parseAttribute(elementSizeAttr, + parser.getBuilder().getIntegerType(32), + "element_size", result->attributes)) || + failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) { + return failure(); + } + return success(); +} + +static void printExPushBindingOp(OpAsmPrinter &p, ExPushBindingOp op) { + p << op.getOperationName() << ' '; + p.printOperand(op.command_buffer()); + p << ", " << op.ordinal() << ", "; + p.printOperand(op.buffer()); + p << ", shape=["; + interleaveComma(op.shape(), p, [&](Value *value) { p.printOperand(value); }); + p << "], element_size=" << op.element_size(); + p.printOptionalAttrDictWithKeyword( + op.getAttrs(), + /*elidedAttrs=*/{"ordinal", "element_size"}); +} + +//===----------------------------------------------------------------------===// // hal.ex.executable_descriptor_set_layout //===----------------------------------------------------------------------===// @@ -1155,6 +1209,17 @@ // hal.buffer_view.compute_offset //===----------------------------------------------------------------------===// +void BufferViewComputeOffsetOp::build(Builder *builder, OperationState &state, + Value *buffer, ArrayRef<Value *> shape, + ArrayRef<Value *> indices, + int32_t elementSize) { + state.addOperands({buffer}); + state.addOperands(shape); + state.addOperands(indices); + state.addAttribute("element_size", builder->getI32IntegerAttr(elementSize)); + state.addTypes({builder->getIntegerType(32)}); +} + void BufferViewComputeOffsetOp::getAsmResultNames( function_ref<void(Value *, StringRef)> setNameFn) { setNameFn(offset(), "off"); @@ -1208,9 +1273,78 @@ } //===----------------------------------------------------------------------===// +// hal.buffer_view.compute_length +//===----------------------------------------------------------------------===// + +void BufferViewComputeLengthOp::build(Builder *builder, OperationState &state, + Value *buffer, ArrayRef<Value *> shape, + int32_t elementSize) { + state.addOperands({buffer}); + state.addOperands(shape); + state.addAttribute("element_size", builder->getI32IntegerAttr(elementSize)); + state.addTypes({builder->getIntegerType(32)}); +} + +void BufferViewComputeLengthOp::getAsmResultNames( + function_ref<void(Value *, StringRef)> setNameFn) { + setNameFn(length(), "len"); +} + +static ParseResult parseBufferViewComputeLengthOp(OpAsmParser &parser, + OperationState *result) { + OpAsmParser::OperandType buffer; + SmallVector<OpAsmParser::OperandType, 4> shape; + IntegerAttr elementSize; + if (failed(parser.parseOperand(buffer)) || + failed(parser.resolveOperand( + buffer, RefPtrType::get(BufferType::get(result->getContext())), + result->operands)) || + failed(parser.parseComma()) || failed(parser.parseKeyword("shape")) || + failed(parser.parseEqual()) || + failed(parser.parseOperandList(shape, OpAsmParser::Delimiter::Square)) || + failed(parser.resolveOperands(shape, getDimType(parser), + result->operands)) || + failed(parser.parseComma()) || + failed(parser.parseKeyword("element_size")) || + failed(parser.parseEqual()) || + failed(parser.parseAttribute(elementSize, + parser.getBuilder().getIntegerType(32), + "element_size", result->attributes)) || + failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) { + return failure(); + } + result->addTypes(getDeviceSizeType(parser)); + return success(); +} + +static void printBufferViewComputeLengthOp(OpAsmPrinter &p, + BufferViewComputeLengthOp op) { + p << op.getOperationName() << ' '; + p.printOperand(op.buffer()); + p << ", shape=["; + p.printOperands(op.shape()); + p << "], element_size=" << op.element_size(); + p.printOptionalAttrDictWithKeyword(op.getAttrs(), + /*elidedAttrs=*/{"element_size"}); +} + +//===----------------------------------------------------------------------===// // hal.buffer_view.compute_range //===----------------------------------------------------------------------===// +void BufferViewComputeRangeOp::build(Builder *builder, OperationState &state, + Value *buffer, ArrayRef<Value *> shape, + ArrayRef<Value *> indices, + ArrayRef<Value *> lengths, + int32_t elementSize) { + state.addOperands({buffer}); + state.addOperands(shape); + state.addOperands(indices); + state.addOperands(lengths); + state.addAttribute("element_size", builder->getI32IntegerAttr(elementSize)); + state.addTypes({builder->getIntegerType(32), builder->getIntegerType(32)}); +} + void BufferViewComputeRangeOp::getAsmResultNames( function_ref<void(Value *, StringRef)> setNameFn) { setNameFn(offset(), "off");
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td index d0e7784..d24f5a9 100644 --- a/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -64,6 +64,17 @@ ]; } +// TODO(benvanik): remove and replace with descriptor sets. +def HAL_ExPushBindingOp : HAL_Op<"ex.push_binding"> { + let arguments = (ins + RefPtrOf<HAL_CommandBuffer>:$command_buffer, + I32Attr:$ordinal, + RefPtrOf<HAL_Buffer>:$buffer, + HAL_Shape:$shape, + I32Attr:$element_size + ); +} + def HAL_ExExecutableDescriptorSetLayoutOp : HAL_PureOp<"ex.executable_descriptor_set_layout", [ DeclareOpInterfaceMethods<OpAsmOpInterface>, @@ -498,6 +509,42 @@ let results = (outs HAL_DeviceSize:$offset ); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<[{ + Builder *builder, OperationState &state, Value *buffer, + ArrayRef<Value *> shape, ArrayRef<Value *> indices, int32_t elementSize + }]>, + ]; +} + +def HAL_BufferViewComputeLengthOp : HAL_PureOp<"buffer_view.compute_length", [ + DeclareOpInterfaceMethods<OpAsmOpInterface>, + SameVariadicOperandSize, + ]> { + let summary = [{buffer view shape to byte size computation operation}]; + let description = [{ + Computes a shaped buffer view length in bytes. + }]; + + let arguments = (ins + RefPtrOf<HAL_Buffer>:$buffer, + HAL_Shape:$shape, + I32Attr:$element_size + ); + + let results = (outs + HAL_DeviceSize:$length + ); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<[{ + Builder *builder, OperationState &state, Value *buffer, + ArrayRef<Value *> shape, int32_t elementSize + }]>, + ]; } def HAL_BufferViewComputeRangeOp : HAL_PureOp<"buffer_view.compute_range", [ @@ -521,6 +568,15 @@ HAL_DeviceSize:$offset, HAL_DeviceSize:$length ); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<[{ + Builder *builder, OperationState &state, Value *buffer, + ArrayRef<Value *> shape, ArrayRef<Value *> indices, + ArrayRef<Value *> lengths, int32_t elementSize + }]>, + ]; } def HAL_BufferViewSliceOp : HAL_PureOp<"buffer_view.slice", [
diff --git a/iree/compiler/Dialect/HAL/IR/test/buffer_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/buffer_ops.mlir index d9e3ced..7593286 100644 --- a/iree/compiler/Dialect/HAL/IR/test/buffer_ops.mlir +++ b/iree/compiler/Dialect/HAL/IR/test/buffer_ops.mlir
@@ -110,6 +110,16 @@ // ----- +// CHECK-LABEL: @buffer_view_compute_length +func @buffer_view_compute_length(%arg0 : !ireex.ref<!hal.buffer>) -> i32 { + %0:2 = "test_hal.shape"() : () -> (i32, i32) + // CHECK: %len = hal.buffer_view.compute_length %arg0, shape=[%0#0, %0#1], element_size=4 + %len = hal.buffer_view.compute_length %arg0, shape=[%0#0, %0#1], element_size=4 + return %len : i32 +} + +// ----- + // CHECK-LABEL: @buffer_view_compute_range func @buffer_view_compute_range(%arg0 : !ireex.ref<!hal.buffer>) -> (i32, i32) { %0:2 = "test_hal.shape"() : () -> (i32, i32)