| // Copyright 2021 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree-dialects/Dialect/Input/InputOps.h" |
| |
| #include "iree-dialects/Dialect/Input/InputDialect.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinAttributeInterfaces.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/IR/TypeUtilities.h" |
| |
| using namespace mlir; |
| using namespace mlir::iree_compiler::IREE::Input; |
| |
| #include "iree-dialects/Dialect/Input/InputOpInterfaces.cpp.inc" |
| |
| //===----------------------------------------------------------------------===// |
| // IREE::Input::TiedOpInterface |
| //===----------------------------------------------------------------------===// |
| |
| namespace mlir::iree_compiler::IREE::Input::detail { |
| |
| std::optional<unsigned> getTiedResultOperandIndex(Operation *op, |
| unsigned resultIndex) { |
| auto storageAttr = |
| op->getAttrOfType<ArrayAttr>(TiedOpInterface::getStorageAttrName()); |
| if (!storageAttr) |
| return std::nullopt; |
| auto valueAttrs = storageAttr.getValue(); |
| if (valueAttrs.empty()) |
| return std::nullopt; |
| if (auto tiedOp = dyn_cast<TiedOpInterface>(op)) { |
| auto indexAndLength = tiedOp.getTiedResultsIndexAndLength(); |
| if (resultIndex < indexAndLength.first) |
| return std::nullopt; |
| resultIndex -= indexAndLength.first; |
| if (resultIndex >= indexAndLength.second) |
| return std::nullopt; |
| } |
| int64_t value = llvm::cast<IntegerAttr>(valueAttrs[resultIndex]).getInt(); |
| if (value == TiedOpInterface::kUntiedIndex) |
| return std::nullopt; |
| if (auto tiedOp = dyn_cast<TiedOpInterface>(op)) { |
| unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first; |
| return tiedOperandsOffset + static_cast<unsigned>(value); |
| } else { |
| return static_cast<unsigned>(value); |
| } |
| } |
| |
| SmallVector<int64_t> getTiedResultOperandIndices(Operation *op) { |
| SmallVector<int64_t> indices; |
| auto storageAttr = |
| op->getAttrOfType<ArrayAttr>(TiedOpInterface::getStorageAttrName()); |
| if (!storageAttr) |
| return indices; |
| auto valueAttrs = storageAttr.getValue(); |
| if (valueAttrs.empty()) |
| return indices; |
| auto tiedOp = cast<TiedOpInterface>(op); |
| auto resultRange = tiedOp.getTiedResultsIndexAndLength(); |
| unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first; |
| indices.resize(resultRange.second); |
| for (unsigned i = 0; i < valueAttrs.size(); ++i) { |
| int64_t index = llvm::cast<IntegerAttr>(valueAttrs[i]).getInt(); |
| indices[i] = index != TiedOpInterface::kUntiedIndex |
| ? tiedOperandsOffset + index |
| : TiedOpInterface::kUntiedIndex; |
| } |
| return indices; |
| } |
| |
| void setTiedResultOperandIndex(Operation *op, unsigned resultIndex, |
| std::optional<unsigned> operandIndex) { |
| auto tiedOp = cast<TiedOpInterface>(op); |
| auto resultRange = tiedOp.getTiedResultsIndexAndLength(); |
| resultIndex -= resultRange.first; |
| |
| auto indices = getTiedResultOperandIndices(op); |
| if (indices.empty()) { |
| indices.resize(resultRange.second, TiedOpInterface::kUntiedIndex); |
| } else { |
| // Well, getTiedResultOperandIndices() returns indices into the full range |
| // of the op, but in the attribute, we expect to store ranges into the range |
| // returned by `getTiedOperandsIndexAndLength`. |
| unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first; |
| for (auto &index : indices) { |
| if (index != TiedOpInterface::kUntiedIndex) |
| index -= tiedOperandsOffset; |
| } |
| } |
| |
| indices[resultIndex] = operandIndex.value_or(TiedOpInterface::kUntiedIndex); |
| op->setAttr(TiedOpInterface::getStorageAttrName(), |
| Builder(op).getIndexArrayAttr(indices)); |
| } |
| |
| bool isOperandTied(Operation *op, unsigned operandIndex) { |
| auto tiedOp = dyn_cast<TiedOpInterface>(op); |
| if (!tiedOp) |
| return false; |
| auto tiedIndices = tiedOp.getTiedResultOperandIndices(); |
| for (unsigned i = 0; i < tiedIndices.size(); ++i) { |
| if (tiedIndices[i] == operandIndex) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| SmallVector<Value> getOperandTiedResults(Operation *op, unsigned operandIndex) { |
| auto tiedOp = dyn_cast<TiedOpInterface>(op); |
| if (!tiedOp) |
| return {}; |
| auto resultRange = tiedOp.getTiedResultsIndexAndLength(); |
| SmallVector<Value> results; |
| auto tiedIndices = tiedOp.getTiedResultOperandIndices(); |
| for (unsigned i = 0; i < tiedIndices.size(); ++i) { |
| if (tiedIndices[i] == operandIndex) { |
| results.push_back(op->getResult(resultRange.first + i)); |
| } |
| } |
| return results; |
| } |
| |
| LogicalResult verifyTiedOp(TiedOpInterface tiedOp) { |
| auto tiedOperandIndices = tiedOp.getTiedResultOperandIndices(); |
| if (tiedOperandIndices.empty()) |
| return success(); |
| auto resultRange = tiedOp.getTiedResultsIndexAndLength(); |
| if (tiedOperandIndices.size() != resultRange.second) { |
| return tiedOp.emitError("op results/tied operand indices mismatch"); |
| } |
| return success(); |
| } |
| |
| } // namespace mlir::iree_compiler::IREE::Input::detail |
| |
| Value TiedOpInterface::findTiedBaseValue(Value derivedValue) { |
| Value baseValue = derivedValue; |
| while (auto definingOp = |
| dyn_cast_or_null<TiedOpInterface>(baseValue.getDefiningOp())) { |
| auto tiedValue = definingOp.getTiedResultOperand(baseValue); |
| if (!tiedValue) |
| break; |
| baseValue = tiedValue; |
| } |
| return baseValue; |
| } |
| |
| bool TiedOpInterface::hasAnyTiedUses(Value value) { |
| return llvm::any_of(value.getUses(), [](auto &use) { |
| if (auto tiedOp = dyn_cast<TiedOpInterface>(use.getOwner())) { |
| return tiedOp.isOperandTied(use.getOperandNumber()); |
| } |
| return false; |
| }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<SymbolVisibility>($sym_visibility) |
| //===----------------------------------------------------------------------===// |
| // some.op custom<SymbolVisibility>($sym_visibility) $sym_name |
| // -> |
| // some.op @foo |
| // some.op private @foo |
| |
| static ParseResult parseSymbolVisibility(OpAsmParser &parser, |
| StringAttr &symVisibilityAttr) { |
| StringRef symVisibility; |
| if (succeeded(parser.parseOptionalKeyword(&symVisibility, |
| {"public", "private", "nested"}))) { |
| symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility); |
| } |
| return success(); |
| } |
| |
| static void printSymbolVisibility(OpAsmPrinter &p, Operation *op, |
| StringAttr symVisibilityAttr) { |
| if (!symVisibilityAttr) { |
| p << "public"; |
| } else { |
| p << symVisibilityAttr.getValue(); |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<TypeOrAttr>($type, $attr) |
| //===----------------------------------------------------------------------===// |
| // some.op custom<TypeOrAttr>($type, $attr) |
| // -> |
| // some.op : i32 |
| // some.op = 42 : i32 |
| // some.op : i32 = 42 : index |
| |
| static ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, |
| TypedAttr &attr) { |
| if (succeeded(parser.parseOptionalEqual())) { |
| if (failed(parser.parseAttribute(attr))) { |
| return parser.emitError(parser.getCurrentLocation()) |
| << "expected attribute"; |
| } |
| typeAttr = TypeAttr::get(attr.getType()); |
| return success(); |
| } |
| |
| Type type; |
| if (failed(parser.parseColonType(type))) { |
| return parser.emitError(parser.getCurrentLocation()) << "expected type"; |
| } |
| typeAttr = TypeAttr::get(type); |
| |
| if (succeeded(parser.parseOptionalEqual())) { |
| if (failed(parser.parseAttribute(attr))) { |
| return parser.emitError(parser.getCurrentLocation()) |
| << "expected attribute"; |
| } |
| } |
| |
| return success(); |
| } |
| |
| static void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, |
| TypedAttr attr) { |
| if (!attr || attr.getType() != type.getValue()) { |
| p << " : "; |
| p.printAttribute(type); |
| } |
| if (attr) { |
| p << " = "; |
| p.printAttribute(attr); |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<ShapedTiedResult> |
| //===----------------------------------------------------------------------===// |
| // type{%dim0, %dim1} |
| // %arg0 as type{%dim0} |
| |
| static ParseResult parseShapedTiedResult( |
| OpAsmParser &parser, Type &resultType, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultDims, |
| ArrayAttr &tiedOperands) { |
| OpAsmParser::UnresolvedOperand tiedResult; |
| auto res = parser.parseOptionalOperand(tiedResult); |
| int64_t tiedOperandIndex = TiedOpInterface::kUntiedIndex; |
| if (res.has_value() && succeeded(res.value())) { |
| tiedOperandIndex = 0; |
| if (failed(parser.parseKeyword("as"))) |
| return failure(); |
| } |
| if (failed(parser.parseType(resultType))) |
| return failure(); |
| if (auto shapedType = dyn_cast<ShapedType>(resultType)) { |
| if (!shapedType.hasStaticShape()) { |
| SmallVector<OpAsmParser::UnresolvedOperand> dynamicDims; |
| if (failed(parser.parseLBrace()) || |
| failed(parser.parseOperandList(dynamicDims, |
| shapedType.getNumDynamicDims(), |
| OpAsmParser::Delimiter::None)) || |
| failed(parser.parseRBrace())) { |
| return failure(); |
| } |
| resultDims.append(dynamicDims); |
| } |
| } |
| tiedOperands = parser.getBuilder().getIndexArrayAttr({tiedOperandIndex}); |
| return success(); |
| } |
| |
| static ParseResult parseShapedTiedResult( |
| OpAsmParser &parser, Type &resultType, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultDims) { |
| ArrayAttr tiedOperands; |
| return parseShapedTiedResult(parser, resultType, resultDims, tiedOperands); |
| } |
| |
| void printShapedTiedResult(OpAsmPrinter &p, TiedOpInterface op, Type resultType, |
| ValueRange resultDims) { |
| auto tiedOperandIndex = op.getTiedResultOperandIndex(0); |
| if (tiedOperandIndex.has_value()) { |
| auto tiedOperand = op->getOperand(tiedOperandIndex.value()); |
| p.printOperand(tiedOperand); |
| p << " as "; |
| } |
| p.printType(resultType); |
| if (auto shapedType = dyn_cast<ShapedType>(resultType)) { |
| if (!shapedType.hasStaticShape()) { |
| if (resultDims.empty()) { |
| p << "{<<INVALID>>}"; |
| return; |
| } |
| p << "{"; |
| llvm::interleaveComma( |
| resultDims.take_front(shapedType.getNumDynamicDims()), p, |
| [&](Value value) { p.printOperand(value); }); |
| p << "}"; |
| resultDims = resultDims.drop_front(shapedType.getNumDynamicDims()); |
| } |
| } |
| } |
| |
| static void printShapedTiedResult(OpAsmPrinter &p, TiedOpInterface op, |
| Type resultType, ValueRange resultDims, |
| ArrayAttr tiedOperands) { |
| printShapedTiedResult(p, op, resultType, resultDims); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<ShapedOperandList>($values, type($values), $value_dims) |
| //===----------------------------------------------------------------------===// |
| // %value : type{%dynamic_dims}, ... |
| |
| ParseResult parseShapedOperandList( |
| OpAsmParser &parser, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, |
| SmallVectorImpl<Type> &valueTypes, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &valueDims) { |
| do { |
| values.emplace_back(); |
| valueTypes.emplace_back(); |
| if (failed(parser.parseOperand(values.back())) || |
| failed(parser.parseColon()) || |
| failed(parser.parseType(valueTypes.back()))) |
| return failure(); |
| if (int64_t dynamicDimCount = |
| cast<ShapedType>(valueTypes.back()).getNumDynamicDims()) { |
| if (failed(parser.parseOperandList(valueDims, dynamicDimCount, |
| AsmParser::Delimiter::Braces))) |
| return failure(); |
| } |
| } while (succeeded(parser.parseOptionalComma())); |
| return success(); |
| } |
| |
| void printShapedOperandList(OpAsmPrinter &p, Operation *op, ValueRange values, |
| TypeRange valueTypes, ValueRange valueDims) { |
| llvm::interleaveComma(llvm::zip_equal(values, valueTypes), p, [&](auto it) { |
| auto [value, valueType] = it; |
| p << value; |
| p << " : "; |
| p << valueType; |
| if (int64_t dynamicDimCount = |
| cast<ShapedType>(valueType).getNumDynamicDims()) { |
| p << "{"; |
| llvm::interleaveComma(valueDims.take_front(dynamicDimCount), p); |
| valueDims = valueDims.drop_front(dynamicDimCount); |
| p << "}"; |
| } |
| }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<ShapedFunctionType> |
| //===----------------------------------------------------------------------===// |
| // (type, type{%dim0, %dim1}, type) -> (type{%dim2}, %operand4) |
| |
| static ParseResult |
| parseShapedTypeList(OpAsmParser &parser, SmallVectorImpl<Type> &types, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dims) { |
| do { |
| Type type; |
| if (failed(parser.parseType(type))) |
| return failure(); |
| if (auto shapedType = dyn_cast<ShapedType>(type)) { |
| if (!shapedType.hasStaticShape()) { |
| SmallVector<OpAsmParser::UnresolvedOperand> dynamicDims; |
| if (failed(parser.parseLBrace()) || |
| failed(parser.parseOperandList(dynamicDims, |
| shapedType.getNumDynamicDims(), |
| OpAsmParser::Delimiter::None)) || |
| failed(parser.parseRBrace())) { |
| return failure(); |
| } |
| dims.append(dynamicDims); |
| } |
| } |
| types.push_back(type); |
| } while (succeeded(parser.parseOptionalComma())); |
| return success(); |
| } |
| |
| // Finds the operand index in |operands| that |tiedResult| references. |
| // Returns TiedOpInterface::kUntiedIndex if no operand is found. |
| static int64_t |
| findTiedOperand(OpAsmParser::UnresolvedOperand tiedResult, |
| ArrayRef<OpAsmParser::UnresolvedOperand> operands) { |
| int64_t operandIndex = TiedOpInterface::kUntiedIndex; |
| for (int64_t i = 0; i < operands.size(); ++i) { |
| if (operands[i].name == tiedResult.name && |
| operands[i].number == tiedResult.number) { |
| operandIndex = i; |
| break; |
| } |
| } |
| return operandIndex; |
| } |
| |
| static ParseResult parseShapedResultList( |
| OpAsmParser &parser, ArrayRef<OpAsmParser::UnresolvedOperand> operands, |
| TypeRange operandTypes, |
| ArrayRef<OpAsmParser::UnresolvedOperand> operandDims, |
| SmallVectorImpl<Type> &resultTypes, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultDims, |
| ArrayAttr &tiedOperands) { |
| SmallVector<int64_t> tiedOperandIndices; |
| do { |
| OpAsmParser::UnresolvedOperand tiedResult; |
| auto res = parser.parseOptionalOperand(tiedResult); |
| Type type; |
| int64_t tiedOperandIndex = TiedOpInterface::kUntiedIndex; |
| if (res.has_value() && succeeded(res.value())) { |
| tiedOperandIndex = findTiedOperand(tiedResult, operands); |
| if (tiedOperandIndex == TiedOpInterface::kUntiedIndex) { |
| return parser.emitError(tiedResult.location, |
| "tied operand not found for result reference ") |
| << tiedResult.name; |
| } |
| if (succeeded(parser.parseOptionalKeyword("as"))) { |
| // Type _may_ differ from the operand. |
| if (failed(parser.parseType(type))) |
| return failure(); |
| } else { |
| // Use the operands type. |
| type = operandTypes[tiedOperandIndex]; |
| } |
| } else if (failed(parser.parseType(type))) { |
| return failure(); |
| } |
| if (auto shapedType = dyn_cast<ShapedType>(type)) { |
| if (!shapedType.hasStaticShape()) { |
| SmallVector<OpAsmParser::UnresolvedOperand> dynamicDims; |
| if (failed(parser.parseLBrace()) || |
| failed(parser.parseOperandList(dynamicDims, |
| shapedType.getNumDynamicDims(), |
| OpAsmParser::Delimiter::None)) || |
| failed(parser.parseRBrace())) { |
| return failure(); |
| } |
| resultDims.append(dynamicDims); |
| } |
| } |
| resultTypes.push_back(type); |
| tiedOperandIndices.push_back(tiedOperandIndex); |
| } while (succeeded(parser.parseOptionalComma())); |
| if (!tiedOperandIndices.empty()) { |
| tiedOperands = parser.getBuilder().getIndexArrayAttr(tiedOperandIndices); |
| } |
| return success(); |
| } |
| |
| static void printShapedResultList(OpAsmPrinter &p, TiedOpInterface tiedOp, |
| ValueRange operands, TypeRange operandTypes, |
| ValueRange operandDims, TypeRange resultTypes, |
| ValueRange resultDims, |
| ArrayAttr tiedOperands) { |
| for (unsigned i = 0; i < resultTypes.size(); ++i) { |
| auto resultType = resultTypes[i]; |
| auto tiedOperandIndex = tiedOp.getTiedResultOperandIndex(i); |
| bool printType = true; |
| if (tiedOperandIndex.has_value()) { |
| auto tiedOperand = tiedOp->getOperand(tiedOperandIndex.value()); |
| p.printOperand(tiedOperand); |
| if (tiedOperand.getType() != resultType) { |
| p << " as "; |
| } else { |
| // Type elided as it matches the operand. |
| printType = false; |
| } |
| } |
| if (printType) { |
| p.printType(resultType); |
| } |
| if (auto shapedType = dyn_cast<ShapedType>(resultType)) { |
| if (!shapedType.hasStaticShape()) { |
| if (resultDims.empty()) { |
| p << "{<<INVALID>>}"; |
| return; |
| } |
| p << "{"; |
| llvm::interleaveComma( |
| resultDims.take_front(shapedType.getNumDynamicDims()), p, |
| [&](Value value) { p.printOperand(value); }); |
| p << "}"; |
| resultDims = resultDims.drop_front(shapedType.getNumDynamicDims()); |
| } |
| } |
| if (i < resultTypes.size() - 1) |
| p << ", "; |
| } |
| } |
| |
| static ParseResult parseShapedFunctionType( |
| OpAsmParser &parser, ArrayRef<OpAsmParser::UnresolvedOperand> operands, |
| SmallVectorImpl<Type> &operandTypes, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operandDims, |
| SmallVectorImpl<Type> &resultTypes, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultDims, |
| ArrayAttr &tiedOperands) { |
| if (failed(parser.parseLParen())) |
| return failure(); |
| if (failed(parser.parseOptionalRParen())) { |
| if (failed(parseShapedTypeList(parser, operandTypes, operandDims)) || |
| failed(parser.parseRParen())) { |
| return failure(); |
| } |
| } |
| if (failed(parser.parseArrow())) |
| return failure(); |
| if (succeeded(parser.parseOptionalLParen())) { |
| if (succeeded(parser.parseOptionalRParen())) { |
| // Empty list/no results `()`. |
| } else { |
| // One or more result types. |
| if (failed(parseShapedResultList(parser, operands, operandTypes, |
| operandDims, resultTypes, resultDims, |
| tiedOperands)) || |
| failed(parser.parseRParen())) { |
| return failure(); |
| } |
| } |
| } else { |
| // Single result with omitted `()`. |
| if (failed(parseShapedResultList(parser, operands, operandTypes, |
| operandDims, resultTypes, resultDims, |
| tiedOperands))) { |
| return failure(); |
| } |
| } |
| return success(); |
| } |
| |
| static void printShapedFunctionType(OpAsmPrinter &p, TiedOpInterface tiedOp, |
| ValueRange operands, TypeRange operandTypes, |
| OperandRange operandDims, |
| TypeRange resultTypes, |
| OperandRange resultDims, |
| ArrayAttr tiedOperands) { |
| p << "("; |
| llvm::interleaveComma(operandTypes, p, [&](Type type) { |
| p.printType(type); |
| if (auto shapedType = dyn_cast<ShapedType>(type)) { |
| if (!shapedType.hasStaticShape()) { |
| if (operandDims.empty()) { |
| p << "{<<INVALID>>}"; |
| return; |
| } |
| p << "{"; |
| llvm::interleaveComma( |
| operandDims.take_front(shapedType.getNumDynamicDims()), p, |
| [&](Value value) { p.printOperand(value); }); |
| p << "}"; |
| operandDims = operandDims.drop_front(shapedType.getNumDynamicDims()); |
| } |
| } |
| }); |
| p << ") -> "; |
| if (resultTypes.size() != 1) |
| p << "("; |
| printShapedResultList(p, tiedOp, operands, operandTypes, operandDims, |
| resultTypes, resultDims, tiedOperands); |
| if (resultTypes.size() != 1) |
| p << ")"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // GlobalOp |
| //===----------------------------------------------------------------------===// |
| |
| void GlobalOp::build(OpBuilder &builder, OperationState &result, StringRef name, |
| bool isMutable, Type type, |
| std::optional<TypedAttr> initialValue, |
| ArrayRef<NamedAttribute> attrs) { |
| result.addAttribute(SymbolTable::getSymbolAttrName(), |
| builder.getStringAttr(name)); |
| if (isMutable) { |
| result.addAttribute("is_mutable", builder.getUnitAttr()); |
| } |
| if (initialValue.has_value()) { |
| result.addAttribute("initial_value", initialValue.value()); |
| } |
| result.addAttribute("type", TypeAttr::get(type)); |
| result.attributes.append(attrs.begin(), attrs.end()); |
| } |
| |
| void GlobalOp::build(OpBuilder &builder, OperationState &result, StringRef name, |
| bool isMutable, Type type, |
| ArrayRef<NamedAttribute> attrs) { |
| build(builder, result, name, isMutable, type, std::nullopt, attrs); |
| } |
| |
| // Returns true if the given |accessType| is compatible with the |globalType|. |
| // For example, this will return true if the global type is a tensor<?xf32> |
| // and the access is tensor<4xf32>. |
| static bool isGlobalTypeCompatible(Type globalType, Type accessType) { |
| // If one is a shaped type, then they both must be and have compatible |
| // shapes. |
| if (isa<ShapedType>(globalType) && isa<ShapedType>(accessType)) { |
| return succeeded(mlir::verifyCompatibleShape(globalType, accessType)) && |
| cast<ShapedType>(globalType).getElementType() == |
| cast<ShapedType>(accessType).getElementType(); |
| } |
| |
| // Permissively allow any other types to be marked compatible as long as |
| // neither are shaped type. |
| return !isa<ShapedType>(globalType) && !isa<ShapedType>(accessType); |
| } |
| |
| LogicalResult |
| GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
| auto globalOp = |
| symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getGlobalAttr()); |
| if (!globalOp) { |
| return emitOpError() << "undefined global: " << getGlobal(); |
| } |
| auto loadType = getResult().getType(); |
| if (!isGlobalTypeCompatible(globalOp.getType(), loadType)) { |
| return emitOpError() << "global type mismatch; global " << getGlobal() |
| << " is " << globalOp.getType() << " but load is " |
| << loadType; |
| } |
| return success(); |
| } |
| |
| LogicalResult GlobalLoadIndirectOp::verify() { |
| auto globalType = cast<PtrType>(getGlobal().getType()).getTargetType(); |
| auto loadType = getResult().getType(); |
| if (!isGlobalTypeCompatible(globalType, loadType)) { |
| return emitOpError() << "global type mismatch; global pointer is " |
| << globalType << " but load is " << loadType; |
| } |
| return success(); |
| } |
| |
| LogicalResult |
| GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
| auto globalOp = |
| symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getGlobalAttr()); |
| if (!globalOp) { |
| return emitOpError() << "undefined global: " << getGlobal(); |
| } |
| auto storeType = getValue().getType(); |
| if (!isGlobalTypeCompatible(globalOp.getType(), storeType)) { |
| return emitOpError() << "global type mismatch; global " << getGlobal() |
| << " is " << globalOp.getType() << " but store is " |
| << storeType; |
| } |
| return success(); |
| } |
| |
| LogicalResult GlobalStoreIndirectOp::verify() { |
| auto globalType = cast<PtrType>(getGlobal().getType()).getTargetType(); |
| auto storeType = getValue().getType(); |
| if (!isGlobalTypeCompatible(globalType, storeType)) { |
| return emitOpError() << "global type mismatch; global pointer is " |
| << globalType << " but store is " << storeType; |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // !iree_input.byte_vuffer |
| //===----------------------------------------------------------------------===// |
| |
| void ByteBufferConstantOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), getName().value_or("bytes_cst")); |
| } |
| |
| LogicalResult ByteBufferConstantOp::verify() { |
| if (auto minAlignmentAttr = getAlignmentAttr()) { |
| int64_t minAlignment = minAlignmentAttr.getInt(); |
| if (minAlignment > 0 && !llvm::isPowerOf2_64(minAlignment)) { |
| return emitOpError("invalid alignment; must be a power of two"); |
| } |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // iree_input.buffer.subspan |
| //===----------------------------------------------------------------------===// |
| |
| void BufferSubspanOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(getResult(), "buffer"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // iree_input.buffer_view.create |
| //===----------------------------------------------------------------------===// |
| |
| void BufferViewCreateOp::build(OpBuilder &builder, OperationState &state, |
| Value sourceBuffer, Value sourceOffset, |
| Value sourceLength, int32_t elementType, |
| int32_t encodingType, ValueRange shape) { |
| build(builder, state, sourceBuffer, sourceOffset, sourceLength, |
| builder.createOrFold<arith::ConstantIntOp>(state.location, elementType, |
| 32), |
| builder.createOrFold<arith::ConstantIntOp>(state.location, encodingType, |
| 32), |
| shape); |
| } |
| |
| void BufferViewCreateOp::build(OpBuilder &builder, OperationState &state, |
| Value sourceBuffer, Value sourceOffset, |
| Value sourceLength, Value elementType, |
| Value encodingType, ValueRange shape) { |
| state.addOperands( |
| {sourceBuffer, sourceOffset, sourceLength, elementType, encodingType}); |
| state.addOperands(shape); |
| state.addTypes({BufferViewType::get(builder.getContext())}); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // iree_input.tensor.update |
| //===----------------------------------------------------------------------===// |
| |
| Value TensorUpdateOp::getTiedResult(unsigned resultIndex) { |
| return TiedOpInterface::findTiedBaseValue(getTarget()); |
| } |
| |
| std::optional<unsigned> |
| TensorUpdateOp::getTiedResultOperandIndex(unsigned resultIndex) { |
| return {0}; // $target |
| } |
| |
| SmallVector<int64_t> TensorUpdateOp::getTiedResultOperandIndices() { |
| return {0}; // $target |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // iree_input.tensor.trace |
| //===----------------------------------------------------------------------===// |
| |
| void TensorTraceOp::build(OpBuilder &builder, OperationState &state, |
| StringRef key, ValueRange values) { |
| SmallVector<Value> dynamicDims; |
| for (auto value : values) { |
| auto valueType = cast<ShapedType>(value.getType()); |
| for (unsigned i = 0; i < valueType.getRank(); ++i) { |
| if (valueType.isDynamicDim(i)) { |
| dynamicDims.push_back( |
| builder.createOrFold<tensor::DimOp>(state.location, value, i)); |
| } |
| } |
| } |
| build(builder, state, key, values, dynamicDims); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // iree_input.dispatch |
| //===----------------------------------------------------------------------===// |
| |
| void DispatchOp::build(OpBuilder &builder, OperationState &state, |
| ExecutableExportOp exportOp, ValueRange workload, |
| TypeRange resultTypes, ValueRange resultDims, |
| ValueRange operands, ValueRange operandDims, |
| ArrayAttr tiedOperands, |
| ArrayRef<NamedAttribute> attributes) { |
| StringRef executableOpSymName = |
| exportOp->getParentOp() |
| ->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()) |
| .getValue(); |
| auto entryPoint = |
| SymbolRefAttr::get(builder.getContext(), executableOpSymName, |
| {SymbolRefAttr::get(exportOp)}); |
| state.addAttribute("entry_point", entryPoint); |
| state.addOperands(workload); |
| state.addTypes(resultTypes); |
| state.addOperands(operands); |
| state.addOperands(operandDims); |
| state.addOperands(resultDims); |
| state.addAttributes(attributes); |
| state.attributes.erase(TiedOpInterface::getStorageAttrName()); |
| state.addAttribute(TiedOpInterface::getStorageAttrName(), tiedOperands); |
| state.attributes.erase(getOperandSegmentSizeAttr()); |
| state.addAttribute(getOperandSegmentSizeAttr(), |
| builder.getDenseI32ArrayAttr({ |
| static_cast<int32_t>(workload.size()), |
| static_cast<int32_t>(operands.size()), |
| static_cast<int32_t>(operandDims.size()), |
| static_cast<int32_t>(resultDims.size()), |
| })); |
| } |
| |
| LogicalResult DispatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
| Operation *op = getOperation(); |
| auto exportOp = symbolTable.lookupNearestSymbolFrom<ExecutableExportOp>( |
| op, getEntryPoint()); |
| if (!exportOp) { |
| return op->emitOpError() << "undefined entry point: " << getEntryPoint(); |
| } |
| |
| // TODO(ezhulenev): verify that the exported function has matching operands. |
| return success(); |
| } |
| |
| std::pair<unsigned, unsigned> DispatchOp::getTiedOperandsIndexAndLength() { |
| return getODSOperandIndexAndLength(1); // $arguments |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // iree_input.optimization_barrier |
| //===----------------------------------------------------------------------===// |
| |
| void OptimizationBarrierOp::build(OpBuilder &builder, OperationState &state, |
| ValueRange operands, |
| ArrayRef<NamedAttribute> attributes) { |
| state.addOperands(operands); |
| state.addTypes(llvm::to_vector<2>(operands.getTypes())); |
| state.addAttributes(attributes); |
| } |
| |
| LogicalResult OptimizationBarrierOp::verify() { |
| Operation *op = getOperation(); |
| if (op->getNumOperands() != op->getNumResults()) { |
| return op->emitOpError() |
| << "must have same number of operands and results, but has " |
| << op->getNumOperands() << " and " << op->getNumResults() |
| << ", respectively"; |
| } |
| |
| for (int i = 0, e = op->getNumOperands(); i < e; ++i) { |
| if (op->getOperand(i).getType() != op->getResult(i).getType()) { |
| op->emitOpError() << "must have same operand and result types, but they " |
| "differ at index " |
| << i; |
| } |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| |
| #define GET_OP_CLASSES |
| #include "iree-dialects/Dialect/Input/InputOps.cpp.inc" |