[Flow][Global Opt] Fold unit dims of `stream.parameter.named` (#17824)

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
index efc338f..c9cfaf2 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
@@ -89,6 +89,7 @@
         "//compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions:LinalgExtExtensions",
         "//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms",
         "//compiler/src/iree/compiler/Dialect/LinalgExt/Utils",
+        "//compiler/src/iree/compiler/Dialect/Stream/IR",
         "//compiler/src/iree/compiler/Dialect/Util/Analysis",
         "//compiler/src/iree/compiler/Dialect/Util/IR",
         "//compiler/src/iree/compiler/Dialect/Util/Transforms",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index f448835..c69fe6a 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -116,6 +116,7 @@
     iree::compiler::Dialect::LinalgExt::TransformExtensions::LinalgExtExtensions
     iree::compiler::Dialect::LinalgExt::Transforms
     iree::compiler::Dialect::LinalgExt::Utils
+    iree::compiler::Dialect::Stream::IR
     iree::compiler::Dialect::Util::Analysis
     iree::compiler::Dialect::Util::IR
     iree::compiler::Dialect::Util::Transforms
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp
index f64e419..390e7f6 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp
@@ -13,7 +13,10 @@
 
 #include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
 #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
 #include "iree/compiler/Dialect/Util/Analysis/Explorer.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -22,6 +25,8 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
+#define DEBUG_TYPE "iree-flow-fold-unit-extent-dims"
+
 namespace mlir::iree_compiler::IREE::Flow {
 
 #define GEN_PASS_DEF_FOLDUNITEXTENTDIMSPASS
@@ -46,15 +51,24 @@
   }
   auto newGlobalType = globalType.clone(newShape);
   auto initialValue = global.getGlobalInitialValue();
-  // TODO: Handle non-uninitialized cases.
-  auto uninitializedAttr =
-      llvm::dyn_cast_if_present<IREE::Util::UninitializedAttr>(initialValue);
-  if (initialValue && !uninitializedAttr)
+  if (!initialValue)
     return success();
-  TypedAttr newInitialValue;
-  if (initialValue) {
-    newInitialValue = IREE::Util::UninitializedAttr::get(rewriter.getContext(),
-                                                         newGlobalType);
+  // TODO: Handle other cases
+  auto newInitialValue =
+      llvm::TypeSwitch<Attribute, Attribute>(initialValue)
+          .Case<IREE::Util::UninitializedAttr>([&](Attribute) {
+            return IREE::Util::UninitializedAttr::get(rewriter.getContext(),
+                                                      newGlobalType);
+          })
+          .Case<IREE::Stream::NamedParameterAttr>(
+              [&](IREE::Stream::NamedParameterAttr attr) {
+                return IREE::Stream::NamedParameterAttr::get(
+                    rewriter.getContext(), newGlobalType, attr.getScope(),
+                    attr.getKey(), attr.getConfig());
+              })
+          .Default([&](Attribute) { return nullptr; });
+  if (!newInitialValue) {
+    return success();
   }
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPoint(global);
@@ -101,19 +115,19 @@
 } // namespace
 
 void FoldUnitExtentDimsPass::runOnOperation() {
-  auto funcOp = getOperation();
+  auto moduleOp = getOperation();
   MLIRContext *context = &getContext();
 
-  Explorer explorer(funcOp, TraversalAction::RECURSE);
+  SymbolTable moduleSymbols(moduleOp);
+  Explorer explorer(moduleOp, TraversalAction::RECURSE);
   explorer.initialize();
   IRRewriter rewriter(context);
-  SymbolTable moduleSymbols(funcOp);
 
   // Fold unit dims of GlobalOpInterface ops.
   explorer.forEachGlobal([&](const Explorer::GlobalInfo *globalInfo) {
     IREE::Util::GlobalOpInterface global = globalInfo->op;
     auto tensorType = dyn_cast<RankedTensorType>(global.getGlobalType());
-    if (!tensorType || !global.isGlobalPrivate() || !global.isGlobalMutable()) {
+    if (!tensorType || !global.isGlobalPrivate()) {
       return;
     }
     if (llvm::none_of(tensorType.getShape(),
@@ -142,7 +156,7 @@
   };
   linalg::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, options);
   linalg::populateMoveInitOperandsToInputPattern(foldUnitDimsPatterns);
-  if (failed(applyPatternsAndFoldGreedily(funcOp,
+  if (failed(applyPatternsAndFoldGreedily(moduleOp,
                                           std::move(foldUnitDimsPatterns)))) {
     return signalPassFailure();
   }
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir
index 611c2da..e652c5e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir
@@ -91,3 +91,18 @@
 //      CHECK:   util.func public @no_fold_global_unit_dims
 //      CHECK:     %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<1x32x1x1x64xf32>
 //      CHECK:     %[[COLLAPSE:.+]] = tensor.collapse_shape %[[LOAD]]
+
+// -----
+
+module @fold_stream_parameter {
+  util.global private mutable @global = #stream.parameter.named<"module"::"global"> : tensor<1x1x10xf32>
+  util.func public @fold_stream_parameter() -> tensor<1x1x10xf32> {
+    %global = util.global.load @global : tensor<1x1x10xf32>
+    util.return %global : tensor<1x1x10xf32>
+  }
+}
+
+//      CHECK: module @fold_stream_parameter
+//      CHECK:   util.global private mutable @[[GLOBAL:.+]] = #stream.parameter.named<"module"::"global"> : tensor<10xf32>
+//      CHECK:   util.func public @fold_stream_parameter
+//      CHECK:     %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<10xf32>