| // Copyright 2019 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Dialect/HAL/IR/HALOps.h" |
| |
| #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/IRMapping.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/IR/SymbolTable.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/Interfaces/FunctionImplementation.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| namespace IREE { |
| namespace HAL { |
| |
| //===----------------------------------------------------------------------===// |
| // custom<DescriptorType>($descriptor_type) |
| //===----------------------------------------------------------------------===// |
| |
| // Custom parser/printer to omit the wrapping `<` and `>` unlike autogenerated |
| // attribute parser/printer. |
| static ParseResult parseDescriptorType(OpAsmParser &parser, |
| DescriptorTypeAttr &dtAttr) { |
| StringRef enumKeyword; |
| if (failed(parser.parseKeyword(&enumKeyword))) |
| return failure(); |
| std::optional<DescriptorType> maybeEnum = |
| symbolizeDescriptorType(enumKeyword); |
| if (!maybeEnum) |
| return failure(); |
| dtAttr = DescriptorTypeAttr::get(parser.getContext(), *maybeEnum); |
| return success(); |
| } |
| |
| static void printDescriptorType(OpAsmPrinter &p, Operation *, |
| DescriptorTypeAttr dtAttr) { |
| p << stringifyDescriptorType(dtAttr.getValue()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<DescriptorSetBindings>($binding_ordinals, |
| // $binding_buffers, |
| // type($binding_buffers), |
| // $binding_offsets, |
| // $binding_lengths) |
| //===----------------------------------------------------------------------===// |
| |
| static ParseResult parseDescriptorSetBindings( |
| OpAsmParser &parser, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &ordinals, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &buffers, |
| SmallVectorImpl<Type> &bufferTypes, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &bufferOffsets, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &bufferLengths) { |
| do { |
| OpAsmParser::UnresolvedOperand ordinal; |
| OpAsmParser::UnresolvedOperand buffer; |
| Type bufferType; |
| OpAsmParser::UnresolvedOperand bufferOffset; |
| OpAsmParser::UnresolvedOperand bufferLength; |
| if (failed(parser.parseOperand(ordinal)) || failed(parser.parseEqual()) || |
| failed(parser.parseLParen()) || failed(parser.parseOperand(buffer)) || |
| failed(parser.parseColonType(bufferType)) || |
| failed(parser.parseRParen()) || failed(parser.parseLSquare()) || |
| failed(parser.parseOperand(bufferOffset)) || |
| failed(parser.parseComma()) || |
| failed(parser.parseOperand(bufferLength)) || |
| failed(parser.parseRSquare())) { |
| return failure(); |
| } |
| ordinals.push_back(ordinal); |
| buffers.push_back(buffer); |
| bufferTypes.push_back(bufferType); |
| bufferOffsets.push_back(bufferOffset); |
| bufferLengths.push_back(bufferLength); |
| } while (succeeded(parser.parseOptionalComma())); |
| return success(); |
| } |
| |
| static void printDescriptorSetBindings(OpAsmPrinter &p, Operation *op, |
| ValueRange ordinals, ValueRange buffers, |
| TypeRange bufferTypes, |
| ValueRange bufferOffsets, |
| ValueRange bufferLengths) { |
| llvm::interleaveComma(llvm::zip_equal(ordinals, buffers, bufferTypes, |
| bufferOffsets, bufferLengths), |
| p, |
| [&](std::tuple<Value, Value, Type, Value, Value> it) { |
| p.printNewline(); |
| p << " "; |
| p.printOperand(std::get<0>(it)); |
| p << " = ("; |
| p.printOperand(std::get<1>(it)); |
| p << " : "; |
| p.printType(std::get<2>(it)); |
| p << ")["; |
| p.printOperand(std::get<3>(it)); |
| p << ", "; |
| p.printOperand(std::get<4>(it)); |
| p << "]"; |
| }); |
| p.printNewline(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<TargetConditionRegion>($body) |
| //===----------------------------------------------------------------------===// |
| |
| static FunctionType getTargetConditionRegionType(MLIRContext *context) { |
| return FunctionType::get(context, |
| { |
| IREE::HAL::DeviceType::get(context), |
| }, |
| { |
| IntegerType::get(context, 1), |
| }); |
| } |
| |
| static LogicalResult verifyTargetConditionRegion(Operation *op, |
| Region ®ion) { |
| // Ignore if empty. |
| if (region.empty()) |
| return success(); |
| |
| // Verify region takes a !hal.device. |
| if (region.getNumArguments() != 1 || |
| !isa<IREE::HAL::DeviceType>(region.getArgumentTypes().front())) { |
| return op->emitOpError() |
| << "target condition region must take a !hal.device"; |
| } |
| |
| // Verify i1 return. |
| for (auto returnOp : region.getOps<IREE::HAL::ReturnOp>()) { |
| if (returnOp.getNumOperands() != 1) { |
| return returnOp.emitOpError() |
| << "target condition region must return a single i1 result"; |
| } |
| for (auto returnType : returnOp.getOperandTypes()) { |
| if (!returnType.isInteger(1)) { |
| return returnOp.emitOpError() |
| << "target condition region must return a single i1 result"; |
| } |
| } |
| } |
| |
| return success(); |
| } |
| |
| static ParseResult parseTargetConditionRegion(OpAsmParser &parser, |
| Region &body) { |
| SmallVector<OpAsmParser::Argument> args; |
| if (failed(parser.parseArgumentList(args, AsmParser::Delimiter::Paren, |
| /*allowType=*/true, |
| /*allowAttrs=*/true))) { |
| return failure(); |
| } |
| |
| SmallVector<Type> returnTypes; |
| if (failed(parser.parseArrowTypeList(returnTypes))) { |
| return failure(); |
| } |
| if (returnTypes.size() != 1 || |
| !llvm::all_of(returnTypes, [](Type type) { return type.isInteger(1); })) { |
| return parser.emitError(parser.getCurrentLocation()) |
| << "target condition region must return one i1"; |
| } |
| |
| return parser.parseRegion(body, args, /*enableNameShadowing=*/false); |
| } |
| |
| static void printTargetConditionRegion(OpAsmPrinter &p, Operation *op, |
| Region &body) { |
| if (body.empty()) |
| return; |
| p << "("; |
| llvm::interleaveComma(body.getArguments(), p, |
| [&](BlockArgument arg) { p.printRegionArgument(arg); }); |
| p << ")"; |
| p.printArrowTypeList(TypeRange{IntegerType::get(body.getContext(), 1)}); |
| p << " "; |
| p.printRegion(body, /*printEntryBlockArgs=*/false, |
| /*printBlockTerminators=*/true); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<WorkgroupCountRegion>($body) |
| //===----------------------------------------------------------------------===// |
| |
| static ParseResult parseWorkgroupCountRegion(OpAsmParser &parser, |
| Region &body) { |
| SmallVector<OpAsmParser::Argument> args; |
| if (failed(parser.parseArgumentList(args, AsmParser::Delimiter::Paren, |
| /*allowType=*/true, |
| /*allowAttrs=*/true))) { |
| return failure(); |
| } |
| |
| // Return types must be 3 dimensions (workgroup count XYZ). |
| SmallVector<Type> returnTypes; |
| if (failed(parser.parseArrowTypeList(returnTypes))) { |
| return failure(); |
| } |
| if (returnTypes.size() != 3 || |
| !llvm::all_of(returnTypes, [](Type type) { return type.isIndex(); })) { |
| return parser.emitError(parser.getCurrentLocation()) |
| << "workgroup count region must return the XYZ dimension counts"; |
| } |
| |
| // Parse region contents. |
| if (failed(parser.parseRegion(body, args, /*enableNameShadowing=*/false))) { |
| return failure(); |
| } |
| |
| // Verify the return types match. |
| for (auto returnOp : body.getOps<IREE::HAL::ReturnOp>()) { |
| for (auto [resultType, returnType] : |
| llvm::zip_equal(returnTypes, returnOp.getOperandTypes())) { |
| if (resultType != returnType) { |
| return returnOp.emitOpError() |
| << "operands do not match expected region return types"; |
| } |
| } |
| } |
| |
| return success(); |
| } |
| |
| static void printWorkgroupCountRegion(OpAsmPrinter &p, Operation *op, |
| Region &body) { |
| if (body.empty()) |
| return; |
| p << "("; |
| llvm::interleaveComma(body.getArguments(), p, |
| [&](BlockArgument arg) { p.printRegionArgument(arg); }); |
| p << ")"; |
| Type indexType = IndexType::get(body.getContext()); |
| p.printArrowTypeList(TypeRange{indexType, indexType, indexType}); |
| p << " "; |
| p.printRegion(body, /*printEntryBlockArgs=*/false, |
| /*printBlockTerminators=*/true); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.ex.* |
| //===----------------------------------------------------------------------===// |
| |
| void ExSharedDeviceOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "device"); |
| } |
| |
| void ExFileFromMemoryOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "memory_file"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.return |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ReturnOp::verify() { |
| ReturnOp op = *this; |
| |
| auto parentFuncOp = dyn_cast_or_null<FunctionOpInterface>(op->getParentOp()); |
| if (parentFuncOp) { |
| auto expectedTypes = parentFuncOp.getResultTypes(); |
| if (op.getNumOperands() != expectedTypes.size()) { |
| return op.emitOpError() << "return must have the same number of operands " |
| "as the parent result signature (have " |
| << op.getNumOperands() << ", expected " |
| << expectedTypes.size() << ")"; |
| } |
| for (auto &&[index, values] : |
| llvm::enumerate(llvm::zip_equal(op.getOperands(), expectedTypes))) { |
| auto [operand, expectedType] = values; |
| if (operand.getType() != expectedType) { |
| return op.emitOpError() |
| << "parent expected result " << index << " to be " |
| << expectedType << " but returning " << operand.getType(); |
| } |
| } |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.tensor.import/export |
| //===----------------------------------------------------------------------===// |
| |
| void TensorImportOp::build(OpBuilder &builder, OperationState &result, |
| Type resultType, Value source, |
| TypeAttr targetEncoding, StringAttr name) { |
| build(builder, result, resultType, source, targetEncoding, |
| /*waitFence=*/Value{}, name); |
| } |
| |
| void TensorImportOp::build(OpBuilder &builder, OperationState &result, |
| Type resultType, Value source, |
| TypeAttr targetEncoding, Value waitFence, |
| StringAttr name) { |
| auto shapedType = llvm::cast<ShapedType>(resultType); |
| assert((source.getType().isa<IREE::HAL::BufferViewType>() || |
| shapedType.hasStaticShape()) && |
| "can only use this constructor for buffer views when shape " |
| "information is required"); |
| SmallVector<Value> dynamicDims; |
| for (int64_t i = 0; i < shapedType.getRank(); ++i) { |
| if (!shapedType.isDynamicDim(i)) |
| continue; |
| dynamicDims.push_back(builder.createOrFold<IREE::HAL::BufferViewDimOp>( |
| result.location, builder.getIndexType(), source, |
| builder.getIndexAttr(i))); |
| } |
| build(builder, result, resultType, source, targetEncoding, dynamicDims, |
| waitFence, name); |
| } |
| |
| Value TensorImportOp::getTiedResult(unsigned resultIndex) { |
| return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource()); |
| } |
| |
| ::std::optional<unsigned> |
| TensorImportOp::getTiedResultOperandIndex(unsigned resultIndex) { |
| return {0}; // source |
| } |
| |
| SmallVector<int64_t> TensorImportOp::getTiedResultOperandIndices() { |
| return {0}; // source |
| } |
| |
| static LogicalResult verifyTypeStorageCompatibility(Operation *op, |
| Type encodingType, |
| Type storageType) { |
| if (encodingType == storageType) |
| return success(); |
| auto encodingShapedType = llvm::dyn_cast<ShapedType>(encodingType); |
| auto storageShapedType = llvm::dyn_cast<ShapedType>(storageType); |
| if (!encodingShapedType || !storageShapedType) |
| return success(); |
| |
| if (IREE::Util::getRoundedElementByteWidth( |
| encodingShapedType.getElementType()) != |
| IREE::Util::getRoundedElementByteWidth( |
| storageShapedType.getElementType())) { |
| // TODO(benvanik): more sophisticated logic here. There are a lot of valid |
| // cases that are difficult to account for here statically; for example, |
| // packing 8xi1 into 1xi8 or complex<f32> into 2xf32. We could try to guess |
| // the element count (at least the static part of it) and ensure the scaling |
| // matches but that wouldn't account for user variance. Really with this op |
| // we are letting the _user_ control the bitcasting and type reflection and |
| // purposefully don't want to mess with it (users should be able to put |
| // custom types here, etc). |
| // |
| // NOTE: we round to bytes first as the base type (such as i1) may not be |
| // representable in an external form. |
| // return op->emitOpError() << "encoding and storage types must be " |
| // "bitcastable; adjusted encoding bit width " |
| // "of " |
| // << encodingShapedType.getElementTypeBitWidth() |
| // << " != adjusted storage bit width of " |
| // << storageShapedType.getElementTypeBitWidth(); |
| } |
| |
| if (encodingShapedType.getNumDynamicDims() != |
| storageShapedType.getNumDynamicDims()) { |
| // NOTE: we implicitly require that the dimensions are equivalent but |
| // dont actually care about their order. For example, tensor<?x1xf32> is |
| // compatible with tensor<?xf32>. |
| return op->emitOpError() |
| << "encoding and storage types must have the same " |
| "dynamic dimension values; encoding shape " |
| << encodingShapedType << " incompatible with storage shape " |
| << storageShapedType; |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult TensorImportOp::verify() { |
| TensorImportOp op = *this; |
| auto targetType = llvm::cast<TensorType>(op.getTarget().getType()); |
| if (targetType.getNumDynamicDims() != op.getTargetDims().size()) { |
| return op->emitOpError() << "number of target_dims must match number of " |
| "dynamic dims in target type"; |
| } |
| return verifyTypeStorageCompatibility(op, op.getTargetEncoding(), targetType); |
| } |
| |
| void TensorExportOp::build(OpBuilder &builder, OperationState &result, |
| Type resultType, Value source, |
| TypeAttr sourceEncoding, StringAttr name) { |
| auto dynamicDims = |
| IREE::Util::buildDynamicDimsForValue(result.location, source, builder); |
| build(builder, result, resultType, source, sourceEncoding, dynamicDims, |
| /*target_storage=*/nullptr, name); |
| } |
| |
| Value TensorExportOp::getTiedResult(unsigned resultIndex) { |
| return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource()); |
| } |
| |
| ::std::optional<unsigned> |
| TensorExportOp::getTiedResultOperandIndex(unsigned resultIndex) { |
| return {0}; // source |
| } |
| |
| SmallVector<int64_t> TensorExportOp::getTiedResultOperandIndices() { |
| return {0}; // source |
| } |
| |
| LogicalResult TensorExportOp::verify() { |
| TensorExportOp op = *this; |
| auto sourceType = llvm::cast<TensorType>(op.getSource().getType()); |
| if (sourceType.getNumDynamicDims() != op.getSourceDims().size()) { |
| return op->emitOpError() << "number of source_dims must match number of " |
| "dynamic dims in source type"; |
| } |
| return verifyTypeStorageCompatibility(op, op.getSourceEncoding(), |
| op.getSource().getType()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.tensor.barrier |
| //===----------------------------------------------------------------------===// |
| |
| void TensorBarrierOp::build(OpBuilder &builder, OperationState &result, |
| ValueRange sources, Value signalFence) { |
| auto resultTypes = llvm::map_to_vector( |
| sources, [](Value source) { return source.getType(); }); |
| build(builder, result, resultTypes, sources, signalFence); |
| } |
| |
| Value TensorBarrierOp::getTiedResult(unsigned resultIndex) { |
| return IREE::Util::TiedOpInterface::findTiedBaseValue( |
| getSources()[resultIndex]); |
| } |
| |
| ::std::optional<unsigned> |
| TensorBarrierOp::getTiedResultOperandIndex(unsigned resultIndex) { |
| return {resultIndex}; // sources[i] |
| } |
| |
| SmallVector<int64_t> TensorBarrierOp::getTiedResultOperandIndices() { |
| size_t numSources = getSources().size(); |
| return llvm::to_vector(llvm::seq<int64_t>(0, numSources)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.dispatch.extern |
| //===----------------------------------------------------------------------===// |
| |
| void DispatchExternOp::build(OpBuilder &builder, OperationState &state, |
| ValueRange workload, TypeRange resultTypes, |
| ValueRange resultDims, ValueRange arguments, |
| ValueRange argumentDims, |
| ArrayRef<int64_t> tiedOperands, |
| ArrayRef<NamedAttribute> attributes) { |
| state.addTypes(resultTypes); |
| state.addOperands(workload); |
| state.addOperands(arguments); |
| state.addOperands(argumentDims); |
| state.addOperands(resultDims); |
| state.addAttributes(attributes); |
| state.attributes.erase(IREE::Util::TiedOpInterface::getStorageAttrName()); |
| state.addAttribute(IREE::Util::TiedOpInterface::getStorageAttrName(), |
| builder.getIndexArrayAttr(tiedOperands)); |
| state.attributes.erase(getOperandSegmentSizeAttr()); |
| state.addAttribute(getOperandSegmentSizeAttr(), |
| builder.getDenseI32ArrayAttr({ |
| static_cast<int32_t>(workload.size()), |
| static_cast<int32_t>(arguments.size()), |
| static_cast<int32_t>(argumentDims.size()), |
| static_cast<int32_t>(resultDims.size()), |
| })); |
| |
| // NOTE: workgroup count region is empty; callers are expected to populate it. |
| state.addRegion(); |
| } |
| |
| // Verifies that |dynamicDims| contains the appropriate number of dims for all |
| // of the dynamic dimensions in |values|. |
| static LogicalResult verifyOpDynamicDims(Operation *op, ValueRange values, |
| ValueRange dynamicDims) { |
| unsigned requiredCount = 0; |
| for (auto value : values) { |
| if (auto shapedType = llvm::dyn_cast<ShapedType>(value.getType())) { |
| requiredCount += shapedType.getNumDynamicDims(); |
| } |
| } |
| if (dynamicDims.size() != requiredCount) { |
| return op->emitOpError() |
| << "value set has " << requiredCount |
| << " dynamic dimensions but only " << dynamicDims.size() |
| << " dimension values are attached"; |
| } |
| return success(); |
| } |
| |
| static LogicalResult |
| verifyWorkgroupCountRegion(Operation *op, ValueRange workload, Region ®ion) { |
| // Verify the workload operands match the expected capture args. |
| auto regionArguments = |
| llvm::make_filter_range(region.getArgumentTypes(), [](Type type) { |
| return !type.isa<IREE::HAL::DeviceType>(); |
| }); |
| if (workload.size() != llvm::range_size(regionArguments)) { |
| return op->emitOpError() |
| << "workload operands and workgroup count args mismatch (" |
| << workload.size() << " vs " << llvm::range_size(regionArguments) |
| << ")"; |
| } |
| for (auto [index, values] : |
| llvm::enumerate(llvm::zip_equal(workload, regionArguments))) { |
| auto [workloadValue, capturedType] = values; |
| if (workloadValue.getType() != capturedType) { |
| return op->emitOpError() |
| << "workload value " << index << " type mismatch; operand is " |
| << workloadValue.getType() << " but region captures " |
| << capturedType; |
| } |
| } |
| return success(); |
| } |
| |
| LogicalResult DispatchExternOp::verify() { |
| Operation *op = getOperation(); |
| |
| if (failed(verifyOpDynamicDims(getOperation(), getArguments(), |
| getArgumentDims())) || |
| failed( |
| verifyOpDynamicDims(getOperation(), getResults(), getResultDims()))) { |
| return failure(); |
| } |
| |
| auto verifyIOType = [&](Type type) -> LogicalResult { |
| if (auto shapedType = llvm::dyn_cast<ShapedType>(type)) { |
| if (shapedType.getElementType().isIndex()) { |
| return op->emitOpError() << "I/O type " << type |
| << " is invalid: index types must not cross " |
| "the dispatch boundary"; |
| } |
| } |
| return success(); |
| }; |
| for (auto type : getOperandTypes()) { |
| if (failed(verifyIOType(type))) |
| return failure(); |
| } |
| for (auto type : getResultTypes()) { |
| if (failed(verifyIOType(type))) |
| return failure(); |
| } |
| |
| if (failed( |
| verifyWorkgroupCountRegion(op, getWorkload(), getWorkgroupCount()))) { |
| return failure(); |
| } |
| |
| return success(); |
| } |
| |
| std::pair<unsigned, unsigned> |
| DispatchExternOp::getTiedOperandsIndexAndLength() { |
| return getODSOperandIndexAndLength(1); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.allocator.allocate |
| //===----------------------------------------------------------------------===// |
| |
| void AllocatorAllocateOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "buffer"); |
| } |
| |
| Value AllocatorAllocateOp::getOperandSize(unsigned idx) { return {}; } |
| |
| Value AllocatorAllocateOp::getResultSize(unsigned idx) { |
| return getResultSize(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.allocator.import |
| //===----------------------------------------------------------------------===// |
| |
| void AllocatorImportOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getDidImport(), "did_import"); |
| setNameFn(getResult(), "mapped"); |
| } |
| |
| Value AllocatorImportOp::getOperandSize(unsigned idx) { return {}; } |
| |
| Value AllocatorImportOp::getResultSize(unsigned idx) { return getLength(); } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.buffer.subspan |
| //===----------------------------------------------------------------------===// |
| |
| void BufferSubspanOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "buffer"); |
| } |
| |
| Value BufferSubspanOp::getOperandSize(unsigned idx) { return getLength(); } |
| |
| Value BufferSubspanOp::getResultSize(unsigned idx) { return getLength(); } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.buffer.byte_length |
| //===----------------------------------------------------------------------===// |
| |
| void BufferLengthOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "len"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.buffer_view.create |
| //===----------------------------------------------------------------------===// |
| |
| void BufferViewCreateOp::build(OpBuilder &builder, OperationState &state, |
| Value sourceBuffer, Value sourceOffset, |
| Value sourceLength, int32_t elementType, |
| int32_t encodingType, ValueRange shape) { |
| build(builder, state, sourceBuffer, sourceOffset, sourceLength, |
| builder.createOrFold<arith::ConstantIntOp>(state.location, elementType, |
| 32), |
| builder.createOrFold<arith::ConstantIntOp>(state.location, encodingType, |
| 32), |
| shape); |
| } |
| |
| void BufferViewCreateOp::build(OpBuilder &builder, OperationState &state, |
| Value sourceBuffer, Value sourceOffset, |
| Value sourceLength, Value elementType, |
| Value encodingType, ValueRange shape) { |
| state.addOperands( |
| {sourceBuffer, sourceOffset, sourceLength, elementType, encodingType}); |
| state.addOperands(shape); |
| state.addTypes({BufferViewType::get(builder.getContext())}); |
| } |
| |
| void BufferViewCreateOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "view"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.buffer_view.buffer |
| //===----------------------------------------------------------------------===// |
| |
| void BufferViewBufferOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "buffer"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.channel.create |
| //===----------------------------------------------------------------------===// |
| |
| void ChannelCreateOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "channel"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.channel.split |
| //===----------------------------------------------------------------------===// |
| |
| void ChannelSplitOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "channel"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.channel.rank_and_count |
| //===----------------------------------------------------------------------===// |
| |
| void ChannelRankAndCountOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getRank(), "ccl_rank"); |
| setNameFn(getCount(), "ccl_count"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.command_buffer.create |
| //===----------------------------------------------------------------------===// |
| |
| void CommandBufferCreateOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "cmd"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.command_buffer.push_descriptor_set |
| //===----------------------------------------------------------------------===// |
| |
| void CommandBufferPushDescriptorSetOp::build( |
| OpBuilder &builder, OperationState &state, Value commandBuffer, |
| Value pipelineLayout, int64_t set, |
| ArrayRef<DescriptorSetBindingValue> bindings) { |
| build(builder, state, commandBuffer, pipelineLayout, |
| builder.createOrFold<arith::ConstantIndexOp>(state.location, set), |
| bindings); |
| } |
| |
| void CommandBufferPushDescriptorSetOp::build( |
| OpBuilder &builder, OperationState &state, Value commandBuffer, |
| Value pipelineLayout, Value set, |
| ArrayRef<DescriptorSetBindingValue> bindings) { |
| state.addOperands({commandBuffer, pipelineLayout, set}); |
| SmallVector<Value> bindingOrdinals; |
| SmallVector<Value> bindingBuffers; |
| SmallVector<Value> bindingOffsets; |
| SmallVector<Value> bindingLengths; |
| for (auto binding : bindings) { |
| bindingOrdinals.push_back(binding.ordinal); |
| bindingBuffers.push_back(binding.buffer); |
| bindingOffsets.push_back(binding.byteOffset); |
| bindingLengths.push_back(binding.byteLength); |
| } |
| state.addOperands(bindingOrdinals); |
| state.addOperands(bindingBuffers); |
| state.addOperands(bindingOffsets); |
| state.addOperands(bindingLengths); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.descriptor_set_layout.create |
| //===----------------------------------------------------------------------===// |
| |
| void DescriptorSetLayoutCreateOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "descriptor_set_layout"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.descriptor_set_layout.lookup |
| //===----------------------------------------------------------------------===// |
| |
| void DescriptorSetLayoutLookupOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "descriptor_set_layout"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.device.allocator |
| //===----------------------------------------------------------------------===// |
| |
| void DeviceAllocatorOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "allocator"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.device.query |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult DeviceQueryOp::verify() { |
| DeviceQueryOp op = *this; |
| if (op.getDefaultValue().has_value()) { |
| if (op.getDefaultValue()->getType() != op.getValue().getType()) { |
| return op.emitOpError() |
| << "type mismatch between result and default value"; |
| } |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.device.queue.* |
| //===----------------------------------------------------------------------===// |
| |
| void DeviceQueueAllocaOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "transient_buffer"); |
| } |
| |
| Value DeviceQueueAllocaOp::getOperandSize(unsigned idx) { return {}; } |
| |
| Value DeviceQueueAllocaOp::getResultSize(unsigned idx) { |
| return getResultSize(); |
| } |
| |
| static LogicalResult verifyDeviceQueueFences(Operation *queueOp, |
| Value waitFence, |
| Value signalFence) { |
| if (waitFence == signalFence && |
| !isa<IREE::Util::NullOp>(waitFence.getDefiningOp())) { |
| return queueOp->emitOpError() << "device queue operations cannot wait and " |
| "signal on the same fence."; |
| } |
| return success(); |
| } |
| |
| LogicalResult DeviceQueueAllocaOp::verify() { |
| return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence()); |
| } |
| |
| LogicalResult DeviceQueueDeallocaOp::verify() { |
| return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence()); |
| } |
| |
| LogicalResult DeviceQueueReadOp::verify() { |
| return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence()); |
| } |
| |
| LogicalResult DeviceQueueWriteOp::verify() { |
| return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence()); |
| } |
| |
| LogicalResult DeviceQueueExecuteOp::verify() { |
| return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.executable.source |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ExecutableSourceOp::verify() { |
| ExecutableSourceOp op = *this; |
| |
| auto conditionOps = getOps<IREE::HAL::ExecutableConditionOp>(); |
| if (llvm::range_size(conditionOps) > 1) |
| return op.emitOpError() |
| << "only one condition op is allowed in an executable"; |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.executable |
| //===----------------------------------------------------------------------===// |
| |
| void ExecutableOp::build(OpBuilder &builder, OperationState &state, |
| StringRef name) { |
| ensureTerminator(*state.addRegion(), builder, state.location); |
| state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), |
| builder.getStringAttr(name)); |
| } |
| |
| LogicalResult ExecutableOp::verify() { |
| // TODO(benvanik): check export name conflicts. |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.executable.export |
| //===----------------------------------------------------------------------===// |
| |
| ParseResult ExecutableExportOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| StringAttr visibilityAttr; |
| if (failed(parseSymbolVisibility(parser, visibilityAttr))) { |
| return failure(); |
| } |
| |
| StringAttr nameAttr; |
| IREE::HAL::PipelineLayoutAttr layoutAttr; |
| if (failed(parser.parseSymbolName(nameAttr, |
| mlir::SymbolTable::getSymbolAttrName(), |
| result.attributes))) { |
| return failure(); |
| } |
| if (succeeded(parser.parseOptionalKeyword("ordinal"))) { |
| IntegerAttr ordinalAttr; |
| if (failed(parser.parseLParen()) || |
| failed(parser.parseAttribute(ordinalAttr, |
| parser.getBuilder().getIndexType())) || |
| failed(parser.parseRParen())) { |
| return failure(); |
| } |
| result.addAttribute("ordinal", ordinalAttr); |
| } |
| if (failed(parser.parseKeyword("layout")) || failed(parser.parseLParen()) || |
| failed(parser.parseAttribute(layoutAttr)) || |
| failed(parser.parseRParen()) || |
| failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) { |
| return failure(); |
| } |
| result.addAttribute("layout", layoutAttr); |
| |
| std::unique_ptr<Region> region; |
| SmallVector<OpAsmParser::Argument> regionOperands; |
| // A missing optional region is materialized as an empty region. |
| (void)parser.parseOptionalRegion(region, regionOperands); |
| result.addRegion(std::move(region)); |
| |
| return success(); |
| } |
| |
| void ExecutableExportOp::print(OpAsmPrinter &p) { |
| Operation *op = getOperation(); |
| p << ' '; |
| printSymbolVisibility(p, op, op->getAttrOfType<StringAttr>("sym_visibility")); |
| p << ' '; |
| p.printSymbolName(getSymName()); |
| if (getOrdinalAttr()) { |
| p << " ordinal("; |
| p.printAttributeWithoutType(getOrdinalAttr()); |
| p << ")"; |
| } |
| p << " layout("; |
| p.printAttribute(getLayout()); |
| p << ")"; |
| p.printOptionalAttrDictWithKeyword( |
| op->getAttrs(), |
| /*elidedAttrs=*/{"sym_name", "layout", "ordinal"}); |
| if (getWorkgroupCount().empty()) |
| return; |
| p << " "; |
| p.printRegion(getWorkgroupCount()); |
| } |
| |
| LogicalResult ExecutableExportOp::verify() { |
| ExecutableExportOp op = *this; |
| Block *body = getWorkgroupCountBody(); |
| // When there is no body, nothing to verify. |
| if (!body) |
| return success(); |
| |
| if (!llvm::hasSingleElement(getWorkgroupCount())) { |
| return op.emitOpError() << "expected a single region block"; |
| } |
| bool validArguments = true; |
| if (body->getNumArguments() == 0) { |
| // Need at least a !hal.device. |
| validArguments = false; |
| } else if (!llvm::isa<IREE::HAL::DeviceType>( |
| body->getArgument(0).getType())) { |
| // !hal.device must come first. |
| validArguments = false; |
| } else { |
| // All remaining arguments need to be of type index (today). |
| for (BlockArgument &blockArg : body->getArguments().drop_front(1)) { |
| if (!llvm::isa<IndexType>(blockArg.getType())) { |
| validArguments = false; |
| break; |
| } |
| } |
| } |
| if (!validArguments) { |
| return op.emitOpError( |
| "expected workgroup_count to take (%device: !hal.device, " |
| "%workload_0: index, %workload_1: index, ..."); |
| } |
| // Check that the last statement in the block is `hal.return` operation. |
| // TODO(ravishankarm): The SingleBlockImplicitTerminator<"HAL::ReturnOp"> |
| // should generate this check, but it doesnt. |
| auto returnOp = dyn_cast<ReturnOp>(body->getTerminator()); |
| if (!returnOp || returnOp.getOperands().size() != getNumWorkgroupDims()) { |
| return op.emitOpError("expected operation to yield ") |
| << getNumWorkgroupDims() << " values"; |
| } |
| return success(); |
| } |
| |
| // Calculates the workgroup count (x, y, z) given the total N-dimensional |
| // |workload| and specific |workgroupSize|. |
| static std::array<Value, 3> |
| calculateWorkloadWorkgroupCount(Location loc, ValueRange workload, |
| const std::array<Value, 3> &workgroupSize, |
| OpBuilder &builder) { |
| std::array<Value, 3> result; |
| |
| auto constantOne = builder.createOrFold<arith::ConstantIndexOp>(loc, 1); |
| if (workload.size() <= 3) { |
| // 1-D to 3-D are easy (pad 2 to 0 dimensions) and divide by workgroup |
| // size. |
| for (int i = 0; i < 3; ++i) { |
| // Round up: (workload[i] + workgroup_size - 1) / workgroup_size; |
| Value workloadI = i < workload.size() ? workload[i] : constantOne; |
| workloadI = builder.createOrFold<arith::SubIOp>( |
| loc, |
| builder.createOrFold<arith::AddIOp>(loc, workloadI, workgroupSize[i]), |
| constantOne); |
| result[i] = builder.createOrFold<arith::DivUIOp>(loc, workloadI, |
| workgroupSize[i]); |
| } |
| } else { |
| // TODO(#4140): remapping of N-D to 3-D: this is not how you do this! |
| Value flatWorkload = constantOne; |
| for (auto workloadI : workload) { |
| flatWorkload = |
| builder.createOrFold<arith::MulIOp>(loc, flatWorkload, workloadI); |
| } |
| for (int i = 0; i < 3; ++i) { |
| // Round up: (workload[i] + workgroup_size - 1) / workgroup_size; |
| auto rounded = builder.createOrFold<arith::SubIOp>( |
| loc, |
| builder.createOrFold<arith::AddIOp>(loc, flatWorkload, |
| workgroupSize[i]), |
| constantOne); |
| auto workgroupCountI = |
| builder.createOrFold<arith::DivUIOp>(loc, rounded, workgroupSize[i]); |
| result[i] = workgroupCountI; |
| |
| // Multiply back out and subtract from invocations. |
| flatWorkload = builder.createOrFold<arith::SubIOp>( |
| loc, flatWorkload, |
| builder.createOrFold<arith::MulIOp>(loc, workgroupCountI, rounded)); |
| } |
| } |
| |
| return result; |
| } |
| |
| static std::array<Value, 3> |
| calculateWorkgroupCountFromRegion(Location loc, Block *body, Value device, |
| ValueRange workload, OpBuilder &builder) { |
| // TODO(benvanik): replace with region inlining util. |
| IRMapping bvm; |
| bvm.map(body->getArgument(0), device); |
| // For now use the number of args to minimum of number of args used by |
| // the body, and number of workload entries. When there is a more explicit |
| // propagation of number of workload entries to the `hal.executable.variant` |
| // this will be the same by construction. |
| unsigned numArgs = |
| std::min<unsigned>(body->getNumArguments() - 1, workload.size()); |
| for (unsigned argNum : llvm::seq<unsigned>(0, numArgs)) { |
| bvm.map(body->getArgument(/*device*/ 1 + argNum), workload[argNum]); |
| } |
| for (Operation &op : body->without_terminator()) { |
| builder.clone(op, bvm); |
| } |
| auto returnOp = cast<IREE::HAL::ReturnOp>(body->getTerminator()); |
| assert(returnOp.getNumOperands() == 3 && "must return xyz"); |
| return { |
| bvm.lookup(returnOp.getOperands()[0]), |
| bvm.lookup(returnOp.getOperands()[1]), |
| bvm.lookup(returnOp.getOperands()[2]), |
| }; |
| } |
| |
| // Calculates the workgroup count (x, y, z) for dispatching to the entry point. |
| // The provided N-dimensional |workload| is the total number of invocations |
| // required as calculated by the generic workload logic (basically, number of |
| // output elements in tensors). |
| std::array<Value, 3> ExecutableExportOp::calculateWorkgroupCount( |
| Location loc, Value device, ValueRange workload, OpBuilder &builder) { |
| Block *body = getWorkgroupCountBody(); |
| if (body) { |
| return calculateWorkgroupCountFromRegion(loc, body, device, workload, |
| builder); |
| } |
| auto workgroupSize = calculateWorkgroupSize(loc, device, workload, builder); |
| return calculateWorkloadWorkgroupCount(loc, workload, workgroupSize, builder); |
| } |
| |
| // Calculates the workgroup size (x, y, z). These are the dimension numbers |
| // for a single workgroup. |
| std::array<Value, 3> ExecutableExportOp::calculateWorkgroupSize( |
| Location loc, Value device, ValueRange workload, OpBuilder &builder) { |
| // When no workgroup size is specified we just assume [1,1,1]. |
| // This yields a workgroup count that models the extents of the workload. |
| return { |
| builder.createOrFold<arith::ConstantIndexOp>(loc, 1), |
| builder.createOrFold<arith::ConstantIndexOp>(loc, 1), |
| builder.createOrFold<arith::ConstantIndexOp>(loc, 1), |
| }; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.executable.variant |
| //===----------------------------------------------------------------------===// |
| |
| void ExecutableVariantOp::build(OpBuilder &builder, OperationState &state, |
| StringRef symName, |
| IREE::HAL::ExecutableTargetAttr target) { |
| ensureTerminator(*state.addRegion(), builder, state.location); |
| state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), |
| builder.getStringAttr(symName)); |
| state.addAttribute("target", target); |
| } |
| |
| LogicalResult ExecutableVariantOp::verify() { |
| ExecutableVariantOp op = *this; |
| |
| auto conditionOps = getOps<IREE::HAL::ExecutableConditionOp>(); |
| if (llvm::range_size(conditionOps) > 1) |
| return op.emitOpError() << "only one condition op is allowed in a variant"; |
| |
| return success(); |
| } |
| |
| DenseMap<Attribute, int> ExecutableVariantOp::gatherConstantOrdinals() { |
| DenseMap<Attribute, int> map; |
| for (auto blockOp : getConstantBlockOps()) { |
| int baseCount = map.size(); |
| for (auto [i, keyAttr] : llvm::enumerate(blockOp.getKeys())) { |
| map.try_emplace(keyAttr, baseCount + i); |
| } |
| } |
| return map; |
| } |
| |
| Value ExecutableVariantOp::buildCondition(Value device, OpBuilder &builder) { |
| // Base case dependent on target information. |
| auto matchAttr = |
| cast<IREE::HAL::MatchAttrInterface>(getTarget().getMatchExpression()); |
| auto selected = matchAttr.buildConditionExpression(getLoc(), device, builder); |
| |
| // Factor in variant condition region, if any. |
| auto conditionOp = getConditionOp(); |
| if (conditionOp) { |
| auto regionOp = builder.create<scf::ExecuteRegionOp>(conditionOp.getLoc(), |
| builder.getI1Type()); |
| |
| IRMapping mapper; |
| mapper.map(conditionOp.getRegion().getArgument(0), device); |
| conditionOp.getRegion().cloneInto(®ionOp.getRegion(), mapper); |
| |
| for (auto returnOp : |
| llvm::make_early_inc_range(regionOp.getOps<IREE::HAL::ReturnOp>())) { |
| OpBuilder(returnOp).create<scf::YieldOp>(returnOp.getLoc(), |
| returnOp.getOperands()); |
| returnOp.erase(); |
| } |
| |
| selected = builder.create<arith::AndIOp>(getLoc(), selected, |
| regionOp.getResult(0)); |
| } |
| |
| return selected; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.executable.condition |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ExecutableConditionOp::verify() { |
| ExecutableConditionOp op = *this; |
| return verifyTargetConditionRegion(op, op.getBody()); |
| } |
| |
| void ExecutableConditionOp::build(OpBuilder &builder, OperationState &result, |
| ArrayRef<NamedAttribute> attrs) { |
| result.addAttribute( |
| "function_type", |
| TypeAttr::get(getTargetConditionRegionType(builder.getContext()))); |
| result.addRegion(); |
| result.attributes.append(attrs.begin(), attrs.end()); |
| } |
| |
| ParseResult ExecutableConditionOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| if (parseTargetConditionRegion(parser, *result.addRegion())) |
| return failure(); |
| result.addAttribute( |
| "function_type", |
| TypeAttr::get(getTargetConditionRegionType(parser.getContext()))); |
| if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) |
| return failure(); |
| return success(); |
| } |
| |
| void ExecutableConditionOp::print(OpAsmPrinter &p) { |
| Operation *op = getOperation(); |
| printTargetConditionRegion(p, op, getBody()); |
| p.printOptionalAttrDictWithKeyword(op->getAttrs(), |
| /*elidedAttrs=*/{"function_type"}); |
| } |
| |
| Block *ExecutableConditionOp::addEntryBlock() { |
| assert(empty() && "function already has an entry block"); |
| auto *entry = new Block(); |
| auto argTypes = getArgumentTypes(); |
| SmallVector<Location> argLocs(argTypes.size(), getLoc()); |
| entry->addArguments(argTypes, argLocs); |
| push_back(entry); |
| return entry; |
| } |
| |
| Block *ExecutableConditionOp::addBlock() { |
| assert(!empty() && "function should at least have an entry block"); |
| push_back(new Block()); |
| return &back(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.executable.constant.block |
| //===----------------------------------------------------------------------===// |
| |
| ParseResult ExecutableConstantBlockOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| auto &builder = parser.getBuilder(); |
| |
| // Parse the function signature. |
| SmallVector<OpAsmParser::Argument> entryArgs; |
| bool isVariadic = false; |
| SmallVector<DictionaryAttr> resultAttrs; |
| SmallVector<Type> resultTypes; |
| if (mlir::function_interface_impl::parseFunctionSignature( |
| parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, |
| resultAttrs)) { |
| return failure(); |
| } |
| SmallVector<Type> argTypes; |
| for (auto &arg : entryArgs) |
| argTypes.push_back(arg.type); |
| auto fnType = builder.getFunctionType(argTypes, resultTypes); |
| result.addAttribute(getFunctionTypeAttrName(result.name), |
| TypeAttr::get(fnType)); |
| |
| // Parse the keys used for each yielded constant value. |
| // There must be one key per result. Note that we support omitted parens when |
| // only one result is present. |
| SmallVector<Attribute> keyAttrs; |
| if (failed(parser.parseKeyword("as"))) |
| return failure(); |
| if (resultTypes.size() == 1) { |
| std::string key; |
| if (failed(parser.parseString(&key))) |
| return failure(); |
| keyAttrs.push_back(builder.getStringAttr(key)); |
| } else { |
| if (failed(parser.parseCommaSeparatedList( |
| AsmParser::Delimiter::OptionalParen, |
| [&]() { |
| std::string key; |
| if (failed(parser.parseString(&key))) |
| return failure(); |
| keyAttrs.push_back(builder.getStringAttr(key)); |
| return success(); |
| }, |
| "containing a 1:1 list of keys per yielded value"))) { |
| return failure(); |
| } |
| } |
| result.addAttribute("keys", builder.getArrayAttr(keyAttrs)); |
| |
| // If function attributes are present, parse them. |
| NamedAttrList parsedAttributes; |
| if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes)) { |
| return failure(); |
| } |
| result.attributes.append(parsedAttributes); |
| |
| // Add the attributes to the function arguments. |
| assert(resultAttrs.size() == resultTypes.size()); |
| mlir::function_interface_impl::addArgAndResultAttrs( |
| builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), |
| getResAttrsAttrName(result.name)); |
| |
| // Parse the optional function body. The printer will not print the body if |
| // its empty, so disallow parsing of empty body in the parser. |
| auto *body = result.addRegion(); |
| SMLoc loc = parser.getCurrentLocation(); |
| if (failed(parser.parseRegion(*body, entryArgs, |
| /*enableNameShadowing=*/false))) { |
| return failure(); |
| } |
| // Function body was parsed, make sure its not empty. |
| if (body->empty()) { |
| return parser.emitError(loc, "expected non-empty function body"); |
| } |
| |
| return success(); |
| } |
| |
| void ExecutableConstantBlockOp::print(OpAsmPrinter &p) { |
| Operation *op = getOperation(); |
| ArrayRef<Type> argTypes = getArgumentTypes(); |
| ArrayRef<Type> resultTypes = getResultTypes(); |
| mlir::function_interface_impl::printFunctionSignature( |
| p, cast<FunctionOpInterface>(op), argTypes, /*isVariadic=*/false, |
| resultTypes); |
| p << " as "; |
| if (resultTypes.size() != 1) |
| p << '('; |
| llvm::interleaveComma(getKeys().getValue(), p, |
| [&](Attribute attr) { p << attr; }); |
| if (resultTypes.size() != 1) |
| p << ')'; |
| mlir::function_interface_impl::printFunctionAttributes( |
| p, op, {getFunctionTypeAttrName(), getKeysAttrName()}); |
| p << " "; |
| p.printRegion(getBody(), /*printEntryBlockArgs=*/false, |
| /*printBlockTerminators=*/true); |
| } |
| |
| LogicalResult ExecutableConstantBlockOp::verify() { |
| ExecutableConstantBlockOp op = *this; |
| |
| // Verify the function takes either nothing or a device. |
| auto argTypes = op.getArgumentTypes(); |
| if (!argTypes.empty() && |
| (argTypes.size() > 1 || !llvm::isa<IREE::HAL::DeviceType>(argTypes[0]))) { |
| return op->emitOpError() |
| << "initializer must take a !hal.device or nothing"; |
| } |
| |
| // Verify the return types are all i32 (today). |
| for (auto resultType : llvm::enumerate(op.getResultTypes())) { |
| if (!resultType.value().isInteger(32)) { |
| return op->emitOpError() |
| << "initializer must return only i32 values (result " |
| << resultType.index() << " is " << resultType.value() << ")"; |
| } |
| } |
| |
| // Verify there's a key for every result. |
| if (op.getNumResults() != op.getKeys().size()) { |
| return op->emitOpError() << "must have one key for every result"; |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.executable.binary |
| //===----------------------------------------------------------------------===// |
| |
| void ExecutableBinaryOp::build(OpBuilder &builder, OperationState &state, |
| StringRef symName, StringRef format, |
| std::vector<uint8_t> data) { |
| state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), |
| builder.getStringAttr(symName)); |
| state.addAttribute("format", builder.getStringAttr(format)); |
| state.addAttribute("data", |
| DenseIntElementsAttr::get( |
| VectorType::get({static_cast<int64_t>(data.size())}, |
| builder.getIntegerType(8)), |
| data)); |
| } |
| |
| void ExecutableBinaryOp::build(OpBuilder &builder, OperationState &state, |
| StringRef symName, StringAttr format, |
| DenseIntElementsAttr data) { |
| state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), |
| builder.getStringAttr(symName)); |
| state.addAttribute("format", format); |
| state.addAttribute("data", data); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.executable.create |
| //===----------------------------------------------------------------------===// |
| |
| void ExecutableCreateOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), StringRef("exe")); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.executable.lookup |
| //===----------------------------------------------------------------------===// |
| |
| void ExecutableLookupOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "exe"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.interface.binding.subspan |
| //===----------------------------------------------------------------------===// |
| void InterfaceBindingSubspanOp::build( |
| OpBuilder &builder, OperationState &result, Type resultType, APInt set, |
| APInt binding, IREE::HAL::DescriptorType descriptor_type, Value byte_offset, |
| ValueRange dynamic_dims, IntegerAttr alignment, |
| std::optional<DescriptorFlags> flags) { |
| IREE::HAL::DescriptorFlagsAttr descriptorAttr; |
| if (flags.has_value()) { |
| descriptorAttr = IREE::HAL::DescriptorFlagsAttr::get(builder.getContext(), |
| flags.value()); |
| } |
| build(builder, result, resultType, set, binding, descriptor_type, byte_offset, |
| dynamic_dims, alignment, descriptorAttr); |
| } |
| |
| LogicalResult InterfaceBindingSubspanOp::verify() { |
| InterfaceBindingSubspanOp op = *this; |
| if (ShapedType shapedType = llvm::dyn_cast<ShapedType>(op.getType())) { |
| if (shapedType.getNumDynamicDims() != op.getDynamicDims().size()) { |
| return op.emitOpError("result type ") |
| << op.getType() << " has " << shapedType.getNumDynamicDims() |
| << " dynamic dimensions but " << op.getDynamicDims().size() |
| << " associated dimension SSA values"; |
| } |
| } |
| |
| return success(); |
| } |
| |
| llvm::MaybeAlign InterfaceBindingSubspanOp::getBaseAlignment() { |
| if (auto baseAlignmentInt = getAlignment()) { |
| return llvm::MaybeAlign(baseAlignmentInt.value().getZExtValue()); |
| } |
| return std::nullopt; |
| } |
| |
| llvm::Align InterfaceBindingSubspanOp::calculateAlignment() { |
| // If we can't calculate an alignment we fall back to the natural alignment of |
| // the element type (for example, a memref<?xi32> is known to be at least |
| // 4-byte aligned). |
| llvm::Align naturalAlignment(1); |
| auto resultType = getType(); |
| if (auto shapedType = llvm::dyn_cast<ShapedType>(resultType)) { |
| naturalAlignment = llvm::Align( |
| IREE::Util::getRoundedElementByteWidth(shapedType.getElementType())); |
| } |
| |
| // If the binding has no assigned alignment we fall back to natural alignment. |
| auto baseAlignment = getBaseAlignment(); |
| if (!baseAlignment) |
| return naturalAlignment; |
| |
| // If there's no offset specified then we can use the binding alignment |
| // directly. |
| if (!getByteOffset()) |
| return baseAlignment.value(); |
| |
| // Try to get the alignment of the byte offset. If it's a constant then we can |
| // find a common alignment between it and the base and otherwise we need to |
| // try to infer the alignment from the IR - otherwise we fall back. |
| auto offsetOrAlignment = lookupOffsetOrAlignment(getByteOffset()); |
| if (!offsetOrAlignment.has_value()) |
| return naturalAlignment; |
| |
| // Compute the common alignment between that of the binding base and that of |
| // the byte offset. |
| return llvm::commonAlignment(baseAlignment.value(), |
| offsetOrAlignment.value()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.interface.workgroup.* |
| //===----------------------------------------------------------------------===// |
| |
| static void getAsmResultNamesForInterfaceWorkgroupOp( |
| StringRef prefix, const APInt &dimension, Value result, |
| function_ref<void(Value, StringRef)> setNameFn) { |
| switch (dimension.getZExtValue()) { |
| case 0: |
| setNameFn(result, (prefix + "x").str()); |
| return; |
| case 1: |
| setNameFn(result, (prefix + "y").str()); |
| return; |
| case 2: |
| setNameFn(result, (prefix + "z").str()); |
| return; |
| } |
| } |
| |
| void InterfaceWorkgroupIDOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| getAsmResultNamesForInterfaceWorkgroupOp("workgroup_id_", getDimension(), |
| getResult(), setNameFn); |
| } |
| |
| void InterfaceWorkgroupCountOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| getAsmResultNamesForInterfaceWorkgroupOp("workgroup_count_", getDimension(), |
| getResult(), setNameFn); |
| } |
| |
| void InterfaceWorkgroupSizeOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| getAsmResultNamesForInterfaceWorkgroupOp("workgroup_size_", getDimension(), |
| getResult(), setNameFn); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.pipeline_layout.create |
| //===----------------------------------------------------------------------===// |
| |
| void PipelineLayoutCreateOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "pipeline_layout"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.pipeline_layout.lookup |
| //===----------------------------------------------------------------------===// |
| |
| void PipelineLayoutLookupOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "pipeline_layout"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // hal.fence.* |
| //===----------------------------------------------------------------------===// |
| |
| void FenceCreateOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "fence"); |
| } |
| |
| void FenceJoinOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "fence"); |
| } |
| |
| void FenceAwaitOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getStatus(), "status"); |
| } |
| |
| } // namespace HAL |
| } // namespace IREE |
| } // namespace iree_compiler |
| } // namespace mlir |
| |
| //===----------------------------------------------------------------------===// |
| // TableGen definitions (intentionally last) |
| //===----------------------------------------------------------------------===// |
| |
| #define GET_OP_CLASSES |
| #include "iree/compiler/Dialect/HAL/IR/HALOps.cpp.inc" |