blob: 19a5a7f73b10bc8a5a98a8a8204b4229b0907bcb [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
#include "llvm/ADT/StringExtras.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Flow {
// Returns true if the given |accessType| is compatible with the |variableType|.
// For example, this will return true if the variable type is a tensor<?xf32>
// and the access is tensor<4xf32>.
static bool isVariableTypeCompatible(Type variableType, Type accessType) {
return succeeded(mlir::verifyCompatibleShape(variableType, accessType));
}
//===----------------------------------------------------------------------===//
// flow.variable
//===----------------------------------------------------------------------===//
static ParseResult parseVariableOp(OpAsmParser &parser,
OperationState *result) {
StringAttr nameAttr;
if (failed(parser.parseSymbolName(nameAttr,
mlir::SymbolTable::getSymbolAttrName(),
result->attributes))) {
return failure();
}
if (succeeded(parser.parseOptionalKeyword("mutable"))) {
result->addAttribute("is_mutable", UnitAttr::get(result->getContext()));
}
if (succeeded(parser.parseOptionalKeyword("init"))) {
FlatSymbolRefAttr initializerAttr;
if (failed(parser.parseLParen()) ||
failed(parser.parseAttribute(initializerAttr, "initializer",
result->attributes)) ||
failed(parser.parseRParen())) {
return failure();
}
}
if (failed(parser.parseOptionalColon())) {
Attribute initialValueAttr;
if (failed(parser.parseAttribute(initialValueAttr, "initial_value",
result->attributes))) {
return failure();
}
result->addAttribute("type", TypeAttr::get(initialValueAttr.getType()));
} else {
Type type;
if (failed(parser.parseType(type))) {
return failure();
}
result->addAttribute("type", TypeAttr::get(type));
}
return success();
}
static void printVariableOp(OpAsmPrinter &p, VariableOp op) {
p << op.getOperationName() << ' ';
p.printSymbolName(op.sym_name());
if (op.is_mutable()) {
p << " mutable";
}
if (op.initializer().hasValue()) {
p << " init(";
p.printSymbolName(op.initializer().getValue());
p << ')';
}
if (op.initial_value().hasValue()) {
p << ' ';
p.printAttribute(op.initial_value().getValue());
} else {
p << " : ";
p.printType(op.type());
}
}
static LogicalResult verifyVariableOp(VariableOp op) {
if (op.initializer().hasValue() && op.initial_value().hasValue()) {
return op.emitOpError()
<< "variables can have either an initializer or an initial value";
} else if (op.initializer().hasValue()) {
// Ensure initializer returns the same type as the variable.
auto *symbolOp =
SymbolTable::lookupNearestSymbolFrom(op, op.initializer().getValue());
if (!symbolOp) {
return op.emitOpError() << "initializer function "
<< op.initializer().getValue() << " not found";
}
auto initializerOp = dyn_cast<FuncOp>(symbolOp);
if (initializerOp.getNumArguments() != 0 ||
initializerOp.getNumResults() != 1 ||
initializerOp.getType().getResult(0) != op.type()) {
return op.emitOpError()
<< "initializer type mismatch; variable " << op.sym_name()
<< " is " << op.type() << " but initializer function "
<< initializerOp.getName() << " is " << initializerOp.getType();
}
} else if (op.initial_value().hasValue()) {
// Ensure the value is something we can store in the variable
if (!isVariableTypeCompatible(op.type(), op.initial_value()->getType())) {
return op.emitOpError()
<< "initial value type mismatch; variable " << op.sym_name()
<< " is " << op.type() << " but initial value provided is "
<< op.initial_value()->getType();
}
}
return success();
}
void VariableOp::build(Builder *builder, OperationState &state, StringRef name,
bool isMutable, FuncOp initializer,
ArrayRef<NamedAttribute> attrs) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder->getStringAttr(name));
if (isMutable) {
state.addAttribute("is_mutable", builder->getUnitAttr());
}
state.addAttribute("initializer", builder->getSymbolRefAttr(initializer));
state.addAttribute("type", TypeAttr::get(initializer.getType().getResult(0)));
state.attributes.append(attrs.begin(), attrs.end());
}
void VariableOp::build(Builder *builder, OperationState &result, StringRef name,
bool isMutable, Type type, Attribute initialValue,
ArrayRef<NamedAttribute> attrs) {
result.addAttribute(SymbolTable::getSymbolAttrName(),
builder->getStringAttr(name));
if (isMutable) {
result.addAttribute("is_mutable", builder->getUnitAttr());
}
result.addAttribute("initial_value", initialValue);
result.addAttribute("type", TypeAttr::get(type));
result.attributes.append(attrs.begin(), attrs.end());
}
void VariableOp::build(Builder *builder, OperationState &result, StringRef name,
bool isMutable, Type type,
ArrayRef<NamedAttribute> attrs) {
result.addAttribute(SymbolTable::getSymbolAttrName(),
builder->getStringAttr(name));
if (isMutable) {
result.addAttribute("is_mutable", builder->getUnitAttr());
}
result.addAttribute("type", TypeAttr::get(type));
result.attributes.append(attrs.begin(), attrs.end());
}
//===----------------------------------------------------------------------===//
// flow.variable.load
//===----------------------------------------------------------------------===//
static ParseResult parseVariableLoadOp(OpAsmParser &parser,
OperationState *result) {
FlatSymbolRefAttr variableAttr;
Type valueType;
if (failed(parser.parseAttribute(variableAttr, "variable",
result->attributes)) ||
failed(parser.parseOptionalAttrDict(result->attributes)) ||
failed(parser.parseColonType(valueType))) {
return failure();
}
result->addTypes({valueType});
return success();
}
static void printVariableLoadOp(OpAsmPrinter &p, VariableLoadOp &op) {
p << op.getOperationName() << ' ';
p.printSymbolName(op.variable());
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"variable"});
p << " : ";
p.printType(op.result()->getType());
}
static LogicalResult verifyVariableLoadOp(VariableLoadOp &op) {
auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(op, op.variable());
if (!symbolOp) {
return op.emitOpError() << "undefined variable: " << op.variable();
}
auto variableOp = dyn_cast<VariableOp>(symbolOp);
auto loadType = op.result()->getType();
if (!isVariableTypeCompatible(variableOp.type(), loadType)) {
return op.emitOpError()
<< "variable type mismatch; variable " << op.variable() << " is "
<< variableOp.type() << " but load is " << loadType;
}
return success();
}
//===----------------------------------------------------------------------===//
// flow.variable.store
//===----------------------------------------------------------------------===//
static ParseResult parseVariableStoreOp(OpAsmParser &parser,
OperationState *result) {
OpAsmParser::OperandType value;
FlatSymbolRefAttr variableAttr;
Type valueType;
if (failed(parser.parseOperand(value)) || failed(parser.parseComma()) ||
failed(parser.parseAttribute(variableAttr, "variable",
result->attributes)) ||
failed(parser.parseOptionalAttrDict(result->attributes)) ||
failed(parser.parseColonType(valueType)) ||
failed(parser.resolveOperand(value, valueType, result->operands))) {
return failure();
}
return success();
}
static void printVariableStoreOp(OpAsmPrinter &p, VariableStoreOp &op) {
p << op.getOperationName() << ' ';
p.printOperand(op.value());
p << ", ";
p.printSymbolName(op.variable());
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"variable"});
p << " : ";
p.printType(op.value()->getType());
}
static LogicalResult verifyVariableStoreOp(VariableStoreOp &op) {
auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(op, op.variable());
if (!symbolOp) {
return op.emitOpError() << "undefined variable: " << op.variable();
}
auto variableOp = dyn_cast<VariableOp>(symbolOp);
auto storeType = op.value()->getType();
if (!isVariableTypeCompatible(variableOp.type(), storeType)) {
return op.emitOpError()
<< "variable type mismatch; variable " << op.variable() << " is "
<< variableOp.type() << " but store is " << storeType;
}
if (!variableOp.is_mutable()) {
return op.emitOpError() << "variable " << op.variable()
<< " is not mutable and cannot be stored to";
}
return success();
}
//===----------------------------------------------------------------------===//
// flow.dispatch.region
//===----------------------------------------------------------------------===//
void DispatchRegionOp::build(Builder *builder, OperationState &state,
ArrayRef<Type> resultTypes, Value workload,
ValueRange operands,
ArrayRef<NamedAttribute> attributes) {
state.addTypes(resultTypes);
state.addOperands({workload});
state.addOperands(operands);
state.addAttributes(attributes);
state.addRegion();
state.setOperandListToResizable();
}
ParseResult parseDispatchRegionOp(OpAsmParser &parser, OperationState *result) {
// Parse required workload.
OpAsmParser::OperandType workloadArg;
Type workloadArgType;
if (failed(parser.parseLSquare()) ||
failed(parser.parseOperand(workloadArg)) ||
failed(parser.parseColonType(workloadArgType)) ||
failed(parser.parseRSquare()) ||
failed(parser.resolveOperand(workloadArg, workloadArgType,
result->operands))) {
return failure();
}
// Parse (optional) args.
SmallVector<OpAsmParser::OperandType, 16> regionArgs;
SmallVector<Type, 16> regionArgTypes;
if (failed(parser.parseLParen())) {
return failure();
}
if (failed(parser.parseOptionalRParen())) {
SmallVector<OpAsmParser::OperandType, 16> regionOperands;
auto argsLoc = parser.getCurrentLocation();
do {
// Reserve entries in the lists.
regionArgs.emplace_back();
regionOperands.emplace_back();
regionArgTypes.emplace_back();
if (failed(parser.parseRegionArgument(regionArgs.back())) ||
failed(parser.parseEqual()) ||
failed(parser.parseOperand(regionOperands.back())) ||
failed(parser.parseColonType(regionArgTypes.back()))) {
return failure();
}
} while (succeeded(parser.parseOptionalComma()));
if (failed(parser.parseRParen()) ||
failed(parser.resolveOperands(regionOperands, regionArgTypes, argsLoc,
result->operands))) {
return failure();
}
}
result->setOperandListToResizable();
// Parse (optional) results.
if (failed(parser.parseOptionalArrowTypeList(result->types))) {
return failure();
}
// Parse region body.
Region *body = result->addRegion();
if (failed(parser.parseRegion(*body, regionArgs, regionArgTypes)) ||
failed(parser.parseOptionalAttrDict(result->attributes))) {
return failure();
}
return success();
}
void printDispatchRegionOp(OpAsmPrinter &p, DispatchRegionOp op) {
p << op.getOperationName();
// Print the workload argument.
p << "[";
p.printOperand(op.workload());
p << " : ";
p.printType(op.workload()->getType());
p << "]";
// Print the data argument remapping.
p << "(";
interleaveComma(llvm::zip(op.body().front().getArguments(), op.args()), p,
[&](std::tuple<BlockArgument, Value> it) {
p << *std::get<0>(it) << " = " << *std::get<1>(it);
p << " : ";
p << std::get<1>(it)->getType();
});
p << ")";
// Print the result types, if any.
if (op.getNumResults() > 0) {
p << " -> ";
if (op.getNumResults() > 1) p << "(";
interleaveComma(op.getResultTypes(), p);
if (op.getNumResults() > 1) p << ")";
}
p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{});
}
//===----------------------------------------------------------------------===//
// flow.reduction.region
//===----------------------------------------------------------------------===//
void ReductionRegionOp::build(Builder *builder, OperationState &state,
ArrayRef<Type> resultTypes, Value workload,
ValueRange operands, ValueRange initialValues,
ArrayRef<int32_t> dimensions,
ArrayRef<NamedAttribute> attributes) {
state.addTypes(resultTypes);
state.addOperands({workload});
state.addOperands(operands);
state.addOperands(initialValues);
state.addAttribute(
"dimensions",
DenseIntElementsAttr::get(
VectorType::get({static_cast<int32_t>(dimensions.size())},
builder->getIntegerType(32)),
dimensions));
state.addAttributes(attributes);
state.addRegion();
state.setOperandListToResizable();
}
ParseResult parseReductionRegionOp(OpAsmParser &parser,
OperationState *result) {
OpAsmParser::OperandType workloadArg;
Type workloadArgType;
if (failed(parser.parseLSquare()) ||
failed(parser.parseOperand(workloadArg)) ||
failed(parser.parseColonType(workloadArgType)) ||
failed(parser.parseRSquare()) ||
failed(parser.resolveOperand(workloadArg, workloadArgType,
result->operands))) {
return failure();
}
SmallVector<OpAsmParser::OperandType, 8> reductionOperands;
Type reductionType;
auto operandsLoc = parser.getCurrentLocation();
if (failed(parser.parseLParen()) ||
failed(parser.parseOperandList(reductionOperands)) ||
failed(parser.parseRParen()) ||
failed(parser.parseColonType(reductionType)) ||
failed(parser.resolveOperands(
reductionOperands, reductionType.cast<FunctionType>().getInputs(),
operandsLoc, result->operands))) {
return failure();
}
for (auto type : reductionType.cast<FunctionType>().getResults()) {
result->types.push_back(type);
}
result->setOperandListToResizable();
SmallVector<OpAsmParser::OperandType, 8> regionArgs;
SmallVector<Type, 8> regionArgTypes;
if (failed(parser.parseKeyword("invocation")) ||
failed(parser.parseLParen())) {
return failure();
}
do {
Type argType;
SmallVector<OpAsmParser::OperandType, 2> reductionRegionArgs;
OpAsmParser::OperandType initialValue;
if (failed(parser.parseLParen()) ||
failed(parser.parseOperandList(reductionRegionArgs, 2)) ||
failed(parser.parseRParen()) || failed(parser.parseEqual()) ||
failed(parser.parseOperand(initialValue)) ||
failed(parser.parseColonType(argType)) ||
failed(
parser.resolveOperand(initialValue, argType, result->operands))) {
return failure();
}
regionArgs.push_back(reductionRegionArgs[0]);
regionArgTypes.push_back(argType);
regionArgs.push_back(reductionRegionArgs[1]);
regionArgTypes.push_back(argType);
} while (succeeded(parser.parseOptionalComma()));
if (failed(parser.parseRParen())) {
return failure();
}
// Parse region body.
Region *body = result->addRegion();
if (failed(parser.parseRegion(*body, regionArgs, regionArgTypes)) ||
failed(parser.parseOptionalAttrDict(result->attributes))) {
return failure();
}
return success();
}
void printReductionRegionOp(OpAsmPrinter &p, ReductionRegionOp op) {
p << op.getOperationName();
// Print the workload argument.
p << "[";
p.printOperand(op.workload());
p << " : ";
p.printType(op.workload()->getType());
p << "]";
p << "(";
p.printOperands(op.operands());
p << ")";
if (op.getNumResults() > 0) {
p << " : (";
interleaveComma(op.operands(), p,
[&](Value operand) { p.printType(operand->getType()); });
p << ")";
p << " -> ";
if (op.getNumResults() > 1) p << "(";
interleaveComma(op.getResultTypes(), p);
if (op.getNumResults() > 1) p << ")";
}
p << "\n";
p << " invocation(";
auto &entryBlock = op.body().getBlocks().front();
int regionArgIndex = 0;
interleaveComma(op.initial_values(), p, [&](Value operand) {
p << "(";
p.printOperand(entryBlock.getArgument(regionArgIndex++));
p << ", ";
p.printOperand(entryBlock.getArgument(regionArgIndex++));
p << ") = ";
p.printOperand(operand);
p << " : ";
p.printType(operand->getType());
});
p << ") ";
p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{});
}
//===----------------------------------------------------------------------===//
// flow.windowed_reduction.region
//===----------------------------------------------------------------------===//
void WindowedReductionRegionOp::build(
Builder *builder, OperationState &state, ArrayRef<Type> resultTypes,
Value workload, ValueRange operands, ValueRange initialValues,
ArrayRef<int32_t> windowDimensions, ArrayRef<int32_t> windowStrides,
ArrayRef<int32_t> baseDilations, ArrayRef<int32_t> windowDilations,
PaddingMode paddingMode, ArrayRef<NamedAttribute> attributes) {
state.addTypes(resultTypes);
state.addOperands({workload});
state.addOperands(operands);
state.addOperands(initialValues);
state.addAttribute(
"window_dimensions",
DenseIntElementsAttr::get(
VectorType::get({static_cast<int32_t>(windowDimensions.size())},
builder->getIntegerType(32)),
windowDimensions));
state.addAttribute(
"window_strides",
DenseIntElementsAttr::get(
VectorType::get({static_cast<int32_t>(windowStrides.size())},
builder->getIntegerType(32)),
windowStrides));
state.addAttribute(
"base_dilations",
DenseIntElementsAttr::get(
VectorType::get({static_cast<int32_t>(baseDilations.size())},
builder->getIntegerType(32)),
baseDilations));
state.addAttribute(
"window_dilations",
DenseIntElementsAttr::get(
VectorType::get({static_cast<int32_t>(windowDilations.size())},
builder->getIntegerType(32)),
windowDilations));
state.addAttribute("padding_mode", builder->getI32IntegerAttr(
static_cast<int32_t>(paddingMode)));
state.addAttributes(attributes);
state.addRegion();
state.setOperandListToResizable();
}
ParseResult parseWindowedReductionRegionOp(OpAsmParser &parser,
OperationState *result) {
return parseReductionRegionOp(parser, result);
}
void printWindowedReductionRegionOp(OpAsmPrinter &p,
WindowedReductionRegionOp op) {
p << op.getOperationName();
// Print the workload argument.
p << "[";
p.printOperand(op.workload());
p << " : ";
p.printType(op.workload()->getType());
p << "]";
p << "(";
p.printOperands(op.operands());
p << ")";
if (op.getNumResults() > 0) {
p << " : (";
interleaveComma(op.operands(), p,
[&](Value operand) { p.printType(operand->getType()); });
p << ")";
p << " -> (";
interleaveComma(op.getResultTypes(), p);
p << ")";
}
p << "\n";
p << " invocation(";
auto &entryBlock = op.body().getBlocks().front();
int regionArgIndex = 0;
interleaveComma(op.initial_values(), p, [&](Value operand) {
p << "(";
p.printOperand(entryBlock.getArgument(regionArgIndex++));
p << ", ";
p.printOperand(entryBlock.getArgument(regionArgIndex++));
p << ") = ";
p.printOperand(operand);
p << " : ";
p.printType(operand->getType());
});
p << ") ";
p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{});
}
//===----------------------------------------------------------------------===//
// flow.return
//===----------------------------------------------------------------------===//
static ParseResult parseReturnOp(OpAsmParser &parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> opInfo;
SmallVector<Type, 2> types;
llvm::SMLoc loc = parser.getCurrentLocation();
return failure(parser.parseOperandList(opInfo) ||
(!opInfo.empty() && parser.parseColonTypeList(types)) ||
parser.resolveOperands(opInfo, types, loc, result->operands));
}
static void printReturnOp(OpAsmPrinter &p, ReturnOp op) {
p << op.getOperationName();
if (op.getNumOperands() > 0) {
p << ' ';
p.printOperands(op.operand_begin(), op.operand_end());
p << " : ";
interleaveComma(op.getOperandTypes(), p);
}
}
//===----------------------------------------------------------------------===//
// flow.executable
//===----------------------------------------------------------------------===//
void ExecutableOp::build(Builder *builder, OperationState &state,
StringRef name) {
ensureTerminator(*state.addRegion(), *builder, state.location);
state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
builder->getStringAttr(name));
}
static ParseResult parseExecutableOp(OpAsmParser &parser,
OperationState *result) {
StringAttr nameAttr;
if (failed(parser.parseSymbolName(nameAttr,
mlir::SymbolTable::getSymbolAttrName(),
result->attributes)) ||
failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
// Parse the module body.
auto *body = result->addRegion();
if (failed(parser.parseRegion(*body, llvm::None, llvm::None))) {
return failure();
}
// Ensure that this module has a valid terminator.
ExecutableOp::ensureTerminator(*body, parser.getBuilder(), result->location);
return success();
}
static void printExecutableOp(OpAsmPrinter &p, ExecutableOp op) {
p << op.getOperationName() << ' ';
p.printSymbolName(op.sym_name());
p.printOptionalAttrDictWithKeyword(
op.getAttrs(),
/*elidedAttrs=*/{mlir::SymbolTable::getSymbolAttrName()});
p.printRegion(op.body(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
}
static LogicalResult verifyExecutableOp(ExecutableOp op) {
// TODO(benvanik): check export name conflicts.
return success();
}
static ParseResult parseRegionEndOp(OpAsmParser &parser,
OperationState *result) {
return parser.parseOptionalAttrDict(result->attributes);
}
static void printRegionEndOp(OpAsmPrinter &p, Operation *op) {
p << op->getName();
p.printOptionalAttrDict(op->getAttrs());
}
//===----------------------------------------------------------------------===//
// flow.dispatch.entry
//===----------------------------------------------------------------------===//
static ParseResult parseDispatchEntryOp(OpAsmParser &parser,
OperationState *result) {
FlatSymbolRefAttr functionRefAttr;
if (failed(parser.parseAttribute(functionRefAttr, "function_ref",
result->attributes))) {
return failure();
}
if (succeeded(parser.parseOptionalKeyword("as"))) {
StringAttr exportNameAttr;
if (failed(parser.parseLParen()) ||
failed(parser.parseAttribute(exportNameAttr, "sym_name",
result->attributes)) ||
failed(parser.parseRParen())) {
return failure();
}
} else {
result->addAttribute("sym_name", parser.getBuilder().getStringAttr(
functionRefAttr.getValue()));
}
if (failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
return success();
}
static void printDispatchEntryOp(OpAsmPrinter &p, DispatchEntryOp op) {
p << op.getOperationName() << ' ';
p.printSymbolName(op.function_ref());
if (op.sym_name() != op.function_ref()) {
p << " as(\"" << op.sym_name() << "\")";
}
p.printOptionalAttrDictWithKeyword(
op.getAttrs(), /*elidedAttrs=*/{"function_ref", "sym_name"});
}
//===----------------------------------------------------------------------===//
// flow.reduction.entry / flow.windowed_reduction.entry
//===----------------------------------------------------------------------===//
static ParseResult parseReductionEntryOp(OpAsmParser &parser,
OperationState *result) {
FlatSymbolRefAttr functionRefAttr;
FlatSymbolRefAttr applyRefAttr;
if (failed(parser.parseAttribute(functionRefAttr, "function_ref",
result->attributes)) ||
failed(parser.parseKeyword("apply")) || failed(parser.parseLParen()) ||
failed(parser.parseAttribute(applyRefAttr, "apply_ref",
result->attributes)) ||
failed(parser.parseRParen())) {
return failure();
}
if (succeeded(parser.parseOptionalKeyword("as"))) {
StringAttr exportNameAttr;
if (failed(parser.parseLParen()) ||
failed(parser.parseAttribute(exportNameAttr, "sym_name",
result->attributes)) ||
failed(parser.parseRParen())) {
return failure();
}
} else {
result->addAttribute("sym_name", parser.getBuilder().getStringAttr(
functionRefAttr.getValue()));
}
if (failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
return success();
}
static void printReductionEntryOp(OpAsmPrinter &p, ReductionEntryOp op) {
p << op.getOperationName() << ' ';
p.printSymbolName(op.function_ref());
p << " apply(";
p.printSymbolName(op.apply_ref());
p << ")";
if (op.sym_name() != op.function_ref()) {
p << " as(\"" << op.sym_name() << "\")";
}
p.printOptionalAttrDictWithKeyword(
op.getAttrs(), /*elidedAttrs=*/{"apply_ref", "function_ref", "sym_name"});
}
static ParseResult parseWindowedReductionEntryOp(OpAsmParser &parser,
OperationState *result) {
return parseReductionEntryOp(parser, result);
}
static void printWindowedReductionEntryOp(OpAsmPrinter &p,
WindowedReductionEntryOp op) {
p << op.getOperationName() << ' ';
p.printSymbolName(op.function_ref());
p << " apply(";
p.printSymbolName(op.apply_ref());
p << ")";
if (op.sym_name() != op.function_ref()) {
p << " as(\"" << op.sym_name() << "\")";
}
p.printOptionalAttrDictWithKeyword(
op.getAttrs(), /*elidedAttrs=*/{"apply_ref", "function_ref", "sym_name"});
}
//===----------------------------------------------------------------------===//
// flow.dispatch
//===----------------------------------------------------------------------===//
static ParseResult parseDispatchOp(OpAsmParser &parser,
OperationState *result) {
auto executableLoc = parser.getNameLoc();
// TODO(benvanik): replace with SymbolRefAttr.
StringAttr executableAttr;
StringAttr entryPointAttr;
if (failed(parser.parseSymbolName(executableAttr, "executable",
result->attributes)) ||
failed(parser.parseColon()) || failed(parser.parseColon()) ||
failed(parser.parseSymbolName(entryPointAttr, "entry_point",
result->attributes))) {
return failure();
}
result->attributes[0].second =
parser.getBuilder().getSymbolRefAttr(executableAttr.getValue());
result->attributes[1].second =
parser.getBuilder().getSymbolRefAttr(entryPointAttr.getValue());
OpAsmParser::OperandType workloadArg;
Type workloadArgType;
if (failed(parser.parseLSquare()) ||
failed(parser.parseOperand(workloadArg)) ||
failed(parser.parseColonType(workloadArgType)) ||
failed(parser.parseRSquare()) ||
failed(parser.resolveOperand(workloadArg, workloadArgType,
result->operands))) {
return failure();
}
SmallVector<OpAsmParser::OperandType, 4> operands;
FunctionType entryPointType;
if (failed(
parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren)) ||
failed(parser.parseOptionalAttrDict(result->attributes)) ||
failed(parser.parseColonType(entryPointType)) ||
failed(
parser.addTypesToList(entryPointType.getResults(), result->types)) ||
failed(parser.resolveOperands(operands, entryPointType.getInputs(),
executableLoc, result->operands))) {
return failure();
}
return success();
}
static void printDispatchOp(OpAsmPrinter &p, DispatchOp op) {
p << op.getOperationName() << ' ';
// TODO(benvanik): replace with SymbolRefAttr.
p.printSymbolName(op.executable());
p << "::";
p.printSymbolName(op.entry_point());
p << "[";
p.printOperand(op.workload());
p << " : ";
p.printType(op.workload()->getType());
p << "](";
p.printOperands(op.operands());
p << ')';
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{
"executable",
"entry_point",
});
p << " : ";
p.printType(op.getEntryPointType());
}
FunctionType DispatchOp::getEntryPointType() {
SmallVector<Type, 4> resultTypes(getResultTypes());
SmallVector<Type, 8> argTypes(operand_type_range{operands()});
return FunctionType::get(argTypes, resultTypes, getContext());
}
//===----------------------------------------------------------------------===//
// flow.tensor.reshape
//===----------------------------------------------------------------------===//
static ParseResult parseTensorReshapeOp(OpAsmParser &parser,
OperationState *result) {
OpAsmParser::OperandType sourceOperand;
ShapedType sourceType;
ShapedType resultType;
if (failed(parser.parseOperand(sourceOperand)) ||
failed(parser.parseColonType(sourceType)) ||
failed(parser.parseArrow()) || failed(parser.parseType(resultType)) ||
failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
if (failed(
parser.resolveOperand(sourceOperand, sourceType, result->operands))) {
return failure();
}
result->addTypes({resultType});
return success();
}
static void printTensorReshapeOp(OpAsmPrinter &p, TensorReshapeOp &op) {
p << op.getOperationName() << ' ';
p.printOperand(op.source());
p << " : ";
p.printType(op.source()->getType());
p << " -> ";
p.printType(op.result()->getType());
p.printOptionalAttrDictWithKeyword(op.getAttrs());
}
//===----------------------------------------------------------------------===//
// flow.tensor.load
//===----------------------------------------------------------------------===//
static ParseResult parseTensorLoadOp(OpAsmParser &parser,
OperationState *result) {
OpAsmParser::OperandType sourceOperand;
SmallVector<OpAsmParser::OperandType, 4> indexOperands;
ShapedType sourceType;
if (failed(parser.parseOperand(sourceOperand)) ||
failed(parser.parseOperandList(indexOperands,
OpAsmParser::Delimiter::OptionalSquare)) ||
failed(parser.parseColonType(sourceType)) ||
failed(parser.parseOptionalAttrDictWithKeyword(result->attributes)) ||
failed(
parser.resolveOperand(sourceOperand, sourceType, result->operands)) ||
failed(parser.resolveOperands(indexOperands,
parser.getBuilder().getIntegerType(32),
result->operands))) {
return failure();
}
result->addTypes({sourceType.getElementType()});
return success();
}
static void printTensorLoadOp(OpAsmPrinter &p, TensorLoadOp &op) {
p << op.getOperationName() << ' ';
p.printOperand(op.source());
if (!op.indices().empty()) {
p << '[';
p.printOperands(op.indices());
p << ']';
}
p << " : ";
p.printType(op.source()->getType());
p.printOptionalAttrDictWithKeyword(op.getAttrs());
}
//===----------------------------------------------------------------------===//
// flow.tensor.store
//===----------------------------------------------------------------------===//
static ParseResult parseTensorStoreOp(OpAsmParser &parser,
OperationState *result) {
OpAsmParser::OperandType valueOperand;
OpAsmParser::OperandType targetOperand;
SmallVector<OpAsmParser::OperandType, 4> indexOperands;
ShapedType targetType;
if (failed(parser.parseOperand(valueOperand)) ||
failed(parser.parseComma()) ||
failed(parser.parseOperand(targetOperand)) ||
failed(parser.parseOperandList(indexOperands,
OpAsmParser::Delimiter::OptionalSquare)) ||
failed(parser.parseColonType(targetType)) ||
failed(parser.parseOptionalAttrDictWithKeyword(result->attributes)) ||
failed(parser.resolveOperand(valueOperand, targetType.getElementType(),
result->operands)) ||
failed(
parser.resolveOperand(targetOperand, targetType, result->operands)) ||
failed(parser.resolveOperands(indexOperands,
parser.getBuilder().getIntegerType(32),
result->operands))) {
return failure();
}
result->addTypes({targetType});
return success();
}
static void printTensorStoreOp(OpAsmPrinter &p, TensorStoreOp &op) {
p << op.getOperationName() << ' ';
p.printOperand(op.value());
p << ", ";
p.printOperand(op.target());
if (!op.indices().empty()) {
p << '[';
p.printOperands(op.indices());
p << ']';
}
p << " : ";
p.printType(op.target()->getType());
p.printOptionalAttrDictWithKeyword(op.getAttrs());
}
//===----------------------------------------------------------------------===//
// flow.tensor.splat
//===----------------------------------------------------------------------===//
static ParseResult parseTensorSplatOp(OpAsmParser &parser,
OperationState *result) {
OpAsmParser::OperandType valueOperand;
ShapedType targetType;
if (failed(parser.parseOperand(valueOperand)) ||
failed(parser.parseColonType(targetType)) ||
failed(parser.parseOptionalAttrDictWithKeyword(result->attributes)) ||
failed(parser.resolveOperand(valueOperand, targetType.getElementType(),
result->operands))) {
return failure();
}
result->addTypes({targetType});
return success();
}
static void printTensorSplatOp(OpAsmPrinter &p, TensorSplatOp &op) {
p << op.getOperationName() << ' ';
p.printOperand(op.value());
p << " : ";
p.printType(op.result()->getType());
p.printOptionalAttrDictWithKeyword(op.getAttrs());
}
//===----------------------------------------------------------------------===//
// flow.tensor.clone
//===----------------------------------------------------------------------===//
static ParseResult parseTensorCloneOp(OpAsmParser &parser,
OperationState *result) {
OpAsmParser::OperandType operand;
ShapedType type;
if (failed(parser.parseOperand(operand)) ||
failed(parser.parseColonType(type)) ||
failed(parser.parseOptionalAttrDictWithKeyword(result->attributes)) ||
failed(parser.resolveOperand(operand, type, result->operands))) {
return failure();
}
result->addTypes({type});
return success();
}
static void printTensorCloneOp(OpAsmPrinter &p, TensorCloneOp &op) {
p << op.getOperationName() << ' ';
p.printOperand(op.operand());
p << " : ";
p.printType(op.result()->getType());
p.printOptionalAttrDictWithKeyword(op.getAttrs());
}
//===----------------------------------------------------------------------===//
// flow.tensor.slice
//===----------------------------------------------------------------------===//
static ParseResult parseTensorSliceOp(OpAsmParser &parser,
OperationState *result) {
OpAsmParser::OperandType sourceOperand;
SmallVector<OpAsmParser::OperandType, 4> indexOperands;
SmallVector<OpAsmParser::OperandType, 4> lengthOperands;
ShapedType sourceType;
ShapedType resultType;
if (failed(parser.parseOperand(sourceOperand)) ||
failed(parser.parseLSquare()) ||
failed(parser.parseOperandList(indexOperands,
OpAsmParser::Delimiter::None)) ||
failed(parser.parseKeyword("for")) ||
failed(parser.parseOperandList(lengthOperands,
OpAsmParser::Delimiter::None)) ||
failed(parser.parseRSquare()) ||
failed(parser.parseColonType(sourceType)) ||
failed(parser.parseArrow()) || failed(parser.parseType(resultType)) ||
failed(parser.parseOptionalAttrDictWithKeyword(result->attributes)) ||
failed(
parser.resolveOperand(sourceOperand, sourceType, result->operands)) ||
failed(parser.resolveOperands(indexOperands,
parser.getBuilder().getIntegerType(32),
result->operands)) ||
failed(parser.resolveOperands(lengthOperands,
parser.getBuilder().getIntegerType(32),
result->operands))) {
return failure();
}
result->addTypes({resultType});
return success();
}
static void printTensorSliceOp(OpAsmPrinter &p, TensorSliceOp &op) {
p << op.getOperationName() << ' ';
p.printOperand(op.source());
p << '[';
p.printOperands(op.start_indices());
p << " for ";
p.printOperands(op.lengths());
p << "] : ";
p.printType(op.source()->getType());
p << " -> ";
p.printType(op.result()->getType());
p.printOptionalAttrDictWithKeyword(op.getAttrs());
}
//===----------------------------------------------------------------------===//
// flow.tensor.update
//===----------------------------------------------------------------------===//
static ParseResult parseTensorUpdateOp(OpAsmParser &parser,
OperationState *result) {
OpAsmParser::OperandType updateOperand;
OpAsmParser::OperandType targetOperand;
SmallVector<OpAsmParser::OperandType, 4> indexOperands;
ShapedType updateType;
ShapedType targetType;
if (failed(parser.parseOperand(updateOperand)) ||
failed(parser.parseComma()) ||
failed(parser.parseOperand(targetOperand)) ||
failed(parser.parseOperandList(indexOperands,
OpAsmParser::Delimiter::Square)) ||
failed(parser.parseColonType(updateType)) ||
failed(parser.parseArrow()) || failed(parser.parseType(targetType)) ||
failed(parser.parseOptionalAttrDictWithKeyword(result->attributes)) ||
failed(
parser.resolveOperand(updateOperand, updateType, result->operands)) ||
failed(
parser.resolveOperand(targetOperand, targetType, result->operands)) ||
failed(parser.resolveOperands(indexOperands,
parser.getBuilder().getIntegerType(32),
result->operands))) {
return failure();
}
result->addTypes({targetType});
return success();
}
static void printTensorUpdateOp(OpAsmPrinter &p, TensorUpdateOp &op) {
p << op.getOperationName() << ' ';
p.printOperand(op.update());
p << ", ";
p.printOperand(op.target());
p << '[';
p.printOperands(op.start_indices());
p << "] : ";
p.printType(op.update()->getType());
p << " -> ";
p.printType(op.result()->getType());
p.printOptionalAttrDictWithKeyword(op.getAttrs());
}
//===----------------------------------------------------------------------===//
// flow.ex.stream.fragment
//===----------------------------------------------------------------------===//
void ExStreamFragmentOp::build(Builder *builder, OperationState &state,
ArrayRef<Type> resultTypes, ValueRange operands,
ArrayRef<NamedAttribute> attributes) {
state.addTypes(resultTypes);
state.addOperands(operands);
state.addAttributes(attributes);
state.addRegion();
state.setOperandListToResizable();
}
ParseResult parseExStreamFragmentOp(OpAsmParser &parser,
OperationState *result) {
SmallVector<OpAsmParser::OperandType, 16> regionArgs;
SmallVector<Type, 16> regionArgTypes;
if (failed(parser.parseLParen())) {
return failure();
}
if (failed(parser.parseOptionalRParen())) {
SmallVector<OpAsmParser::OperandType, 16> regionOperands;
auto argsLoc = parser.getCurrentLocation();
do {
// Reserve entries in the lists.
regionArgs.emplace_back();
regionOperands.emplace_back();
regionArgTypes.emplace_back();
if (failed(parser.parseRegionArgument(regionArgs.back())) ||
failed(parser.parseEqual()) ||
failed(parser.parseOperand(regionOperands.back())) ||
failed(parser.parseColonType(regionArgTypes.back()))) {
return failure();
}
} while (succeeded(parser.parseOptionalComma()));
if (failed(parser.parseRParen()) ||
failed(parser.resolveOperands(regionOperands, regionArgTypes, argsLoc,
result->operands))) {
return failure();
}
}
result->setOperandListToResizable();
// Parse (optional) results.
if (failed(parser.parseOptionalArrowTypeList(result->types))) {
return failure();
}
// Parse region body.
Region *body = result->addRegion();
if (failed(parser.parseRegion(*body, regionArgs, regionArgTypes)) ||
failed(parser.parseOptionalAttrDict(result->attributes))) {
return failure();
}
return success();
}
void printExStreamFragmentOp(OpAsmPrinter &p, ExStreamFragmentOp op) {
p << op.getOperationName();
// Print the data argument remapping.
p << "(";
interleaveComma(llvm::zip(op.body().front().getArguments(), op.args()), p,
[&](std::tuple<BlockArgument, Value> it) {
p << *std::get<0>(it) << " = " << *std::get<1>(it);
p << " : ";
p << std::get<1>(it)->getType();
});
p << ")";
// Print the result types, if any.
if (op.getNumResults() > 0) {
p << " -> ";
if (op.getNumResults() > 1) p << "(";
interleaveComma(op.getResultTypes(), p);
if (op.getNumResults() > 1) p << ")";
}
p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{});
}
//===----------------------------------------------------------------------===//
// TableGen definitions (intentionally last)
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "iree/compiler/Dialect/Flow/IR/FlowOps.cpp.inc"
} // namespace Flow
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir