Cleaning up parsers/printers and generalizing tied operand/result handling. (#6785)

* Adding support for util.global type/initializer differences.
* Moving size-aware type interfaces and parsers/printers to util.
* Improving tied result printing.
* Allowing size-aware types in tied operand/result parsing.
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/test/lower_global_tensors.mlir b/integrations/tensorflow/iree_tf_compiler/TF/test/lower_global_tensors.mlir
index d93da89..1b5ff14 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/test/lower_global_tensors.mlir
+++ b/integrations/tensorflow/iree_tf_compiler/TF/test/lower_global_tensors.mlir
@@ -6,7 +6,7 @@
 // TODO(silvasean): Verify "type" handling.
 // I think when "type" is a partial type that flow will not model it correctly.
 
-// CHECK:       util.global private mutable [[V:@[a-zA-Z0-9$._-]+]] = dense<1.000000e+00> : tensor<1xf32>
+// CHECK:       util.global private mutable [[V:@.+]] : tensor<?xf32> = dense<1.000000e+00> : tensor<1xf32>
 // CHECK:       func @f() -> (tensor<?xf32> {tf_saved_model.index_path = []})
 // CHECK-NEXT:    [[PTR:%.+]] = util.global.address [[V]] : !util.ptr<tensor<?xf32>>
 // CHECK-NEXT:    [[T:%.+]] = util.global.load.indirect [[PTR]] : !util.ptr<tensor<?xf32>> -> tensor<?xf32>
@@ -26,9 +26,9 @@
 // CHECK-LABEL: module attributes {tf_saved_model.semantics}
 module attributes {tf_saved_model.semantics} {
 
-// CHECK:       util.global private mutable [[V:@[a-zA-Z0-9$._-]+]] = dense<1.000000e+00> : tensor<1xf32>
+// CHECK:       util.global private mutable [[V:@.+]] : tensor<?xf32> = dense<1.000000e+00> : tensor<1xf32>
 // CHECK:       func @f(%arg0: tensor<?xf32> {tf_saved_model.index_path = [0]})
-// CHECK-NEXT:    [[PTR:%.+]] = util.global.address @__iree_flow_v : !util.ptr<tensor<?xf32>>
+// CHECK-NEXT:    [[PTR:%.+]] = util.global.address [[V]] : !util.ptr<tensor<?xf32>>
 // CHECK-NEXT:    util.global.store.indirect %arg0, [[PTR]] : tensor<?xf32> -> !util.ptr<tensor<?xf32>>
 // CHECK-NEXT:    return
 
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/test/lower_global_tensors_complex.mlir b/integrations/tensorflow/iree_tf_compiler/TF/test/lower_global_tensors_complex.mlir
index 6e65a12..4758edd 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/test/lower_global_tensors_complex.mlir
+++ b/integrations/tensorflow/iree_tf_compiler/TF/test/lower_global_tensors_complex.mlir
@@ -5,7 +5,7 @@
 // CHECK-LABEL: module attributes {tf_saved_model.semantics}
 module attributes {tf_saved_model.semantics} {
 
-// CHECK:      util.global private mutable [[V:@.+]] = dense<1.000000e+00> : tensor<1xf32>
+// CHECK:      util.global private mutable [[V:@.+]] : tensor<?xf32> = dense<1.000000e+00> : tensor<1xf32>
 // CHECK:      func @f(%arg0: tensor<?xf32> {tf_saved_model.index_path = [0]}) attributes {tf_saved_model.exported_names = ["f"]} {
 // CHECK-NEXT:   [[PTR:%.+]] = util.global.address [[V]] : !util.ptr<tensor<?xf32>>
 // CHECK-NEXT:   br ^bb1([[PTR]] : !util.ptr<tensor<?xf32>>)
@@ -28,8 +28,8 @@
 // CHECK-LABEL: module attributes {tf_saved_model.semantics}
 module attributes {tf_saved_model.semantics} {
 
-// CHECK:      util.global private mutable [[V:@.+]] = dense<1.000000e+00> : tensor<1xf32>
-// CHECK:      util.global private mutable [[V1:@.+]] = dense<1.000000e+00> : tensor<1xf32>
+// CHECK:      util.global private mutable [[V:@.+]] : tensor<?xf32> = dense<1.000000e+00> : tensor<1xf32>
+// CHECK:      util.global private mutable [[V1:@.+]] : tensor<?xf32> = dense<1.000000e+00> : tensor<1xf32>
 // CHECK:      func @f(%arg0: tensor<?xf32> {tf_saved_model.index_path = [0]}) -> (tensor<?xf32> {tf_saved_model.index_path = [0]}) attributes {tf_saved_model.exported_names = ["f"]} {
 // CHECK-NEXT:   [[PTR0:%.+]] = util.global.address [[V]] : !util.ptr<tensor<?xf32>>
 // CHECK-NEXT:   [[PTR1:%.+]] = util.global.address [[V1]] : !util.ptr<tensor<?xf32>>
@@ -55,7 +55,7 @@
 // CHECK-LABEL: module attributes {tf_saved_model.semantics}
 module attributes {tf_saved_model.semantics} {
 
-// CHECK:      util.global private mutable [[V:@.+]] = dense<1.000000e+00> : tensor<1xf32>
+// CHECK:      util.global private mutable [[V:@.+]] : tensor<?xf32> = dense<1.000000e+00> : tensor<1xf32>
 // CHECK:      func @f(%arg0: tensor<?xf32> {tf_saved_model.index_path = [0]}) attributes {tf_saved_model.exported_names = ["f"]} {
 // CHECK-NEXT:   [[PTR:%.+]] = util.global.address [[V]] : !util.ptr<tensor<?xf32>>
 // CHECK-NEXT:   br ^bb1([[PTR]], [[PTR]], [[PTR]] : !util.ptr<tensor<?xf32>>, !util.ptr<tensor<?xf32>>, !util.ptr<tensor<?xf32>>)
diff --git a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/insert_slice.mlir b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/insert_slice.mlir
index 383624e..1c29bf4 100644
--- a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/insert_slice.mlir
+++ b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/insert_slice.mlir
@@ -16,7 +16,7 @@
 //   CHECK-DAG:   %[[C4:.+]] = constant 4
 //   CHECK-DAG:   %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
 //       CHECK:   %[[UPDATE:.+]] = flow.tensor.update %[[ARG1]], %[[ARG0]][%[[C4]], %[[C2]], %[[C0]]]
-//  CHECK-SAME:     : tensor<1x4x48xf32> -> tensor<?x24x48xf32>{%[[DIM0]]}
+//  CHECK-SAME:     : tensor<1x4x48xf32> -> %[[ARG0]] as tensor<?x24x48xf32>{%[[DIM0]]}
 
 // -----
 
@@ -37,7 +37,7 @@
 //   CHECK-DAG:   %[[RESHAPE:.+]] = flow.tensor.reshape %[[ARG1]] : tensor<4x48xf32> -> tensor<1x4x48xf32>
 //   CHECK-DAG:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
 //       CHECK:   %[[UPDATE:.+]] = flow.tensor.update %[[RESHAPE]], %[[ARG0]][%[[C4]], %[[C2]], %[[C0]]]
-//  CHECK-SAME:     : tensor<1x4x48xf32> -> tensor<?x24x48xf32>{%[[DIM]]}
+//  CHECK-SAME:     : tensor<1x4x48xf32> -> %[[ARG0]] as tensor<?x24x48xf32>{%[[DIM]]}
 
 // -----
 
@@ -50,7 +50,7 @@
 //   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
 //   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
 //       CHECK:   %[[RESHAPE:.+]] = flow.tensor.reshape %{{.+}} : tensor<49x20xf32> -> tensor<1x49x20x1xf32>
-//       CHECK:   flow.tensor.update %[[RESHAPE]], %{{.+}}[%[[C0]], %[[C1]], %[[C0]], %[[C0]]] : tensor<1x49x20x1xf32> -> tensor<1x50x20x1xf32>
+//       CHECK:   flow.tensor.update %[[RESHAPE]], %{{.+}}[%[[C0]], %[[C1]], %[[C0]], %[[C0]]] : tensor<1x49x20x1xf32> -> %{{.+}} as tensor<1x50x20x1xf32>
 
 
 // -----
diff --git a/iree/compiler/Dialect/Flow/IR/BUILD b/iree/compiler/Dialect/Flow/IR/BUILD
index 23d7550..5466c08 100644
--- a/iree/compiler/Dialect/Flow/IR/BUILD
+++ b/iree/compiler/Dialect/Flow/IR/BUILD
@@ -27,6 +27,7 @@
     deps = [
         "//iree/compiler/Dialect/Shape/IR:td_files",
         "//iree/compiler/Dialect/Util/IR:td_files",
+        "@llvm-project//mlir:ControlFlowInterfacesTdFiles",
         "@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
         "@llvm-project//mlir:OpBaseTdFiles",
         "@llvm-project//mlir:SideEffectTdFiles",
@@ -42,7 +43,6 @@
         "FlowEnums.cpp.inc",
         "FlowOpFolders.cpp",
         "FlowOpInterfaces.cpp.inc",
-        "FlowOpUtils.cpp",
         "FlowOps.cpp",
         "FlowOps.cpp.inc",
         "FlowTypeInterfaces.cpp.inc",
@@ -53,7 +53,6 @@
         "FlowDialect.h",
         "FlowEnums.h.inc",
         "FlowOpInterfaces.h.inc",
-        "FlowOpUtils.h",
         "FlowOps.h",
         "FlowOps.h.inc",
         "FlowTypeInterfaces.h.inc",
@@ -68,6 +67,7 @@
         "//iree/compiler/Dialect/Shape/IR",
         "//iree/compiler/Dialect/Util/IR",
         "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:ControlFlowInterfaces",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:InferTypeOpInterface",
         "@llvm-project//mlir:MemRefDialect",
diff --git a/iree/compiler/Dialect/Flow/IR/CMakeLists.txt b/iree/compiler/Dialect/Flow/IR/CMakeLists.txt
index 245efb5..5a8860a 100644
--- a/iree/compiler/Dialect/Flow/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/IR/CMakeLists.txt
@@ -17,7 +17,6 @@
     "FlowDialect.h"
     "FlowEnums.h.inc"
     "FlowOpInterfaces.h.inc"
-    "FlowOpUtils.h"
     "FlowOps.h"
     "FlowOps.h.inc"
     "FlowTypeInterfaces.h.inc"
@@ -28,7 +27,6 @@
     "FlowEnums.cpp.inc"
     "FlowOpFolders.cpp"
     "FlowOpInterfaces.cpp.inc"
-    "FlowOpUtils.cpp"
     "FlowOps.cpp"
     "FlowOps.cpp.inc"
     "FlowTypeInterfaces.cpp.inc"
@@ -40,6 +38,7 @@
     ::FlowOpsGen
     ::FlowTypesGen
     LLVMSupport
+    MLIRControlFlowInterfaces
     MLIRIR
     MLIRInferTypeOpInterface
     MLIRMemRef
diff --git a/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td b/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td
index 9fc93f8..11f35e6 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td
+++ b/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td
@@ -10,81 +10,6 @@
 include "iree/compiler/Dialect/Util/IR/UtilBase.td"
 
 //===----------------------------------------------------------------------===//
-// IREE::Flow::ClosureOpInterface
-//===----------------------------------------------------------------------===//
-
-def FLOW_ClosureOpInterface : OpInterface<"ClosureOpInterface"> {
-  let description = [{
-    Interface for ops that follow the Flow dialect closure semantics (explicit
-    captures, dynamic-shape awareness, and normal operand/result SSA behavior).
-
-    Implementing this interface enables optimizations that perform manipulation
-    across the closure capture boundary (outside of the op <-> regions within
-    the op).
-  }];
-
-  let methods = [
-    InterfaceMethod<
-      /*desc=*/[{
-        Returns the body region of the closure (may have multiple blocks).
-      }],
-      /*retTy=*/"Region &",
-      /*methodName=*/"getClosureBodyRegion",
-      /*args=*/(ins),
-      /*methodBody=*/[{}],
-      /*defaultImplementation=*/[{
-        return this->getOperation()->getRegion(0);
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/[{Returns all closure operand values.}],
-      /*retTy=*/"Operation::operand_range",
-      /*methodName=*/"getClosureOperands",
-      /*args=*/(ins)
-    >,
-    InterfaceMethod<
-      /*desc=*/[{Returns all closure result values.}],
-      /*retTy=*/"Operation::result_range",
-      /*methodName=*/"getClosureResults",
-      /*args=*/(ins)
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Returns true if the given operation can exist in the closure.
-        Not all operations that a closure can contain are guaranteed to be folded
-        into the closure, such as when the operation may have side-effects.
-      }],
-      /*retTy=*/"bool",
-      /*methodName=*/"canClosureContainOp",
-      /*args=*/(ins "Operation *":$op)
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Clones the op while removing specified operands and results.
-        The body of the op will be transferred to the new op and the entry block
-        will have its arguments removed.
-
-        The returned op will be free standing. Callers must insert it into a block
-        where desired (most often just replacing the current op).
-      }],
-      /*retTy=*/"ClosureOpInterface",
-      /*methodName=*/"cloneReplacementExcludingOperandsAndResults",
-      /*args=*/(ins "ArrayRef<unsigned>":$excludedOperandIndices,
-                    "ArrayRef<unsigned>":$excludedResultIndices,
-                    "PatternRewriter &":$rewriter)
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Returns true if the output is also read within the region.
-      }],
-      /*retTy=*/"bool",
-      /*methodName=*/"isOutputReadWithinRegion",
-      /*args=*/(ins "unsigned":$resultIndex)
-    >
-  ];
-}
-
-//===----------------------------------------------------------------------===//
 // IREE::Flow::StreamableOpInterface
 //===----------------------------------------------------------------------===//
 
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index c4c8b5f..4753b8b 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -8,9 +8,9 @@
 #include <numeric>
 
 #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
-#include "iree/compiler/Dialect/Flow/IR/FlowOpUtils.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
 #include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/Optional.h"
@@ -228,7 +228,8 @@
 
 void ExStreamFragmentOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
-  results.insert<ClosureOptimizationPattern<ExStreamFragmentOp>>(context);
+  results.insert<IREE::Util::ClosureOptimizationPattern<ExStreamFragmentOp>>(
+      context);
   results.insert<InsertImmutabilityPreservingStreamClones>(context);
   // TODO(#6420): fix HAL lowering of this (or wait until streams are gone).
   // results.insert<TieStreamResults>(context);
@@ -240,7 +241,8 @@
 
 void DispatchWorkgroupsOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
-  results.insert<ClosureOptimizationPattern<DispatchWorkgroupsOp>>(context);
+  results.insert<IREE::Util::ClosureOptimizationPattern<DispatchWorkgroupsOp>>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index 501ad88..1d8a835 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -6,8 +6,8 @@
 
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 
-#include "iree/compiler/Dialect/Flow/IR/FlowOpUtils.h"
 #include "iree/compiler/Dialect/Shape/IR/Builders.h"
+#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
 #include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
 #include "llvm/ADT/BitVector.h"
 #include "llvm/ADT/StringExtras.h"
@@ -41,13 +41,6 @@
 // Op utilities used within the Flow dialect
 //===----------------------------------------------------------------------===//
 
-// Returns true if the given |accessType| is compatible with the |variableType|.
-// For example, this will return true if the variable type is a tensor<?xf32>
-// and the access is tensor<4xf32>.
-static bool isVariableTypeCompatible(Type variableType, Type accessType) {
-  return succeeded(mlir::verifyCompatibleShape(variableType, accessType));
-}
-
 // Verifies that |dynamicDims| contains the appropriate number of dims for all
 // of the dynamic dimensions in |values|.
 static LogicalResult verifyOpDynamicDims(Operation *op, ValueRange values,
@@ -68,241 +61,6 @@
 }
 
 //===----------------------------------------------------------------------===//
-// custom<TiedResult>
-//===----------------------------------------------------------------------===//
-// type{%dim0, %dim1}
-// %arg0
-
-static ParseResult parseTiedResult(
-    OpAsmParser &parser, Type &resultType,
-    SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
-    ArrayAttr &tiedOperands) {
-  if (failed(parser.parseType(resultType))) return failure();
-  if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
-    if (!shapedType.hasStaticShape()) {
-      SmallVector<OpAsmParser::OperandType, 4> dynamicDims;
-      if (failed(parser.parseLBrace()) ||
-          failed(parser.parseOperandList(dynamicDims,
-                                         shapedType.getNumDynamicDims(),
-                                         OpAsmParser::Delimiter::None)) ||
-          failed(parser.parseRBrace())) {
-        return failure();
-      }
-      resultDims.append(dynamicDims);
-    }
-  }
-  tiedOperands = parser.getBuilder().getIndexArrayAttr({0});
-  return success();
-}
-
-static void printTiedResult(OpAsmPrinter &p, Operation *op, Type resultType,
-                            ValueRange resultDims, ArrayAttr tiedOperands) {
-  p.printType(resultType);
-  if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
-    if (!shapedType.hasStaticShape()) {
-      if (resultDims.empty()) {
-        p << "{<<INVALID>>}";
-        return;
-      }
-      p << "{";
-      llvm::interleaveComma(
-          resultDims.take_front(shapedType.getNumDynamicDims()), p,
-          [&](Value value) { p.printOperand(value); });
-      p << "}";
-    }
-  }
-}
-
-//===----------------------------------------------------------------------===//
-// custom<ShapedFunctionType>
-//===----------------------------------------------------------------------===//
-// (type, type{%dim0, %dim1}, type) -> (type{%dim2}, %operand4)
-
-static ParseResult parseShapedOperandList(
-    OpAsmParser &parser, SmallVectorImpl<Type> &types,
-    SmallVectorImpl<OpAsmParser::OperandType> &dims) {
-  do {
-    Type type;
-    if (failed(parser.parseType(type))) return failure();
-    if (auto shapedType = type.dyn_cast<ShapedType>()) {
-      if (!shapedType.hasStaticShape()) {
-        SmallVector<OpAsmParser::OperandType, 4> dynamicDims;
-        if (failed(parser.parseLBrace()) ||
-            failed(parser.parseOperandList(dynamicDims,
-                                           shapedType.getNumDynamicDims(),
-                                           OpAsmParser::Delimiter::None)) ||
-            failed(parser.parseRBrace())) {
-          return failure();
-        }
-        dims.append(dynamicDims);
-      }
-    }
-    types.push_back(type);
-  } while (succeeded(parser.parseOptionalComma()));
-  return success();
-}
-
-// Finds the operand index in |operands| that |tiedResult| references.
-// Returns TiedOpInterface::kUntiedIndex if no operand is found.
-static int64_t findTiedOperand(OpAsmParser::OperandType tiedResult,
-                               ArrayRef<OpAsmParser::OperandType> operands) {
-  int64_t operandIndex = IREE::Util::TiedOpInterface::kUntiedIndex;
-  for (int64_t i = 0; i < operands.size(); ++i) {
-    if (operands[i].name == tiedResult.name) {
-      operandIndex = i;
-      break;
-    }
-  }
-  return operandIndex;
-}
-
-static ParseResult parseShapedResultList(
-    OpAsmParser &parser, ArrayRef<OpAsmParser::OperandType> operands,
-    TypeRange operandTypes, ArrayRef<OpAsmParser::OperandType> operandDims,
-    SmallVectorImpl<Type> &resultTypes,
-    SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
-    ArrayAttr &tiedOperands) {
-  SmallVector<int64_t, 4> tiedOperandIndices;
-  do {
-    OpAsmParser::OperandType tiedResult;
-    auto res = parser.parseOptionalOperand(tiedResult);
-    Type type;
-    int64_t tiedOperandIndex = IREE::Util::TiedOpInterface::kUntiedIndex;
-    if (res.hasValue() && succeeded(res.getValue())) {
-      tiedOperandIndex = findTiedOperand(tiedResult, operands);
-      if (tiedOperandIndex == IREE::Util::TiedOpInterface::kUntiedIndex) {
-        return parser.emitError(tiedResult.location,
-                                "tied operand not found for result reference ")
-               << tiedResult.name;
-      }
-      if (succeeded(parser.parseOptionalKeyword("as"))) {
-        // Type _may_ differ from the operand.
-        if (failed(parser.parseType(type))) return failure();
-      } else {
-        // Use the operands type.
-        type = operandTypes[tiedOperandIndex];
-      }
-    } else if (failed(parser.parseType(type))) {
-      return failure();
-    }
-    if (auto shapedType = type.dyn_cast<ShapedType>()) {
-      if (!shapedType.hasStaticShape()) {
-        SmallVector<OpAsmParser::OperandType, 4> dynamicDims;
-        if (failed(parser.parseLBrace()) ||
-            failed(parser.parseOperandList(dynamicDims,
-                                           shapedType.getNumDynamicDims(),
-                                           OpAsmParser::Delimiter::None)) ||
-            failed(parser.parseRBrace())) {
-          return failure();
-        }
-        resultDims.append(dynamicDims);
-      }
-    }
-    resultTypes.push_back(type);
-    tiedOperandIndices.push_back(tiedOperandIndex);
-  } while (succeeded(parser.parseOptionalComma()));
-  if (!tiedOperandIndices.empty()) {
-    tiedOperands = parser.getBuilder().getIndexArrayAttr(tiedOperandIndices);
-  }
-  return success();
-}
-
-static ParseResult parseShapedFunctionType(
-    OpAsmParser &parser, ArrayRef<OpAsmParser::OperandType> operands,
-    SmallVectorImpl<Type> &operandTypes,
-    SmallVectorImpl<OpAsmParser::OperandType> &operandDims,
-    SmallVectorImpl<Type> &resultTypes,
-    SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
-    ArrayAttr &tiedOperands) {
-  if (failed(parser.parseLParen())) return failure();
-  if (failed(parser.parseOptionalRParen())) {
-    if (failed(parseShapedOperandList(parser, operandTypes, operandDims)) ||
-        failed(parser.parseRParen())) {
-      return failure();
-    }
-  }
-  if (failed(parser.parseArrow())) return failure();
-  if (succeeded(parser.parseOptionalLParen())) {
-    if (failed(parseShapedResultList(parser, operands, operandTypes,
-                                     operandDims, resultTypes, resultDims,
-                                     tiedOperands)) ||
-        failed(parser.parseRParen())) {
-      return failure();
-    }
-  } else {
-    if (failed(parseShapedResultList(parser, operands, operandTypes,
-                                     operandDims, resultTypes, resultDims,
-                                     tiedOperands))) {
-      return failure();
-    }
-  }
-  return success();
-}
-
-static void printShapedFunctionType(OpAsmPrinter &p, Operation *op,
-                                    ValueRange operands, TypeRange operandTypes,
-                                    OperandRange operandDims,
-                                    TypeRange resultTypes,
-                                    OperandRange resultDims,
-                                    ArrayAttr tiedOperands) {
-  p << "(";
-  llvm::interleaveComma(operandTypes, p, [&](Type type) {
-    p.printType(type);
-    if (auto shapedType = type.dyn_cast<ShapedType>()) {
-      if (!shapedType.hasStaticShape()) {
-        if (operandDims.empty()) {
-          p << "{<<INVALID>>}";
-          return;
-        }
-        p << "{";
-        llvm::interleaveComma(
-            operandDims.take_front(shapedType.getNumDynamicDims()), p,
-            [&](Value value) { p.printOperand(value); });
-        p << "}";
-        operandDims = operandDims.drop_front(shapedType.getNumDynamicDims());
-      }
-    }
-  });
-  p << ") -> ";
-  if (resultTypes.size() != 1) p << "(";
-  auto tiedOp = cast<IREE::Util::TiedOpInterface>(op);
-  for (unsigned i = 0; i < resultTypes.size(); ++i) {
-    auto resultType = resultTypes[i];
-    auto tiedOperandIndex = tiedOp.getTiedResultOperandIndex(i);
-    bool printType = true;
-    if (tiedOperandIndex.hasValue()) {
-      auto tiedOperand = op->getOperand(tiedOperandIndex.getValue());
-      p.printOperand(tiedOperand);
-      if (tiedOperand.getType() != resultType) {
-        p << " as ";
-      } else {
-        // Type elided as it matches the operand.
-        printType = false;
-      }
-    }
-    if (printType) {
-      p.printType(resultType);
-    }
-    if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
-      if (!shapedType.hasStaticShape()) {
-        if (resultDims.empty()) {
-          p << "{<<INVALID>>}";
-          return;
-        }
-        p << "{";
-        llvm::interleaveComma(
-            resultDims.take_front(shapedType.getNumDynamicDims()), p,
-            [&](Value value) { p.printOperand(value); });
-        p << "}";
-        resultDims = resultDims.drop_front(shapedType.getNumDynamicDims());
-      }
-    }
-    if (i < resultTypes.size() - 1) p << ", ";
-  }
-  if (resultTypes.size() != 1) p << ")";
-}
-
-//===----------------------------------------------------------------------===//
 // flow.dispatch.tensor.load
 //===----------------------------------------------------------------------===//
 
