| // Copyright 2019 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Dialect/Util/IR/UtilOps.h" |
| |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/SMLoc.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/Dialect/UB/IR/UBOps.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/Diagnostics.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/IR/OperationSupport.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/IR/Value.h" |
| #include "mlir/Interfaces/FunctionImplementation.h" |
| #include "mlir/Support/LLVM.h" |
| #include "mlir/Support/LogicalResult.h" |
| |
| #include <numeric> |
| |
| namespace mlir::iree_compiler { |
| |
| //===----------------------------------------------------------------------===// |
| // Experimental |
| //===----------------------------------------------------------------------===// |
| |
| // For now we emit all cases and then select the first found (by selecting |
| // in reverse). So if selecting between case0, case1, and case2 we'd end up with |
| // %case0 = ... |
| // %case1 = ... |
| // %case2 = ... |
| // %0 = arith.select %case2, %c2, %c-1 |
| // %1 = arith.select %case1, %c1, %0 |
| // %2 = arith.select %case0, %c0, %1 |
| // // %2 is now -1 if nothing matched or the index of the match |
| Value buildIfElseTree( |
| Location loc, size_t count, |
| std::function<Value(Location, size_t, OpBuilder &)> caseBuilder, |
| OpBuilder &builder) { |
| SmallVector<Value> caseValues; |
| caseValues.reserve(count); |
| for (size_t i = 0; i < count; ++i) { |
| caseValues.push_back(caseBuilder(loc, i, builder)); |
| } |
| Value result = arith::ConstantIndexOp::create(builder, loc, -1); |
| for (int i = count - 1; i >= 0; --i) { |
| result = arith::SelectOp::create( |
| builder, loc, caseValues[i], |
| arith::ConstantIndexOp::create(builder, loc, i), result); |
| } |
| return result; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Utils |
| //===----------------------------------------------------------------------===// |
| |
| ArrayAttr deduplicateArrayElements(ArrayAttr arrayAttr) { |
| SetVector<Attribute> attrsSet(arrayAttr.begin(), arrayAttr.end()); |
| if (attrsSet.size() == arrayAttr.size()) { |
| return arrayAttr; |
| } |
| return ArrayAttr::get(arrayAttr.getContext(), attrsSet.takeVector()); |
| } |
| |
| int64_t findTiedOperand(OpAsmParser::UnresolvedOperand tiedResult, |
| ArrayRef<OpAsmParser::UnresolvedOperand> operands) { |
| int64_t operandIndex = IREE::Util::TiedOpInterface::kUntiedIndex; |
| for (int64_t i = 0; i < operands.size(); ++i) { |
| if (operands[i].name == tiedResult.name && |
| operands[i].number == tiedResult.number) { |
| operandIndex = i; |
| break; |
| } |
| } |
| return operandIndex; |
| } |
| |
| static int64_t findTiedArgument(OpAsmParser::UnresolvedOperand tiedResult, |
| ArrayRef<OpAsmParser::Argument> arguments) { |
| int64_t operandIndex = IREE::Util::TiedOpInterface::kUntiedIndex; |
| for (int64_t i = 0; i < arguments.size(); ++i) { |
| if (arguments[i].ssaName.name == tiedResult.name && |
| arguments[i].ssaName.number == tiedResult.number) { |
| operandIndex = i; |
| break; |
| } |
| } |
| return operandIndex; |
| } |
| |
| // Returns true if any attribute in |attr| references |symbolNameAttr|. |
| static bool hasAnyRefsToSymbol(DictionaryAttr attrs, |
| StringAttr symbolNameAttr) { |
| bool anyRefs = false; |
| attrs.walk([&](FlatSymbolRefAttr attr) { |
| anyRefs = attr.getValue() == symbolNameAttr.getValue(); |
| return anyRefs ? WalkResult::interrupt() : WalkResult::advance(); |
| }); |
| return anyRefs; |
| } |
| |
| // Returns true if any attribute on any ancestor of |baseOp| references |
| // |symbolNameAttr|. |
| static bool anyAncestorHasAnyRefsToSymbol(Operation *baseOp, |
| StringAttr symbolNameAttr) { |
| Operation *parentOp = baseOp->getParentOp(); |
| while (parentOp) { |
| // Check the op attributes for a reference. |
| if (hasAnyRefsToSymbol(parentOp->getAttrDictionary(), symbolNameAttr)) { |
| return true; // found a ref |
| } |
| if (parentOp->hasTrait<OpTrait::SymbolTable>()) { |
| break; // don't continue op past the first symbol table |
| } |
| parentOp = parentOp->getParentOp(); |
| } |
| return false; // none found |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<SymbolVisibility>($sym_visibility) |
| //===----------------------------------------------------------------------===// |
| // some.op custom<SymbolVisibility>($sym_visibility) $sym_name |
| // -> |
| // some.op @foo |
| // some.op private @foo |
| |
| ParseResult parseSymbolVisibility(OpAsmParser &parser, |
| StringAttr &symVisibilityAttr) { |
| StringRef symVisibility; |
| if (succeeded(parser.parseOptionalKeyword(&symVisibility, |
| {"public", "private", "nested"}))) { |
| symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility); |
| } |
| return success(); |
| } |
| |
| void printSymbolVisibility(OpAsmPrinter &p, Operation *op, |
| StringAttr symVisibilityAttr) { |
| if (!symVisibilityAttr) { |
| p << "public"; |
| } else { |
| p << symVisibilityAttr.getValue(); |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<TypeOrAttr>($type, $attr) |
| //===----------------------------------------------------------------------===// |
| // some.op custom<TypeOrAttr>($type, $attr) |
| // -> |
| // some.op : i32 |
| // some.op = 42 : i32 |
| // some.op : i32 = 42 : index |
| |
| ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, |
| Attribute &attr) { |
| if (succeeded(parser.parseOptionalEqual())) { |
| if (failed(parser.parseAttribute(attr))) { |
| return parser.emitError(parser.getCurrentLocation()) |
| << "expected attribute"; |
| } |
| |
| if (auto typedAttr = dyn_cast<TypedAttr>(attr)) { |
| typeAttr = TypeAttr::get(typedAttr.getType()); |
| } |
| return success(); |
| } |
| |
| Type type; |
| if (failed(parser.parseColonType(type))) { |
| return parser.emitError(parser.getCurrentLocation()) << "expected type"; |
| } |
| typeAttr = TypeAttr::get(type); |
| |
| if (succeeded(parser.parseOptionalEqual())) { |
| if (failed(parser.parseAttribute(attr))) { |
| return parser.emitError(parser.getCurrentLocation()) |
| << "expected attribute"; |
| } |
| } |
| |
| return success(); |
| } |
| |
| void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, |
| Attribute attr) { |
| bool needsSpace = false; |
| auto typedAttr = dyn_cast_if_present<TypedAttr>(attr); |
| if (!typedAttr || typedAttr.getType() != type.getValue()) { |
| p << ": "; |
| p.printAttribute(type); |
| needsSpace = true; // subsequent attr value needs a space separator |
| } |
| if (attr) { |
| if (needsSpace) { |
| p << ' '; |
| } |
| p << "= "; |
| p.printAttribute(attr); |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<SymbolAlias>($sym_name, $alias) |
| //===----------------------------------------------------------------------===// |
| // @foo sym_name: @foo, alias: @foo |
| // @foo as("bar") sym_name: @bar, alias: @foo |
| |
| ParseResult parseSymbolAlias(OpAsmParser &parser, StringAttr &sym_name, |
| FlatSymbolRefAttr &alias) { |
| if (failed(parser.parseAttribute(alias))) { |
| return failure(); |
| } |
| if (succeeded(parser.parseOptionalKeyword("as"))) { |
| if (failed(parser.parseLParen()) || |
| failed(parser.parseAttribute(sym_name)) || |
| failed(parser.parseRParen())) { |
| return failure(); |
| } |
| } else { |
| sym_name = StringAttr::get(parser.getContext(), alias.getValue()); |
| } |
| return success(); |
| } |
| |
| void printSymbolAlias(OpAsmPrinter &p, Operation *op, StringAttr sym_name, |
| FlatSymbolRefAttr alias) { |
| p.printAttributeWithoutType(alias); |
| if (sym_name.getValue() != alias.getValue()) { |
| p << " as(\"" << sym_name.getValue() << "\")"; |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<TypeAlias>($encoding_type, $storage_type) |
| //===----------------------------------------------------------------------===// |
| // tensor<4xf32> |
| // tensor<4xf32> as tensor<2xf64> |
| |
| ParseResult parseTypeAlias(OpAsmParser &parser, TypeAttr &encodingTypeAttr, |
| Type &storageType) { |
| Type encodingType; |
| if (failed(parser.parseType(encodingType))) { |
| return failure(); |
| } |
| storageType = encodingType; |
| if (succeeded(parser.parseOptionalKeyword("as"))) { |
| if (failed(parser.parseType(storageType))) { |
| return failure(); |
| } |
| } |
| encodingTypeAttr = TypeAttr::get(encodingType); |
| return success(); |
| } |
| |
| void printTypeAlias(OpAsmPrinter &p, Operation *op, TypeAttr encodingTypeAttr, |
| Type storageType) { |
| if (encodingTypeAttr.getValue() != storageType) { |
| p.printType(encodingTypeAttr.getValue()); |
| p << " as "; |
| } |
| p.printType(storageType); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<TypedValueList>(ref($type_value), $values) |
| //===----------------------------------------------------------------------===// |
| |
| ParseResult |
| parseTypedValueList(OpAsmParser &parser, Type type, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, |
| SmallVectorImpl<Type> &valueTypes) { |
| if (failed(parser.parseOperandList(values, AsmParser::Delimiter::Square))) { |
| return failure(); |
| } |
| valueTypes.append(values.size(), type); |
| return success(); |
| } |
| |
| void printTypedValueList(OpAsmPrinter &p, Operation *op, Type type, |
| OperandRange values, TypeRange valueTypes) { |
| p << "["; |
| p.printOperands(values); |
| p << "]"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<RangeList>($offsets, $lengths) |
| //===----------------------------------------------------------------------===// |
| // [%offset for %length], [%offset for %length], ... |
| |
| ParseResult |
| parseRangeList(OpAsmParser &parser, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &offsets, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &lengths) { |
| do { |
| OpAsmParser::UnresolvedOperand offset; |
| OpAsmParser::UnresolvedOperand length; |
| if (failed(parser.parseLSquare()) || failed(parser.parseOperand(offset)) || |
| failed(parser.parseKeyword("for")) || |
| failed(parser.parseOperand(length)) || failed(parser.parseRSquare())) { |
| return failure(); |
| } |
| offsets.push_back(offset); |
| lengths.push_back(length); |
| } while (succeeded(parser.parseOptionalComma())); |
| return success(); |
| } |
| |
| void printRangeList(OpAsmPrinter &p, Operation *op, OperandRange offsets, |
| OperandRange lengths) { |
| llvm::interleaveComma(llvm::zip_equal(offsets, lengths), p, [&](auto it) { |
| auto offset = std::get<0>(it); |
| auto length = std::get<1>(it); |
| p << "["; |
| p.printOperand(offset); |
| p << " for "; |
| p.printOperand(length); |
| p << "]"; |
| }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<SizeAwareType> |
| //===----------------------------------------------------------------------===// |
| // type{%size} |
| |
| ParseResult parseSizeAwareType(OpAsmParser &parser, Type &type, |
| OpAsmParser::UnresolvedOperand &size) { |
| if (failed(parser.parseType(type)) || failed(parser.parseLBrace()) || |
| failed(parser.parseOperand(size)) || failed(parser.parseRBrace())) { |
| return failure(); |
| } |
| return success(); |
| } |
| |
| void printSizeAwareType(OpAsmPrinter &p, Operation *op, Type type, Value size) { |
| p.printType(type); |
| p << "{"; |
| p.printOperand(size); |
| p << "}"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<OperandTypeList> |
| //===----------------------------------------------------------------------===// |
| // () |
| // (type, type) |
| |
| ParseResult parseOperandTypeList(OpAsmParser &parser, |
| SmallVectorImpl<Type> &operandTypes) { |
| if (failed(parser.parseLParen())) { |
| return failure(); |
| } |
| if (succeeded(parser.parseOptionalRParen())) { |
| return success(); // empty |
| } |
| do { |
| Type type; |
| if (failed(parser.parseType(type))) { |
| return failure(); |
| } |
| operandTypes.push_back(type); |
| } while (succeeded(parser.parseOptionalComma())); |
| if (failed(parser.parseRParen())) { |
| return failure(); |
| } |
| return success(); |
| } |
| |
| void printOperandTypeList(OpAsmPrinter &p, Operation *op, |
| TypeRange operandTypes) { |
| p << '('; |
| llvm::interleaveComma(operandTypes, p.getStream()); |
| p << ')'; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<TiedResultList> |
| //===----------------------------------------------------------------------===// |
| // type, %operand0, %operand1 as type |
| |
| ParseResult |
| parseTiedResultList(OpAsmParser &parser, |
| ArrayRef<OpAsmParser::UnresolvedOperand> operands, |
| TypeRange operandTypes, SmallVectorImpl<Type> &resultTypes, |
| ArrayAttr &tiedOperands) { |
| SmallVector<int64_t> tiedOperandIndices; |
| do { |
| OpAsmParser::UnresolvedOperand tiedResult; |
| auto res = parser.parseOptionalOperand(tiedResult); |
| Type type; |
| int64_t tiedOperandIndex = IREE::Util::TiedOpInterface::kUntiedIndex; |
| if (res.has_value() && succeeded(res.value())) { |
| tiedOperandIndex = findTiedOperand(tiedResult, operands); |
| if (tiedOperandIndex == IREE::Util::TiedOpInterface::kUntiedIndex) { |
| return parser.emitError(tiedResult.location, |
| "tied operand not found for result reference ") |
| << tiedResult.name; |
| } |
| if (succeeded(parser.parseOptionalKeyword("as"))) { |
| // Type _may_ differ from the operand. |
| if (failed(parser.parseType(type))) { |
| return failure(); |
| } |
| } else { |
| // Use the operands type. |
| type = operandTypes[tiedOperandIndex]; |
| } |
| } else if (failed(parser.parseType(type))) { |
| return failure(); |
| } |
| resultTypes.push_back(type); |
| tiedOperandIndices.push_back(tiedOperandIndex); |
| } while (succeeded(parser.parseOptionalComma())); |
| if (!tiedOperandIndices.empty()) { |
| tiedOperands = parser.getBuilder().getIndexArrayAttr(tiedOperandIndices); |
| } |
| return success(); |
| } |
| |
| void printTiedResultList(OpAsmPrinter &p, Operation *op, ValueRange operands, |
| TypeRange operandTypes, TypeRange resultTypes, |
| ArrayAttr tiedOperands) { |
| auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(op); |
| for (unsigned i = 0; i < resultTypes.size(); ++i) { |
| auto resultType = resultTypes[i]; |
| auto tiedOperandIndex = |
| tiedOp ? tiedOp.getTiedResultOperandIndex(i) : std::nullopt; |
| bool printType = true; |
| if (tiedOperandIndex.has_value()) { |
| auto tiedOperand = op->getOperand(tiedOperandIndex.value()); |
| p.printOperand(tiedOperand); |
| if (tiedOperand.getType() != resultType) { |
| p << " as "; |
| } else { |
| // Type elided as it matches the operand. |
| printType = false; |
| } |
| } |
| if (printType) { |
| p.printType(resultType); |
| } |
| if (i < resultTypes.size() - 1) { |
| p << ", "; |
| } |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<TiedFunctionResultList> |
| //===----------------------------------------------------------------------===// |
| // () |
| // type |
| // (type, %operand0 {some.attr}, %operand1 as type) |
| |
| static ParseResult |
| parseTiedFunctionResultListImpl(OpAsmParser &parser, |
| ArrayRef<OpAsmParser::Argument> arguments, |
| SmallVectorImpl<Type> &resultTypes, |
| SmallVectorImpl<DictionaryAttr> &resultAttrs, |
| ArrayAttr &tiedOperands, bool allowAttrs) { |
| SmallVector<int64_t> tiedOperandIndices; |
| do { |
| OpAsmParser::UnresolvedOperand tiedResult; |
| auto res = parser.parseOptionalOperand(tiedResult); |
| Type type; |
| int64_t tiedOperandIndex = IREE::Util::TiedOpInterface::kUntiedIndex; |
| if (res.has_value() && succeeded(res.value())) { |
| tiedOperandIndex = findTiedArgument(tiedResult, arguments); |
| if (tiedOperandIndex == IREE::Util::TiedOpInterface::kUntiedIndex) { |
| return parser.emitError(tiedResult.location, |
| "tied operand not found for result reference ") |
| << tiedResult.name; |
| } |
| if (succeeded(parser.parseOptionalKeyword("as"))) { |
| // Type _may_ differ from the operand. |
| if (failed(parser.parseType(type))) { |
| return failure(); |
| } |
| } else { |
| // Use the operands type. |
| type = arguments[tiedOperandIndex].type; |
| } |
| } else if (failed(parser.parseType(type))) { |
| return failure(); |
| } |
| DictionaryAttr resultAttrDict; |
| if (allowAttrs) { |
| NamedAttrList resultAttrList; |
| if (succeeded(parser.parseOptionalAttrDict(resultAttrList))) { |
| resultAttrDict = parser.getBuilder().getDictionaryAttr(resultAttrList); |
| } |
| } |
| resultTypes.push_back(type); |
| resultAttrs.push_back(resultAttrDict); |
| tiedOperandIndices.push_back(tiedOperandIndex); |
| } while (succeeded(parser.parseOptionalComma())); |
| if (!tiedOperandIndices.empty()) { |
| tiedOperands = parser.getBuilder().getIndexArrayAttr(tiedOperandIndices); |
| } |
| return success(); |
| } |
| |
| static ParseResult parseTiedFunctionResultList( |
| OpAsmParser &parser, ArrayRef<OpAsmParser::Argument> arguments, |
| SmallVectorImpl<Type> &resultTypes, |
| SmallVectorImpl<DictionaryAttr> &resultAttrs, ArrayAttr &tiedOperands) { |
| SmallVector<OpAsmParser::UnresolvedOperand> operands; |
| SmallVector<Type> operandTypes; |
| operands.reserve(arguments.size()); |
| operandTypes.reserve(arguments.size()); |
| for (auto argument : arguments) { |
| operands.push_back(argument.ssaName); |
| operandTypes.push_back(argument.type); |
| } |
| if (succeeded(parser.parseOptionalLParen())) { |
| if (succeeded(parser.parseOptionalRParen())) { |
| // Empty list/no results `()`. |
| } else { |
| // One or more result types. |
| if (failed(parseTiedFunctionResultListImpl(parser, arguments, resultTypes, |
| resultAttrs, tiedOperands, |
| /*allowAttrs=*/true)) || |
| failed(parser.parseRParen())) { |
| return failure(); |
| } |
| } |
| } else { |
| // Single result with omitted `()`. |
| if (failed(parseTiedFunctionResultListImpl(parser, arguments, resultTypes, |
| resultAttrs, tiedOperands, |
| /*allowAttrs=*/false))) { |
| return failure(); |
| } |
| } |
| return success(); |
| } |
| |
| ParseResult parseTiedFunctionResultList( |
| OpAsmParser &parser, ArrayRef<OpAsmParser::UnresolvedOperand> operands, |
| ArrayRef<Type> operandTypes, SmallVectorImpl<Type> &resultTypes, |
| ArrayAttr &tiedOperands) { |
| if (succeeded(parser.parseOptionalLParen())) { |
| if (succeeded(parser.parseOptionalRParen())) { |
| // Empty list/no results `()`. |
| } else { |
| // One or more result types. |
| if (failed(parseTiedResultList(parser, operands, operandTypes, |
| resultTypes, tiedOperands)) || |
| failed(parser.parseRParen())) { |
| return failure(); |
| } |
| } |
| } else { |
| // Single result with omitted `()`. |
| if (failed(parseTiedResultList(parser, operands, operandTypes, resultTypes, |
| tiedOperands))) { |
| return failure(); |
| } |
| } |
| return success(); |
| } |
| |
| void printTiedFunctionResultList(OpAsmPrinter &p, Operation *op, |
| ValueRange operands, TypeRange operandTypes, |
| TypeRange resultTypes, |
| ArrayAttr tiedOperands) { |
| if (resultTypes.size() != 1) { |
| p << "("; |
| } |
| printTiedResultList(p, op, operands, operandTypes, resultTypes, tiedOperands); |
| if (resultTypes.size() != 1) { |
| p << ")"; |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<ShapedTypeList> |
| //===----------------------------------------------------------------------===// |
| // type{%size0}, type, type{%size1} |
| |
| ParseResult |
| parseShapedTypeList(OpAsmParser &parser, SmallVectorImpl<Type> &types, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dims) { |
| do { |
| Type type; |
| if (failed(parser.parseType(type))) { |
| return failure(); |
| } |
| if (auto shapedType = dyn_cast<ShapedType>(type)) { |
| if (!shapedType.hasStaticShape()) { |
| SmallVector<OpAsmParser::UnresolvedOperand> dynamicDims; |
| if (failed(parser.parseLBrace()) || |
| failed(parser.parseOperandList(dynamicDims, |
| shapedType.getNumDynamicDims(), |
| OpAsmParser::Delimiter::None)) || |
| failed(parser.parseRBrace())) { |
| return failure(); |
| } |
| dims.append(dynamicDims); |
| } |
| } else if (isa<IREE::Util::SizeAwareTypeInterface>(type)) { |
| OpAsmParser::UnresolvedOperand size; |
| if (failed(parser.parseLBrace()) || failed(parser.parseOperand(size)) || |
| failed(parser.parseRBrace())) { |
| return failure(); |
| } |
| dims.push_back(size); |
| } |
| types.push_back(type); |
| } while (succeeded(parser.parseOptionalComma())); |
| return success(); |
| } |
| |
| void printShapedTypeList(OpAsmPrinter &p, Operation *op, TypeRange types, |
| ValueRange dims) { |
| llvm::interleaveComma(types, p, [&](Type type) { |
| p.printType(type); |
| if (auto shapedType = dyn_cast<ShapedType>(type)) { |
| if (!shapedType.hasStaticShape()) { |
| if (dims.empty()) { |
| p << "{<<INVALID>>}"; |
| return; |
| } |
| p << "{"; |
| llvm::interleaveComma(dims.take_front(shapedType.getNumDynamicDims()), |
| p, [&](Value value) { p.printOperand(value); }); |
| p << "}"; |
| dims = dims.drop_front(shapedType.getNumDynamicDims()); |
| } |
| } else if (isa<IREE::Util::SizeAwareTypeInterface>(type)) { |
| p << "{"; |
| p.printOperand(dims.front()); |
| p << "}"; |
| dims = dims.drop_front(1); |
| } |
| }); |
| } |
| |
| ParseResult |
| parseShapedTypeList(OpAsmParser &parser, SmallVectorImpl<Type> &types0, |
| SmallVectorImpl<Type> &types1, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dims) { |
| if (failed(parseShapedTypeList(parser, types0, dims))) { |
| return failure(); |
| } |
| types1 = types0; |
| return success(); |
| } |
| |
| void printShapedTypeList(OpAsmPrinter &p, Operation *op, TypeRange types0, |
| TypeRange types1, ValueRange dims) { |
| printShapedTypeList(p, op, types0, dims); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<ShapedTiedResult> |
| //===----------------------------------------------------------------------===// |
| // type{%dim0, %dim1} |
| // %arg0 as type{%dim0} |
| |
| ParseResult parseShapedTiedResult( |
| OpAsmParser &parser, Type &resultType, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultDims) { |
| ArrayAttr tiedOperands; |
| return parseShapedTiedResult(parser, resultType, resultDims, tiedOperands); |
| } |
| |
| ParseResult parseShapedTiedResult( |
| OpAsmParser &parser, Type &resultType, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultDims, |
| ArrayAttr &tiedOperands) { |
| OpAsmParser::UnresolvedOperand tiedResult; |
| auto res = parser.parseOptionalOperand(tiedResult); |
| int64_t tiedOperandIndex = IREE::Util::TiedOpInterface::kUntiedIndex; |
| if (res.has_value() && succeeded(res.value())) { |
| tiedOperandIndex = 0; |
| if (failed(parser.parseKeyword("as"))) { |
| return failure(); |
| } |
| } |
| if (failed(parser.parseType(resultType))) { |
| return failure(); |
| } |
| if (auto shapedType = dyn_cast<ShapedType>(resultType)) { |
| if (!shapedType.hasStaticShape()) { |
| SmallVector<OpAsmParser::UnresolvedOperand> dynamicDims; |
| if (failed(parser.parseLBrace()) || |
| failed(parser.parseOperandList(dynamicDims, |
| shapedType.getNumDynamicDims(), |
| OpAsmParser::Delimiter::None)) || |
| failed(parser.parseRBrace())) { |
| return failure(); |
| } |
| resultDims.append(dynamicDims); |
| } |
| } else if (auto sizedType = |
| dyn_cast<IREE::Util::SizeAwareTypeInterface>(resultType)) { |
| OpAsmParser::UnresolvedOperand size; |
| if (failed(parser.parseLBrace()) || failed(parser.parseOperand(size)) || |
| failed(parser.parseRBrace())) { |
| return failure(); |
| } |
| resultDims.push_back(size); |
| } |
| tiedOperands = parser.getBuilder().getIndexArrayAttr({tiedOperandIndex}); |
| return success(); |
| } |
| |
| void printShapedTiedResult(OpAsmPrinter &p, Operation *op, Type resultType, |
| ValueRange resultDims) { |
| auto tiedOp = cast<IREE::Util::TiedOpInterface>(op); |
| auto tiedOperandIndex = tiedOp.getTiedResultOperandIndex(0); |
| if (tiedOperandIndex.has_value()) { |
| auto tiedOperand = op->getOperand(tiedOperandIndex.value()); |
| p.printOperand(tiedOperand); |
| p << " as "; |
| } |
| p.printType(resultType); |
| if (auto shapedType = dyn_cast<ShapedType>(resultType)) { |
| if (!shapedType.hasStaticShape()) { |
| if (resultDims.empty()) { |
| p << "{<<INVALID>>}"; |
| return; |
| } |
| p << "{"; |
| llvm::interleaveComma( |
| resultDims.take_front(shapedType.getNumDynamicDims()), p, |
| [&](Value value) { p.printOperand(value); }); |
| p << "}"; |
| resultDims = resultDims.drop_front(shapedType.getNumDynamicDims()); |
| } |
| } else if (auto sizedType = |
| dyn_cast<IREE::Util::SizeAwareTypeInterface>(resultType)) { |
| p << "{"; |
| p.printOperand(resultDims.front()); |
| p << "}"; |
| resultDims = resultDims.drop_front(1); |
| } |
| } |
| |
| void printShapedTiedResult(OpAsmPrinter &p, Operation *op, Type resultType, |
| ValueRange resultDims, ArrayAttr tiedOperands) { |
| printShapedTiedResult(p, op, resultType, resultDims); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<ShapedResultList> |
| //===----------------------------------------------------------------------===// |
| // type{%dim2}, %operand4 |
| |
| ParseResult parseShapedResultList( |
| OpAsmParser &parser, ArrayRef<OpAsmParser::UnresolvedOperand> operands, |
| TypeRange operandTypes, |
| ArrayRef<OpAsmParser::UnresolvedOperand> operandDims, |
| SmallVectorImpl<Type> &resultTypes, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultDims, |
| ArrayAttr &tiedOperands) { |
| SmallVector<int64_t> tiedOperandIndices; |
| do { |
| OpAsmParser::UnresolvedOperand tiedResult; |
| auto res = parser.parseOptionalOperand(tiedResult); |
| Type type; |
| int64_t tiedOperandIndex = IREE::Util::TiedOpInterface::kUntiedIndex; |
| if (res.has_value() && succeeded(res.value())) { |
| tiedOperandIndex = findTiedOperand(tiedResult, operands); |
| if (tiedOperandIndex == IREE::Util::TiedOpInterface::kUntiedIndex) { |
| return parser.emitError(tiedResult.location, |
| "tied operand not found for result reference ") |
| << tiedResult.name; |
| } |
| if (succeeded(parser.parseOptionalKeyword("as"))) { |
| // Type _may_ differ from the operand. |
| if (failed(parser.parseType(type))) { |
| return failure(); |
| } |
| } else { |
| // Use the operands type. |
| type = operandTypes[tiedOperandIndex]; |
| } |
| } else if (failed(parser.parseType(type))) { |
| return failure(); |
| } |
| if (auto shapedType = dyn_cast<ShapedType>(type)) { |
| if (!shapedType.hasStaticShape()) { |
| SmallVector<OpAsmParser::UnresolvedOperand> dynamicDims; |
| if (failed(parser.parseLBrace()) || |
| failed(parser.parseOperandList(dynamicDims, |
| shapedType.getNumDynamicDims(), |
| OpAsmParser::Delimiter::None)) || |
| failed(parser.parseRBrace())) { |
| return failure(); |
| } |
| resultDims.append(dynamicDims); |
| } |
| } else if (auto sizedType = |
| dyn_cast<IREE::Util::SizeAwareTypeInterface>(type)) { |
| OpAsmParser::UnresolvedOperand size; |
| if (failed(parser.parseLBrace()) || failed(parser.parseOperand(size)) || |
| failed(parser.parseRBrace())) { |
| return failure(); |
| } |
| resultDims.push_back(size); |
| } |
| resultTypes.push_back(type); |
| tiedOperandIndices.push_back(tiedOperandIndex); |
| } while (succeeded(parser.parseOptionalComma())); |
| if (!tiedOperandIndices.empty()) { |
| tiedOperands = parser.getBuilder().getIndexArrayAttr(tiedOperandIndices); |
| } |
| return success(); |
| } |
| |
| void printShapedResultList(OpAsmPrinter &p, Operation *op, ValueRange operands, |
| TypeRange operandTypes, ValueRange operandDims, |
| TypeRange resultTypes, ValueRange resultDims, |
| ArrayAttr tiedOperands) { |
| auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(op); |
| for (unsigned i = 0; i < resultTypes.size(); ++i) { |
| auto resultType = resultTypes[i]; |
| auto tiedOperandIndex = |
| tiedOp ? tiedOp.getTiedResultOperandIndex(i) : std::nullopt; |
| bool printType = true; |
| if (tiedOperandIndex.has_value()) { |
| auto tiedOperand = op->getOperand(tiedOperandIndex.value()); |
| p.printOperand(tiedOperand); |
| if (tiedOperand.getType() != resultType) { |
| p << " as "; |
| } else { |
| // Type elided as it matches the operand. |
| printType = false; |
| } |
| } |
| if (printType) { |
| p.printType(resultType); |
| } |
| if (auto shapedType = dyn_cast<ShapedType>(resultType)) { |
| if (!shapedType.hasStaticShape()) { |
| if (resultDims.empty()) { |
| p << "{<<INVALID>>}"; |
| return; |
| } |
| p << "{"; |
| llvm::interleaveComma( |
| resultDims.take_front(shapedType.getNumDynamicDims()), p, |
| [&](Value value) { p.printOperand(value); }); |
| p << "}"; |
| resultDims = resultDims.drop_front(shapedType.getNumDynamicDims()); |
| } |
| } else if (auto sizedType = |
| dyn_cast<IREE::Util::SizeAwareTypeInterface>(resultType)) { |
| p << "{"; |
| p.printOperand(resultDims.front()); |
| p << "}"; |
| resultDims = resultDims.drop_front(1); |
| } |
| if (i < resultTypes.size() - 1) { |
| p << ", "; |
| } |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<ShapedFunctionType> |
| //===----------------------------------------------------------------------===// |
| // (type, type{%dim0, %dim1}, type) -> (type{%dim2}, %operand4) |
| |
| ParseResult parseShapedFunctionType( |
| OpAsmParser &parser, ArrayRef<OpAsmParser::UnresolvedOperand> operands, |
| SmallVectorImpl<Type> &operandTypes, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operandDims, |
| SmallVectorImpl<Type> &resultTypes, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultDims, |
| ArrayAttr &tiedOperands) { |
| if (failed(parser.parseLParen())) { |
| return failure(); |
| } |
| if (failed(parser.parseOptionalRParen())) { |
| if (failed(parseShapedTypeList(parser, operandTypes, operandDims)) || |
| failed(parser.parseRParen())) { |
| return failure(); |
| } |
| } |
| if (failed(parser.parseArrow())) { |
| return failure(); |
| } |
| if (succeeded(parser.parseOptionalLParen())) { |
| if (succeeded(parser.parseOptionalRParen())) { |
| // Empty list/no results `()`. |
| } else { |
| // One or more result types. |
| if (failed(parseShapedResultList(parser, operands, operandTypes, |
| operandDims, resultTypes, resultDims, |
| tiedOperands)) || |
| failed(parser.parseRParen())) { |
| return failure(); |
| } |
| } |
| } else { |
| // Single result with omitted `()`. |
| if (failed(parseShapedResultList(parser, operands, operandTypes, |
| operandDims, resultTypes, resultDims, |
| tiedOperands))) { |
| return failure(); |
| } |
| } |
| return success(); |
| } |
| |
| void printShapedFunctionType(OpAsmPrinter &p, Operation *op, |
| ValueRange operands, TypeRange operandTypes, |
| OperandRange operandDims, TypeRange resultTypes, |
| OperandRange resultDims, ArrayAttr tiedOperands) { |
| p << "("; |
| printShapedTypeList(p, op, operandTypes, operandDims); |
| p << ") -> "; |
| if (resultTypes.size() != 1) { |
| p << "("; |
| } |
| printShapedResultList(p, op, operands, operandTypes, operandDims, resultTypes, |
| resultDims, tiedOperands); |
| if (resultTypes.size() != 1) { |
| p << ")"; |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<ShapedFunctionSignature> |
| //===----------------------------------------------------------------------===// |
| // (%arg0: type {some.attr = 54 : index}, %arg1: type) -> (type, %arg1 as type) |
| |
| static ParseResult parseShapedFunctionArgumentList( |
| OpAsmParser &parser, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &args, |
| SmallVectorImpl<Type> &types, ArrayAttr &attrs) { |
| SmallVector<Attribute> argAttrsVec; |
| do { |
| OpAsmParser::UnresolvedOperand arg; |
| Type type; |
| NamedAttrList attrsVec; |
| if (failed(parser.parseOperand(arg)) || |
| failed(parser.parseColonType(type)) || |
| failed(parser.parseOptionalAttrDict(attrsVec))) { |
| return failure(); |
| } |
| args.push_back(arg); |
| types.push_back(type); |
| argAttrsVec.push_back(parser.getBuilder().getDictionaryAttr(attrsVec)); |
| } while (succeeded(parser.parseOptionalComma())); |
| if (!argAttrsVec.empty()) { |
| attrs = parser.getBuilder().getArrayAttr(argAttrsVec); |
| } |
| return success(); |
| } |
| |
| static ParseResult parseShapedFunctionResultList( |
| OpAsmParser &parser, ArrayRef<OpAsmParser::UnresolvedOperand> args, |
| TypeRange argTypes, SmallVectorImpl<Type> &resultTypes, |
| ArrayAttr &resultAttrs, ArrayAttr &tiedOperands, bool allowResultAttrs) { |
| SmallVector<Attribute> resultAttrsVec; |
| SmallVector<int64_t> tiedOperandIndices; |
| do { |
| OpAsmParser::UnresolvedOperand tiedResult; |
| auto res = parser.parseOptionalOperand(tiedResult); |
| Type type; |
| int64_t tiedOperandIndex = IREE::Util::TiedOpInterface::kUntiedIndex; |
| if (res.has_value() && succeeded(res.value())) { |
| tiedOperandIndex = findTiedOperand(tiedResult, args); |
| if (tiedOperandIndex == IREE::Util::TiedOpInterface::kUntiedIndex) { |
| return parser.emitError(tiedResult.location, |
| "tied operand not found for result reference ") |
| << tiedResult.name; |
| } |
| if (succeeded(parser.parseOptionalKeyword("as"))) { |
| // Type _may_ differ from the operand. |
| if (failed(parser.parseType(type))) { |
| return failure(); |
| } |
| } else { |
| // Use the operands type. |
| type = argTypes[tiedOperandIndex]; |
| } |
| } else if (failed(parser.parseType(type))) { |
| return failure(); |
| } |
| NamedAttrList attrs; |
| if (allowResultAttrs && failed(parser.parseOptionalAttrDict(attrs))) { |
| return failure(); |
| } |
| resultTypes.push_back(type); |
| resultAttrsVec.push_back(parser.getBuilder().getDictionaryAttr(attrs)); |
| tiedOperandIndices.push_back(tiedOperandIndex); |
| } while (succeeded(parser.parseOptionalComma())); |
| if (!resultAttrsVec.empty()) { |
| resultAttrs = parser.getBuilder().getArrayAttr(resultAttrsVec); |
| } |
| if (!tiedOperandIndices.empty()) { |
| tiedOperands = parser.getBuilder().getIndexArrayAttr(tiedOperandIndices); |
| } |
| return success(); |
| } |
| |
| static void printShapedFunctionResultList(OpAsmPrinter &p, Operation *op, |
| TypeRange argTypes, |
| TypeRange resultTypes, |
| ArrayAttr resultAttrs, |
| ArrayAttr tiedOperands) { |
| for (unsigned i = 0; i < resultTypes.size(); ++i) { |
| auto resultType = resultTypes[i]; |
| auto tiedOperandIndex = |
| IREE::Util::detail::getTiedResultOperandIndex(op, i); |
| bool printType = true; |
| if (tiedOperandIndex.has_value()) { |
| p << "%arg" << tiedOperandIndex.value(); |
| if (argTypes[tiedOperandIndex.value()] != resultType) { |
| p << " as "; |
| } else { |
| // Type elided as it matches the operand. |
| printType = false; |
| } |
| } |
| if (printType) { |
| p.printType(resultType); |
| } |
| if (resultAttrs) { |
| auto attrs = |
| dyn_cast_if_present<DictionaryAttr>(resultAttrs.getValue()[i]); |
| if (attrs && !attrs.empty()) { |
| p.printOptionalAttrDict(attrs.getValue()); |
| } |
| } |
| if (i < resultTypes.size() - 1) { |
| p << ", "; |
| } |
| } |
| } |
| |
| ParseResult parseShapedFunctionSignature(OpAsmParser &parser, |
| TypeAttr &functionTypeAttr, |
| ArrayAttr &tiedOperands, |
| ArrayAttr &argAttrs, |
| ArrayAttr &resultAttrs) { |
| SmallVector<OpAsmParser::UnresolvedOperand> args; |
| SmallVector<Type> argTypes; |
| SmallVector<Type> resultTypes; |
| if (failed(parser.parseLParen())) { |
| return failure(); |
| } |
| if (failed(parser.parseOptionalRParen())) { |
| if (failed(parseShapedFunctionArgumentList(parser, args, argTypes, |
| argAttrs)) || |
| failed(parser.parseRParen())) { |
| return failure(); |
| } |
| } |
| if (succeeded(parser.parseOptionalArrow())) { |
| if (succeeded(parser.parseOptionalLParen())) { |
| if (failed(parseShapedFunctionResultList( |
| parser, args, argTypes, resultTypes, resultAttrs, tiedOperands, |
| /*allowResultAttrs=*/true)) || |
| failed(parser.parseRParen())) { |
| return failure(); |
| } |
| } else { |
| if (failed(parseShapedFunctionResultList( |
| parser, args, argTypes, resultTypes, resultAttrs, tiedOperands, |
| /*allowResultAttrs=*/false))) { |
| return failure(); |
| } |
| } |
| } |
| functionTypeAttr = TypeAttr::get( |
| FunctionType::get(parser.getContext(), argTypes, resultTypes)); |
| return success(); |
| } |
| |
| void printShapedFunctionSignature(OpAsmPrinter &p, Operation *op, |
| TypeAttr functionTypeAttr, |
| ArrayAttr tiedOperands, ArrayAttr argAttrs, |
| ArrayAttr resultAttrs) { |
| auto functionType = cast<FunctionType>(functionTypeAttr.getValue()); |
| p << "("; |
| int argIndex = 0; |
| llvm::interleaveComma(functionType.getInputs(), p, [&](auto type) { |
| p << "%arg"; |
| p << argIndex; |
| p << ": "; |
| p.printType(type); |
| if (argAttrs) { |
| auto attrs = |
| dyn_cast_if_present<DictionaryAttr>(argAttrs.getValue()[argIndex]); |
| if (attrs && !attrs.empty()) { |
| p.printOptionalAttrDict(attrs.getValue()); |
| } |
| } |
| ++argIndex; |
| }); |
| p << ")"; |
| auto resultTypes = functionType.getResults(); |
| if (!resultTypes.empty()) { |
| p << " -> "; |
| bool anyResultAttrs = |
| resultAttrs && !resultAttrs.empty() && |
| llvm::any_of(resultAttrs.getAsValueRange<DictionaryAttr>(), |
| [](auto attr) { return !attr.empty(); }); |
| if (resultTypes.size() != 1 || anyResultAttrs) { |
| p << "("; |
| } |
| printShapedFunctionResultList(p, op, functionType.getInputs(), resultTypes, |
| resultAttrs, tiedOperands); |
| if (resultTypes.size() != 1 || anyResultAttrs) { |
| p << ")"; |
| } |
| } |
| } |
| |
| } // namespace mlir::iree_compiler |
| |
| namespace mlir::iree_compiler::IREE::Util { |
| |
| //===----------------------------------------------------------------------===// |
| // util.align |
| //===----------------------------------------------------------------------===// |
| |
| void AlignOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
| SetIntRangeFn setResultRange) { |
| auto constantAlignment = argRanges[1].getConstantValue(); |
| // Note that for non constant alignment, there may still be something we |
| // want to infer, but this is left for the future. |
| if (constantAlignment && !constantAlignment->isZero()) { |
| // We can align the range directly. |
| // (value + (alignment - 1)) & ~(alignment - 1) |
| // https://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding |
| APInt umin = argRanges[0].umin(); |
| APInt umax = argRanges[0].umax(); |
| APInt one(constantAlignment->getBitWidth(), 1); |
| APInt alignmentM1 = *constantAlignment - one; |
| APInt alignmentM1Inv = ~alignmentM1; |
| auto align = [&](APInt value, bool &invalid) -> APInt { |
| APInt aligned = (value + alignmentM1) & alignmentM1Inv; |
| // Detect overflow, which commonly happens at max range. |
| if (aligned.ult(value)) { |
| invalid = true; |
| } |
| return aligned; |
| }; |
| bool invalid = false; |
| auto alignedUmin = align(umin, invalid); |
| auto alignedUmax = align(umax, invalid); |
| if (!invalid) { |
| setResultRange(getResult(), |
| ConstantIntRanges::fromUnsigned(alignedUmin, alignedUmax)); |
| } |
| } |
| } |
| |
| void AlignOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs, |
| SetIntDivisibilityFn setResultDivs) { |
| auto alignmentDiv = argDivs[1]; |
| if (alignmentDiv.isUninitialized()) { |
| return; |
| } |
| setResultDivs(getResult(), alignmentDiv.getValue()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // util.assume.int |
| //===----------------------------------------------------------------------===// |
| |
| SmallVector<IntAssumptionAttr> |
| AssumeIntOp::getOperandAssumptions(unsigned operandIndex) { |
| assert(operandIndex < getNumOperands() && |
| "getUnionedUnsignedRange operand out of range"); |
| auto assumptions = cast<ArrayAttr>(getAssumptions()[operandIndex]); |
| SmallVector<IntAssumptionAttr> results; |
| for (auto assumption : assumptions) { |
| results.push_back(cast<IntAssumptionAttr>(assumption)); |
| } |
| return results; |
| } |
| |
| std::pair<std::optional<uint64_t>, std::optional<uint64_t>> |
| AssumeIntOp::getUnionedUnsignedRange(unsigned operandIndex) { |
| auto assumptions = getOperandAssumptions(operandIndex); |
| std::optional<uint64_t> uminUnion; |
| int uminCount = 0; |
| std::optional<uint64_t> umaxUnion; |
| int umaxCount = 0; |
| |
| for (auto assumption : assumptions) { |
| auto umin = assumption.getUmin(); |
| auto umax = assumption.getUmax(); |
| if (umin) { |
| uminUnion = uminUnion ? std::min(*umin, *uminUnion) : *umin; |
| uminCount += 1; |
| } |
| if (umax) { |
| umaxUnion = umaxUnion ? std::max(*umax, *umaxUnion) : *umax; |
| umaxCount += 1; |
| } |
| } |
| return std::make_pair( |
| uminCount == assumptions.size() ? uminUnion : std::nullopt, |
| umaxCount == assumptions.size() ? umaxUnion : std::nullopt); |
| } |
| |
| static bool isConstantZero(IntAssumptionAttr assumption) { |
| std::optional<uint64_t> umin = assumption.getUmin(); |
| std::optional<uint64_t> umax = assumption.getUmax(); |
| if (!umin || !umax) { |
| return false; |
| } |
| return *umin == 0 && *umax == 0; |
| } |
| |
| std::optional<uint64_t> |
| AssumeIntOp::getUnionedUnsignedDivisor(unsigned operandIndex) { |
| auto assumptions = getOperandAssumptions(operandIndex); |
| std::optional<uint64_t> divisorUnion; |
| for (auto assumption : assumptions) { |
| auto divisor = assumption.getUdiv(); |
| if (!divisor) { |
| // Constant zero is divisible by anything |
| if (isConstantZero(assumption)) { |
| continue; |
| } |
| return std::nullopt; |
| } |
| if (divisorUnion) { |
| divisorUnion = std::gcd(*divisor, *divisorUnion); |
| } else { |
| divisorUnion = *divisor; |
| } |
| } |
| return divisorUnion; |
| } |
| |
| void AssumeIntOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
| SetIntRangeFn setResultRange) { |
| for (auto [index, result] : llvm::enumerate(getResults())) { |
| Type type = result.getType(); |
| unsigned bitWidth; |
| if (isa<IndexType>(type)) { |
| bitWidth = 64; |
| } else if (auto intType = dyn_cast<IntegerType>(type)) { |
| bitWidth = intType.getWidth(); |
| } else { |
| continue; |
| } |
| auto [umin, umax] = getUnionedUnsignedRange(index); |
| auto uminAp = APInt::getMinValue(bitWidth); |
| auto umaxAp = APInt::getMaxValue(bitWidth); |
| if (umin) { |
| uminAp = APInt(bitWidth, *umin); |
| } |
| if (umax) { |
| umaxAp = APInt(bitWidth, *umax); |
| } |
| |
| setResultRange(result, ConstantIntRanges::fromUnsigned(uminAp, umaxAp)); |
| } |
| } |
| |
| void AssumeIntOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs, |
| SetIntDivisibilityFn setResultDivs) { |
| for (auto [index, result] : llvm::enumerate(getResults())) { |
| Type type = result.getType(); |
| if (!isa<IndexType>(type) && !isa<IntegerType>(type)) { |
| continue; |
| } |
| auto udiv = getUnionedUnsignedDivisor(index); |
| if (udiv) { |
| setResultDivs(result, |
| ConstantIntDivisibility(/*udiv=*/*udiv, /*sdiv=*/*udiv)); |
| } |
| } |
| } |
| |
| void AssumeIntOp::build(OpBuilder &builder, OperationState &state, |
| Value singleOperand, |
| IntAssumptionAttr singleAssumption) { |
| state.addOperands({singleOperand}); |
| state.addTypes({singleOperand.getType()}); |
| state.addAttribute("assumptions", builder.getArrayAttr(builder.getArrayAttr( |
| {singleAssumption}))); |
| } |
| |
| void AssumeIntOp::build(OpBuilder &builder, OperationState &state, |
| ArrayRef<Value> operands, |
| ArrayRef<ArrayAttr> assumptions) { |
| state.addOperands(operands); |
| for (auto operand : operands) { |
| state.addTypes({operand.getType()}); |
| } |
| state.addAttribute("assumptions", |
| ArrayAttr::get(builder.getContext(), |
| ArrayRef<Attribute>(assumptions.begin(), |
| assumptions.end()))); |
| } |
| |
| LogicalResult AssumeIntOp::verify() { |
| ArrayAttr allOperandAssumptions = getAssumptions(); |
| // Verify that there is an assumption row per operand. |
| if (getNumOperands() != allOperandAssumptions.size()) { |
| return emitOpError() << "expected " << getNumOperands() |
| << " assumption rows to match number of operands"; |
| } |
| |
| std::optional<int> rank; |
| for (auto [index, operandAssumptionsAttr] : |
| llvm::enumerate(allOperandAssumptions)) { |
| auto operandAssumptions = cast<ArrayAttr>(operandAssumptionsAttr); |
| // We always allow a single row to broadcast to any requested size. |
| if (operandAssumptions.size() == 1) { |
| continue; |
| } |
| if (rank && *rank != operandAssumptions.size()) { |
| return emitOpError() << "expected operand #" << index << " to have " |
| << *rank << " assumptions but it has " |
| << operandAssumptions.size(); |
| } |
| rank = operandAssumptions.size(); |
| } |
| |
| return success(); |
| } |
| |
| ParseResult AssumeIntOp::parse(OpAsmParser &parser, OperationState &result) { |
| SmallVector<Attribute> allOperandAssumptions; |
| SmallVector<OpAsmParser::UnresolvedOperand> parsedOperands; |
| SmallVector<Type> parsedOperandTypes; |
| |
| if (parser.parseCommaSeparatedList([&]() { |
| parsedOperands.emplace_back(); |
| OpAsmParser::UnresolvedOperand &parsedOperand = parsedOperands.back(); |
| SmallVector<Attribute> operandAssumptions; |
| |
| if (parser.parseOperand(parsedOperand)) { |
| return failure(); |
| } |
| |
| // Parse as a single assumption or a list. |
| if (failed(parser.parseOptionalLSquare())) { |
| // Single assumption. |
| IntAssumptionAttr singleAssumption; |
| if (parser.parseCustomAttributeWithFallback(singleAssumption)) { |
| return failure(); |
| } |
| operandAssumptions.push_back(singleAssumption); |
| } else { |
| // Multiple assumptions. |
| if (failed(parser.parseOptionalRSquare())) { |
| if (parser.parseCommaSeparatedList([&]() { |
| IntAssumptionAttr singleAssumption; |
| if (parser.parseCustomAttributeWithFallback( |
| singleAssumption)) { |
| return failure(); |
| } |
| operandAssumptions.push_back(singleAssumption); |
| return success(); |
| })) { |
| return failure(); |
| } |
| if (parser.parseRSquare()) { |
| return failure(); |
| } |
| } |
| } |
| |
| // Finalize operand. |
| allOperandAssumptions.push_back( |
| parser.getBuilder().getArrayAttr(operandAssumptions)); |
| |
| return success(); |
| })) { |
| return failure(); |
| } |
| |
| // Parse `:` type. |
| if (parser.parseColon() || parser.parseTypeList(parsedOperandTypes)) { |
| return failure(); |
| } |
| result.addTypes(parsedOperandTypes); |
| |
| if (parser.resolveOperands(parsedOperands, parsedOperandTypes, |
| parser.getNameLoc(), result.operands)) { |
| return failure(); |
| } |
| |
| result.attributes.append( |
| "assumptions", parser.getBuilder().getArrayAttr(allOperandAssumptions)); |
| if (parser.parseOptionalAttrDict(result.attributes)) { |
| return failure(); |
| } |
| |
| return success(); |
| } |
| |
| void AssumeIntOp::print(OpAsmPrinter &p) { |
| p << " "; |
| bool multiLine = getOperands().size() > 1; |
| if (multiLine) { |
| p.increaseIndent(); |
| p.increaseIndent(); |
| p.printNewline(); |
| } |
| ArrayAttr allOperandAssumptions = getAssumptions(); |
| for (auto [index, operand] : llvm::enumerate(getOperands())) { |
| if (index > 0) { |
| p << ", "; |
| if (multiLine) { |
| p.printNewline(); |
| } |
| } |
| ArrayAttr operandAssumptions = |
| cast<ArrayAttr>(allOperandAssumptions[index]); |
| p.printOperand(operand); |
| |
| // Print the assumptions, either as a single assumption or list. |
| if (operandAssumptions.size() == 1) { |
| p.printStrippedAttrOrType(cast<IntAssumptionAttr>(operandAssumptions[0])); |
| } else { |
| p << "["; |
| llvm::interleaveComma( |
| operandAssumptions, p.getStream(), [&](Attribute attr) { |
| p.printStrippedAttrOrType(cast<IntAssumptionAttr>(attr)); |
| }); |
| p << "]"; |
| } |
| } |
| |
| if (multiLine) { |
| p.decreaseIndent(); |
| p.printNewline(); |
| } else { |
| p << " "; |
| } |
| p << ": "; |
| llvm::interleaveComma(getOperands(), p.getStream(), |
| [&](Value operand) { p.printType(operand.getType()); }); |
| p.printOptionalAttrDict((*this)->getAttrs(), {"assumptions"}); |
| if (multiLine) { |
| p.decreaseIndent(); |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // util.optimization_barrier |
| //===----------------------------------------------------------------------===// |
| |
| void OptimizationBarrierOp::build(OpBuilder &builder, OperationState &state, |
| ValueRange operands, |
| ArrayRef<NamedAttribute> attributes) { |
| state.addOperands(operands); |
| state.addTypes(llvm::to_vector<2>(operands.getTypes())); |
| state.addAttributes(attributes); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // util.unfoldable_constant |
| //===----------------------------------------------------------------------===// |
| |
| // Parsing/printing copied from std.constant |
| |
| ParseResult UnfoldableConstantOp::parse(OpAsmParser &parser, |
| OperationState &state) { |
| Attribute valueAttr; |
| if (parser.parseOptionalAttrDict(state.attributes) || |
| parser.parseAttribute(valueAttr, "value", state.attributes)) { |
| return failure(); |
| } |
| |
| // If the attribute is a symbol reference, then we expect a trailing type. |
| Type type; |
| if (!isa<SymbolRefAttr>(valueAttr)) { |
| type = cast<TypedAttr>(valueAttr).getType(); |
| } else if (parser.parseColonType(type)) { |
| return failure(); |
| } |
| |
| // Add the attribute type to the list. |
| return parser.addTypeToList(type, state.types); |
| } |
| |
| void UnfoldableConstantOp::print(OpAsmPrinter &p) { |
| Operation *op = getOperation(); |
| p << " "; |
| p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"}); |
| |
| if (op->getAttrs().size() > 1) { |
| p << ' '; |
| } |
| p << getValue(); |
| |
| // If the value is a symbol reference, print a trailing type. |
| if (isa<SymbolRefAttr>(getValue())) { |
| p << " : " << getType(); |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Type manipulation |
| //===----------------------------------------------------------------------===// |
| |
| bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
| if (inputs.size() != 1 || outputs.size() != 1) { |
| return false; |
| } |
| Type a = inputs.front(), b = outputs.front(); |
| if (a == b) { |
| // Both types are the same. |
| return true; |
| } |
| if (isa<IREE::Util::ObjectType>(a) || isa<IREE::Util::ObjectType>(b)) { |
| // Either type is an opaque object. |
| return true; |
| } |
| // Don't currently allow casting between types as we don't have runtime |
| // support for such operations (we don't generally care in the VM). |
| return false; |
| } |
| |
| LogicalResult CastOp::verify() { |
| auto operandType = getOperand().getType(); |
| if (!IREE::Util::ObjectType::isCompatible(operandType)) { |
| return this->emitOpError() << "operand type " << operandType |
| << " is not object cast compatible"; |
| } |
| auto resultType = getResult().getType(); |
| if (!IREE::Util::ObjectType::isCompatible(resultType)) { |
| return this->emitOpError() |
| << "result type " << resultType << " is not object cast compatible"; |
| } |
| return success(); |
| } |
| |
| Value CastOp::getTiedResult(unsigned resultIndex) { |
| return IREE::Util::TiedOpInterface::findTiedBaseValue(getOperand()); |
| } |
| |
| Value CastOp::getTiedResultOperand(Value result) { return getOperand(); } |
| |
| ::std::optional<unsigned> |
| CastOp::getTiedResultOperandIndex(unsigned resultIndex) { |
| return {0}; // operand |
| } |
| |
| SmallVector<int64_t> CastOp::getTiedResultOperandIndices() { |
| return {0}; // operand |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Numeric ops |
| //===----------------------------------------------------------------------===// |
| |
| std::optional<std::pair<int64_t, int64_t>> |
| NumericOptionalNarrowOp::getIntegerRange() { |
| if (!getMinValue() || !getMaxValue()) { |
| return {}; |
| } |
| bool signExtend = isSigned(); |
| // Note: Cannot sign extend 0 bit values. |
| int64_t minValue = signExtend && getMinValue()->getBitWidth() > 0 |
| ? getMinValue()->getSExtValue() |
| : getMinValue()->getZExtValue(); |
| int64_t maxValue = signExtend && getMaxValue()->getBitWidth() > 0 |
| ? getMaxValue()->getSExtValue() |
| : getMaxValue()->getZExtValue(); |
| return std::make_pair(minValue, maxValue); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // util.initializer |
| //===----------------------------------------------------------------------===// |
| |
| void InitializerOp::build(OpBuilder &builder, OperationState &result, |
| ArrayRef<NamedAttribute> attrs) { |
| result.addAttribute("function_type", TypeAttr::get(FunctionType::get( |
| builder.getContext(), {}, {}))); |
| result.addRegion(); |
| result.attributes.append(attrs.begin(), attrs.end()); |
| } |
| |
| ParseResult InitializerOp::parse(OpAsmParser &parser, OperationState &result) { |
| result.addAttribute("function_type", TypeAttr::get(FunctionType::get( |
| result.getContext(), {}, {}))); |
| if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) { |
| return failure(); |
| } |
| auto &body = *result.addRegion(); |
| if (failed(parser.parseRegion(body))) { |
| return failure(); |
| } |
| return success(); |
| } |
| |
| void InitializerOp::print(OpAsmPrinter &p) { |
| Operation *op = getOperation(); |
| p.printOptionalAttrDictWithKeyword(op->getAttrs(), |
| /*elidedAttrs=*/{"function_type"}); |
| p << " "; |
| p.printRegion(getBody()); |
| } |
| |
| Block *InitializerOp::addEntryBlock() { |
| assert(empty() && "function already has an entry block"); |
| auto *entry = new Block(); |
| push_back(entry); |
| return entry; |
| } |
| |
| Block *InitializerOp::addBlock() { |
| assert(!empty() && "function should at least have an entry block"); |
| push_back(new Block()); |
| return &back(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // util.func |
| //===----------------------------------------------------------------------===// |
| |
| FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, |
| ArrayRef<int64_t> tiedOperands, |
| ArrayRef<NamedAttribute> attrs, |
| ArrayRef<DictionaryAttr> argAttrs, |
| ArrayRef<DictionaryAttr> resAttrs) { |
| OpBuilder builder(location->getContext()); |
| OperationState state(location, getOperationName()); |
| FuncOp::build(builder, state, name, type, |
| tiedOperands.empty() ? ArrayAttr{} |
| : builder.getIndexArrayAttr(tiedOperands), |
| attrs, argAttrs, resAttrs); |
| return cast<FuncOp>(Operation::create(state)); |
| } |
| |
| void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, |
| FunctionType type, ArrayAttr tiedOperands, |
| ArrayRef<NamedAttribute> attrs, |
| ArrayRef<DictionaryAttr> argAttrs, |
| ArrayRef<DictionaryAttr> resAttrs) { |
| state.addAttribute(SymbolTable::getSymbolAttrName(), |
| builder.getStringAttr(name)); |
| state.addAttribute(SymbolTable::getVisibilityAttrName(), |
| builder.getStringAttr("public")); |
| state.addAttribute("function_type", TypeAttr::get(type)); |
| state.attributes.append(attrs.begin(), attrs.end()); |
| state.attributes.erase(IREE::Util::TiedOpInterface::getStorageAttrName()); |
| if (tiedOperands) { |
| state.addAttribute(IREE::Util::TiedOpInterface::getStorageAttrName(), |
| tiedOperands); |
| } |
| state.addRegion(); |
| if (!argAttrs.empty() || !resAttrs.empty()) { |
| assert(type.getNumInputs() == argAttrs.size()); |
| assert(type.getNumResults() == resAttrs.size()); |
| call_interface_impl::addArgAndResultAttrs( |
| builder, state, argAttrs, resAttrs, builder.getStringAttr("arg_attrs"), |
| builder.getStringAttr("res_attrs")); |
| } |
| } |
| |
| static ParseResult |
| parseFunctionArgumentList(OpAsmParser &parser, |
| SmallVectorImpl<OpAsmParser::Argument> &arguments) { |
| return parser.parseCommaSeparatedList( |
| OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { |
| OpAsmParser::Argument argument; |
| auto argPresent = parser.parseOptionalArgument( |
| argument, /*allowType=*/true, /*allowAttrs=*/true); |
| if (argPresent.has_value()) { |
| if (failed(argPresent.value())) { |
| return failure(); // Present but malformed. |
| } |
| if (!arguments.empty() && arguments.back().ssaName.name.empty()) { |
| return parser.emitError(argument.ssaName.location, |
| "expected type instead of SSA identifier"); |
| } |
| |
| } else { |
| argument.ssaName.location = parser.getCurrentLocation(); |
| if (!arguments.empty() && !arguments.back().ssaName.name.empty()) { |
| return parser.emitError(argument.ssaName.location, |
| "expected SSA identifier"); |
| } |
| NamedAttrList attrs; |
| if (parser.parseType(argument.type) || |
| parser.parseOptionalAttrDict(attrs) || |
| parser.parseOptionalLocationSpecifier(argument.sourceLoc)) { |
| return failure(); |
| } |
| argument.attrs = attrs.getDictionary(parser.getContext()); |
| } |
| arguments.push_back(argument); |
| return success(); |
| }); |
| } |
| |
| ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { |
| auto &builder = parser.getBuilder(); |
| |
| StringAttr symVisibilityAttr; |
| if (failed(parseSymbolVisibility(parser, symVisibilityAttr))) { |
| return failure(); |
| } |
| if (symVisibilityAttr) { |
| result.addAttribute(SymbolTable::getVisibilityAttrName(), |
| symVisibilityAttr); |
| } |
| |
| StringAttr nameAttr; |
| if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), |
| result.attributes)) { |
| return failure(); |
| } |
| |
| SmallVector<OpAsmParser::Argument> arguments; |
| if (parseFunctionArgumentList(parser, arguments)) { |
| return failure(); |
| } |
| |
| SmallVector<Type> resultTypes; |
| SmallVector<DictionaryAttr> resultAttrs; |
| ArrayAttr tiedOperands; |
| if (succeeded(parser.parseOptionalArrow())) { |
| if (failed(parseTiedFunctionResultList(parser, arguments, resultTypes, |
| resultAttrs, tiedOperands))) { |
| return failure(); |
| } |
| } |
| if (tiedOperands) { |
| result.addAttribute("tied_operands", tiedOperands); |
| } |
| |
| SmallVector<Type> argumentTypes; |
| for (auto argument : arguments) { |
| argumentTypes.push_back(argument.type); |
| } |
| result.addAttribute("function_type", TypeAttr::get(builder.getFunctionType( |
| argumentTypes, resultTypes))); |
| |
| NamedAttrList parsedAttributes; |
| SMLoc attributeDictLocation = parser.getCurrentLocation(); |
| if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes)) { |
| return failure(); |
| } |
| for (StringRef disallowed : { |
| SymbolTable::getVisibilityAttrName(), |
| SymbolTable::getSymbolAttrName(), |
| StringRef("function_type"), |
| }) { |
| if (parsedAttributes.get(disallowed)) { |
| return parser.emitError(attributeDictLocation, "'") |
| << disallowed |
| << "' is an inferred attribute and should not be specified in the " |
| "explicit attribute dictionary"; |
| } |
| } |
| result.attributes.append(parsedAttributes); |
| |
| assert(resultAttrs.size() == resultTypes.size()); |
| call_interface_impl::addArgAndResultAttrs( |
| builder, result, arguments, resultAttrs, |
| builder.getStringAttr("arg_attrs"), builder.getStringAttr("res_attrs")); |
| |
| auto *body = result.addRegion(); |
| SMLoc loc = parser.getCurrentLocation(); |
| auto parseResult = parser.parseOptionalRegion(*body, arguments, |
| /*enableNameShadowing=*/false); |
| if (parseResult.has_value()) { |
| if (failed(*parseResult)) { |
| return failure(); |
| } |
| if (body->empty()) { |
| return parser.emitError(loc, "expected non-empty function body"); |
| } |
| } |
| return success(); |
| } |
| |
| void FuncOp::print(OpAsmPrinter &p) { |
| p << ' '; |
| printSymbolVisibility(p, *this, getSymVisibilityAttr()); |
| p << ' '; |
| p.printSymbolName(getSymName()); |
| printShapedFunctionSignature(p, *this, getFunctionTypeAttr(), |
| getTiedOperandsAttr(), getArgAttrsAttr(), |
| getResAttrsAttr()); |
| p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), |
| /*elidedAttrs=*/{ |
| "sym_name", |
| "function_type", |
| "tied_operands", |
| "sym_visibility", |
| "arg_attrs", |
| "res_attrs", |
| }); |
| if (!getBody().empty()) { |
| p << ' '; |
| p.printRegion(getBody(), /*printEntryBlockArgs=*/false); |
| } |
| } |
| |
| bool IREE::Util::FuncOp::canDiscardOnUseEmpty() { |
| return getVisibility() != SymbolTable::Visibility::Public && |
| !anyAncestorHasAnyRefsToSymbol(this->getOperation(), getSymNameAttr()); |
| } |
| |
| bool IREE::Util::FuncOp::hasAnyTiedOperands() { |
| auto tiedOperandsAttr = getTiedOperandsAttr(); |
| if (!tiedOperandsAttr) { |
| return false; |
| } |
| return llvm::any_of( |
| tiedOperandsAttr.getAsRange<IntegerAttr>(), [](IntegerAttr attr) { |
| return attr.getInt() != IREE::Util::TiedOpInterface::kUntiedIndex; |
| }); |
| } |
| |
| void IREE::Util::FuncOp::expandSignature( |
| std::function<void(unsigned, Type, SmallVectorImpl<Type> &)> expandArgument, |
| std::function<void(unsigned, Type, SmallVectorImpl<Type> &)> expandResult) { |
| auto oldType = getFunctionType(); |
| |
| SmallVector<DictionaryAttr> oldArgumentAttrs; |
| getAllArgAttrs(oldArgumentAttrs); |
| SmallVector<DictionaryAttr> oldResultAttrs; |
| getAllResultAttrs(oldResultAttrs); |
| |
| SmallVector<int64_t> adjustedTiedOperands; |
| IREE::Util::detail::getAllTiedOperands(getOperation(), adjustedTiedOperands); |
| |
| SmallVector<Type> newArgumentTypes; |
| SmallVector<DictionaryAttr> newArgumentAttrs; |
| for (auto [oldIndex, argType] : llvm::enumerate(oldType.getInputs())) { |
| size_t newIndex = newArgumentTypes.size(); |
| expandArgument(oldIndex, argType, newArgumentTypes); |
| size_t expandedCount = newArgumentTypes.size() - newIndex; |
| if (expandedCount == 0) { |
| continue; |
| } |
| for (size_t i = 0; i < adjustedTiedOperands.size(); ++i) { |
| if (adjustedTiedOperands[i] == oldIndex) { |
| adjustedTiedOperands[i] = newIndex; |
| } |
| } |
| newArgumentAttrs.push_back(oldArgumentAttrs[oldIndex]); |
| newArgumentAttrs.append(expandedCount - 1, |
| DictionaryAttr::get(getContext())); |
| } |
| |
| SmallVector<Type> newResultTypes; |
| SmallVector<int64_t> newTiedOperands; |
| SmallVector<DictionaryAttr> newResultAttrs; |
| for (auto [oldIndex, resultType] : llvm::enumerate(oldType.getResults())) { |
| size_t newIndex = newResultTypes.size(); |
| expandResult(oldIndex, resultType, newResultTypes); |
| size_t expandedCount = newResultTypes.size() - newIndex; |
| if (expandedCount == 0) { |
| continue; |
| } |
| newTiedOperands.push_back(adjustedTiedOperands[oldIndex]); |
| newTiedOperands.append(expandedCount - 1, |
| IREE::Util::TiedOpInterface::kUntiedIndex); |
| newResultAttrs.push_back(oldResultAttrs[oldIndex]); |
| newResultAttrs.append(expandedCount - 1, DictionaryAttr::get(getContext())); |
| } |
| |
| auto newType = |
| FunctionType::get(getContext(), newArgumentTypes, newResultTypes); |
| if (newType != oldType) { |
| setFunctionType(newType); |
| setTiedOperandsAttr(ArrayAttr::get( |
| getContext(), |
| llvm::map_to_vector<8>(newTiedOperands, [&](int64_t v) -> Attribute { |
| return IntegerAttr::get(IndexType::get(getContext()), v); |
| }))); |
| setAllArgAttrs(newArgumentAttrs); |
| setAllResultAttrs(newResultAttrs); |
| } |
| } |
| |
| LogicalResult FuncOp::verify() { |
| // Get the tied_operands attribute if it exists |
| if (auto tiedOperandsAttr = getTiedOperandsAttr()) { |
| // Each index must be valid |
| unsigned numOperands = getFunctionType().getNumInputs(); |
| for (auto [resultIdx, attr] : llvm::enumerate(tiedOperandsAttr)) { |
| int64_t operandIdx = cast<IntegerAttr>(attr).getInt(); |
| |
| // Allow -1 (untied) |
| if (operandIdx == IREE::Util::TiedOpInterface::kUntiedIndex) { |
| continue; |
| } |
| |
| // Check if operand index is in valid range |
| if (operandIdx < 0 || operandIdx >= static_cast<int64_t>(numOperands)) { |
| return emitOpError() << "result #" << resultIdx |
| << " tied to invalid operand index " << operandIdx; |
| } |
| } |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // util.call |
| //===----------------------------------------------------------------------===// |
| |
| FunctionType CallOp::getCalleeType() { |
| return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); |
| } |
| |
| static bool areTiedOperandsEqual(ArrayAttr a, ArrayAttr b) { |
| auto hasAnyTied = [](ArrayAttr tiedOperandsAttr) { |
| if (!tiedOperandsAttr) { |
| return false; |
| } |
| return llvm::any_of( |
| tiedOperandsAttr.getAsRange<IntegerAttr>(), [](IntegerAttr attr) { |
| return attr.getInt() != IREE::Util::TiedOpInterface::kUntiedIndex; |
| }); |
| }; |
| bool hasAnyTiedA = hasAnyTied(a); |
| bool hasAnyTiedB = hasAnyTied(b); |
| if (hasAnyTiedA != hasAnyTiedB) { |
| return false; |
| } |
| if (!a || !b) { |
| return true; |
| } |
| return a == b; |
| } |
| |
| LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
| Operation *op = getOperation(); |
| |
| // Only support calls to util.func. |
| auto calleeOp = symbolTable.lookupNearestSymbolFrom<IREE::Util::FuncOp>( |
| op, getCalleeAttr()); |
| if (!calleeOp) { |
| return op->emitOpError("undefined/incompatible callee: ") << getCallee(); |
| } |
| |
| // Ensure that the arg/result types match. |
| auto expectedType = getCalleeType(); |
| auto calleeType = calleeOp.getFunctionType(); |
| if (calleeType != expectedType) { |
| return emitOpError("function type mismatch; expected ") |
| << expectedType << " but callee is " << calleeType; |
| } |
| |
| // Ensure tied operands are consistent. |
| auto callerTiedOperands = getTiedOperandsAttr(); |
| auto calleeTiedOperands = calleeOp.getTiedOperandsAttr(); |
| if (!areTiedOperandsEqual(calleeTiedOperands, callerTiedOperands)) { |
| return emitOpError("function tied operands mismatch; have ") |
| << callerTiedOperands << " but callee is " << calleeTiedOperands; |
| } |
| |
| return success(); |
| } |
| |
| IREE::Util::CallOp IREE::Util::CallOp::cloneAndExpand( |
| std::function<void(unsigned, Value, SmallVectorImpl<Value> &)> |
| expandOperand, |
| std::function<void(unsigned, Type, SmallVectorImpl<Type> &)> expandResult, |
| OpBuilder &builder) { |
| SmallVector<int64_t> adjustedTiedOperands; |
| IREE::Util::detail::getAllTiedOperands(getOperation(), adjustedTiedOperands); |
| |
| SmallVector<Value> newOperands; |
| for (auto [oldIndex, operand] : llvm::enumerate(getOperands())) { |
| size_t newIndex = newOperands.size(); |
| expandOperand(oldIndex, operand, newOperands); |
| for (size_t i = 0; i < adjustedTiedOperands.size(); ++i) { |
| if (adjustedTiedOperands[i] == oldIndex) { |
| adjustedTiedOperands[i] = newIndex; |
| } |
| } |
| } |
| |
| SmallVector<Type> newResultTypes; |
| SmallVector<int64_t> newTiedOperands; |
| for (auto [oldIndex, resultType] : llvm::enumerate(getResultTypes())) { |
| size_t newIndex = newResultTypes.size(); |
| expandResult(oldIndex, resultType, newResultTypes); |
| size_t expandedCount = newResultTypes.size() - newIndex; |
| newTiedOperands.push_back(adjustedTiedOperands[oldIndex]); |
| newTiedOperands.append(expandedCount - 1, |
| IREE::Util::TiedOpInterface::kUntiedIndex); |
| } |
| |
| return IREE::Util::CallOp::create(builder, getLoc(), newResultTypes, |
| getCallee(), newOperands, |
| builder.getIndexArrayAttr(newTiedOperands), |
| getArgAttrsAttr(), getResAttrsAttr()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // util.return |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ReturnOp::verify() { |
| Operation *op = getOperation(); |
| auto parentOp = cast<mlir::FunctionOpInterface>(op->getParentOp()); |
| auto expectedTypes = parentOp.getResultTypes(); |
| if (getNumOperands() != expectedTypes.size()) { |
| return emitOpError("has ") |
| << getNumOperands() |
| << " operands, but enclosing function-like op returns " |
| << expectedTypes.size(); |
| } |
| for (auto [i, expectedType, actualType] : |
| llvm::enumerate(expectedTypes, getOperandTypes())) { |
| if (expectedType != actualType) { |
| return emitOpError() << "type of return operand " << i << " (" |
| << actualType |
| << ") doesn't match function result type (" |
| << expectedType << ")"; |
| } |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // util.global |
| //===----------------------------------------------------------------------===// |
| |
| // TODO(benvanik): move entirely to the interface. |
| // Returns true if the given |accessType| is compatible with the |globalType|. |
| // For example, this will return true if the global type is a tensor<?xf32> |
| // and the access is tensor<4xf32>. |
| static bool isGlobalTypeCompatible(Type globalType, Type accessType) { |
| // If one is a shaped type, then they both must be and have compatible |
| // shapes. |
| if (isa<ShapedType>(globalType) && isa<ShapedType>(accessType)) { |
| return succeeded(mlir::verifyCompatibleShape(globalType, accessType)); |
| } |
| |
| if (auto knownType = dyn_cast<GlobalTypeInterface>(globalType)) { |
| return knownType.isAccessStorageCompatible(accessType); |
| } |
| |
| // Otherwise, the types must be the same. |
| return globalType == accessType; |
| } |
| |
| void GlobalOp::build(OpBuilder &builder, OperationState &result, StringRef name, |
| bool isMutable, Type type, |
| std::optional<TypedAttr> initialValue, |
| ArrayRef<NamedAttribute> attrs) { |
| result.addAttribute(SymbolTable::getSymbolAttrName(), |
| builder.getStringAttr(name)); |
| if (isMutable) { |
| result.addAttribute("is_mutable", builder.getUnitAttr()); |
| } |
| if (initialValue.has_value()) { |
| result.addAttribute("initial_value", initialValue.value()); |
| } |
| result.addAttribute("type", TypeAttr::get(type)); |
| result.attributes.append(attrs.begin(), attrs.end()); |
| } |
| |
| void GlobalOp::build(OpBuilder &builder, OperationState &result, StringRef name, |
| bool isMutable, Type type, |
| ArrayRef<NamedAttribute> attrs) { |
| build(builder, result, name, isMutable, type, std::nullopt, attrs); |
| } |
| |
| // This is a workaround for SymbolDCE not handling attribute references on |
| // unnamed modules. This, for example, will fail and @foo will be DCEd: |
| // builtin.module attributes { some.attr = @foo } { |
| // util.global private @foo : i64 |
| // } |
| // While this succeeds: |
| // builtin.module @module attributes { some.attr = @module::@foo } { |
| // util.global private @foo : i64 |
| // } |
| // Since nearly all modules we see are anonymous we'll commonly end up with |
| // attributes that need to reference nested symbols via anonymous modules. |
| // |
| // During DCE this is called and for each symbol we want to preserve we then |
| // walk up to the module and see if it has any attributes referencing it to |
| // prevent the DCE. |
| bool GlobalOp::canDiscardOnUseEmpty() { |
| return getVisibility() != SymbolTable::Visibility::Public && |
| !anyAncestorHasAnyRefsToSymbol(this->getOperation(), getSymNameAttr()); |
| } |
| |
| IREE::Util::GlobalLoadOpInterface GlobalOp::createLoadOp(Location loc, |
| OpBuilder &builder) { |
| // TODO(benvanik): create with the immutable flag if the global is immutable. |
| // Today we avoid this and let analysis add the immutable flag when safe |
| // (not in initializers/etc). |
| return IREE::Util::GlobalLoadOp::create(builder, loc, getType(), |
| getSymName()); |
| } |
| |
| IREE::Util::GlobalStoreOpInterface |
| GlobalOp::createStoreOp(Location loc, Value value, OpBuilder &builder) { |
| return IREE::Util::GlobalStoreOp::create(builder, loc, value, getSymName()); |
| } |
| |
| void GlobalAddressOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), Twine("ptr_" + getGlobal()).str()); |
| } |
| |
| void GlobalLoadOp::build(OpBuilder &builder, OperationState &state, |
| IREE::Util::GlobalOpInterface globalOp, |
| ArrayRef<NamedAttribute> attrs) { |
| state.addTypes({globalOp.getGlobalType()}); |
| state.addAttribute("global", SymbolRefAttr::get(globalOp)); |
| state.attributes.append(attrs.begin(), attrs.end()); |
| } |
| |
| void GlobalLoadOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), getGlobal()); |
| } |
| |
| void GlobalLoadOp::getEffects( |
| SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| // HACK: mlir doesn't have symbol side effects so we have to mark as a global |
| // read if not immutable and not in an initializer. |
| if (!isGlobalImmutable()) { |
| effects.emplace_back(MemoryEffects::Read::get()); |
| } |
| } |
| |
| LogicalResult |
| verifyGlobalLoadIndirectOp(IREE::Util::GlobalLoadIndirectOpInterface op) { |
| auto globalType = |
| cast<IREE::Util::PtrType>(op.getGlobal().getType()).getTargetType(); |
| auto loadType = op.getLoadedGlobalValue().getType(); |
| if (!isGlobalTypeCompatible(globalType, loadType)) { |
| return op->emitOpError() << "global type mismatch; global pointer is " |
| << globalType << " but load is " << loadType; |
| } |
| return success(); |
| } |
| |
| LogicalResult GlobalLoadIndirectOp::verify() { |
| return verifyGlobalLoadIndirectOp(*this); |
| } |
| |
| void GlobalStoreOp::build(OpBuilder &builder, OperationState &state, |
| Value value, IREE::Util::GlobalOpInterface globalOp, |
| ArrayRef<NamedAttribute> attrs) { |
| state.addOperands({value}); |
| state.addAttribute("global", SymbolRefAttr::get(globalOp)); |
| state.attributes.append(attrs.begin(), attrs.end()); |
| } |
| |
| LogicalResult GlobalStoreIndirectOp::verify() { |
| Operation *op = getOperation(); |
| auto globalType = |
| cast<IREE::Util::PtrType>(getGlobal().getType()).getTargetType(); |
| auto storeType = getValue().getType(); |
| if (!isGlobalTypeCompatible(globalType, storeType)) { |
| return op->emitOpError() << "global type mismatch; global pointer is " |
| << globalType << " but store is " << storeType; |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // !util.list<T> |
| //===----------------------------------------------------------------------===// |
| |
| static ParseResult |
| parseValueTypeList(OpAsmParser &parser, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, |
| SmallVectorImpl<Type> &types) { |
| if (parser.parseLSquare()) { |
| return failure(); |
| } |
| if (succeeded(parser.parseOptionalRSquare())) { |
| return success(); // empty list |
| } |
| do { |
| OpAsmParser::UnresolvedOperand value; |
| Type type; |
| if (parser.parseOperand(value) || parser.parseColon() || |
| parser.parseType(type)) { |
| return failure(); |
| } |
| values.push_back(value); |
| types.push_back(type); |
| } while (succeeded(parser.parseOptionalComma())); |
| return parser.parseRSquare(); |
| } |
| |
| static void printValueTypeList(OpAsmPrinter &p, Operation *, |
| OperandRange values, TypeRange types) { |
| p << "["; |
| llvm::interleaveComma(llvm::zip(values, types), p, [&](auto pair) { |
| p << std::get<0>(pair) << " : " << std::get<1>(pair); |
| }); |
| p << "]"; |
| } |
| |
| static ParseResult parseListTypeGet(OpAsmParser &parser, Type &listType, |
| Type &elementType) { |
| if (failed(parser.parseType(listType))) { |
| return parser.emitError(parser.getCurrentLocation(), |
| "expected !util.list<T> type"); |
| } |
| auto listElementType = cast<ListType>(listType).getElementType(); |
| if (succeeded(parser.parseOptionalArrow())) { |
| // Use overridden type - required for variants only. |
| if (failed(parser.parseType(elementType))) { |
| return parser.emitError( |
| parser.getCurrentLocation(), |
| "expected an element type when specifying list access types"); |
| } |
| if (!ListType::canImplicitlyCast(listElementType, elementType)) { |
| return parser.emitError( |
| parser.getCurrentLocation(), |
| "list access types must match the same base type as the list element " |
| "type (when not variant)"); |
| } |
| } else { |
| // Use list element type as the result element type. |
| elementType = listElementType; |
| } |
| return success(); |
| } |
| |
| static void printListTypeGet(OpAsmPrinter &printer, Operation *, Type listType, |
| Type elementType) { |
| printer.printType(listType); |
| auto listElementType = cast<ListType>(listType).getElementType(); |
| if (listElementType != elementType) { |
| printer.printArrowTypeList(ArrayRef<Type>{elementType}); |
| } |
| } |
| |
| static ParseResult parseListTypeSet(OpAsmParser &parser, Type &listType, |
| Type &elementType) { |
| Type leadingType; |
| if (failed(parser.parseType(leadingType))) { |
| return parser.emitError(parser.getCurrentLocation(), |
| "expected element type or !util.list<T> type"); |
| } |
| if (succeeded(parser.parseOptionalArrow())) { |
| elementType = leadingType; |
| if (failed(parser.parseType(listType)) || !isa<ListType>(listType)) { |
| return parser.emitError(parser.getCurrentLocation(), |
| "expected an !util.list<T> type"); |
| } |
| } else { |
| if (!isa<ListType>(leadingType)) { |
| return parser.emitError(parser.getCurrentLocation(), |
| "expected an !util.list<T> type"); |
| } |
| listType = leadingType; |
| elementType = cast<ListType>(listType).getElementType(); |
| } |
| return success(); |
| } |
| |
| static void printListTypeSet(OpAsmPrinter &printer, Operation *, Type listType, |
| Type elementType) { |
| auto listElementType = cast<ListType>(listType).getElementType(); |
| if (listElementType != elementType) { |
| printer.printType(elementType); |
| printer.printArrowTypeList(ArrayRef<Type>{listType}); |
| } else { |
| printer.printType(listType); |
| } |
| } |
| |
| LogicalResult ListConstructOp::verify() { |
| Operation *op = getOperation(); |
| auto listType = cast<IREE::Util::ListType>(getResult().getType()); |
| Type elementType = listType.getElementType(); |
| for (auto [idx, value] : llvm::enumerate(getValues())) { |
| Type valueType = value.getType(); |
| if (!ListType::canImplicitlyCast(valueType, elementType)) { |
| return op->emitError() |
| << "list[" << idx << "] type " << valueType |
| << " cannot be be cast to list type " << elementType; |
| } |
| } |
| return success(); |
| } |
| |
| LogicalResult ListGetOp::verify() { |
| Operation *op = getOperation(); |
| auto listType = cast<IREE::Util::ListType>(getList().getType()); |
| auto elementType = listType.getElementType(); |
| auto resultType = getResult().getType(); |
| if (!ListType::canImplicitlyCast(elementType, resultType)) { |
| return op->emitError() << "list contains " << elementType |
| << " and cannot be accessed as " << resultType; |
| } |
| return success(); |
| } |
| |
| LogicalResult ListSetOp::verify() { |
| Operation *op = getOperation(); |
| auto listType = cast<IREE::Util::ListType>(getList().getType()); |
| auto elementType = listType.getElementType(); |
| auto valueType = getValue().getType(); |
| if (!ListType::canImplicitlyCast(valueType, elementType)) { |
| return op->emitError() << "list contains " << elementType |
| << " and cannot be mutated as " << valueType; |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // !util.buffer |
| //===----------------------------------------------------------------------===// |
| |
| void BufferConstantOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), getName().value_or("buffer_cst")); |
| } |
| |
| void BufferConstantOp::build(OpBuilder &builder, OperationState &state, |
| Attribute value) { |
| state.addTypes({builder.getType<IREE::Util::BufferType>()}); |
| state.addAttribute("value", value); |
| } |
| |
| void BufferConstantOp::build(OpBuilder &builder, OperationState &state, |
| StringRef value) { |
| state.addTypes({builder.getType<IREE::Util::BufferType>()}); |
| state.addAttribute("value", builder.getStringAttr(value)); |
| } |
| |
| void BufferConstantOp::build(OpBuilder &builder, OperationState &state, |
| ArrayRef<uint8_t> value) { |
| state.addTypes({builder.getType<IREE::Util::BufferType>()}); |
| state.addAttribute("value", |
| DenseIntElementsAttr::get( |
| VectorType::get(static_cast<int64_t>(value.size()), |
| builder.getI8Type()), |
| value)); |
| } |
| |
| // static |
| Value BufferConstantOp::createOrNull(OpBuilder &builder, Location loc, |
| Attribute value) { |
| if (!value) { |
| auto bufferType = builder.getType<IREE::Util::BufferType>(); |
| return IREE::Util::NullOp::create(builder, loc, bufferType).getResult(); |
| } |
| return IREE::Util::BufferConstantOp::create(builder, loc, value); |
| } |
| |
| // static |
| Value BufferConstantOp::createOrNull(OpBuilder &builder, Location loc, |
| StringRef value) { |
| if (value.empty()) { |
| auto bufferType = builder.getType<IREE::Util::BufferType>(); |
| return IREE::Util::NullOp::create(builder, loc, bufferType).getResult(); |
| } |
| return IREE::Util::BufferConstantOp::create(builder, loc, value); |
| } |
| |
| // static |
| Value BufferConstantOp::createOrNull(OpBuilder &builder, Location loc, |
| ArrayRef<uint8_t> value) { |
| if (value.empty()) { |
| auto bufferType = builder.getType<IREE::Util::BufferType>(); |
| return IREE::Util::NullOp::create(builder, loc, bufferType).getResult(); |
| } |
| return IREE::Util::BufferConstantOp::create(builder, loc, value); |
| } |
| |
| LogicalResult BufferConstantOp::verify() { |
| if (!isa<IREE::Util::SerializableAttrInterface>(getValue())) { |
| return emitOpError("unsupported non-serializable constant attribute type"); |
| } |
| if (auto minAlignmentAttr = getAlignmentAttr()) { |
| int64_t minAlignment = minAlignmentAttr.getInt(); |
| if (minAlignment > 0 && !llvm::isPowerOf2_64(minAlignment)) { |
| return emitOpError("invalid alignment; must be a power of two"); |
| } |
| } |
| return success(); |
| } |
| |
| void BufferAllocOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "buffer"); |
| } |
| |
| LogicalResult BufferAllocOp::verify() { |
| if (auto minAlignmentAttr = getAlignmentAttr()) { |
| int64_t minAlignment = minAlignmentAttr.getInt(); |
| if (minAlignment > 0 && !llvm::isPowerOf2_64(minAlignment)) { |
| return emitOpError("invalid alignment; must be a power of two"); |
| } |
| } |
| return success(); |
| } |
| |
| void BufferSliceOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "buffer"); |
| } |
| |
| SubrangeOperand BufferSliceOp::getSubrangeOperand(unsigned operandIndex) { |
| if (operandIndex == 0) { |
| return SubrangeOperand{getSource(), getSourceSize(), getSourceOffset(), |
| getResultSize()}; |
| } else { |
| assert(false && "only source is a subrange"); |
| return {}; |
| } |
| } |
| |
| void BufferSliceOp::setSubrangeOperand(unsigned operandIndex, |
| SubrangeOperand operand) { |
| assert(operandIndex == 0 && "only source is a subrange"); |
| getSourceMutable().assign(operand.resource); |
| getSourceSizeMutable().assign(operand.resourceSize); |
| getSourceOffsetMutable().assign(operand.offset); |
| getResultSizeMutable().assign(operand.length); |
| } |
| |
| void BufferSubspanOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "buffer_span"); |
| } |
| |
| Value BufferSubspanOp::getViewSource() { return getSource(); } |
| |
| Value BufferSubspanOp::getTiedResult(unsigned resultIndex) { |
| return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource()); |
| } |
| |
| SubrangeOperand BufferSubspanOp::getSubrangeOperand(unsigned operandIndex) { |
| if (operandIndex == 0) { |
| return SubrangeOperand{getSource(), getSourceSize(), getSourceOffset(), |
| getResultSize()}; |
| } else { |
| assert(false && "only source is a subrange"); |
| return {}; |
| } |
| } |
| |
| void BufferSubspanOp::setSubrangeOperand(unsigned operandIndex, |
| SubrangeOperand operand) { |
| assert(operandIndex == 0 && "only source is a subrange"); |
| getSourceMutable().assign(operand.resource); |
| getSourceSizeMutable().assign(operand.resourceSize); |
| getSourceOffsetMutable().assign(operand.offset); |
| getResultSizeMutable().assign(operand.length); |
| } |
| |
| ::std::optional<unsigned> |
| BufferSubspanOp::getTiedResultOperandIndex(unsigned resultIndex) { |
| return {0}; // source |
| } |
| |
| SmallVector<int64_t> BufferSubspanOp::getTiedResultOperandIndices() { |
| return {0}; // source |
| } |
| |
| // static |
| IREE::Util::BufferSubspanOp BufferSubspanOp::findSubspanOp(Value value) { |
| while (value) { |
| auto *definingOp = value.getDefiningOp(); |
| if (!definingOp) { |
| // Defined as a block argument - stop walk. |
| break; |
| } else if (auto subviewOp = |
| dyn_cast<IREE::Util::BufferSubspanOp>(definingOp)) { |
| // Found! |
| return subviewOp; |
| } else if (auto tiedOp = |
| dyn_cast<IREE::Util::TiedOpInterface>(definingOp)) { |
| // Continue walking up through the tied operand. |
| value = tiedOp.getTiedResultOperand(value); |
| } else { |
| break; |
| } |
| } |
| return {}; |
| } |
| |
| void BufferSizeOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "buffer_size"); |
| } |
| |
| void BufferStorageOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "buffer_storage"); |
| setNameFn(getOffset(), "buffer_offset"); |
| } |
| |
| SubrangeOperand BufferCopyOp::getSubrangeOperand(unsigned operandIndex) { |
| if (operandIndex == 0) { |
| return SubrangeOperand{getSource(), getSourceSize(), getSourceOffset(), |
| getLength()}; |
| } else if (operandIndex == 3) { |
| return SubrangeOperand{getTarget(), getTargetSize(), getTargetOffset(), |
| getLength()}; |
| } else { |
| assert(false && "only source/target are subranges"); |
| return {}; |
| } |
| } |
| |
| void BufferCopyOp::setSubrangeOperand(unsigned operandIndex, |
| SubrangeOperand operand) { |
| if (operandIndex == 0) { |
| getSourceMutable().assign(operand.resource); |
| getSourceSizeMutable().assign(operand.resourceSize); |
| getSourceOffsetMutable().assign(operand.offset); |
| getLengthMutable().assign(operand.length); |
| } else if (operandIndex == 3) { |
| getTargetMutable().assign(operand.resource); |
| getTargetSizeMutable().assign(operand.resourceSize); |
| getTargetOffsetMutable().assign(operand.offset); |
| getLengthMutable().assign(operand.length); |
| } else { |
| assert(false && "only source/target are subranges"); |
| } |
| } |
| |
| SubrangeOperand BufferCompareOp::getSubrangeOperand(unsigned operandIndex) { |
| if (operandIndex == 0) { |
| return SubrangeOperand{getLhs(), getLhsSize(), getLhsOffset(), getLength()}; |
| } else if (operandIndex == 3) { |
| return SubrangeOperand{getRhs(), getRhsSize(), getRhsOffset(), getLength()}; |
| } else { |
| assert(false && "only lhs/rhs are subranges"); |
| return {}; |
| } |
| } |
| |
| void BufferCompareOp::setSubrangeOperand(unsigned operandIndex, |
| SubrangeOperand operand) { |
| if (operandIndex == 0) { |
| getLhsMutable().assign(operand.resource); |
| getLhsSizeMutable().assign(operand.resourceSize); |
| getLhsOffsetMutable().assign(operand.offset); |
| getLengthMutable().assign(operand.length); |
| } else if (operandIndex == 3) { |
| getRhsMutable().assign(operand.resource); |
| getRhsSizeMutable().assign(operand.resourceSize); |
| getRhsOffsetMutable().assign(operand.offset); |
| getLengthMutable().assign(operand.length); |
| } else { |
| assert(false && "only lhs/rhs are subranges"); |
| } |
| } |
| |
| SubrangeOperand BufferFillOp::getSubrangeOperand(unsigned operandIndex) { |
| if (operandIndex == 1) { |
| return SubrangeOperand{getTarget(), getTargetSize(), getTargetOffset(), |
| getLength()}; |
| } else { |
| assert(false && "only target is a subrange"); |
| return {}; |
| } |
| } |
| |
| void BufferFillOp::setSubrangeOperand(unsigned operandIndex, |
| SubrangeOperand operand) { |
| assert(operandIndex == 1 && "only target is a subrange"); |
| getTargetMutable().assign(operand.resource); |
| getTargetSizeMutable().assign(operand.resourceSize); |
| getTargetOffsetMutable().assign(operand.offset); |
| getLengthMutable().assign(operand.length); |
| } |
| |
| SubrangeOperand BufferLoadOp::getSubrangeOperand(unsigned operandIndex) { |
| if (operandIndex == 0) { |
| return SubrangeOperand{getSource(), getSourceSize(), getSourceOffset(), |
| getLength()}; |
| } else { |
| assert(false && "only source is a subrange"); |
| return {}; |
| } |
| } |
| |
| void BufferLoadOp::setSubrangeOperand(unsigned operandIndex, |
| SubrangeOperand operand) { |
| assert(operandIndex == 0 && "only source is a subrange"); |
| getSourceMutable().assign(operand.resource); |
| getSourceSizeMutable().assign(operand.resourceSize); |
| getSourceOffsetMutable().assign(operand.offset); |
| getLengthMutable().assign(operand.length); |
| } |
| |
| SubrangeOperand BufferStoreOp::getSubrangeOperand(unsigned operandIndex) { |
| if (operandIndex == 1) { |
| return SubrangeOperand{getTarget(), getTargetSize(), getTargetOffset(), |
| getLength()}; |
| } else { |
| assert(false && "only target is a subrange"); |
| return {}; |
| } |
| } |
| |
| void BufferStoreOp::setSubrangeOperand(unsigned operandIndex, |
| SubrangeOperand operand) { |
| assert(operandIndex == 1 && "only target is a subrange"); |
| getTargetMutable().assign(operand.resource); |
| getTargetSizeMutable().assign(operand.resourceSize); |
| getTargetOffsetMutable().assign(operand.offset); |
| getLengthMutable().assign(operand.length); |
| } |
| |
| SubrangeOperand BufferHashOp::getSubrangeOperand(unsigned operandIndex) { |
| if (operandIndex == 0) { |
| return SubrangeOperand{getSource(), getSourceSize(), getSourceOffset(), |
| getLength()}; |
| } else { |
| assert(false && "only source is a subrange"); |
| return {}; |
| } |
| } |
| |
| void BufferHashOp::setSubrangeOperand(unsigned operandIndex, |
| SubrangeOperand operand) { |
| assert(operandIndex == 0 && "only source is a subrange"); |
| getSourceMutable().assign(operand.resource); |
| getSourceSizeMutable().assign(operand.resourceSize); |
| getSourceOffsetMutable().assign(operand.offset); |
| getLengthMutable().assign(operand.length); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // util.scf.unreachable |
| //===----------------------------------------------------------------------===// |
| |
| // static |
| SmallVector<Value> SCFUnreachableOp::createPoisonValues(OpBuilder &builder, |
| Location loc, |
| TypeRange resultTypes) { |
| SmallVector<Value> poisonValues; |
| for (Type type : resultTypes) { |
| poisonValues.push_back( |
| mlir::ub::PoisonOp::create(builder, loc, type, nullptr).getResult()); |
| } |
| return poisonValues; |
| } |
| |
| // static |
| scf::YieldOp SCFUnreachableOp::createRegionTerminator(OpBuilder &builder, |
| Location loc, |
| TypeRange resultTypes, |
| StringAttr message) { |
| SCFUnreachableOp::create(builder, loc, message); |
| return mlir::scf::YieldOp::create( |
| builder, loc, createPoisonValues(builder, loc, resultTypes)); |
| } |
| |
| // static |
| scf::YieldOp SCFUnreachableOp::createRegionTerminator(OpBuilder &builder, |
| Location loc, |
| TypeRange resultTypes, |
| StringRef message) { |
| return createRegionTerminator( |
| builder, loc, resultTypes, |
| message.empty() ? StringAttr{} : builder.getStringAttr(message)); |
| } |
| |
| // static |
| Operation *SCFUnreachableOp::createWithTerminator(OpBuilder &builder, |
| Location loc, |
| TypeRange resultTypes, |
| StringAttr message) { |
| // Create scf.yield with poison values for SCF regions. |
| auto *parentOp = builder.getInsertionPoint()->getParentOp(); |
| if (parentOp && isa<scf::SCFDialect>(parentOp->getDialect())) { |
| return createRegionTerminator(builder, loc, resultTypes, message); |
| } |
| |
| // For non-SCF regions, convert to util.unreachable terminator. |
| return IREE::Util::UnreachableOp::create(builder, loc, message); |
| } |
| |
| } // namespace mlir::iree_compiler::IREE::Util |
| |
| #define GET_OP_CLASSES |
| #include "iree/compiler/Dialect/Util/IR/UtilOps.cpp.inc" |