blob: d18360fcf65a794e3ecebee8a5fdda58e1b91b1c [file]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOpUtils.h"
#include "iree/compiler/Dialect/Shape/IR/Builders.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/RegionUtils.h"
static llvm::cl::opt<int> clInlineConstantByteLength(
"iree-flow-inline-constants-max-byte-length",
llvm::cl::desc("Maximum byte-length of constant that can be inlined into a "
"dispatch region"),
llvm::cl::init(256));
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Flow {
//===----------------------------------------------------------------------===//
// Op utilities used within the Flow dialect
//===----------------------------------------------------------------------===//
// Returns true if the given |accessType| is compatible with the |variableType|.
// For example, this will return true if the variable type is a tensor<?xf32>
// and the access is tensor<4xf32>.
static bool isVariableTypeCompatible(Type variableType, Type accessType) {
return succeeded(mlir::verifyCompatibleShape(variableType, accessType));
}
// Verifies that |dynamicDims| contains the appropriate number of dims for all
// of the dynamic dimensions in |values|.
static LogicalResult verifyOpDynamicDims(Operation *op, ValueRange values,
ValueRange dynamicDims) {
unsigned requiredCount = 0;
for (auto value : values) {
if (auto shapedType = value.getType().dyn_cast<ShapedType>()) {
requiredCount += shapedType.getNumDynamicDims();
}
}
if (dynamicDims.size() != requiredCount) {
return op->emitOpError()
<< "value set has " << requiredCount
<< " dynamic dimensions but only " << dynamicDims.size()
<< " dimension values are attached";
}
return success();
}
//===----------------------------------------------------------------------===//
// custom<TiedResult>
//===----------------------------------------------------------------------===//
// type{%dim0, %dim1}
// %arg0
static ParseResult parseTiedResult(
OpAsmParser &parser, Type &resultType,
SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
ArrayAttr &tiedOperands) {
if (failed(parser.parseType(resultType))) return failure();
if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
if (!shapedType.hasStaticShape()) {
SmallVector<OpAsmParser::OperandType, 4> dynamicDims;
if (failed(parser.parseLBrace()) ||
failed(parser.parseOperandList(dynamicDims,
shapedType.getNumDynamicDims(),
OpAsmParser::Delimiter::None)) ||
failed(parser.parseRBrace())) {
return failure();
}
resultDims.append(dynamicDims);
}
}
tiedOperands = parser.getBuilder().getIndexArrayAttr({0});
return success();
}
static void printTiedResult(OpAsmPrinter &p, Operation *op, Type resultType,
ValueRange resultDims, ArrayAttr tiedOperands) {
p.printType(resultType);
if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
if (!shapedType.hasStaticShape()) {
if (resultDims.empty()) {
p << "{<<INVALID>>}";
return;
}
p << "{";
llvm::interleaveComma(
resultDims.take_front(shapedType.getNumDynamicDims()), p,
[&](Value value) { p.printOperand(value); });
p << "}";
}
}
}
//===----------------------------------------------------------------------===//
// custom<ShapedFunctionType>
//===----------------------------------------------------------------------===//
// (type, type{%dim0, %dim1}, type) -> (type{%dim2}, %operand4)
static ParseResult parseShapedOperandList(
OpAsmParser &parser, SmallVectorImpl<Type> &types,
SmallVectorImpl<OpAsmParser::OperandType> &dims) {
do {
Type type;
if (failed(parser.parseType(type))) return failure();
if (auto shapedType = type.dyn_cast<ShapedType>()) {
if (!shapedType.hasStaticShape()) {
SmallVector<OpAsmParser::OperandType, 4> dynamicDims;
if (failed(parser.parseLBrace()) ||
failed(parser.parseOperandList(dynamicDims,
shapedType.getNumDynamicDims(),
OpAsmParser::Delimiter::None)) ||
failed(parser.parseRBrace())) {
return failure();
}
dims.append(dynamicDims);
}
}
types.push_back(type);
} while (succeeded(parser.parseOptionalComma()));
return success();
}
// Finds the operand index in |operands| that |tiedResult| references.
// Returns TiedOpInterface::kUntiedIndex if no operand is found.
static int64_t findTiedOperand(OpAsmParser::OperandType tiedResult,
ArrayRef<OpAsmParser::OperandType> operands) {
int64_t operandIndex = IREE::Util::TiedOpInterface::kUntiedIndex;
for (int64_t i = 0; i < operands.size(); ++i) {
if (operands[i].name == tiedResult.name) {
operandIndex = i;
break;
}
}
return operandIndex;
}
static ParseResult parseShapedResultList(
OpAsmParser &parser, ArrayRef<OpAsmParser::OperandType> operands,
TypeRange operandTypes, ArrayRef<OpAsmParser::OperandType> operandDims,
SmallVectorImpl<Type> &resultTypes,
SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
ArrayAttr &tiedOperands) {
SmallVector<int64_t, 4> tiedOperandIndices;
do {
OpAsmParser::OperandType tiedResult;
auto res = parser.parseOptionalOperand(tiedResult);
Type type;
int64_t tiedOperandIndex = IREE::Util::TiedOpInterface::kUntiedIndex;
if (res.hasValue() && succeeded(res.getValue())) {
tiedOperandIndex = findTiedOperand(tiedResult, operands);
if (tiedOperandIndex == IREE::Util::TiedOpInterface::kUntiedIndex) {
return parser.emitError(tiedResult.location,
"tied operand not found for result reference ")
<< tiedResult.name;
}
if (succeeded(parser.parseOptionalKeyword("as"))) {
// Type _may_ differ from the operand.
if (failed(parser.parseType(type))) return failure();
} else {
// Use the operands type.
type = operandTypes[tiedOperandIndex];
}
} else if (failed(parser.parseType(type))) {
return failure();
}
if (auto shapedType = type.dyn_cast<ShapedType>()) {
if (!shapedType.hasStaticShape()) {
SmallVector<OpAsmParser::OperandType, 4> dynamicDims;
if (failed(parser.parseLBrace()) ||
failed(parser.parseOperandList(dynamicDims,
shapedType.getNumDynamicDims(),
OpAsmParser::Delimiter::None)) ||
failed(parser.parseRBrace())) {
return failure();
}
resultDims.append(dynamicDims);
}
}
resultTypes.push_back(type);
tiedOperandIndices.push_back(tiedOperandIndex);
} while (succeeded(parser.parseOptionalComma()));
if (!tiedOperandIndices.empty()) {
tiedOperands = parser.getBuilder().getIndexArrayAttr(tiedOperandIndices);
}
return success();
}
static ParseResult parseShapedFunctionType(
OpAsmParser &parser, ArrayRef<OpAsmParser::OperandType> operands,
SmallVectorImpl<Type> &operandTypes,
SmallVectorImpl<OpAsmParser::OperandType> &operandDims,
SmallVectorImpl<Type> &resultTypes,
SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
ArrayAttr &tiedOperands) {
if (failed(parser.parseLParen())) return failure();
if (failed(parser.parseOptionalRParen())) {
if (failed(parseShapedOperandList(parser, operandTypes, operandDims)) ||
failed(parser.parseRParen())) {
return failure();
}
}
if (failed(parser.parseArrow())) return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (failed(parseShapedResultList(parser, operands, operandTypes,
operandDims, resultTypes, resultDims,
tiedOperands)) ||
failed(parser.parseRParen())) {
return failure();
}
} else {
if (failed(parseShapedResultList(parser, operands, operandTypes,
operandDims, resultTypes, resultDims,
tiedOperands))) {
return failure();
}
}
return success();
}
static void printShapedFunctionType(OpAsmPrinter &p, Operation *op,
ValueRange operands, TypeRange operandTypes,
OperandRange operandDims,
TypeRange resultTypes,
OperandRange resultDims,
ArrayAttr tiedOperands) {
p << "(";
llvm::interleaveComma(operandTypes, p, [&](Type type) {
p.printType(type);
if (auto shapedType = type.dyn_cast<ShapedType>()) {
if (!shapedType.hasStaticShape()) {
if (operandDims.empty()) {
p << "{<<INVALID>>}";
return;
}
p << "{";
llvm::interleaveComma(
operandDims.take_front(shapedType.getNumDynamicDims()), p,
[&](Value value) { p.printOperand(value); });
p << "}";
operandDims = operandDims.drop_front(shapedType.getNumDynamicDims());
}
}
});
p << ") -> ";
if (resultTypes.size() != 1) p << "(";
auto tiedOp = cast<IREE::Util::TiedOpInterface>(op);
for (unsigned i = 0; i < resultTypes.size(); ++i) {
auto resultType = resultTypes[i];
auto tiedOperandIndex = tiedOp.getTiedResultOperandIndex(i);
bool printType = true;
if (tiedOperandIndex.hasValue()) {
auto tiedOperand = op->getOperand(tiedOperandIndex.getValue());
p.printOperand(tiedOperand);
if (tiedOperand.getType() != resultType) {
p << " as ";
} else {
// Type elided as it matches the operand.
printType = false;
}
}
if (printType) {
p.printType(resultType);
}
if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
if (!shapedType.hasStaticShape()) {
if (resultDims.empty()) {
p << "{<<INVALID>>}";
return;
}
p << "{";
llvm::interleaveComma(
resultDims.take_front(shapedType.getNumDynamicDims()), p,
[&](Value value) { p.printOperand(value); });
p << "}";
resultDims = resultDims.drop_front(shapedType.getNumDynamicDims());
}
}
if (i < resultTypes.size() - 1) p << ", ";
}
if (resultTypes.size() != 1) p << ")";
}
//===----------------------------------------------------------------------===//
// flow.variable
//===----------------------------------------------------------------------===//
static ParseResult parseVariableOp(OpAsmParser &parser,
OperationState *result) {
StringAttr nameAttr;
if (failed(parser.parseSymbolName(nameAttr,
mlir::SymbolTable::getSymbolAttrName(),
result->attributes))) {
return failure();
}
if (succeeded(parser.parseOptionalKeyword("mutable"))) {
result->addAttribute("is_mutable", UnitAttr::get(result->getContext()));
}
if (succeeded(parser.parseOptionalKeyword("init"))) {
FlatSymbolRefAttr initializerAttr;
if (failed(parser.parseLParen()) ||
failed(parser.parseAttribute(initializerAttr, "initializer",
result->attributes)) ||
failed(parser.parseRParen())) {
return failure();
}
}
if (failed(parser.parseOptionalColon())) {
Attribute initialValueAttr;
if (failed(parser.parseAttribute(initialValueAttr, "initial_value",
result->attributes))) {
return failure();
}
result->addAttribute("type", TypeAttr::get(initialValueAttr.getType()));
} else {
Type type;
if (failed(parser.parseType(type))) {
return failure();
}
result->addAttribute("type", TypeAttr::get(type));
}
if (failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
return success();
}
static void printVariableOp(OpAsmPrinter &p, VariableOp op) {
p << op.getOperationName() << ' ';
p.printSymbolName(op.sym_name());
if (op.is_mutable()) {
p << " mutable";
}
if (op.initializer().hasValue()) {
p << " init(";
p.printSymbolName(op.initializer().getValue());
p << ')';
}
if (op.initial_value().hasValue()) {
p << ' ';
p.printAttribute(op.initial_value().getValue());
} else {
p << " : ";
p.printType(op.type());
}
p.printOptionalAttrDictWithKeyword(op->getAttrs(), /*elidedAttrs=*/{
"sym_name",
"type",
"is_mutable",
"initializer",
"initial_value",
});
}
static LogicalResult verifyVariableOp(VariableOp op) {
if (op.initializer().hasValue() && op.initial_value().hasValue()) {
return op.emitOpError()
<< "variables can have either an initializer or an initial value";
} else if (op.initializer().hasValue()) {
// Ensure initializer returns the same type as the variable.
auto *symbolOp =
SymbolTable::lookupNearestSymbolFrom(op, op.initializer().getValue());
if (!symbolOp) {
return op.emitOpError() << "initializer function "
<< op.initializer().getValue() << " not found";
}
auto initializerOp = dyn_cast<FuncOp>(symbolOp);
if (initializerOp.getNumArguments() != 0 ||
initializerOp.getNumResults() != 1 ||
initializerOp.getType().getResult(0) != op.type()) {
return op.emitOpError()
<< "initializer type mismatch; variable " << op.sym_name()
<< " is " << op.type() << " but initializer function "
<< initializerOp.getName() << " is " << initializerOp.getType();
}
} else if (op.initial_value().hasValue()) {
// Ensure the value is something we can store in the variable
if (!isVariableTypeCompatible(op.type(), op.initial_value()->getType())) {
return op.emitOpError()
<< "initial value type mismatch; variable " << op.sym_name()
<< " is " << op.type() << " but initial value provided is "
<< op.initial_value()->getType();
}
}
return success();
}
void VariableOp::build(OpBuilder &builder, OperationState &state,
StringRef name, bool isMutable, FuncOp initializer,
ArrayRef<NamedAttribute> attrs) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
if (isMutable) {
state.addAttribute("is_mutable", builder.getUnitAttr());
}
state.addAttribute("initializer", builder.getSymbolRefAttr(initializer));
state.addAttribute("type", TypeAttr::get(initializer.getType().getResult(0)));
state.attributes.append(attrs.begin(), attrs.end());
}
void VariableOp::build(OpBuilder &builder, OperationState &result,
StringRef name, bool isMutable, Type type,
Attribute initialValue, ArrayRef<NamedAttribute> attrs) {
result.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
if (isMutable) {
result.addAttribute("is_mutable", builder.getUnitAttr());
}
result.addAttribute("initial_value", initialValue);
result.addAttribute("type", TypeAttr::get(type));
result.attributes.append(attrs.begin(), attrs.end());
}
void VariableOp::build(OpBuilder &builder, OperationState &result,
StringRef name, bool isMutable, Type type,
ArrayRef<NamedAttribute> attrs) {
result.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
if (isMutable) {
result.addAttribute("is_mutable", builder.getUnitAttr());
}
result.addAttribute("type", TypeAttr::get(type));
result.attributes.append(attrs.begin(), attrs.end());
}
//===----------------------------------------------------------------------===//
// flow.variable.load
//===----------------------------------------------------------------------===//
void VariableLoadOp::build(OpBuilder &builder, OperationState &state,
VariableOp variableOp,
ArrayRef<NamedAttribute> attrs) {
state.addTypes({variableOp.type()});
state.addAttribute("variable", builder.getSymbolRefAttr(variableOp));
state.attributes.append(attrs.begin(), attrs.end());
}
void VariableLoadOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
// HACK: works around the lack of symbol side effects in mlir by only saying
// we have a side-effect if the variable we are loading is mutable.
auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(*this, variable());
assert(symbolOp);
auto variableOp = dyn_cast<VariableOp>(symbolOp);
if (variableOp.is_mutable()) {
effects.emplace_back(MemoryEffects::Read::get());
}
}
static LogicalResult verifyVariableLoadOp(VariableLoadOp &op) {
auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(op, op.variable());
if (!symbolOp) {
return op.emitOpError() << "undefined variable: " << op.variable();
}
auto variableOp = dyn_cast<VariableOp>(symbolOp);
auto loadType = op.result().getType();
if (!isVariableTypeCompatible(variableOp.type(), loadType)) {
return op.emitOpError()
<< "variable type mismatch; variable " << op.variable() << " is "
<< variableOp.type() << " but load is " << loadType;
}
return success();
}
VariableOp VariableLoadOp::getLoadedVariable() {
return SymbolTable::lookupNearestSymbolFrom<IREE::Flow::VariableOp>(
getOperation()->getParentOp(), variable());
}
//===----------------------------------------------------------------------===//
// flow.variable.load.indirect
//===----------------------------------------------------------------------===//
static LogicalResult verifyVariableLoadIndirectOp(VariableLoadIndirectOp &op) {
auto variableType =
op.variable().getType().cast<IREE::Util::PtrType>().getTargetType();
auto loadType = op.result().getType();
if (!isVariableTypeCompatible(variableType, loadType)) {
return op.emitOpError() << "variable type mismatch; variable pointer is "
<< variableType << " but load is " << loadType;
}
return success();
}
//===----------------------------------------------------------------------===//
// flow.variable.store
//===----------------------------------------------------------------------===//
static LogicalResult verifyVariableStoreOp(VariableStoreOp &op) {
auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(op, op.variable());
if (!symbolOp) {
return op.emitOpError() << "undefined variable: " << op.variable();
}
auto variableOp = dyn_cast<VariableOp>(symbolOp);
auto storeType = op.value().getType();
if (!isVariableTypeCompatible(variableOp.type(), storeType)) {
return op.emitOpError()
<< "variable type mismatch; variable " << op.variable() << " is "
<< variableOp.type() << " but store is " << storeType;
}
if (!variableOp.is_mutable()) {
return op.emitOpError() << "variable " << op.variable()
<< " is not mutable and cannot be stored to";
}
return success();
}
//===----------------------------------------------------------------------===//
// flow.variable.store.indirect
//===----------------------------------------------------------------------===//
static LogicalResult verifyVariableStoreIndirectOp(
VariableStoreIndirectOp &op) {
auto variableType =
op.variable().getType().cast<IREE::Util::PtrType>().getTargetType();
auto storeType = op.value().getType();
if (!isVariableTypeCompatible(variableType, storeType)) {
return op.emitOpError() << "variable type mismatch; variable pointer is "
<< variableType << " but store is " << storeType;
}
return success();
}
//===----------------------------------------------------------------------===//
// flow.dispatch.tensor.load
//===----------------------------------------------------------------------===//
/// Extracts static and dynamic values from list of `OpFoldResult`.
static void processMixedOperands(ArrayRef<OpFoldResult> valueOrAttrs,
SmallVectorImpl<Value> &dynamicValues,
SmallVectorImpl<int64_t> &staticValues,
int64_t dynamicIndexValue) {
for (OpFoldResult valueOrAttr : valueOrAttrs) {
if (auto value = valueOrAttr.dyn_cast<Value>()) {
dynamicValues.push_back(value);
staticValues.push_back(dynamicIndexValue);
} else {
auto operandValue =
valueOrAttr.dyn_cast<Attribute>().cast<IntegerAttr>().getInt();
staticValues.push_back(operandValue);
}
}
}
RankedTensorType DispatchTensorLoadOp::inferResultType(
IREE::Flow::DispatchTensorType sourceType,
ArrayRef<OpFoldResult> mixedSizes) {
auto shape = llvm::to_vector<4>(
llvm::map_range(mixedSizes, [&](OpFoldResult valueOrAttr) -> int64_t {
if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
return attr.cast<IntegerAttr>().getInt();
}
return DispatchTensorType::kDynamicSize;
}));
return RankedTensorType::get(shape, sourceType.getElementType());
}
void DispatchTensorLoadOp::build(OpBuilder &builder, OperationState &state,
RankedTensorType returnType, Value source,
ArrayRef<NamedAttribute> attributes) {
build(builder, state, returnType, source, ArrayRef<Value>(),
ArrayRef<Value>(), ArrayRef<Value>(), builder.getI64ArrayAttr({}),
builder.getI64ArrayAttr({}), builder.getI64ArrayAttr({}));
}
void DispatchTensorLoadOp::build(OpBuilder &builder, OperationState &state,
RankedTensorType returnType, Value source,
ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides,
ArrayRef<NamedAttribute> attributes) {
SmallVector<Value> offsets;
SmallVector<Value> sizes;
SmallVector<Value> strides;
SmallVector<int64_t> staticOffsets;
SmallVector<int64_t> staticSizes;
SmallVector<int64_t> staticStrides;
processMixedOperands(mixedOffsets, offsets, staticOffsets,
ShapedType::kDynamicStrideOrOffset);
processMixedOperands(mixedSizes, sizes, staticSizes,
ShapedType::kDynamicSize);
processMixedOperands(mixedStrides, strides, staticStrides,
ShapedType::kDynamicStrideOrOffset);
build(builder, state, returnType, source, offsets, sizes, strides,
builder.getI64ArrayAttr(staticOffsets),
builder.getI64ArrayAttr(staticSizes),
builder.getI64ArrayAttr(staticStrides));
}
void DispatchTensorLoadOp::build(OpBuilder &builder, OperationState &state,
Value source,
ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides,
ArrayRef<NamedAttribute> attributes) {
auto returnType =
inferResultType(source.getType().cast<DispatchTensorType>(), mixedSizes);
build(builder, state, returnType, source, mixedOffsets, mixedSizes,
mixedStrides);
}
//===----------------------------------------------------------------------===//
// flow.dispatch.tensor.store
//===----------------------------------------------------------------------===//
void DispatchTensorStoreOp::build(OpBuilder &builder, OperationState &state,
Value value, Value target,
ArrayRef<NamedAttribute> attributes) {
build(builder, state, ArrayRef<Type>(), value, target, ArrayRef<Value>(),
ArrayRef<Value>(), ArrayRef<Value>(), builder.getI64ArrayAttr({}),
builder.getI64ArrayAttr({}), builder.getI64ArrayAttr({}));
}
//===----------------------------------------------------------------------===//
// flow.dispatch.workgroups
//===----------------------------------------------------------------------===//
void DispatchWorkgroupsOp::build(OpBuilder &builder, OperationState &state,
ValueRange workgroupCount,
TypeRange resultTypes, ValueRange resultDims,
ValueRange operands, ValueRange operandDims,
ArrayRef<int64_t> tiedOperands,
ArrayRef<NamedAttribute> attributes) {
state.addTypes(resultTypes);
state.addOperands(workgroupCount);
state.addOperands(operands);
state.addOperands(operandDims);
state.addOperands(resultDims);
state.addAttributes(attributes);
state.attributes.erase(IREE::Util::TiedOpInterface::getStorageAttrName());
state.addAttribute(IREE::Util::TiedOpInterface::getStorageAttrName(),
builder.getIndexArrayAttr(tiedOperands));
state.attributes.erase("operand_segment_sizes");
state.addAttribute("operand_segment_sizes",
builder.getI32VectorAttr({
static_cast<int32_t>(workgroupCount.size()),
static_cast<int32_t>(operands.size()),
static_cast<int32_t>(operandDims.size()),
static_cast<int32_t>(resultDims.size()),
}));
auto *body = state.addRegion();
assert(body->begin() == body->end());
{
OpBuilder::InsertionGuard g(builder);
builder.createBlock(body); // createBlock implicitly moves IP, RAII away...
}
llvm::BitVector operandAliases(llvm::size(operands), false);
llvm::BitVector resultAliases(llvm::size(resultTypes), false);
for (unsigned resultIndex = 0; resultIndex < tiedOperands.size();
++resultIndex) {
int64_t tiedOperandIndex = tiedOperands[resultIndex];
if (tiedOperandIndex != IREE::Util::TiedOpInterface::kUntiedIndex) {
operandAliases[tiedOperandIndex] = true;
resultAliases[resultIndex] = true;
}
}
for (auto operand : llvm::enumerate(operands)) {
Type type = operand.value().getType();
if (auto tensorType = type.dyn_cast<TensorType>()) {
type = DispatchTensorType::get(operandAliases[operand.index()]
? TensorAccess::ReadWrite
: TensorAccess::ReadOnly,
tensorType);
}
body->addArgument(type);
}
for (auto resultType : llvm::enumerate(resultTypes)) {
if (resultAliases[resultType.index()]) {
// Already handled by an aliased operand.
continue;
}
Type type = resultType.value();
if (auto tensorType = type.dyn_cast<TensorType>()) {
type = DispatchTensorType::get(TensorAccess::WriteOnly, tensorType);
}
body->addArgument(type);
}
assert(std::next(body->begin()) == body->end());
}
static ParseResult parseDispatchWorkgroupBody(OpAsmParser &parser,
TypeRange operandTypes,
TypeRange resultTypes,
Region &body) {
SmallVector<OpAsmParser::OperandType, 16> regionArgs;
SmallVector<Type, 16> regionArgTypes;
if (failed(parser.parseLParen())) {
return failure();
}
if (failed(parser.parseOptionalRParen())) {
do {
// Reserve entries in the lists.
regionArgs.emplace_back();
regionArgTypes.emplace_back();
if (failed(parser.parseRegionArgument(regionArgs.back())) ||
failed(parser.parseColonType(regionArgTypes.back()))) {
return failure();
}
} while (succeeded(parser.parseOptionalComma()));
if (failed(parser.parseRParen())) {
return failure();
}
}
return parser.parseRegion(body, regionArgs, regionArgTypes,
/*enableNameShadowing=*/true);
}
static void printDispatchWorkgroupBody(OpAsmPrinter &p, Operation *op,
TypeRange operandTypes,
TypeRange resultTypes, Region &body) {
p << "(";
interleaveComma(body.getArguments(), p, [&](BlockArgument arg) {
p << arg;
p << ": ";
p << arg.getType();
});
p << ")";
p.printRegion(body, /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
}
static LogicalResult verifyDispatchWorkgroupsOp(DispatchWorkgroupsOp op) {
if (op.workgroup_count().empty()) {
return op.emitOpError() << "at least one workgroup dimension is required";
}
if (failed(verifyOpDynamicDims(op, op.operands(), op.operand_dims())) ||
failed(verifyOpDynamicDims(op, op.results(), op.result_dims()))) {
return failure();
}
return success();
}
Value DispatchWorkgroupsOp::buildOperandRankedShape(unsigned idx,
OpBuilder &builder) {
return Shape::buildRankedShapeForValueInList(getLoc(), idx, getOperands(),
operand_dims(), builder);
}
Value DispatchWorkgroupsOp::buildResultRankedShape(unsigned idx,
OpBuilder &builder) {
return Shape::buildRankedShapeForValueInList(getLoc(), idx, getResults(),
result_dims(), builder);
}
Operation::operand_range DispatchWorkgroupsOp::getClosureOperands() {
return operands();
}
Operation::result_range DispatchWorkgroupsOp::getClosureResults() {
return results();
}
// Inline operations that the dispatch region can handle natively.
static bool canDispatchRegionContainOp(Operation *op) {
// Inline constant operations that are splat or small constants.
if (auto constantOp = dyn_cast<ConstantOp>(op)) {
auto constantValueAttr = constantOp.getValue();
auto constantType = constantOp.getType();
if (constantValueAttr.isa<SplatElementsAttr>()) {
return true;
} else if (auto denseAttr =
constantValueAttr.dyn_cast<DenseElementsAttr>()) {
// TODO(GH-4897): Non-splat constants seems to have an issue on the LLVM
// side. Uncomment after that is fixed.
auto shapedType = constantOp.getType().cast<ShapedType>();
uint64_t estimatedByteLength =
(shapedType.getNumElements() * shapedType.getElementTypeBitWidth()) /
8;
return denseAttr.isSplat() ||
estimatedByteLength <= clInlineConstantByteLength;
} else if (constantType.isIntOrIndexOrFloat()) {
return true;
}
}
return false;
}
bool DispatchWorkgroupsOp::canClosureContainOp(Operation *op) {
return canDispatchRegionContainOp(op);
}
bool DispatchWorkgroupsOp::isOutputReadWithinRegion(unsigned resultIndex) {
unsigned startIndex = getBody()->getNumArguments() - getNumResults();
BlockArgument arg = body().front().getArgument(startIndex + resultIndex);
// If argument is of `writeonly` access, then it is not read by construction.
if (arg.getType().cast<DispatchTensorType>().getAccess() ==
TensorAccess::WriteOnly) {
return false;
}
// If the argument is a result with `readwrite` access, return false if the
// value is only written to. Check this by looking at the uses of the argument
// being only the `target` of `flow.dispatch.tensor.store` ops.
for (OpOperand &uses : arg.getUses()) {
auto storeOp = dyn_cast<DispatchTensorStoreOp>(uses.getOwner());
if (!(storeOp && storeOp.target() == uses.get())) {
return true;
}
}
return false;
}
ClosureOpInterface
DispatchWorkgroupsOp::cloneReplacementExcludingOperandsAndResults(
ArrayRef<unsigned> excludedOperandIndices,
ArrayRef<unsigned> excludedResultIndices, PatternRewriter &rewriter) {
SmallVector<Type, 4> newResultTypes = llvm::to_vector<4>(getResultTypes());
SmallVector<Value, 4> newResultDims = llvm::to_vector<4>(result_dims());
SmallVector<Value, 4> newOperandsValues = llvm::to_vector<4>(operands());
SmallVector<Value, 4> newOperandDims = llvm::to_vector<4>(operand_dims());
excludeClosureOperandsAndResults(newOperandsValues, newOperandDims,
excludedOperandIndices, newResultTypes,
newResultDims, excludedResultIndices);
auto newTiedOperandIndices =
llvm::to_vector<4>(getTiedResultOperandIndices());
// TODO(benvanik): all this offset stuff is confusing and should be reworked.
// We should probably have absolute indices and relative indices, or just one
// or the other, and not be crossing the streams. The way things are offset
// is the same as variadic ODS operands for consistency, but just like ODS
// operands half of the code assumes its within a particular ODS operand and
// half the code assumes it's within the flattened set of all Operation
// operands.
unsigned tiedOperandOffset = getTiedOperandsIndexAndLength().first;
for (unsigned i = 0; i < newTiedOperandIndices.size(); ++i) {
if (newTiedOperandIndices[i] != IREE::Util::TiedOpInterface::kUntiedIndex) {
newTiedOperandIndices[i] -= tiedOperandOffset;
}
}
// This need to happen *after* accounting for tied operand offset, given that
// all excluded operand/result indices are relative ranges.
IREE::Util::excludeTiedOperandAndResultIndices(
excludedOperandIndices, excludedResultIndices, newTiedOperandIndices);
auto newOp = rewriter.create<DispatchWorkgroupsOp>(
getLoc(), workgroup_count(), newResultTypes, newResultDims,
newOperandsValues, newOperandDims, newTiedOperandIndices,
getOperation()->getAttrs());
auto &newBody = newOp.getClosureBodyRegion();
newBody.takeBody(getClosureBodyRegion());
// Use old index when erasing ops.
unsigned baseResultIndex = operands().size();
// For dropped results, erase all the store-op uses. It is a pre-requisite
// that the result can be dropped only if it is written within the dispatch
// region op.
auto erasedArguments = llvm::to_vector<4>(excludedOperandIndices);
for (unsigned i = baseResultIndex, e = newBody.getNumArguments(); i != e;
++i) {
if (!is_contained(excludedResultIndices, i - baseResultIndex)) continue;
auto arg = newBody.front().getArgument(i);
for (OpOperand &user : llvm::make_early_inc_range(arg.getUses())) {
auto storeOp = dyn_cast<DispatchTensorStoreOp>(user.getOwner());
if (storeOp && storeOp.target() == user.get()) {
rewriter.eraseOp(storeOp);
}
}
erasedArguments.push_back(i);
}
newBody.front().eraseArguments(erasedArguments);
return newOp;
}
std::pair<unsigned, unsigned>
DispatchWorkgroupsOp::getTiedOperandsIndexAndLength() {
return getODSOperandIndexAndLength(1);
}
//===----------------------------------------------------------------------===//
// flow.dispatch.workgroup.*
//===----------------------------------------------------------------------===//
void DispatchWorkgroupRankOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "workgroup_rank");
}
static void getAsmResultNamesForDispatchWorkgroupInfoOp(
StringRef prefix, const APInt &dimension, Value result,
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result, (prefix + std::to_string(dimension.getZExtValue())).str());
}
void DispatchWorkgroupIDOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
getAsmResultNamesForDispatchWorkgroupInfoOp("workgroup_id_", dimension(),
result(), setNameFn);
}
void DispatchWorkgroupCountOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
getAsmResultNamesForDispatchWorkgroupInfoOp("workgroup_count_", dimension(),
result(), setNameFn);
}
void DispatchWorkgroupSizeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
getAsmResultNamesForDispatchWorkgroupInfoOp("workgroup_size_", dimension(),
result(), setNameFn);
}
template <typename T>
static LogicalResult verifyDispatchWorkgroupInfoOp(T op) {
size_t dimCount = 0;
if (auto dispatchOp = op->template getParentOfType<DispatchWorkgroupsOp>()) {
dimCount = dispatchOp.workgroup_count().size();
}
uint64_t dimension = op.dimension().getZExtValue();
if (dimCount != 0 && (dimension < 0 || dimension >= dimCount)) {
return op.emitOpError()
<< "dimension " << dimension
<< " out of bounds of dispatch dimensions; expected [0, "
<< (dimCount - 1) << ")";
}
return success();
}
//===----------------------------------------------------------------------===//
// flow.dispatch.shape
//===----------------------------------------------------------------------===//
void DispatchShapeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
// TODO(benvanik): since we know these are arguments, we could map them based
// on index (so we get arg0_shape, ret0_shape, etc).
setNameFn(result(), "shape");
}
LogicalResult DispatchShapeOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto dispatchTensorType = operands[0].getType().cast<DispatchTensorType>();
auto shape = dispatchTensorType.getShape();
auto rankedShapeType = Shape::RankedShapeType::get(shape, context);
inferredReturnTypes.assign({rankedShapeType});
return success();
}
//===----------------------------------------------------------------------===//
// flow.executable
//===----------------------------------------------------------------------===//
void ExecutableOp::build(OpBuilder &builder, OperationState &state,
StringRef name) {
ensureTerminator(*state.addRegion(), builder, state.location);
state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
}
static ParseResult parseExecutableOp(OpAsmParser &parser,
OperationState *result) {
StringAttr nameAttr;
if (failed(parser.parseSymbolName(nameAttr,
mlir::SymbolTable::getSymbolAttrName(),
result->attributes)) ||
failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
// Parse the module body.
auto *body = result->addRegion();
if (failed(parser.parseRegion(*body, llvm::None, llvm::None))) {
return failure();
}
// Ensure that this module has a valid terminator.
ExecutableOp::ensureTerminator(*body, parser.getBuilder(), result->location);
return success();
}
static void printExecutableOp(OpAsmPrinter &p, ExecutableOp op) {
p << op.getOperationName() << ' ';
p.printSymbolName(op.sym_name());
p.printOptionalAttrDictWithKeyword(
op->getAttrs(),
/*elidedAttrs=*/{mlir::SymbolTable::getSymbolAttrName()});
p.printRegion(op.body(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
}
static LogicalResult verifyExecutableOp(ExecutableOp op) {
// TODO(benvanik): check export name conflicts.
return success();
}
//===----------------------------------------------------------------------===//
// flow.dispatch.entry
//===----------------------------------------------------------------------===//
static ParseResult parseDispatchEntryOp(OpAsmParser &parser,
OperationState *result) {
FlatSymbolRefAttr functionRefAttr;
if (failed(parser.parseAttribute(functionRefAttr, "function_ref",
result->attributes))) {
return failure();
}
if (succeeded(parser.parseOptionalKeyword("as"))) {
StringAttr exportNameAttr;
if (failed(parser.parseLParen()) ||
failed(parser.parseAttribute(exportNameAttr, "sym_name",
result->attributes)) ||
failed(parser.parseRParen())) {
return failure();
}
} else {
result->addAttribute("sym_name", parser.getBuilder().getStringAttr(
functionRefAttr.getValue()));
}
if (failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
return success();
}
static void printDispatchEntryOp(OpAsmPrinter &p, DispatchEntryOp op) {
p << op.getOperationName() << ' ';
p.printSymbolName(op.function_ref());
if (op.sym_name() != op.function_ref()) {
p << " as(\"" << op.sym_name() << "\")";
}
p.printOptionalAttrDictWithKeyword(
op->getAttrs(), /*elidedAttrs=*/{"function_ref", "sym_name"});
}
//===----------------------------------------------------------------------===//
// flow.dispatch
//===----------------------------------------------------------------------===//
void DispatchOp::build(OpBuilder &builder, OperationState &state,
DispatchEntryOp entryPoint, ValueRange workgroupCount,
TypeRange resultTypes, ValueRange resultDims,
ValueRange operands, ValueRange operandDims,
ArrayAttr tiedOperands,
ArrayRef<NamedAttribute> attributes) {
StringRef executableOpSymName =
entryPoint->getParentOp()
->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
.getValue();
state.addAttribute(
"entry_point",
builder.getSymbolRefAttr(executableOpSymName,
{builder.getSymbolRefAttr(entryPoint)}));
state.addOperands(workgroupCount);
state.addTypes(resultTypes);
state.addOperands(operands);
state.addOperands(operandDims);
state.addOperands(resultDims);
state.addAttributes(attributes);
state.attributes.erase(IREE::Util::TiedOpInterface::getStorageAttrName());
state.addAttribute(IREE::Util::TiedOpInterface::getStorageAttrName(),
tiedOperands);
state.attributes.erase("operand_segment_sizes");
state.addAttribute("operand_segment_sizes",
builder.getI32VectorAttr({
static_cast<int32_t>(workgroupCount.size()),
static_cast<int32_t>(operands.size()),
static_cast<int32_t>(operandDims.size()),
static_cast<int32_t>(resultDims.size()),
}));
}
StringRef DispatchOp::executable() { return entry_point().getRootReference(); }
FunctionType DispatchOp::getEntryPointType() {
SmallVector<Type, 8> argTypes(operand_type_range{operands()});
return FunctionType::get(getContext(), argTypes, getResultTypes());
}
static LogicalResult verifyDispatchOp(DispatchOp op) {
if (op.workgroup_count().empty()) {
return op.emitOpError() << "at least one workgroup dimension is required";
}
if (failed(verifyOpDynamicDims(op, op.operands(), op.operand_dims())) ||
failed(verifyOpDynamicDims(op, op.results(), op.result_dims()))) {
return failure();
}
return success();
}
Value DispatchOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValueInList(getLoc(), idx, getOperands(),
operand_dims(), builder);
}
Value DispatchOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValueInList(getLoc(), idx, getResults(),
result_dims(), builder);
}
std::pair<unsigned, unsigned> DispatchOp::getTiedOperandsIndexAndLength() {
return getODSOperandIndexAndLength(1);
}
//===----------------------------------------------------------------------===//
// flow.tensor.reshape
//===----------------------------------------------------------------------===//
Value TensorReshapeOp::buildOperandRankedShape(unsigned idx,
OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), source(), source_dims(),
builder);
}
Value TensorReshapeOp::buildResultRankedShape(unsigned idx,
OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), result(), result_dims(),
builder);
}
Value TensorReshapeOp::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(source());
}
::llvm::Optional<unsigned> TensorReshapeOp::getTiedResultOperandIndex(
unsigned resultIndex) {
return {0}; // source
}
SmallVector<int64_t, 4> TensorReshapeOp::getTiedResultOperandIndices() {
return {0}; // source
}
//===----------------------------------------------------------------------===//
// flow.tensor.*
//===----------------------------------------------------------------------===//
Value TensorLoadOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), source(), source_dims(),
builder);
}
Value TensorLoadOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
return {};
}
Value TensorStoreOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), target(), target_dims(),
builder);
}
Value TensorStoreOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), result(), target_dims(),
builder);
}
Value TensorSplatOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) {
return {};
}
Value TensorSplatOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), result(), result_dims(),
builder);
}
Value TensorCloneOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), operand(), operand_dims(),
builder);
}
Value TensorCloneOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), result(), operand_dims(),
builder);
}
Value TensorSliceOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), source(), source_dims(),
builder);
}
Value TensorSliceOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), result(), result_dims(),
builder);
}
//===----------------------------------------------------------------------===//
// flow.tensor.update
//===----------------------------------------------------------------------===//
void TensorUpdateOp::build(OpBuilder &builder, OperationState &state,
Value target, ValueRange startIndices,
Value update) {
auto targetDims =
Shape::buildOrFindDynamicDimsForValue(state.location, target, builder);
auto updateDims =
Shape::buildOrFindDynamicDimsForValue(state.location, update, builder);
build(builder, state, target.getType(), target, targetDims, startIndices,
update, updateDims, builder.getIndexArrayAttr({0}));
}
static LogicalResult verifyTensorUpdateOp(TensorUpdateOp op) {
if (failed(verifyOpDynamicDims(op, {op.update()}, op.update_dims())) ||
failed(verifyOpDynamicDims(op, {op.target()}, op.target_dims()))) {
return failure();
}
return success();
}
Value TensorUpdateOp::buildOperandRankedShape(unsigned idx,
OpBuilder &builder) {
if (idx == 0) {
return Shape::buildRankedShapeForValueInList(getLoc(), idx, getOperands(),
target_dims(), builder);
} else {
return Shape::buildRankedShapeForValueInList(getLoc(), idx, getOperands(),
update_dims(), builder);
}
}
Value TensorUpdateOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), target(), target_dims(),
builder);
}
Value TensorUpdateOp::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(target());
}
::llvm::Optional<unsigned> TensorUpdateOp::getTiedResultOperandIndex(
unsigned resultIndex) {
return {0}; // target
}
SmallVector<int64_t, 4> TensorUpdateOp::getTiedResultOperandIndices() {
return {0}; // target
}
//===----------------------------------------------------------------------===//
// flow.ex.stream.fragment
//===----------------------------------------------------------------------===//
void ExStreamFragmentOp::build(OpBuilder &builder, OperationState &state,
TypeRange resultTypes, ValueRange resultDims,
ValueRange operands, ValueRange operandDims,
ArrayRef<int64_t> tiedOperands,
ArrayRef<NamedAttribute> attributes) {
state.addTypes(resultTypes);
state.addOperands(operands);
state.addOperands(operandDims);
state.addOperands(resultDims);
state.addAttributes(attributes);
state.attributes.erase(IREE::Util::TiedOpInterface::getStorageAttrName());
state.addAttribute(IREE::Util::TiedOpInterface::getStorageAttrName(),
builder.getIndexArrayAttr(tiedOperands));
state.attributes.erase("operand_segment_sizes");
state.addAttribute("operand_segment_sizes",
builder.getI32VectorAttr({
static_cast<int32_t>(operands.size()),
static_cast<int32_t>(operandDims.size()),
static_cast<int32_t>(resultDims.size()),
}));
state.addRegion();
}
static LogicalResult verifyExStreamFragmentOp(ExStreamFragmentOp op) {
if (failed(verifyOpDynamicDims(op, op.operands(), op.operand_dims())) ||
failed(verifyOpDynamicDims(op, op.results(), op.result_dims()))) {
return failure();
}
return success();
}
static ParseResult parseStreamFragmentBody(OpAsmParser &parser,
TypeRange operandTypes,
TypeRange resultTypes,
ArrayAttr tiedOperands,
Region &body) {
auto loc = parser.getCurrentLocation();
SmallVector<OpAsmParser::OperandType, 16> regionArgs;
SmallVector<Type, 16> regionArgTypes;
if (failed(parser.parseLParen())) {
return failure();
}
if (failed(parser.parseOptionalRParen())) {
do {
// Reserve entries in the lists.
regionArgs.emplace_back();
regionArgTypes.emplace_back();
if (failed(parser.parseRegionArgument(regionArgs.back())) ||
failed(parser.parseColonType(regionArgTypes.back()))) {
return failure();
}
} while (succeeded(parser.parseOptionalComma()));
if (failed(parser.parseRParen())) {
return failure();
}
}
SmallVector<Type, 4> regionResultTypes;
if (failed(parser.parseArrowTypeList(regionResultTypes))) return failure();
if (regionArgs.size() != operandTypes.size()) {
return parser.emitError(loc, "region operand list mismatch");
}
if (regionResultTypes.size() != resultTypes.size()) {
return parser.emitError(loc, "region result list mismatch");
}
return parser.parseRegion(body, regionArgs, regionArgTypes,
/*enableNameShadowing=*/true);
}
static void printStreamFragmentBody(OpAsmPrinter &p, Operation *op,
TypeRange operandTypes,
TypeRange resultTypes,
ArrayAttr tiedOperands, Region &body) {
p << "(";
llvm::interleaveComma(body.getArguments(), p, [&](BlockArgument arg) {
p << arg;
p << ": ";
p << arg.getType();
});
p << ") -> ";
if (resultTypes.size() != 1) p << "(";
for (unsigned i = 0; i < resultTypes.size(); ++i) {
p.printType(resultTypes[i]);
if (i < resultTypes.size() - 1) p << ", ";
}
if (resultTypes.size() != 1) p << ")";
p.printRegion(body, /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
}
Value ExStreamFragmentOp::buildOperandRankedShape(unsigned idx,
OpBuilder &builder) {
return Shape::buildRankedShapeForValueInList(getLoc(), idx, getOperands(),
operand_dims(), builder);
}
Value ExStreamFragmentOp::buildResultRankedShape(unsigned idx,
OpBuilder &builder) {
return Shape::buildRankedShapeForValueInList(getLoc(), idx, getResults(),
result_dims(), builder);
}
Operation::operand_range ExStreamFragmentOp::getClosureOperands() {
return operands();
}
Operation::result_range ExStreamFragmentOp::getClosureResults() {
return results();
}
bool ExStreamFragmentOp::canClosureContainOp(Operation *op) {
// NOTE: we widen support on new stream ops only - the legacy path isn't worth
// upgrading to support more.
if (auto constantOp = dyn_cast<ConstantOp>(op)) {
return constantOp.getType().isIntOrIndexOrFloat();
}
if (auto loadOp = dyn_cast<VariableLoadOp>(op)) {
// Only allow loads of immutable variables to move into the stream.
// As they are immutable it's always safe to do so as no synchronization at
// the stream entry/exit boundary is required.
//
// Loads of mutable variables may sometimes be safe to move in as well
// however that is best done when we have better cross-stream
// synchronization support and can make those guarantees structurally.
auto variableOp =
SymbolTable::lookupNearestSymbolFrom<VariableOp>(op, loadOp.variable());
return variableOp.is_mutable() == false;
}
return false;
}
bool ExStreamFragmentOp::isOutputReadWithinRegion(unsigned resultIndex) {
return false;
}
ClosureOpInterface
ExStreamFragmentOp::cloneReplacementExcludingOperandsAndResults(
ArrayRef<unsigned> excludedOperandIndices,
ArrayRef<unsigned> excludedResultIndices, PatternRewriter &rewriter) {
SmallVector<Type, 4> newResultTypes = llvm::to_vector<4>(getResultTypes());
SmallVector<Value, 4> newResultDims = llvm::to_vector<4>(result_dims());
SmallVector<Value, 4> newOperandsValues = llvm::to_vector<4>(operands());
SmallVector<Value, 4> newOperandDims = llvm::to_vector<4>(operand_dims());
excludeClosureOperandsAndResults(newOperandsValues, newOperandDims,
excludedOperandIndices, newResultTypes,
newResultDims, excludedResultIndices);
auto newTiedOperandIndices =
llvm::to_vector<4>(getTiedResultOperandIndices());
IREE::Util::excludeTiedOperandAndResultIndices(
excludedOperandIndices, excludedResultIndices, newTiedOperandIndices);
assert(getTiedOperandsIndexAndLength().first == 0 &&
"operands must be the first ODS group");
auto newOp = rewriter.create<ExStreamFragmentOp>(
getLoc(), newResultTypes, newResultDims, newOperandsValues,
newOperandDims, newTiedOperandIndices, getOperation()->getAttrs());
auto &newBody = newOp.getClosureBodyRegion();
newBody.takeBody(getClosureBodyRegion());
eraseRegionResults(newBody, excludedResultIndices);
newBody.front().eraseArguments(excludedOperandIndices);
return newOp;
}
//===----------------------------------------------------------------------===//
// Public methods
//===----------------------------------------------------------------------===//
void populateFlowDispatchCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
DispatchTensorLoadOp::getCanonicalizationPatterns(results, context);
}
} // namespace Flow
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir
//===----------------------------------------------------------------------===//
// TableGen definitions (intentionally last)
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "iree/compiler/Dialect/Flow/IR/FlowOps.cpp.inc" // IWYU pragma: keep