@@ -565,27 +323,66 @@
   return canDispatchRegionContainOp(op);
 }
 
-bool DispatchWorkgroupsOp::isOutputReadWithinRegion(unsigned resultIndex) {
-  unsigned startIndex = getBody()->getNumArguments() - getNumResults();
-  BlockArgument arg = body().front().getArgument(startIndex + resultIndex);
-  // If argument is of `writeonly` access, then it is not read by construction.
-  if (arg.getType().cast<DispatchTensorType>().getAccess() ==
-      TensorAccess::WriteOnly) {
-    return false;
-  }
-  // If the argument is a result with `readwrite` access, return false if the
-  // value is only written to. Check this by looking at the uses of the argument
-  // being only the `target` of `flow.dispatch.tensor.store` ops.
-  for (OpOperand &uses : arg.getUses()) {
-    auto storeOp = dyn_cast<DispatchTensorStoreOp>(uses.getOwner());
-    if (!(storeOp && storeOp.target() == uses.get())) {
-      return true;
+// Refines the tensor access from what is declared on |type| based on actual
+// usage. We expect that the access was set correctly to begin with but today
+// we sometimes specify things too wide.
+static TensorAccess refineTensorAccess(Value value, DispatchTensorType type) {
+  auto tensorAccess = type.getAccess();
+  if (tensorAccess == TensorAccess::ReadWrite) {
+    // If the argument is a result with `readwrite` access, return false if the
+    // value is only written to. Check this by looking at the uses of the
+    // argument being only the `target` of `flow.dispatch.tensor.store` ops.
+    bool onlyWrites = true;
+    for (OpOperand &uses : value.getUses()) {
+      auto storeOp = dyn_cast<DispatchTensorStoreOp>(uses.getOwner());
+      if (!(storeOp && storeOp.target() == uses.get())) {
+        onlyWrites = false;
+        break;
+      }
     }
+    if (onlyWrites) tensorAccess = TensorAccess::WriteOnly;
   }
-  return false;
+  return tensorAccess;
 }
 
-ClosureOpInterface
+IREE::Util::ValueAccess DispatchWorkgroupsOp::getOperandAccess(
+    unsigned operandIndex) {
+  BlockArgument arg = body().front().getArgument(operandIndex);
+  if (auto tensorType = arg.getType().dyn_cast<DispatchTensorType>()) {
+    auto tensorAccess = refineTensorAccess(arg, tensorType);
+    return IREE::Util::ValueAccess(
+        /*isRead=*/(tensorAccess == TensorAccess::ReadOnly) ||
+            (tensorAccess == TensorAccess::ReadWrite),
+        /*isWrite=*/(tensorAccess == TensorAccess::ReadWrite) ||
+            (tensorAccess == TensorAccess::WriteOnly),
+        /*isDiscard=*/(tensorAccess == TensorAccess::WriteOnly));
+  } else {
+    return IREE::Util::ValueAccess(/*isRead=*/!arg.use_empty(),
+                                   /*isWrite=*/false,
+                                   /*isDiscard=*/false);
+  }
+}
+
+IREE::Util::ValueAccess DispatchWorkgroupsOp::getResultAccess(
+    unsigned resultIndex) {
+  unsigned startIndex = getBody()->getNumArguments() - getNumResults();
+  BlockArgument arg = body().front().getArgument(startIndex + resultIndex);
+  if (auto tensorType = arg.getType().dyn_cast<DispatchTensorType>()) {
+    auto tensorAccess = refineTensorAccess(arg, tensorType);
+    return IREE::Util::ValueAccess(
+        /*isRead=*/(tensorAccess == TensorAccess::ReadOnly) ||
+            (tensorAccess == TensorAccess::ReadWrite),
+        /*isWrite=*/(tensorAccess == TensorAccess::ReadWrite) ||
+            (tensorAccess == TensorAccess::WriteOnly),
+        /*isDiscard=*/(tensorAccess == TensorAccess::WriteOnly));
+  } else {
+    return IREE::Util::ValueAccess(/*isRead=*/!arg.use_empty(),
+                                   /*isWrite=*/false,
+                                   /*isDiscard=*/false);
+  }
+}
+
+IREE::Util::ClosureOpInterface
 DispatchWorkgroupsOp::cloneReplacementExcludingOperandsAndResults(
     ArrayRef<unsigned> excludedOperandIndices,
     ArrayRef<unsigned> excludedResultIndices, PatternRewriter &rewriter) {
@@ -593,9 +390,9 @@
   SmallVector<Value, 4> newResultDims = llvm::to_vector<4>(result_dims());
   SmallVector<Value, 4> newOperandsValues = llvm::to_vector<4>(operands());
   SmallVector<Value, 4> newOperandDims = llvm::to_vector<4>(operand_dims());
-  excludeClosureOperandsAndResults(newOperandsValues, newOperandDims,
-                                   excludedOperandIndices, newResultTypes,
-                                   newResultDims, excludedResultIndices);
+  IREE::Util::excludeClosureOperandsAndResults(
+      newOperandsValues, newOperandDims, excludedOperandIndices, newResultTypes,
+      newResultDims, excludedResultIndices);
 
   auto newTiedOperandIndices =
       llvm::to_vector<4>(getTiedResultOperandIndices());
@@ -1151,11 +948,20 @@
   return false;
 }
 
-bool ExStreamFragmentOp::isOutputReadWithinRegion(unsigned resultIndex) {
-  return false;
+IREE::Util::ValueAccess ExStreamFragmentOp::getOperandAccess(
+    unsigned operandIndex) {
+  return !isOperandTied(operandIndex) ? IREE::Util::ValueAccess::ReadOnly()
+                                      : IREE::Util::ValueAccess::ReadWrite();
 }
 
-ClosureOpInterface
+IREE::Util::ValueAccess ExStreamFragmentOp::getResultAccess(
+    unsigned resultIndex) {
+  return getTiedResultOperandIndex(resultIndex).hasValue()
+             ? IREE::Util::ValueAccess::ReadWrite()
+             : IREE::Util::ValueAccess::DiscardWrite();
+}
+
+IREE::Util::ClosureOpInterface
 ExStreamFragmentOp::cloneReplacementExcludingOperandsAndResults(
     ArrayRef<unsigned> excludedOperandIndices,
     ArrayRef<unsigned> excludedResultIndices, PatternRewriter &rewriter) {
@@ -1163,9 +969,9 @@
   SmallVector<Value, 4> newResultDims = llvm::to_vector<4>(result_dims());
   SmallVector<Value, 4> newOperandsValues = llvm::to_vector<4>(operands());
   SmallVector<Value, 4> newOperandDims = llvm::to_vector<4>(operand_dims());
-  excludeClosureOperandsAndResults(newOperandsValues, newOperandDims,
-                                   excludedOperandIndices, newResultTypes,
-                                   newResultDims, excludedResultIndices);
+  IREE::Util::excludeClosureOperandsAndResults(
+      newOperandsValues, newOperandDims, excludedOperandIndices, newResultTypes,
+      newResultDims, excludedResultIndices);
 
   auto newTiedOperandIndices =
       llvm::to_vector<4>(getTiedResultOperandIndices());
@@ -1179,7 +985,7 @@
       newOperandDims, newTiedOperandIndices, getOperation()->getAttrs());
   auto &newBody = newOp.getClosureBodyRegion();
   newBody.takeBody(getClosureBodyRegion());
-  eraseRegionResults(newBody, excludedResultIndices);
+  IREE::Util::eraseRegionResults(newBody, excludedResultIndices);
   newBody.front().eraseArguments(excludedOperandIndices);
   return newOp;
 }
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.h b/iree/compiler/Dialect/Flow/IR/FlowOps.h
index e1d1230..a455cad 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.h
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.h
@@ -23,6 +23,7 @@
 #include "mlir/IR/FunctionSupport.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.td b/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 3f5e724..ee24714 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -13,6 +13,7 @@
 include "iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
@@ -28,7 +29,7 @@
   IsolatedFromAbove,
   AttrSizedOperandSegments,
   SingleBlockImplicitTerminator<"IREE::Flow::ReturnOp">,
-  DeclareOpInterfaceMethods<FLOW_ClosureOpInterface>,
+  DeclareOpInterfaceMethods<Util_ClosureOpInterface>,
   DeclareOpInterfaceMethods<Util_TiedOpInterface, [
     "getTiedOperandsIndexAndLength",
   ]>,
@@ -421,7 +422,7 @@
   let hasCanonicalizer = 1;
 }
 
-def FLOW_ReturnOp : FLOW_Op<"return", [Terminator]> {
+def FLOW_ReturnOp : FLOW_Op<"return", [NoSideEffect, ReturnLike, Terminator]> {
   let summary = [{return from a flow.dispatch_region}];
   let description = [{
     Returns the given values from the region and back to the host code.
@@ -874,7 +875,7 @@
   let assemblyFormat = [{
     $update `,` $target `[` $start_indices `]` `:`
     type($update) (`{` $update_dims^ `}`)? `->`
-    custom<TiedResult>(type($result), $target_dims, $tied_operands)
+    custom<ShapedTiedResult>(type($result), $target_dims, $tied_operands)
     attr-dict-with-keyword
   }];
 
@@ -921,7 +922,7 @@
 def FLOW_ExStreamFragmentOp : FLOW_PureOp<"ex.stream.fragment", [
   IsolatedFromAbove,
   AttrSizedOperandSegments,
-  DeclareOpInterfaceMethods<FLOW_ClosureOpInterface>,
+  DeclareOpInterfaceMethods<Util_ClosureOpInterface>,
   DeclareOpInterfaceMethods<Util_TiedOpInterface>,
   DeclareOpInterfaceMethods<Shape_ShapeCarryingOpInterface>,
 ]> {
diff --git a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir
index 4b177ac..93e5259 100644
--- a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir
@@ -47,13 +47,32 @@
 
 // -----
 
-// CHECK-LABEL: func @removeUnusedResult
+// CHECK-LABEL: func @removeUnusedProducedResult
+// CHECK-SAME: (%[[ARG0:.+]]: index)
+func @removeUnusedProducedResult(%arg0: index) -> index {
+  // CHECK: flow.ex.stream.fragment(%[[ARG0]]) : (index) -> index =
+  %0:2 = flow.ex.stream.fragment(%arg0) : (index) -> (index, index) =
+      (%arg0_in: index) -> (index, index) {
+    // CHECK: %[[T:.+]] = addi
+    %t = addi %arg0_in, %arg0_in : index
+    %unused = muli %arg0_in, %arg0_in : index
+    // CHECK: flow.return %[[T]] : index
+    flow.return %t, %unused : index, index
+  }
+  return %0#0 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @removeUnusedPassThroughResult
 // CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
-func @removeUnusedResult(%arg0: index, %arg1: index) -> index {
+func @removeUnusedPassThroughResult(%arg0: index, %arg1: index) -> index {
   // CHECK: flow.ex.stream.fragment(%[[ARG1]])
   %0:2 = flow.ex.stream.fragment(%arg0, %arg1) : (index, index) -> (index, index) =
-      (%unused: index, %arg1: index) -> (index, index) {
-    %t = addi %arg1, %arg1 : index
+      (%unused: index, %arg1_in: index) -> (index, index) {
+    // CHECK: %[[T:.+]] = addi
+    %t = addi %arg1_in, %arg1_in : index
+    // CHECK: flow.return %[[T]] : index
     flow.return %t, %unused : index, index
   }
   return %0#0 : index
@@ -69,7 +88,7 @@
   // CHECK: flow.ex.stream.fragment(%[[ARG1]]) :
   %0:2 = flow.ex.stream.fragment(%arg0, %arg1) :
       // CHECK-SAME: (tensor<8x?xf32>{%[[DIM1]]}) -> %[[ARG1]]{%[[DIM1]]} =
-      (tensor<4x?xf32>{%dim0}, tensor<8x?xf32>{%dim1}) -> (%arg0{%dim0}, %arg1{%dim1}) =
+      (tensor<4x?xf32>{%dim0}, tensor<8x?xf32>{%dim1}) -> (tensor<4x?xf32>{%dim0}, %arg1{%dim1}) =
       // CHECK-NEXT: (%[[INNER_ARG:.+]]: tensor<8x?xf32>) -> tensor<8x?xf32>
       (%unused: tensor<4x?xf32>, %arg1: tensor<8x?xf32>) -> (tensor<4x?xf32>, tensor<8x?xf32>) {
     // CHECK-NEXT: flow.return %[[INNER_ARG]] : tensor<8x?xf32>
@@ -96,7 +115,7 @@
     %workload = constant 8 : index
     //      CHECK: %[[TARGET_CLONE:.+]] = flow.tensor.clone %[[TARGET]] : tensor<2x4xi32>
     //      CHECK: %[[UPDATED:.+]] = flow.tensor.update %[[UPDATE]], %[[TARGET]]
-    %t0 = flow.tensor.update %stream_update, %stream_target[%start0, %start1] : tensor<1x1xi32> -> tensor<2x4xi32>
+    %t0 = flow.tensor.update %stream_update, %stream_target[%start0, %start1] : tensor<1x1xi32> -> %stream_target as tensor<2x4xi32>
     // CHECK-NEXT: %[[RETURN:.+]] = flow.dispatch @ex::@entry[%c8](%[[TARGET_CLONE]], %[[UPDATED]])
     %t1 = flow.dispatch @ex::@entry[%workload](%stream_target, %t0) : (tensor<2x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
     // CHECK-NEXT: flow.return %[[RETURN]]
@@ -165,7 +184,7 @@
     %5 = util.global.load @_large_const : tensor<7xi32>
     // CHECK: %[[CLONE:.+]] = flow.tensor.clone %[[LOAD]]
     // CHECK: flow.tensor.update %{{.+}}, %[[CLONE]]
-    %6 = flow.tensor.update %arg0, %5[%c3] : tensor<2xi32> -> tensor<7xi32>
+    %6 = flow.tensor.update %arg0, %5[%c3] : tensor<2xi32> -> %5 as tensor<7xi32>
     flow.return %6 : tensor<7xi32>
   }
   return %4 : tensor<7xi32>
diff --git a/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir
index 84fef08..12f2481 100644
--- a/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir
@@ -143,8 +143,8 @@
 
 // CHECK-LABEL: @tensorUpdate
 func @tensorUpdate(%arg0 : tensor<2x2xf32>, %arg1 : tensor<4x4xf32>, %arg2 : index, %arg3 : index) -> tensor<4x4xf32> {
-  // CHECK-NEXT: %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor<2x2xf32> -> tensor<4x4xf32>
-  %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor<2x2xf32> -> tensor<4x4xf32>
+  // CHECK-NEXT: %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor<2x2xf32> -> %arg1 as tensor<4x4xf32>
+  %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor<2x2xf32> -> %arg1 as tensor<4x4xf32>
   return %0 : tensor<4x4xf32>
 }
 
@@ -153,7 +153,7 @@
   %c1 = constant 1 : index
   %c2 = constant 2 : index
   %c3 = constant 3 : index
-  // CHECK: %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor<?x?xf32>{%c1, %c2} -> tensor<?x4xf32>{%c3}
-  %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor<?x?xf32>{%c1, %c2} -> tensor<?x4xf32>{%c3}
+  // CHECK: %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor<?x?xf32>{%c1, %c2} -> %arg1 as tensor<?x4xf32>{%c3}
+  %0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor<?x?xf32>{%c1, %c2} -> %arg1 as tensor<?x4xf32>{%c3}
   return %0 : tensor<?x4xf32>
 }
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_after.mlir b/iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_after.mlir
index e3e964e..a7716d2 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_after.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_after.mlir
@@ -10,7 +10,7 @@
   %4 = affine.apply affine_map<(d0)[s0, s1] -> (d0 + s0 + s1)>(%2)[%arg3, %arg5]
   %5 = linalg.init_tensor [%3, %4] : tensor<?x?xf32>
   %6 = linalg.fill(%0, %5) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-  %7 = flow.tensor.update %arg0, %6[%arg2, %arg3] : tensor<?x?xf32>{%1, %2} -> tensor<?x?xf32>{%3, %4}
+  %7 = flow.tensor.update %arg0, %6[%arg2, %arg3] : tensor<?x?xf32>{%1, %2} -> %6 as tensor<?x?xf32>{%3, %4}
   return %7 : tensor<?x?xf32>
 }
 
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index afe63b1..5d846cd 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -334,7 +334,7 @@
   %4 = affine.apply affine_map<(d0)[s0, s1] -> (d0 + s0 + s1)>(%2)[%arg3, %arg5]
   %5 = linalg.init_tensor [%3, %4] : tensor<?x?xf32>
   %6 = linalg.fill(%0, %5) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-  %7 = flow.tensor.update %arg0, %6[%arg2, %arg3] : tensor<?x?xf32>{%1, %2} -> tensor<?x?xf32>{%3, %4}
+  %7 = flow.tensor.update %arg0, %6[%arg2, %arg3] : tensor<?x?xf32>{%1, %2} -> %6 as tensor<?x?xf32>{%3, %4}
   return %7 : tensor<?x?xf32>
 }
 
@@ -639,7 +639,7 @@
   %c5_i32 = constant 5 : i32
   %c0_i32 = constant 0 : i32
   %c9_i32 = constant 9 : i32
-  %245 = flow.tensor.update %240, %244[%c9] : tensor<9xi32> -> tensor<18xi32>
+  %245 = flow.tensor.update %240, %244[%c9] : tensor<9xi32> -> %244 as tensor<18xi32>
   %248 = tensor.extract %247[] : tensor<i32>
   %249 = cmpi slt, %248, %c9_i32 : i32
   %250 = select %249, %248, %c9_i32 : i32
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir
index ccfaf18..8a52b41 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir
@@ -284,7 +284,7 @@
     // CHECK-SAME:   source(%[[UBUF]] : !hal.buffer)[%c0]
     // CHECK-SAME:   target(%[[RET_BUF]] : !hal.buffer)[%c204]
     // CHECK-SAME:   length(%c40)
-    %1 = flow.tensor.update %arg2, %clone[%arg4, %arg5, %arg5] : tensor<1x1x10xf32> -> tensor<5x1x10xf32>
+    %1 = flow.tensor.update %arg2, %clone[%arg4, %arg5, %arg5] : tensor<1x1x10xf32> -> %clone as tensor<5x1x10xf32>
     flow.return %1 : tensor<5x1x10xf32>
   }
   // CHECK: hal.command_buffer.end<%[[CMD]]
@@ -571,7 +571,7 @@
     // CHECK: %[[CSTBUF:.+]] = util.global.load @_const_pool_splats : !hal.buffer
     // CHECK: hal.command_buffer.copy_buffer<%cmd : !hal.command_buffer> source(%[[CSTBUF]] : !hal.buffer)[%[[C0]]] target(%[[DSTBUF]] : !hal.buffer)[%[[C0]]] length(%[[C28]])
     %2 = flow.tensor.clone %const_span : tensor<7xi32>
-    %3 = flow.tensor.update %arg0, %2[%c3] : tensor<2xi32> -> tensor<7xi32>
+    %3 = flow.tensor.update %arg0, %2[%c3] : tensor<2xi32> -> %2 as tensor<7xi32>
     flow.return %3 : tensor<7xi32>
   }
   return %1 : tensor<7xi32>
diff --git a/iree/compiler/Dialect/HAL/IR/HALInterfaces.td b/iree/compiler/Dialect/HAL/IR/HALInterfaces.td
index ac0d32c..cdf89ae 100644
--- a/iree/compiler/Dialect/HAL/IR/HALInterfaces.td
+++ b/iree/compiler/Dialect/HAL/IR/HALInterfaces.td
@@ -9,68 +9,6 @@
 
 include "iree/compiler/Dialect/Util/IR/UtilBase.td"
 
-//===----------------------------------------------------------------------===//
-// IREE::HAL::SizeAwareOpInterface
-//===----------------------------------------------------------------------===//
-
-def HAL_InferTypeSize : TypeInterface<"InferTypeSizeInterface"> {
-  let description = [{
-    Allows types to be queried for their size by inserting the required logic
-    when required.
-  }];
-
-  let methods = [
-    InterfaceMethod<
-      [{Builds an expression computing the size of the value.}],
-      "Value", "inferSizeFromValue", (ins "Location":$loc,
-                                          "Value":$value,
-                                          "OpBuilder &":$builder)
-    >,
-  ];
-}
-
-def HAL_SizeAwareType : TypeInterface<"SizeAwareTypeInterface"> {
-  let description = [{
-    Denotes that a type is size-aware and must always have a size value
-    associated with it in the IR. See `SizeAwareOp` for more information.
-  }];
-
-  let methods = [
-    InterfaceMethod<
-      [{Returns a size for the given sized value.}],
-      "Value", "getSize", (ins "Value":$value)
-    >,
-  ];
-}
-
-def HAL_SizeAwareOp : OpInterface<"SizeAwareOpInterface"> {
-  let description = [{
-    An operation that is able to provide size values for all size-aware operands
-    and results.
-  }];
-
-  let methods = [
-    InterfaceMethod<
-      [{Returns a size for the given sized operand index.}],
-      "Value", "getOperandSize", (ins "unsigned":$idx)
-    >,
-    InterfaceMethod<
-      [{Returns a size for the given sized result index.}],
-      "Value", "getResultSize", (ins "unsigned":$idx)
-    >,
-    InterfaceMethod<
-      [{Returns a size for the given sized result value.}],
-      "Value", "getResultSizeFromValue", (ins "Value":$value),
-      /*defaultImplementation=*/[{
-        for (unsigned i = 0; i < $_self->getNumResults(); ++i) {
-          if ($_self->getResult(i) == value) return $_self.getResultSize(i);
-        }
-        return {};
-      }]
-    >,
-  ];
-}
-
 def HAL_MatchAttrInterface :
     AttrInterface<"MatchAttrInterface"> {
   let description = [{
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index edb6f90..e5f26b8 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -51,65 +51,6 @@
 }
 
 //===----------------------------------------------------------------------===//
-// custom<SizeAwareType>
-//===----------------------------------------------------------------------===//
-// type{%size}
-
-static ParseResult parseSizeAwareType(OpAsmParser &parser, Type &type,
-                                      OpAsmParser::OperandType &size) {
-  if (failed(parser.parseType(type)) || failed(parser.parseLBrace()) ||
-      failed(parser.parseOperand(size)) || failed(parser.parseRBrace())) {
-    return failure();
-  }
-  return success();
-}
-
-static void printSizeAwareType(OpAsmPrinter &p, Operation *op, Type type,
-                               Value size) {
-  p.printType(type);
-  p << "{";
-  p.printOperand(size);
-  p << "}";
-}
-
-//===----------------------------------------------------------------------===//
-// custom<SizeAwareTypeList>
-//===----------------------------------------------------------------------===//
-// (type{%size0}, type, type{%size1})
-
-static ParseResult parseSizeAwareTypeList(
-    OpAsmParser &parser, SmallVectorImpl<Type> &types,
-    SmallVectorImpl<OpAsmParser::OperandType> &sizes) {
-  do {
-    Type type;
-    if (failed(parser.parseType(type))) return failure();
-    if (type.isa<SizeAwareTypeInterface>()) {
-      OpAsmParser::OperandType size;
-      if (failed(parser.parseLBrace()) || failed(parser.parseOperand(size)) ||
-          failed(parser.parseRBrace())) {
-        return failure();
-      }
-      sizes.push_back(size);
-    }
-    types.push_back(type);
-  } while (succeeded(parser.parseOptionalComma()));
-  return success();
-}
-
-static void printSizeAwareTypeList(OpAsmPrinter &p, Operation *op,
-                                   TypeRange types, OperandRange sizes) {
-  int sizeIndex = 0;
-  llvm::interleaveComma(types, p, [&](Type type) {
-    p.printType(type);
-    if (type.isa<SizeAwareTypeInterface>()) {
-      p << "{";
-      p.printOperand(sizes[sizeIndex++]);
-      p << "}";
-    }
-  });
-}
-
-//===----------------------------------------------------------------------===//
 // custom<DescriptorSetBindings>($binding_ordinals,
 //                               $binding_buffers,
 //                               type($binding_buffers),
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td
index 2d87fe6..6879e69 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -267,7 +267,7 @@
 
 def HAL_AllocatorAllocateOp : HAL_Op<"allocator.allocate", [
     DeclareOpInterfaceMethods<OpAsmOpInterface>,
-    DeclareOpInterfaceMethods<HAL_SizeAwareOp>,
+    DeclareOpInterfaceMethods<Util_SizeAwareOp>,
   ]> {
   let summary = [{empty buffer allocation operation}];
   let description = [{
@@ -331,7 +331,7 @@
 
 def HAL_AllocatorMapOp : HAL_Op<"allocator.map", [
     DeclareOpInterfaceMethods<OpAsmOpInterface>,
-    DeclareOpInterfaceMethods<HAL_SizeAwareOp>,
+    DeclareOpInterfaceMethods<Util_SizeAwareOp>,
   ]> {
   let summary = [{allocator-supported host buffer wrapping operation}];
   let description = [{
@@ -498,7 +498,7 @@
 
 def HAL_BufferSubspanOp : HAL_PureOp<"buffer.subspan", [
     DeclareOpInterfaceMethods<OpAsmOpInterface>,
-    DeclareOpInterfaceMethods<HAL_SizeAwareOp>,
+    DeclareOpInterfaceMethods<Util_SizeAwareOp>,
   ]> {
   let summary = [{buffer subspan operation}];
   let description = [{
diff --git a/iree/compiler/Dialect/HAL/IR/HALTypes.cpp b/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
index 2f9c7b6..e0c413e 100644
--- a/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
@@ -380,7 +380,7 @@
 
 // Returns the SSA value containing the size of the given |value|.
 static Value lookupValueSize(Value value) {
-  assert(value.getType().isa<SizeAwareTypeInterface>());
+  assert(value.getType().isa<IREE::Util::SizeAwareTypeInterface>());
 
   auto definingOp = value.getDefiningOp();
   if (!definingOp) {
@@ -402,7 +402,7 @@
     }
   }
   assert(resultIndex != -1 && "result not in results");
-  auto sizeAwareOp = dyn_cast<SizeAwareOpInterface>(definingOp);
+  auto sizeAwareOp = dyn_cast<IREE::Util::SizeAwareOpInterface>(definingOp);
   if (!sizeAwareOp) return {};
   return sizeAwareOp.getResultSize(resultIndex);
 }
diff --git a/iree/compiler/Dialect/HAL/IR/HALTypes.h b/iree/compiler/Dialect/HAL/IR/HALTypes.h
index a0abc18..ac22bc1 100644
--- a/iree/compiler/Dialect/HAL/IR/HALTypes.h
+++ b/iree/compiler/Dialect/HAL/IR/HALTypes.h
@@ -79,16 +79,18 @@
   using Base::Base;
 };
 
-class BufferType : public Type::TypeBase<BufferType, Type, TypeStorage,
-                                         InferTypeSizeInterface::Trait> {
+class BufferType
+    : public Type::TypeBase<BufferType, Type, TypeStorage,
+                            IREE::Util::InferTypeSizeInterface::Trait> {
  public:
   using Base::Base;
 
   Value inferSizeFromValue(Location loc, Value value, OpBuilder &builder) const;
 };
 
-class BufferViewType : public Type::TypeBase<BufferViewType, Type, TypeStorage,
-                                             InferTypeSizeInterface::Trait> {
+class BufferViewType
+    : public Type::TypeBase<BufferViewType, Type, TypeStorage,
+                            IREE::Util::InferTypeSizeInterface::Trait> {
  public:
   using Base::Base;
 
diff --git a/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp b/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp
index ebe1e99..20296da 100644
--- a/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp
+++ b/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp
@@ -67,16 +67,17 @@
         loc, builder.getIndexType(), value);
   }
 
-  if (auto awareOp = dyn_cast_or_null<SizeAwareOpInterface>(definingOp)) {
+  if (auto awareOp =
+          dyn_cast_or_null<IREE::Util::SizeAwareOpInterface>(definingOp)) {
     return awareOp.getResultSizeFromValue(value);
   }
 
   auto type = value.getType();
-  if (auto awareType = type.dyn_cast<SizeAwareTypeInterface>()) {
+  if (auto awareType = type.dyn_cast<IREE::Util::SizeAwareTypeInterface>()) {
     auto sizeValue = awareType.getSize(value);
     if (sizeValue) return sizeValue;
   }
-  if (auto inferType = type.dyn_cast<InferTypeSizeInterface>()) {
+  if (auto inferType = type.dyn_cast<IREE::Util::InferTypeSizeInterface>()) {
     return inferType.inferSizeFromValue(loc, value, builder);
   }
 
diff --git a/iree/compiler/Dialect/Util/IR/BUILD b/iree/compiler/Dialect/Util/IR/BUILD
index cf34616..b5ae87d 100644
--- a/iree/compiler/Dialect/Util/IR/BUILD
+++ b/iree/compiler/Dialect/Util/IR/BUILD
@@ -35,21 +35,29 @@
 cc_library(
     name = "IR",
     srcs = [
+        "ClosureOpUtils.cpp",
         "UtilDialect.cpp",
         "UtilOpFolders.cpp",
-        "UtilOpInterfaces.cpp.inc",
         "UtilOps.cpp",
         "UtilOps.cpp.inc",
         "UtilTypes.cpp",
     ],
     hdrs = [
+        "ClosureOpUtils.h",
         "UtilDialect.h",
-        "UtilOpInterfaces.h.inc",
         "UtilOps.h",
         "UtilOps.h.inc",
         "UtilTraits.h",
         "UtilTypes.h",
     ],
+    textual_hdrs = [
+        "UtilAttrInterfaces.cpp.inc",
+        "UtilAttrInterfaces.h.inc",
+        "UtilOpInterfaces.cpp.inc",
+        "UtilOpInterfaces.h.inc",
+        "UtilTypeInterfaces.cpp.inc",
+        "UtilTypeInterfaces.h.inc",
+    ],
     deps = [
         ":UtilInterfacesGen",
         ":UtilOpsGen",
@@ -69,6 +77,14 @@
     name = "UtilInterfacesGen",
     tbl_outs = [
         (
+            ["-gen-attr-interface-decls"],
+            "UtilAttrInterfaces.h.inc",
+        ),
+        (
+            ["-gen-attr-interface-defs"],
+            "UtilAttrInterfaces.cpp.inc",
+        ),
+        (
             ["-gen-op-interface-decls"],
             "UtilOpInterfaces.h.inc",
         ),
@@ -76,6 +92,14 @@
             ["-gen-op-interface-defs"],
             "UtilOpInterfaces.cpp.inc",
         ),
+        (
+            ["-gen-type-interface-decls"],
+            "UtilTypeInterfaces.h.inc",
+        ),
+        (
+            ["-gen-type-interface-defs"],
+            "UtilTypeInterfaces.cpp.inc",
+        ),
     ],
     tblgen = "@llvm-project//mlir:mlir-tblgen",
     td_file = "UtilInterfaces.td",
diff --git a/iree/compiler/Dialect/Util/IR/CMakeLists.txt b/iree/compiler/Dialect/Util/IR/CMakeLists.txt
index 72d4d1f..412793e 100644
--- a/iree/compiler/Dialect/Util/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/Util/IR/CMakeLists.txt
@@ -14,16 +14,23 @@
   NAME
     IR
   HDRS
+    "ClosureOpUtils.h"
     "UtilDialect.h"
-    "UtilOpInterfaces.h.inc"
     "UtilOps.h"
     "UtilOps.h.inc"
     "UtilTraits.h"
     "UtilTypes.h"
+  TEXTUAL_HDRS
+    "UtilAttrInterfaces.cpp.inc"
+    "UtilAttrInterfaces.h.inc"
+    "UtilOpInterfaces.cpp.inc"
+    "UtilOpInterfaces.h.inc"
+    "UtilTypeInterfaces.cpp.inc"
+    "UtilTypeInterfaces.h.inc"
   SRCS
+    "ClosureOpUtils.cpp"
     "UtilDialect.cpp"
     "UtilOpFolders.cpp"
-    "UtilOpInterfaces.cpp.inc"
     "UtilOps.cpp"
     "UtilOps.cpp.inc"
     "UtilTypes.cpp"
@@ -47,8 +54,12 @@
   TD_FILE
     "UtilInterfaces.td"
   OUTS
+    -gen-attr-interface-decls UtilAttrInterfaces.h.inc
+    -gen-attr-interface-defs UtilAttrInterfaces.cpp.inc
     -gen-op-interface-decls UtilOpInterfaces.h.inc
     -gen-op-interface-defs UtilOpInterfaces.cpp.inc
+    -gen-type-interface-decls UtilTypeInterfaces.h.inc
+    -gen-type-interface-defs UtilTypeInterfaces.cpp.inc
 )
 
 iree_tablegen_library(
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpUtils.cpp b/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp
similarity index 89%
rename from iree/compiler/Dialect/Flow/IR/FlowOpUtils.cpp
rename to iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp
index 16dd347..7febab2 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOpUtils.cpp
+++ b/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp
@@ -4,14 +4,15 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-#include "iree/compiler/Dialect/Flow/IR/FlowOpUtils.h"
+#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
 
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 
 namespace mlir {
 namespace iree_compiler {
 namespace IREE {
-namespace Flow {
+namespace Util {
 
 //------------------------------------------------------------------------------
 // Closure optimization
@@ -53,6 +54,8 @@
     auto type = it.value().getType();
     if (auto shapedType = type.dyn_cast<ShapedType>()) {
       numDynamicDims = shapedType.getNumDynamicDims();
+    } else if (type.isa<IREE::Util::SizeAwareTypeInterface>()) {
+      numDynamicDims = 1;
     }
     if (!llvm::count(excludedOperandIndices, it.index())) {
       operandValues.push_back(it.value());
@@ -73,6 +76,8 @@
     auto type = it.value();
     if (auto shapedType = type.dyn_cast<ShapedType>()) {
       numDynamicDims = shapedType.getNumDynamicDims();
+    } else if (type.isa<IREE::Util::SizeAwareTypeInterface>()) {
+      numDynamicDims = 1;
     }
     if (!llvm::count(excludedResultIndices, it.index())) {
       resultTypes.push_back(type);
@@ -86,15 +91,18 @@
 
 void eraseRegionResults(Region &region,
                         ArrayRef<unsigned> excludedResultIndices) {
-  region.walk([&](IREE::Flow::ReturnOp terminator) {
-    llvm::SmallVector<Value, 4> newReturns;
-    for (auto it : llvm::enumerate(terminator.getOperands())) {
-      if (!llvm::count(excludedResultIndices, it.index())) {
-        newReturns.push_back(it.value());
+  for (auto &block : region.getBlocks()) {
+    auto *terminatorOp = block.getTerminator();
+    if (terminatorOp && terminatorOp->hasTrait<OpTrait::ReturnLike>()) {
+      llvm::SmallVector<Value, 4> newReturns;
+      for (auto it : llvm::enumerate(terminatorOp->getOperands())) {
+        if (!llvm::count(excludedResultIndices, it.index())) {
+          newReturns.push_back(it.value());
+        }
       }
+      terminatorOp->setOperands(newReturns);
     }
-    terminator.getOperation()->setOperands(newReturns);
-  });
+  }
 }
 
 // Returns true if |constantOp| represents a (logically) small constant value.
@@ -173,14 +181,11 @@
     auto *sourceOp = outerValue.getDefiningOp();
     if (!sourceOp) continue;  // can't clone block arguments into closures
 
-    BlockArgument blockArg = entryBlock.getArgument(opArg.index());
-    if (auto type = blockArg.getType().dyn_cast<DispatchTensorType>()) {
-      // We cannot just simply inline and replace all users if this is an
-      // argument that can be written; for example, the region might perform
-      // work after loading a initial constant from the argument and then
-      // write back.
-      if (type.getAccess() != TensorAccess::ReadOnly) continue;
-    }
+    // We cannot just simply inline and replace all users if this is an
+    // argument that can be written; for example, the region might perform
+    // work after loading a initial constant from the argument and then
+    // write back.
+    if (!closureOp.getOperandAccess(opArg.index()).isReadOnly()) continue;
 
     if (closureOp.canClosureContainOp(sourceOp) &&
         shouldInlineIntoClosure(outerValue)) {
@@ -197,6 +202,7 @@
       auto newValue = clonedOp->getResult(resultIndex);
 
       // Replace all of the uses inside of the closure.
+      BlockArgument blockArg = entryBlock.getArgument(opArg.index());
       blockArg.replaceAllUsesWith(newValue);
     }
   }
@@ -245,7 +251,7 @@
     // You can drop a result if the use is empty, and that it is only written to
     // within the dispatch region.
     if (result.value().use_empty() &&
-        !closureOp.isOutputReadWithinRegion(result.index())) {
+        !closureOp.getResultAccess(result.index()).isRead) {
       elidedResults.push_back(result.index());
     } else {
       preservedResults.push_back(result.value());
@@ -293,7 +299,7 @@
   return success();
 }
 
-}  // namespace Flow
+}  // namespace Util
 }  // namespace IREE
 }  // namespace iree_compiler
 }  // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpUtils.h b/iree/compiler/Dialect/Util/IR/ClosureOpUtils.h
similarity index 86%
rename from iree/compiler/Dialect/Flow/IR/FlowOpUtils.h
rename to iree/compiler/Dialect/Util/IR/ClosureOpUtils.h
index 7f15490..c3a900f 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOpUtils.h
+++ b/iree/compiler/Dialect/Util/IR/ClosureOpUtils.h
@@ -4,8 +4,12 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
-#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#ifndef IREE_COMPILER_DIALECT_UTIL_IR_CLOSUREOPUTILS_H_
+#define IREE_COMPILER_DIALECT_UTIL_IR_CLOSUREOPUTILS_H_
+
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "mlir/IR/Builders.h"
@@ -15,7 +19,7 @@
 namespace mlir {
 namespace iree_compiler {
 namespace IREE {
-namespace Flow {
+namespace Util {
 
 //------------------------------------------------------------------------------
 // Closure optimization
@@ -50,7 +54,7 @@
 // Duplicate operands will be combined and unused operands and results will be
 // removed.
 //
-// T must implement the IREE::Flow::ClosureOpInterface.
+// T must implement the IREE::Util::ClosureOpInterface.
 template <typename T>
 struct ClosureOptimizationPattern : public OpRewritePattern<T> {
   using OpRewritePattern<T>::OpRewritePattern;
@@ -62,7 +66,9 @@
   }
 };
 
-}  // namespace Flow
+}  // namespace Util
 }  // namespace IREE
 }  // namespace iree_compiler
 }  // namespace mlir
+
+#endif  // IREE_COMPILER_DIALECT_UTIL_IR_CLOSUREOPUTILS_H_
diff --git a/iree/compiler/Dialect/Util/IR/UtilDialect.h b/iree/compiler/Dialect/Util/IR/UtilDialect.h
index 6d2e99f..8224469 100644
--- a/iree/compiler/Dialect/Util/IR/UtilDialect.h
+++ b/iree/compiler/Dialect/Util/IR/UtilDialect.h
@@ -4,8 +4,8 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-#ifndef IREE_COMPILER_DIALECT_IREE_IR_IREEDIALECT_H_
-#define IREE_COMPILER_DIALECT_IREE_IR_IREEDIALECT_H_
+#ifndef IREE_COMPILER_DIALECT_UTIL_IR_UTILDIALECT_H_
+#define IREE_COMPILER_DIALECT_UTIL_IR_UTILDIALECT_H_
 
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
@@ -36,4 +36,4 @@
 }  // namespace iree_compiler
 }  // namespace mlir
 
-#endif  // IREE_COMPILER_DIALECT_IREE_IR_IREEDIALECT_H_
+#endif  // IREE_COMPILER_DIALECT_UTIL_IR_UTILDIALECT_H_
diff --git a/iree/compiler/Dialect/Util/IR/UtilInterfaces.td b/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
index 6deb9a1..ea48549 100644
--- a/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
+++ b/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
@@ -10,6 +10,91 @@
 include "mlir/IR/OpBase.td"
 
 //===----------------------------------------------------------------------===//
+// IREE::Util::ClosureOpInterface
+//===----------------------------------------------------------------------===//
+
+def Util_ClosureOpInterface : OpInterface<"ClosureOpInterface"> {
+  let cppNamespace = "::mlir::iree_compiler::IREE::Util";
+
+  let description = [{
+    Interface for ops that follow the util dialect closure semantics (explicit
+    captures, dynamic-shape awareness, and normal operand/result SSA behavior).
+
+    Implementing this interface enables optimizations that perform manipulation
+    across the closure capture boundary (outside of the op <-> regions within
+    the op).
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns the body region of the closure (may have multiple blocks).
+      }],
+      /*retTy=*/"Region &",
+      /*methodName=*/"getClosureBodyRegion",
+      /*args=*/(ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        return this->getOperation()->getRegion(0);
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{Returns all closure operand values.}],
+      /*retTy=*/"Operation::operand_range",
+      /*methodName=*/"getClosureOperands",
+      /*args=*/(ins)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{Returns all closure result values.}],
+      /*retTy=*/"Operation::result_range",
+      /*methodName=*/"getClosureResults",
+      /*args=*/(ins)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns true if the given operation can exist in the closure.
+        Not all operations that a closure can contain are guaranteed to be folded
+        into the closure, such as when the operation may have side-effects.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"canClosureContainOp",
+      /*args=*/(ins "Operation *":$op)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Clones the op while removing specified operands and results.
+        The body of the op will be transferred to the new op and the entry block
+        will have its arguments removed.
+
+        The returned op will be free standing. Callers must insert it into a block
+        where desired (most often just replacing the current op).
+      }],
+      /*retTy=*/"IREE::Util::ClosureOpInterface",
+      /*methodName=*/"cloneReplacementExcludingOperandsAndResults",
+      /*args=*/(ins "ArrayRef<unsigned>":$excludedOperandIndices,
+                    "ArrayRef<unsigned>":$excludedResultIndices,
+                    "PatternRewriter &":$rewriter)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns a bitfield indicating how an operand is used within the closure.
+      }],
+      /*retTy=*/"IREE::Util::ValueAccess",
+      /*methodName=*/"getOperandAccess",
+      /*args=*/(ins "unsigned":$operandIndex)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns a bitfield indicating how a result is used within the closure.
+      }],
+      /*retTy=*/"IREE::Util::ValueAccess",
+      /*methodName=*/"getResultAccess",
+      /*args=*/(ins "unsigned":$resultIndex)
+    >
+  ];
+}
+
+//===----------------------------------------------------------------------===//
 // IREE::Util::TiedOpInterface
 //===----------------------------------------------------------------------===//
 
@@ -151,6 +236,19 @@
         return IREE::Util::detail::getTiedResultOperandIndices($_op);
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns true if the given flattened operand index is tied to one or more
+        results.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isOperandTied",
+      /*args=*/(ins "unsigned":$operandIndex),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        return IREE::Util::detail::isOperandTied($_op, operandIndex);
+      }]
+    >,
   ];
 
   let extraClassDeclaration = [{
@@ -171,6 +269,74 @@
 }
 
 //===----------------------------------------------------------------------===//
+// IREE::Util::SizeAware* interfaces
+//===----------------------------------------------------------------------===//
+
+def Util_InferTypeSize : TypeInterface<"InferTypeSizeInterface"> {
+  let cppNamespace = "::mlir::iree_compiler::IREE::Util";
+
+  let description = [{
+    Allows types to be queried for their size by inserting the required logic
+    when required.
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      [{Builds an expression computing the size of the value.}],
+      "Value", "inferSizeFromValue", (ins "Location":$loc,
+                                          "Value":$value,
+                                          "OpBuilder &":$builder)
+    >,
+  ];
+}
+
+def Util_SizeAwareType : TypeInterface<"SizeAwareTypeInterface"> {
+  let cppNamespace = "::mlir::iree_compiler::IREE::Util";
+
+  let description = [{
+    Denotes that a type is size-aware and must always have a size value
+    associated with it in the IR. See `SizeAwareOp` for more information.
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      [{Returns a size for the given sized value.}],
+      "Value", "getSize", (ins "Value":$value)
+    >,
+  ];
+}
+
+def Util_SizeAwareOp : OpInterface<"SizeAwareOpInterface"> {
+  let cppNamespace = "::mlir::iree_compiler::IREE::Util";
+
+  let description = [{
+    An operation that is able to provide size values for all size-aware operands
+    and results.
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      [{Returns a size for the given sized operand index.}],
+      "Value", "getOperandSize", (ins "unsigned":$idx)
+    >,
+    InterfaceMethod<
+      [{Returns a size for the given sized result index.}],
+      "Value", "getResultSize", (ins "unsigned":$idx)
+    >,
+    InterfaceMethod<
+      [{Returns a size for the given sized result value.}],
+      "Value", "getResultSizeFromValue", (ins "Value":$value),
+      /*defaultImplementation=*/[{
+        for (unsigned i = 0; i < $_self->getNumResults(); ++i) {
+          if ($_self->getResult(i) == value) return $_self.getResultSize(i);
+        }
+        return {};
+      }]
+    >,
+  ];
+}
+
+//===----------------------------------------------------------------------===//
 // IREE::Util::GlobalTypeInterface
 //===----------------------------------------------------------------------===//
 
@@ -195,7 +361,7 @@
       /*defaultImplementation=*/[{
         // If one is a shaped type, then they both must be and have compatible
         // shapes.
-        if ($_type.isa<ShapedType>() || accessType.isa<ShapedType>()) {
+        if ($_type.template isa<ShapedType>() || accessType.isa<ShapedType>()) {
           return succeeded(mlir::verifyCompatibleShape($_type, accessType));
         }
         // Otherwise, the types must be the same.
diff --git a/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/iree/compiler/Dialect/Util/IR/UtilOps.cpp
index 5cab132..11a6c11 100644
--- a/iree/compiler/Dialect/Util/IR/UtilOps.cpp
+++ b/iree/compiler/Dialect/Util/IR/UtilOps.cpp
@@ -24,8 +24,6 @@
 
 namespace mlir {
 namespace iree_compiler {
-namespace IREE {
-namespace Util {
 
 //===----------------------------------------------------------------------===//
 // custom<SymbolVisibility>($sym_visibility)
@@ -35,8 +33,8 @@
 // some.op @foo
 // some.op private @foo
 
-static ParseResult parseSymbolVisibility(OpAsmParser &parser,
-                                         StringAttr &symVisibilityAttr) {
+ParseResult parseSymbolVisibility(OpAsmParser &parser,
+                                  StringAttr &symVisibilityAttr) {
   StringRef symVisibility;
   parser.parseOptionalKeyword(&symVisibility, {"public", "private", "nested"});
   if (!symVisibility.empty()) {
@@ -45,8 +43,8 @@
   return success();
 }
 
-static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
-                                  StringAttr symVisibilityAttr) {
+void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
+                           StringAttr symVisibilityAttr) {
   if (!symVisibilityAttr) {
     p << "public";
   } else {
@@ -61,37 +59,396 @@
 // ->
 // some.op : i32
 // some.op = 42 : i32
+// some.op : i32 = 42 : index
 
-static ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
-                                   Attribute &attr) {
+ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
+                            Attribute &attr) {
   if (succeeded(parser.parseOptionalEqual())) {
     if (failed(parser.parseAttribute(attr))) {
       return parser.emitError(parser.getCurrentLocation())
              << "expected attribute";
     }
     typeAttr = TypeAttr::get(attr.getType());
-  } else {
-    Type type;
-    if (failed(parser.parseColonType(type))) {
-      return parser.emitError(parser.getCurrentLocation()) << "expected type";
+    return success();
+  }
+
+  Type type;
+  if (failed(parser.parseColonType(type))) {
+    return parser.emitError(parser.getCurrentLocation()) << "expected type";
+  }
+  typeAttr = TypeAttr::get(type);
+
+  if (succeeded(parser.parseOptionalEqual())) {
+    if (failed(parser.parseAttribute(attr))) {
+      return parser.emitError(parser.getCurrentLocation())
+             << "expected attribute";
     }
-    typeAttr = TypeAttr::get(type);
+  }
+
+  return success();
+}
+
+void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
+                     Attribute attr) {
+  if (!attr || attr.getType() != type.getValue()) {
+    p << " : ";
+    p.printAttribute(type);
+  }
+  if (attr) {
+    p << " = ";
+    p.printAttribute(attr);
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// custom<SizeAwareType>
+//===----------------------------------------------------------------------===//
+// type{%size}
+
+ParseResult parseSizeAwareType(OpAsmParser &parser, Type &type,
+                               OpAsmParser::OperandType &size) {
+  if (failed(parser.parseType(type)) || failed(parser.parseLBrace()) ||
+      failed(parser.parseOperand(size)) || failed(parser.parseRBrace())) {
+    return failure();
   }
   return success();
 }
 
-static void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
-                            Attribute attr) {
-  if (attr) {
-    p << " = ";
-    p.printAttribute(attr);
-  } else {
-    p << " : ";
-    p.printAttribute(type);
+void printSizeAwareType(OpAsmPrinter &p, Operation *op, Type type, Value size) {
+  p.printType(type);
+  p << "{";
+  p.printOperand(size);
+  p << "}";
+}
+
+//===----------------------------------------------------------------------===//
+// custom<SizeAwareTypeList>
+//===----------------------------------------------------------------------===//
+// (type{%size0}, type, type{%size1})
+
+ParseResult parseSizeAwareTypeList(
+    OpAsmParser &parser, SmallVectorImpl<Type> &types,
+    SmallVectorImpl<OpAsmParser::OperandType> &sizes) {
+  do {
+    Type type;
+    if (failed(parser.parseType(type))) return failure();
+    if (type.isa<IREE::Util::SizeAwareTypeInterface>()) {
+      OpAsmParser::OperandType size;
+      if (failed(parser.parseLBrace()) || failed(parser.parseOperand(size)) ||
+          failed(parser.parseRBrace())) {
+        return failure();
+      }
+      sizes.push_back(size);
+    }
+    types.push_back(type);
+  } while (succeeded(parser.parseOptionalComma()));
+  return success();
+}
+
+void printSizeAwareTypeList(OpAsmPrinter &p, Operation *op, TypeRange types,
+                            OperandRange sizes) {
+  int sizeIndex = 0;
+  llvm::interleaveComma(types, p, [&](Type type) {
+    p.printType(type);
+    if (type.isa<IREE::Util::SizeAwareTypeInterface>()) {
+      p << "{";
+      p.printOperand(sizes[sizeIndex++]);
+      p << "}";
+    }
+  });
+}
+
+//===----------------------------------------------------------------------===//
+// custom<ShapedTiedResult>
+//===----------------------------------------------------------------------===//
+// type{%dim0, %dim1}
+// %arg0 as type{%dim0}
+
+ParseResult parseShapedTiedResult(
+    OpAsmParser &parser, Type &resultType,
+    SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
+    ArrayAttr &tiedOperands) {
+  OpAsmParser::OperandType tiedResult;
+  auto res = parser.parseOptionalOperand(tiedResult);
+  int64_t tiedOperandIndex = IREE::Util::TiedOpInterface::kUntiedIndex;
+  if (res.hasValue() && succeeded(res.getValue())) {
+    tiedOperandIndex = 0;
+    if (failed(parser.parseKeyword("as"))) return failure();
+  }
+  if (failed(parser.parseType(resultType))) return failure();
+  if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
+    if (!shapedType.hasStaticShape()) {
+      SmallVector<OpAsmParser::OperandType, 4> dynamicDims;
+      if (failed(parser.parseLBrace()) ||
+          failed(parser.parseOperandList(dynamicDims,
+                                         shapedType.getNumDynamicDims(),
+                                         OpAsmParser::Delimiter::None)) ||
+          failed(parser.parseRBrace())) {
+        return failure();
+      }
+      resultDims.append(dynamicDims);
+    }
+  } else if (auto sizedType =
+                 resultType.dyn_cast<IREE::Util::SizeAwareTypeInterface>()) {
+    OpAsmParser::OperandType size;
+    if (failed(parser.parseOperand(size))) {
+      return failure();
+    }
+    resultDims.push_back(size);
+  }
+  tiedOperands = parser.getBuilder().getIndexArrayAttr({tiedOperandIndex});
+  return success();
+}
+
+void printShapedTiedResult(OpAsmPrinter &p, Operation *op, Type resultType,
+                           ValueRange resultDims, ArrayAttr tiedOperands) {
+  auto tiedOp = cast<IREE::Util::TiedOpInterface>(op);
+  auto tiedOperandIndex = tiedOp.getTiedResultOperandIndex(0);
+  if (tiedOperandIndex.hasValue()) {
+    auto tiedOperand = op->getOperand(tiedOperandIndex.getValue());
+    p.printOperand(tiedOperand);
+    p << " as ";
+  }
+  p.printType(resultType);
+  if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
+    if (!shapedType.hasStaticShape()) {
+      if (resultDims.empty()) {
+        p << "{<<INVALID>>}";
+        return;
+      }
+      p << "{";
+      llvm::interleaveComma(
+          resultDims.take_front(shapedType.getNumDynamicDims()), p,
+          [&](Value value) { p.printOperand(value); });
+      p << "}";
+      resultDims = resultDims.drop_front(shapedType.getNumDynamicDims());
+    }
+  } else if (auto sizedType =
+                 resultType.dyn_cast<IREE::Util::SizeAwareTypeInterface>()) {
+    p << "{";
+    p.printOperand(resultDims.front());
+    p << "}";
+    resultDims = resultDims.drop_front(1);
   }
 }
 
 //===----------------------------------------------------------------------===//
+// custom<ShapedFunctionType>
+//===----------------------------------------------------------------------===//
+// (type, type{%dim0, %dim1}, type) -> (type{%dim2}, %operand4)
+
+static ParseResult parseShapedOperandList(
+    OpAsmParser &parser, SmallVectorImpl<Type> &types,
+    SmallVectorImpl<OpAsmParser::OperandType> &dims) {
+  do {
+    Type type;
+    if (failed(parser.parseType(type))) return failure();
+    if (auto shapedType = type.dyn_cast<ShapedType>()) {
+      if (!shapedType.hasStaticShape()) {
+        SmallVector<OpAsmParser::OperandType, 4> dynamicDims;
+        if (failed(parser.parseLBrace()) ||
+            failed(parser.parseOperandList(dynamicDims,
+                                           shapedType.getNumDynamicDims(),
+                                           OpAsmParser::Delimiter::None)) ||
+            failed(parser.parseRBrace())) {
+          return failure();
+        }
+        dims.append(dynamicDims);
+      }
+    } else if (auto sizedType =
+                   type.dyn_cast<IREE::Util::SizeAwareTypeInterface>()) {
+      OpAsmParser::OperandType size;
+      if (failed(parser.parseOperand(size))) {
+        return failure();
+      }
+      dims.push_back(size);
+    }
+    types.push_back(type);
+  } while (succeeded(parser.parseOptionalComma()));
+  return success();
+}
+
+// Finds the operand index in |operands| that |tiedResult| references.
+// Returns TiedOpInterface::kUntiedIndex if no operand is found.
+static int64_t findTiedOperand(OpAsmParser::OperandType tiedResult,
+                               ArrayRef<OpAsmParser::OperandType> operands) {
+  int64_t operandIndex = IREE::Util::TiedOpInterface::kUntiedIndex;
+  for (int64_t i = 0; i < operands.size(); ++i) {
+    if (operands[i].name == tiedResult.name) {
+      operandIndex = i;
+      break;
+    }
+  }
+  return operandIndex;
+}
+
+static ParseResult parseShapedResultList(
+    OpAsmParser &parser, ArrayRef<OpAsmParser::OperandType> operands,
+    TypeRange operandTypes, ArrayRef<OpAsmParser::OperandType> operandDims,
+    SmallVectorImpl<Type> &resultTypes,
+    SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
+    ArrayAttr &tiedOperands) {
+  SmallVector<int64_t, 4> tiedOperandIndices;
+  do {
+    OpAsmParser::OperandType tiedResult;
+    auto res = parser.parseOptionalOperand(tiedResult);
+    Type type;
+    int64_t tiedOperandIndex = IREE::Util::TiedOpInterface::kUntiedIndex;
+    if (res.hasValue() && succeeded(res.getValue())) {
+      tiedOperandIndex = findTiedOperand(tiedResult, operands);
+      if (tiedOperandIndex == IREE::Util::TiedOpInterface::kUntiedIndex) {
+        return parser.emitError(tiedResult.location,
+                                "tied operand not found for result reference ")
+               << tiedResult.name;
+      }
+      if (succeeded(parser.parseOptionalKeyword("as"))) {
+        // Type _may_ differ from the operand.
+        if (failed(parser.parseType(type))) return failure();
+      } else {
+        // Use the operands type.
+        type = operandTypes[tiedOperandIndex];
+      }
+    } else if (failed(parser.parseType(type))) {
+      return failure();
+    }
+    if (auto shapedType = type.dyn_cast<ShapedType>()) {
+      if (!shapedType.hasStaticShape()) {
+        SmallVector<OpAsmParser::OperandType, 4> dynamicDims;
+        if (failed(parser.parseLBrace()) ||
+            failed(parser.parseOperandList(dynamicDims,
+                                           shapedType.getNumDynamicDims(),
+                                           OpAsmParser::Delimiter::None)) ||
+            failed(parser.parseRBrace())) {
+          return failure();
+        }
+        resultDims.append(dynamicDims);
+      }
+    } else if (auto sizedType =
+                   type.dyn_cast<IREE::Util::SizeAwareTypeInterface>()) {
+      OpAsmParser::OperandType size;
+      if (failed(parser.parseOperand(size))) {
+        return failure();
+      }
+      resultDims.push_back(size);
+    }
+    resultTypes.push_back(type);
+    tiedOperandIndices.push_back(tiedOperandIndex);
+  } while (succeeded(parser.parseOptionalComma()));
+  if (!tiedOperandIndices.empty()) {
+    tiedOperands = parser.getBuilder().getIndexArrayAttr(tiedOperandIndices);
+  }
+  return success();
+}
+
+ParseResult parseShapedFunctionType(
+    OpAsmParser &parser, ArrayRef<OpAsmParser::OperandType> operands,
+    SmallVectorImpl<Type> &operandTypes,
+    SmallVectorImpl<OpAsmParser::OperandType> &operandDims,
+    SmallVectorImpl<Type> &resultTypes,
+    SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
+    ArrayAttr &tiedOperands) {
+  if (failed(parser.parseLParen())) return failure();
+  if (failed(parser.parseOptionalRParen())) {
+    if (failed(parseShapedOperandList(parser, operandTypes, operandDims)) ||
+        failed(parser.parseRParen())) {
+      return failure();
+    }
+  }
+  if (failed(parser.parseArrow())) return failure();
+  if (succeeded(parser.parseOptionalLParen())) {
+    if (failed(parseShapedResultList(parser, operands, operandTypes,
+                                     operandDims, resultTypes, resultDims,
+                                     tiedOperands)) ||
+        failed(parser.parseRParen())) {
+      return failure();
+    }
+  } else {
+    if (failed(parseShapedResultList(parser, operands, operandTypes,
+                                     operandDims, resultTypes, resultDims,
+                                     tiedOperands))) {
+      return failure();
+    }
+  }
+  return success();
+}
+
+void printShapedFunctionType(OpAsmPrinter &p, Operation *op,
+                             ValueRange operands, TypeRange operandTypes,
+                             OperandRange operandDims, TypeRange resultTypes,
+                             OperandRange resultDims, ArrayAttr tiedOperands) {
+  p << "(";
+  llvm::interleaveComma(operandTypes, p, [&](Type type) {
+    p.printType(type);
+    if (auto shapedType = type.dyn_cast<ShapedType>()) {
+      if (!shapedType.hasStaticShape()) {
+        if (operandDims.empty()) {
+          p << "{<<INVALID>>}";
+          return;
+        }
+        p << "{";
+        llvm::interleaveComma(
+            operandDims.take_front(shapedType.getNumDynamicDims()), p,
+            [&](Value value) { p.printOperand(value); });
+        p << "}";
+        operandDims = operandDims.drop_front(shapedType.getNumDynamicDims());
+      }
+    } else if (auto sizedType =
+                   type.dyn_cast<IREE::Util::SizeAwareTypeInterface>()) {
+      p << "{";
+      p.printOperand(operandDims.front());
+      p << "}";
+      operandDims = operandDims.drop_front(1);
+    }
+  });
+  p << ") -> ";
+  if (resultTypes.size() != 1) p << "(";
+  auto tiedOp = cast<IREE::Util::TiedOpInterface>(op);
+  for (unsigned i = 0; i < resultTypes.size(); ++i) {
+    auto resultType = resultTypes[i];
+    auto tiedOperandIndex = tiedOp.getTiedResultOperandIndex(i);
+    bool printType = true;
+    if (tiedOperandIndex.hasValue()) {
+      auto tiedOperand = op->getOperand(tiedOperandIndex.getValue());
+      p.printOperand(tiedOperand);
+      if (tiedOperand.getType() != resultType) {
+        p << " as ";
+      } else {
+        // Type elided as it matches the operand.
+        printType = false;
+      }
+    }
+    if (printType) {
+      p.printType(resultType);
+    }
+    if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
+      if (!shapedType.hasStaticShape()) {
+        if (resultDims.empty()) {
+          p << "{<<INVALID>>}";
+          return;
+        }
+        p << "{";
+        llvm::interleaveComma(
+            resultDims.take_front(shapedType.getNumDynamicDims()), p,
+            [&](Value value) { p.printOperand(value); });
+        p << "}";
+        resultDims = resultDims.drop_front(shapedType.getNumDynamicDims());
+      }
+    } else if (auto sizedType =
+                   resultType.dyn_cast<IREE::Util::SizeAwareTypeInterface>()) {
+      p << "{";
+      p.printOperand(resultDims.front());
+      p << "}";
+      resultDims = resultDims.drop_front(1);
+    }
+    if (i < resultTypes.size() - 1) p << ", ";
+  }
+  if (resultTypes.size() != 1) p << ")";
+}
+
+namespace IREE {
+namespace Util {
+
+//===----------------------------------------------------------------------===//
 // util.do_not_optimize
 //===----------------------------------------------------------------------===//
 
@@ -222,7 +579,9 @@
     return succeeded(mlir::verifyCompatibleShape(globalType, accessType));
   }
 
-  // TODO(benvanik): use GlobalOpInterface.
+  if (auto knownType = globalType.dyn_cast<GlobalTypeInterface>()) {
+    return knownType.isAccessStorageCompatible(accessType);
+  }
 
   // Otherwise, the types must be the same.
   return globalType == accessType;
diff --git a/iree/compiler/Dialect/Util/IR/UtilOps.h b/iree/compiler/Dialect/Util/IR/UtilOps.h
index f0453fa..55cdde8 100644
--- a/iree/compiler/Dialect/Util/IR/UtilOps.h
+++ b/iree/compiler/Dialect/Util/IR/UtilOps.h
@@ -4,8 +4,8 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-#ifndef IREE_COMPILER_DIALECT_IREE_IR_IREEOPS_H_
-#define IREE_COMPILER_DIALECT_IREE_IR_IREEOPS_H_
+#ifndef IREE_COMPILER_DIALECT_UTIL_IR_UTILOPS_H_
+#define IREE_COMPILER_DIALECT_UTIL_IR_UTILOPS_H_
 
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -20,4 +20,87 @@
 #define GET_OP_CLASSES
 #include "iree/compiler/Dialect/Util/IR/UtilOps.h.inc"  // IWYU pragma: export
 
-#endif  // IREE_COMPILER_DIALECT_IREE_IR_IREEOPS_H_
+namespace mlir {
+namespace iree_compiler {
+
+//===----------------------------------------------------------------------===//
+// custom<SymbolVisibility>($sym_visibility)
+//===----------------------------------------------------------------------===//
+// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
+// ->
+// some.op @foo
+// some.op private @foo
+
+ParseResult parseSymbolVisibility(OpAsmParser &parser,
+                                  StringAttr &symVisibilityAttr);
+void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
+                           StringAttr symVisibilityAttr);
+
+//===----------------------------------------------------------------------===//
+// custom<TypeOrAttr>($type, $attr)
+//===----------------------------------------------------------------------===//
+// some.op custom<TypeOrAttr>($type, $attr)
+// ->
+// some.op : i32
+// some.op = 42 : i32
+// some.op : i32 = 42 : index
+
+ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
+                            Attribute &attr);
+void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
+                     Attribute attr);
+
+//===----------------------------------------------------------------------===//
+// custom<SizeAwareType>
+//===----------------------------------------------------------------------===//
+// type{%size}
+
+ParseResult parseSizeAwareType(OpAsmParser &parser, Type &type,
+                               OpAsmParser::OperandType &size);
+void printSizeAwareType(OpAsmPrinter &p, Operation *op, Type type, Value size);
+
+//===----------------------------------------------------------------------===//
+// custom<SizeAwareTypeList>
+//===----------------------------------------------------------------------===//
+// (type{%size0}, type, type{%size1})
+
+ParseResult parseSizeAwareTypeList(
+    OpAsmParser &parser, SmallVectorImpl<Type> &types,
+    SmallVectorImpl<OpAsmParser::OperandType> &sizes);
+void printSizeAwareTypeList(OpAsmPrinter &p, Operation *op, TypeRange types,
+                            OperandRange sizes);
+
+//===----------------------------------------------------------------------===//
+// custom<ShapedTiedResult>
+//===----------------------------------------------------------------------===//
+// type{%dim0, %dim1}
+// %arg0 as type{%dim0}
+
+ParseResult parseShapedTiedResult(
+    OpAsmParser &parser, Type &resultType,
+    SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
+    ArrayAttr &tiedOperands);
+void printShapedTiedResult(OpAsmPrinter &p, Operation *op, Type resultType,
+                           ValueRange resultDims, ArrayAttr tiedOperands);
+
+//===----------------------------------------------------------------------===//
+// custom<ShapedFunctionType>
+//===----------------------------------------------------------------------===//
+// (type, type{%dim0, %dim1}, type) -> (type{%dim2}, %operand4)
+
+ParseResult parseShapedFunctionType(
+    OpAsmParser &parser, ArrayRef<OpAsmParser::OperandType> operands,
+    SmallVectorImpl<Type> &operandTypes,
+    SmallVectorImpl<OpAsmParser::OperandType> &operandDims,
+    SmallVectorImpl<Type> &resultTypes,
+    SmallVectorImpl<OpAsmParser::OperandType> &resultDims,
+    ArrayAttr &tiedOperands);
+void printShapedFunctionType(OpAsmPrinter &p, Operation *op,
+                             ValueRange operands, TypeRange operandTypes,
+                             OperandRange operandDims, TypeRange resultTypes,
+                             OperandRange resultDims, ArrayAttr tiedOperands);
+
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_DIALECT_UTIL_IR_UTILOPS_H_
diff --git a/iree/compiler/Dialect/Util/IR/UtilTraits.h b/iree/compiler/Dialect/Util/IR/UtilTraits.h
index 3695e4e..e610220 100644
--- a/iree/compiler/Dialect/Util/IR/UtilTraits.h
+++ b/iree/compiler/Dialect/Util/IR/UtilTraits.h
@@ -4,8 +4,8 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-#ifndef IREE_COMPILER_DIALECT_IREE_IR_IREETRAITS_H_
-#define IREE_COMPILER_DIALECT_IREE_IR_IREETRAITS_H_
+#ifndef IREE_COMPILER_DIALECT_UTIL_IR_UTILTRAITS_H_
+#define IREE_COMPILER_DIALECT_UTIL_IR_UTILTRAITS_H_
 
 #include "mlir/IR/OpDefinition.h"
 
@@ -37,4 +37,4 @@
 }  // namespace OpTrait
 }  // namespace mlir
 
-#endif  // IREE_COMPILER_DIALECT_IREE_IR_IREETRAITS_H_
+#endif  // IREE_COMPILER_DIALECT_UTIL_IR_UTILTRAITS_H_
diff --git a/iree/compiler/Dialect/Util/IR/UtilTypes.cpp b/iree/compiler/Dialect/Util/IR/UtilTypes.cpp
index cd2534b..60ad827 100644
--- a/iree/compiler/Dialect/Util/IR/UtilTypes.cpp
+++ b/iree/compiler/Dialect/Util/IR/UtilTypes.cpp
@@ -181,6 +181,24 @@
   return baseValue;
 }
 
+bool detail::isOperandTied(Operation *op, unsigned operandIndex) {
+  auto storageAttr =
+      op->getAttrOfType<ArrayAttr>(TiedOpInterface::getStorageAttrName());
+  if (!storageAttr) return false;
+  auto valueAttrs = storageAttr.getValue();
+  if (valueAttrs.empty()) return false;
+  auto tiedOp = cast<TiedOpInterface>(op);
+  unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first;
+  for (unsigned i = 0; i < valueAttrs.size(); ++i) {
+    int64_t index = valueAttrs[i].cast<IntegerAttr>().getInt();
+    index = index != TiedOpInterface::kUntiedIndex
+                ? tiedOperandsOffset + index
+                : TiedOpInterface::kUntiedIndex;
+    if (index == operandIndex) return true;
+  }
+  return false;
+}
+
 LogicalResult detail::verifyTiedOp(TiedOpInterface tiedOp) {
   auto storageAttr =
       tiedOp->getAttrOfType<ArrayAttr>(TiedOpInterface::getStorageAttrName());
@@ -242,13 +260,15 @@
   }
 }
 
-// At the end so it can use functions above:
-#include "iree/compiler/Dialect/Util/IR/UtilOpInterfaces.cpp.inc"
-
 //===----------------------------------------------------------------------===//
 // IREE::Util::UtilDialect
 //===----------------------------------------------------------------------===//
 
+// At the end so it can use functions above:
+#include "iree/compiler/Dialect/Util/IR/UtilAttrInterfaces.cpp.inc"
+#include "iree/compiler/Dialect/Util/IR/UtilOpInterfaces.cpp.inc"
+#include "iree/compiler/Dialect/Util/IR/UtilTypeInterfaces.cpp.inc"
+
 void UtilDialect::registerTypes() {
   addTypes<IREE::Util::ByteBufferType, IREE::Util::ListType,
            IREE::Util::MutableByteBufferType, IREE::Util::PtrType,
diff --git a/iree/compiler/Dialect/Util/IR/UtilTypes.h b/iree/compiler/Dialect/Util/IR/UtilTypes.h
index c051a65..6d2cac3 100644
--- a/iree/compiler/Dialect/Util/IR/UtilTypes.h
+++ b/iree/compiler/Dialect/Util/IR/UtilTypes.h
@@ -4,15 +4,18 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-#ifndef IREE_COMPILER_DIALECT_IREE_IR_IREETYPES_H_
-#define IREE_COMPILER_DIALECT_IREE_IR_IREETYPES_H_
+#ifndef IREE_COMPILER_DIALECT_UTIL_IR_UTILTYPES_H_
+#define IREE_COMPILER_DIALECT_UTIL_IR_UTILTYPES_H_
 
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/SubElementInterfaces.h"
 #include "mlir/IR/TypeSupport.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Types.h"
 
 namespace mlir {
@@ -52,6 +55,22 @@
   DoNotUseReservedForFutureExpansionUseDefaultInSwitchInstead_ = 20
 };
 
+struct ValueAccess {
+  bool isRead : 1;
+  bool isWrite : 1;
+  bool isDiscard : 1;
+  bool isNone() const { return !isRead && !isWrite && !isDiscard; }
+  bool isReadOnly() const { return isRead && !isWrite && !isDiscard; }
+  ValueAccess() : isRead(false), isWrite(false), isDiscard(false) {}
+  ValueAccess(bool isRead, bool isWrite, bool isDiscard)
+      : isRead(isRead), isWrite(isWrite), isDiscard(isDiscard) {}
+  static ValueAccess None() { return ValueAccess(false, false, false); }
+  static ValueAccess ReadOnly() { return ValueAccess(true, false, false); }
+  static ValueAccess ReadWrite() { return ValueAccess(true, true, false); }
+  static ValueAccess WriteOnly() { return ValueAccess(false, true, false); }
+  static ValueAccess DiscardWrite() { return ValueAccess(false, true, true); }
+};
+
 /// Placeholder for a variant type (`?`).
 class VariantType : public Type::TypeBase<VariantType, Type, TypeStorage> {
  public:
@@ -133,6 +152,7 @@
 void setTiedResultOperandIndex(Operation *op, unsigned resultIndex,
                                llvm::Optional<unsigned> operandIndex);
 SmallVector<int64_t, 4> getTiedResultOperandIndices(Operation *op);
+bool isOperandTied(Operation *tiedOp, unsigned operandIndex);
 LogicalResult verifyTiedOp(TiedOpInterface tiedOp);
 }  // namespace detail
 
@@ -148,6 +168,8 @@
 }  // namespace iree_compiler
 }  // namespace mlir
 
+#include "iree/compiler/Dialect/Util/IR/UtilAttrInterfaces.h.inc"  // IWYU pragma: export
 #include "iree/compiler/Dialect/Util/IR/UtilOpInterfaces.h.inc"  // IWYU pragma: export
+#include "iree/compiler/Dialect/Util/IR/UtilTypeInterfaces.h.inc"  // IWYU pragma: export
 
-#endif  // IREE_COMPILER_DIALECT_IREE_IR_IREETYPES_H_
+#endif  // IREE_COMPILER_DIALECT_UTIL_IR_UTILTYPES_H_
diff --git a/iree/compiler/Dialect/Util/IR/test/global_ops.mlir b/iree/compiler/Dialect/Util/IR/test/global_ops.mlir
index b38d3ff..822c131 100644
--- a/iree/compiler/Dialect/Util/IR/test/global_ops.mlir
+++ b/iree/compiler/Dialect/Util/IR/test/global_ops.mlir
@@ -19,6 +19,12 @@
 // CHECK: util.global public @v_initialized_const3 = dense<4> : tensor<4xi32>
 util.global public @v_initialized_const3 = dense<4> : tensor<4xi32>
 
+// CHECK: util.global public @v_initialized_const4 = dense<4> : tensor<4xi32>
+util.global public @v_initialized_const4 : tensor<4xi32> = dense<4> : tensor<4xi32>
+
+// CHECK: util.global public @v_initialized_const5 : tensor<4xf32> = dense<4> : tensor<4xi32>
+util.global public @v_initialized_const5 : tensor<4xf32> = dense<4> : tensor<4xi32>
+
 // -----
 
 // CHECK: util.global private @v_initialized initializer(@initializer) : tensor<4xi32>
diff --git a/iree/compiler/Dialect/VM/IR/VMOps.cpp b/iree/compiler/Dialect/VM/IR/VMOps.cpp
index 9b064d8..92c5ae6 100644
--- a/iree/compiler/Dialect/VM/IR/VMOps.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMOps.cpp
@@ -6,6 +6,7 @@
 
 #include "iree/compiler/Dialect/VM/IR/VMOps.h"
 
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
 #include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
@@ -27,70 +28,6 @@
 namespace VM {
 
 //===----------------------------------------------------------------------===//
-// custom<SymbolVisibility>($sym_visibility)
-//===----------------------------------------------------------------------===//
-// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
-// ->
-// some.op @foo
-// some.op private @foo
-
-static ParseResult parseSymbolVisibility(OpAsmParser &parser,
-                                         StringAttr &symVisibilityAttr) {
-  StringRef symVisibility;
-  parser.parseOptionalKeyword(&symVisibility, {"public", "private", "nested"});
-  if (!symVisibility.empty()) {
-    symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
-  }
-  return success();
-}
-
-static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
-                                  StringAttr symVisibilityAttr) {
-  if (!symVisibilityAttr) {
-    p << "public";
-  } else {
-    p << symVisibilityAttr.getValue();
-  }
-}
-
-//===----------------------------------------------------------------------===//
-// custom<TypeOrAttr>($type, $attr)
-//===----------------------------------------------------------------------===//
-// some.op custom<TypeOrAttr>($type, $attr)
-// ->
-// some.op : i32
-// some.op = 42 : i32
-
-static ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
-                                   Attribute &attr) {
-  if (succeeded(parser.parseOptionalEqual())) {
-    if (failed(parser.parseAttribute(attr))) {
-      return parser.emitError(parser.getCurrentLocation())
-             << "expected attribute";
-    }
-    typeAttr = TypeAttr::get(attr.getType());
-  } else {
-    Type type;
-    if (failed(parser.parseColonType(type))) {
-      return parser.emitError(parser.getCurrentLocation()) << "expected type";
-    }
-    typeAttr = TypeAttr::get(type);
-  }
-  return success();
-}
-
-static void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
-                            Attribute attr) {
-  if (attr) {
-    p << " = ";
-    p.printAttribute(attr);
-  } else {
-    p << " : ";
-    p.printAttribute(type);
-  }
-}
-
-//===----------------------------------------------------------------------===//
 // Structural ops
 //===----------------------------------------------------------------------===//