Cleaning up parsers/printers and generalizing tied operand/result handling. (#6785)
* Adding support for util.global type/initializer differences.
* Moving size-aware type interfaces and parsers/printers to util.
* Improving tied result printing.
* Allowing size-aware types in tied operand/result parsing.
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/test/lower_global_tensors.mlir b/integrations/tensorflow/iree_tf_compiler/TF/test/lower_global_tensors.mlir
index d93da89..1b5ff14 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/test/lower_global_tensors.mlir
+++ b/integrations/tensorflow/iree_tf_compiler/TF/test/lower_global_tensors.mlir
@@ -6,7 +6,7 @@
// TODO(silvasean): Verify "type" handling.
// I think when "type" is a partial type that flow will not model it correctly.
-// CHECK: util.global private mutable [[V:@[a-zA-Z0-9$._-]+]] = dense<1.000000e+00> : tensor<1xf32>
+// CHECK: util.global private mutable [[V:@.+]] : tensor<?xf32> = dense<1.000000e+00> : tensor<1xf32>
// CHECK: func @f() -> (tensor<?xf32> {tf_saved_model.index_path = []})
// CHECK-NEXT: [[PTR:%.+]] = util.global.address [[V]] : !util.ptr<tensor<?xf32>>
// CHECK-NEXT: [[T:%.+]] = util.global.load.indirect [[PTR]] : !util.ptr<tensor<?xf32>> -> tensor<?xf32>
@@ -26,9 +26,9 @@
// CHECK-LABEL: module attributes {tf_saved_model.semantics}
module attributes {tf_saved_model.semantics} {
-// CHECK: util.global private mutable [[V:@[a-zA-Z0-9$._-]+]] = dense<1.000000e+00> : tensor<1xf32>
+// CHECK: util.global private mutable [[V:@.+]] : tensor<?xf32> = dense<1.000000e+00> : tensor<1xf32>
// CHECK: func @f(%arg0: tensor<?xf32> {tf_saved_model.index_path = [0]})
-// CHECK-NEXT: [[PTR:%.+]] = util.global.address @__iree_flow_v : !util.ptr<tensor<?xf32>>
+// CHECK-NEXT: [[PTR:%.+]] = util.global.address [[V]] : !util.ptr<tensor<?xf32>>
// CHECK-NEXT: util.global.store.indirect %arg0, [[PTR]] : tensor<?xf32> -> !util.ptr<tensor<?xf32>>
// CHECK-NEXT: return
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/test/lower_global_tensors_complex.mlir b/integrations/tensorflow/iree_tf_compiler/TF/test/lower_global_tensors_complex.mlir
index 6e65a12..4758edd 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/test/lower_global_tensors_complex.mlir
+++ b/integrations/tensorflow/iree_tf_compiler/TF/test/lower_global_tensors_complex.mlir
@@ -5,7 +5,7 @@
// CHECK-LABEL: module attributes {tf_saved_model.semantics}
module attributes {tf_saved_model.semantics} {
-// CHECK: util.global private mutable [[V:@.+]] = dense<1.000000e+00> : tensor<1xf32>
+// CHECK: util.global private mutable [[V:@.+]] : tensor<?xf32> = dense<1.000000e+00> : tensor<1xf32>
// CHECK: func @f(%arg0: tensor<?xf32> {tf_saved_model.index_path = [0]}) attributes {tf_saved_model.exported_names = ["f"]} {
// CHECK-NEXT: [[PTR:%.+]] = util.global.address [[V]] : !util.ptr<tensor<?xf32>>
// CHECK-NEXT: br ^bb1([[PTR]] : !util.ptr<tensor<?xf32>>)
@@ -28,8 +28,8 @@
// CHECK-LABEL: module attributes {tf_saved_model.semantics}
module attributes {tf_saved_model.semantics} {
-// CHECK: util.global private mutable [[V:@.+]] = dense<1.000000e+00> : tensor<1xf32>
-// CHECK: util.global private mutable [[V1:@.+]] = dense<1.000000e+00> : tensor<1xf32>
+// CHECK: util.global private mutable [[V:@.+]] : tensor<?xf32> = dense<1.000000e+00> : tensor<1xf32>
+// CHECK: util.global private mutable [[V1:@.+]] : tensor<?xf32> = dense<1.000000e+00> : tensor<1xf32>
// CHECK: func @f(%arg0: tensor<?xf32> {tf_saved_model.index_path = [0]}) -> (tensor<?xf32> {tf_saved_model.index_path = [0]}) attributes {tf_saved_model.exported_names = ["f"]} {
// CHECK-NEXT: [[PTR0:%.+]] = util.global.address [[V]] : !util.ptr<tensor<?xf32>>
// CHECK-NEXT: [[PTR1:%.+]] = util.global.address [[V1]] : !util.ptr<tensor<?xf32>>
@@ -55,7 +55,7 @@
// CHECK-LABEL: module attributes {tf_saved_model.semantics}
module attributes {tf_saved_model.semantics} {
-// CHECK: util.global private mutable [[V:@.+]] = dense<1.000000e+00> : tensor<1xf32>
+// CHECK: util.global private mutable [[V:@.+]] : tensor<?xf32> = dense<1.000000e+00> : tensor<1xf32>
// CHECK: func @f(%arg0: tensor<?xf32> {tf_saved_model.index_path = [0]}) attributes {tf_saved_model.exported_names = ["f"]} {
// CHECK-NEXT: [[PTR:%.+]] = util.global.address [[V]] : !util.ptr<tensor<?xf32>>
// CHECK-NEXT: br ^bb1([[PTR]], [[PTR]], [[PTR]] : !util.ptr<tensor<?xf32>>, !util.ptr<tensor<?xf32>>, !util.ptr<tensor<?xf32>>)
diff --git a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/insert_slice.mlir b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/insert_slice.mlir
index 383624e..1c29bf4 100644
--- a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/insert_slice.mlir
+++ b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/insert_slice.mlir
@@ -16,7 +16,7 @@
// CHECK-DAG: %[[C4:.+]] = constant 4
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[UPDATE:.+]] = flow.tensor.update %[[ARG1]], %[[ARG0]][%[[C4]], %[[C2]], %[[C0]]]
-// CHECK-SAME: : tensor<1x4x48xf32> -> tensor<?x24x48xf32>{%[[DIM0]]}
+// CHECK-SAME: : tensor<1x4x48xf32> -> %[[ARG0]] as tensor<?x24x48xf32>{%[[DIM0]]}
// -----
@@ -37,7 +37,7 @@
// CHECK-DAG: %[[RESHAPE:.+]] = flow.tensor.reshape %[[ARG1]] : tensor<4x48xf32> -> tensor<1x4x48xf32>
// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[UPDATE:.+]] = flow.tensor.update %[[RESHAPE]], %[[ARG0]][%[[C4]], %[[C2]], %[[C0]]]
-// CHECK-SAME: : tensor<1x4x48xf32> -> tensor<?x24x48xf32>{%[[DIM]]}
+// CHECK-SAME: : tensor<1x4x48xf32> -> %[[ARG0]] as tensor<?x24x48xf32>{%[[DIM]]}
// -----
@@ -50,7 +50,7 @@
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK: %[[RESHAPE:.+]] = flow.tensor.reshape %{{.+}} : tensor<49x20xf32> -> tensor<1x49x20x1xf32>
-// CHECK: flow.tensor.update %[[RESHAPE]], %{{.+}}[%[[C0]], %[[C1]], %[[C0]], %[[C0]]] : tensor<1x49x20x1xf32> -> tensor<1x50x20x1xf32>
+// CHECK: flow.tensor.update %[[RESHAPE]], %{{.+}}[%[[C0]], %[[C1]], %[[C0]], %[[C0]]] : tensor<1x49x20x1xf32> -> %{{.+}} as tensor<1x50x20x1xf32>
// -----
diff --git a/iree/compiler/Dialect/Flow/IR/BUILD b/iree/compiler/Dialect/Flow/IR/BUILD
index 23d7550..5466c08 100644
--- a/iree/compiler/Dialect/Flow/IR/BUILD
+++ b/iree/compiler/Dialect/Flow/IR/BUILD
@@ -27,6 +27,7 @@
deps = [
"//iree/compiler/Dialect/Shape/IR:td_files",
"//iree/compiler/Dialect/Util/IR:td_files",
+ "@llvm-project//mlir:ControlFlowInterfacesTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectTdFiles",
@@ -42,7 +43,6 @@
"FlowEnums.cpp.inc",
"FlowOpFolders.cpp",
"FlowOpInterfaces.cpp.inc",
- "FlowOpUtils.cpp",
"FlowOps.cpp",
"FlowOps.cpp.inc",
"FlowTypeInterfaces.cpp.inc",
@@ -53,7 +53,6 @@
"FlowDialect.h",
"FlowEnums.h.inc",
"FlowOpInterfaces.h.inc",
- "FlowOpUtils.h",
"FlowOps.h",
"FlowOps.h.inc",
"FlowTypeInterfaces.h.inc",
@@ -68,6 +67,7 @@
"//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Util/IR",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:ControlFlowInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:MemRefDialect",
diff --git a/iree/compiler/Dialect/Flow/IR/CMakeLists.txt b/iree/compiler/Dialect/Flow/IR/CMakeLists.txt
index 245efb5..5a8860a 100644
--- a/iree/compiler/Dialect/Flow/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/IR/CMakeLists.txt
@@ -17,7 +17,6 @@
"FlowDialect.h"
"FlowEnums.h.inc"
"FlowOpInterfaces.h.inc"
- "FlowOpUtils.h"
"FlowOps.h"
"FlowOps.h.inc"
"FlowTypeInterfaces.h.inc"
@@ -28,7 +27,6 @@
"FlowEnums.cpp.inc"
"FlowOpFolders.cpp"
"FlowOpInterfaces.cpp.inc"
- "FlowOpUtils.cpp"
"FlowOps.cpp"
"FlowOps.cpp.inc"
"FlowTypeInterfaces.cpp.inc"
@@ -40,6 +38,7 @@
::FlowOpsGen
::FlowTypesGen
LLVMSupport
+ MLIRControlFlowInterfaces
MLIRIR
MLIRInferTypeOpInterface
MLIRMemRef
diff --git a/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td b/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td
index 9fc93f8..11f35e6 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td
+++ b/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td
@@ -10,81 +10,6 @@
include "iree/compiler/Dialect/Util/IR/UtilBase.td"
//===----------------------------------------------------------------------===//
-// IREE::Flow::ClosureOpInterface
-//===----------------------------------------------------------------------===//
-
-def FLOW_ClosureOpInterface : OpInterface<"ClosureOpInterface"> {
- let description = [{
- Interface for ops that follow the Flow dialect closure semantics (explicit
- captures, dynamic-shape awareness, and normal operand/result SSA behavior).
-
- Implementing this interface enables optimizations that perform manipulation
- across the closure capture boundary (outside of the op <-> regions within
- the op).
- }];
-
- let methods = [
- InterfaceMethod<
- /*desc=*/[{
- Returns the body region of the closure (may have multiple blocks).
- }],
- /*retTy=*/"Region &",
- /*methodName=*/"getClosureBodyRegion",
- /*args=*/(ins),
- /*methodBody=*/[{}],
- /*defaultImplementation=*/[{
- return this->getOperation()->getRegion(0);
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{Returns all closure operand values.}],
- /*retTy=*/"Operation::operand_range",
- /*methodName=*/"getClosureOperands",
- /*args=*/(ins)
- >,
- InterfaceMethod<
- /*desc=*/[{Returns all closure result values.}],
- /*retTy=*/"Operation::result_range",
- /*methodName=*/"getClosureResults",
- /*args=*/(ins)
- >,
- InterfaceMethod<
- /*desc=*/[{
- Returns true if the given operation can exist in the closure.
- Not all operations that a closure can contain are guaranteed to be folded
- into the closure, such as when the operation may have side-effects.
- }],
- /*retTy=*/"bool",
- /*methodName=*/"canClosureContainOp",
- /*args=*/(ins "Operation *":$op)
- >,
- InterfaceMethod<
- /*desc=*/[{
- Clones the op while removing specified operands and results.
- The body of the op will be transferred to the new op and the entry block
- will have its arguments removed.
-
- The returned op will be free standing. Callers must insert it into a block
- where desired (most often just replacing the current op).
- }],
- /*retTy=*/"ClosureOpInterface",
- /*methodName=*/"cloneReplacementExcludingOperandsAndResults",
- /*args=*/(ins "ArrayRef<unsigned>":$excludedOperandIndices,
- "ArrayRef<unsigned>":$excludedResultIndices,
- "PatternRewriter &":$rewriter)
- >,
- InterfaceMethod<
- /*desc=*/[{
- Returns true if the output is also read within the region.
- }],
- /*retTy=*/"bool",
- /*methodName=*/"isOutputReadWithinRegion",
- /*args=*/(ins "unsigned":$resultIndex)
- >
- ];
-}
-
-//===----------------------------------------------------------------------===//
// IREE::Flow::StreamableOpInterface
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index c4c8b5f..4753b8b 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -8,9 +8,9 @@
#include <numeric>
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
-#include "iree/compiler/Dialect/Flow/IR/FlowOpUtils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Optional.h"
@@ -228,7 +228,8 @@
void ExStreamFragmentOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<ClosureOptimizationPattern<ExStreamFragmentOp>>(context);
+ results.insert<IREE::Util::ClosureOptimizationPattern<ExStreamFragmentOp>>(
+ context);
results.insert<InsertImmutabilityPreservingStreamClones>(context);
// TODO(#6420): fix HAL lowering of this (or wait until streams are gone).
// results.insert<TieStreamResults>(context);
@@ -240,7 +241,8 @@
void DispatchWorkgroupsOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<ClosureOptimizationPattern<DispatchWorkgroupsOp>>(context);
+ results.insert<IREE::Util::ClosureOptimizationPattern<DispatchWorkgroupsOp>>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index 501ad88..1d8a835 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -6,8 +6,8 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "iree/compiler/Dialect/Flow/IR/FlowOpUtils.h"
#include "iree/compiler/Dialect/Shape/IR/Builders.h"
+#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/StringExtras.h"
@@ -41,13 +41,6 @@
// Op utilities used within the Flow dialect
//===----------------------------------------------------------------------===//
-// Returns true if the given |accessType| is compatible with the |variableType|.
-// For example, this will return true if the variable type is a tensor<?xf32>
-// and the access is tensor<4xf32>.
-static bool isVariableTypeCompatible(Type variableType, Type accessType) {
- return succeeded(mlir::verifyCompatibleShape(variableType, accessType));
-}
-
// Verifies that |dynamicDims| contains the appropriate number of dims for all
// of the dynamic dimensions in |values|.
static LogicalResult verifyOpDynamicDims(Operation *op, ValueRange values,
@@ -68,241 +61,6 @@
}
//===----------------------------------------------------------------------===//
-// custom<TiedResult>
-//===----------------------------------------------------------------------===//
-// type{%dim0, %dim1}
-// %arg0
-
-static ParseResult parseTiedResult(
- OpAsmParser &parser, Type &resultType,
- SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
- ArrayAttr &tiedOperands) {
- if (failed(parser.parseType(resultType))) return failure();
- if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
- if (!shapedType.hasStaticShape()) {
- SmallVector<OpAsmParser::OperandType, 4> dynamicDims;
- if (failed(parser.parseLBrace()) ||
- failed(parser.parseOperandList(dynamicDims,
- shapedType.getNumDynamicDims(),
- OpAsmParser::Delimiter::None)) ||
- failed(parser.parseRBrace())) {
- return failure();
- }
- resultDims.append(dynamicDims);
- }
- }
- tiedOperands = parser.getBuilder().getIndexArrayAttr({0});
- return success();
-}
-
-static void printTiedResult(OpAsmPrinter &p, Operation *op, Type resultType,
- ValueRange resultDims, ArrayAttr tiedOperands) {
- p.printType(resultType);
- if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
- if (!shapedType.hasStaticShape()) {
- if (resultDims.empty()) {
- p << "{<<INVALID>>}";
- return;
- }
- p << "{";
- llvm::interleaveComma(
- resultDims.take_front(shapedType.getNumDynamicDims()), p,
- [&](Value value) { p.printOperand(value); });
- p << "}";
- }
- }
-}
-
-//===----------------------------------------------------------------------===//
-// custom<ShapedFunctionType>
-//===----------------------------------------------------------------------===//
-// (type, type{%dim0, %dim1}, type) -> (type{%dim2}, %operand4)
-
-static ParseResult parseShapedOperandList(
- OpAsmParser &parser, SmallVectorImpl<Type> &types,
- SmallVectorImpl<OpAsmParser::OperandType> &dims) {
- do {
- Type type;
- if (failed(parser.parseType(type))) return failure();
- if (auto shapedType = type.dyn_cast<ShapedType>()) {
- if (!shapedType.hasStaticShape()) {
- SmallVector<OpAsmParser::OperandType, 4> dynamicDims;
- if (failed(parser.parseLBrace()) ||
- failed(parser.parseOperandList(dynamicDims,
- shapedType.getNumDynamicDims(),
- OpAsmParser::Delimiter::None)) ||
- failed(parser.parseRBrace())) {
- return failure();
- }
- dims.append(dynamicDims);
- }
- }
- types.push_back(type);
- } while (succeeded(parser.parseOptionalComma()));
- return success();
-}
-
-// Finds the operand index in |operands| that |tiedResult| references.
-// Returns TiedOpInterface::kUntiedIndex if no operand is found.
-static int64_t findTiedOperand(OpAsmParser::OperandType tiedResult,
- ArrayRef<OpAsmParser::OperandType> operands) {
- int64_t operandIndex = IREE::Util::TiedOpInterface::kUntiedIndex;
- for (int64_t i = 0; i < operands.size(); ++i) {
- if (operands[i].name == tiedResult.name) {
- operandIndex = i;
- break;
- }
- }
- return operandIndex;
-}
-
-static ParseResult parseShapedResultList(
- OpAsmParser &parser, ArrayRef<OpAsmParser::OperandType> operands,
- TypeRange operandTypes, ArrayRef<OpAsmParser::OperandType> operandDims,
- SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
- ArrayAttr &tiedOperands) {
- SmallVector<int64_t, 4> tiedOperandIndices;
- do {
- OpAsmParser::OperandType tiedResult;
- auto res = parser.parseOptionalOperand(tiedResult);
- Type type;
- int64_t tiedOperandIndex = IREE::Util::TiedOpInterface::kUntiedIndex;
- if (res.hasValue() && succeeded(res.getValue())) {
- tiedOperandIndex = findTiedOperand(tiedResult, operands);
- if (tiedOperandIndex == IREE::Util::TiedOpInterface::kUntiedIndex) {
- return parser.emitError(tiedResult.location,
- "tied operand not found for result reference ")
- << tiedResult.name;
- }
- if (succeeded(parser.parseOptionalKeyword("as"))) {
- // Type _may_ differ from the operand.
- if (failed(parser.parseType(type))) return failure();
- } else {
- // Use the operands type.
- type = operandTypes[tiedOperandIndex];
- }
- } else if (failed(parser.parseType(type))) {
- return failure();
- }
- if (auto shapedType = type.dyn_cast<ShapedType>()) {
- if (!shapedType.hasStaticShape()) {
- SmallVector<OpAsmParser::OperandType, 4> dynamicDims;
- if (failed(parser.parseLBrace()) ||
- failed(parser.parseOperandList(dynamicDims,
- shapedType.getNumDynamicDims(),
- OpAsmParser::Delimiter::None)) ||
- failed(parser.parseRBrace())) {
- return failure();
- }
- resultDims.append(dynamicDims);
- }
- }
- resultTypes.push_back(type);
- tiedOperandIndices.push_back(tiedOperandIndex);
- } while (succeeded(parser.parseOptionalComma()));
- if (!tiedOperandIndices.empty()) {
- tiedOperands = parser.getBuilder().getIndexArrayAttr(tiedOperandIndices);
- }
- return success();
-}
-
-static ParseResult parseShapedFunctionType(
- OpAsmParser &parser, ArrayRef<OpAsmParser::OperandType> operands,
- SmallVectorImpl<Type> &operandTypes,
- SmallVectorImpl<OpAsmParser::OperandType> &operandDims,
- SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
- ArrayAttr &tiedOperands) {
- if (failed(parser.parseLParen())) return failure();
- if (failed(parser.parseOptionalRParen())) {
- if (failed(parseShapedOperandList(parser, operandTypes, operandDims)) ||
- failed(parser.parseRParen())) {
- return failure();
- }
- }
- if (failed(parser.parseArrow())) return failure();
- if (succeeded(parser.parseOptionalLParen())) {
- if (failed(parseShapedResultList(parser, operands, operandTypes,
- operandDims, resultTypes, resultDims,
- tiedOperands)) ||
- failed(parser.parseRParen())) {
- return failure();
- }
- } else {
- if (failed(parseShapedResultList(parser, operands, operandTypes,
- operandDims, resultTypes, resultDims,
- tiedOperands))) {
- return failure();
- }
- }
- return success();
-}
-
-static void printShapedFunctionType(OpAsmPrinter &p, Operation *op,
- ValueRange operands, TypeRange operandTypes,
- OperandRange operandDims,
- TypeRange resultTypes,
- OperandRange resultDims,
- ArrayAttr tiedOperands) {
- p << "(";
- llvm::interleaveComma(operandTypes, p, [&](Type type) {
- p.printType(type);
- if (auto shapedType = type.dyn_cast<ShapedType>()) {
- if (!shapedType.hasStaticShape()) {
- if (operandDims.empty()) {
- p << "{<<INVALID>>}";
- return;
- }
- p << "{";
- llvm::interleaveComma(
- operandDims.take_front(shapedType.getNumDynamicDims()), p,
- [&](Value value) { p.printOperand(value); });
- p << "}";
- operandDims = operandDims.drop_front(shapedType.getNumDynamicDims());
- }
- }
- });
- p << ") -> ";
- if (resultTypes.size() != 1) p << "(";
- auto tiedOp = cast<IREE::Util::TiedOpInterface>(op);
- for (unsigned i = 0; i < resultTypes.size(); ++i) {
- auto resultType = resultTypes[i];
- auto tiedOperandIndex = tiedOp.getTiedResultOperandIndex(i);
- bool printType = true;
- if (tiedOperandIndex.hasValue()) {
- auto tiedOperand = op->getOperand(tiedOperandIndex.getValue());
- p.printOperand(tiedOperand);
- if (tiedOperand.getType() != resultType) {
- p << " as ";
- } else {
- // Type elided as it matches the operand.
- printType = false;
- }
- }
- if (printType) {
- p.printType(resultType);
- }
- if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
- if (!shapedType.hasStaticShape()) {
- if (resultDims.empty()) {
- p << "{<<INVALID>>}";
- return;
- }
- p << "{";
- llvm::interleaveComma(
- resultDims.take_front(shapedType.getNumDynamicDims()), p,
- [&](Value value) { p.printOperand(value); });
- p << "}";
- resultDims = resultDims.drop_front(shapedType.getNumDynamicDims());
- }
- }
- if (i < resultTypes.size() - 1) p << ", ";
- }
- if (resultTypes.size() != 1) p << ")";
-}
-
-//===----------------------------------------------------------------------===//
// flow.dispatch.tensor.load
//===----------------------------------------------------------------------===//
@@ -565,27 +323,66 @@
return canDispatchRegionContainOp(op);
}
-bool DispatchWorkgroupsOp::isOutputReadWithinRegion(unsigned resultIndex) {
- unsigned startIndex = getBody()->getNumArguments() - getNumResults();
- BlockArgument arg = body().front().getArgument(startIndex + resultIndex);
- // If argument is of `writeonly` access, then it is not read by construction.
- if (arg.getType().cast<DispatchTensorType>().getAccess() ==
- TensorAccess::WriteOnly) {
- return false;
- }
- // If the argument is a result with `readwrite` access, return false if the
- // value is only written to. Check this by looking at the uses of the argument
- // being only the `target` of `flow.dispatch.tensor.store` ops.
- for (OpOperand &uses : arg.getUses()) {
- auto storeOp = dyn_cast<DispatchTensorStoreOp>(uses.getOwner());
- if (!(storeOp && storeOp.target() == uses.get())) {
- return true;
+// Refines the tensor access from what is declared on |type| based on actual
+// usage. We expect that the access was set correctly to begin with but today
+// we sometimes specify things too wide.
+static TensorAccess refineTensorAccess(Value value, DispatchTensorType type) {
+ auto tensorAccess = type.getAccess();
+ if (tensorAccess == TensorAccess::ReadWrite) {
+ // If the argument is a result with `readwrite` access, return false if the
+ // value is only written to. Check this by looking at the uses of the
+ // argument being only the `target` of `flow.dispatch.tensor.store` ops.
+ bool onlyWrites = true;
+ for (OpOperand &uses : value.getUses()) {
+ auto storeOp = dyn_cast<DispatchTensorStoreOp>(uses.getOwner());
+ if (!(storeOp && storeOp.target() == uses.get())) {
+ onlyWrites = false;
+ break;
+ }
}
+ if (onlyWrites) tensorAccess = TensorAccess::WriteOnly;
}
- return false;
+ return tensorAccess;
}
-ClosureOpInterface
+IREE::Util::ValueAccess DispatchWorkgroupsOp::getOperandAccess(
+ unsigned operandIndex) {
+ BlockArgument arg = body().front().getArgument(operandIndex);
+ if (auto tensorType = arg.getType().dyn_cast<DispatchTensorType>()) {
+ auto tensorAccess = refineTensorAccess(arg, tensorType);
+ return IREE::Util::ValueAccess(
+ /*isRead=*/(tensorAccess == TensorAccess::ReadOnly) ||
+ (tensorAccess == TensorAccess::ReadWrite),
+ /*isWrite=*/(tensorAccess == TensorAccess::ReadWrite) ||
+ (tensorAccess == TensorAccess::WriteOnly),
+ /*isDiscard=*/(tensorAccess == TensorAccess::WriteOnly));
+ } else {
+ return IREE::Util::ValueAccess(/*isRead=*/!arg.use_empty(),
+ /*isWrite=*/false,
+ /*isDiscard=*/false);
+ }
+}
+
+IREE::Util::ValueAccess DispatchWorkgroupsOp::getResultAccess(
+ unsigned resultIndex) {
+ unsigned startIndex = getBody()->getNumArguments() - getNumResults();
+ BlockArgument arg = body().front().getArgument(startIndex + resultIndex);
+ if (auto tensorType = arg.getType().dyn_cast<DispatchTensorType>()) {
+ auto tensorAccess = refineTensorAccess(arg, tensorType);
+ return IREE::Util::ValueAccess(
+ /*isRead=*/(tensorAccess == TensorAccess::ReadOnly) ||
+ (tensorAccess == TensorAccess::ReadWrite),
+ /*isWrite=*/(tensorAccess == TensorAccess::ReadWrite) ||
+ (tensorAccess == TensorAccess::WriteOnly),
+ /*isDiscard=*/(tensorAccess == TensorAccess::WriteOnly));
+ } else {
+ return IREE::Util::ValueAccess(/*isRead=*/!arg.use_empty(),
+ /*isWrite=*/false,
+ /*isDiscard=*/false);
+ }
+}
+
+IREE::Util::ClosureOpInterface
DispatchWorkgroupsOp::cloneReplacementExcludingOperandsAndResults(
ArrayRef<unsigned> excludedOperandIndices,
ArrayRef<unsigned> excludedResultIndices, PatternRewriter &rewriter) {
@@ -593,9 +390,9 @@
SmallVector<Value, 4> newResultDims = llvm::to_vector<4>(result_dims());
SmallVector<Value, 4> newOperandsValues = llvm::to_vector<4>(operands());
SmallVector<Value, 4> newOperandDims = llvm::to_vector<4>(operand_dims());
- excludeClosureOperandsAndResults(newOperandsValues, newOperandDims,
- excludedOperandIndices, newResultTypes,
- newResultDims, excludedResultIndices);
+ IREE::Util::excludeClosureOperandsAndResults(
+ newOperandsValues, newOperandDims, excludedOperandIndices, newResultTypes,
+ newResultDims, excludedResultIndices);
auto newTiedOperandIndices =
llvm::to_vector<4>(getTiedResultOperandIndices());
@@ -1151,11 +948,20 @@
return false;
}
-bool ExStreamFragmentOp::isOutputReadWithinRegion(unsigned resultIndex) {
- return false;
+IREE::Util::ValueAccess ExStreamFragmentOp::getOperandAccess(
+ unsigned operandIndex) {
+ return !isOperandTied(operandIndex) ? IREE::Util::ValueAccess::ReadOnly()
+ : IREE::Util::ValueAccess::ReadWrite();
}
-ClosureOpInterface
+IREE::Util::ValueAccess ExStreamFragmentOp::getResultAccess(
+ unsigned resultIndex) {
+ return getTiedResultOperandIndex(resultIndex).hasValue()
+ ? IREE::Util::ValueAccess::ReadWrite()
+ : IREE::Util::ValueAccess::DiscardWrite();
+}
+
+IREE::Util::ClosureOpInterface
ExStreamFragmentOp::cloneReplacementExcludingOperandsAndResults(
ArrayRef<unsigned> excludedOperandIndices,
ArrayRef<unsigned> excludedResultIndices, PatternRewriter &rewriter) {
@@ -1163,9 +969,9 @@
SmallVector<Value, 4> newResultDims = llvm::to_vector<4>(result_dims());
SmallVector<Value, 4> newOperandsValues = llvm::to_vector<4>(operands());
SmallVector<Value, 4> newOperandDims = llvm::to_vector<4>(operand_dims());
- excludeClosureOperandsAndResults(newOperandsValues, newOperandDims,
- excludedOperandIndices, newResultTypes,
- newResultDims, excludedResultIndices);
+ IREE::Util::excludeClosureOperandsAndResults(
+ newOperandsValues, newOperandDims, excludedOperandIndices, newResultTypes,
+ newResultDims, excludedResultIndices);
auto newTiedOperandIndices =
llvm::to_vector<4>(getTiedResultOperandIndices());
@@ -1179,7 +985,7 @@
newOperandDims, newTiedOperandIndices, getOperation()->getAttrs());
auto &newBody = newOp.getClosureBodyRegion();
newBody.takeBody(getClosureBodyRegion());
- eraseRegionResults(newBody, excludedResultIndices);
+ IREE::Util::eraseRegionResults(newBody, excludedResultIndices);
newBody.front().eraseArguments(excludedOperandIndices);
return newOp;
}
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.h b/iree/compiler/Dialect/Flow/IR/FlowOps.h
index e1d1230..a455cad 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.h
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.h
@@ -23,6 +23,7 @@
#include "mlir/IR/FunctionSupport.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.td b/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 3f5e724..ee24714 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -13,6 +13,7 @@
include "iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
@@ -28,7 +29,7 @@
IsolatedFromAbove,
AttrSizedOperandSegments,
SingleBlockImplicitTerminator<"IREE::Flow::ReturnOp">,
- DeclareOpInterfaceMethods<FLOW_ClosureOpInterface>,
+ DeclareOpInterfaceMethods<Util_ClosureOpInterface>,
DeclareOpInterfaceMethods<Util_TiedOpInterface, [
"getTiedOperandsIndexAndLength",
]>,
@@ -421,7 +422,7 @@
let hasCanonicalizer = 1;
}
-def FLOW_ReturnOp : FLOW_Op<"return", [Terminator]> {
+def FLOW_ReturnOp : FLOW_Op<"return", [NoSideEffect, ReturnLike, Terminator]> {
let summary = [{return from a flow.dispatch_region}];
let description = [{
Returns the given values from the region and back to the host code.
@@ -874,7 +875,7 @@
let assemblyFormat = [{
$update `,` $target `[` $start_indices `]` `:`
type($update) (`{` $update_dims^ `}`)? `->`
- custom<TiedResult>(type($result), $target_dims, $tied_operands)
+ custom<ShapedTiedResult>(type($result), $target_dims, $tied_operands)
attr-dict-with-keyword
}];
@@ -921,7 +922,7 @@
def FLOW_ExStreamFragmentOp : FLOW_PureOp<"ex.stream.fragment", [
IsolatedFromAbove,
AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<FLOW_ClosureOpInterface>,
+ DeclareOpInterfaceMethods<Util_ClosureOpInterface>,
DeclareOpInterfaceMethods<Util_TiedOpInterface>,
DeclareOpInterfaceMethods<Shape_ShapeCarryingOpInterface>,
]> {
diff --git a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir
index 4b177ac..93e5259 100644
--- a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir
@@ -47,13 +47,32 @@
// -----
-// CHECK-LABEL: func @removeUnusedResult
+// CHECK-LABEL: func @removeUnusedProducedResult
+// CHECK-SAME: (%[[ARG0:.+]]: index)
+func @removeUnusedProducedResult(%arg0: index) -> index {
+ // CHECK: flow.ex.stream.fragment(%[[ARG0]]) : (index) -> index =
+ %0:2 = flow.ex.stream.fragment(%arg0) : (index) -> (index, index) =
+ (%arg0_in: index) -> (index, index) {
+ // CHECK: %[[T:.+]] = addi
+ %t = addi %arg0_in, %arg0_in : index
+ %unused = muli %arg0_in, %arg0_in : index
+ // CHECK: flow.return %[[T]] : index
+ flow.return %t, %unused : index, index
+ }
+ return %0#0 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @removeUnusedPassThroughResult
// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
-func @removeUnusedResult(%arg0: index, %arg1: index) -> index {
+func @removeUnusedPassThroughResult(%arg0: index, %arg1: index) -> index {
// CHECK: flow.ex.stream.fragment(%[[ARG1]])
%0:2 = flow.ex.stream.fragment(%arg0, %arg1) : (index, index) -> (index, index) =
- (%unused: index, %arg1: index) -> (index, index) {
- %t = addi %arg1, %arg1 : index
+ (%unused: index, %arg1_in: index) -> (index, index) {
+ // CHECK: %[[T:.+]] = addi
+ %t = addi %arg1_in, %arg1_in : index
+ // CHECK: flow.return %[[T]] : index
flow.return %t, %unused : index, index
}
return %0#0 : index
@@ -69,7 +88,7 @@
// CHECK: flow.ex.stream.fragment(%[[ARG1]]) :
%0:2 = flow.ex.stream.fragment(%arg0, %arg1) :
// CHECK-SAME: (tensor<8x?xf32>{%[[DIM1]]}) -> %[[ARG1]]{%[[DIM1]]} =
- (tensor<4x?xf32>{%dim0}, tensor<8x?xf32>{%dim1}) -> (%arg0{%dim0}, %arg1{%dim1}) =
+ (tensor<4x?xf32>{%dim0}, tensor<8x?xf32>{%dim1}) -> (tensor<4x?xf32>{%dim0}, %arg1{%dim1}) =
// CHECK-NEXT: (%[[INNER_ARG:.+]]: tensor<8x?xf32>) -> tensor<8x?xf32>
(%unused: tensor<4x?xf32>, %arg1: tensor<8x?xf32>) -> (tensor<4x?xf32>, tensor<8x?xf32>) {
// CHECK-NEXT: flow.return %[[INNER_ARG]] : tensor<8x?xf32>
@@ -96,7 +115,7 @@
%workload = constant 8 : index
// CHECK: %[[TARGET_CLONE:.+]] = flow.tensor.clone %[[TARGET]] : tensor<2x4xi32>
// CHECK: %[[UPDATED:.+]] = flow.tensor.update %[[UPDATE]], %[[TARGET]]
- %t0 = flow.tensor.update %stream_update, %stream_target[%start0, %start1] : tensor<1x1xi32> -> tensor<2x4xi32>
+ %t0 = flow.tensor.update %stream_update, %stream_target[%start0, %start1] : tensor<1x1xi32> -> %stream_target as tensor<2x4xi32>
// CHECK-NEXT: %[[RETURN:.+]] = flow.dispatch @ex::@entry[%c8](%[[TARGET_CLONE]], %[[UPDATED]])
%t1 = flow.dispatch @ex::@entry[%workload](%stream_target, %t0) : (tensor<2x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
// CHECK-NEXT: flow.return %[[RETURN]]
@@ -165,7 +184,7 @@
%5 = util.global.load @_large_const : tensor<7xi32>
// CHECK: %[[CLONE:.+]] = flow.tensor.clone %[[LOAD]]
// CHECK: flow.tensor.update %{{.+}}, %[[CLONE]]
- %6 = flow.tensor.update %arg0, %5[%c3] : tensor<2xi32> -> tensor<7xi32>
+ %6 = flow.tensor.update %arg0, %5[%c3] : tensor<2xi32> -> %5 as tensor<7xi32>
flow.return %6 : tensor<7xi32>
}
return %4 : tensor<7xi32>
diff --git a/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir
index 84fef08..12f2481 100644
--- a/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir
@@ -143,8 +143,8 @@
// CHECK-LABEL: @tensorUpdate
func @tensorUpdate(%arg0 : tensor<2x2xf32>, %arg1 : tensor<4x4xf32>, %arg2 : index, %arg3 : index) -> tensor<4x4xf32> {
- // CHECK-NEXT: %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor<2x2xf32> -> tensor<4x4xf32>
- %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor<2x2xf32> -> tensor<4x4xf32>
+ // CHECK-NEXT: %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor<2x2xf32> -> %arg1 as tensor<4x4xf32>
+ %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor<2x2xf32> -> %arg1 as tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
@@ -153,7 +153,7 @@
%c1 = constant 1 : index
%c2 = constant 2 : index
%c3 = constant 3 : index
- // CHECK: %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor<?x?xf32>{%c1, %c2} -> tensor<?x4xf32>{%c3}
- %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor<?x?xf32>{%c1, %c2} -> tensor<?x4xf32>{%c3}
+ // CHECK: %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor<?x?xf32>{%c1, %c2} -> %arg1 as tensor<?x4xf32>{%c3}
+ %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor<?x?xf32>{%c1, %c2} -> %arg1 as tensor<?x4xf32>{%c3}
return %0 : tensor<?x4xf32>
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_after.mlir b/iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_after.mlir
index e3e964e..a7716d2 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_after.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_after.mlir
@@ -10,7 +10,7 @@
%4 = affine.apply affine_map<(d0)[s0, s1] -> (d0 + s0 + s1)>(%2)[%arg3, %arg5]
%5 = linalg.init_tensor [%3, %4] : tensor<?x?xf32>
%6 = linalg.fill(%0, %5) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
- %7 = flow.tensor.update %arg0, %6[%arg2, %arg3] : tensor<?x?xf32>{%1, %2} -> tensor<?x?xf32>{%3, %4}
+ %7 = flow.tensor.update %arg0, %6[%arg2, %arg3] : tensor<?x?xf32>{%1, %2} -> %6 as tensor<?x?xf32>{%3, %4}
return %7 : tensor<?x?xf32>
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index afe63b1..5d846cd 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -334,7 +334,7 @@
%4 = affine.apply affine_map<(d0)[s0, s1] -> (d0 + s0 + s1)>(%2)[%arg3, %arg5]
%5 = linalg.init_tensor [%3, %4] : tensor<?x?xf32>
%6 = linalg.fill(%0, %5) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
- %7 = flow.tensor.update %arg0, %6[%arg2, %arg3] : tensor<?x?xf32>{%1, %2} -> tensor<?x?xf32>{%3, %4}
+ %7 = flow.tensor.update %arg0, %6[%arg2, %arg3] : tensor<?x?xf32>{%1, %2} -> %6 as tensor<?x?xf32>{%3, %4}
return %7 : tensor<?x?xf32>
}
@@ -639,7 +639,7 @@
%c5_i32 = constant 5 : i32
%c0_i32 = constant 0 : i32
%c9_i32 = constant 9 : i32
- %245 = flow.tensor.update %240, %244[%c9] : tensor<9xi32> -> tensor<18xi32>
+ %245 = flow.tensor.update %240, %244[%c9] : tensor<9xi32> -> %244 as tensor<18xi32>
%248 = tensor.extract %247[] : tensor<i32>
%249 = cmpi slt, %248, %c9_i32 : i32
%250 = select %249, %248, %c9_i32 : i32
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir
index ccfaf18..8a52b41 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir
@@ -284,7 +284,7 @@
// CHECK-SAME: source(%[[UBUF]] : !hal.buffer)[%c0]
// CHECK-SAME: target(%[[RET_BUF]] : !hal.buffer)[%c204]
// CHECK-SAME: length(%c40)
- %1 = flow.tensor.update %arg2, %clone[%arg4, %arg5, %arg5] : tensor<1x1x10xf32> -> tensor<5x1x10xf32>
+ %1 = flow.tensor.update %arg2, %clone[%arg4, %arg5, %arg5] : tensor<1x1x10xf32> -> %clone as tensor<5x1x10xf32>
flow.return %1 : tensor<5x1x10xf32>
}
// CHECK: hal.command_buffer.end<%[[CMD]]
@@ -571,7 +571,7 @@
// CHECK: %[[CSTBUF:.+]] = util.global.load @_const_pool_splats : !hal.buffer
// CHECK: hal.command_buffer.copy_buffer<%cmd : !hal.command_buffer> source(%[[CSTBUF]] : !hal.buffer)[%[[C0]]] target(%[[DSTBUF]] : !hal.buffer)[%[[C0]]] length(%[[C28]])
%2 = flow.tensor.clone %const_span : tensor<7xi32>
- %3 = flow.tensor.update %arg0, %2[%c3] : tensor<2xi32> -> tensor<7xi32>
+ %3 = flow.tensor.update %arg0, %2[%c3] : tensor<2xi32> -> %2 as tensor<7xi32>
flow.return %3 : tensor<7xi32>
}
return %1 : tensor<7xi32>
diff --git a/iree/compiler/Dialect/HAL/IR/HALInterfaces.td b/iree/compiler/Dialect/HAL/IR/HALInterfaces.td
index ac0d32c..cdf89ae 100644
--- a/iree/compiler/Dialect/HAL/IR/HALInterfaces.td
+++ b/iree/compiler/Dialect/HAL/IR/HALInterfaces.td
@@ -9,68 +9,6 @@
include "iree/compiler/Dialect/Util/IR/UtilBase.td"
-//===----------------------------------------------------------------------===//
-// IREE::HAL::SizeAwareOpInterface
-//===----------------------------------------------------------------------===//
-
-def HAL_InferTypeSize : TypeInterface<"InferTypeSizeInterface"> {
- let description = [{
- Allows types to be queried for their size by inserting the required logic
- when required.
- }];
-
- let methods = [
- InterfaceMethod<
- [{Builds an expression computing the size of the value.}],
- "Value", "inferSizeFromValue", (ins "Location":$loc,
- "Value":$value,
- "OpBuilder &":$builder)
- >,
- ];
-}
-
-def HAL_SizeAwareType : TypeInterface<"SizeAwareTypeInterface"> {
- let description = [{
- Denotes that a type is size-aware and must always have a size value
- associated with it in the IR. See `SizeAwareOp` for more information.
- }];
-
- let methods = [
- InterfaceMethod<
- [{Returns a size for the given sized value.}],
- "Value", "getSize", (ins "Value":$value)
- >,
- ];
-}
-
-def HAL_SizeAwareOp : OpInterface<"SizeAwareOpInterface"> {
- let description = [{
- An operation that is able to provide size values for all size-aware operands
- and results.
- }];
-
- let methods = [
- InterfaceMethod<
- [{Returns a size for the given sized operand index.}],
- "Value", "getOperandSize", (ins "unsigned":$idx)
- >,
- InterfaceMethod<
- [{Returns a size for the given sized result index.}],
- "Value", "getResultSize", (ins "unsigned":$idx)
- >,
- InterfaceMethod<
- [{Returns a size for the given sized result value.}],
- "Value", "getResultSizeFromValue", (ins "Value":$value),
- /*defaultImplementation=*/[{
- for (unsigned i = 0; i < $_self->getNumResults(); ++i) {
- if ($_self->getResult(i) == value) return $_self.getResultSize(i);
- }
- return {};
- }]
- >,
- ];
-}
-
def HAL_MatchAttrInterface :
AttrInterface<"MatchAttrInterface"> {
let description = [{
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index edb6f90..e5f26b8 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -51,65 +51,6 @@
}
//===----------------------------------------------------------------------===//
-// custom<SizeAwareType>
-//===----------------------------------------------------------------------===//
-// type{%size}
-
-static ParseResult parseSizeAwareType(OpAsmParser &parser, Type &type,
- OpAsmParser::OperandType &size) {
- if (failed(parser.parseType(type)) || failed(parser.parseLBrace()) ||
- failed(parser.parseOperand(size)) || failed(parser.parseRBrace())) {
- return failure();
- }
- return success();
-}
-
-static void printSizeAwareType(OpAsmPrinter &p, Operation *op, Type type,
- Value size) {
- p.printType(type);
- p << "{";
- p.printOperand(size);
- p << "}";
-}
-
-//===----------------------------------------------------------------------===//
-// custom<SizeAwareTypeList>
-//===----------------------------------------------------------------------===//
-// (type{%size0}, type, type{%size1})
-
-static ParseResult parseSizeAwareTypeList(
- OpAsmParser &parser, SmallVectorImpl<Type> &types,
- SmallVectorImpl<OpAsmParser::OperandType> &sizes) {
- do {
- Type type;
- if (failed(parser.parseType(type))) return failure();
- if (type.isa<SizeAwareTypeInterface>()) {
- OpAsmParser::OperandType size;
- if (failed(parser.parseLBrace()) || failed(parser.parseOperand(size)) ||
- failed(parser.parseRBrace())) {
- return failure();
- }
- sizes.push_back(size);
- }
- types.push_back(type);
- } while (succeeded(parser.parseOptionalComma()));
- return success();
-}
-
-static void printSizeAwareTypeList(OpAsmPrinter &p, Operation *op,
- TypeRange types, OperandRange sizes) {
- int sizeIndex = 0;
- llvm::interleaveComma(types, p, [&](Type type) {
- p.printType(type);
- if (type.isa<SizeAwareTypeInterface>()) {
- p << "{";
- p.printOperand(sizes[sizeIndex++]);
- p << "}";
- }
- });
-}
-
-//===----------------------------------------------------------------------===//
// custom<DescriptorSetBindings>($binding_ordinals,
// $binding_buffers,
// type($binding_buffers),
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td
index 2d87fe6..6879e69 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -267,7 +267,7 @@
def HAL_AllocatorAllocateOp : HAL_Op<"allocator.allocate", [
DeclareOpInterfaceMethods<OpAsmOpInterface>,
- DeclareOpInterfaceMethods<HAL_SizeAwareOp>,
+ DeclareOpInterfaceMethods<Util_SizeAwareOp>,
]> {
let summary = [{empty buffer allocation operation}];
let description = [{
@@ -331,7 +331,7 @@
def HAL_AllocatorMapOp : HAL_Op<"allocator.map", [
DeclareOpInterfaceMethods<OpAsmOpInterface>,
- DeclareOpInterfaceMethods<HAL_SizeAwareOp>,
+ DeclareOpInterfaceMethods<Util_SizeAwareOp>,
]> {
let summary = [{allocator-supported host buffer wrapping operation}];
let description = [{
@@ -498,7 +498,7 @@
def HAL_BufferSubspanOp : HAL_PureOp<"buffer.subspan", [
DeclareOpInterfaceMethods<OpAsmOpInterface>,
- DeclareOpInterfaceMethods<HAL_SizeAwareOp>,
+ DeclareOpInterfaceMethods<Util_SizeAwareOp>,
]> {
let summary = [{buffer subspan operation}];
let description = [{
diff --git a/iree/compiler/Dialect/HAL/IR/HALTypes.cpp b/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
index 2f9c7b6..e0c413e 100644
--- a/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
@@ -380,7 +380,7 @@
// Returns the SSA value containing the size of the given |value|.
static Value lookupValueSize(Value value) {
- assert(value.getType().isa<SizeAwareTypeInterface>());
+ assert(value.getType().isa<IREE::Util::SizeAwareTypeInterface>());
auto definingOp = value.getDefiningOp();
if (!definingOp) {
@@ -402,7 +402,7 @@
}
}
assert(resultIndex != -1 && "result not in results");
- auto sizeAwareOp = dyn_cast<SizeAwareOpInterface>(definingOp);
+ auto sizeAwareOp = dyn_cast<IREE::Util::SizeAwareOpInterface>(definingOp);
if (!sizeAwareOp) return {};
return sizeAwareOp.getResultSize(resultIndex);
}
diff --git a/iree/compiler/Dialect/HAL/IR/HALTypes.h b/iree/compiler/Dialect/HAL/IR/HALTypes.h
index a0abc18..ac22bc1 100644
--- a/iree/compiler/Dialect/HAL/IR/HALTypes.h
+++ b/iree/compiler/Dialect/HAL/IR/HALTypes.h
@@ -79,16 +79,18 @@
using Base::Base;
};
-class BufferType : public Type::TypeBase<BufferType, Type, TypeStorage,
- InferTypeSizeInterface::Trait> {
+class BufferType
+ : public Type::TypeBase<BufferType, Type, TypeStorage,
+ IREE::Util::InferTypeSizeInterface::Trait> {
public:
using Base::Base;
Value inferSizeFromValue(Location loc, Value value, OpBuilder &builder) const;
};
-class BufferViewType : public Type::TypeBase<BufferViewType, Type, TypeStorage,
- InferTypeSizeInterface::Trait> {
+class BufferViewType
+ : public Type::TypeBase<BufferViewType, Type, TypeStorage,
+ IREE::Util::InferTypeSizeInterface::Trait> {
public:
using Base::Base;
diff --git a/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp b/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp
index ebe1e99..20296da 100644
--- a/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp
+++ b/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp
@@ -67,16 +67,17 @@
loc, builder.getIndexType(), value);
}
- if (auto awareOp = dyn_cast_or_null<SizeAwareOpInterface>(definingOp)) {
+ if (auto awareOp =
+ dyn_cast_or_null<IREE::Util::SizeAwareOpInterface>(definingOp)) {
return awareOp.getResultSizeFromValue(value);
}
auto type = value.getType();
- if (auto awareType = type.dyn_cast<SizeAwareTypeInterface>()) {
+ if (auto awareType = type.dyn_cast<IREE::Util::SizeAwareTypeInterface>()) {
auto sizeValue = awareType.getSize(value);
if (sizeValue) return sizeValue;
}
- if (auto inferType = type.dyn_cast<InferTypeSizeInterface>()) {
+ if (auto inferType = type.dyn_cast<IREE::Util::InferTypeSizeInterface>()) {
return inferType.inferSizeFromValue(loc, value, builder);
}
diff --git a/iree/compiler/Dialect/Util/IR/BUILD b/iree/compiler/Dialect/Util/IR/BUILD
index cf34616..b5ae87d 100644
--- a/iree/compiler/Dialect/Util/IR/BUILD
+++ b/iree/compiler/Dialect/Util/IR/BUILD
@@ -35,21 +35,29 @@
cc_library(
name = "IR",
srcs = [
+ "ClosureOpUtils.cpp",
"UtilDialect.cpp",
"UtilOpFolders.cpp",
- "UtilOpInterfaces.cpp.inc",
"UtilOps.cpp",
"UtilOps.cpp.inc",
"UtilTypes.cpp",
],
hdrs = [
+ "ClosureOpUtils.h",
"UtilDialect.h",
- "UtilOpInterfaces.h.inc",
"UtilOps.h",
"UtilOps.h.inc",
"UtilTraits.h",
"UtilTypes.h",
],
+ textual_hdrs = [
+ "UtilAttrInterfaces.cpp.inc",
+ "UtilAttrInterfaces.h.inc",
+ "UtilOpInterfaces.cpp.inc",
+ "UtilOpInterfaces.h.inc",
+ "UtilTypeInterfaces.cpp.inc",
+ "UtilTypeInterfaces.h.inc",
+ ],
deps = [
":UtilInterfacesGen",
":UtilOpsGen",
@@ -69,6 +77,14 @@
name = "UtilInterfacesGen",
tbl_outs = [
(
+ ["-gen-attr-interface-decls"],
+ "UtilAttrInterfaces.h.inc",
+ ),
+ (
+ ["-gen-attr-interface-defs"],
+ "UtilAttrInterfaces.cpp.inc",
+ ),
+ (
["-gen-op-interface-decls"],
"UtilOpInterfaces.h.inc",
),
@@ -76,6 +92,14 @@
["-gen-op-interface-defs"],
"UtilOpInterfaces.cpp.inc",
),
+ (
+ ["-gen-type-interface-decls"],
+ "UtilTypeInterfaces.h.inc",
+ ),
+ (
+ ["-gen-type-interface-defs"],
+ "UtilTypeInterfaces.cpp.inc",
+ ),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "UtilInterfaces.td",
diff --git a/iree/compiler/Dialect/Util/IR/CMakeLists.txt b/iree/compiler/Dialect/Util/IR/CMakeLists.txt
index 72d4d1f..412793e 100644
--- a/iree/compiler/Dialect/Util/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/Util/IR/CMakeLists.txt
@@ -14,16 +14,23 @@
NAME
IR
HDRS
+ "ClosureOpUtils.h"
"UtilDialect.h"
- "UtilOpInterfaces.h.inc"
"UtilOps.h"
"UtilOps.h.inc"
"UtilTraits.h"
"UtilTypes.h"
+ TEXTUAL_HDRS
+ "UtilAttrInterfaces.cpp.inc"
+ "UtilAttrInterfaces.h.inc"
+ "UtilOpInterfaces.cpp.inc"
+ "UtilOpInterfaces.h.inc"
+ "UtilTypeInterfaces.cpp.inc"
+ "UtilTypeInterfaces.h.inc"
SRCS
+ "ClosureOpUtils.cpp"
"UtilDialect.cpp"
"UtilOpFolders.cpp"
- "UtilOpInterfaces.cpp.inc"
"UtilOps.cpp"
"UtilOps.cpp.inc"
"UtilTypes.cpp"
@@ -47,8 +54,12 @@
TD_FILE
"UtilInterfaces.td"
OUTS
+ -gen-attr-interface-decls UtilAttrInterfaces.h.inc
+ -gen-attr-interface-defs UtilAttrInterfaces.cpp.inc
-gen-op-interface-decls UtilOpInterfaces.h.inc
-gen-op-interface-defs UtilOpInterfaces.cpp.inc
+ -gen-type-interface-decls UtilTypeInterfaces.h.inc
+ -gen-type-interface-defs UtilTypeInterfaces.cpp.inc
)
iree_tablegen_library(
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpUtils.cpp b/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp
similarity index 89%
rename from iree/compiler/Dialect/Flow/IR/FlowOpUtils.cpp
rename to iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp
index 16dd347..7febab2 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOpUtils.cpp
+++ b/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp
@@ -4,14 +4,15 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/Dialect/Flow/IR/FlowOpUtils.h"
+#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
-namespace Flow {
+namespace Util {
//------------------------------------------------------------------------------
// Closure optimization
@@ -53,6 +54,8 @@
auto type = it.value().getType();
if (auto shapedType = type.dyn_cast<ShapedType>()) {
numDynamicDims = shapedType.getNumDynamicDims();
+ } else if (type.isa<IREE::Util::SizeAwareTypeInterface>()) {
+ numDynamicDims = 1;
}
if (!llvm::count(excludedOperandIndices, it.index())) {
operandValues.push_back(it.value());
@@ -73,6 +76,8 @@
auto type = it.value();
if (auto shapedType = type.dyn_cast<ShapedType>()) {
numDynamicDims = shapedType.getNumDynamicDims();
+ } else if (type.isa<IREE::Util::SizeAwareTypeInterface>()) {
+ numDynamicDims = 1;
}
if (!llvm::count(excludedResultIndices, it.index())) {
resultTypes.push_back(type);
@@ -86,15 +91,18 @@
void eraseRegionResults(Region ®ion,
ArrayRef<unsigned> excludedResultIndices) {
- region.walk([&](IREE::Flow::ReturnOp terminator) {
- llvm::SmallVector<Value, 4> newReturns;
- for (auto it : llvm::enumerate(terminator.getOperands())) {
- if (!llvm::count(excludedResultIndices, it.index())) {
- newReturns.push_back(it.value());
+ for (auto &block : region.getBlocks()) {
+ auto *terminatorOp = block.getTerminator();
+ if (terminatorOp && terminatorOp->hasTrait<OpTrait::ReturnLike>()) {
+ llvm::SmallVector<Value, 4> newReturns;
+ for (auto it : llvm::enumerate(terminatorOp->getOperands())) {
+ if (!llvm::count(excludedResultIndices, it.index())) {
+ newReturns.push_back(it.value());
+ }
}
+ terminatorOp->setOperands(newReturns);
}
- terminator.getOperation()->setOperands(newReturns);
- });
+ }
}
// Returns true if |constantOp| represents a (logically) small constant value.
@@ -173,14 +181,11 @@
auto *sourceOp = outerValue.getDefiningOp();
if (!sourceOp) continue; // can't clone block arguments into closures
- BlockArgument blockArg = entryBlock.getArgument(opArg.index());
- if (auto type = blockArg.getType().dyn_cast<DispatchTensorType>()) {
- // We cannot just simply inline and replace all users if this is an
- // argument that can be written; for example, the region might perform
- // work after loading a initial constant from the argument and then
- // write back.
- if (type.getAccess() != TensorAccess::ReadOnly) continue;
- }
+ // We cannot just simply inline and replace all users if this is an
+ // argument that can be written; for example, the region might perform
+ // work after loading a initial constant from the argument and then
+ // write back.
+ if (!closureOp.getOperandAccess(opArg.index()).isReadOnly()) continue;
if (closureOp.canClosureContainOp(sourceOp) &&
shouldInlineIntoClosure(outerValue)) {
@@ -197,6 +202,7 @@
auto newValue = clonedOp->getResult(resultIndex);
// Replace all of the uses inside of the closure.
+ BlockArgument blockArg = entryBlock.getArgument(opArg.index());
blockArg.replaceAllUsesWith(newValue);
}
}
@@ -245,7 +251,7 @@
// You can drop a result if the use is empty, and that it is only written to
// within the dispatch region.
if (result.value().use_empty() &&
- !closureOp.isOutputReadWithinRegion(result.index())) {
+ !closureOp.getResultAccess(result.index()).isRead) {
elidedResults.push_back(result.index());
} else {
preservedResults.push_back(result.value());
@@ -293,7 +299,7 @@
return success();
}
-} // namespace Flow
+} // namespace Util
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpUtils.h b/iree/compiler/Dialect/Util/IR/ClosureOpUtils.h
similarity index 86%
rename from iree/compiler/Dialect/Flow/IR/FlowOpUtils.h
rename to iree/compiler/Dialect/Util/IR/ClosureOpUtils.h
index 7f15490..c3a900f 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOpUtils.h
+++ b/iree/compiler/Dialect/Util/IR/ClosureOpUtils.h
@@ -4,8 +4,12 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
-#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#ifndef IREE_COMPILER_DIALECT_UTIL_IR_CLOSUREOPUTILS_H_
+#define IREE_COMPILER_DIALECT_UTIL_IR_CLOSUREOPUTILS_H_
+
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Builders.h"
@@ -15,7 +19,7 @@
namespace mlir {
namespace iree_compiler {
namespace IREE {
-namespace Flow {
+namespace Util {
//------------------------------------------------------------------------------
// Closure optimization
@@ -50,7 +54,7 @@
// Duplicate operands will be combined and unused operands and results will be
// removed.
//
-// T must implement the IREE::Flow::ClosureOpInterface.
+// T must implement the IREE::Util::ClosureOpInterface.
template <typename T>
struct ClosureOptimizationPattern : public OpRewritePattern<T> {
using OpRewritePattern<T>::OpRewritePattern;
@@ -62,7 +66,9 @@
}
};
-} // namespace Flow
+} // namespace Util
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_UTIL_IR_CLOSUREOPUTILS_H_
diff --git a/iree/compiler/Dialect/Util/IR/UtilDialect.h b/iree/compiler/Dialect/Util/IR/UtilDialect.h
index 6d2e99f..8224469 100644
--- a/iree/compiler/Dialect/Util/IR/UtilDialect.h
+++ b/iree/compiler/Dialect/Util/IR/UtilDialect.h
@@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_COMPILER_DIALECT_IREE_IR_IREEDIALECT_H_
-#define IREE_COMPILER_DIALECT_IREE_IR_IREEDIALECT_H_
+#ifndef IREE_COMPILER_DIALECT_UTIL_IR_UTILDIALECT_H_
+#define IREE_COMPILER_DIALECT_UTIL_IR_UTILDIALECT_H_
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
@@ -36,4 +36,4 @@
} // namespace iree_compiler
} // namespace mlir
-#endif // IREE_COMPILER_DIALECT_IREE_IR_IREEDIALECT_H_
+#endif // IREE_COMPILER_DIALECT_UTIL_IR_UTILDIALECT_H_
diff --git a/iree/compiler/Dialect/Util/IR/UtilInterfaces.td b/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
index 6deb9a1..ea48549 100644
--- a/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
+++ b/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
@@ -10,6 +10,91 @@
include "mlir/IR/OpBase.td"
//===----------------------------------------------------------------------===//
+// IREE::Util::ClosureOpInterface
+//===----------------------------------------------------------------------===//
+
+def Util_ClosureOpInterface : OpInterface<"ClosureOpInterface"> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Util";
+
+ let description = [{
+ Interface for ops that follow the util dialect closure semantics (explicit
+ captures, dynamic-shape awareness, and normal operand/result SSA behavior).
+
+ Implementing this interface enables optimizations that perform manipulation
+ across the closure capture boundary (outside of the op <-> regions within
+ the op).
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns the body region of the closure (may have multiple blocks).
+ }],
+ /*retTy=*/"Region &",
+ /*methodName=*/"getClosureBodyRegion",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return this->getOperation()->getRegion(0);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{Returns all closure operand values.}],
+ /*retTy=*/"Operation::operand_range",
+ /*methodName=*/"getClosureOperands",
+ /*args=*/(ins)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{Returns all closure result values.}],
+ /*retTy=*/"Operation::result_range",
+ /*methodName=*/"getClosureResults",
+ /*args=*/(ins)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns true if the given operation can exist in the closure.
+ Not all operations that a closure can contain are guaranteed to be folded
+ into the closure, such as when the operation may have side-effects.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"canClosureContainOp",
+ /*args=*/(ins "Operation *":$op)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Clones the op while removing specified operands and results.
+ The body of the op will be transferred to the new op and the entry block
+ will have its arguments removed.
+
+ The returned op will be free standing. Callers must insert it into a block
+ where desired (most often just replacing the current op).
+ }],
+ /*retTy=*/"IREE::Util::ClosureOpInterface",
+ /*methodName=*/"cloneReplacementExcludingOperandsAndResults",
+ /*args=*/(ins "ArrayRef<unsigned>":$excludedOperandIndices,
+ "ArrayRef<unsigned>":$excludedResultIndices,
+ "PatternRewriter &":$rewriter)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns a bitfield indicating how an operand is used within the closure.
+ }],
+ /*retTy=*/"IREE::Util::ValueAccess",
+ /*methodName=*/"getOperandAccess",
+ /*args=*/(ins "unsigned":$operandIndex)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns a bitfield indicating how a result is used within the closure.
+ }],
+ /*retTy=*/"IREE::Util::ValueAccess",
+ /*methodName=*/"getResultAccess",
+ /*args=*/(ins "unsigned":$resultIndex)
+ >
+ ];
+}
+
+//===----------------------------------------------------------------------===//
// IREE::Util::TiedOpInterface
//===----------------------------------------------------------------------===//
@@ -151,6 +236,19 @@
return IREE::Util::detail::getTiedResultOperandIndices($_op);
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns true if the given flattened operand index is tied to one or more
+ results.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isOperandTied",
+ /*args=*/(ins "unsigned":$operandIndex),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return IREE::Util::detail::isOperandTied($_op, operandIndex);
+ }]
+ >,
];
let extraClassDeclaration = [{
@@ -171,6 +269,74 @@
}
//===----------------------------------------------------------------------===//
+// IREE::Util::SizeAware* interfaces
+//===----------------------------------------------------------------------===//
+
+def Util_InferTypeSize : TypeInterface<"InferTypeSizeInterface"> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Util";
+
+ let description = [{
+ Allows types to be queried for their size by inserting the required logic
+ when required.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ [{Builds an expression computing the size of the value.}],
+ "Value", "inferSizeFromValue", (ins "Location":$loc,
+ "Value":$value,
+ "OpBuilder &":$builder)
+ >,
+ ];
+}
+
+def Util_SizeAwareType : TypeInterface<"SizeAwareTypeInterface"> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Util";
+
+ let description = [{
+ Denotes that a type is size-aware and must always have a size value
+ associated with it in the IR. See `SizeAwareOp` for more information.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ [{Returns a size for the given sized value.}],
+ "Value", "getSize", (ins "Value":$value)
+ >,
+ ];
+}
+
+def Util_SizeAwareOp : OpInterface<"SizeAwareOpInterface"> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Util";
+
+ let description = [{
+ An operation that is able to provide size values for all size-aware operands
+ and results.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ [{Returns a size for the given sized operand index.}],
+ "Value", "getOperandSize", (ins "unsigned":$idx)
+ >,
+ InterfaceMethod<
+ [{Returns a size for the given sized result index.}],
+ "Value", "getResultSize", (ins "unsigned":$idx)
+ >,
+ InterfaceMethod<
+ [{Returns a size for the given sized result value.}],
+ "Value", "getResultSizeFromValue", (ins "Value":$value),
+ /*defaultImplementation=*/[{
+ for (unsigned i = 0; i < $_self->getNumResults(); ++i) {
+ if ($_self->getResult(i) == value) return $_self.getResultSize(i);
+ }
+ return {};
+ }]
+ >,
+ ];
+}
+
+//===----------------------------------------------------------------------===//
// IREE::Util::GlobalTypeInterface
//===----------------------------------------------------------------------===//
@@ -195,7 +361,7 @@
/*defaultImplementation=*/[{
// If one is a shaped type, then they both must be and have compatible
// shapes.
- if ($_type.isa<ShapedType>() || accessType.isa<ShapedType>()) {
+ if ($_type.template isa<ShapedType>() || accessType.isa<ShapedType>()) {
return succeeded(mlir::verifyCompatibleShape($_type, accessType));
}
// Otherwise, the types must be the same.
diff --git a/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/iree/compiler/Dialect/Util/IR/UtilOps.cpp
index 5cab132..11a6c11 100644
--- a/iree/compiler/Dialect/Util/IR/UtilOps.cpp
+++ b/iree/compiler/Dialect/Util/IR/UtilOps.cpp
@@ -24,8 +24,6 @@
namespace mlir {
namespace iree_compiler {
-namespace IREE {
-namespace Util {
//===----------------------------------------------------------------------===//
// custom<SymbolVisibility>($sym_visibility)
@@ -35,8 +33,8 @@
// some.op @foo
// some.op private @foo
-static ParseResult parseSymbolVisibility(OpAsmParser &parser,
- StringAttr &symVisibilityAttr) {
+ParseResult parseSymbolVisibility(OpAsmParser &parser,
+ StringAttr &symVisibilityAttr) {
StringRef symVisibility;
parser.parseOptionalKeyword(&symVisibility, {"public", "private", "nested"});
if (!symVisibility.empty()) {
@@ -45,8 +43,8 @@
return success();
}
-static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
- StringAttr symVisibilityAttr) {
+void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
+ StringAttr symVisibilityAttr) {
if (!symVisibilityAttr) {
p << "public";
} else {
@@ -61,37 +59,396 @@
// ->
// some.op : i32
// some.op = 42 : i32
+// some.op : i32 = 42 : index
-static ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
- Attribute &attr) {
+ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
+ Attribute &attr) {
if (succeeded(parser.parseOptionalEqual())) {
if (failed(parser.parseAttribute(attr))) {
return parser.emitError(parser.getCurrentLocation())
<< "expected attribute";
}
typeAttr = TypeAttr::get(attr.getType());
- } else {
- Type type;
- if (failed(parser.parseColonType(type))) {
- return parser.emitError(parser.getCurrentLocation()) << "expected type";
+ return success();
+ }
+
+ Type type;
+ if (failed(parser.parseColonType(type))) {
+ return parser.emitError(parser.getCurrentLocation()) << "expected type";
+ }
+ typeAttr = TypeAttr::get(type);
+
+ if (succeeded(parser.parseOptionalEqual())) {
+ if (failed(parser.parseAttribute(attr))) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected attribute";
}
- typeAttr = TypeAttr::get(type);
+ }
+
+ return success();
+}
+
+void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
+ Attribute attr) {
+ if (!attr || attr.getType() != type.getValue()) {
+ p << " : ";
+ p.printAttribute(type);
+ }
+ if (attr) {
+ p << " = ";
+ p.printAttribute(attr);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// custom<SizeAwareType>
+//===----------------------------------------------------------------------===//
+// type{%size}
+
+ParseResult parseSizeAwareType(OpAsmParser &parser, Type &type,
+ OpAsmParser::OperandType &size) {
+ if (failed(parser.parseType(type)) || failed(parser.parseLBrace()) ||
+ failed(parser.parseOperand(size)) || failed(parser.parseRBrace())) {
+ return failure();
}
return success();
}
-static void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
- Attribute attr) {
- if (attr) {
- p << " = ";
- p.printAttribute(attr);
- } else {
- p << " : ";
- p.printAttribute(type);
+void printSizeAwareType(OpAsmPrinter &p, Operation *op, Type type, Value size) {
+ p.printType(type);
+ p << "{";
+ p.printOperand(size);
+ p << "}";
+}
+
+//===----------------------------------------------------------------------===//
+// custom<SizeAwareTypeList>
+//===----------------------------------------------------------------------===//
+// (type{%size0}, type, type{%size1})
+
+ParseResult parseSizeAwareTypeList(
+ OpAsmParser &parser, SmallVectorImpl<Type> &types,
+ SmallVectorImpl<OpAsmParser::OperandType> &sizes) {
+ do {
+ Type type;
+ if (failed(parser.parseType(type))) return failure();
+ if (type.isa<IREE::Util::SizeAwareTypeInterface>()) {
+ OpAsmParser::OperandType size;
+ if (failed(parser.parseLBrace()) || failed(parser.parseOperand(size)) ||
+ failed(parser.parseRBrace())) {
+ return failure();
+ }
+ sizes.push_back(size);
+ }
+ types.push_back(type);
+ } while (succeeded(parser.parseOptionalComma()));
+ return success();
+}
+
+void printSizeAwareTypeList(OpAsmPrinter &p, Operation *op, TypeRange types,
+ OperandRange sizes) {
+ int sizeIndex = 0;
+ llvm::interleaveComma(types, p, [&](Type type) {
+ p.printType(type);
+ if (type.isa<IREE::Util::SizeAwareTypeInterface>()) {
+ p << "{";
+ p.printOperand(sizes[sizeIndex++]);
+ p << "}";
+ }
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// custom<ShapedTiedResult>
+//===----------------------------------------------------------------------===//
+// type{%dim0, %dim1}
+// %arg0 as type{%dim0}
+
+ParseResult parseShapedTiedResult(
+ OpAsmParser &parser, Type &resultType,
+ SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
+ ArrayAttr &tiedOperands) {
+ OpAsmParser::OperandType tiedResult;
+ auto res = parser.parseOptionalOperand(tiedResult);
+ int64_t tiedOperandIndex = IREE::Util::TiedOpInterface::kUntiedIndex;
+ if (res.hasValue() && succeeded(res.getValue())) {
+ tiedOperandIndex = 0;
+ if (failed(parser.parseKeyword("as"))) return failure();
+ }
+ if (failed(parser.parseType(resultType))) return failure();
+ if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
+ if (!shapedType.hasStaticShape()) {
+ SmallVector<OpAsmParser::OperandType, 4> dynamicDims;
+ if (failed(parser.parseLBrace()) ||
+ failed(parser.parseOperandList(dynamicDims,
+ shapedType.getNumDynamicDims(),
+ OpAsmParser::Delimiter::None)) ||
+ failed(parser.parseRBrace())) {
+ return failure();
+ }
+ resultDims.append(dynamicDims);
+ }
+ } else if (auto sizedType =
+ resultType.dyn_cast<IREE::Util::SizeAwareTypeInterface>()) {
+ OpAsmParser::OperandType size;
+ if (failed(parser.parseOperand(size))) {
+ return failure();
+ }
+ resultDims.push_back(size);
+ }
+ tiedOperands = parser.getBuilder().getIndexArrayAttr({tiedOperandIndex});
+ return success();
+}
+
+void printShapedTiedResult(OpAsmPrinter &p, Operation *op, Type resultType,
+ ValueRange resultDims, ArrayAttr tiedOperands) {
+ auto tiedOp = cast<IREE::Util::TiedOpInterface>(op);
+ auto tiedOperandIndex = tiedOp.getTiedResultOperandIndex(0);
+ if (tiedOperandIndex.hasValue()) {
+ auto tiedOperand = op->getOperand(tiedOperandIndex.getValue());
+ p.printOperand(tiedOperand);
+ p << " as ";
+ }
+ p.printType(resultType);
+ if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
+ if (!shapedType.hasStaticShape()) {
+ if (resultDims.empty()) {
+ p << "{<<INVALID>>}";
+ return;
+ }
+ p << "{";
+ llvm::interleaveComma(
+ resultDims.take_front(shapedType.getNumDynamicDims()), p,
+ [&](Value value) { p.printOperand(value); });
+ p << "}";
+ resultDims = resultDims.drop_front(shapedType.getNumDynamicDims());
+ }
+ } else if (auto sizedType =
+ resultType.dyn_cast<IREE::Util::SizeAwareTypeInterface>()) {
+ p << "{";
+ p.printOperand(resultDims.front());
+ p << "}";
+ resultDims = resultDims.drop_front(1);
}
}
//===----------------------------------------------------------------------===//
+// custom<ShapedFunctionType>
+//===----------------------------------------------------------------------===//
+// (type, type{%dim0, %dim1}, type) -> (type{%dim2}, %operand4)
+
+static ParseResult parseShapedOperandList(
+ OpAsmParser &parser, SmallVectorImpl<Type> &types,
+ SmallVectorImpl<OpAsmParser::OperandType> &dims) {
+ do {
+ Type type;
+ if (failed(parser.parseType(type))) return failure();
+ if (auto shapedType = type.dyn_cast<ShapedType>()) {
+ if (!shapedType.hasStaticShape()) {
+ SmallVector<OpAsmParser::OperandType, 4> dynamicDims;
+ if (failed(parser.parseLBrace()) ||
+ failed(parser.parseOperandList(dynamicDims,
+ shapedType.getNumDynamicDims(),
+ OpAsmParser::Delimiter::None)) ||
+ failed(parser.parseRBrace())) {
+ return failure();
+ }
+ dims.append(dynamicDims);
+ }
+ } else if (auto sizedType =
+ type.dyn_cast<IREE::Util::SizeAwareTypeInterface>()) {
+ OpAsmParser::OperandType size;
+ if (failed(parser.parseOperand(size))) {
+ return failure();
+ }
+ dims.push_back(size);
+ }
+ types.push_back(type);
+ } while (succeeded(parser.parseOptionalComma()));
+ return success();
+}
+
+// Finds the operand index in |operands| that |tiedResult| references.
+// Returns TiedOpInterface::kUntiedIndex if no operand is found.
+static int64_t findTiedOperand(OpAsmParser::OperandType tiedResult,
+ ArrayRef<OpAsmParser::OperandType> operands) {
+ int64_t operandIndex = IREE::Util::TiedOpInterface::kUntiedIndex;
+ for (int64_t i = 0; i < operands.size(); ++i) {
+ if (operands[i].name == tiedResult.name) {
+ operandIndex = i;
+ break;
+ }
+ }
+ return operandIndex;
+}
+
+static ParseResult parseShapedResultList(
+ OpAsmParser &parser, ArrayRef<OpAsmParser::OperandType> operands,
+ TypeRange operandTypes, ArrayRef<OpAsmParser::OperandType> operandDims,
+ SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
+ ArrayAttr &tiedOperands) {
+ SmallVector<int64_t, 4> tiedOperandIndices;
+ do {
+ OpAsmParser::OperandType tiedResult;
+ auto res = parser.parseOptionalOperand(tiedResult);
+ Type type;
+ int64_t tiedOperandIndex = IREE::Util::TiedOpInterface::kUntiedIndex;
+ if (res.hasValue() && succeeded(res.getValue())) {
+ tiedOperandIndex = findTiedOperand(tiedResult, operands);
+ if (tiedOperandIndex == IREE::Util::TiedOpInterface::kUntiedIndex) {
+ return parser.emitError(tiedResult.location,
+ "tied operand not found for result reference ")
+ << tiedResult.name;
+ }
+ if (succeeded(parser.parseOptionalKeyword("as"))) {
+ // Type _may_ differ from the operand.
+ if (failed(parser.parseType(type))) return failure();
+ } else {
+ // Use the operands type.
+ type = operandTypes[tiedOperandIndex];
+ }
+ } else if (failed(parser.parseType(type))) {
+ return failure();
+ }
+ if (auto shapedType = type.dyn_cast<ShapedType>()) {
+ if (!shapedType.hasStaticShape()) {
+ SmallVector<OpAsmParser::OperandType, 4> dynamicDims;
+ if (failed(parser.parseLBrace()) ||
+ failed(parser.parseOperandList(dynamicDims,
+ shapedType.getNumDynamicDims(),
+ OpAsmParser::Delimiter::None)) ||
+ failed(parser.parseRBrace())) {
+ return failure();
+ }
+ resultDims.append(dynamicDims);
+ }
+ } else if (auto sizedType =
+ type.dyn_cast<IREE::Util::SizeAwareTypeInterface>()) {
+ OpAsmParser::OperandType size;
+ if (failed(parser.parseOperand(size))) {
+ return failure();
+ }
+ resultDims.push_back(size);
+ }
+ resultTypes.push_back(type);
+ tiedOperandIndices.push_back(tiedOperandIndex);
+ } while (succeeded(parser.parseOptionalComma()));
+ if (!tiedOperandIndices.empty()) {
+ tiedOperands = parser.getBuilder().getIndexArrayAttr(tiedOperandIndices);
+ }
+ return success();
+}
+
+ParseResult parseShapedFunctionType(
+ OpAsmParser &parser, ArrayRef<OpAsmParser::OperandType> operands,
+ SmallVectorImpl<Type> &operandTypes,
+ SmallVectorImpl<OpAsmParser::OperandType> &operandDims,
+ SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
+ ArrayAttr &tiedOperands) {
+ if (failed(parser.parseLParen())) return failure();
+ if (failed(parser.parseOptionalRParen())) {
+ if (failed(parseShapedOperandList(parser, operandTypes, operandDims)) ||
+ failed(parser.parseRParen())) {
+ return failure();
+ }
+ }
+ if (failed(parser.parseArrow())) return failure();
+ if (succeeded(parser.parseOptionalLParen())) {
+ if (failed(parseShapedResultList(parser, operands, operandTypes,
+ operandDims, resultTypes, resultDims,
+ tiedOperands)) ||
+ failed(parser.parseRParen())) {
+ return failure();
+ }
+ } else {
+ if (failed(parseShapedResultList(parser, operands, operandTypes,
+ operandDims, resultTypes, resultDims,
+ tiedOperands))) {
+ return failure();
+ }
+ }
+ return success();
+}
+
+void printShapedFunctionType(OpAsmPrinter &p, Operation *op,
+ ValueRange operands, TypeRange operandTypes,
+ OperandRange operandDims, TypeRange resultTypes,
+ OperandRange resultDims, ArrayAttr tiedOperands) {
+ p << "(";
+ llvm::interleaveComma(operandTypes, p, [&](Type type) {
+ p.printType(type);
+ if (auto shapedType = type.dyn_cast<ShapedType>()) {
+ if (!shapedType.hasStaticShape()) {
+ if (operandDims.empty()) {
+ p << "{<<INVALID>>}";
+ return;
+ }
+ p << "{";
+ llvm::interleaveComma(
+ operandDims.take_front(shapedType.getNumDynamicDims()), p,
+ [&](Value value) { p.printOperand(value); });
+ p << "}";
+ operandDims = operandDims.drop_front(shapedType.getNumDynamicDims());
+ }
+ } else if (auto sizedType =
+ type.dyn_cast<IREE::Util::SizeAwareTypeInterface>()) {
+ p << "{";
+ p.printOperand(operandDims.front());
+ p << "}";
+ operandDims = operandDims.drop_front(1);
+ }
+ });
+ p << ") -> ";
+ if (resultTypes.size() != 1) p << "(";
+ auto tiedOp = cast<IREE::Util::TiedOpInterface>(op);
+ for (unsigned i = 0; i < resultTypes.size(); ++i) {
+ auto resultType = resultTypes[i];
+ auto tiedOperandIndex = tiedOp.getTiedResultOperandIndex(i);
+ bool printType = true;
+ if (tiedOperandIndex.hasValue()) {
+ auto tiedOperand = op->getOperand(tiedOperandIndex.getValue());
+ p.printOperand(tiedOperand);
+ if (tiedOperand.getType() != resultType) {
+ p << " as ";
+ } else {
+ // Type elided as it matches the operand.
+ printType = false;
+ }
+ }
+ if (printType) {
+ p.printType(resultType);
+ }
+ if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
+ if (!shapedType.hasStaticShape()) {
+ if (resultDims.empty()) {
+ p << "{<<INVALID>>}";
+ return;
+ }
+ p << "{";
+ llvm::interleaveComma(
+ resultDims.take_front(shapedType.getNumDynamicDims()), p,
+ [&](Value value) { p.printOperand(value); });
+ p << "}";
+ resultDims = resultDims.drop_front(shapedType.getNumDynamicDims());
+ }
+ } else if (auto sizedType =
+ resultType.dyn_cast<IREE::Util::SizeAwareTypeInterface>()) {
+ p << "{";
+ p.printOperand(resultDims.front());
+ p << "}";
+ resultDims = resultDims.drop_front(1);
+ }
+ if (i < resultTypes.size() - 1) p << ", ";
+ }
+ if (resultTypes.size() != 1) p << ")";
+}
+
+namespace IREE {
+namespace Util {
+
+//===----------------------------------------------------------------------===//
// util.do_not_optimize
//===----------------------------------------------------------------------===//
@@ -222,7 +579,9 @@
return succeeded(mlir::verifyCompatibleShape(globalType, accessType));
}
- // TODO(benvanik): use GlobalOpInterface.
+ if (auto knownType = globalType.dyn_cast<GlobalTypeInterface>()) {
+ return knownType.isAccessStorageCompatible(accessType);
+ }
// Otherwise, the types must be the same.
return globalType == accessType;
diff --git a/iree/compiler/Dialect/Util/IR/UtilOps.h b/iree/compiler/Dialect/Util/IR/UtilOps.h
index f0453fa..55cdde8 100644
--- a/iree/compiler/Dialect/Util/IR/UtilOps.h
+++ b/iree/compiler/Dialect/Util/IR/UtilOps.h
@@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_COMPILER_DIALECT_IREE_IR_IREEOPS_H_
-#define IREE_COMPILER_DIALECT_IREE_IR_IREEOPS_H_
+#ifndef IREE_COMPILER_DIALECT_UTIL_IR_UTILOPS_H_
+#define IREE_COMPILER_DIALECT_UTIL_IR_UTILOPS_H_
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
@@ -20,4 +20,87 @@
#define GET_OP_CLASSES
#include "iree/compiler/Dialect/Util/IR/UtilOps.h.inc" // IWYU pragma: export
-#endif // IREE_COMPILER_DIALECT_IREE_IR_IREEOPS_H_
+namespace mlir {
+namespace iree_compiler {
+
+//===----------------------------------------------------------------------===//
+// custom<SymbolVisibility>($sym_visibility)
+//===----------------------------------------------------------------------===//
+// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
+// ->
+// some.op @foo
+// some.op private @foo
+
+ParseResult parseSymbolVisibility(OpAsmParser &parser,
+ StringAttr &symVisibilityAttr);
+void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
+ StringAttr symVisibilityAttr);
+
+//===----------------------------------------------------------------------===//
+// custom<TypeOrAttr>($type, $attr)
+//===----------------------------------------------------------------------===//
+// some.op custom<TypeOrAttr>($type, $attr)
+// ->
+// some.op : i32
+// some.op = 42 : i32
+// some.op : i32 = 42 : index
+
+ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
+ Attribute &attr);
+void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
+ Attribute attr);
+
+//===----------------------------------------------------------------------===//
+// custom<SizeAwareType>
+//===----------------------------------------------------------------------===//
+// type{%size}
+
+ParseResult parseSizeAwareType(OpAsmParser &parser, Type &type,
+ OpAsmParser::OperandType &size);
+void printSizeAwareType(OpAsmPrinter &p, Operation *op, Type type, Value size);
+
+//===----------------------------------------------------------------------===//
+// custom<SizeAwareTypeList>
+//===----------------------------------------------------------------------===//
+// (type{%size0}, type, type{%size1})
+
+ParseResult parseSizeAwareTypeList(
+ OpAsmParser &parser, SmallVectorImpl<Type> &types,
+ SmallVectorImpl<OpAsmParser::OperandType> &sizes);
+void printSizeAwareTypeList(OpAsmPrinter &p, Operation *op, TypeRange types,
+ OperandRange sizes);
+
+//===----------------------------------------------------------------------===//
+// custom<ShapedTiedResult>
+//===----------------------------------------------------------------------===//
+// type{%dim0, %dim1}
+// %arg0 as type{%dim0}
+
+ParseResult parseShapedTiedResult(
+ OpAsmParser &parser, Type &resultType,
+ SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
+ ArrayAttr &tiedOperands);
+void printShapedTiedResult(OpAsmPrinter &p, Operation *op, Type resultType,
+ ValueRange resultDims, ArrayAttr tiedOperands);
+
+//===----------------------------------------------------------------------===//
+// custom<ShapedFunctionType>
+//===----------------------------------------------------------------------===//
+// (type, type{%dim0, %dim1}, type) -> (type{%dim2}, %operand4)
+
+ParseResult parseShapedFunctionType(
+ OpAsmParser &parser, ArrayRef<OpAsmParser::OperandType> operands,
+ SmallVectorImpl<Type> &operandTypes,
+ SmallVectorImpl<OpAsmParser::OperandType> &operandDims,
+ SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
+ ArrayAttr &tiedOperands);
+void printShapedFunctionType(OpAsmPrinter &p, Operation *op,
+ ValueRange operands, TypeRange operandTypes,
+ OperandRange operandDims, TypeRange resultTypes,
+ OperandRange resultDims, ArrayAttr tiedOperands);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_UTIL_IR_UTILOPS_H_
diff --git a/iree/compiler/Dialect/Util/IR/UtilTraits.h b/iree/compiler/Dialect/Util/IR/UtilTraits.h
index 3695e4e..e610220 100644
--- a/iree/compiler/Dialect/Util/IR/UtilTraits.h
+++ b/iree/compiler/Dialect/Util/IR/UtilTraits.h
@@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_COMPILER_DIALECT_IREE_IR_IREETRAITS_H_
-#define IREE_COMPILER_DIALECT_IREE_IR_IREETRAITS_H_
+#ifndef IREE_COMPILER_DIALECT_UTIL_IR_UTILTRAITS_H_
+#define IREE_COMPILER_DIALECT_UTIL_IR_UTILTRAITS_H_
#include "mlir/IR/OpDefinition.h"
@@ -37,4 +37,4 @@
} // namespace OpTrait
} // namespace mlir
-#endif // IREE_COMPILER_DIALECT_IREE_IR_IREETRAITS_H_
+#endif // IREE_COMPILER_DIALECT_UTIL_IR_UTILTRAITS_H_
diff --git a/iree/compiler/Dialect/Util/IR/UtilTypes.cpp b/iree/compiler/Dialect/Util/IR/UtilTypes.cpp
index cd2534b..60ad827 100644
--- a/iree/compiler/Dialect/Util/IR/UtilTypes.cpp
+++ b/iree/compiler/Dialect/Util/IR/UtilTypes.cpp
@@ -181,6 +181,24 @@
return baseValue;
}
+bool detail::isOperandTied(Operation *op, unsigned operandIndex) {
+ auto storageAttr =
+ op->getAttrOfType<ArrayAttr>(TiedOpInterface::getStorageAttrName());
+ if (!storageAttr) return false;
+ auto valueAttrs = storageAttr.getValue();
+ if (valueAttrs.empty()) return false;
+ auto tiedOp = cast<TiedOpInterface>(op);
+ unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first;
+ for (unsigned i = 0; i < valueAttrs.size(); ++i) {
+ int64_t index = valueAttrs[i].cast<IntegerAttr>().getInt();
+ index = index != TiedOpInterface::kUntiedIndex
+ ? tiedOperandsOffset + index
+ : TiedOpInterface::kUntiedIndex;
+ if (index == operandIndex) return true;
+ }
+ return false;
+}
+
LogicalResult detail::verifyTiedOp(TiedOpInterface tiedOp) {
auto storageAttr =
tiedOp->getAttrOfType<ArrayAttr>(TiedOpInterface::getStorageAttrName());
@@ -242,13 +260,15 @@
}
}
-// At the end so it can use functions above:
-#include "iree/compiler/Dialect/Util/IR/UtilOpInterfaces.cpp.inc"
-
//===----------------------------------------------------------------------===//
// IREE::Util::UtilDialect
//===----------------------------------------------------------------------===//
+// At the end so it can use functions above:
+#include "iree/compiler/Dialect/Util/IR/UtilAttrInterfaces.cpp.inc"
+#include "iree/compiler/Dialect/Util/IR/UtilOpInterfaces.cpp.inc"
+#include "iree/compiler/Dialect/Util/IR/UtilTypeInterfaces.cpp.inc"
+
void UtilDialect::registerTypes() {
addTypes<IREE::Util::ByteBufferType, IREE::Util::ListType,
IREE::Util::MutableByteBufferType, IREE::Util::PtrType,
diff --git a/iree/compiler/Dialect/Util/IR/UtilTypes.h b/iree/compiler/Dialect/Util/IR/UtilTypes.h
index c051a65..6d2cac3 100644
--- a/iree/compiler/Dialect/Util/IR/UtilTypes.h
+++ b/iree/compiler/Dialect/Util/IR/UtilTypes.h
@@ -4,15 +4,18 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_COMPILER_DIALECT_IREE_IR_IREETYPES_H_
-#define IREE_COMPILER_DIALECT_IREE_IR_IREETYPES_H_
+#ifndef IREE_COMPILER_DIALECT_UTIL_IR_UTILTYPES_H_
+#define IREE_COMPILER_DIALECT_UTIL_IR_UTILTYPES_H_
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SubElementInterfaces.h"
#include "mlir/IR/TypeSupport.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
namespace mlir {
@@ -52,6 +55,22 @@
DoNotUseReservedForFutureExpansionUseDefaultInSwitchInstead_ = 20
};
+struct ValueAccess {
+ bool isRead : 1;
+ bool isWrite : 1;
+ bool isDiscard : 1;
+ bool isNone() const { return !isRead && !isWrite && !isDiscard; }
+ bool isReadOnly() const { return isRead && !isWrite && !isDiscard; }
+ ValueAccess() : isRead(false), isWrite(false), isDiscard(false) {}
+ ValueAccess(bool isRead, bool isWrite, bool isDiscard)
+ : isRead(isRead), isWrite(isWrite), isDiscard(isDiscard) {}
+ static ValueAccess None() { return ValueAccess(false, false, false); }
+ static ValueAccess ReadOnly() { return ValueAccess(true, false, false); }
+ static ValueAccess ReadWrite() { return ValueAccess(true, true, false); }
+ static ValueAccess WriteOnly() { return ValueAccess(false, true, false); }
+ static ValueAccess DiscardWrite() { return ValueAccess(false, true, true); }
+};
+
/// Placeholder for a variant type (`?`).
class VariantType : public Type::TypeBase<VariantType, Type, TypeStorage> {
public:
@@ -133,6 +152,7 @@
void setTiedResultOperandIndex(Operation *op, unsigned resultIndex,
llvm::Optional<unsigned> operandIndex);
SmallVector<int64_t, 4> getTiedResultOperandIndices(Operation *op);
+bool isOperandTied(Operation *tiedOp, unsigned operandIndex);
LogicalResult verifyTiedOp(TiedOpInterface tiedOp);
} // namespace detail
@@ -148,6 +168,8 @@
} // namespace iree_compiler
} // namespace mlir
+#include "iree/compiler/Dialect/Util/IR/UtilAttrInterfaces.h.inc" // IWYU pragma: export
#include "iree/compiler/Dialect/Util/IR/UtilOpInterfaces.h.inc" // IWYU pragma: export
+#include "iree/compiler/Dialect/Util/IR/UtilTypeInterfaces.h.inc" // IWYU pragma: export
-#endif // IREE_COMPILER_DIALECT_IREE_IR_IREETYPES_H_
+#endif // IREE_COMPILER_DIALECT_UTIL_IR_UTILTYPES_H_
diff --git a/iree/compiler/Dialect/Util/IR/test/global_ops.mlir b/iree/compiler/Dialect/Util/IR/test/global_ops.mlir
index b38d3ff..822c131 100644
--- a/iree/compiler/Dialect/Util/IR/test/global_ops.mlir
+++ b/iree/compiler/Dialect/Util/IR/test/global_ops.mlir
@@ -19,6 +19,12 @@
// CHECK: util.global public @v_initialized_const3 = dense<4> : tensor<4xi32>
util.global public @v_initialized_const3 = dense<4> : tensor<4xi32>
+// CHECK: util.global public @v_initialized_const4 = dense<4> : tensor<4xi32>
+util.global public @v_initialized_const4 : tensor<4xi32> = dense<4> : tensor<4xi32>
+
+// CHECK: util.global public @v_initialized_const5 : tensor<4xf32> = dense<4> : tensor<4xi32>
+util.global public @v_initialized_const5 : tensor<4xf32> = dense<4> : tensor<4xi32>
+
// -----
// CHECK: util.global private @v_initialized initializer(@initializer) : tensor<4xi32>
diff --git a/iree/compiler/Dialect/VM/IR/VMOps.cpp b/iree/compiler/Dialect/VM/IR/VMOps.cpp
index 9b064d8..92c5ae6 100644
--- a/iree/compiler/Dialect/VM/IR/VMOps.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMOps.cpp
@@ -6,6 +6,7 @@
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
@@ -27,70 +28,6 @@
namespace VM {
//===----------------------------------------------------------------------===//
-// custom<SymbolVisibility>($sym_visibility)
-//===----------------------------------------------------------------------===//
-// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
-// ->
-// some.op @foo
-// some.op private @foo
-
-static ParseResult parseSymbolVisibility(OpAsmParser &parser,
- StringAttr &symVisibilityAttr) {
- StringRef symVisibility;
- parser.parseOptionalKeyword(&symVisibility, {"public", "private", "nested"});
- if (!symVisibility.empty()) {
- symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
- }
- return success();
-}
-
-static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
- StringAttr symVisibilityAttr) {
- if (!symVisibilityAttr) {
- p << "public";
- } else {
- p << symVisibilityAttr.getValue();
- }
-}
-
-//===----------------------------------------------------------------------===//
-// custom<TypeOrAttr>($type, $attr)
-//===----------------------------------------------------------------------===//
-// some.op custom<TypeOrAttr>($type, $attr)
-// ->
-// some.op : i32
-// some.op = 42 : i32
-
-static ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
- Attribute &attr) {
- if (succeeded(parser.parseOptionalEqual())) {
- if (failed(parser.parseAttribute(attr))) {
- return parser.emitError(parser.getCurrentLocation())
- << "expected attribute";
- }
- typeAttr = TypeAttr::get(attr.getType());
- } else {
- Type type;
- if (failed(parser.parseColonType(type))) {
- return parser.emitError(parser.getCurrentLocation()) << "expected type";
- }
- typeAttr = TypeAttr::get(type);
- }
- return success();
-}
-
-static void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
- Attribute attr) {
- if (attr) {
- p << " = ";
- p.printAttribute(attr);
- } else {
- p << " : ";
- p.printAttribute(type);
- }
-}
-
-//===----------------------------------------------------------------------===//
// Structural ops
//===----------------------------------------------------------------------===//