blob: c9db68a8e6debf64466e432da6e04590a29a2233 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Stream/Builtins/Builtins.h"
#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/Utils/ModuleUtils.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/RegionUtils.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Stream {
//===----------------------------------------------------------------------===//
// Op utilities used within the stream dialect
//===----------------------------------------------------------------------===//
// Verifies that |dynamicDims| contains the appropriate number of dims for all
// of the dynamic dimensions in |values|.
static LogicalResult verifyOpDynamicDims(Operation *op, ValueRange values,
ValueRange dynamicDims) {
unsigned requiredCount = 0;
for (auto value : values) {
if (auto shapedType = value.getType().dyn_cast<ShapedType>()) {
requiredCount += shapedType.getNumDynamicDims();
}
}
if (dynamicDims.size() != requiredCount) {
return op->emitOpError()
<< "value set has " << requiredCount
<< " dynamic dimensions but only " << dynamicDims.size()
<< " dimension values are attached";
}
return success();
}
// Verifies that |dynamicDims| contains the appropriate number of dims for all
// the dynamic dimensions in |type|.
static LogicalResult verifyOpDynamicDims(Operation *op, TypeRange types,
ValueRange dynamicDims) {
unsigned requiredCount = 0;
for (auto type : types) {
if (auto shapedType = type.dyn_cast<ShapedType>()) {
requiredCount += shapedType.getNumDynamicDims();
}
}
if (dynamicDims.size() != requiredCount) {
return op->emitOpError()
<< "type set has " << requiredCount
<< " dynamic dimensions but only " << dynamicDims.size()
<< " dimension values are attached";
}
return success();
}
// Verifies that |sizes| contains the appropriate number of sizes for all of the
// sized types in |values|.
static LogicalResult verifyOpValueSizes(Operation *op, ValueRange values,
ValueRange sizes) {
unsigned requiredCount = 0;
for (auto value : values) {
if (value.getType().isa<IREE::Util::SizeAwareTypeInterface>()) {
++requiredCount;
}
}
if (sizes.size() != requiredCount) {
return op->emitOpError() << "value set has " << requiredCount
<< " dynamic dimensions but only " << sizes.size()
<< " dimension values are attached";
}
return success();
}
// Verifies that all !stream.resources used within |region| are captured by
// the entry arguments to the region.
static LogicalResult verifyAllResourcesCaptured(Region &region) {
SetVector<Value> availableResources;
for (auto arg : region.front().getArguments()) {
availableResources.insert(arg);
}
for (auto &op : region.front()) {
for (auto result : op.getResults()) {
availableResources.insert(result);
}
for (auto operand : op.getOperands()) {
if (!operand.getType().isa<IREE::Stream::ResourceType>()) continue;
if (!availableResources.contains(operand)) {
return op.emitOpError() << "used resource not listed in explicit "
"captures (or produced internally)";
}
}
}
return success();
}
// Verifies that escaping !stream.resources have the sizes when they are
// yielded match the sizes declared on the parent op. This information is
// redundant but keeps analysis local and agnostic to the parent op structure
// which is useful for when we outline things.
static LogicalResult verifyEscapingResources(Region &region,
ResultRange results,
ValueRange resultSizes) {
// Ensure yielded resources match the signature.
for (auto yieldOp : region.getOps<IREE::Stream::YieldOp>()) {
if (results.size() != yieldOp.operands().size()) {
return yieldOp.emitOpError()
<< "yield result count mismatch with parent op";
}
for (auto it : llvm::zip(results, yieldOp.operands())) {
auto outerValue = std::get<0>(it);
auto innerValue = std::get<1>(it);
if (outerValue.getType() != innerValue.getType()) {
return yieldOp.emitOpError()
<< "result type mismatch: expected " << outerValue.getType()
<< " but got " << innerValue.getType();
}
}
for (auto it : llvm::zip(resultSizes, yieldOp.operand_sizes())) {
auto outerSize = std::get<0>(it);
auto innerSize = std::get<1>(it);
if (outerSize != innerSize) {
return yieldOp.emitOpError() << "result size mismatch";
}
}
}
return success();
}
// Computes the value access bits starting from |rootValue|.
// Traverses the IR graph along tied ops but does not handle branches.
static IREE::Util::ValueAccess computeValueAccess(Value rootValue) {
IREE::Util::ValueAccess access;
DenseSet<Value> processedValues;
SmallVector<Value> worklist;
auto enqueueValue = [&](Value value) {
if (processedValues.contains(value)) return;
processedValues.insert(value);
worklist.push_back(value);
};
enqueueValue(rootValue);
while (!worklist.empty()) {
Value value = worklist.back();
worklist.pop_back();
// Walk up the definition chain.
if (auto definingOp = value.getDefiningOp()) {
// Value is produced within the region and thus written.
access.isWrite = true;
if (auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(definingOp)) {
access.isRead = true;
auto operand = tiedOp.getTiedResultOperand(value);
if (operand) {
// Value is tied back to another value; continue analyzing past it.
enqueueValue(operand);
} else {
// Value contents are fully produced by this op.
access.isDiscard = true;
}
} else if (isa<IREE::Stream::SubviewEffectOpInterface>(definingOp)) {
// TODO(benvanik): actually query; for now assume *.
access.isRead = true;
access.isWrite = true;
} else {
// Value contents are fully produced by this op.
access.isDiscard = true;
}
}
// Walk down the use chain.
for (auto user : value.getUsers()) {
// Used by an op.
access.isRead = true;
if (auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(user)) {
auto tiedIndices = tiedOp.getTiedResultOperandIndices();
for (int64_t tiedIndex : tiedIndices) {
if (tiedIndex == IREE::Util::TiedOpInterface::kUntiedIndex) continue;
auto operand = user->getOperand(tiedIndex);
if (operand == value) {
// Tied operand.
access.isRead = true;
access.isWrite = true;
enqueueValue(operand);
}
}
} else if (isa<IREE::Stream::SubviewEffectOpInterface>(user)) {
// TODO(benvanik): actually query; for now assume *.
access.isRead = true;
access.isWrite = true;
}
}
}
return access;
}
static void eraseStreamRegionResults(Region &region,
ArrayRef<unsigned> excludedResultIndices) {
for (auto &block : region.getBlocks()) {
auto yieldOp = dyn_cast<IREE::Stream::YieldOp>(block.getTerminator());
if (!yieldOp) continue;
llvm::SmallVector<Value, 4> newOperands;
for (auto i : llvm::reverse(excludedResultIndices)) {
yieldOp.operandsMutable().erase(i);
yieldOp.operand_sizesMutable().erase(i);
}
}
}
//===----------------------------------------------------------------------===//
// custom<ResourceRegion>($operands, type($operands), $operand_sizes,
// type($results), $result_sizes,
// $tied_operands, $body)
//===----------------------------------------------------------------------===//
static ParseResult parseResourceRegion(
OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &operands,
SmallVectorImpl<Type> &operandTypes,
SmallVectorImpl<OpAsmParser::OperandType> &operandSizes,
SmallVectorImpl<Type> &resultTypes,
SmallVectorImpl<OpAsmParser::OperandType> &resultSizes,
ArrayAttr &tiedOperands, Region &body) {
SmallVector<OpAsmParser::OperandType, 16> regionArgs;
if (failed(parser.parseLParen())) {
return failure();
}
if (failed(parser.parseOptionalRParen())) {
do {
// Reserve entries in the lists.
operands.emplace_back();
operandTypes.emplace_back();
operandSizes.emplace_back();
regionArgs.emplace_back();
if (failed(parser.parseOperand(operands.back())) ||
failed(parser.parseKeyword("as")) ||
failed(parser.parseRegionArgument(regionArgs.back())) ||
failed(parser.parseColon()) ||
failed(parseSizeAwareType(parser, operandTypes.back(),
operandSizes.back()))) {
return failure();
}
} while (succeeded(parser.parseOptionalComma()));
if (failed(parser.parseRParen())) {
return failure();
}
}
if (succeeded(parser.parseOptionalArrow())) {
if (succeeded(parser.parseOptionalLParen())) {
if (succeeded(parser.parseOptionalRParen())) {
// -> ()
} else if (failed(parseShapedResultList(parser, operands, operandTypes,
operandSizes, resultTypes,
resultSizes, tiedOperands)) ||
failed(parser.parseRParen())) {
return failure();
}
} else {
if (failed(parseShapedResultList(parser, operands, operandTypes,
operandSizes, resultTypes, resultSizes,
tiedOperands))) {
return failure();
}
}
}
return parser.parseRegion(body, regionArgs, operandTypes,
/*argLocations=*/{},
/*enableNameShadowing=*/false);
}
static void printResourceRegion(OpAsmPrinter &p, Operation *op,
ValueRange operands, TypeRange operandTypes,
ValueRange operandSizes, TypeRange resultTypes,
ValueRange resultSizes, ArrayAttr tiedOperands,
Region &body) {
p << "(";
llvm::interleaveComma(
llvm::zip(operands, body.getArguments()), p, [&](auto it) {
auto operand = std::get<0>(it);
auto arg = std::get<1>(it);
p << operand;
p << " as ";
p << arg;
p << ": ";
p << arg.getType();
if (arg.getType().template isa<IREE::Util::SizeAwareTypeInterface>()) {
p << "{" << operandSizes.front() << "}";
operandSizes = operandSizes.drop_front(1);
}
});
p << ")";
if (!resultTypes.empty()) {
p << " -> ";
if (resultTypes.size() != 1) p << "(";
printShapedResultList(p, op, operands, operandTypes, operandSizes,
resultTypes, resultSizes, tiedOperands);
if (resultTypes.size() != 1) p << ")";
}
p << " ";
p.printRegion(body, /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
}
//===----------------------------------------------------------------------===//
// custom<ExplicitResourceRegion>($operands, type($operands), $operand_sizes,
// $body)
//===----------------------------------------------------------------------===//
static ParseResult parseExplicitResourceRegion(
OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &operands,
SmallVectorImpl<Type> &operandTypes,
SmallVectorImpl<OpAsmParser::OperandType> &operandSizes, Region &body) {
SmallVector<OpAsmParser::OperandType, 16> regionArgs;
if (failed(parser.parseLParen())) {
return failure();
}
if (failed(parser.parseOptionalRParen())) {
do {
// Reserve entries in the lists.
operands.emplace_back();
operandTypes.emplace_back();
operandSizes.emplace_back();
regionArgs.emplace_back();
if (failed(parser.parseOperand(operands.back())) ||
failed(parser.parseKeyword("as")) ||
failed(parser.parseRegionArgument(regionArgs.back())) ||
failed(parser.parseColon()) ||
failed(parseSizeAwareType(parser, operandTypes.back(),
operandSizes.back()))) {
return failure();
}
} while (succeeded(parser.parseOptionalComma()));
if (failed(parser.parseRParen())) {
return failure();
}
}
if (failed(parser.parseRegion(body, regionArgs, operandTypes,
/*argLocations=*/{},
/*enableNameShadowing=*/false))) {
return failure();
}
// HACK: I can't figure out how to make this work with the default parsing -
// it doesn't call this like it should.
IREE::Stream::CmdExecuteOp::ensureTerminator(
body, parser.getBuilder(),
parser.getEncodedSourceLoc(parser.getCurrentLocation()));
return success();
}
static void printExplicitResourceRegion(OpAsmPrinter &p, Operation *op,
ValueRange operands,
TypeRange operandTypes,
ValueRange operandSizes, Region &body) {
p << "(";
llvm::interleaveComma(
llvm::zip(operands, body.getArguments()), p, [&](auto it) {
auto operand = std::get<0>(it);
auto arg = std::get<1>(it);
p << operand;
p << " as ";
p << arg;
p << ": ";
p << arg.getType();
if (arg.getType().template isa<IREE::Util::SizeAwareTypeInterface>()) {
p << "{" << operandSizes.front() << "}";
operandSizes = operandSizes.drop_front(1);
}
});
p << ") ";
p.printRegion(body, /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
}
//===----------------------------------------------------------------------===//
// custom<PackSliceRanges>($lifetime_intervals,
// $dynamic_slice_sizes,
// type($packed_offsets))
//===----------------------------------------------------------------------===//
static ParseResult parsePackSliceRanges(
OpAsmParser &parser, ArrayAttr &lifetimeIntervals,
SmallVectorImpl<OpAsmParser::OperandType> &dynamicSliceSizes,
SmallVectorImpl<Type> &packedOffsetTypes) {
auto indexType = parser.getBuilder().getIndexType();
SmallVector<Attribute> lifetimeRangeValues;
do {
if (failed(parser.parseOptionalLSquare())) break;
IntegerAttr lifetimeStart;
IntegerAttr lifetimeEnd;
OpAsmParser::OperandType dynamicSliceSize;
if (failed(parser.parseAttribute(lifetimeStart, indexType)) ||
failed(parser.parseComma()) ||
failed(parser.parseAttribute(lifetimeEnd, indexType)) ||
failed(parser.parseRSquare()) || failed(parser.parseEqual()) ||
failed(parser.parseOperand(dynamicSliceSize))) {
return failure();
}
lifetimeRangeValues.push_back(lifetimeStart);
lifetimeRangeValues.push_back(lifetimeEnd);
dynamicSliceSizes.push_back(dynamicSliceSize);
packedOffsetTypes.push_back(indexType);
} while (succeeded(parser.parseOptionalComma()));
lifetimeIntervals = parser.getBuilder().getArrayAttr(lifetimeRangeValues);
return success();
}
static void printPackSliceRanges(OpAsmPrinter &p, Operation *op,
ArrayAttr lifetimeIntervals,
ValueRange dynamicSliceSizes,
TypeRange packedOffsetTypes) {
if (packedOffsetTypes.empty()) return;
for (unsigned i = 0; i < packedOffsetTypes.size(); ++i) {
auto lifetimeStart = lifetimeIntervals[i * 2];
auto lifetimeEnd = lifetimeIntervals[i * 2 + 1];
auto sliceSize = dynamicSliceSizes[i];
p.printNewline();
p << " [";
p.printAttributeWithoutType(lifetimeStart);
p << ", ";
p.printAttributeWithoutType(lifetimeEnd);
p << "] = ";
p.printOperand(sliceSize);
if (i < packedOffsetTypes.size() - 1) p << ",";
}
p.printNewline();
}
//===----------------------------------------------------------------------===//
// custom<ConstantValueList>(type($results),
// $result_sizes,
// $values)
//===----------------------------------------------------------------------===//
// !stream.resource<constant>{%sz} = #value,
// !stream.resource<constant>{%sz} = #value
static ParseResult parseConstantValueList(
OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
SmallVectorImpl<OpAsmParser::OperandType> &resultSizes, ArrayAttr &values) {
SmallVector<Attribute> valueAttrs;
do {
Type resultType;
OpAsmParser::OperandType resultSize;
Attribute valueAttr;
if (failed(parseSizeAwareType(parser, resultType, resultSize)) ||
failed(parser.parseEqual()) ||
failed(parser.parseAttribute(valueAttr))) {
return failure();
}
resultTypes.push_back(resultType);
resultSizes.push_back(resultSize);
valueAttrs.push_back(valueAttr);
} while (succeeded(parser.parseOptionalComma()));
values = parser.getBuilder().getArrayAttr(valueAttrs);
return success();
}
static void printConstantValueList(OpAsmPrinter &p, Operation *op,
TypeRange resultTypes,
ValueRange resultSizes, ArrayAttr values) {
if (resultTypes.empty()) return;
for (unsigned i = 0; i < resultTypes.size(); ++i) {
p.printNewline();
p << " ";
printSizeAwareType(p, op, resultTypes[i], resultSizes[i]);
p << " = ";
p.printAttribute(values[i]);
if (i < resultTypes.size() - 1) p << ",";
}
}
//===----------------------------------------------------------------------===//
// custom<SymbolAlias>($sym_name, $alias)
//===----------------------------------------------------------------------===//
// @foo sym_name: @foo, alias: @foo
// @foo as @bar sym_name: @bar, alias: @foo
static ParseResult parseSymbolAlias(OpAsmParser &parser, StringAttr &sym_name,
FlatSymbolRefAttr &alias) {
if (failed(parser.parseAttribute(alias))) {
return failure();
}
if (succeeded(parser.parseOptionalKeyword("as"))) {
if (failed(parser.parseLParen()) ||
failed(parser.parseAttribute(sym_name)) ||
failed(parser.parseRParen())) {
return failure();
}
} else {
sym_name = StringAttr::get(parser.getContext(), alias.getValue());
}
return success();
}
static void printSymbolAlias(OpAsmPrinter &p, Operation *op,
StringAttr sym_name, FlatSymbolRefAttr alias) {
p.printAttributeWithoutType(alias);
if (sym_name.getValue() != alias.getValue()) {
p << " as(\"";
p.printSymbolName(sym_name.getValue());
p << "\")";
}
}
//===----------------------------------------------------------------------===//
// stream.resource.alloc
//===----------------------------------------------------------------------===//
LogicalResult ResourceAllocOp::verify() {
ResourceAllocOp op = *this;
if (failed(verifyOpValueSizes(op, op.results(), op.storage_sizes()))) {
return failure();
}
// All allocated resources must have the same lifetime.
auto anyType = op.results().front().getType();
for (auto type : op.getResultTypes()) {
if (type != anyType) {
return op.emitError()
<< "all allocated resources must have the same lifetime";
}
}
return success();
}
//===----------------------------------------------------------------------===//
// stream.resource.map
//===----------------------------------------------------------------------===//
LogicalResult ResourceMapOp::verify() {
ResourceMapOp op = *this;
if (failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// stream.resource.try_map
//===----------------------------------------------------------------------===//
LogicalResult ResourceTryMapOp::verify() {
ResourceTryMapOp op = *this;
if (failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// stream.resource.load
//===----------------------------------------------------------------------===//
LogicalResult ResourceLoadOp::verify() {
ResourceLoadOp op = *this;
if (failed(verifyOpValueSizes(op, op.source(), op.source_size()))) {
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// stream.resource.store
//===----------------------------------------------------------------------===//
LogicalResult ResourceStoreOp::verify() {
ResourceStoreOp op = *this;
if (failed(verifyOpValueSizes(op, op.target(), op.target_size()))) {
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// stream.resource.pack
//===----------------------------------------------------------------------===//
void ResourcePackOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
// TODO(benvanik): figure out if we can get the names to coalesce when there
// are multiple results. Ideally we'd have `%total_length, %offsets:123` but
// unfortunately all get splatted out and create 10k+ char lines that are a
// pain to read.
// setNameFn(total_length(), "total_length");
// for (auto packedOffset : llvm::enumerate(packed_offsets())) {
// setNameFn(packedOffset.value(),
// "offset" + std::to_string(packedOffset.index()));
// }
}
LogicalResult ResourcePackOp::verify() {
ResourcePackOp op = *this;
size_t sliceCount = op.packed_offsets().size();
if (op.lifetime_intervals().size() != sliceCount * 2) {
return op.emitOpError() << "requires a [start, end] range for each slice";
}
if (op.dynamic_slice_sizes().size() != sliceCount) {
return op.emitOpError() << "requires a size for each slice";
}
return success();
}
SmallVector<ResourcePackOp::Slice> ResourcePackOp::getSlices() {
auto intervalPairs = lifetime_intervals().getValue();
auto sizes = dynamic_slice_sizes();
auto offsets = packed_offsets();
SmallVector<ResourcePackOp::Slice> slices(offsets.size());
for (size_t i = 0; i < offsets.size(); ++i) {
int64_t start = intervalPairs[i * 2 + 0].cast<IntegerAttr>().getInt();
int64_t end = intervalPairs[i * 2 + 1].cast<IntegerAttr>().getInt();
slices[i] = {start, end, sizes[i], offsets[i]};
}
return slices;
}
//===----------------------------------------------------------------------===//
// stream.resource.constants
//===----------------------------------------------------------------------===//
LogicalResult ResourceConstantsOp::verify() {
ResourceConstantsOp op = *this;
size_t count = op.results().size();
if (op.result_sizes().size() != count || op.values().size() != count) {
return op.emitOpError() << "mismatched constant/result counts";
}
// All resources must have the same lifetime.
auto anyType = op.results().front().getType();
for (auto result : op.results()) {
if (result.getType() != anyType) {
return op.emitError()
<< "all constant resources must have the same lifetime";
}
}
return success();
}
//===----------------------------------------------------------------------===//
// stream.resource.subview
//===----------------------------------------------------------------------===//
LogicalResult ResourceSubviewOp::verify() {
ResourceSubviewOp op = *this;
if (failed(verifyOpValueSizes(op, op.source(), op.source_size())) ||
failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
return failure();
}
return success();
}
bool ResourceSubviewOp::isMetadata() { return true; }
Value ResourceSubviewOp::getViewSource() { return source(); }
Value ResourceSubviewOp::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(source());
}
::llvm::Optional<unsigned> ResourceSubviewOp::getTiedResultOperandIndex(
unsigned resultIndex) {
return {0}; // source
}
SmallVector<int64_t, 4> ResourceSubviewOp::getTiedResultOperandIndices() {
return {0}; // source
}
// static
IREE::Stream::ResourceSubviewOp ResourceSubviewOp::findSubviewOp(Value value) {
while (value) {
auto *definingOp = value.getDefiningOp();
if (!definingOp) {
// Defined as a block argument - stop walk.
break;
} else if (auto subviewOp =
dyn_cast<IREE::Stream::ResourceSubviewOp>(definingOp)) {
// Found!
return subviewOp;
} else if (auto tiedOp =
dyn_cast<IREE::Util::TiedOpInterface>(definingOp)) {
// Continue walking up through the tied operand.
value = tiedOp.getTiedResultOperand(value);
} else {
break;
}
}
return {};
}
//===----------------------------------------------------------------------===//
// stream.tensor.import
//===----------------------------------------------------------------------===//
LogicalResult TensorImportOp::verify() {
TensorImportOp op = *this;
if (failed(verifyOpDynamicDims(op, op.result_encoding(),
op.result_encoding_dims())) ||
failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
return failure();
}
return success();
}
Value TensorImportOp::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(source());
}
::llvm::Optional<unsigned> TensorImportOp::getTiedResultOperandIndex(
unsigned resultIndex) {
return {0}; // source
}
SmallVector<int64_t, 4> TensorImportOp::getTiedResultOperandIndices() {
return {0}; // source
}
//===----------------------------------------------------------------------===//
// stream.tensor.export
//===----------------------------------------------------------------------===//
LogicalResult TensorExportOp::verify() {
TensorExportOp op = *this;
if (failed(verifyOpDynamicDims(op, op.source_encoding(),
op.source_encoding_dims())) ||
failed(verifyOpValueSizes(op, op.source(), op.source_size()))) {
return failure();
}
return success();
}
Value TensorExportOp::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(source());
}
::llvm::Optional<unsigned> TensorExportOp::getTiedResultOperandIndex(
unsigned resultIndex) {
return {0}; // source
}
SmallVector<int64_t, 4> TensorExportOp::getTiedResultOperandIndices() {
return {0}; // source
}
//===----------------------------------------------------------------------===//
// stream.tensor.sizeof
//===----------------------------------------------------------------------===//
LogicalResult TensorSizeOfOp::verify() {
TensorSizeOfOp op = *this;
if (failed(verifyOpDynamicDims(op, op.encoding(), op.encoding_dims()))) {
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// stream.tensor.constant
//===----------------------------------------------------------------------===//
void TensorConstantOp::getAsmResultNames(mlir::OpAsmSetValueNameFn setNameFn) {
setNameFn(result(), "cst");
}
LogicalResult TensorConstantOp::verify() {
TensorConstantOp op = *this;
if (failed(verifyOpDynamicDims(op, op.result_encoding(),
op.result_encoding_dims()))) {
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// stream.tensor.splat
//===----------------------------------------------------------------------===//
LogicalResult TensorSplatOp::verify() {
TensorSplatOp op = *this;
if (failed(verifyOpDynamicDims(op, op.result_encoding(),
op.result_encoding_dims())) ||
failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// stream.tensor.clone
//===----------------------------------------------------------------------===//
LogicalResult TensorCloneOp::verify() {
TensorCloneOp op = *this;
// Clones can't change encodings but they can change shape information.
auto sourceEncoding = op.source_encoding().cast<RankedTensorType>();
auto resultEncoding = op.result_encoding().cast<RankedTensorType>();
if (sourceEncoding.getEncoding() != resultEncoding.getEncoding()) {
return op.emitOpError() << "clones changing tensor encoding from "
<< sourceEncoding.getEncoding() << " to "
<< resultEncoding.getEncoding() << "; not allowed";
}
if (failed(verifyOpDynamicDims(op, op.source_encoding(),
op.source_encoding_dims())) ||
failed(verifyOpDynamicDims(op, op.result_encoding(),
op.result_encoding_dims())) ||
failed(verifyOpValueSizes(op, op.source(), op.source_size())) ||
failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// stream.tensor.slice
//===----------------------------------------------------------------------===//
LogicalResult TensorSliceOp::verify() {
TensorSliceOp op = *this;
if (failed(verifyOpDynamicDims(op, op.source_encoding(),
op.source_encoding_dims())) ||
failed(verifyOpDynamicDims(op, op.result_encoding(),
op.result_encoding_dims())) ||
failed(verifyOpValueSizes(op, op.source(), op.source_size())) ||
failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
return failure();
}
auto sourceType = op.source_encoding().cast<ShapedType>();
if (op.start_indices().size() != sourceType.getRank() ||
op.lengths().size() != sourceType.getRank()) {
return op.emitOpError() << "start_indices/lengths rank mismatch";
}
return success();
}
//===----------------------------------------------------------------------===//
// stream.tensor.update
//===----------------------------------------------------------------------===//
LogicalResult TensorUpdateOp::verify() {
TensorUpdateOp op = *this;
if (failed(verifyOpDynamicDims(op, op.update_encoding(),
op.update_encoding_dims())) ||
failed(verifyOpDynamicDims(op, op.target_encoding(),
op.target_encoding_dims())) ||
failed(verifyOpValueSizes(op, op.update(), op.update_size())) ||
failed(verifyOpValueSizes(op, op.result(), op.target_size()))) {
return failure();
}
return success();
}
Value TensorUpdateOp::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(target());
}
::llvm::Optional<unsigned> TensorUpdateOp::getTiedResultOperandIndex(
unsigned resultIndex) {
return {0}; // target
}
SmallVector<int64_t, 4> TensorUpdateOp::getTiedResultOperandIndices() {
return {0}; // target
}
//===----------------------------------------------------------------------===//
// stream.tensor.fill
//===----------------------------------------------------------------------===//
LogicalResult TensorFillOp::verify() {
TensorFillOp op = *this;
if (failed(verifyOpDynamicDims(op, op.target_encoding(),
op.target_encoding_dims())) ||
failed(verifyOpValueSizes(op, op.result(), op.target_size()))) {
return failure();
}
return success();
}
Value TensorFillOp::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(target());
}
::llvm::Optional<unsigned> TensorFillOp::getTiedResultOperandIndex(
unsigned resultIndex) {
return {0}; // target
}
SmallVector<int64_t, 4> TensorFillOp::getTiedResultOperandIndices() {
return {0}; // target
}
//===----------------------------------------------------------------------===//
// stream.tensor.load
//===----------------------------------------------------------------------===//
LogicalResult TensorLoadOp::verify() {
TensorLoadOp op = *this;
if (failed(verifyOpDynamicDims(op, op.source_encoding(),
op.source_encoding_dims())) ||
failed(verifyOpValueSizes(op, op.source(), op.source_size()))) {
return failure();
}
auto sourceType = op.source_encoding().cast<ShapedType>();
if (op.indices().size() != sourceType.getRank()) {
return op.emitOpError() << "indices rank mismatch";
}
return success();
}
//===----------------------------------------------------------------------===//
// stream.tensor.store
//===----------------------------------------------------------------------===//
LogicalResult TensorStoreOp::verify() {
TensorStoreOp op = *this;
if (failed(verifyOpDynamicDims(op, op.target_encoding(),
op.target_encoding_dims())) ||
failed(verifyOpValueSizes(op, op.target(), op.target_size()))) {
return failure();
}
auto targetType = op.target_encoding().cast<ShapedType>();
if (op.indices().size() != targetType.getRank()) {
return op.emitOpError() << "indices rank mismatch";
}
return success();
}
Value TensorStoreOp::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(target());
}
::llvm::Optional<unsigned> TensorStoreOp::getTiedResultOperandIndex(
unsigned resultIndex) {
return {0}; // target
}
SmallVector<int64_t, 4> TensorStoreOp::getTiedResultOperandIndices() {
return {0}; // target
}
//===----------------------------------------------------------------------===//
// stream.builtin.* utilities
//===----------------------------------------------------------------------===//
// Merges a builtin module from iree/compiler/Dialect/Stream/Builtins/*.mlir
// into the user module; this allows for host functions and multiple
// executables.
//
// Fails if there's a name conflict; we have a __ prefix and things outside the
// compiler shouldn't use it.
static LogicalResult mergeBuiltinModuleSource(Location loc, StringRef fileName,
Operation *targetOp,
OpBuilder &targetBuilder) {
// Find the file in the embedded data.
const iree_file_toc_t *toc = iree_compiler_Stream_Builtins_create();
const iree_file_toc_t *file = nullptr;
for (size_t i = 0; i < iree_compiler_Stream_Builtins_size(); ++i) {
if (fileName == toc[i].name) {
file = &toc[i];
break;
}
}
if (!file) {
return mlir::emitError(
loc, "unable to merge builtin module; file not found " + fileName);
}
SymbolTable targetSymbols(targetOp);
return mergeSourceModuleInto(loc, StringRef(file->data, file->size), targetOp,
targetSymbols, targetBuilder);
}
//===----------------------------------------------------------------------===//
// stream.builtin.splat.i64
//===----------------------------------------------------------------------===//
LogicalResult BuiltinSplatI64Op::verify() {
BuiltinSplatI64Op op = *this;
if (failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
return failure();
}
return success();
}
LogicalResult BuiltinSplatI64Op::mergeBuiltinModule(Operation *targetOp,
OpBuilder &targetBuilder) {
return mergeBuiltinModuleSource(getLoc(), "splat_i64.mlir", targetOp,
targetBuilder);
}
LogicalResult BuiltinSplatI64Op::convertBuiltinOp(OpBuilder &builder) {
auto c8 = builder.createOrFold<arith::ConstantIndexOp>(getLoc(), 8);
auto count =
builder.createOrFold<arith::DivUIOp>(getLoc(), result_size(), c8);
auto one = builder.create<arith::ConstantIndexOp>(getLoc(), 1);
Value workgroupCount[3] = {
count,
one,
one,
};
SmallVector<Value> operands = {
value(),
count,
};
SmallVector<Value> operandSizes = {};
SmallVector<int64_t> tiedOperands = {
-1,
};
SmallVector<Value> resultSizes = {
result_size(),
};
SmallVector<Type> resultTypes{
result().getType(),
};
auto dispatchOp = builder.create<IREE::Stream::AsyncDispatchOp>(
getLoc(), resultTypes, workgroupCount,
SymbolRefAttr::get(
builder.getStringAttr("__builtin_splat_i64"),
FlatSymbolRefAttr::get(builder.getContext(), "__builtin_splat_i64")),
operands, operandSizes, resultSizes,
builder.getIndexArrayAttr(tiedOperands), affinityAttr());
result().replaceAllUsesWith(dispatchOp.results().front());
return success();
}
//===----------------------------------------------------------------------===//
// stream.builtin.fill.i64
//===----------------------------------------------------------------------===//
LogicalResult BuiltinFillI64Op::verify() {
BuiltinFillI64Op op = *this;
if (failed(verifyOpValueSizes(op, op.result(), op.target_size()))) {
return failure();
}
return success();
}
Value BuiltinFillI64Op::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(target());
}
::llvm::Optional<unsigned> BuiltinFillI64Op::getTiedResultOperandIndex(
unsigned resultIndex) {
return {0}; // target
}
SmallVector<int64_t, 4> BuiltinFillI64Op::getTiedResultOperandIndices() {
return {0}; // target
}
LogicalResult BuiltinFillI64Op::mergeBuiltinModule(Operation *targetOp,
OpBuilder &targetBuilder) {
return mergeBuiltinModuleSource(getLoc(), "fill_i64.mlir", targetOp,
targetBuilder);
}
LogicalResult BuiltinFillI64Op::convertBuiltinOp(OpBuilder &builder) {
auto c8 = builder.createOrFold<arith::ConstantIndexOp>(getLoc(), 8);
auto count =
builder.createOrFold<arith::DivUIOp>(getLoc(), target_length(), c8);
auto one = builder.create<arith::ConstantIndexOp>(getLoc(), 1);
Value workgroupCount[3] = {
count,
one,
one,
};
SmallVector<Value> operands = {
target(),
value(),
target_offset(),
count,
};
SmallVector<Value> operandSizes = {
target_size(),
};
SmallVector<int64_t> tiedOperands = {
0,
};
SmallVector<Value> resultSizes = {
target_size(),
};
SmallVector<Type> resultTypes{
result().getType(),
};
auto dispatchOp = builder.create<IREE::Stream::AsyncDispatchOp>(
getLoc(), resultTypes, workgroupCount,
SymbolRefAttr::get(
builder.getStringAttr("__builtin_fill_i64"),
FlatSymbolRefAttr::get(builder.getContext(), "__builtin_fill_i64")),
operands, operandSizes, resultSizes,
builder.getIndexArrayAttr(tiedOperands), affinityAttr());
result().replaceAllUsesWith(dispatchOp.results().front());
return success();
}
//===----------------------------------------------------------------------===//
// stream.async.alloca
//===----------------------------------------------------------------------===//
bool AsyncAllocaOp::isMetadata() { return true; }
bool AsyncAllocaOp::preferCloneToConsumers() { return true; }
//===----------------------------------------------------------------------===//
// stream.async.constant
//===----------------------------------------------------------------------===//
bool AsyncConstantOp::isMetadata() { return true; }
void AsyncConstantOp::getAsmResultNames(mlir::OpAsmSetValueNameFn setNameFn) {
setNameFn(result(), "cst");
}
LogicalResult AsyncConstantOp::verify() {
AsyncConstantOp op = *this;
if (failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// stream.async.splat
//===----------------------------------------------------------------------===//
LogicalResult AsyncSplatOp::verify() {
AsyncSplatOp op = *this;
if (failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
return failure();
}
return success();
}
bool AsyncSplatOp::preferCloneToConsumers() { return true; }
//===----------------------------------------------------------------------===//
// stream.async.clone
//===----------------------------------------------------------------------===//
LogicalResult AsyncCloneOp::verify() {
AsyncCloneOp op = *this;
if (failed(verifyOpValueSizes(op, op.source(), op.source_size())) ||
failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
return failure();
}
return success();
}
bool AsyncCloneOp::preferCloneToConsumers() { return true; }
//===----------------------------------------------------------------------===//
// stream.async.slice
//===----------------------------------------------------------------------===//
LogicalResult AsyncSliceOp::verify() {
AsyncSliceOp op = *this;
if (failed(verifyOpValueSizes(op, op.source(), op.source_size())) ||
failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
return failure();
}
return success();
}
bool AsyncSliceOp::isMetadata() { return true; }
//===----------------------------------------------------------------------===//
// stream.async.fill
//===----------------------------------------------------------------------===//
LogicalResult AsyncFillOp::verify() {
AsyncFillOp op = *this;
if (failed(verifyOpValueSizes(op, op.result(), op.target_size()))) {
return failure();
}
return success();
}
Value AsyncFillOp::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(target());
}
::llvm::Optional<unsigned> AsyncFillOp::getTiedResultOperandIndex(
unsigned resultIndex) {
return {0}; // target
}
SmallVector<int64_t, 4> AsyncFillOp::getTiedResultOperandIndices() {
return {0}; // target
}
//===----------------------------------------------------------------------===//
// stream.async.update
//===----------------------------------------------------------------------===//
LogicalResult AsyncUpdateOp::verify() {
AsyncUpdateOp op = *this;
if (failed(verifyOpValueSizes(op, op.update(), op.update_size())) ||
failed(verifyOpValueSizes(op, op.result(), op.target_size()))) {
return failure();
}
return success();
}
bool AsyncUpdateOp::isMetadata() { return true; }
Value AsyncUpdateOp::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(target());
}
::llvm::Optional<unsigned> AsyncUpdateOp::getTiedResultOperandIndex(
unsigned resultIndex) {
return {0}; // target
}
SmallVector<int64_t, 4> AsyncUpdateOp::getTiedResultOperandIndices() {
return {0}; // target
}
//===----------------------------------------------------------------------===//
// stream.async.copy
//===----------------------------------------------------------------------===//
LogicalResult AsyncCopyOp::verify() {
AsyncCopyOp op = *this;
if (op.source() == op.target()) {
// If we want to perform memmove-like operations where it's safe to copy
// overlapping ranges we'll need to emit some runtime checks. We can in
// many cases statically detect a lack of overlap just based on symbolic
// offset equality but that requires some analysis we don't have yet.
return op.emitOpError() << "cannot copy within the same resource (yet)";
}
if (failed(verifyOpValueSizes(op, op.source(), op.source_size())) ||
failed(verifyOpValueSizes(op, op.result(), op.target_size()))) {
return failure();
}
return success();
}
Value AsyncCopyOp::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(target());
}
::llvm::Optional<unsigned> AsyncCopyOp::getTiedResultOperandIndex(
unsigned resultIndex) {
return {0}; // target
}
SmallVector<int64_t, 4> AsyncCopyOp::getTiedResultOperandIndices() {
return {0}; // target
}
//===----------------------------------------------------------------------===//
// stream.async.transfer
//===----------------------------------------------------------------------===//
LogicalResult AsyncTransferOp::verify() {
AsyncTransferOp op = *this;
if (failed(verifyOpValueSizes(op, op.source(), op.source_size())) ||
failed(verifyOpValueSizes(op, op.result(), op.result_size()))) {
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// stream.async.load
//===----------------------------------------------------------------------===//
LogicalResult AsyncLoadOp::verify() {
AsyncLoadOp op = *this;
if (failed(verifyOpValueSizes(op, op.source(), op.source_size()))) {
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// stream.async.store
//===----------------------------------------------------------------------===//
LogicalResult AsyncStoreOp::verify() {
AsyncStoreOp op = *this;
if (failed(verifyOpValueSizes(op, op.target(), op.target_size()))) {
return failure();
}
return success();
}
Value AsyncStoreOp::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(target());
}
::llvm::Optional<unsigned> AsyncStoreOp::getTiedResultOperandIndex(
unsigned resultIndex) {
return {0}; // target
}
SmallVector<int64_t, 4> AsyncStoreOp::getTiedResultOperandIndices() {
return {0}; // target
}
//===----------------------------------------------------------------------===//
// stream.async.dispatch
//===----------------------------------------------------------------------===//
LogicalResult AsyncDispatchOp::verify() {
AsyncDispatchOp op = *this;
if (failed(verifyOpValueSizes(op, op.operands(), op.operand_sizes())) ||
failed(verifyOpValueSizes(op, op.results(), op.result_sizes()))) {
return failure();
}
return success();
}
std::pair<unsigned, unsigned> AsyncDispatchOp::getTiedOperandsIndexAndLength() {
return getODSOperandIndexAndLength(1); // $operands
}
//===----------------------------------------------------------------------===//
// stream.async.execute
//===----------------------------------------------------------------------===//
void AsyncExecuteOp::build(OpBuilder &builder, OperationState &state,
TypeRange resultTypes, ValueRange resultSizes,
Value awaitTimepoint, ValueRange operands,
ValueRange operandSizes,
ArrayRef<int64_t> tiedOperands,
ArrayRef<NamedAttribute> attributes) {
state.addTypes(resultTypes);
state.addTypes(IREE::Stream::TimepointType::get(builder.getContext()));
state.addOperands(operands);
state.addOperands(operandSizes);
state.addOperands(resultSizes);
if (awaitTimepoint) state.addOperands(awaitTimepoint);
state.addAttributes(attributes);
state.attributes.erase(IREE::Util::TiedOpInterface::getStorageAttrName());
state.addAttribute(IREE::Util::TiedOpInterface::getStorageAttrName(),
builder.getIndexArrayAttr(tiedOperands));
state.attributes.erase("operand_segment_sizes");
state.addAttribute("operand_segment_sizes",
builder.getI32VectorAttr({
static_cast<int32_t>(operands.size()),
static_cast<int32_t>(operandSizes.size()),
static_cast<int32_t>(resultSizes.size()),
awaitTimepoint ? 1 : 0,
}));
state.addRegion();
}
LogicalResult AsyncExecuteOp::verify() {
AsyncExecuteOp op = *this;
if (failed(RegionBranchOpInterface::verifyTypes(op))) return failure();
if (failed(verifyOpValueSizes(op, op.operands(), op.operand_sizes())) ||
failed(verifyOpValueSizes(op, op.results(), op.result_sizes()))) {
return failure();
}
if (failed(verifyAllResourcesCaptured(op.body())) ||
failed(verifyEscapingResources(op.body(), op.results(),
op.result_sizes()))) {
return failure();
}
return success();
}
std::pair<unsigned, unsigned> AsyncExecuteOp::getTiedResultsIndexAndLength() {
return {0, results().size()};
}
OperandRange AsyncExecuteOp::getSuccessorEntryOperands(unsigned index) {
assert(index == 0 && "invalid region index");
return operands();
}
void AsyncExecuteOp::getSuccessorRegions(
Optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
// Unconditional control flow into the region and back to the parent, so
// return the correct RegionSuccessor purely based on the index being None or
// 0.
if (index.hasValue()) {
regions.push_back(RegionSuccessor(results()));
} else {
regions.push_back(RegionSuccessor(&body(), body().getArguments()));
}
}
Operation::operand_range AsyncExecuteOp::getClosureOperands() {
return operands();
}
Operation::result_range AsyncExecuteOp::getClosureResults() {
return results();
}
bool AsyncExecuteOp::canClosureContainOp(Operation *op) { return false; }
IREE::Util::ValueAccess AsyncExecuteOp::getOperandAccess(
unsigned operandIndex) {
auto arg = body().getArgument(operandIndex);
return computeValueAccess(arg);
}
IREE::Util::ValueAccess AsyncExecuteOp::getResultAccess(unsigned resultIndex) {
auto yieldOp = cast<YieldOp>(body().getBlocks().front().getTerminator());
return computeValueAccess(yieldOp.getOperand(resultIndex));
}
IREE::Util::ClosureOpInterface
AsyncExecuteOp::cloneReplacementExcludingOperandsAndResults(
ArrayRef<unsigned> excludedOperandIndices,
ArrayRef<unsigned> excludedResultIndices, PatternRewriter &rewriter) {
auto newResultTypes = llvm::to_vector<4>(
llvm::map_range(results(), [](auto value) { return value.getType(); }));
auto newResultSizes = llvm::to_vector<4>(result_sizes());
auto newOperandsValues = llvm::to_vector<4>(operands());
auto newOperandSizes = llvm::to_vector<4>(operand_sizes());
IREE::Util::excludeClosureOperandsAndResults(
newOperandsValues, newOperandSizes, excludedOperandIndices,
newResultTypes, newResultSizes, excludedResultIndices);
auto newTiedOperandIndices =
llvm::to_vector<4>(getTiedResultOperandIndices());
IREE::Util::excludeTiedOperandAndResultIndices(
excludedOperandIndices, excludedResultIndices, newTiedOperandIndices);
assert(getTiedOperandsIndexAndLength().first == 0 &&
"operands must be the first ODS group");
auto newOp = rewriter.create<AsyncExecuteOp>(
getLoc(), newResultTypes, newResultSizes, await_timepoint(),
newOperandsValues, newOperandSizes, newTiedOperandIndices,
getOperation()->getAttrs());
auto &newBody = newOp.getClosureBodyRegion();
newBody.takeBody(getClosureBodyRegion());
eraseStreamRegionResults(newBody, excludedResultIndices);
newBody.front().eraseArguments(excludedOperandIndices);
return newOp;
}
//===----------------------------------------------------------------------===//
// stream.async.concurrent
//===----------------------------------------------------------------------===//
void AsyncConcurrentOp::build(OpBuilder &builder, OperationState &state,
TypeRange resultTypes, ValueRange resultSizes,
ValueRange operands, ValueRange operandSizes,
ArrayRef<int64_t> tiedOperands,
ArrayRef<NamedAttribute> attributes) {
state.addTypes(resultTypes);
state.addOperands(operands);
state.addOperands(operandSizes);
state.addOperands(resultSizes);
state.addAttributes(attributes);
state.attributes.erase(IREE::Util::TiedOpInterface::getStorageAttrName());
state.addAttribute(IREE::Util::TiedOpInterface::getStorageAttrName(),
builder.getIndexArrayAttr(tiedOperands));
state.attributes.erase("operand_segment_sizes");
state.addAttribute("operand_segment_sizes",
builder.getI32VectorAttr({
static_cast<int32_t>(operands.size()),
static_cast<int32_t>(operandSizes.size()),
static_cast<int32_t>(resultSizes.size()),
}));
state.addRegion();
}
LogicalResult AsyncConcurrentOp::verify() {
AsyncConcurrentOp op = *this;
if (failed(RegionBranchOpInterface::verifyTypes(op))) return failure();
if (failed(verifyOpValueSizes(op, op.operands(), op.operand_sizes())) ||
failed(verifyOpValueSizes(op, op.results(), op.result_sizes()))) {
return failure();
}
if (failed(verifyAllResourcesCaptured(op.body())) ||
failed(verifyEscapingResources(op.body(), op.results(),
op.result_sizes()))) {
return failure();
}
return success();
}
OperandRange AsyncConcurrentOp::getSuccessorEntryOperands(unsigned index) {
assert(index == 0 && "invalid region index");
return operands();
}
void AsyncConcurrentOp::getSuccessorRegions(
Optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
// Unconditional control flow into the region and back to the parent, so
// return the correct RegionSuccessor purely based on the index being None or
// 0.
if (index.hasValue()) {
regions.push_back(RegionSuccessor(results()));
} else {
regions.push_back(RegionSuccessor(&body(), body().getArguments()));
}
}
Operation::operand_range AsyncConcurrentOp::getClosureOperands() {
return operands();
}
Operation::result_range AsyncConcurrentOp::getClosureResults() {
return results();
}
bool AsyncConcurrentOp::canClosureContainOp(Operation *op) { return false; }
IREE::Util::ValueAccess AsyncConcurrentOp::getOperandAccess(
unsigned operandIndex) {
auto arg = body().getArgument(operandIndex);
return computeValueAccess(arg);
}
IREE::Util::ValueAccess AsyncConcurrentOp::getResultAccess(
unsigned resultIndex) {
auto yieldOp = cast<YieldOp>(body().getBlocks().front().getTerminator());
return computeValueAccess(yieldOp.getOperand(resultIndex));
}
IREE::Util::ClosureOpInterface
AsyncConcurrentOp::cloneReplacementExcludingOperandsAndResults(
ArrayRef<unsigned> excludedOperandIndices,
ArrayRef<unsigned> excludedResultIndices, PatternRewriter &rewriter) {
auto newResultTypes = llvm::to_vector<4>(getResultTypes());
auto newResultSizes = llvm::to_vector<4>(result_sizes());
auto newOperandsValues = llvm::to_vector<4>(operands());
auto newOperandSizes = llvm::to_vector<4>(operand_sizes());
IREE::Util::excludeClosureOperandsAndResults(
newOperandsValues, newOperandSizes, excludedOperandIndices,
newResultTypes, newResultSizes, excludedResultIndices);
auto newTiedOperandIndices =
llvm::to_vector<4>(getTiedResultOperandIndices());
IREE::Util::excludeTiedOperandAndResultIndices(
excludedOperandIndices, excludedResultIndices, newTiedOperandIndices);
assert(getTiedOperandsIndexAndLength().first == 0 &&
"operands must be the first ODS group");
auto newOp = rewriter.create<AsyncConcurrentOp>(
getLoc(), newResultTypes, newResultSizes, newOperandsValues,
newOperandSizes, newTiedOperandIndices, getOperation()->getAttrs());
auto &newBody = newOp.getClosureBodyRegion();
newBody.takeBody(getClosureBodyRegion());
eraseStreamRegionResults(newBody, excludedResultIndices);
newBody.front().eraseArguments(excludedOperandIndices);
return newOp;
}
//===----------------------------------------------------------------------===//
// stream.cmd.flush
//===----------------------------------------------------------------------===//
LogicalResult CmdFlushOp::verify() {
CmdFlushOp op = *this;
if (failed(verifyOpValueSizes(op, op.target(), op.target_size()))) {
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// stream.cmd.invalidate
//===----------------------------------------------------------------------===//
LogicalResult CmdInvalidateOp::verify() {
CmdInvalidateOp op = *this;
if (failed(verifyOpValueSizes(op, op.target(), op.target_size()))) {
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// stream.cmd.discard
//===----------------------------------------------------------------------===//
LogicalResult CmdDiscardOp::verify() {
CmdDiscardOp op = *this;
if (failed(verifyOpValueSizes(op, op.target(), op.target_size()))) {
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// stream.cmd.fill
//===----------------------------------------------------------------------===//
LogicalResult CmdFillOp::verify() {
CmdFillOp op = *this;
if (failed(verifyOpValueSizes(op, op.target(), op.target_size()))) {
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// stream.cmd.copy
//===----------------------------------------------------------------------===//
LogicalResult CmdCopyOp::verify() {
CmdCopyOp op = *this;
if (failed(verifyOpValueSizes(op, op.source(), op.source_size())) ||
failed(verifyOpValueSizes(op, op.target(), op.target_size()))) {
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// stream.cmd.dispatch
//===----------------------------------------------------------------------===//
LogicalResult CmdDispatchOp::verify() {
CmdDispatchOp op = *this;
size_t resourceCount = op.resources().size();
if (op.resource_sizes().size() != resourceCount ||
op.resource_offsets().size() != resourceCount ||
op.resource_lengths().size() != resourceCount ||
op.resource_accesses().size() != resourceCount) {
return op->emitOpError() << "dispatch with " << resourceCount
<< " resources has mismatched associated ranges";
}
return success();
}
static ParseResult parseDispatchResources(
OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &resources,
SmallVectorImpl<Type> &resourceTypes,
SmallVectorImpl<OpAsmParser::OperandType> &resourceSizes,
SmallVectorImpl<OpAsmParser::OperandType> &resourceOffsets,
SmallVectorImpl<OpAsmParser::OperandType> &resourceLengths,
ArrayAttr &resourceAccesses) {
SmallVector<Attribute> accessAttrs;
do {
// Reserve entries in the lists.
resources.emplace_back();
resourceTypes.emplace_back();
resourceSizes.emplace_back();
resourceOffsets.emplace_back();
resourceLengths.emplace_back();
StringRef accessStr;
if (failed(parser.parseKeyword(&accessStr)) ||
failed(parser.parseOperand(resources.back())) ||
failed(parser.parseLSquare()) ||
failed(parser.parseOperand(resourceOffsets.back())) ||
failed(parser.parseKeyword("for")) ||
failed(parser.parseOperand(resourceLengths.back())) ||
failed(parser.parseRSquare()) || failed(parser.parseColon()) ||
failed(parseSizeAwareType(parser, resourceTypes.back(),
resourceSizes.back()))) {
return failure();
}
IREE::Stream::ResourceAccessBitfield accessBits =
IREE::Stream::ResourceAccessBitfield::None;
if (accessStr == "ro") {
accessBits = IREE::Stream::ResourceAccessBitfield::Read;
} else if (accessStr == "wo") {
accessBits = IREE::Stream::ResourceAccessBitfield::Write;
} else if (accessStr == "rw") {
accessBits = IREE::Stream::ResourceAccessBitfield::Read |
IREE::Stream::ResourceAccessBitfield::Write;
}
accessAttrs.push_back(IREE::Stream::ResourceAccessBitfieldAttr::get(
parser.getBuilder().getContext(), accessBits));
} while (succeeded(parser.parseOptionalComma()));
resourceAccesses = parser.getBuilder().getArrayAttr(accessAttrs);
return success();
}
static void printDispatchResources(OpAsmPrinter &p, Operation *op,
ValueRange resources,
TypeRange resourceTypes,
ValueRange resourceSizes,
ValueRange resourceOffsets,
ValueRange resourceLengths,
ArrayAttr resourceAccesses) {
for (size_t i = 0; i < resources.size(); ++i) {
auto resource = resources[i];
auto resourceType = resourceTypes[i];
auto resourceSize = resourceSizes[i];
auto resourceOffset = resourceOffsets[i];
auto resourceLength = resourceLengths[i];
auto resourceAccess = resourceAccesses[i]
.cast<IREE::Stream::ResourceAccessBitfieldAttr>()
.getValue();
p.printNewline();
p << " ";
if (bitEnumContains(resourceAccess,
IREE::Stream::ResourceAccessBitfield::Read) &&
bitEnumContains(resourceAccess,
IREE::Stream::ResourceAccessBitfield::Write)) {
p << "rw";
} else if (bitEnumContains(resourceAccess,
IREE::Stream::ResourceAccessBitfield::Read)) {
p << "ro";
} else if (bitEnumContains(resourceAccess,
IREE::Stream::ResourceAccessBitfield::Write)) {
p << "wo";
}
p << ' ';
p.printOperand(resource);
p << "[";
p.printOperand(resourceOffset);
p << " for ";
p.printOperand(resourceLength);
p << "] : ";
printSizeAwareType(p, op, resourceType, resourceSize);
if (i < resources.size() - 1) p << ",";
}
}
// This is sloppy because the function has interleaved bindings and operands;
// if we had our own op we could just reuse the map we have for operands.
// static
SmallVector<unsigned> CmdDispatchOp::makeOperandToArgMap(mlir::FuncOp funcOp) {
unsigned operandCount = llvm::count_if(
funcOp.getArgumentTypes(),
[](Type type) { return !type.isa<IREE::Stream::BindingType>(); });
SmallVector<unsigned> map(operandCount);
unsigned operandIdx = 0;
for (auto it : llvm::enumerate(funcOp.getArgumentTypes())) {
unsigned argIdx = it.index();
auto argType = it.value();
if (!argType.isa<IREE::Stream::BindingType>()) {
map[operandIdx++] = argIdx;
}
}
return map;
}
// static
SmallVector<unsigned> CmdDispatchOp::makeResourceToArgMap(mlir::FuncOp funcOp) {
unsigned operandCount = llvm::count_if(
funcOp.getArgumentTypes(),
[](Type type) { return type.isa<IREE::Stream::BindingType>(); });
SmallVector<unsigned> map(operandCount);
unsigned operandIdx = 0;
for (auto it : llvm::enumerate(funcOp.getArgumentTypes())) {
unsigned argIdx = it.index();
auto argType = it.value();
if (argType.isa<IREE::Stream::BindingType>()) {
map[operandIdx++] = argIdx;
}
}
return map;
}
//===----------------------------------------------------------------------===//
// stream.cmd.execute
//===----------------------------------------------------------------------===//
void CmdExecuteOp::build(OpBuilder &builder, OperationState &state,
Value awaitTimepoint, ValueRange operands,
ValueRange operandSizes,
ArrayRef<NamedAttribute> attributes) {
state.addTypes(IREE::Stream::TimepointType::get(builder.getContext()));
state.addOperands(operands);
state.addOperands(operandSizes);
if (awaitTimepoint) state.addOperands(awaitTimepoint);
state.addAttributes(attributes);
state.attributes.erase("operand_segment_sizes");
state.addAttribute("operand_segment_sizes",
builder.getI32VectorAttr({
static_cast<int32_t>(operands.size()),
static_cast<int32_t>(operandSizes.size()),
awaitTimepoint ? 1 : 0,
}));
state.addRegion();
}
// Returns success if the given op is a known valid stream.cmd.* op for use
// within an execution region.
static LogicalResult verifyCmdOp(Operation *op) {
// TODO(benvanik): add a trait that lets us avoid this switch.
if (!TypeSwitch<Operation *, bool>(op)
.Case<IREE::Stream::CmdFlushOp, IREE::Stream::CmdInvalidateOp,
IREE::Stream::CmdDiscardOp, IREE::Stream::CmdFillOp,
IREE::Stream::CmdCopyOp, IREE::Stream::CmdDispatchOp,
IREE::Stream::CmdSerialOp, IREE::Stream::CmdConcurrentOp>(
[](auto op) { return true; })
.Case<IREE::Stream::YieldOp>([](auto op) { return true; })
.Default(false)) {
return op->emitOpError()
<< "explicit execution regions must only contain explicit ops";
}
return success();
}
LogicalResult CmdExecuteOp::verify() {
CmdExecuteOp op = *this;
if (failed(RegionBranchOpInterface::verifyTypes(op))) return failure();
if (failed(verifyOpValueSizes(op, op.operands(), op.operand_sizes()))) {
return failure();
}
if (failed(verifyAllResourcesCaptured(op.body()))) {
return failure();
}
for (auto &nestedOp : op.body().front()) {
if (failed(verifyCmdOp(&nestedOp))) return failure();
}
return success();
}
OperandRange CmdExecuteOp::getSuccessorEntryOperands(unsigned index) {
assert(index == 0 && "invalid region index");
return operands();
}
void CmdExecuteOp::getSuccessorRegions(
Optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
// Unconditional control flow into the region and back to the parent, so
// return the correct RegionSuccessor purely based on the index being None or
// 0.
if (index.hasValue()) {
regions.push_back(RegionSuccessor({}));
} else {
regions.push_back(RegionSuccessor(&body(), body().getArguments()));
}
}
Operation::operand_range CmdExecuteOp::getClosureOperands() {
return operands();
}
Operation::result_range CmdExecuteOp::getClosureResults() {
return Operation::result_range(nullptr, 0);
}
bool CmdExecuteOp::canClosureContainOp(Operation *op) { return false; }
IREE::Util::ValueAccess CmdExecuteOp::getOperandAccess(unsigned operandIndex) {
auto arg = body().getArgument(operandIndex);
return computeValueAccess(arg);
}
IREE::Util::ValueAccess CmdExecuteOp::getResultAccess(unsigned resultIndex) {
return IREE::Util::ValueAccess::None();
}
IREE::Util::ClosureOpInterface
CmdExecuteOp::cloneReplacementExcludingOperandsAndResults(
ArrayRef<unsigned> excludedOperandIndices,
ArrayRef<unsigned> excludedResultIndices, PatternRewriter &rewriter) {
SmallVector<Type, 4> newResultTypes;
SmallVector<Value, 4> newResultSizes;
auto newOperandsValues = llvm::to_vector<4>(operands());
auto newOperandSizes = llvm::to_vector<4>(operand_sizes());
IREE::Util::excludeClosureOperandsAndResults(
newOperandsValues, newOperandSizes, excludedOperandIndices,
newResultTypes, newResultSizes, excludedResultIndices);
auto newOp = rewriter.create<CmdExecuteOp>(getLoc(), await_timepoint(),
newOperandsValues, newOperandSizes,
getOperation()->getAttrs());
auto &newBody = newOp.getClosureBodyRegion();
newBody.takeBody(getClosureBodyRegion());
newBody.front().eraseArguments(excludedOperandIndices);
return newOp;
}
//===----------------------------------------------------------------------===//
// stream.cmd.serial
//===----------------------------------------------------------------------===//
LogicalResult CmdSerialOp::verify() {
CmdSerialOp op = *this;
for (auto &nestedOp : op.body().front()) {
if (failed(verifyCmdOp(&nestedOp))) return failure();
}
return success();
}
void CmdSerialOp::getSuccessorRegions(
Optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
// Unconditional control flow into the region and back to the parent, so
// return the correct RegionSuccessor purely based on the index being None or
// 0.
if (index.hasValue()) {
regions.push_back(RegionSuccessor({}));
} else {
regions.push_back(RegionSuccessor(&body(), {}));
}
}
//===----------------------------------------------------------------------===//
// stream.cmd.concurrent
//===----------------------------------------------------------------------===//
LogicalResult CmdConcurrentOp::verify() {
CmdConcurrentOp op = *this;
for (auto &nestedOp : op.body().front()) {
if (failed(verifyCmdOp(&nestedOp))) return failure();
}
return success();
}
void CmdConcurrentOp::getSuccessorRegions(
Optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
// Unconditional control flow into the region and back to the parent, so
// return the correct RegionSuccessor purely based on the index being None or
// 0.
if (index.hasValue()) {
regions.push_back(RegionSuccessor({}));
} else {
regions.push_back(RegionSuccessor(&body(), {}));
}
}
//===----------------------------------------------------------------------===//
// stream.timepoint.join
//===----------------------------------------------------------------------===//
LogicalResult TimepointJoinOp::verify() {
// We could test if timepoints all come from the same place - this is not
// strictly required but if we could avoid it things will be easier to
// implement at runtime (won't have to do a cuda<->vulkan sync, etc).
return success();
}
//===----------------------------------------------------------------------===//
// stream.timepoint.await
//===----------------------------------------------------------------------===//
void TimepointAwaitOp::build(OpBuilder &builder, OperationState &state,
ValueRange operands, ValueRange operandSizes,
Value timepoint,
ArrayRef<NamedAttribute> attributes) {
state.addTypes(llvm::map_range(
operands, [&](Value operand) { return operand.getType(); }));
state.addOperands(operands);
state.addOperands(operandSizes);
state.addOperands(timepoint);
state.addAttributes(attributes);
state.attributes.erase("operand_segment_sizes");
state.addAttribute("operand_segment_sizes",
builder.getI32VectorAttr({
static_cast<int32_t>(operands.size()),
static_cast<int32_t>(operandSizes.size()),
static_cast<int32_t>(1), // timepoint
}));
}
LogicalResult TimepointAwaitOp::verify() {
TimepointAwaitOp op = *this;
if (failed(verifyOpValueSizes(op, op.operands(), op.operand_sizes())) ||
failed(verifyOpValueSizes(op, op.results(), op.operand_sizes()))) {
return failure();
}
return success();
}
::llvm::Optional<unsigned> TimepointAwaitOp::getTiedResultOperandIndex(
unsigned resultIndex) {
return {resultIndex};
}
SmallVector<int64_t, 4> TimepointAwaitOp::getTiedResultOperandIndices() {
return llvm::to_vector<4>(llvm::seq<int64_t>(0, operands().size()));
}
//===----------------------------------------------------------------------===//
// stream.executable
//===----------------------------------------------------------------------===//
void ExecutableOp::build(OpBuilder &builder, OperationState &state,
StringRef sym_name) {
ensureTerminator(*state.addRegion(), builder, state.location);
state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(sym_name));
}
LogicalResult ExecutableOp::verify() {
// TODO(benvanik): check export name conflicts.
return success();
}
//===----------------------------------------------------------------------===//
// stream.executable.entry
//===----------------------------------------------------------------------===//
void ExecutableExportOp::build(OpBuilder &builder, OperationState &state,
StringRef sym_name,
FlatSymbolRefAttr function_ref) {
build(builder, state, /*sym_visibility=*/nullptr,
builder.getStringAttr(sym_name), function_ref);
}
::mlir::FuncOp ExecutableExportOp::getFunctionRef() {
auto executableOp =
this->getOperation()->getParentOfType<IREE::Stream::ExecutableOp>();
if (!executableOp) return {};
return executableOp.getInnerModule().lookupSymbol<::mlir::FuncOp>(
function_ref());
}
//===----------------------------------------------------------------------===//
// stream.binding.subspan
//===----------------------------------------------------------------------===//
LogicalResult BindingSubspanOp::verify() {
BindingSubspanOp op = *this;
if (auto shapedType = op.getType().dyn_cast<ShapedType>()) {
if (failed(verifyOpDynamicDims(op, shapedType, op.dynamic_dims()))) {
return failure();
}
}
return success();
}
//===----------------------------------------------------------------------===//
// stream.yield
//===----------------------------------------------------------------------===//
MutableOperandRange YieldOp::getMutableSuccessorOperands(
Optional<unsigned> index) {
return operandsMutable();
}
} // namespace Stream
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir
//===----------------------------------------------------------------------===//
// TableGen definitions (intentionally last)
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "iree/compiler/Dialect/Stream/IR/StreamOps.cpp.inc" // IWYU pragma: keep