Change tied operand formatting to allow type changes. (#6631)
The ops have always held the types/dims but they were elided when printing.
Now type and dimension changes can be fully specified and round-tripped.
Reverts #6188.
Fixes #6075.
Fixes #6185.
Fixes #6420.
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index ce2f8cb..474bb07 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -230,8 +230,7 @@
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ClosureOptimizationPattern<ExStreamFragmentOp>>(context);
results.insert<InsertImmutabilityPreservingStreamClones>(context);
- // TODO(#6185): fix stream ties when types/shapes change.
- // results.insert<TieStreamResults>(context);
+ results.insert<TieStreamResults>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index ee5fd47..0dc1f1d 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -142,16 +142,10 @@
return success();
}
-// Ties the |tiedResult| parsed operand back to a previously parsed operand.
-// The type and any dynamic dimensions of the operand will be used for the
-// result values and the operand index will be appended to |tiedOperandIndices|.
-static ParseResult tieOperand(
- OpAsmParser::OperandType tiedResult, OpAsmParser &parser,
- ArrayRef<OpAsmParser::OperandType> operands, TypeRange operandTypes,
- ArrayRef<OpAsmParser::OperandType> operandDims,
- SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
- SmallVectorImpl<int64_t> &tiedOperandIndices) {
+// Finds the operand index in |operands| that |tiedResult| references.
+// Returns TiedOpInterface::kUntiedIndex if no operand is found.
+static int64_t findTiedOperand(OpAsmParser::OperandType tiedResult,
+ ArrayRef<OpAsmParser::OperandType> operands) {
int64_t operandIndex = TiedOpInterface::kUntiedIndex;
for (int64_t i = 0; i < operands.size(); ++i) {
if (operands[i].name == tiedResult.name) {
@@ -159,29 +153,7 @@
break;
}
}
- if (operandIndex == TiedOpInterface::kUntiedIndex) {
- return parser.emitError(tiedResult.location,
- "tied operand not found for result reference ")
- << tiedResult.name;
- }
-
- auto resultType = operandTypes[operandIndex];
- resultTypes.push_back(resultType);
- tiedOperandIndices.push_back(operandIndex);
-
- auto shapedType = resultType.dyn_cast<ShapedType>();
- if (shapedType) {
- unsigned dimsIndex = 0;
- for (unsigned i = 0; i < operandIndex; ++i) {
- if (auto shapedType = operandTypes[i].dyn_cast<ShapedType>()) {
- dimsIndex += shapedType.getNumDynamicDims();
- }
- }
- resultDims.append(llvm::to_vector<4>(
- operandDims.slice(dimsIndex, shapedType.getNumDynamicDims())));
- }
-
- return success();
+ return operandIndex;
}
static ParseResult parseShapedResultList(
@@ -194,31 +166,40 @@
do {
OpAsmParser::OperandType tiedResult;
auto res = parser.parseOptionalOperand(tiedResult);
+ Type type;
+ int64_t tiedOperandIndex = TiedOpInterface::kUntiedIndex;
if (res.hasValue() && succeeded(res.getValue())) {
- if (failed(tieOperand(tiedResult, parser, operands, operandTypes,
- operandDims, resultTypes, resultDims,
- tiedOperandIndices))) {
- return failure();
+ tiedOperandIndex = findTiedOperand(tiedResult, operands);
+ if (tiedOperandIndex == TiedOpInterface::kUntiedIndex) {
+ return parser.emitError(tiedResult.location,
+ "tied operand not found for result reference ")
+ << tiedResult.name;
}
- } else {
- Type type;
- if (failed(parser.parseType(type))) return failure();
- if (auto shapedType = type.dyn_cast<ShapedType>()) {
- if (!shapedType.hasStaticShape()) {
- SmallVector<OpAsmParser::OperandType, 4> dynamicDims;
- if (failed(parser.parseLBrace()) ||
- failed(parser.parseOperandList(dynamicDims,
- shapedType.getNumDynamicDims(),
- OpAsmParser::Delimiter::None)) ||
- failed(parser.parseRBrace())) {
- return failure();
- }
- resultDims.append(dynamicDims);
- }
+ if (succeeded(parser.parseOptionalKeyword("as"))) {
+ // Type _may_ differ from the operand.
+ if (failed(parser.parseType(type))) return failure();
+ } else {
+ // Use the operands type.
+ type = operandTypes[tiedOperandIndex];
}
- resultTypes.push_back(type);
- tiedOperandIndices.push_back(TiedOpInterface::kUntiedIndex);
+ } else if (failed(parser.parseType(type))) {
+ return failure();
}
+ if (auto shapedType = type.dyn_cast<ShapedType>()) {
+ if (!shapedType.hasStaticShape()) {
+ SmallVector<OpAsmParser::OperandType, 4> dynamicDims;
+ if (failed(parser.parseLBrace()) ||
+ failed(parser.parseOperandList(dynamicDims,
+ shapedType.getNumDynamicDims(),
+ OpAsmParser::Delimiter::None)) ||
+ failed(parser.parseRBrace())) {
+ return failure();
+ }
+ resultDims.append(dynamicDims);
+ }
+ }
+ resultTypes.push_back(type);
+ tiedOperandIndices.push_back(tiedOperandIndex);
} while (succeeded(parser.parseOptionalComma()));
if (!tiedOperandIndices.empty()) {
tiedOperands = parser.getBuilder().getIndexArrayAttr(tiedOperandIndices);
@@ -286,25 +267,34 @@
if (resultTypes.size() != 1) p << "(";
auto tiedOp = cast<TiedOpInterface>(op);
for (unsigned i = 0; i < resultTypes.size(); ++i) {
- auto tiedOperand = tiedOp.getTiedResultOperandIndex(i);
- if (tiedOperand.hasValue()) {
- p.printOperand(op->getOperand(tiedOperand.getValue()));
- } else {
- auto type = resultTypes[i];
- p.printType(type);
- if (auto shapedType = type.dyn_cast<ShapedType>()) {
- if (!shapedType.hasStaticShape()) {
- if (resultDims.empty()) {
- p << "{<<INVALID>>}";
- return;
- }
- p << "{";
- llvm::interleaveComma(
- resultDims.take_front(shapedType.getNumDynamicDims()), p,
- [&](Value value) { p.printOperand(value); });
- p << "}";
- resultDims = resultDims.drop_front(shapedType.getNumDynamicDims());
+ auto resultType = resultTypes[i];
+ auto tiedOperandIndex = tiedOp.getTiedResultOperandIndex(i);
+ bool printType = true;
+ if (tiedOperandIndex.hasValue()) {
+ auto tiedOperand = op->getOperand(tiedOperandIndex.getValue());
+ p.printOperand(tiedOperand);
+ if (tiedOperand.getType() != resultType) {
+ p << " as ";
+ } else {
+ // Type elided as it matches the operand.
+ printType = false;
+ }
+ }
+ if (printType) {
+ p.printType(resultType);
+ }
+ if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
+ if (!shapedType.hasStaticShape()) {
+ if (resultDims.empty()) {
+ p << "{<<INVALID>>}";
+ return;
}
+ p << "{";
+ llvm::interleaveComma(
+ resultDims.take_front(shapedType.getNumDynamicDims()), p,
+ [&](Value value) { p.printOperand(value); });
+ p << "}";
+ resultDims = resultDims.drop_front(shapedType.getNumDynamicDims());
}
}
if (i < resultTypes.size() - 1) p << ", ";
diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir
index 7cb22d0..298bb75 100644
--- a/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir
@@ -39,7 +39,21 @@
%dim0 = constant 100 : index
// CHECK-DAG: %[[DIM1:.+]] = constant 200
%dim1 = constant 200 : index
- // CHECK: %0:2 = flow.dispatch @ex0::@dispatch_fn[%[[CST]]](%[[CST]], %arg0, %arg1) : (index, tensor<4x?xf32>{%[[DIM0]]}, tensor<8x?xf32>{%[[DIM1]]}) -> (%arg0, %arg1)
- %0, %1 = flow.dispatch @ex0::@dispatch_fn[%cst](%cst, %arg0, %arg1) : (index, tensor<4x?xf32>{%dim0}, tensor<8x?xf32>{%dim1}) -> (%arg0, %arg1)
+ // CHECK: %0:2 = flow.dispatch @ex0::@dispatch_fn[%[[CST]]](%[[CST]], %arg0, %arg1) : (index, tensor<4x?xf32>{%[[DIM0]]}, tensor<8x?xf32>{%[[DIM1]]}) -> (%arg0{%[[DIM1]]}, %arg1{%[[DIM0]]})
+ %0, %1 = flow.dispatch @ex0::@dispatch_fn[%cst](%cst, %arg0, %arg1) : (index, tensor<4x?xf32>{%dim0}, tensor<8x?xf32>{%dim1}) -> (%arg0{%dim1}, %arg1{%dim0})
return %0, %1 : tensor<4x?xf32>, tensor<8x?xf32>
}
+
+// -----
+
+// CHECK-LABEL: @inplaceTypeChange
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<4x?xf32>)
+func @inplaceTypeChange(%arg0: tensor<4x?xf32>) -> tensor<?x4xf32> {
+ // CHECK-DAG: %[[CST:.+]] = constant 4
+ %cst = constant 4 : index
+ // CHECK-DAG: %[[DIM0:.+]] = constant 100
+ %dim0 = constant 100 : index
+ // CHECK: %0 = flow.dispatch @ex0::@dispatch_fn[%[[CST]]](%[[ARG0]]) : (tensor<4x?xf32>{%[[DIM0]]}) -> %arg0 as tensor<?x4xf32>{%[[DIM0]]}
+ %0 = flow.dispatch @ex0::@dispatch_fn[%cst](%arg0) : (tensor<4x?xf32>{%dim0}) -> %arg0 as tensor<?x4xf32>{%dim0}
+ return %0 : tensor<?x4xf32>
+}
diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups.mlir
index f5b997d..9231857 100644
--- a/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups.mlir
@@ -99,8 +99,8 @@
// CHECK: %[[OUTER_RET0:.+]] = flow.dispatch.workgroups[
// CHECK-SAME: %[[WORKGROUP_COUNT_X]], %[[WORKGROUP_COUNT_Y]]
// CHECK-SAME: ](%[[ARG0]], %[[ARG1]])
- // CHECK-SAME: : (tensor<?x4xf32>{%c128}, index) -> %arg0 =
- %0 = flow.dispatch.workgroups[%x, %y](%arg0, %arg1) : (tensor<?x4xf32>{%c128}, index) -> %arg0 =
+ // CHECK-SAME: : (tensor<?x4xf32>{%c128}, index) -> %arg0{%c128} =
+ %0 = flow.dispatch.workgroups[%x, %y](%arg0, %arg1) : (tensor<?x4xf32>{%c128}, index) -> %arg0{%c128} =
// CHECK-NEXT: (%[[INNER_ARG0:.+]]: !flow.dispatch.tensor<readwrite:?x4xf32>
// CHECK-SAME: %[[INNER_ARG1:.+]]: index) {
(%arg0_capture: !flow.dispatch.tensor<readwrite:?x4xf32>, %arg1_capture: index) {
diff --git a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir
index bac3cac..b357d60 100644
--- a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir
@@ -68,8 +68,8 @@
%arg1: tensor<8x?xf32>, %dim1: index) -> tensor<8x?xf32> {
// CHECK: flow.ex.stream.fragment(%[[ARG1]]) :
%0:2 = flow.ex.stream.fragment(%arg0, %arg1) :
- // CHECK-SAME: (tensor<8x?xf32>{%[[DIM1]]}) -> %[[ARG1]] =
- (tensor<4x?xf32>{%dim0}, tensor<8x?xf32>{%dim1}) -> (%arg0, %arg1) =
+ // CHECK-SAME: (tensor<8x?xf32>{%[[DIM1]]}) -> %[[ARG1]]{%[[DIM1]]} =
+ (tensor<4x?xf32>{%dim0}, tensor<8x?xf32>{%dim1}) -> (%arg0{%dim0}, %arg1{%dim1}) =
// CHECK-NEXT: (%[[INNER_ARG:.+]]: tensor<8x?xf32>) -> tensor<8x?xf32>
(%unused: tensor<4x?xf32>, %arg1: tensor<8x?xf32>) -> (tensor<4x?xf32>, tensor<8x?xf32>) {
// CHECK-NEXT: flow.return %[[INNER_ARG]] : tensor<8x?xf32>
diff --git a/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir
index 1dab1a9..f7323f5 100644
--- a/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir
@@ -1,5 +1,3 @@
-// Tests printing and parsing of stream ops.
-
// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
flow.executable @dispatch_0 {
@@ -29,3 +27,23 @@
// CHECK-NEXT: return
return %0#0, %0#1 : tensor<4xf32>, tensor<4xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @typeChange
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>, %[[DIM0:.+]]: index, %[[DIM1:.+]]: index)
+func @typeChange(%arg0: tensor<?x?xf32>, %dim0: index, %dim1: index) -> (tensor<4x?xf32>) {
+ // CHECK: %[[RET:.+]] = flow.ex.stream.fragment(%[[ARG0]], %[[DIM0]], %[[DIM1]]) :
+ // CHECK-SAME: (tensor<?x?xf32>{%[[DIM0]], %[[DIM1]]}, index, index) -> %[[ARG0]] as tensor<4x?xf32>{%[[DIM1]]} =
+ // CHECK-NEXT: (%[[STREAM_ARG0:.+]]: tensor<?x?xf32>, %[[STREAM_DIM0:.+]]: index, %[[STREAM_DIM1:.+]]: index) -> tensor<4x?xf32> {
+ %0 = flow.ex.stream.fragment(%arg0, %dim0, %dim1) : (tensor<?x?xf32>{%dim0, %dim1}, index, index) -> %arg0 as tensor<4x?xf32>{%dim1} =
+ (%stream_arg0: tensor<?x?xf32>, %stream_dim0: index, %stream_dim1: index) -> tensor<4x?xf32> {
+ // CHECK-NEXT: %[[STREAM_RET:.+]] = flow.tensor.reshape %[[STREAM_ARG0:.+]] : tensor<?x?xf32>{%[[STREAM_DIM0]], %[[STREAM_DIM1]]} -> tensor<4x?xf32>{%[[STREAM_DIM1]]}
+ %1 = flow.tensor.reshape %stream_arg0 : tensor<?x?xf32>{%stream_dim0, %stream_dim1} -> tensor<4x?xf32>{%stream_dim1}
+ // CHECK-NEXT: flow.return %[[STREAM_RET]] : tensor<4x?xf32>
+ flow.return %1 : tensor<4x?xf32>
+ // CHECK-NEXT: }
+ }
+ // CHECK-NEXT: return %[[RET]] : tensor<4x?xf32>
+ return %0 : tensor<4x?xf32>
+}