Adding flow.tensor.* ops to handle stream-level work living outside of dispatch regions.
PiperOrigin-RevId: 283447150
diff --git a/.gitignore b/.gitignore
index 990e1ad..f2abb10 100644
--- a/.gitignore
+++ b/.gitignore
@@ -42,4 +42,3 @@
# Emacs autosaves
*~
-\#*\#
diff --git a/iree/compiler/Dialect/Flow/IR/BUILD b/iree/compiler/Dialect/Flow/IR/BUILD
index ccbfd21..0c6568b 100644
--- a/iree/compiler/Dialect/Flow/IR/BUILD
+++ b/iree/compiler/Dialect/Flow/IR/BUILD
@@ -30,6 +30,7 @@
"FlowDialect.cpp",
"FlowEnums.cpp.inc",
"FlowOpFolders.cpp",
+ "FlowOpInterface.cpp.inc",
"FlowOps.cpp",
"FlowOps.cpp.inc",
"FlowTypes.cpp",
@@ -37,12 +38,14 @@
hdrs = [
"FlowDialect.h",
"FlowEnums.h.inc",
+ "FlowOpInterface.h.inc",
"FlowOps.h",
"FlowOps.h.inc",
"FlowTypes.h",
],
deps = [
":FlowEnumsGen",
+ ":FlowOpInterfaceGen",
":FlowOpsGen",
"//iree/compiler/Dialect",
"@llvm//:support",
@@ -70,6 +73,21 @@
)
gentbl(
+ name = "FlowOpInterfaceGen",
+ tbl_outs = [
+ ("-gen-op-interface-decls", "FlowOpInterface.h.inc"),
+ ("-gen-op-interface-defs", "FlowOpInterface.cpp.inc"),
+ ],
+ tblgen = "@local_config_mlir//:mlir-tblgen",
+ td_file = "FlowBase.td",
+ td_srcs = [
+ ":td_files",
+ "//iree/compiler/Dialect:td_files",
+ "@local_config_mlir//:StdOpsTdFiles",
+ ],
+)
+
+gentbl(
name = "FlowOpsGen",
tbl_outs = [
("-gen-op-decls", "FlowOps.h.inc"),
diff --git a/iree/compiler/Dialect/Flow/IR/FlowBase.td b/iree/compiler/Dialect/Flow/IR/FlowBase.td
index 09b0e53..741df6a 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowBase.td
+++ b/iree/compiler/Dialect/Flow/IR/FlowBase.td
@@ -81,10 +81,51 @@
class FLOW_PureOp<string mnemonic, list<OpTrait> traits = []> :
FLOW_Op<mnemonic, !listconcat(traits, [NoSideEffect])>;
+def FLOW_StreamableOp : OpInterface<"StreamableOpInterface"> {
+ let description = [{
+ Interface for ops that can be used within a stream.
+
+ Some ops can exist both within a stream and outside of a stream. This allows
+ optimizations to place certain ops such that they are performed in a
+ synchronous (outside of a stream) or asynchronous (inside of a stream)
+ fashion.
+
+ The goal of the stream forming process is to move as many operations that
+ can be used within a stream into one and only using non-streamed ops as a
+ last resort. Ops that are isStreamOnly may force the creation of single-op
+ command buffers and synchronous dispatches.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ [{Returns true if the op is transfer operation (as defined by the HAL).}],
+ "bool", "isTransfer", (ins)
+ >,
+ InterfaceMethod<
+ [{Returns true if the op *can* be used within a stream.}],
+ "bool", "isUsableInStream", (ins)
+ >,
+ InterfaceMethod<
+ [{Returns true if the op *must* be used within a stream.}],
+ "bool", "isStreamOnly", (ins)
+ >,
+ ];
+}
+
//===----------------------------------------------------------------------===//
// Flow dialect types
//===----------------------------------------------------------------------===//
+// TODO(benvanik): move to base?
+class Optional<Type type> : Variadic<type>;
+
+def FLOW_PrimitiveType : AnyTypeOf<[Index, AnyInteger, AnyFloat]>;
+
+// TODO(benvanik): use index here instead? need to wait for si dialect.
+def FLOW_Dim : I<32>;
+
+def FLOW_Tensor : TypeAlias<AnyRankedTensor>;
+
def FLOW_ExecutableRefAttr : AliasedSymbolRefAttr;
def FLOW_VariableRefAttr : AliasedSymbolRefAttr;
diff --git a/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp b/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp
index b6a7fe3..d7e5a5c 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp
@@ -27,6 +27,8 @@
namespace IREE {
namespace Flow {
+#include "iree/compiler/Dialect/Flow/IR/FlowOpInterface.cpp.inc"
+
static DialectRegistration<FlowDialect> flow_dialect;
namespace {
diff --git a/iree/compiler/Dialect/Flow/IR/FlowDialect.h b/iree/compiler/Dialect/Flow/IR/FlowDialect.h
index c9760c3..f804a93 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowDialect.h
+++ b/iree/compiler/Dialect/Flow/IR/FlowDialect.h
@@ -24,6 +24,8 @@
namespace IREE {
namespace Flow {
+#include "iree/compiler/Dialect/Flow/IR/FlowOpInterface.h.inc"
+
class FlowDialect : public Dialect {
public:
explicit FlowDialect(MLIRContext *context);
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index 748f0f6..80b1b2d 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -13,6 +13,7 @@
// limitations under the License.
#include <algorithm>
+#include <numeric>
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
@@ -126,6 +127,133 @@
results.insert<EraseUnusedVariableStoreOp>(context);
}
+//===----------------------------------------------------------------------===//
+// Tensor ops
+//===----------------------------------------------------------------------===//
+
+/// Reduces the provided multidimensional index into a flattended 1D row-major
+/// index. The |type| is expected to be statically shaped (as all constants
+/// are).
+static uint64_t getFlattenedIndex(ShapedType type, ArrayRef<uint64_t> index) {
+ assert(type.hasStaticShape() && "for use on statically shaped types only");
+ auto rank = type.getRank();
+ auto shape = type.getShape();
+ uint64_t valueIndex = 0;
+ uint64_t dimMultiplier = 1;
+ for (int i = rank - 1; i >= 0; --i) {
+ valueIndex += index[i] * dimMultiplier;
+ dimMultiplier *= shape[i];
+ }
+ return valueIndex;
+}
+
+OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute> operands) {
+ auto sourceType = source()->getType().cast<ShapedType>();
+ auto resultType = result()->getType().cast<ShapedType>();
+ if (sourceType.hasStaticShape() && sourceType == resultType) {
+ // No-op.
+ return source();
+ }
+
+ // Skip intermediate reshapes.
+ if (auto definingOp =
+ dyn_cast_or_null<TensorReshapeOp>(source()->getDefiningOp())) {
+ setOperand(definingOp.getOperand());
+ return result();
+ }
+
+ return {};
+}
+
+OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute> operands) {
+ if (auto source = operands[0].dyn_cast_or_null<ElementsAttr>()) {
+ // Load directly from the constant source tensor.
+ auto indices = operands.drop_front();
+ if (llvm::count(indices, nullptr) == 0) {
+ return source.getValue(
+ llvm::to_vector<4>(llvm::map_range(indices, [](Attribute value) {
+ return value.cast<IntegerAttr>().getValue().getZExtValue();
+ })));
+ }
+ }
+ return {};
+}
+
+OpFoldResult TensorStoreOp::fold(ArrayRef<Attribute> operands) {
+ if (!operands[0]) return {};
+ auto &value = operands[0];
+ if (auto target = operands[1].dyn_cast_or_null<ElementsAttr>()) {
+ // Store into the constant target tensor.
+ if (target.getType().getRank() == 0) {
+ return DenseElementsAttr::get(target.getType(), {value});
+ }
+ auto indices = operands.drop_front(2);
+ if (llvm::count(indices, nullptr) == 0) {
+ uint64_t offset = getFlattenedIndex(
+ target.getType(),
+ llvm::to_vector<4>(llvm::map_range(indices, [](Attribute value) {
+ return value.cast<IntegerAttr>().getValue().getZExtValue();
+ })));
+ SmallVector<Attribute, 16> newContents(target.getValues<Attribute>());
+ newContents[offset] = value;
+ return DenseElementsAttr::get(target.getType(), newContents);
+ }
+ }
+ return {};
+}
+
+OpFoldResult TensorSplatOp::fold(ArrayRef<Attribute> operands) {
+ // TODO(benvanik): only fold when shape is constant.
+ if (operands[0]) {
+ // Splat value is constant and we can fold the operation.
+ return SplatElementsAttr::get(result()->getType().cast<ShapedType>(),
+ operands[0]);
+ }
+ return {};
+}
+
+OpFoldResult TensorCloneOp::fold(ArrayRef<Attribute> operands) {
+ if (operands[0]) {
+ return operands[0];
+ }
+ // TODO(benvanik): fold if clone device placements differ.
+ return operand();
+}
+
+OpFoldResult TensorSliceOp::fold(ArrayRef<Attribute> operands) {
+ if (operands[0] && operands[1] && operands[2]) {
+ // Fully constant arguments so we can perform the slice here.
+ // TODO(benvanik): constant slice.
+ return {};
+ }
+ return {};
+}
+
+static ElementsAttr tensorUpdate(ElementsAttr update, ElementsAttr target,
+ ArrayRef<Attribute> startIndicesAttrs) {
+ // TODO(benvanik): tensor update constant folding.
+ return {};
+}
+
+OpFoldResult TensorUpdateOp::fold(ArrayRef<Attribute> operands) {
+ auto indices = operands.drop_front(2);
+ bool allIndicesConstant = llvm::count(indices, nullptr) == 0;
+ if (operands[0] && operands[1] && allIndicesConstant) {
+ // Fully constant arguments so we can perform the update here.
+ return tensorUpdate(operands[0].cast<ElementsAttr>(),
+ operands[1].cast<ElementsAttr>(), indices);
+ } else {
+ // Replace the entire tensor when the sizes match.
+ auto updateType = update()->getType().cast<ShapedType>();
+ auto targetType = target()->getType().cast<ShapedType>();
+ if (updateType.hasStaticShape() && targetType.hasStaticShape() &&
+ updateType == targetType) {
+ return update();
+ }
+ }
+ return {};
+}
+
} // namespace Flow
} // namespace IREE
} // namespace iree_compiler
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index 455a685..f7dc689 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
@@ -627,82 +628,6 @@
}
//===----------------------------------------------------------------------===//
-// flow.dispatch
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseDispatchOp(OpAsmParser &parser,
- OperationState *result) {
- auto executableLoc = parser.getNameLoc();
-
- // TODO(benvanik): replace with SymbolRefAttr.
- StringAttr executableAttr;
- StringAttr entryPointAttr;
- if (failed(parser.parseSymbolName(executableAttr, "executable",
- result->attributes)) ||
- failed(parser.parseColon()) || failed(parser.parseColon()) ||
- failed(parser.parseSymbolName(entryPointAttr, "entry_point",
- result->attributes))) {
- return failure();
- }
- result->attributes[0].second =
- parser.getBuilder().getSymbolRefAttr(executableAttr.getValue());
- result->attributes[1].second =
- parser.getBuilder().getSymbolRefAttr(entryPointAttr.getValue());
-
- OpAsmParser::OperandType workloadArg;
- Type workloadArgType;
- if (failed(parser.parseLSquare()) ||
- failed(parser.parseOperand(workloadArg)) ||
- failed(parser.parseColonType(workloadArgType)) ||
- failed(parser.parseRSquare()) ||
- failed(parser.resolveOperand(workloadArg, workloadArgType,
- result->operands))) {
- return failure();
- }
-
- SmallVector<OpAsmParser::OperandType, 4> operands;
- FunctionType entryPointType;
- if (failed(
- parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren)) ||
- failed(parser.parseOptionalAttrDict(result->attributes)) ||
- failed(parser.parseColonType(entryPointType)) ||
- failed(
- parser.addTypesToList(entryPointType.getResults(), result->types)) ||
- failed(parser.resolveOperands(operands, entryPointType.getInputs(),
- executableLoc, result->operands))) {
- return failure();
- }
- return success();
-}
-
-static void printDispatchOp(OpAsmPrinter &p, DispatchOp op) {
- p << op.getOperationName() << ' ';
- // TODO(benvanik): replace with SymbolRefAttr.
- p.printSymbolName(op.executable());
- p << "::";
- p.printSymbolName(op.entry_point());
- p << "[";
- p.printOperand(op.workload());
- p << " : ";
- p.printType(op.workload()->getType());
- p << "](";
- p.printOperands(op.operands());
- p << ')';
- p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{
- "executable",
- "entry_point",
- });
- p << " : ";
- p.printType(op.getEntryPointType());
-}
-
-FunctionType DispatchOp::getEntryPointType() {
- SmallVector<Type, 4> resultTypes(getResultTypes());
- SmallVector<Type, 8> argTypes(operand_type_range{operands()});
- return FunctionType::get(argTypes, resultTypes, getContext());
-}
-
-//===----------------------------------------------------------------------===//
// flow.executable
//===----------------------------------------------------------------------===//
@@ -871,6 +796,347 @@
}
//===----------------------------------------------------------------------===//
+// flow.dispatch
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseDispatchOp(OpAsmParser &parser,
+ OperationState *result) {
+ auto executableLoc = parser.getNameLoc();
+
+ // TODO(benvanik): replace with SymbolRefAttr.
+ StringAttr executableAttr;
+ StringAttr entryPointAttr;
+ if (failed(parser.parseSymbolName(executableAttr, "executable",
+ result->attributes)) ||
+ failed(parser.parseColon()) || failed(parser.parseColon()) ||
+ failed(parser.parseSymbolName(entryPointAttr, "entry_point",
+ result->attributes))) {
+ return failure();
+ }
+ result->attributes[0].second =
+ parser.getBuilder().getSymbolRefAttr(executableAttr.getValue());
+ result->attributes[1].second =
+ parser.getBuilder().getSymbolRefAttr(entryPointAttr.getValue());
+
+ OpAsmParser::OperandType workloadArg;
+ Type workloadArgType;
+ if (failed(parser.parseLSquare()) ||
+ failed(parser.parseOperand(workloadArg)) ||
+ failed(parser.parseColonType(workloadArgType)) ||
+ failed(parser.parseRSquare()) ||
+ failed(parser.resolveOperand(workloadArg, workloadArgType,
+ result->operands))) {
+ return failure();
+ }
+
+ SmallVector<OpAsmParser::OperandType, 4> operands;
+ FunctionType entryPointType;
+ if (failed(
+ parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren)) ||
+ failed(parser.parseOptionalAttrDict(result->attributes)) ||
+ failed(parser.parseColonType(entryPointType)) ||
+ failed(
+ parser.addTypesToList(entryPointType.getResults(), result->types)) ||
+ failed(parser.resolveOperands(operands, entryPointType.getInputs(),
+ executableLoc, result->operands))) {
+ return failure();
+ }
+ return success();
+}
+
+static void printDispatchOp(OpAsmPrinter &p, DispatchOp op) {
+ p << op.getOperationName() << ' ';
+ // TODO(benvanik): replace with SymbolRefAttr.
+ p.printSymbolName(op.executable());
+ p << "::";
+ p.printSymbolName(op.entry_point());
+ p << "[";
+ p.printOperand(op.workload());
+ p << " : ";
+ p.printType(op.workload()->getType());
+ p << "](";
+ p.printOperands(op.operands());
+ p << ')';
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{
+ "executable",
+ "entry_point",
+ });
+ p << " : ";
+ p.printType(op.getEntryPointType());
+}
+
+FunctionType DispatchOp::getEntryPointType() {
+ SmallVector<Type, 4> resultTypes(getResultTypes());
+ SmallVector<Type, 8> argTypes(operand_type_range{operands()});
+ return FunctionType::get(argTypes, resultTypes, getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// flow.tensor.reshape
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseTensorReshapeOp(OpAsmParser &parser,
+ OperationState *result) {
+ OpAsmParser::OperandType sourceOperand;
+ ShapedType sourceType;
+ ShapedType resultType;
+ if (failed(parser.parseOperand(sourceOperand)) ||
+ failed(parser.parseColonType(sourceType)) ||
+ failed(parser.parseArrow()) || failed(parser.parseType(resultType)) ||
+ failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
+ return failure();
+ }
+ if (failed(
+ parser.resolveOperand(sourceOperand, sourceType, result->operands))) {
+ return failure();
+ }
+ result->addTypes({resultType});
+ return success();
+}
+
+static void printTensorReshapeOp(OpAsmPrinter &p, TensorReshapeOp &op) {
+ p << op.getOperationName() << ' ';
+ p.printOperand(op.source());
+ p << " : ";
+ p.printType(op.source()->getType());
+ p << " -> ";
+ p.printType(op.result()->getType());
+ p.printOptionalAttrDictWithKeyword(op.getAttrs());
+}
+
+//===----------------------------------------------------------------------===//
+// flow.tensor.load
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseTensorLoadOp(OpAsmParser &parser,
+ OperationState *result) {
+ OpAsmParser::OperandType sourceOperand;
+ SmallVector<OpAsmParser::OperandType, 4> indexOperands;
+ ShapedType sourceType;
+ if (failed(parser.parseOperand(sourceOperand)) ||
+ failed(parser.parseOperandList(indexOperands,
+ OpAsmParser::Delimiter::OptionalSquare)) ||
+ failed(parser.parseColonType(sourceType)) ||
+ failed(parser.parseOptionalAttrDictWithKeyword(result->attributes)) ||
+ failed(
+ parser.resolveOperand(sourceOperand, sourceType, result->operands)) ||
+ failed(parser.resolveOperands(indexOperands,
+ parser.getBuilder().getIntegerType(32),
+ result->operands))) {
+ return failure();
+ }
+ result->addTypes({sourceType.getElementType()});
+ return success();
+}
+
+static void printTensorLoadOp(OpAsmPrinter &p, TensorLoadOp &op) {
+ p << op.getOperationName() << ' ';
+ p.printOperand(op.source());
+ if (!op.indices().empty()) {
+ p << '[';
+ p.printOperands(op.indices());
+ p << ']';
+ }
+ p << " : ";
+ p.printType(op.source()->getType());
+ p.printOptionalAttrDictWithKeyword(op.getAttrs());
+}
+
+//===----------------------------------------------------------------------===//
+// flow.tensor.store
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseTensorStoreOp(OpAsmParser &parser,
+ OperationState *result) {
+ OpAsmParser::OperandType valueOperand;
+ OpAsmParser::OperandType targetOperand;
+ SmallVector<OpAsmParser::OperandType, 4> indexOperands;
+ ShapedType targetType;
+ if (failed(parser.parseOperand(valueOperand)) ||
+ failed(parser.parseComma()) ||
+ failed(parser.parseOperand(targetOperand)) ||
+ failed(parser.parseOperandList(indexOperands,
+ OpAsmParser::Delimiter::OptionalSquare)) ||
+ failed(parser.parseColonType(targetType)) ||
+ failed(parser.parseOptionalAttrDictWithKeyword(result->attributes)) ||
+ failed(parser.resolveOperand(valueOperand, targetType.getElementType(),
+ result->operands)) ||
+ failed(
+ parser.resolveOperand(targetOperand, targetType, result->operands)) ||
+ failed(parser.resolveOperands(indexOperands,
+ parser.getBuilder().getIntegerType(32),
+ result->operands))) {
+ return failure();
+ }
+ result->addTypes({targetType});
+ return success();
+}
+
+static void printTensorStoreOp(OpAsmPrinter &p, TensorStoreOp &op) {
+ p << op.getOperationName() << ' ';
+ p.printOperand(op.value());
+ p << ", ";
+ p.printOperand(op.target());
+ if (!op.indices().empty()) {
+ p << '[';
+ p.printOperands(op.indices());
+ p << ']';
+ }
+ p << " : ";
+ p.printType(op.target()->getType());
+ p.printOptionalAttrDictWithKeyword(op.getAttrs());
+}
+
+//===----------------------------------------------------------------------===//
+// flow.tensor.splat
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseTensorSplatOp(OpAsmParser &parser,
+ OperationState *result) {
+ OpAsmParser::OperandType valueOperand;
+ ShapedType targetType;
+ if (failed(parser.parseOperand(valueOperand)) ||
+ failed(parser.parseColonType(targetType)) ||
+ failed(parser.parseOptionalAttrDictWithKeyword(result->attributes)) ||
+ failed(parser.resolveOperand(valueOperand, targetType.getElementType(),
+ result->operands))) {
+ return failure();
+ }
+ result->addTypes({targetType});
+ return success();
+}
+
+static void printTensorSplatOp(OpAsmPrinter &p, TensorSplatOp &op) {
+ p << op.getOperationName() << ' ';
+ p.printOperand(op.value());
+ p << " : ";
+ p.printType(op.result()->getType());
+ p.printOptionalAttrDictWithKeyword(op.getAttrs());
+}
+
+//===----------------------------------------------------------------------===//
+// flow.tensor.clone
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseTensorCloneOp(OpAsmParser &parser,
+ OperationState *result) {
+ OpAsmParser::OperandType operand;
+ ShapedType type;
+ if (failed(parser.parseOperand(operand)) ||
+ failed(parser.parseColonType(type)) ||
+ failed(parser.parseOptionalAttrDictWithKeyword(result->attributes)) ||
+ failed(parser.resolveOperand(operand, type, result->operands))) {
+ return failure();
+ }
+ result->addTypes({type});
+ return success();
+}
+
+static void printTensorCloneOp(OpAsmPrinter &p, TensorCloneOp &op) {
+ p << op.getOperationName() << ' ';
+ p.printOperand(op.operand());
+ p << " : ";
+ p.printType(op.result()->getType());
+ p.printOptionalAttrDictWithKeyword(op.getAttrs());
+}
+
+//===----------------------------------------------------------------------===//
+// flow.tensor.slice
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseTensorSliceOp(OpAsmParser &parser,
+ OperationState *result) {
+ OpAsmParser::OperandType sourceOperand;
+ SmallVector<OpAsmParser::OperandType, 4> indexOperands;
+ SmallVector<OpAsmParser::OperandType, 4> lengthOperands;
+ ShapedType sourceType;
+ ShapedType resultType;
+ if (failed(parser.parseOperand(sourceOperand)) ||
+ failed(parser.parseLSquare()) ||
+ failed(parser.parseOperandList(indexOperands,
+ OpAsmParser::Delimiter::None)) ||
+ failed(parser.parseKeyword("for")) ||
+ failed(parser.parseOperandList(lengthOperands,
+ OpAsmParser::Delimiter::None)) ||
+ failed(parser.parseRSquare()) ||
+ failed(parser.parseColonType(sourceType)) ||
+ failed(parser.parseArrow()) || failed(parser.parseType(resultType)) ||
+ failed(parser.parseOptionalAttrDictWithKeyword(result->attributes)) ||
+ failed(
+ parser.resolveOperand(sourceOperand, sourceType, result->operands)) ||
+ failed(parser.resolveOperands(indexOperands,
+ parser.getBuilder().getIntegerType(32),
+ result->operands)) ||
+ failed(parser.resolveOperands(lengthOperands,
+ parser.getBuilder().getIntegerType(32),
+ result->operands))) {
+ return failure();
+ }
+ result->addTypes({resultType});
+ return success();
+}
+
+static void printTensorSliceOp(OpAsmPrinter &p, TensorSliceOp &op) {
+ p << op.getOperationName() << ' ';
+ p.printOperand(op.source());
+ p << '[';
+ p.printOperands(op.start_indices());
+ p << " for ";
+ p.printOperands(op.lengths());
+ p << "] : ";
+ p.printType(op.source()->getType());
+ p << " -> ";
+ p.printType(op.result()->getType());
+ p.printOptionalAttrDictWithKeyword(op.getAttrs());
+}
+
+//===----------------------------------------------------------------------===//
+// flow.tensor.update
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseTensorUpdateOp(OpAsmParser &parser,
+ OperationState *result) {
+ OpAsmParser::OperandType updateOperand;
+ OpAsmParser::OperandType targetOperand;
+ SmallVector<OpAsmParser::OperandType, 4> indexOperands;
+ ShapedType updateType;
+ ShapedType targetType;
+ if (failed(parser.parseOperand(updateOperand)) ||
+ failed(parser.parseComma()) ||
+ failed(parser.parseOperand(targetOperand)) ||
+ failed(parser.parseOperandList(indexOperands,
+ OpAsmParser::Delimiter::Square)) ||
+ failed(parser.parseColonType(updateType)) ||
+ failed(parser.parseArrow()) || failed(parser.parseType(targetType)) ||
+ failed(parser.parseOptionalAttrDictWithKeyword(result->attributes)) ||
+ failed(
+ parser.resolveOperand(updateOperand, updateType, result->operands)) ||
+ failed(
+ parser.resolveOperand(targetOperand, targetType, result->operands)) ||
+ failed(parser.resolveOperands(indexOperands,
+ parser.getBuilder().getIntegerType(32),
+ result->operands))) {
+ return failure();
+ }
+ result->addTypes({targetType});
+ return success();
+}
+
+static void printTensorUpdateOp(OpAsmPrinter &p, TensorUpdateOp &op) {
+ p << op.getOperationName() << ' ';
+ p.printOperand(op.update());
+ p << ", ";
+ p.printOperand(op.target());
+ p << '[';
+ p.printOperands(op.start_indices());
+ p << "] : ";
+ p.printType(op.update()->getType());
+ p << " -> ";
+ p.printType(op.result()->getType());
+ p.printOptionalAttrDictWithKeyword(op.getAttrs());
+}
+
+//===----------------------------------------------------------------------===//
// flow.ex.stream.fragment
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.td b/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 3973c7d..cf38705 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -210,8 +210,8 @@
let arguments = (ins
FLOW_Workload:$workload,
- Variadic<AnyType>:$operands,
- Variadic<AnyType>:$initial_values,
+ Variadic<FLOW_Tensor>:$operands,
+ Variadic<FLOW_Tensor>:$initial_values,
// TODO(benvanik): use index types instead of i32.
I32ElementsAttr:$window_dimensions,
I32ElementsAttr:$window_strides,
@@ -220,7 +220,7 @@
FLOW_PaddingModeAttr:$padding_mode
);
let results = (outs
- Variadic<AnyType>:$results
+ Variadic<FLOW_Tensor>:$results
);
let regions = (region AnyRegion:$body);
@@ -264,48 +264,6 @@
}
//===----------------------------------------------------------------------===//
-// Dispatch ops
-//===----------------------------------------------------------------------===//
-
-def FLOW_DispatchOp : FLOW_PureOp<"dispatch"> {
- let summary = [{a dispatch to an outlined dispatch region}];
- let description = [{
- Dispatches a workload to the specified executable function.
- }];
-
- let arguments = (ins
- // TODO(benvanik): replace with SymbolRefAttr.
- // TODO(benvanik): validate target is an executable.
- FlatSymbolRefAttr:$executable,
- FlatSymbolRefAttr:$entry_point,
- FLOW_Workload:$workload,
- Variadic<AnyType>:$operands
- );
- let results = (outs
- Variadic<AnyType>:$results
- );
-
- let skipDefaultBuilders = 1;
- let builders = [
- OpBuilder<[{
- Builder *builder, OperationState &result, StringRef executable,
- StringRef entryPoint, Value *workload,
- ArrayRef<Type> results, ArrayRef<Value *> operands = {}
- }], [{
- result.addOperands({workload});
- result.addOperands(operands);
- result.addAttribute("executable", builder->getSymbolRefAttr(executable));
- result.addAttribute("entry_point", builder->getSymbolRefAttr(entryPoint));
- result.addTypes(results);
- }]>,
- ];
-
- let extraClassDeclaration = [{
- FunctionType getEntryPointType();
- }];
-}
-
-//===----------------------------------------------------------------------===//
// Executables for outlined regions
//===----------------------------------------------------------------------===//
@@ -432,12 +390,244 @@
}
//===----------------------------------------------------------------------===//
+// Dispatch ops
+//===----------------------------------------------------------------------===//
+
+def FLOW_DispatchOp : FLOW_PureOp<"dispatch", [
+ FLOW_StreamableOp,
+ ]> {
+ let summary = [{a dispatch to an outlined dispatch region}];
+ let description = [{
+ Dispatches a workload to the specified executable function.
+ }];
+
+ let arguments = (ins
+ // TODO(benvanik): replace with SymbolRefAttr.
+ // TODO(benvanik): validate target is an executable.
+ FlatSymbolRefAttr:$executable,
+ FlatSymbolRefAttr:$entry_point,
+ FLOW_Workload:$workload,
+ Variadic<AnyType>:$operands
+ );
+ let results = (outs
+ Variadic<AnyType>:$results
+ );
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<[{
+ Builder *builder, OperationState &result, StringRef executable,
+ StringRef entryPoint, Value *workload,
+ ArrayRef<Type> results, ArrayRef<Value *> operands = {}
+ }], [{
+ result.addOperands({workload});
+ result.addOperands(operands);
+ result.addAttribute("executable", builder->getSymbolRefAttr(executable));
+ result.addAttribute("entry_point", builder->getSymbolRefAttr(entryPoint));
+ result.addTypes(results);
+ }]>,
+ ];
+
+ let extraClassDeclaration = [{
+ FunctionType getEntryPointType();
+
+ // StreamableOpInterface:
+ bool isTransfer() { return false; }
+ bool isUsableInStream() { return true; }
+ bool isStreamOnly() { return true; }
+ }];
+}
+
+//===----------------------------------------------------------------------===//
// Tensor ops
//===----------------------------------------------------------------------===//
-// TODO(benvanik): tensor casts for widening/narrowing? or rely on std?
-// TODO(benvanik): DynamicUpdateSlice-equivalent?
-// TODO(benvanik): structured control flow (if we want it here).
+def FLOW_TensorReshapeOp : FLOW_PureOp<"tensor.reshape", [
+ FLOW_StreamableOp,
+ AllElementTypesMatch<["source", "result"]>,
+ ]> {
+ let summary = [{reshapes a tensor}];
+ let description = [{
+ Reshapes a tensor to a new shape without modifying the contents.
+ }];
+
+ let arguments = (ins
+ FLOW_Tensor:$source
+ // TODO(benvanik): FLOW_Shape:$shape when supporting dynamic shapes.
+ );
+ let results = (outs
+ FLOW_Tensor:$result
+ );
+
+ let extraClassDeclaration = [{
+ // StreamableOpInterface:
+ bool isTransfer() { return true; }
+ bool isUsableInStream() { return true; }
+ // TODO(benvanik): allow out of stream to act as a shape manipulation.
+ bool isStreamOnly() { return true; }
+ }];
+
+ // TODO(benvanik): canonicalize away if resulting ops don't care.
+ let hasFolder = 1;
+}
+
+def FLOW_TensorLoadOp : FLOW_PureOp<"tensor.load"> {
+ let summary = [{loads a value from a tensor element}];
+ let description = [{
+ Returns the element at the given location from within the tensor.
+ }];
+
+ let arguments = (ins
+ FLOW_Tensor:$source,
+ Variadic<FLOW_Dim>:$indices
+ );
+ let results = (outs
+ AnyTypeOf<[FLOW_PrimitiveType, AnyVector]>:$result
+ );
+
+ // TODO(benvanik): canonicalize to slice+load if dims are known.
+ let hasFolder = 1;
+}
+
+def FLOW_TensorStoreOp : FLOW_PureOp<"tensor.store"> {
+ let summary = [{stores a value into a tensor element}];
+ let description = [{
+ Returns a tensor with the element at the given index set to the given value.
+ }];
+
+ let arguments = (ins
+ AnyTypeOf<[FLOW_PrimitiveType, AnyVector]>:$value,
+ FLOW_Tensor:$target,
+ Variadic<FLOW_Dim>:$indices
+ );
+ let results = (outs
+ FLOW_Tensor:$result
+ );
+
+ let hasFolder = 1;
+}
+
+def FLOW_TensorSplatOp : FLOW_PureOp<"tensor.splat", [
+ FLOW_StreamableOp,
+ ]> {
+ let summary = [{splats a value into a shaped tensor}];
+ let description = [{
+ Returns a tensor initialized to the given primitive value.
+ }];
+
+ let arguments = (ins
+ FLOW_PrimitiveType:$value
+ // TODO(benvanik): FLOW_Shape:$shape when supporting dynamic shapes.
+ );
+ let results = (outs
+ FLOW_Tensor:$result
+ );
+
+ let extraClassDeclaration = [{
+ // StreamableOpInterface:
+ bool isTransfer() { return true; }
+ bool isUsableInStream() { return true; }
+ // TODO(benvanik): allow out of stream to act as a hal.buffer.fill.
+ bool isStreamOnly() { return true; }
+ }];
+
+ // TODO(benvanik): canonicalize splat+slice to smaller splat.
+ let hasFolder = 1;
+}
+
+def FLOW_TensorCloneOp : FLOW_PureOp<"tensor.clone", [
+ FLOW_StreamableOp,
+ SameOperandsAndResultType,
+ ]> {
+ let summary = [{performs a full tensor clone operation}];
+ let description = [{
+ Clones the input tensor into an identical output tensor.
+ }];
+
+ let arguments = (ins
+ FLOW_Tensor:$operand
+ );
+ let results = (outs
+ FLOW_Tensor:$result
+ );
+
+ let extraClassDeclaration = [{
+ // StreamableOpInterface:
+ bool isTransfer() { return true; }
+ bool isUsableInStream() { return true; }
+ // TODO(benvanik): allow out of stream to act as a hal.buffer.copy.
+ bool isStreamOnly() { return true; }
+ }];
+
+ // TODO(benvanik): canonicalize away entirely in most cases.
+ let hasFolder = 1;
+}
+
+def FLOW_TensorSliceOp : FLOW_PureOp<"tensor.slice", [
+ FLOW_StreamableOp,
+ AllRanksMatch<["source", "result"]>,
+ AllElementTypesMatch<["source", "result"]>,
+ SameVariadicOperandSize,
+ ]> {
+ let summary = [{slices out a subregion of a tensor}];
+ let description = [{
+ Clones a subregion of a tensor.
+ }];
+
+ let arguments = (ins
+ FLOW_Tensor:$source,
+ Variadic<FLOW_Dim>:$start_indices,
+ Variadic<FLOW_Dim>:$lengths
+ // TODO(benvanik): strides.
+ );
+ let results = (outs
+ FLOW_Tensor:$result
+ );
+
+ let extraClassDeclaration = [{
+ // StreamableOpInterface:
+ bool isTransfer() { return true; }
+ bool isUsableInStream() { return true; }
+ // TODO(benvanik): allow out of stream to act as a hal.buffer.slice.
+ bool isStreamOnly() { return true; }
+ }];
+
+ // TODO(benvanik): canonicalize multiple slices (traverse upward through ssa).
+ let hasFolder = 1;
+}
+
+def FLOW_TensorUpdateOp : FLOW_PureOp<"tensor.update", [
+ FLOW_StreamableOp,
+ AllRanksMatch<["update", "target", "result"]>,
+ AllShapesMatch<["target", "result"]>,
+ AllElementTypesMatch<["update", "target", "result"]>,
+ ]> {
+ let summary = [{updates a tensor with the contents of another tensor}];
+ let description = [{
+ Updates the target tensor with the contents of the update tensor at the
+ given offset indices.
+ }];
+
+ let arguments = (ins
+ FLOW_Tensor:$update,
+ FLOW_Tensor:$target,
+ Variadic<FLOW_Dim>:$start_indices
+ );
+ let results = (outs
+ FLOW_Tensor:$result
+ );
+
+ let extraClassDeclaration = [{
+ // StreamableOpInterface:
+ bool isTransfer() { return true; }
+ bool isUsableInStream() { return true; }
+ // TODO(benvanik): allow out of stream to act as a hal.buffer.copy.
+ bool isStreamOnly() { return true; }
+ }];
+
+ // TODO(benvanik): canonicalize contiguous updates/across slices.
+ let hasFolder = 1;
+}
//===----------------------------------------------------------------------===//
// Streams
diff --git a/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
new file mode 100644
index 0000000..681dffe
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
@@ -0,0 +1,142 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Tests folding and canonicalization of tensor ops.
+
+// RUN: iree-opt -split-input-file -canonicalize %s | iree-opt -split-input-file | IreeFileCheck %s
+
+// CHECK-LABEL: @reshapeNoOp
+func @reshapeNoOp(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
+ // CHECK-NEXT: return %arg0 : tensor<4x4xf32>
+ %0 = flow.tensor.reshape %arg0 : tensor<4x4xf32> -> tensor<4x4xf32>
+ return %0 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: @reshapeNoOpScalar
+func @reshapeNoOpScalar(%arg0 : tensor<f32>) -> tensor<f32> {
+ // CHECK-NEXT: return %arg0 : tensor<f32>
+ %0 = flow.tensor.reshape %arg0 : tensor<f32> -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// CHECK-LABEL: @reshapeTransitive
+func @reshapeTransitive(%arg0 : tensor<4x4xf32>) -> tensor<8x2xf32> {
+ %0 = flow.tensor.reshape %arg0 : tensor<4x4xf32> -> tensor<2x8xf32>
+ // CHECK-NEXT: [[T:%.+]] = flow.tensor.reshape %arg0 : tensor<4x4xf32> -> tensor<8x2xf32>
+ %1 = flow.tensor.reshape %0 : tensor<2x8xf32> -> tensor<8x2xf32>
+ // CHECK-NEXT: return [[T]] : tensor<8x2xf32>
+ return %1 : tensor<8x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @loadConst
+func @loadConst() -> i32 {
+ %0 = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
+ %c0 = constant 0 : i32
+ %c1 = constant 1 : i32
+ // CHECK-NEXT: [[C2:%.+]] = constant 2 : i32
+ %2 = flow.tensor.load %0[%c1, %c0] : tensor<2x2xi32>
+ // CHECK-NEXT: return [[C2]]
+ return %2 : i32
+}
+
+// CHECK-LABEL: @loadConstScalar
+func @loadConstScalar() -> i32 {
+ %0 = constant dense<4> : tensor<i32>
+ // CHECK-NEXT: [[C4:%.+]] = constant 4 : i32
+ %1 = flow.tensor.load %0 : tensor<i32>
+ // CHECK-NEXT: return [[C4]]
+ return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @storeConst
+func @storeConst() -> tensor<2x2xi32> {
+ %0 = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
+ %c0 = constant 0 : i32
+ %c1 = constant 1 : i32
+ %c4 = constant 4 : i32
+ // CHECK-NEXT: [[C:%.+]] = constant dense<[
+ // CHECK-SAME: [0, 1], [4, 3]
+ // CHECK-SAME: ]> : tensor<2x2xi32>
+ %1 = flow.tensor.store %c4, %0[%c1, %c0] : tensor<2x2xi32>
+ // CHECK-NEXT: return [[C]]
+ return %1 : tensor<2x2xi32>
+}
+
+// CHECK-LABEL: @storeConstScalar
+func @storeConstScalar() -> tensor<i32> {
+ %0 = constant dense<0> : tensor<i32>
+ %1 = constant 4 : i32
+ // CHECK-NEXT: [[C:%.+]] = constant dense<4> : tensor<i32>
+ %2 = flow.tensor.store %1, %0 : tensor<i32>
+ // CHECK-NEXT: return [[C]]
+ return %2 : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @splatConst
+func @splatConst() -> tensor<4xi32> {
+ %0 = constant 4 : i32
+ // CHECK-NEXT: [[C:%.+]] = constant dense<4> : tensor<4xi32>
+ %1 = flow.tensor.splat %0 : tensor<4xi32>
+ // CHECK-NEXT: return [[C]]
+ return %1 : tensor<4xi32>
+}
+
+// CHECK-LABEL: @splatConstScalar
+func @splatConstScalar() -> tensor<i32> {
+ %0 = constant 4 : i32
+ // CHECK-NEXT: [[C:%.+]] = constant dense<4> : tensor<i32>
+ %1 = flow.tensor.splat %0 : tensor<i32>
+ // CHECK-NEXT: return [[C]]
+ return %1 : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @cloneConst
+func @cloneConst() -> tensor<4xi32> {
+ %0 = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
+ // CHECK-NEXT: [[C:%.+]] = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
+ %1 = flow.tensor.clone %0 : tensor<4xi32>
+ // CHECK-NEXT: return [[C]]
+ return %1 : tensor<4xi32>
+}
+
+// CHECK-LABEL: @cloneDynamic
+func @cloneDynamic(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
+ %0 = flow.tensor.clone %arg0 : tensor<4xi32>
+ // CHECK-NEXT: return %arg0
+ return %0 : tensor<4xi32>
+}
+
+// -----
+
+// TODO(benvanik): const folder for slice.
+
+// -----
+
+// TODO(benvanik): const folder for update.
+
+// CHECK-LABEL: @updateReplace
+func @updateReplace(%arg0 : tensor<4xi32>, %arg1 : tensor<4xi32>) -> tensor<4xi32> {
+ %c0 = constant 0 : i32
+ %0 = flow.tensor.update %arg0, %arg1[%c0] : tensor<4xi32> -> tensor<4xi32>
+ // CHECK-NEXT: return %arg0
+ return %0 : tensor<4xi32>
+}
diff --git a/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir
new file mode 100644
index 0000000..07added
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir
@@ -0,0 +1,113 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Tests printing and parsing of tensor ops.
+
+// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
+
+// CHECK-LABEL: @tensorReshape
+func @tensorReshape(%arg0 : tensor<4x4xf32>) -> tensor<16xf32> {
+ // CHECK-NEXT: %0 = flow.tensor.reshape %arg0 : tensor<4x4xf32> -> tensor<16xf32>
+ %0 = flow.tensor.reshape %arg0 : tensor<4x4xf32> -> tensor<16xf32>
+ return %0 : tensor<16xf32>
+}
+
+// CHECK-LABEL: @tensorReshapeScalar
+func @tensorReshapeScalar(%arg0 : tensor<f32>) -> tensor<f32> {
+ // CHECK-NEXT: %0 = flow.tensor.reshape %arg0 : tensor<f32> -> tensor<f32>
+ %0 = flow.tensor.reshape %arg0 : tensor<f32> -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensorLoad
+func @tensorLoad(%arg0 : tensor<4x4xf32>, %arg1 : i32, %arg2 : i32) -> f32 {
+ // CHECK-NEXT: %0 = flow.tensor.load %arg0[%arg1, %arg2] : tensor<4x4xf32>
+ %0 = flow.tensor.load %arg0[%arg1, %arg2] : tensor<4x4xf32>
+ return %0 : f32
+}
+
+// CHECK-LABEL: @tensorLoadScalar
+func @tensorLoadScalar(%arg0 : tensor<f32>) -> f32 {
+ // CHECK-NEXT: %0 = flow.tensor.load %arg0 : tensor<f32>
+ %0 = flow.tensor.load %arg0 : tensor<f32>
+ return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @tensorStore
+func @tensorStore(%arg0 : tensor<4x4xf32>, %arg1 : i32, %arg2 : i32, %arg3 : f32) -> tensor<4x4xf32> {
+ // CHECK-NEXT: %0 = flow.tensor.store %arg3, %arg0[%arg1, %arg2] : tensor<4x4xf32>
+ %0 = flow.tensor.store %arg3, %arg0[%arg1, %arg2] : tensor<4x4xf32>
+ return %0 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: @tensorStoreScalar
+func @tensorStoreScalar(%arg0 : f32, %arg1 : tensor<f32>) -> tensor<f32> {
+ // CHECK-NEXT: %0 = flow.tensor.store %arg0, %arg1 : tensor<f32>
+ %0 = flow.tensor.store %arg0, %arg1 : tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensorSplat
+func @tensorSplat(%arg0 : f32) -> tensor<4x4xf32> {
+ // CHECK-NEXT: %0 = flow.tensor.splat %arg0 : tensor<4x4xf32>
+ %0 = flow.tensor.splat %arg0 : tensor<4x4xf32>
+ return %0 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: @tensorSplatScalar
+func @tensorSplatScalar(%arg0 : f32) -> tensor<f32> {
+ // CHECK-NEXT: %0 = flow.tensor.splat %arg0 : tensor<f32>
+ %0 = flow.tensor.splat %arg0 : tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensorClone
+func @tensorClone(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
+ // CHECK-NEXT: %0 = flow.tensor.clone %arg0 : tensor<4x4xf32>
+ %0 = flow.tensor.clone %arg0 : tensor<4x4xf32>
+ return %0 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: @tensorCloneScalar
+func @tensorCloneScalar(%arg0 : tensor<f32>) -> tensor<f32> {
+ // CHECK-NEXT: %0 = flow.tensor.clone %arg0 : tensor<f32>
+ %0 = flow.tensor.clone %arg0 : tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensorSlice
+func @tensorSlice(%arg0 : tensor<4x4xf32>, %arg1 : i32, %arg2 : i32) -> tensor<2x2xf32> {
+ // CHECK-NEXT: %0 = flow.tensor.slice %arg0[%arg1, %arg2 for %arg2, %arg1] : tensor<4x4xf32> -> tensor<2x2xf32>
+ %0 = flow.tensor.slice %arg0[%arg1, %arg2 for %arg2, %arg1] : tensor<4x4xf32> -> tensor<2x2xf32>
+ return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensorUpdate
+func @tensorUpdate(%arg0 : tensor<2x2xf32>, %arg1 : tensor<4x4xf32>, %arg2 : i32, %arg3 : i32) -> tensor<4x4xf32> {
+ // CHECK-NEXT: %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor<2x2xf32> -> tensor<4x4xf32>
+ %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor<2x2xf32> -> tensor<4x4xf32>
+ return %0 : tensor<4x4xf32>
+}