blob: 24b56e6197ecc299f17027d5c6791ddba59b724d [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOpUtils.h"
#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
#include "llvm/ADT/StringExtras.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/RegionUtils.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Flow {
// Returns true if the given |accessType| is compatible with the |variableType|.
// For example, this will return true if the variable type is a tensor<?xf32>
// and the access is tensor<4xf32>.
static bool isVariableTypeCompatible(Type variableType, Type accessType) {
return succeeded(mlir::verifyCompatibleShape(variableType, accessType));
}
//===----------------------------------------------------------------------===//
// flow.variable
//===----------------------------------------------------------------------===//
static ParseResult parseVariableOp(OpAsmParser &parser,
OperationState *result) {
StringAttr nameAttr;
if (failed(parser.parseSymbolName(nameAttr,
mlir::SymbolTable::getSymbolAttrName(),
result->attributes))) {
return failure();
}
if (succeeded(parser.parseOptionalKeyword("mutable"))) {
result->addAttribute("is_mutable", UnitAttr::get(result->getContext()));
}
if (succeeded(parser.parseOptionalKeyword("init"))) {
FlatSymbolRefAttr initializerAttr;
if (failed(parser.parseLParen()) ||
failed(parser.parseAttribute(initializerAttr, "initializer",
result->attributes)) ||
failed(parser.parseRParen())) {
return failure();
}
}
if (failed(parser.parseOptionalColon())) {
Attribute initialValueAttr;
if (failed(parser.parseAttribute(initialValueAttr, "initial_value",
result->attributes))) {
return failure();
}
result->addAttribute("type", TypeAttr::get(initialValueAttr.getType()));
} else {
Type type;
if (failed(parser.parseType(type))) {
return failure();
}
result->addAttribute("type", TypeAttr::get(type));
}
if (failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
return success();
}
static void printVariableOp(OpAsmPrinter &p, VariableOp op) {
p << op.getOperationName() << ' ';
p.printSymbolName(op.sym_name());
if (op.is_mutable()) {
p << " mutable";
}
if (op.initializer().hasValue()) {
p << " init(";
p.printSymbolName(op.initializer().getValue());
p << ')';
}
if (op.initial_value().hasValue()) {
p << ' ';
p.printAttribute(op.initial_value().getValue());
} else {
p << " : ";
p.printType(op.type());
}
p.printOptionalAttrDictWithKeyword(op.getAttrs(), /*elidedAttrs=*/{
"sym_name",
"type",
"is_mutable",
"initializer",
"initial_value",
});
}
static LogicalResult verifyVariableOp(VariableOp op) {
if (op.initializer().hasValue() && op.initial_value().hasValue()) {
return op.emitOpError()
<< "variables can have either an initializer or an initial value";
} else if (op.initializer().hasValue()) {
// Ensure initializer returns the same type as the variable.
auto *symbolOp =
SymbolTable::lookupNearestSymbolFrom(op, op.initializer().getValue());
if (!symbolOp) {
return op.emitOpError() << "initializer function "
<< op.initializer().getValue() << " not found";
}
auto initializerOp = dyn_cast<FuncOp>(symbolOp);
if (initializerOp.getNumArguments() != 0 ||
initializerOp.getNumResults() != 1 ||
initializerOp.getType().getResult(0) != op.type()) {
return op.emitOpError()
<< "initializer type mismatch; variable " << op.sym_name()
<< " is " << op.type() << " but initializer function "
<< initializerOp.getName() << " is " << initializerOp.getType();
}
} else if (op.initial_value().hasValue()) {
// Ensure the value is something we can store in the variable
if (!isVariableTypeCompatible(op.type(), op.initial_value()->getType())) {
return op.emitOpError()
<< "initial value type mismatch; variable " << op.sym_name()
<< " is " << op.type() << " but initial value provided is "
<< op.initial_value()->getType();
}
}
return success();
}
void VariableOp::build(OpBuilder &builder, OperationState &state,
StringRef name, bool isMutable, FuncOp initializer,
ArrayRef<NamedAttribute> attrs) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
if (isMutable) {
state.addAttribute("is_mutable", builder.getUnitAttr());
}
state.addAttribute("initializer", builder.getSymbolRefAttr(initializer));
state.addAttribute("type", TypeAttr::get(initializer.getType().getResult(0)));
state.attributes.append(attrs.begin(), attrs.end());
}
void VariableOp::build(OpBuilder &builder, OperationState &result,
StringRef name, bool isMutable, Type type,
Attribute initialValue, ArrayRef<NamedAttribute> attrs) {
result.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
if (isMutable) {
result.addAttribute("is_mutable", builder.getUnitAttr());
}
result.addAttribute("initial_value", initialValue);
result.addAttribute("type", TypeAttr::get(type));
result.attributes.append(attrs.begin(), attrs.end());
}
void VariableOp::build(OpBuilder &builder, OperationState &result,
StringRef name, bool isMutable, Type type,
ArrayRef<NamedAttribute> attrs) {
result.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
if (isMutable) {
result.addAttribute("is_mutable", builder.getUnitAttr());
}
result.addAttribute("type", TypeAttr::get(type));
result.attributes.append(attrs.begin(), attrs.end());
}
//===----------------------------------------------------------------------===//
// flow.variable.load
//===----------------------------------------------------------------------===//
void VariableLoadOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
// HACK: works around the lack of symbol side effects in mlir by only saying
// we have a side-effect if the variable we are loading is mutable.
auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(*this, variable());
assert(symbolOp);
auto variableOp = dyn_cast<VariableOp>(symbolOp);
if (variableOp.is_mutable()) {
effects.emplace_back(MemoryEffects::Read::get());
}
}
static LogicalResult verifyVariableLoadOp(VariableLoadOp &op) {
auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(op, op.variable());
if (!symbolOp) {
return op.emitOpError() << "undefined variable: " << op.variable();
}
auto variableOp = dyn_cast<VariableOp>(symbolOp);
auto loadType = op.result().getType();
if (!isVariableTypeCompatible(variableOp.type(), loadType)) {
return op.emitOpError()
<< "variable type mismatch; variable " << op.variable() << " is "
<< variableOp.type() << " but load is " << loadType;
}
return success();
}
//===----------------------------------------------------------------------===//
// flow.variable.load.indirect
//===----------------------------------------------------------------------===//
static LogicalResult verifyVariableLoadIndirectOp(VariableLoadIndirectOp &op) {
auto variableType =
op.variable().getType().cast<IREE::PtrType>().getTargetType();
auto loadType = op.result().getType();
if (!isVariableTypeCompatible(variableType, loadType)) {
return op.emitOpError() << "variable type mismatch; variable pointer is "
<< variableType << " but load is " << loadType;
}
return success();
}
//===----------------------------------------------------------------------===//
// flow.variable.store
//===----------------------------------------------------------------------===//
static LogicalResult verifyVariableStoreOp(VariableStoreOp &op) {
auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(op, op.variable());
if (!symbolOp) {
return op.emitOpError() << "undefined variable: " << op.variable();
}
auto variableOp = dyn_cast<VariableOp>(symbolOp);
auto storeType = op.value().getType();
if (!isVariableTypeCompatible(variableOp.type(), storeType)) {
return op.emitOpError()
<< "variable type mismatch; variable " << op.variable() << " is "
<< variableOp.type() << " but store is " << storeType;
}
if (!variableOp.is_mutable()) {
return op.emitOpError() << "variable " << op.variable()
<< " is not mutable and cannot be stored to";
}
return success();
}
//===----------------------------------------------------------------------===//
// flow.variable.store.indirect
//===----------------------------------------------------------------------===//
static LogicalResult verifyVariableStoreIndirectOp(
VariableStoreIndirectOp &op) {
auto variableType =
op.variable().getType().cast<IREE::PtrType>().getTargetType();
auto storeType = op.value().getType();
if (!isVariableTypeCompatible(variableType, storeType)) {
return op.emitOpError() << "variable type mismatch; variable pointer is "
<< variableType << " but store is " << storeType;
}
return success();
}
//===----------------------------------------------------------------------===//
// flow.dispatch.region
//===----------------------------------------------------------------------===//
/// Inlines operation |op| into the |dispatchRegionOp| by making all operands,
/// as well as values caputred implicitly by the regions of the operation, that
/// are outside the dispatch region operands of the dispatch region as well.
static Operation *inlineOpIntoDispatchRegion(OpBuilder &builder,
DispatchRegionOp dispatchRegionOp,
Operation *op,
BlockAndValueMapping &map) {
llvm::SetVector<Value> capturedInputs(op->getOperands().begin(),
op->getOperands().end());
getUsedValuesDefinedAbove(op->getRegions(), capturedInputs);
Block *block = builder.getInsertionBlock();
for (Value capturedInput : capturedInputs) {
if (map.contains(capturedInput)) continue;
dispatchRegionOp.getOperation()->insertOperands(
dispatchRegionOp.getOperation()->getNumOperands(), {capturedInput});
Value newBlockArgument = block->addArgument(capturedInput.getType());
map.map(capturedInput, newBlockArgument);
}
return builder.clone(*op, map);
}
llvm::Optional<std::pair<DispatchRegionOp, Operation *>>
DispatchRegionOp::formFromAnchorOp(Value workload, Operation *anchorOp,
OpBuilder &builder) {
builder.setInsertionPoint(anchorOp);
auto loc = anchorOp->getLoc();
// Map anchor into new dispatch region.
auto drOp = builder.create<DispatchRegionOp>(
loc, llvm::to_vector<1>(anchorOp->getResultTypes()), workload,
ArrayRef<Value>());
auto *drBlock = new Block();
drOp.body().push_back(drBlock);
BlockAndValueMapping mapping;
builder.setInsertionPointToEnd(drBlock);
Operation *newAnchorOp =
inlineOpIntoDispatchRegion(builder, drOp, anchorOp, mapping);
// Insert terminator
builder.create<IREE::Flow::ReturnOp>(loc, newAnchorOp->getResults());
// Replace anchor uses with region result.
for (auto it : llvm::enumerate(anchorOp->getResults())) {
it.value().replaceAllUsesWith(drOp.getResult(it.index()));
}
anchorOp->erase();
return std::make_pair(drOp, newAnchorOp);
}
void DispatchRegionOp::dceOperandsAndResults(DispatchRegionOp &op) {
OpBuilder builder(op.getContext());
ClosureOpDce dce(op, op.body().front(), /*variadicOffset=*/1);
op = llvm::cast<DispatchRegionOp>(dce.optimize(builder));
}
ResultRange DispatchRegionOp::appendResults(DispatchRegionOp &self,
ValueRange addlResults,
OpBuilder &builder) {
Block &block = self.body().front();
unsigned origNumResults = self.getNumResults();
llvm::SmallVector<Type, 4> newTypes(self.getResultTypes().begin(),
self.getResultTypes().end());
for (auto r : addlResults) newTypes.push_back(r.getType());
// Changing the arity of the results requires replacing the dispatch region.
builder.setInsertionPoint(self);
auto newDrOp = llvm::cast<DispatchRegionOp>(
builder.insert(cloneWithNewResultTypes(self, newTypes)));
self.replaceAllUsesWith(ResultRange(newDrOp, 0, origNumResults));
self.erase();
self = newDrOp;
// Add results to the terminator.
auto terminator = block.getTerminator();
llvm::SmallVector<Value, 4> returns(terminator->getOperands());
returns.append(addlResults.begin(), addlResults.end());
terminator->setOperands(returns);
return ResultRange(self, origNumResults, addlResults.size());
}
Operation *DispatchRegionOp::inlineOp(Operation *origOp, OpBuilder &builder,
bool positionAtEnd) {
Block &block = body().front();
if (positionAtEnd) {
builder.setInsertionPoint(block.getTerminator());
} else {
builder.setInsertionPointToStart(&block);
}
// Map existing dr args.
BlockAndValueMapping mapping;
for (unsigned i = 0, e = block.getNumArguments(); i < e; ++i) {
mapping.map(args()[i], block.getArgument(i));
}
// Also map any terminator operands to support inlining at the end.
for (auto it : llvm::enumerate(block.getTerminator()->getOperands())) {
mapping.map(getResult(it.index()), it.value());
}
// Remember the values corresponding to original op results.
llvm::SmallVector<Value, 4> origOpResultValues;
for (Value result : origOp->getResults()) {
origOpResultValues.push_back(mapping.lookupOrNull(result));
}
Operation *inlinedOp =
inlineOpIntoDispatchRegion(builder, *this, origOp, mapping);
// Replace any results from the orig with results from the clone.
for (unsigned i = 0, e = origOp->getNumResults(); i < e; ++i) {
Value resultTo = origOpResultValues[i];
if (resultTo) {
resultTo.replaceAllUsesWith(inlinedOp->getResult(i));
}
}
return inlinedOp;
}
void DispatchRegionOp::build(OpBuilder &builder, OperationState &state,
ArrayRef<Type> resultTypes, Value workload,
ValueRange args,
ArrayRef<NamedAttribute> attributes) {
state.addTypes(resultTypes);
state.addOperands({workload});
state.addOperands(args);
state.addAttributes(attributes);
state.addRegion();
}
ParseResult parseDispatchRegionOp(OpAsmParser &parser, OperationState *result) {
// Parse required workload.
OpAsmParser::OperandType workloadArg;
Type workloadArgType;
if (failed(parser.parseLSquare()) ||
failed(parser.parseOperand(workloadArg)) ||
failed(parser.parseColonType(workloadArgType)) ||
failed(parser.parseRSquare()) ||
failed(parser.resolveOperand(workloadArg, workloadArgType,
result->operands))) {
return failure();
}
// Parse (optional) args.
SmallVector<OpAsmParser::OperandType, 16> regionArgs;
SmallVector<Type, 16> regionArgTypes;
if (failed(parser.parseLParen())) {
return failure();
}
if (failed(parser.parseOptionalRParen())) {
SmallVector<OpAsmParser::OperandType, 16> regionOperands;
auto argsLoc = parser.getCurrentLocation();
do {
// Reserve entries in the lists.
regionArgs.emplace_back();
regionOperands.emplace_back();
regionArgTypes.emplace_back();
if (failed(parser.parseRegionArgument(regionArgs.back())) ||
failed(parser.parseEqual()) ||
failed(parser.parseOperand(regionOperands.back())) ||
failed(parser.parseColonType(regionArgTypes.back()))) {
return failure();
}
} while (succeeded(parser.parseOptionalComma()));
if (failed(parser.parseRParen()) ||
failed(parser.resolveOperands(regionOperands, regionArgTypes, argsLoc,
result->operands))) {
return failure();
}
}
// Parse (optional) results.
if (failed(parser.parseOptionalArrowTypeList(result->types))) {
return failure();
}
// Parse region body.
Region *body = result->addRegion();
if (failed(parser.parseRegion(*body, regionArgs, regionArgTypes)) ||
failed(parser.parseOptionalAttrDict(result->attributes))) {
return failure();
}
return success();
}
void printDispatchRegionOp(OpAsmPrinter &p, DispatchRegionOp op) {
p << op.getOperationName();
// Print the workload argument.
p << "[";
p.printOperand(op.workload());
p << " : ";
p.printType(op.workload().getType());
p << "]";
// Print the data argument remapping.
p << "(";
interleaveComma(llvm::zip(op.body().getArguments(), op.args()), p,
[&](std::tuple<BlockArgument, Value> it) {
p << std::get<0>(it) << " = " << std::get<1>(it);
p << " : ";
p << std::get<1>(it).getType();
});
p << ")";
// Print the result types, if any.
if (op.getNumResults() > 0) {
p << " -> ";
if (op.getNumResults() > 1) p << "(";
interleaveComma(op.getResultTypes(), p);
if (op.getNumResults() > 1) p << ")";
}
p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{});
}
//===----------------------------------------------------------------------===//
// flow.dispatch.workgroups
//===----------------------------------------------------------------------===//
void DispatchWorkgroupsOp::build(OpBuilder &builder, OperationState &state,
ValueRange workgroupCount,
TypeRange resultTypes, ValueRange operands,
ArrayRef<NamedAttribute> attributes) {
state.addOperands(workgroupCount);
state.addTypes(resultTypes);
state.addOperands(operands);
state.addAttributes(attributes);
auto *body = state.addRegion();
for (auto operand : operands) {
Type type = operand.getType();
if (auto tensorType = type.dyn_cast<TensorType>()) {
type = DispatchInputType::get(tensorType);
}
body->addArgument(type);
}
for (auto resultType : resultTypes) {
Type type = resultType;
if (auto tensorType = type.dyn_cast<TensorType>()) {
type = DispatchOutputType::get(tensorType);
}
body->addArgument(type);
}
}
static ParseResult parseDispatchWorkgroupBody(OpAsmParser &parser,
TypeRange operandTypes,
TypeRange resultTypes,
Region &body) {
auto loc = parser.getCurrentLocation();
SmallVector<OpAsmParser::OperandType, 16> regionArgs;
SmallVector<Type, 16> regionArgTypes;
if (failed(parser.parseLParen())) {
return failure();
}
if (failed(parser.parseOptionalRParen())) {
do {
// Reserve entries in the lists.
regionArgs.emplace_back();
regionArgTypes.emplace_back();
if (failed(parser.parseRegionArgument(regionArgs.back())) ||
failed(parser.parseColonType(regionArgTypes.back()))) {
return failure();
}
} while (succeeded(parser.parseOptionalComma()));
if (failed(parser.parseRParen())) {
return failure();
}
}
if (regionArgs.size() != operandTypes.size() + resultTypes.size()) {
return parser.emitError(loc,
"region operand list required required to match "
"count of dispatch op operands + results");
}
return parser.parseRegion(body, regionArgs, regionArgTypes,
/*enableNameShadowing=*/true);
}
static void printDispatchWorkgroupBody(OpAsmPrinter &p, Operation *op,
TypeRange operandTypes,
TypeRange resultTypes, Region &body) {
p << "(";
interleaveComma(body.getArguments(), p, [&](BlockArgument arg) {
p << arg;
p << " : ";
p << arg.getType();
});
p << ")";
p.printRegion(body, /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
}
// TODO(benvanik): remove after https://bugs.llvm.org/show_bug.cgi?id=48478
// The parser/printer are modified autogenerated values to work around the bug.
static ::mlir::ParseResult parseDispatchWorkgroupsOp(
::mlir::OpAsmParser &parser, ::mlir::OperationState *result) {
::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4>
workgroup_countOperands;
::llvm::SMLoc workgroup_countOperandsLoc;
(void)workgroup_countOperandsLoc;
::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> operandsOperands;
::llvm::SMLoc operandsOperandsLoc;
(void)operandsOperandsLoc;
::llvm::ArrayRef<::mlir::Type> operandsTypes;
::llvm::ArrayRef<::mlir::Type> resultsTypes;
std::unique_ptr<::mlir::Region> bodyRegion =
std::make_unique<::mlir::Region>();
if (parser.parseLSquare()) return ::mlir::failure();
workgroup_countOperandsLoc = parser.getCurrentLocation();
if (parser.parseOperandList(workgroup_countOperands))
return ::mlir::failure();
if (parser.parseRSquare()) return ::mlir::failure();
if (parser.parseLParen()) return ::mlir::failure();
operandsOperandsLoc = parser.getCurrentLocation();
if (parser.parseOperandList(operandsOperands)) return ::mlir::failure();
if (parser.parseRParen()) return ::mlir::failure();
if (parser.parseColon()) return ::mlir::failure();
::mlir::FunctionType operands__results_functionType;
if (parser.parseType(operands__results_functionType))
return ::mlir::failure();
operandsTypes = operands__results_functionType.getInputs();
resultsTypes = operands__results_functionType.getResults();
if (parser.parseOptionalAttrDictWithKeyword(result->attributes))
return ::mlir::failure();
if (parser.parseEqual()) return ::mlir::failure();
{
if (parseDispatchWorkgroupBody(parser, operandsTypes, resultsTypes,
*bodyRegion))
return ::mlir::failure();
}
::mlir::Type odsBuildableType0 = parser.getBuilder().getIndexType();
result->addTypes(resultsTypes);
if (parser.resolveOperands(workgroup_countOperands, odsBuildableType0,
result->operands))
return ::mlir::failure();
if (parser.resolveOperands(operandsOperands, operandsTypes,
operandsOperandsLoc, result->operands))
return ::mlir::failure();
result->addRegion(std::move(bodyRegion));
result->addAttribute(
"operand_segment_sizes",
parser.getBuilder().getI32VectorAttr(
{static_cast<int32_t>(workgroup_countOperands.size()),
static_cast<int32_t>(operandsOperands.size())}));
return ::mlir::success();
}
static void printDispatchWorkgroupsOp(::mlir::OpAsmPrinter &p,
DispatchWorkgroupsOp &op) {
p << "flow.dispatch.workgroups";
p << "[";
p << op.workgroup_count();
p << "]";
p << ' ' << "(";
p << op.operands();
p << ")";
p << ' ' << ":";
p << ' ';
p.printFunctionalType(op.operands().getTypes(), op.results().getTypes());
p.printOptionalAttrDictWithKeyword(op.getAttrs(), /*elidedAttrs=*/{
"operand_segment_sizes",
});
p << ' ' << "=";
p << ' ';
printDispatchWorkgroupBody(p, op, op.operands().getTypes(),
op.results().getTypes(), op.body());
}
static LogicalResult verifyDispatchWorkgroupsOp(DispatchWorkgroupsOp op) {
if (op.workgroup_count().empty()) {
return op.emitOpError() << "at least one workgroup dimension is required";
}
return success();
}
//===----------------------------------------------------------------------===//
// flow.dispatch.workgroup.*
//===----------------------------------------------------------------------===//
void DispatchWorkgroupRankOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "workgroup_rank");
}
static void getAsmResultNamesForDispatchWorkgroupInfoOp(
StringRef prefix, const APInt &dimension, Value result,
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result, (prefix + std::to_string(dimension.getZExtValue())).str());
}
void DispatchWorkgroupIDOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
getAsmResultNamesForDispatchWorkgroupInfoOp("workgroup_id_", dimension(),
result(), setNameFn);
}
void DispatchWorkgroupCountOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
getAsmResultNamesForDispatchWorkgroupInfoOp("workgroup_count_", dimension(),
result(), setNameFn);
}
void DispatchWorkgroupSizeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
getAsmResultNamesForDispatchWorkgroupInfoOp("workgroup_size_", dimension(),
result(), setNameFn);
}
template <typename T>
static LogicalResult verifyDispatchWorkgroupInfoOp(T op) {
size_t dimCount = 0;
if (auto dispatchOp = op.template getParentOfType<DispatchWorkgroupsOp>()) {
dimCount = dispatchOp.workgroup_count().size();
}
uint64_t dimension = op.dimension().getZExtValue();
if (dimCount != 0 && (dimension < 0 || dimension >= dimCount)) {
return op.emitOpError()
<< "dimension " << dimension
<< " out of bounds of dispatch dimensions; expected [0, "
<< (dimCount - 1) << ")";
}
return success();
}
//===----------------------------------------------------------------------===//
// flow.dispatch.shape
//===----------------------------------------------------------------------===//
void DispatchShapeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
// TODO(benvanik): since we know these are arguments, we could map them based
// on index (so we get arg0_shape, ret0_shape, etc).
setNameFn(result(), "shape");
}
//===----------------------------------------------------------------------===//
// flow.executable
//===----------------------------------------------------------------------===//
void ExecutableOp::build(OpBuilder &builder, OperationState &state,
StringRef name) {
ensureTerminator(*state.addRegion(), builder, state.location);
state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
}
static ParseResult parseExecutableOp(OpAsmParser &parser,
OperationState *result) {
StringAttr nameAttr;
if (failed(parser.parseSymbolName(nameAttr,
mlir::SymbolTable::getSymbolAttrName(),
result->attributes)) ||
failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
// Parse the module body.
auto *body = result->addRegion();
if (failed(parser.parseRegion(*body, llvm::None, llvm::None))) {
return failure();
}
// Ensure that this module has a valid terminator.
ExecutableOp::ensureTerminator(*body, parser.getBuilder(), result->location);
return success();
}
static void printExecutableOp(OpAsmPrinter &p, ExecutableOp op) {
p << op.getOperationName() << ' ';
p.printSymbolName(op.sym_name());
p.printOptionalAttrDictWithKeyword(
op.getAttrs(),
/*elidedAttrs=*/{mlir::SymbolTable::getSymbolAttrName()});
p.printRegion(op.body(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
}
static LogicalResult verifyExecutableOp(ExecutableOp op) {
// TODO(benvanik): check export name conflicts.
return success();
}
//===----------------------------------------------------------------------===//
// flow.dispatch.entry
//===----------------------------------------------------------------------===//
static ParseResult parseDispatchEntryOp(OpAsmParser &parser,
OperationState *result) {
FlatSymbolRefAttr functionRefAttr;
if (failed(parser.parseAttribute(functionRefAttr, "function_ref",
result->attributes))) {
return failure();
}
if (succeeded(parser.parseOptionalKeyword("as"))) {
StringAttr exportNameAttr;
if (failed(parser.parseLParen()) ||
failed(parser.parseAttribute(exportNameAttr, "sym_name",
result->attributes)) ||
failed(parser.parseRParen())) {
return failure();
}
} else {
result->addAttribute("sym_name", parser.getBuilder().getStringAttr(
functionRefAttr.getValue()));
}
if (failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
return success();
}
static void printDispatchEntryOp(OpAsmPrinter &p, DispatchEntryOp op) {
p << op.getOperationName() << ' ';
p.printSymbolName(op.function_ref());
if (op.sym_name() != op.function_ref()) {
p << " as(\"" << op.sym_name() << "\")";
}
p.printOptionalAttrDictWithKeyword(
op.getAttrs(), /*elidedAttrs=*/{"function_ref", "sym_name"});
}
//===----------------------------------------------------------------------===//
// flow.dispatch
//===----------------------------------------------------------------------===//
static ParseResult parseDispatchOp(OpAsmParser &parser,
OperationState *result) {
SymbolRefAttr entryPointAttr;
if (failed(parser.parseAttribute(entryPointAttr, "entry_point",
result->attributes))) {
return failure();
}
OpAsmParser::OperandType workloadArg;
Type workloadArgType;
if (failed(parser.parseLSquare()) ||
failed(parser.parseOperand(workloadArg)) ||
failed(parser.parseColonType(workloadArgType)) ||
failed(parser.parseRSquare()) ||
failed(parser.resolveOperand(workloadArg, workloadArgType,
result->operands))) {
return failure();
}
SmallVector<OpAsmParser::OperandType, 4> operands;
FunctionType entryPointType;
if (failed(
parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren)) ||
failed(parser.parseOptionalAttrDict(result->attributes)) ||
failed(parser.parseColonType(entryPointType)) ||
failed(
parser.addTypesToList(entryPointType.getResults(), result->types)) ||
failed(parser.resolveOperands(operands, entryPointType.getInputs(),
parser.getNameLoc(), result->operands))) {
return failure();
}
return success();
}
static void printDispatchOp(OpAsmPrinter &p, DispatchOp op) {
p << op.getOperationName() << ' ';
p.printAttributeWithoutType(op.entry_point());
p << "[";
p.printOperand(op.workload());
p << " : ";
p.printType(op.workload().getType());
p << "](";
p.printOperands(op.operands());
p << ')';
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"entry_point"});
p << " : ";
p.printType(op.getEntryPointType());
}
void DispatchOp::build(OpBuilder &builder, OperationState &state,
DispatchEntryOp entryPoint, Value workload,
ArrayRef<Type> results, ValueRange operands) {
state.addOperands({workload});
state.addOperands(operands);
// Construct Executable::Entry nested reference.
StringRef executableOpSymName =
entryPoint->getParentOp()
->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
.getValue();
state.addAttribute(
"entry_point",
builder.getSymbolRefAttr(executableOpSymName,
{builder.getSymbolRefAttr(entryPoint)}));
state.addTypes(results);
}
StringRef DispatchOp::executable() { return entry_point().getRootReference(); }
FunctionType DispatchOp::getEntryPointType() {
SmallVector<Type, 8> argTypes(operand_type_range{operands()});
return FunctionType::get(getContext(), argTypes, getResultTypes());
}
//===----------------------------------------------------------------------===//
// flow.dispatch2
//===----------------------------------------------------------------------===//
void Dispatch2Op::build(OpBuilder &builder, OperationState &state,
DispatchEntryOp entryPoint, ValueRange workgroupCount,
TypeRange results, ValueRange operands,
ArrayRef<NamedAttribute> attributes) {
StringRef executableOpSymName =
entryPoint->getParentOp()
->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
.getValue();
state.addAttribute(
"entry_point",
builder.getSymbolRefAttr(executableOpSymName,
{builder.getSymbolRefAttr(entryPoint)}));
state.addOperands(workgroupCount);
state.addTypes(results);
state.addOperands(operands);
state.addAttributes(attributes);
state.addAttribute(
"operand_segment_sizes",
builder.getI32VectorAttr({static_cast<int32_t>(workgroupCount.size()),
static_cast<int32_t>(operands.size())}));
}
StringRef Dispatch2Op::executable() { return entry_point().getRootReference(); }
FunctionType Dispatch2Op::getEntryPointType() {
SmallVector<Type, 8> argTypes(operand_type_range{operands()});
return FunctionType::get(getContext(), argTypes, getResultTypes());
}
static LogicalResult verifyDispatch2Op(Dispatch2Op op) {
if (op.workgroup_count().empty()) {
return op.emitOpError() << "at least one workgroup dimension is required";
}
return success();
}
//===----------------------------------------------------------------------===//
// flow.ex.stream.fragment
//===----------------------------------------------------------------------===//
void ExStreamFragmentOp::build(OpBuilder &builder, OperationState &state,
ArrayRef<Type> resultTypes, ValueRange operands,
ArrayRef<NamedAttribute> attributes) {
state.addTypes(resultTypes);
state.addOperands(operands);
state.addAttributes(attributes);
state.addRegion();
}
ParseResult parseExStreamFragmentOp(OpAsmParser &parser,
OperationState *result) {
SmallVector<OpAsmParser::OperandType, 16> regionArgs;
SmallVector<Type, 16> regionArgTypes;
if (failed(parser.parseLParen())) {
return failure();
}
if (failed(parser.parseOptionalRParen())) {
SmallVector<OpAsmParser::OperandType, 16> regionOperands;
auto argsLoc = parser.getCurrentLocation();
do {
// Reserve entries in the lists.
regionArgs.emplace_back();
regionOperands.emplace_back();
regionArgTypes.emplace_back();
if (failed(parser.parseRegionArgument(regionArgs.back())) ||
failed(parser.parseEqual()) ||
failed(parser.parseOperand(regionOperands.back())) ||
failed(parser.parseColonType(regionArgTypes.back()))) {
return failure();
}
} while (succeeded(parser.parseOptionalComma()));
if (failed(parser.parseRParen()) ||
failed(parser.resolveOperands(regionOperands, regionArgTypes, argsLoc,
result->operands))) {
return failure();
}
}
// Parse (optional) results.
if (failed(parser.parseOptionalArrowTypeList(result->types))) {
return failure();
}
// Parse region body.
Region *body = result->addRegion();
if (failed(parser.parseRegion(*body, regionArgs, regionArgTypes)) ||
failed(parser.parseOptionalAttrDict(result->attributes))) {
return failure();
}
return success();
}
void printExStreamFragmentOp(OpAsmPrinter &p, ExStreamFragmentOp op) {
p << op.getOperationName();
// Print the data argument remapping.
p << "(";
interleaveComma(llvm::zip(op.body().getArguments(), op.args()), p,
[&](std::tuple<BlockArgument, Value> it) {
p << std::get<0>(it) << " = " << std::get<1>(it);
p << " : ";
p << std::get<1>(it).getType();
});
p << ")";
// Print the result types, if any.
if (op.getNumResults() > 0) {
p << " -> ";
if (op.getNumResults() > 1) p << "(";
interleaveComma(op.getResultTypes(), p);
if (op.getNumResults() > 1) p << ")";
}
p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{});
}
} // namespace Flow
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir
//===----------------------------------------------------------------------===//
// TableGen definitions (intentionally last)
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "iree/compiler/Dialect/Flow/IR/FlowOps.cpp.inc"