Change tied operand formatting to allow type changes. (#6631)

The ops have always held the types/dims but they were elided when printing.
Now type and dimension changes can be fully specified and round-tripped.

Reverts #6188.
Fixes #6075.
Fixes #6185.
Fixes #6420.
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index ce2f8cb..474bb07 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -230,8 +230,7 @@
     OwningRewritePatternList &results, MLIRContext *context) {
   results.insert<ClosureOptimizationPattern<ExStreamFragmentOp>>(context);
   results.insert<InsertImmutabilityPreservingStreamClones>(context);
-  // TODO(#6185): fix stream ties when types/shapes change.
-  // results.insert<TieStreamResults>(context);
+  results.insert<TieStreamResults>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index ee5fd47..0dc1f1d 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -142,16 +142,10 @@
   return success();
 }
 
-// Ties the |tiedResult| parsed operand back to a previously parsed operand.
-// The type and any dynamic dimensions of the operand will be used for the
-// result values and the operand index will be appended to |tiedOperandIndices|.
-static ParseResult tieOperand(
-    OpAsmParser::OperandType tiedResult, OpAsmParser &parser,
-    ArrayRef<OpAsmParser::OperandType> operands, TypeRange operandTypes,
-    ArrayRef<OpAsmParser::OperandType> operandDims,
-    SmallVectorImpl<Type> &resultTypes,
-    SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
-    SmallVectorImpl<int64_t> &tiedOperandIndices) {
+// Finds the operand index in |operands| that |tiedResult| references.
+// Returns TiedOpInterface::kUntiedIndex if no operand is found.
+static int64_t findTiedOperand(OpAsmParser::OperandType tiedResult,
+                               ArrayRef<OpAsmParser::OperandType> operands) {
   int64_t operandIndex = TiedOpInterface::kUntiedIndex;
   for (int64_t i = 0; i < operands.size(); ++i) {
     if (operands[i].name == tiedResult.name) {
@@ -159,29 +153,7 @@
       break;
     }
   }
-  if (operandIndex == TiedOpInterface::kUntiedIndex) {
-    return parser.emitError(tiedResult.location,
-                            "tied operand not found for result reference ")
-           << tiedResult.name;
-  }
-
-  auto resultType = operandTypes[operandIndex];
-  resultTypes.push_back(resultType);
-  tiedOperandIndices.push_back(operandIndex);
-
-  auto shapedType = resultType.dyn_cast<ShapedType>();
-  if (shapedType) {
-    unsigned dimsIndex = 0;
-    for (unsigned i = 0; i < operandIndex; ++i) {
-      if (auto shapedType = operandTypes[i].dyn_cast<ShapedType>()) {
-        dimsIndex += shapedType.getNumDynamicDims();
-      }
-    }
-    resultDims.append(llvm::to_vector<4>(
-        operandDims.slice(dimsIndex, shapedType.getNumDynamicDims())));
-  }
-
-  return success();
+  return operandIndex;
 }
 
 static ParseResult parseShapedResultList(
@@ -194,31 +166,40 @@
   do {
     OpAsmParser::OperandType tiedResult;
     auto res = parser.parseOptionalOperand(tiedResult);
+    Type type;
+    int64_t tiedOperandIndex = TiedOpInterface::kUntiedIndex;
     if (res.hasValue() && succeeded(res.getValue())) {
-      if (failed(tieOperand(tiedResult, parser, operands, operandTypes,
-                            operandDims, resultTypes, resultDims,
-                            tiedOperandIndices))) {
-        return failure();
+      tiedOperandIndex = findTiedOperand(tiedResult, operands);
+      if (tiedOperandIndex == TiedOpInterface::kUntiedIndex) {
+        return parser.emitError(tiedResult.location,
+                                "tied operand not found for result reference ")
+               << tiedResult.name;
       }
-    } else {
-      Type type;
-      if (failed(parser.parseType(type))) return failure();
-      if (auto shapedType = type.dyn_cast<ShapedType>()) {
-        if (!shapedType.hasStaticShape()) {
-          SmallVector<OpAsmParser::OperandType, 4> dynamicDims;
-          if (failed(parser.parseLBrace()) ||
-              failed(parser.parseOperandList(dynamicDims,
-                                             shapedType.getNumDynamicDims(),
-                                             OpAsmParser::Delimiter::None)) ||
-              failed(parser.parseRBrace())) {
-            return failure();
-          }
-          resultDims.append(dynamicDims);
-        }
+      if (succeeded(parser.parseOptionalKeyword("as"))) {
+        // Type _may_ differ from the operand.
+        if (failed(parser.parseType(type))) return failure();
+      } else {
+        // Use the operands type.
+        type = operandTypes[tiedOperandIndex];
       }
-      resultTypes.push_back(type);
-      tiedOperandIndices.push_back(TiedOpInterface::kUntiedIndex);
+    } else if (failed(parser.parseType(type))) {
+      return failure();
     }
+    if (auto shapedType = type.dyn_cast<ShapedType>()) {
+      if (!shapedType.hasStaticShape()) {
+        SmallVector<OpAsmParser::OperandType, 4> dynamicDims;
+        if (failed(parser.parseLBrace()) ||
+            failed(parser.parseOperandList(dynamicDims,
+                                           shapedType.getNumDynamicDims(),
+                                           OpAsmParser::Delimiter::None)) ||
+            failed(parser.parseRBrace())) {
+          return failure();
+        }
+        resultDims.append(dynamicDims);
+      }
+    }
+    resultTypes.push_back(type);
+    tiedOperandIndices.push_back(tiedOperandIndex);
   } while (succeeded(parser.parseOptionalComma()));
   if (!tiedOperandIndices.empty()) {
     tiedOperands = parser.getBuilder().getIndexArrayAttr(tiedOperandIndices);
@@ -286,25 +267,34 @@
   if (resultTypes.size() != 1) p << "(";
   auto tiedOp = cast<TiedOpInterface>(op);
   for (unsigned i = 0; i < resultTypes.size(); ++i) {
-    auto tiedOperand = tiedOp.getTiedResultOperandIndex(i);
-    if (tiedOperand.hasValue()) {
-      p.printOperand(op->getOperand(tiedOperand.getValue()));
-    } else {
-      auto type = resultTypes[i];
-      p.printType(type);
-      if (auto shapedType = type.dyn_cast<ShapedType>()) {
-        if (!shapedType.hasStaticShape()) {
-          if (resultDims.empty()) {
-            p << "{<<INVALID>>}";
-            return;
-          }
-          p << "{";
-          llvm::interleaveComma(
-              resultDims.take_front(shapedType.getNumDynamicDims()), p,
-              [&](Value value) { p.printOperand(value); });
-          p << "}";
-          resultDims = resultDims.drop_front(shapedType.getNumDynamicDims());
+    auto resultType = resultTypes[i];
+    auto tiedOperandIndex = tiedOp.getTiedResultOperandIndex(i);
+    bool printType = true;
+    if (tiedOperandIndex.hasValue()) {
+      auto tiedOperand = op->getOperand(tiedOperandIndex.getValue());
+      p.printOperand(tiedOperand);
+      if (tiedOperand.getType() != resultType) {
+        p << " as ";
+      } else {
+        // Type elided as it matches the operand.
+        printType = false;
+      }
+    }
+    if (printType) {
+      p.printType(resultType);
+    }
+    if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
+      if (!shapedType.hasStaticShape()) {
+        if (resultDims.empty()) {
+          p << "{<<INVALID>>}";
+          return;
         }
+        p << "{";
+        llvm::interleaveComma(
+            resultDims.take_front(shapedType.getNumDynamicDims()), p,
+            [&](Value value) { p.printOperand(value); });
+        p << "}";
+        resultDims = resultDims.drop_front(shapedType.getNumDynamicDims());
       }
     }
     if (i < resultTypes.size() - 1) p << ", ";
diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir
index 7cb22d0..298bb75 100644
--- a/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_ops.mlir
@@ -39,7 +39,21 @@
   %dim0 = constant 100 : index
   // CHECK-DAG: %[[DIM1:.+]] = constant 200
   %dim1 = constant 200 : index
-  // CHECK: %0:2 = flow.dispatch @ex0::@dispatch_fn[%[[CST]]](%[[CST]], %arg0, %arg1) : (index, tensor<4x?xf32>{%[[DIM0]]}, tensor<8x?xf32>{%[[DIM1]]}) -> (%arg0, %arg1)
-  %0, %1 = flow.dispatch @ex0::@dispatch_fn[%cst](%cst, %arg0, %arg1) : (index, tensor<4x?xf32>{%dim0}, tensor<8x?xf32>{%dim1}) -> (%arg0, %arg1)
+  // CHECK: %0:2 = flow.dispatch @ex0::@dispatch_fn[%[[CST]]](%[[CST]], %arg0, %arg1) : (index, tensor<4x?xf32>{%[[DIM0]]}, tensor<8x?xf32>{%[[DIM1]]}) -> (%arg0{%[[DIM1]]}, %arg1{%[[DIM0]]})
+  %0, %1 = flow.dispatch @ex0::@dispatch_fn[%cst](%cst, %arg0, %arg1) : (index, tensor<4x?xf32>{%dim0}, tensor<8x?xf32>{%dim1}) -> (%arg0{%dim1}, %arg1{%dim0})
   return %0, %1 : tensor<4x?xf32>, tensor<8x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @inplaceTypeChange
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<4x?xf32>)
+func @inplaceTypeChange(%arg0: tensor<4x?xf32>) -> tensor<?x4xf32> {
+  // CHECK-DAG: %[[CST:.+]] = constant 4
+  %cst = constant 4 : index
+  // CHECK-DAG: %[[DIM0:.+]] = constant 100
+  %dim0 = constant 100 : index
+  // CHECK: %0 = flow.dispatch @ex0::@dispatch_fn[%[[CST]]](%[[ARG0]]) : (tensor<4x?xf32>{%[[DIM0]]}) -> %arg0 as tensor<?x4xf32>{%[[DIM0]]}
+  %0 = flow.dispatch @ex0::@dispatch_fn[%cst](%arg0) : (tensor<4x?xf32>{%dim0}) -> %arg0 as tensor<?x4xf32>{%dim0}
+  return %0 : tensor<?x4xf32>
+}
diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups.mlir
index f5b997d..9231857 100644
--- a/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups.mlir
@@ -99,8 +99,8 @@
   // CHECK: %[[OUTER_RET0:.+]] = flow.dispatch.workgroups[
   // CHECK-SAME: %[[WORKGROUP_COUNT_X]], %[[WORKGROUP_COUNT_Y]]
   // CHECK-SAME: ](%[[ARG0]], %[[ARG1]])
-  // CHECK-SAME: : (tensor<?x4xf32>{%c128}, index) -> %arg0 =
-  %0 = flow.dispatch.workgroups[%x, %y](%arg0, %arg1) : (tensor<?x4xf32>{%c128}, index) -> %arg0 =
+  // CHECK-SAME: : (tensor<?x4xf32>{%c128}, index) -> %arg0{%c128} =
+  %0 = flow.dispatch.workgroups[%x, %y](%arg0, %arg1) : (tensor<?x4xf32>{%c128}, index) -> %arg0{%c128} =
   // CHECK-NEXT: (%[[INNER_ARG0:.+]]: !flow.dispatch.tensor<readwrite:?x4xf32>
   // CHECK-SAME:  %[[INNER_ARG1:.+]]: index) {
   (%arg0_capture: !flow.dispatch.tensor<readwrite:?x4xf32>, %arg1_capture: index) {
diff --git a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir
index bac3cac..b357d60 100644
--- a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir
@@ -68,8 +68,8 @@
                                 %arg1: tensor<8x?xf32>, %dim1: index) -> tensor<8x?xf32> {
   // CHECK: flow.ex.stream.fragment(%[[ARG1]]) :
   %0:2 = flow.ex.stream.fragment(%arg0, %arg1) :
-      // CHECK-SAME: (tensor<8x?xf32>{%[[DIM1]]}) -> %[[ARG1]] =
-      (tensor<4x?xf32>{%dim0}, tensor<8x?xf32>{%dim1}) -> (%arg0, %arg1) =
+      // CHECK-SAME: (tensor<8x?xf32>{%[[DIM1]]}) -> %[[ARG1]]{%[[DIM1]]} =
+      (tensor<4x?xf32>{%dim0}, tensor<8x?xf32>{%dim1}) -> (%arg0{%dim0}, %arg1{%dim1}) =
       // CHECK-NEXT: (%[[INNER_ARG:.+]]: tensor<8x?xf32>) -> tensor<8x?xf32>
       (%unused: tensor<4x?xf32>, %arg1: tensor<8x?xf32>) -> (tensor<4x?xf32>, tensor<8x?xf32>) {
     // CHECK-NEXT: flow.return %[[INNER_ARG]] : tensor<8x?xf32>
diff --git a/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir
index 1dab1a9..f7323f5 100644
--- a/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/stream_ops.mlir
@@ -1,5 +1,3 @@
-// Tests printing and parsing of stream ops.
-
 // RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
 
 flow.executable @dispatch_0 {
@@ -29,3 +27,23 @@
   // CHECK-NEXT: return
   return %0#0, %0#1 : tensor<4xf32>, tensor<4xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @typeChange
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>, %[[DIM0:.+]]: index, %[[DIM1:.+]]: index)
+func @typeChange(%arg0: tensor<?x?xf32>, %dim0: index, %dim1: index) -> (tensor<4x?xf32>) {
+  //      CHECK: %[[RET:.+]] = flow.ex.stream.fragment(%[[ARG0]], %[[DIM0]], %[[DIM1]]) :
+  // CHECK-SAME:     (tensor<?x?xf32>{%[[DIM0]], %[[DIM1]]}, index, index) -> %[[ARG0]] as tensor<4x?xf32>{%[[DIM1]]} =
+  // CHECK-NEXT: (%[[STREAM_ARG0:.+]]: tensor<?x?xf32>, %[[STREAM_DIM0:.+]]: index, %[[STREAM_DIM1:.+]]: index) -> tensor<4x?xf32> {
+  %0 = flow.ex.stream.fragment(%arg0, %dim0, %dim1) : (tensor<?x?xf32>{%dim0, %dim1}, index, index) -> %arg0 as tensor<4x?xf32>{%dim1} =
+      (%stream_arg0: tensor<?x?xf32>, %stream_dim0: index, %stream_dim1: index) -> tensor<4x?xf32> {
+    // CHECK-NEXT: %[[STREAM_RET:.+]] = flow.tensor.reshape %[[STREAM_ARG0:.+]] : tensor<?x?xf32>{%[[STREAM_DIM0]], %[[STREAM_DIM1]]} -> tensor<4x?xf32>{%[[STREAM_DIM1]]}
+    %1 = flow.tensor.reshape %stream_arg0 : tensor<?x?xf32>{%stream_dim0, %stream_dim1} -> tensor<4x?xf32>{%stream_dim1}
+    // CHECK-NEXT: flow.return %[[STREAM_RET]] : tensor<4x?xf32>
+    flow.return %1 : tensor<4x?xf32>
+    // CHECK-NEXT: }
+  }
+  // CHECK-NEXT: return %[[RET]] : tensor<4x?xf32>
+  return %0 : tensor<4x?xf32>
+}