blob: 6513688d6666012dc78dd440706b6eabee525b91 [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Shape/IR/Builders.h"
#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/RegionUtils.h"
static llvm::cl::opt<int> clInlineConstantByteLength(
"iree-flow-inline-constants-max-byte-length",
llvm::cl::desc("Maximum byte-length of constant that can be inlined into a "
"dispatch region"),
llvm::cl::init(256));
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Flow {
//===----------------------------------------------------------------------===//
// Op utilities used within the Flow dialect
//===----------------------------------------------------------------------===//
// Verifies that |dynamicDims| contains the appropriate number of dims for all
// of the dynamic dimensions in |values|.
static LogicalResult verifyOpDynamicDims(Operation *op, ValueRange values,
ValueRange dynamicDims) {
unsigned requiredCount = 0;
for (auto value : values) {
if (auto shapedType = value.getType().dyn_cast<ShapedType>()) {
requiredCount += shapedType.getNumDynamicDims();
}
}
if (dynamicDims.size() != requiredCount) {
return op->emitOpError()
<< "value set has " << requiredCount
<< " dynamic dimensions but only " << dynamicDims.size()
<< " dimension values are attached";
}
return success();
}
//===----------------------------------------------------------------------===//
// flow.dispatch.tensor.load
//===----------------------------------------------------------------------===//
/// Extracts static and dynamic values from list of `OpFoldResult`.
static void processMixedOperands(ArrayRef<OpFoldResult> valueOrAttrs,
SmallVectorImpl<Value> &dynamicValues,
SmallVectorImpl<int64_t> &staticValues,
int64_t dynamicIndexValue) {
for (OpFoldResult valueOrAttr : valueOrAttrs) {
if (auto value = valueOrAttr.dyn_cast<Value>()) {
dynamicValues.push_back(value);
staticValues.push_back(dynamicIndexValue);
} else {
auto operandValue =
valueOrAttr.dyn_cast<Attribute>().cast<IntegerAttr>().getInt();
staticValues.push_back(operandValue);
}
}
}
RankedTensorType DispatchTensorLoadOp::inferResultType(
IREE::Flow::DispatchTensorType sourceType,
ArrayRef<OpFoldResult> mixedSizes) {
auto shape = llvm::to_vector<4>(
llvm::map_range(mixedSizes, [&](OpFoldResult valueOrAttr) -> int64_t {
if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
return attr.cast<IntegerAttr>().getInt();
}
return DispatchTensorType::kDynamicSize;
}));
return RankedTensorType::get(shape, sourceType.getElementType());
}
void DispatchTensorLoadOp::build(OpBuilder &builder, OperationState &state,
RankedTensorType returnType, Value source,
ArrayRef<NamedAttribute> attributes) {
build(builder, state, returnType, source, ArrayRef<Value>(),
ArrayRef<Value>(), ArrayRef<Value>(), builder.getI64ArrayAttr({}),
builder.getI64ArrayAttr({}), builder.getI64ArrayAttr({}));
}
void DispatchTensorLoadOp::build(OpBuilder &builder, OperationState &state,
RankedTensorType returnType, Value source,
ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides,
ArrayRef<NamedAttribute> attributes) {
SmallVector<Value> offsets;
SmallVector<Value> sizes;
SmallVector<Value> strides;
SmallVector<int64_t> staticOffsets;
SmallVector<int64_t> staticSizes;
SmallVector<int64_t> staticStrides;
processMixedOperands(mixedOffsets, offsets, staticOffsets,
ShapedType::kDynamicStrideOrOffset);
processMixedOperands(mixedSizes, sizes, staticSizes,
ShapedType::kDynamicSize);
processMixedOperands(mixedStrides, strides, staticStrides,
ShapedType::kDynamicStrideOrOffset);
build(builder, state, returnType, source, offsets, sizes, strides,
builder.getI64ArrayAttr(staticOffsets),
builder.getI64ArrayAttr(staticSizes),
builder.getI64ArrayAttr(staticStrides));
}
void DispatchTensorLoadOp::build(OpBuilder &builder, OperationState &state,
Value source,
ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides,
ArrayRef<NamedAttribute> attributes) {
auto returnType =
inferResultType(source.getType().cast<DispatchTensorType>(), mixedSizes);
build(builder, state, returnType, source, mixedOffsets, mixedSizes,
mixedStrides);
}
//===----------------------------------------------------------------------===//
// flow.dispatch.tensor.store
//===----------------------------------------------------------------------===//
void DispatchTensorStoreOp::build(OpBuilder &builder, OperationState &state,
Value value, Value target,
ArrayRef<NamedAttribute> attributes) {
build(builder, state, ArrayRef<Type>(), value, target, ArrayRef<Value>(),
ArrayRef<Value>(), ArrayRef<Value>(), builder.getI64ArrayAttr({}),
builder.getI64ArrayAttr({}), builder.getI64ArrayAttr({}));
}
//===----------------------------------------------------------------------===//
// flow.dispatch.workgroups
//===----------------------------------------------------------------------===//
void DispatchWorkgroupsOp::build(OpBuilder &builder, OperationState &state,
ValueRange workgroupCount,
TypeRange resultTypes, ValueRange resultDims,
ValueRange operands, ValueRange operandDims,
ArrayRef<int64_t> tiedOperands,
ArrayRef<NamedAttribute> attributes) {
state.addTypes(resultTypes);
state.addOperands(workgroupCount);
state.addOperands(operands);
state.addOperands(operandDims);
state.addOperands(resultDims);
state.addAttributes(attributes);
state.attributes.erase(IREE::Util::TiedOpInterface::getStorageAttrName());
state.addAttribute(IREE::Util::TiedOpInterface::getStorageAttrName(),
builder.getIndexArrayAttr(tiedOperands));
state.attributes.erase("operand_segment_sizes");
state.addAttribute("operand_segment_sizes",
builder.getI32VectorAttr({
static_cast<int32_t>(workgroupCount.size()),
static_cast<int32_t>(operands.size()),
static_cast<int32_t>(operandDims.size()),
static_cast<int32_t>(resultDims.size()),
}));
auto *body = state.addRegion();
assert(body->begin() == body->end());
{
OpBuilder::InsertionGuard g(builder);
builder.createBlock(body); // createBlock implicitly moves IP, RAII away...
}
llvm::BitVector operandAliases(llvm::size(operands), false);
llvm::BitVector resultAliases(llvm::size(resultTypes), false);
for (unsigned resultIndex = 0; resultIndex < tiedOperands.size();
++resultIndex) {
int64_t tiedOperandIndex = tiedOperands[resultIndex];
if (tiedOperandIndex != IREE::Util::TiedOpInterface::kUntiedIndex) {
operandAliases[tiedOperandIndex] = true;
resultAliases[resultIndex] = true;
}
}
for (auto operand : llvm::enumerate(operands)) {
Type type = operand.value().getType();
if (auto tensorType = type.dyn_cast<TensorType>()) {
type = DispatchTensorType::get(operandAliases[operand.index()]
? TensorAccess::ReadWrite
: TensorAccess::ReadOnly,
tensorType);
}
body->addArgument(type);
}
for (auto resultType : llvm::enumerate(resultTypes)) {
if (resultAliases[resultType.index()]) {
// Already handled by an aliased operand.
continue;
}
Type type = resultType.value();
if (auto tensorType = type.dyn_cast<TensorType>()) {
type = DispatchTensorType::get(TensorAccess::WriteOnly, tensorType);
}
body->addArgument(type);
}
assert(std::next(body->begin()) == body->end());
}
static ParseResult parseDispatchWorkgroupBody(OpAsmParser &parser,
TypeRange operandTypes,
TypeRange resultTypes,
Region &body) {
SmallVector<OpAsmParser::OperandType, 16> regionArgs;
SmallVector<Type, 16> regionArgTypes;
if (failed(parser.parseLParen())) {
return failure();
}
if (failed(parser.parseOptionalRParen())) {
do {
// Reserve entries in the lists.
regionArgs.emplace_back();
regionArgTypes.emplace_back();
if (failed(parser.parseRegionArgument(regionArgs.back())) ||
failed(parser.parseColonType(regionArgTypes.back()))) {
return failure();
}
} while (succeeded(parser.parseOptionalComma()));
if (failed(parser.parseRParen())) {
return failure();
}
}
return parser.parseRegion(body, regionArgs, regionArgTypes,
/*enableNameShadowing=*/true);
}
static void printDispatchWorkgroupBody(OpAsmPrinter &p, Operation *op,
TypeRange operandTypes,
TypeRange resultTypes, Region &body) {
p << "(";
interleaveComma(body.getArguments(), p, [&](BlockArgument arg) {
p << arg;
p << ": ";
p << arg.getType();
});
p << ")";
p.printRegion(body, /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
}
static LogicalResult verifyDispatchWorkgroupsOp(DispatchWorkgroupsOp op) {
if (op.workgroup_count().empty()) {
return op.emitOpError() << "at least one workgroup dimension is required";
}
if (failed(verifyOpDynamicDims(op, op.operands(), op.operand_dims())) ||
failed(verifyOpDynamicDims(op, op.results(), op.result_dims()))) {
return failure();
}
return success();
}
Value DispatchWorkgroupsOp::buildOperandRankedShape(unsigned idx,
OpBuilder &builder) {
return Shape::buildRankedShapeForValueInList(getLoc(), idx, getOperands(),
operand_dims(), builder);
}
Value DispatchWorkgroupsOp::buildResultRankedShape(unsigned idx,
OpBuilder &builder) {
return Shape::buildRankedShapeForValueInList(getLoc(), idx, getResults(),
result_dims(), builder);
}
Operation::operand_range DispatchWorkgroupsOp::getClosureOperands() {
return operands();
}
Operation::result_range DispatchWorkgroupsOp::getClosureResults() {
return results();
}
// Inline operations that the dispatch region can handle natively.
static bool canDispatchRegionContainOp(Operation *op) {
// Inline constant operations that are splat or small constants.
if (auto constantOp = dyn_cast<ConstantOp>(op)) {
auto constantValueAttr = constantOp.getValue();
auto constantType = constantOp.getType();
if (constantValueAttr.isa<SplatElementsAttr>()) {
return true;
} else if (auto denseAttr =
constantValueAttr.dyn_cast<DenseElementsAttr>()) {
// TODO(GH-4897): Non-splat constants seems to have an issue on the LLVM
// side. Uncomment after that is fixed.
auto shapedType = constantOp.getType().cast<ShapedType>();
uint64_t estimatedByteLength =
(shapedType.getNumElements() * shapedType.getElementTypeBitWidth()) /
8;
return denseAttr.isSplat() ||
estimatedByteLength <= clInlineConstantByteLength;
} else if (constantType.isIntOrIndexOrFloat()) {
return true;
}
}
return false;
}
bool DispatchWorkgroupsOp::canClosureContainOp(Operation *op) {
return canDispatchRegionContainOp(op);
}
// Refines the tensor access from what is declared on |type| based on actual
// usage. We expect that the access was set correctly to begin with but today
// we sometimes specify things too wide.
static TensorAccess refineTensorAccess(Value value, DispatchTensorType type) {
auto tensorAccess = type.getAccess();
if (tensorAccess == TensorAccess::ReadWrite) {
// If the argument is a result with `readwrite` access, return false if the
// value is only written to. Check this by looking at the uses of the
// argument being only the `target` of `flow.dispatch.tensor.store` ops.
bool onlyWrites = true;
for (OpOperand &uses : value.getUses()) {
auto storeOp = dyn_cast<DispatchTensorStoreOp>(uses.getOwner());
if (!(storeOp && storeOp.target() == uses.get())) {
onlyWrites = false;
break;
}
}
if (onlyWrites) tensorAccess = TensorAccess::WriteOnly;
}
return tensorAccess;
}
IREE::Util::ValueAccess DispatchWorkgroupsOp::getOperandAccess(
unsigned operandIndex) {
BlockArgument arg = body().front().getArgument(operandIndex);
if (auto tensorType = arg.getType().dyn_cast<DispatchTensorType>()) {
auto tensorAccess = refineTensorAccess(arg, tensorType);
return IREE::Util::ValueAccess(
/*isRead=*/(tensorAccess == TensorAccess::ReadOnly) ||
(tensorAccess == TensorAccess::ReadWrite),
/*isWrite=*/(tensorAccess == TensorAccess::ReadWrite) ||
(tensorAccess == TensorAccess::WriteOnly),
/*isDiscard=*/(tensorAccess == TensorAccess::WriteOnly));
} else {
return IREE::Util::ValueAccess(/*isRead=*/!arg.use_empty(),
/*isWrite=*/false,
/*isDiscard=*/false);
}
}
IREE::Util::ValueAccess DispatchWorkgroupsOp::getResultAccess(
unsigned resultIndex) {
unsigned startIndex = getBody()->getNumArguments() - getNumResults();
BlockArgument arg = body().front().getArgument(startIndex + resultIndex);
if (auto tensorType = arg.getType().dyn_cast<DispatchTensorType>()) {
auto tensorAccess = refineTensorAccess(arg, tensorType);
return IREE::Util::ValueAccess(
/*isRead=*/(tensorAccess == TensorAccess::ReadOnly) ||
(tensorAccess == TensorAccess::ReadWrite),
/*isWrite=*/(tensorAccess == TensorAccess::ReadWrite) ||
(tensorAccess == TensorAccess::WriteOnly),
/*isDiscard=*/(tensorAccess == TensorAccess::WriteOnly));
} else {
return IREE::Util::ValueAccess(/*isRead=*/!arg.use_empty(),
/*isWrite=*/false,
/*isDiscard=*/false);
}
}
IREE::Util::ClosureOpInterface
DispatchWorkgroupsOp::cloneReplacementExcludingOperandsAndResults(
ArrayRef<unsigned> excludedOperandIndices,
ArrayRef<unsigned> excludedResultIndices, PatternRewriter &rewriter) {
SmallVector<Type, 4> newResultTypes = llvm::to_vector<4>(getResultTypes());
SmallVector<Value, 4> newResultDims = llvm::to_vector<4>(result_dims());
SmallVector<Value, 4> newOperandsValues = llvm::to_vector<4>(operands());
SmallVector<Value, 4> newOperandDims = llvm::to_vector<4>(operand_dims());
IREE::Util::excludeClosureOperandsAndResults(
newOperandsValues, newOperandDims, excludedOperandIndices, newResultTypes,
newResultDims, excludedResultIndices);
auto newTiedOperandIndices =
llvm::to_vector<4>(getTiedResultOperandIndices());
// TODO(benvanik): all this offset stuff is confusing and should be reworked.
// We should probably have absolute indices and relative indices, or just one
// or the other, and not be crossing the streams. The way things are offset
// is the same as variadic ODS operands for consistency, but just like ODS
// operands half of the code assumes its within a particular ODS operand and
// half the code assumes it's within the flattened set of all Operation
// operands.
unsigned tiedOperandOffset = getTiedOperandsIndexAndLength().first;
for (unsigned i = 0; i < newTiedOperandIndices.size(); ++i) {
if (newTiedOperandIndices[i] != IREE::Util::TiedOpInterface::kUntiedIndex) {
newTiedOperandIndices[i] -= tiedOperandOffset;
}
}
// This need to happen *after* accounting for tied operand offset, given that
// all excluded operand/result indices are relative ranges.
IREE::Util::excludeTiedOperandAndResultIndices(
excludedOperandIndices, excludedResultIndices, newTiedOperandIndices);
auto newOp = rewriter.create<DispatchWorkgroupsOp>(
getLoc(), workgroup_count(), newResultTypes, newResultDims,
newOperandsValues, newOperandDims, newTiedOperandIndices,
getOperation()->getAttrs());
auto &newBody = newOp.getClosureBodyRegion();
newBody.takeBody(getClosureBodyRegion());
// Use old index when erasing ops.
unsigned baseResultIndex = operands().size();
// For dropped results, erase all the store-op uses. It is a pre-requisite
// that the result can be dropped only if it is written within the dispatch
// region op.
auto erasedArguments = llvm::to_vector<4>(excludedOperandIndices);
for (unsigned i = baseResultIndex, e = newBody.getNumArguments(); i != e;
++i) {
if (!is_contained(excludedResultIndices, i - baseResultIndex)) continue;
auto arg = newBody.front().getArgument(i);
for (OpOperand &user : llvm::make_early_inc_range(arg.getUses())) {
auto storeOp = dyn_cast<DispatchTensorStoreOp>(user.getOwner());
if (storeOp && storeOp.target() == user.get()) {
rewriter.eraseOp(storeOp);
}
}
erasedArguments.push_back(i);
}
newBody.front().eraseArguments(erasedArguments);
return newOp;
}
std::pair<unsigned, unsigned>
DispatchWorkgroupsOp::getTiedOperandsIndexAndLength() {
return getODSOperandIndexAndLength(1);
}
//===----------------------------------------------------------------------===//
// flow.dispatch.workgroup.*
//===----------------------------------------------------------------------===//
void DispatchWorkgroupRankOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "workgroup_rank");
}
static void getAsmResultNamesForDispatchWorkgroupInfoOp(
StringRef prefix, const APInt &dimension, Value result,
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result, (prefix + std::to_string(dimension.getZExtValue())).str());
}
void DispatchWorkgroupIDOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
getAsmResultNamesForDispatchWorkgroupInfoOp("workgroup_id_", dimension(),
result(), setNameFn);
}
void DispatchWorkgroupCountOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
getAsmResultNamesForDispatchWorkgroupInfoOp("workgroup_count_", dimension(),
result(), setNameFn);
}
void DispatchWorkgroupSizeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
getAsmResultNamesForDispatchWorkgroupInfoOp("workgroup_size_", dimension(),
result(), setNameFn);
}
template <typename T>
static LogicalResult verifyDispatchWorkgroupInfoOp(T op) {
size_t dimCount = 0;
if (auto dispatchOp = op->template getParentOfType<DispatchWorkgroupsOp>()) {
dimCount = dispatchOp.workgroup_count().size();
}
uint64_t dimension = op.dimension().getZExtValue();
if (dimCount != 0 && (dimension < 0 || dimension >= dimCount)) {
return op.emitOpError()
<< "dimension " << dimension
<< " out of bounds of dispatch dimensions; expected [0, "
<< (dimCount - 1) << ")";
}
return success();
}
//===----------------------------------------------------------------------===//
// flow.dispatch.shape
//===----------------------------------------------------------------------===//
void DispatchShapeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
// TODO(benvanik): since we know these are arguments, we could map them based
// on index (so we get arg0_shape, ret0_shape, etc).
setNameFn(result(), "shape");
}
LogicalResult DispatchShapeOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto dispatchTensorType = operands[0].getType().cast<DispatchTensorType>();
auto shape = dispatchTensorType.getShape();
auto rankedShapeType = Shape::RankedShapeType::get(shape, context);
inferredReturnTypes.assign({rankedShapeType});
return success();
}
//===----------------------------------------------------------------------===//
// flow.executable
//===----------------------------------------------------------------------===//
void ExecutableOp::build(OpBuilder &builder, OperationState &state,
StringRef name) {
ensureTerminator(*state.addRegion(), builder, state.location);
state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
}
static ParseResult parseExecutableOp(OpAsmParser &parser,
OperationState *result) {
StringAttr nameAttr;
if (failed(parser.parseSymbolName(nameAttr,
mlir::SymbolTable::getSymbolAttrName(),
result->attributes)) ||
failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
// Parse the module body.
auto *body = result->addRegion();
if (failed(parser.parseRegion(*body, llvm::None, llvm::None))) {
return failure();
}
// Ensure that this module has a valid terminator.
ExecutableOp::ensureTerminator(*body, parser.getBuilder(), result->location);
return success();
}
static void printExecutableOp(OpAsmPrinter &p, ExecutableOp op) {
p << ' ';
p.printSymbolName(op.sym_name());
p.printOptionalAttrDictWithKeyword(
op->getAttrs(),
/*elidedAttrs=*/{mlir::SymbolTable::getSymbolAttrName()});
p.printRegion(op.body(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
}
static LogicalResult verifyExecutableOp(ExecutableOp op) {
// TODO(benvanik): check export name conflicts.
return success();
}
//===----------------------------------------------------------------------===//
// flow.dispatch.entry
//===----------------------------------------------------------------------===//
static ParseResult parseDispatchEntryOp(OpAsmParser &parser,
OperationState *result) {
FlatSymbolRefAttr functionRefAttr;
if (failed(parser.parseAttribute(functionRefAttr, "function_ref",
result->attributes))) {
return failure();
}
if (succeeded(parser.parseOptionalKeyword("as"))) {
StringAttr exportNameAttr;
if (failed(parser.parseLParen()) ||
failed(parser.parseAttribute(exportNameAttr, "sym_name",
result->attributes)) ||
failed(parser.parseRParen())) {
return failure();
}
} else {
result->addAttribute("sym_name", parser.getBuilder().getStringAttr(
functionRefAttr.getValue()));
}
if (failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
return success();
}
static void printDispatchEntryOp(OpAsmPrinter &p, DispatchEntryOp op) {
p << ' ';
p.printSymbolName(op.function_ref());
if (op.sym_name() != op.function_ref()) {
p << " as(\"" << op.sym_name() << "\")";
}
p.printOptionalAttrDictWithKeyword(
op->getAttrs(), /*elidedAttrs=*/{"function_ref", "sym_name"});
}
//===----------------------------------------------------------------------===//
// flow.dispatch
//===----------------------------------------------------------------------===//
void DispatchOp::build(OpBuilder &builder, OperationState &state,
DispatchEntryOp entryPoint, ValueRange workgroupCount,
TypeRange resultTypes, ValueRange resultDims,
ValueRange operands, ValueRange operandDims,
ArrayAttr tiedOperands,
ArrayRef<NamedAttribute> attributes) {
StringRef executableOpSymName =
entryPoint->getParentOp()
->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
.getValue();
state.addAttribute(
"entry_point",
SymbolRefAttr::get(builder.getContext(), executableOpSymName,
{SymbolRefAttr::get(entryPoint)}));
state.addOperands(workgroupCount);
state.addTypes(resultTypes);
state.addOperands(operands);
state.addOperands(operandDims);
state.addOperands(resultDims);
state.addAttributes(attributes);
state.attributes.erase(IREE::Util::TiedOpInterface::getStorageAttrName());
state.addAttribute(IREE::Util::TiedOpInterface::getStorageAttrName(),
tiedOperands);
state.attributes.erase("operand_segment_sizes");
state.addAttribute("operand_segment_sizes",
builder.getI32VectorAttr({
static_cast<int32_t>(workgroupCount.size()),
static_cast<int32_t>(operands.size()),
static_cast<int32_t>(operandDims.size()),
static_cast<int32_t>(resultDims.size()),
}));
}
StringAttr DispatchOp::executable() { return entry_point().getRootReference(); }
FunctionType DispatchOp::getEntryPointType() {
SmallVector<Type, 8> argTypes(operand_type_range{operands()});
return FunctionType::get(getContext(), argTypes, getResultTypes());
}
static LogicalResult verifyDispatchOp(DispatchOp op) {
if (op.workgroup_count().empty()) {
return op.emitOpError() << "at least one workgroup dimension is required";
}
if (failed(verifyOpDynamicDims(op, op.operands(), op.operand_dims())) ||
failed(verifyOpDynamicDims(op, op.results(), op.result_dims()))) {
return failure();
}
return success();
}
Value DispatchOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValueInList(getLoc(), idx, getOperands(),
operand_dims(), builder);
}
Value DispatchOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValueInList(getLoc(), idx, getResults(),
result_dims(), builder);
}
std::pair<unsigned, unsigned> DispatchOp::getTiedOperandsIndexAndLength() {
return getODSOperandIndexAndLength(1);
}
//===----------------------------------------------------------------------===//
// flow.tensor.reshape
//===----------------------------------------------------------------------===//
Value TensorReshapeOp::buildOperandRankedShape(unsigned idx,
OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), source(), source_dims(),
builder);
}
Value TensorReshapeOp::buildResultRankedShape(unsigned idx,
OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), result(), result_dims(),
builder);
}
Value TensorReshapeOp::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(source());
}
::llvm::Optional<unsigned> TensorReshapeOp::getTiedResultOperandIndex(
unsigned resultIndex) {
return {0}; // source
}
SmallVector<int64_t, 4> TensorReshapeOp::getTiedResultOperandIndices() {
return {0}; // source
}
//===----------------------------------------------------------------------===//
// flow.tensor.*
//===----------------------------------------------------------------------===//
Value TensorLoadOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), source(), source_dims(),
builder);
}
Value TensorLoadOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
return {};
}
Value TensorStoreOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), target(), target_dims(),
builder);
}
Value TensorStoreOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), result(), target_dims(),
builder);
}
Value TensorSplatOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) {
return {};
}
Value TensorSplatOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), result(), result_dims(),
builder);
}
Value TensorCloneOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), operand(), operand_dims(),
builder);
}
Value TensorCloneOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), result(), operand_dims(),
builder);
}
Value TensorSliceOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), source(), source_dims(),
builder);
}
Value TensorSliceOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), result(), result_dims(),
builder);
}
//===----------------------------------------------------------------------===//
// flow.tensor.update
//===----------------------------------------------------------------------===//
void TensorUpdateOp::build(OpBuilder &builder, OperationState &state,
Value target, ValueRange startIndices,
Value update) {
auto targetDims =
Shape::buildOrFindDynamicDimsForValue(state.location, target, builder);
auto updateDims =
Shape::buildOrFindDynamicDimsForValue(state.location, update, builder);
build(builder, state, target.getType(), target, targetDims, startIndices,
update, updateDims, builder.getIndexArrayAttr({0}));
}
static LogicalResult verifyTensorUpdateOp(TensorUpdateOp op) {
if (failed(verifyOpDynamicDims(op, {op.update()}, op.update_dims())) ||
failed(verifyOpDynamicDims(op, {op.target()}, op.target_dims()))) {
return failure();
}
return success();
}
Value TensorUpdateOp::buildOperandRankedShape(unsigned idx,
OpBuilder &builder) {
if (idx == 0) {
return Shape::buildRankedShapeForValueInList(getLoc(), idx, getOperands(),
target_dims(), builder);
} else {
return Shape::buildRankedShapeForValueInList(getLoc(), idx, getOperands(),
update_dims(), builder);
}
}
Value TensorUpdateOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
return Shape::buildRankedShapeForValue(getLoc(), target(), target_dims(),
builder);
}
Value TensorUpdateOp::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(target());
}
::llvm::Optional<unsigned> TensorUpdateOp::getTiedResultOperandIndex(
unsigned resultIndex) {
return {0}; // target
}
SmallVector<int64_t, 4> TensorUpdateOp::getTiedResultOperandIndices() {
return {0}; // target
}
//===----------------------------------------------------------------------===//
// flow.ex.stream.fragment
//===----------------------------------------------------------------------===//
void ExStreamFragmentOp::build(OpBuilder &builder, OperationState &state,
TypeRange resultTypes, ValueRange resultDims,
ValueRange operands, ValueRange operandDims,
ArrayRef<int64_t> tiedOperands,
ArrayRef<NamedAttribute> attributes) {
state.addTypes(resultTypes);
state.addOperands(operands);
state.addOperands(operandDims);
state.addOperands(resultDims);
state.addAttributes(attributes);
state.attributes.erase(IREE::Util::TiedOpInterface::getStorageAttrName());
state.addAttribute(IREE::Util::TiedOpInterface::getStorageAttrName(),
builder.getIndexArrayAttr(tiedOperands));
state.attributes.erase("operand_segment_sizes");
state.addAttribute("operand_segment_sizes",
builder.getI32VectorAttr({
static_cast<int32_t>(operands.size()),
static_cast<int32_t>(operandDims.size()),
static_cast<int32_t>(resultDims.size()),
}));
state.addRegion();
}
static LogicalResult verifyExStreamFragmentOp(ExStreamFragmentOp op) {
if (failed(verifyOpDynamicDims(op, op.operands(), op.operand_dims())) ||
failed(verifyOpDynamicDims(op, op.results(), op.result_dims()))) {
return failure();
}
return success();
}
static ParseResult parseStreamFragmentBody(OpAsmParser &parser,
TypeRange operandTypes,
TypeRange resultTypes,
ArrayAttr tiedOperands,
Region &body) {
auto loc = parser.getCurrentLocation();
SmallVector<OpAsmParser::OperandType, 16> regionArgs;
SmallVector<Type, 16> regionArgTypes;
if (failed(parser.parseLParen())) {
return failure();
}
if (failed(parser.parseOptionalRParen())) {
do {
// Reserve entries in the lists.
regionArgs.emplace_back();
regionArgTypes.emplace_back();
if (failed(parser.parseRegionArgument(regionArgs.back())) ||
failed(parser.parseColonType(regionArgTypes.back()))) {
return failure();
}
} while (succeeded(parser.parseOptionalComma()));
if (failed(parser.parseRParen())) {
return failure();
}
}
SmallVector<Type, 4> regionResultTypes;
if (failed(parser.parseArrowTypeList(regionResultTypes))) return failure();
if (regionArgs.size() != operandTypes.size()) {
return parser.emitError(loc, "region operand list mismatch");
}
if (regionResultTypes.size() != resultTypes.size()) {
return parser.emitError(loc, "region result list mismatch");
}
return parser.parseRegion(body, regionArgs, regionArgTypes,
/*enableNameShadowing=*/true);
}
static void printStreamFragmentBody(OpAsmPrinter &p, Operation *op,
TypeRange operandTypes,
TypeRange resultTypes,
ArrayAttr tiedOperands, Region &body) {
p << "(";
llvm::interleaveComma(body.getArguments(), p, [&](BlockArgument arg) {
p << arg;
p << ": ";
p << arg.getType();
});
p << ") -> ";
if (resultTypes.size() != 1) p << "(";
for (unsigned i = 0; i < resultTypes.size(); ++i) {
p.printType(resultTypes[i]);
if (i < resultTypes.size() - 1) p << ", ";
}
if (resultTypes.size() != 1) p << ")";
p.printRegion(body, /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
}
Value ExStreamFragmentOp::buildOperandRankedShape(unsigned idx,
OpBuilder &builder) {
return Shape::buildRankedShapeForValueInList(getLoc(), idx, getOperands(),
operand_dims(), builder);
}
Value ExStreamFragmentOp::buildResultRankedShape(unsigned idx,
OpBuilder &builder) {
return Shape::buildRankedShapeForValueInList(getLoc(), idx, getResults(),
result_dims(), builder);
}
Operation::operand_range ExStreamFragmentOp::getClosureOperands() {
return operands();
}
Operation::result_range ExStreamFragmentOp::getClosureResults() {
return results();
}
bool ExStreamFragmentOp::canClosureContainOp(Operation *op) {
// NOTE: we widen support on new stream ops only - the legacy path isn't worth
// upgrading to support more.
if (auto constantOp = dyn_cast<ConstantOp>(op)) {
return constantOp.getType().isIntOrIndexOrFloat();
}
if (auto loadOp = dyn_cast<IREE::Util::GlobalLoadOp>(op)) {
// Only allow loads of immutable variables to move into the stream.
// As they are immutable it's always safe to do so as no synchronization at
// the stream entry/exit boundary is required.
//
// Loads of mutable variables may sometimes be safe to move in as well
// however that is best done when we have better cross-stream
// synchronization support and can make those guarantees structurally.
return loadOp.isGlobalImmutable();
}
return false;
}
IREE::Util::ValueAccess ExStreamFragmentOp::getOperandAccess(
unsigned operandIndex) {
return !isOperandTied(operandIndex) ? IREE::Util::ValueAccess::ReadOnly()
: IREE::Util::ValueAccess::ReadWrite();
}
IREE::Util::ValueAccess ExStreamFragmentOp::getResultAccess(
unsigned resultIndex) {
return getTiedResultOperandIndex(resultIndex).hasValue()
? IREE::Util::ValueAccess::ReadWrite()
: IREE::Util::ValueAccess::DiscardWrite();
}
IREE::Util::ClosureOpInterface
ExStreamFragmentOp::cloneReplacementExcludingOperandsAndResults(
ArrayRef<unsigned> excludedOperandIndices,
ArrayRef<unsigned> excludedResultIndices, PatternRewriter &rewriter) {
SmallVector<Type, 4> newResultTypes = llvm::to_vector<4>(getResultTypes());
SmallVector<Value, 4> newResultDims = llvm::to_vector<4>(result_dims());
SmallVector<Value, 4> newOperandsValues = llvm::to_vector<4>(operands());
SmallVector<Value, 4> newOperandDims = llvm::to_vector<4>(operand_dims());
IREE::Util::excludeClosureOperandsAndResults(
newOperandsValues, newOperandDims, excludedOperandIndices, newResultTypes,
newResultDims, excludedResultIndices);
auto newTiedOperandIndices =
llvm::to_vector<4>(getTiedResultOperandIndices());
IREE::Util::excludeTiedOperandAndResultIndices(
excludedOperandIndices, excludedResultIndices, newTiedOperandIndices);
assert(getTiedOperandsIndexAndLength().first == 0 &&
"operands must be the first ODS group");
auto newOp = rewriter.create<ExStreamFragmentOp>(
getLoc(), newResultTypes, newResultDims, newOperandsValues,
newOperandDims, newTiedOperandIndices, getOperation()->getAttrs());
auto &newBody = newOp.getClosureBodyRegion();
newBody.takeBody(getClosureBodyRegion());
IREE::Util::eraseRegionResults(newBody, excludedResultIndices);
newBody.front().eraseArguments(excludedOperandIndices);
return newOp;
}
//===----------------------------------------------------------------------===//
// Public methods
//===----------------------------------------------------------------------===//
void populateFlowDispatchCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
DispatchTensorLoadOp::getCanonicalizationPatterns(results, context);
}
} // namespace Flow
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir
//===----------------------------------------------------------------------===//
// TableGen definitions (intentionally last)
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "iree/compiler/Dialect/Flow/IR/FlowOps.cpp.inc" // IWYU pragma: keep