| // Copyright 2019 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Dialect/Util/IR/UtilOps.h" |
| |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/Support/SMLoc.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/Diagnostics.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/IR/OperationSupport.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/IR/Value.h" |
| #include "mlir/Support/LLVM.h" |
| #include "mlir/Support/LogicalResult.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| //===----------------------------------------------------------------------===// |
| // Utils |
| //===----------------------------------------------------------------------===// |
| |
| Value findValueSizeInList(unsigned index, ValueRange values, ValueRange sizes) { |
| assert(values[index].getType().isa<IREE::Util::SizeAwareTypeInterface>() && |
| "must be a size-aware type to get dims"); |
| unsigned sizeIndex = 0; |
| for (unsigned i = 0; i < index; ++i) { |
| if (values[i].getType().isa<IREE::Util::SizeAwareTypeInterface>()) { |
| ++sizeIndex; |
| } |
| } |
| return sizes[sizeIndex]; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<SymbolVisibility>($sym_visibility) |
| //===----------------------------------------------------------------------===// |
| // some.op custom<SymbolVisibility>($sym_visibility) $sym_name |
| // -> |
| // some.op @foo |
| // some.op private @foo |
| |
| ParseResult parseSymbolVisibility(OpAsmParser &parser, |
| StringAttr &symVisibilityAttr) { |
| StringRef symVisibility; |
| parser.parseOptionalKeyword(&symVisibility, {"public", "private", "nested"}); |
| if (!symVisibility.empty()) { |
| symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility); |
| } |
| return success(); |
| } |
| |
| void printSymbolVisibility(OpAsmPrinter &p, Operation *op, |
| StringAttr symVisibilityAttr) { |
| if (!symVisibilityAttr) { |
| p << "public"; |
| } else { |
| p << symVisibilityAttr.getValue(); |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<TypeOrAttr>($type, $attr) |
| //===----------------------------------------------------------------------===// |
| // some.op custom<TypeOrAttr>($type, $attr) |
| // -> |
| // some.op : i32 |
| // some.op = 42 : i32 |
| // some.op : i32 = 42 : index |
| |
| ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, |
| Attribute &attr) { |
| if (succeeded(parser.parseOptionalEqual())) { |
| if (failed(parser.parseAttribute(attr))) { |
| return parser.emitError(parser.getCurrentLocation()) |
| << "expected attribute"; |
| } |
| typeAttr = TypeAttr::get(attr.getType()); |
| return success(); |
| } |
| |
| Type type; |
| if (failed(parser.parseColonType(type))) { |
| return parser.emitError(parser.getCurrentLocation()) << "expected type"; |
| } |
| typeAttr = TypeAttr::get(type); |
| |
| if (succeeded(parser.parseOptionalEqual())) { |
| if (failed(parser.parseAttribute(attr))) { |
| return parser.emitError(parser.getCurrentLocation()) |
| << "expected attribute"; |
| } |
| } |
| |
| return success(); |
| } |
| |
| void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, |
| Attribute attr) { |
| bool needsSpace = false; |
| if (!attr || attr.getType() != type.getValue()) { |
| p << ": "; |
| p.printAttribute(type); |
| needsSpace = true; // subsequent attr value needs a space separator |
| } |
| if (attr) { |
| if (needsSpace) p << ' '; |
| p << "= "; |
| p.printAttribute(attr); |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<TypeAlias>($encoding_type, $storage_type) |
| //===----------------------------------------------------------------------===// |
| // tensor<4xf32> |
| // tensor<4xf32> as tensor<2xf64> |
| |
| ParseResult parseTypeAlias(OpAsmParser &parser, TypeAttr &encodingTypeAttr, |
| Type &storageType) { |
| Type encodingType; |
| if (failed(parser.parseType(encodingType))) return failure(); |
| storageType = encodingType; |
| if (succeeded(parser.parseOptionalKeyword("as"))) { |
| if (failed(parser.parseType(storageType))) return failure(); |
| } |
| encodingTypeAttr = TypeAttr::get(encodingType); |
| return success(); |
| } |
| |
| void printTypeAlias(OpAsmPrinter &p, Operation *op, TypeAttr encodingTypeAttr, |
| Type storageType) { |
| if (encodingTypeAttr.getValue() != storageType) { |
| p.printType(encodingTypeAttr.getValue()); |
| p << " as "; |
| } |
| p.printType(storageType); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<RangeList>($offsets, $lengths) |
| //===----------------------------------------------------------------------===// |
| // [%offset for %length], [%offset for %length], ... |
| |
| ParseResult parseRangeList(OpAsmParser &parser, |
| SmallVectorImpl<OpAsmParser::OperandType> &offsets, |
| SmallVectorImpl<OpAsmParser::OperandType> &lengths) { |
| do { |
| OpAsmParser::OperandType offset; |
| OpAsmParser::OperandType length; |
| if (failed(parser.parseLSquare()) || failed(parser.parseOperand(offset)) || |
| failed(parser.parseKeyword("for")) || |
| failed(parser.parseOperand(length)) || failed(parser.parseRSquare())) { |
| return failure(); |
| } |
| offsets.push_back(offset); |
| lengths.push_back(length); |
| } while (succeeded(parser.parseOptionalComma())); |
| return success(); |
| } |
| |
| void printRangeList(OpAsmPrinter &p, Operation *op, OperandRange offsets, |
| OperandRange lengths) { |
| llvm::interleaveComma(llvm::zip(offsets, lengths), p, [&](auto it) { |
| auto offset = std::get<0>(it); |
| auto length = std::get<1>(it); |
| p << "["; |
| p.printOperand(offset); |
| p << " for "; |
| p.printOperand(length); |
| p << "]"; |
| }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<SizeAwareType> |
| //===----------------------------------------------------------------------===// |
| // type{%size} |
| |
| ParseResult parseSizeAwareType(OpAsmParser &parser, Type &type, |
| OpAsmParser::OperandType &size) { |
| if (failed(parser.parseType(type)) || failed(parser.parseLBrace()) || |
| failed(parser.parseOperand(size)) || failed(parser.parseRBrace())) { |
| return failure(); |
| } |
| return success(); |
| } |
| |
| void printSizeAwareType(OpAsmPrinter &p, Operation *op, Type type, Value size) { |
| p.printType(type); |
| p << "{"; |
| p.printOperand(size); |
| p << "}"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<SizeAwareTypeList> |
| //===----------------------------------------------------------------------===// |
| // type{%size0}, type, type{%size1} |
| |
| ParseResult parseSizeAwareTypeList( |
| OpAsmParser &parser, SmallVectorImpl<Type> &types, |
| SmallVectorImpl<OpAsmParser::OperandType> &sizes) { |
| do { |
| Type type; |
| if (failed(parser.parseType(type))) return failure(); |
| if (type.isa<IREE::Util::SizeAwareTypeInterface>()) { |
| OpAsmParser::OperandType size; |
| if (failed(parser.parseLBrace()) || failed(parser.parseOperand(size)) || |
| failed(parser.parseRBrace())) { |
| return failure(); |
| } |
| sizes.push_back(size); |
| } |
| types.push_back(type); |
| } while (succeeded(parser.parseOptionalComma())); |
| return success(); |
| } |
| |
| void printSizeAwareTypeList(OpAsmPrinter &p, Operation *op, TypeRange types, |
| OperandRange sizes) { |
| int sizeIndex = 0; |
| llvm::interleaveComma(types, p, [&](Type type) { |
| p.printType(type); |
| if (type.isa<IREE::Util::SizeAwareTypeInterface>()) { |
| p << "{"; |
| p.printOperand(sizes[sizeIndex++]); |
| p << "}"; |
| } |
| }); |
| } |
| |
| ParseResult parseSizeAwareTypeList( |
| OpAsmParser &parser, SmallVectorImpl<Type> &types0, |
| SmallVectorImpl<Type> &types1, |
| SmallVectorImpl<OpAsmParser::OperandType> &sizes) { |
| if (failed(parseSizeAwareTypeList(parser, types0, sizes))) return failure(); |
| types1 = types0; |
| return success(); |
| } |
| |
| void printSizeAwareTypeList(OpAsmPrinter &p, Operation *op, TypeRange types0, |
| TypeRange types1, OperandRange sizes) { |
| printSizeAwareTypeList(p, op, types0, sizes); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<ShapedTiedResult> |
| //===----------------------------------------------------------------------===// |
| // type{%dim0, %dim1} |
| // %arg0 as type{%dim0} |
| |
| ParseResult parseShapedTiedResult( |
| OpAsmParser &parser, Type &resultType, |
| SmallVectorImpl<OpAsmParser::OperandType> &resultDims) { |
| ArrayAttr tiedOperands; |
| return parseShapedTiedResult(parser, resultType, resultDims, tiedOperands); |
| } |
| |
| ParseResult parseShapedTiedResult( |
| OpAsmParser &parser, Type &resultType, |
| SmallVectorImpl<OpAsmParser::OperandType> &resultDims, |
| ArrayAttr &tiedOperands) { |
| OpAsmParser::OperandType tiedResult; |
| auto res = parser.parseOptionalOperand(tiedResult); |
| int64_t tiedOperandIndex = IREE::Util::TiedOpInterface::kUntiedIndex; |
| if (res.hasValue() && succeeded(res.getValue())) { |
| tiedOperandIndex = 0; |
| if (failed(parser.parseKeyword("as"))) return failure(); |
| } |
| if (failed(parser.parseType(resultType))) return failure(); |
| if (auto shapedType = resultType.dyn_cast<ShapedType>()) { |
| if (!shapedType.hasStaticShape()) { |
| SmallVector<OpAsmParser::OperandType, 4> dynamicDims; |
| if (failed(parser.parseLBrace()) || |
| failed(parser.parseOperandList(dynamicDims, |
| shapedType.getNumDynamicDims(), |
| OpAsmParser::Delimiter::None)) || |
| failed(parser.parseRBrace())) { |
| return failure(); |
| } |
| resultDims.append(dynamicDims); |
| } |
| } else if (auto sizedType = |
| resultType.dyn_cast<IREE::Util::SizeAwareTypeInterface>()) { |
| OpAsmParser::OperandType size; |
| if (failed(parser.parseLBrace()) || failed(parser.parseOperand(size)) || |
| failed(parser.parseRBrace())) { |
| return failure(); |
| } |
| resultDims.push_back(size); |
| } |
| tiedOperands = parser.getBuilder().getIndexArrayAttr({tiedOperandIndex}); |
| return success(); |
| } |
| |
| void printShapedTiedResult(OpAsmPrinter &p, Operation *op, Type resultType, |
| ValueRange resultDims) { |
| auto tiedOp = cast<IREE::Util::TiedOpInterface>(op); |
| auto tiedOperandIndex = tiedOp.getTiedResultOperandIndex(0); |
| if (tiedOperandIndex.hasValue()) { |
| auto tiedOperand = op->getOperand(tiedOperandIndex.getValue()); |
| p.printOperand(tiedOperand); |
| p << " as "; |
| } |
| p.printType(resultType); |
| if (auto shapedType = resultType.dyn_cast<ShapedType>()) { |
| if (!shapedType.hasStaticShape()) { |
| if (resultDims.empty()) { |
| p << "{<<INVALID>>}"; |
| return; |
| } |
| p << "{"; |
| llvm::interleaveComma( |
| resultDims.take_front(shapedType.getNumDynamicDims()), p, |
| [&](Value value) { p.printOperand(value); }); |
| p << "}"; |
| resultDims = resultDims.drop_front(shapedType.getNumDynamicDims()); |
| } |
| } else if (auto sizedType = |
| resultType.dyn_cast<IREE::Util::SizeAwareTypeInterface>()) { |
| p << "{"; |
| p.printOperand(resultDims.front()); |
| p << "}"; |
| resultDims = resultDims.drop_front(1); |
| } |
| } |
| |
| void printShapedTiedResult(OpAsmPrinter &p, Operation *op, Type resultType, |
| ValueRange resultDims, ArrayAttr tiedOperands) { |
| printShapedTiedResult(p, op, resultType, resultDims); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // custom<ShapedFunctionType> |
| //===----------------------------------------------------------------------===// |
| // (type, type{%dim0, %dim1}, type) -> (type{%dim2}, %operand4) |
| |
| static ParseResult parseShapedOperandList( |
| OpAsmParser &parser, SmallVectorImpl<Type> &types, |
| SmallVectorImpl<OpAsmParser::OperandType> &dims) { |
| do { |
| Type type; |
| if (failed(parser.parseType(type))) return failure(); |
| if (auto shapedType = type.dyn_cast<ShapedType>()) { |
| if (!shapedType.hasStaticShape()) { |
| SmallVector<OpAsmParser::OperandType, 4> dynamicDims; |
| if (failed(parser.parseLBrace()) || |
| failed(parser.parseOperandList(dynamicDims, |
| shapedType.getNumDynamicDims(), |
| OpAsmParser::Delimiter::None)) || |
| failed(parser.parseRBrace())) { |
| return failure(); |
| } |
| dims.append(dynamicDims); |
| } |
| } else if (auto sizedType = |
| type.dyn_cast<IREE::Util::SizeAwareTypeInterface>()) { |
| OpAsmParser::OperandType size; |
| if (failed(parser.parseLBrace()) || failed(parser.parseOperand(size)) || |
| failed(parser.parseRBrace())) { |
| return failure(); |
| } |
| dims.push_back(size); |
| } |
| types.push_back(type); |
| } while (succeeded(parser.parseOptionalComma())); |
| return success(); |
| } |
| |
| // Finds the operand index in |operands| that |tiedResult| references. |
| // Returns TiedOpInterface::kUntiedIndex if no operand is found. |
| static int64_t findTiedOperand(OpAsmParser::OperandType tiedResult, |
| ArrayRef<OpAsmParser::OperandType> operands) { |
| int64_t operandIndex = IREE::Util::TiedOpInterface::kUntiedIndex; |
| for (int64_t i = 0; i < operands.size(); ++i) { |
| if (operands[i].name == tiedResult.name && |
| operands[i].number == tiedResult.number) { |
| operandIndex = i; |
| break; |
| } |
| } |
| return operandIndex; |
| } |
| |
| ParseResult parseShapedResultList( |
| OpAsmParser &parser, ArrayRef<OpAsmParser::OperandType> operands, |
| TypeRange operandTypes, ArrayRef<OpAsmParser::OperandType> operandDims, |
| SmallVectorImpl<Type> &resultTypes, |
| SmallVectorImpl<OpAsmParser::OperandType> &resultDims, |
| ArrayAttr &tiedOperands) { |
| SmallVector<int64_t, 4> tiedOperandIndices; |
| do { |
| OpAsmParser::OperandType tiedResult; |
| auto res = parser.parseOptionalOperand(tiedResult); |
| Type type; |
| int64_t tiedOperandIndex = IREE::Util::TiedOpInterface::kUntiedIndex; |
| if (res.hasValue() && succeeded(res.getValue())) { |
| tiedOperandIndex = findTiedOperand(tiedResult, operands); |
| if (tiedOperandIndex == IREE::Util::TiedOpInterface::kUntiedIndex) { |
| return parser.emitError(tiedResult.location, |
| "tied operand not found for result reference ") |
| << tiedResult.name; |
| } |
| if (succeeded(parser.parseOptionalKeyword("as"))) { |
| // Type _may_ differ from the operand. |
| if (failed(parser.parseType(type))) return failure(); |
| } else { |
| // Use the operands type. |
| type = operandTypes[tiedOperandIndex]; |
| } |
| } else if (failed(parser.parseType(type))) { |
| return failure(); |
| } |
| if (auto shapedType = type.dyn_cast<ShapedType>()) { |
| if (!shapedType.hasStaticShape()) { |
| SmallVector<OpAsmParser::OperandType, 4> dynamicDims; |
| if (failed(parser.parseLBrace()) || |
| failed(parser.parseOperandList(dynamicDims, |
| shapedType.getNumDynamicDims(), |
| OpAsmParser::Delimiter::None)) || |
| failed(parser.parseRBrace())) { |
| return failure(); |
| } |
| resultDims.append(dynamicDims); |
| } |
| } else if (auto sizedType = |
| type.dyn_cast<IREE::Util::SizeAwareTypeInterface>()) { |
| OpAsmParser::OperandType size; |
| if (failed(parser.parseLBrace()) || failed(parser.parseOperand(size)) || |
| failed(parser.parseRBrace())) { |
| return failure(); |
| } |
| resultDims.push_back(size); |
| } |
| resultTypes.push_back(type); |
| tiedOperandIndices.push_back(tiedOperandIndex); |
| } while (succeeded(parser.parseOptionalComma())); |
| if (!tiedOperandIndices.empty()) { |
| tiedOperands = parser.getBuilder().getIndexArrayAttr(tiedOperandIndices); |
| } |
| return success(); |
| } |
| |
| void printShapedResultList(OpAsmPrinter &p, Operation *op, ValueRange operands, |
| TypeRange operandTypes, ValueRange operandDims, |
| TypeRange resultTypes, ValueRange resultDims, |
| ArrayAttr tiedOperands) { |
| auto tiedOp = cast<IREE::Util::TiedOpInterface>(op); |
| for (unsigned i = 0; i < resultTypes.size(); ++i) { |
| auto resultType = resultTypes[i]; |
| auto tiedOperandIndex = tiedOp.getTiedResultOperandIndex(i); |
| bool printType = true; |
| if (tiedOperandIndex.hasValue()) { |
| auto tiedOperand = op->getOperand(tiedOperandIndex.getValue()); |
| p.printOperand(tiedOperand); |
| if (tiedOperand.getType() != resultType) { |
| p << " as "; |
| } else { |
| // Type elided as it matches the operand. |
| printType = false; |
| } |
| } |
| if (printType) { |
| p.printType(resultType); |
| } |
| if (auto shapedType = resultType.dyn_cast<ShapedType>()) { |
| if (!shapedType.hasStaticShape()) { |
| if (resultDims.empty()) { |
| p << "{<<INVALID>>}"; |
| return; |
| } |
| p << "{"; |
| llvm::interleaveComma( |
| resultDims.take_front(shapedType.getNumDynamicDims()), p, |
| [&](Value value) { p.printOperand(value); }); |
| p << "}"; |
| resultDims = resultDims.drop_front(shapedType.getNumDynamicDims()); |
| } |
| } else if (auto sizedType = |
| resultType.dyn_cast<IREE::Util::SizeAwareTypeInterface>()) { |
| p << "{"; |
| p.printOperand(resultDims.front()); |
| p << "}"; |
| resultDims = resultDims.drop_front(1); |
| } |
| if (i < resultTypes.size() - 1) p << ", "; |
| } |
| } |
| |
| ParseResult parseShapedFunctionType( |
| OpAsmParser &parser, ArrayRef<OpAsmParser::OperandType> operands, |
| SmallVectorImpl<Type> &operandTypes, |
| SmallVectorImpl<OpAsmParser::OperandType> &operandDims, |
| SmallVectorImpl<Type> &resultTypes, |
| SmallVectorImpl<OpAsmParser::OperandType> &resultDims, |
| ArrayAttr &tiedOperands) { |
| if (failed(parser.parseLParen())) return failure(); |
| if (failed(parser.parseOptionalRParen())) { |
| if (failed(parseShapedOperandList(parser, operandTypes, operandDims)) || |
| failed(parser.parseRParen())) { |
| return failure(); |
| } |
| } |
| if (failed(parser.parseArrow())) return failure(); |
| if (succeeded(parser.parseOptionalLParen())) { |
| if (failed(parseShapedResultList(parser, operands, operandTypes, |
| operandDims, resultTypes, resultDims, |
| tiedOperands)) || |
| failed(parser.parseRParen())) { |
| return failure(); |
| } |
| } else { |
| if (failed(parseShapedResultList(parser, operands, operandTypes, |
| operandDims, resultTypes, resultDims, |
| tiedOperands))) { |
| return failure(); |
| } |
| } |
| return success(); |
| } |
| |
| void printShapedFunctionType(OpAsmPrinter &p, Operation *op, |
| ValueRange operands, TypeRange operandTypes, |
| OperandRange operandDims, TypeRange resultTypes, |
| OperandRange resultDims, ArrayAttr tiedOperands) { |
| p << "("; |
| llvm::interleaveComma(operandTypes, p, [&](Type type) { |
| p.printType(type); |
| if (auto shapedType = type.dyn_cast<ShapedType>()) { |
| if (!shapedType.hasStaticShape()) { |
| if (operandDims.empty()) { |
| p << "{<<INVALID>>}"; |
| return; |
| } |
| p << "{"; |
| llvm::interleaveComma( |
| operandDims.take_front(shapedType.getNumDynamicDims()), p, |
| [&](Value value) { p.printOperand(value); }); |
| p << "}"; |
| operandDims = operandDims.drop_front(shapedType.getNumDynamicDims()); |
| } |
| } else if (auto sizedType = |
| type.dyn_cast<IREE::Util::SizeAwareTypeInterface>()) { |
| p << "{"; |
| p.printOperand(operandDims.front()); |
| p << "}"; |
| operandDims = operandDims.drop_front(1); |
| } |
| }); |
| p << ") -> "; |
| if (resultTypes.size() != 1) p << "("; |
| printShapedResultList(p, op, operands, operandTypes, operandDims, resultTypes, |
| resultDims, tiedOperands); |
| if (resultTypes.size() != 1) p << ")"; |
| } |
| |
| namespace IREE { |
| namespace Util { |
| |
| //===----------------------------------------------------------------------===// |
| // util.do_not_optimize |
| //===----------------------------------------------------------------------===// |
| |
| void DoNotOptimizeOp::build(OpBuilder &builder, OperationState &state, |
| ValueRange operands, |
| ArrayRef<NamedAttribute> attributes) { |
| state.addOperands(operands); |
| state.addTypes(llvm::to_vector<2>(operands.getTypes())); |
| state.addAttributes(attributes); |
| } |
| |
| ParseResult parseDoNotOptimizeOp(OpAsmParser &parser, OperationState &state) { |
| SmallVector<OpAsmParser::OperandType, 2> args; |
| // Operands and results have the same types. |
| auto &operandTypes = state.types; |
| |
| if (failed(parser.parseLParen()) || failed(parser.parseOperandList(args)) || |
| failed(parser.parseRParen()) || |
| failed(parser.parseOptionalAttrDict(state.attributes)) || |
| failed(parser.parseOptionalColonTypeList(state.types)) || |
| failed(parser.resolveOperands( |
| args, operandTypes, parser.getCurrentLocation(), state.operands))) { |
| return failure(); |
| } |
| |
| return success(); |
| } |
| |
| void printDoNotOptimizeOp(OpAsmPrinter &p, Operation *op) { |
| p << "("; |
| p.printOperands(op->getOperands()); |
| p << ")"; |
| p.printOptionalAttrDict(op->getAttrs()); |
| |
| if (op->getNumOperands() != 0) { |
| p << " : "; |
| interleaveComma(op->getOperandTypes(), p); |
| } |
| } |
| |
| static LogicalResult verifyDoNotOptimizeOp(DoNotOptimizeOp op) { |
| if (op.getNumOperands() != op.getNumResults()) { |
| return op.emitOpError() |
| << "must have same number of operands and results, but has " |
| << op.getNumOperands() << " and " << op.getNumResults() |
| << ", respectively"; |
| } |
| |
| for (int i = 0, e = op.getNumOperands(); i < e; ++i) { |
| if (op.getOperand(i).getType() != op.getResult(i).getType()) { |
| op.emitOpError() << "must have same operand and result types, but they " |
| "differ at index " |
| << i; |
| } |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // util.unfoldable_constant |
| //===----------------------------------------------------------------------===// |
| |
| // Parsing/printing copied from std.constant |
| |
| ParseResult parseUnfoldableConstantOp(OpAsmParser &parser, |
| OperationState &state) { |
| Attribute valueAttr; |
| if (parser.parseOptionalAttrDict(state.attributes) || |
| parser.parseAttribute(valueAttr, "value", state.attributes)) |
| return failure(); |
| |
| // If the attribute is a symbol reference, then we expect a trailing type. |
| Type type; |
| if (!valueAttr.isa<SymbolRefAttr>()) |
| type = valueAttr.getType(); |
| else if (parser.parseColonType(type)) |
| return failure(); |
| |
| // Add the attribute type to the list. |
| return parser.addTypeToList(type, state.types); |
| } |
| |
| void printUnfoldableConstantOp(OpAsmPrinter &p, Operation *op) { |
| auto constOp = cast<IREE::Util::UnfoldableConstantOp>(op); |
| p << " "; |
| p.printOptionalAttrDict(constOp->getAttrs(), /*elidedAttrs=*/{"value"}); |
| |
| if (constOp->getAttrs().size() > 1) p << ' '; |
| p << constOp.value(); |
| |
| // If the value is a symbol reference, print a trailing type. |
| if (constOp.value().isa<SymbolRefAttr>()) p << " : " << constOp.getType(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Structural ops |
| //===----------------------------------------------------------------------===// |
| |
| void InitializerOp::build(OpBuilder &builder, OperationState &result, |
| ArrayRef<NamedAttribute> attrs) { |
| result.addAttribute( |
| "type", TypeAttr::get(FunctionType::get(builder.getContext(), {}, {}))); |
| result.addRegion(); |
| result.attributes.append(attrs.begin(), attrs.end()); |
| } |
| |
| static ParseResult parseInitializerOp(OpAsmParser &parser, |
| OperationState &result) { |
| result.addAttribute( |
| "type", TypeAttr::get(FunctionType::get(result.getContext(), {}, {}))); |
| if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) { |
| return failure(); |
| } |
| auto &body = *result.addRegion(); |
| if (failed(parser.parseRegion(body))) { |
| return failure(); |
| } |
| return success(); |
| } |
| |
| static void printInitializerOp(OpAsmPrinter &p, InitializerOp &op) { |
| p.printOptionalAttrDictWithKeyword(op->getAttrs(), /*elidedAttrs=*/{"type"}); |
| p.printRegion(op.body()); |
| } |
| |
| Block *InitializerOp::addEntryBlock() { |
| assert(empty() && "function already has an entry block"); |
| auto *entry = new Block(); |
| push_back(entry); |
| return entry; |
| } |
| |
| Block *InitializerOp::addBlock() { |
| assert(!empty() && "function should at least have an entry block"); |
| push_back(new Block()); |
| return &back(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Globals |
| //===----------------------------------------------------------------------===// |
| |
| // Returns true if the given |accessType| is compatible with the |globalType|. |
| // For example, this will return true if the global type is a tensor<?xf32> |
| // and the access is tensor<4xf32>. |
| static bool isGlobalTypeCompatible(Type globalType, Type accessType) { |
| // If one is a shaped type, then they both must be and have compatible |
| // shapes. |
| if (globalType.isa<ShapedType>() && accessType.isa<ShapedType>()) { |
| return succeeded(mlir::verifyCompatibleShape(globalType, accessType)); |
| } |
| |
| if (auto knownType = globalType.dyn_cast<GlobalTypeInterface>()) { |
| return knownType.isAccessStorageCompatible(accessType); |
| } |
| |
| // Otherwise, the types must be the same. |
| return globalType == accessType; |
| } |
| |
| void GlobalOp::build(OpBuilder &builder, OperationState &result, StringRef name, |
| bool isMutable, Type type, |
| Optional<Attribute> initialValue, |
| ArrayRef<NamedAttribute> attrs) { |
| result.addAttribute(SymbolTable::getSymbolAttrName(), |
| builder.getStringAttr(name)); |
| if (isMutable) { |
| result.addAttribute("is_mutable", builder.getUnitAttr()); |
| } |
| if (initialValue.hasValue()) { |
| result.addAttribute("initial_value", initialValue.getValue()); |
| } |
| result.addAttribute("type", TypeAttr::get(type)); |
| result.attributes.append(attrs.begin(), attrs.end()); |
| } |
| |
| void GlobalOp::build(OpBuilder &builder, OperationState &result, StringRef name, |
| bool isMutable, Type type, |
| ArrayRef<NamedAttribute> attrs) { |
| build(builder, result, name, isMutable, type, llvm::None, attrs); |
| } |
| |
| static LogicalResult verifyGlobalOp(GlobalOp op) { |
| if (op.initial_value().hasValue()) { |
| // Ensure the value is something we can convert to a const. |
| if (!isGlobalTypeCompatible(op.type(), op.initial_valueAttr().getType())) { |
| return op->emitOpError() |
| << "initial value type mismatch; global " << op.getSymbolName() |
| << " is " << op.type() << " but initial value provided is " |
| << op.initial_valueAttr().getType(); |
| } |
| } |
| return success(); |
| } |
| |
| IREE::Util::GlobalOp GlobalAddressOp::getGlobalOp() { |
| return SymbolTable::lookupNearestSymbolFrom<IREE::Util::GlobalOp>( |
| getOperation()->getParentOp(), globalAttr()); |
| } |
| |
| FlatSymbolRefAttr GlobalAddressOp::getGlobalRefAttr() { return globalAttr(); } |
| |
| void GlobalAddressOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(result(), Twine("ptr_" + global()).str()); |
| } |
| |
| static LogicalResult verifyGlobalAddressOp(GlobalAddressOp op) { |
| auto globalOp = op.getGlobalOp(); |
| if (!globalOp) { |
| return op.emitOpError() << "undefined global: " << op.global(); |
| } |
| return success(); |
| } |
| |
| void GlobalLoadOp::build(OpBuilder &builder, OperationState &state, |
| GlobalOp globalOp, ArrayRef<NamedAttribute> attrs) { |
| state.addTypes({globalOp.type()}); |
| state.addAttribute("global", SymbolRefAttr::get(globalOp)); |
| state.attributes.append(attrs.begin(), attrs.end()); |
| } |
| |
| IREE::Util::GlobalOp GlobalLoadOp::getGlobalOp() { |
| return SymbolTable::lookupNearestSymbolFrom<IREE::Util::GlobalOp>( |
| getOperation()->getParentOp(), globalAttr()); |
| } |
| |
| FlatSymbolRefAttr GlobalLoadOp::getGlobalRefAttr() { return globalAttr(); } |
| |
| bool GlobalLoadOp::isGlobalImmutable() { return !getGlobalOp().is_mutable(); } |
| |
| void GlobalLoadOp::getAsmResultNames( |
| function_ref<void(Value, StringRef)> setNameFn) { |
| setNameFn(result(), global()); |
| } |
| |
| void GlobalLoadOp::getEffects( |
| SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| // HACK: works around the lack of symbol side effects in mlir by only saying |
| // we have a side-effect if the variable we are loading is mutable. |
| auto globalOp = |
| SymbolTable::lookupNearestSymbolFrom<GlobalOp>(*this, globalAttr()); |
| assert(globalOp); |
| if (globalOp.is_mutable()) { |
| effects.emplace_back(MemoryEffects::Read::get()); |
| } |
| } |
| |
| static LogicalResult verifyGlobalLoadOp(GlobalLoadOp op) { |
| auto globalOp = op.getGlobalOp(); |
| if (!globalOp) { |
| return op->emitOpError() << "undefined global: " << op.global(); |
| } |
| auto loadType = op->getResult(0).getType(); |
| if (!isGlobalTypeCompatible(globalOp.type(), loadType)) { |
| return op->emitOpError() |
| << "global type mismatch; global " << op.global() << " is " |
| << globalOp.type() << " but load is " << loadType; |
| } |
| return success(); |
| } |
| |
| static LogicalResult verifyGlobalLoadIndirectOp(GlobalLoadIndirectOp &op) { |
| auto globalType = |
| op.global().getType().cast<IREE::Util::PtrType>().getTargetType(); |
| auto loadType = op.result().getType(); |
| if (!isGlobalTypeCompatible(globalType, loadType)) { |
| return op.emitOpError() << "global type mismatch; global pointer is " |
| << globalType << " but load is " << loadType; |
| } |
| return success(); |
| } |
| |
| IREE::Util::GlobalOp GlobalStoreOp::getGlobalOp() { |
| return SymbolTable::lookupNearestSymbolFrom<IREE::Util::GlobalOp>( |
| getOperation()->getParentOp(), globalAttr()); |
| } |
| |
| FlatSymbolRefAttr GlobalStoreOp::getGlobalRefAttr() { return globalAttr(); } |
| |
| static LogicalResult verifyGlobalStoreOp(GlobalStoreOp op) { |
| auto globalOp = op.getGlobalOp(); |
| if (!globalOp) { |
| return op->emitOpError() << "undefined global: " << op.global(); |
| } |
| auto storeType = op->getOperand(0).getType(); |
| if (globalOp.type() != storeType) { |
| return op->emitOpError() |
| << "global type mismatch; global " << op.global() << " is " |
| << globalOp.type() << " but store is " << storeType; |
| } |
| if (!globalOp.isMutable()) { |
| // Allow stores to immutable globals in initializers. |
| if (!op->getParentOfType<InitializerOp>()) { |
| return op->emitOpError() << "global " << op.global() |
| << " is not mutable and cannot be stored to"; |
| } |
| } |
| return success(); |
| } |
| |
| static LogicalResult verifyGlobalStoreIndirectOp(GlobalStoreIndirectOp &op) { |
| auto globalType = |
| op.global().getType().cast<IREE::Util::PtrType>().getTargetType(); |
| auto storeType = op.value().getType(); |
| if (!isGlobalTypeCompatible(globalType, storeType)) { |
| return op.emitOpError() << "global type mismatch; global pointer is " |
| << globalType << " but store is " << storeType; |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Lists |
| //===----------------------------------------------------------------------===// |
| |
| static ParseResult parseListTypeGet(OpAsmParser &parser, Type &listType, |
| Type &elementType) { |
| if (failed(parser.parseType(listType))) { |
| return parser.emitError(parser.getCurrentLocation(), |
| "expected !util.list<T> type"); |
| } |
| auto listElementType = listType.cast<ListType>().getElementType(); |
| if (succeeded(parser.parseOptionalArrow())) { |
| // Use overridden type - required for variants only. |
| if (failed(parser.parseType(elementType))) { |
| return parser.emitError( |
| parser.getCurrentLocation(), |
| "expected an element type when specifying list access types"); |
| } |
| if (!ListType::canImplicitlyCast(listElementType, elementType)) { |
| return parser.emitError( |
| parser.getCurrentLocation(), |
| "list access types must match the same base type as the list element " |
| "type (when not variant)"); |
| } |
| } else { |
| // Use list element type as the result element type. |
| elementType = listElementType; |
| } |
| return success(); |
| } |
| |
| static void printListTypeGet(OpAsmPrinter &printer, Operation *, Type listType, |
| Type elementType) { |
| printer.printType(listType); |
| auto listElementType = listType.cast<ListType>().getElementType(); |
| if (listElementType != elementType) { |
| printer.printArrowTypeList(ArrayRef<Type>{elementType}); |
| } |
| } |
| |
| static ParseResult parseListTypeSet(OpAsmParser &parser, Type &listType, |
| Type &elementType) { |
| Type leadingType; |
| if (failed(parser.parseType(leadingType))) { |
| return parser.emitError(parser.getCurrentLocation(), |
| "expected element type or !util.list<T> type"); |
| } |
| if (succeeded(parser.parseOptionalArrow())) { |
| elementType = leadingType; |
| if (failed(parser.parseType(listType)) || !listType.isa<ListType>()) { |
| return parser.emitError(parser.getCurrentLocation(), |
| "expected an !util.list<T> type"); |
| } |
| } else { |
| if (!leadingType.isa<ListType>()) { |
| return parser.emitError(parser.getCurrentLocation(), |
| "expected an !util.list<T> type"); |
| } |
| listType = leadingType; |
| elementType = listType.cast<ListType>().getElementType(); |
| } |
| return success(); |
| } |
| |
| static void printListTypeSet(OpAsmPrinter &printer, Operation *, Type listType, |
| Type elementType) { |
| auto listElementType = listType.cast<ListType>().getElementType(); |
| if (listElementType != elementType) { |
| printer.printType(elementType); |
| printer.printArrowTypeList(ArrayRef<Type>{listType}); |
| } else { |
| printer.printType(listType); |
| } |
| } |
| |
| static LogicalResult verifyListGetOp(ListGetOp &op) { |
| auto listType = op.list().getType().cast<IREE::Util::ListType>(); |
| auto elementType = listType.getElementType(); |
| auto resultType = op.result().getType(); |
| if (!ListType::canImplicitlyCast(elementType, resultType)) { |
| return op.emitError() << "list contains " << elementType |
| << " and cannot be accessed as " << resultType; |
| } |
| return success(); |
| } |
| |
| static LogicalResult verifyListSetOp(ListSetOp &op) { |
| auto listType = op.list().getType().cast<IREE::Util::ListType>(); |
| auto elementType = listType.getElementType(); |
| auto valueType = op.value().getType(); |
| if (!ListType::canImplicitlyCast(valueType, elementType)) { |
| return op.emitError() << "list contains " << elementType |
| << " and cannot be mutated as " << valueType; |
| } |
| return success(); |
| } |
| |
| } // namespace Util |
| } // namespace IREE |
| } // namespace iree_compiler |
| } // namespace mlir |
| |
| #define GET_OP_CLASSES |
| #include "iree/compiler/Dialect/Util/IR/UtilOps.cpp.inc" |