[Flow][Global Opt] Fold unit dims of `stream.parameter.named` (#17824)
Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
index efc338f..c9cfaf2 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
@@ -89,6 +89,7 @@
"//compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions:LinalgExtExtensions",
"//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms",
"//compiler/src/iree/compiler/Dialect/LinalgExt/Utils",
+ "//compiler/src/iree/compiler/Dialect/Stream/IR",
"//compiler/src/iree/compiler/Dialect/Util/Analysis",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Dialect/Util/Transforms",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index f448835..c69fe6a 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -116,6 +116,7 @@
iree::compiler::Dialect::LinalgExt::TransformExtensions::LinalgExtExtensions
iree::compiler::Dialect::LinalgExt::Transforms
iree::compiler::Dialect::LinalgExt::Utils
+ iree::compiler::Dialect::Stream::IR
iree::compiler::Dialect::Util::Analysis
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::Util::Transforms
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp
index f64e419..390e7f6 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp
@@ -13,7 +13,10 @@
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Util/Analysis/Explorer.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -22,6 +25,8 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#define DEBUG_TYPE "iree-flow-fold-unit-extent-dims"
+
namespace mlir::iree_compiler::IREE::Flow {
#define GEN_PASS_DEF_FOLDUNITEXTENTDIMSPASS
@@ -46,15 +51,24 @@
}
auto newGlobalType = globalType.clone(newShape);
auto initialValue = global.getGlobalInitialValue();
- // TODO: Handle non-uninitialized cases.
- auto uninitializedAttr =
- llvm::dyn_cast_if_present<IREE::Util::UninitializedAttr>(initialValue);
- if (initialValue && !uninitializedAttr)
+ if (!initialValue)
return success();
- TypedAttr newInitialValue;
- if (initialValue) {
- newInitialValue = IREE::Util::UninitializedAttr::get(rewriter.getContext(),
- newGlobalType);
+ // TODO: Handle other cases
+ auto newInitialValue =
+ llvm::TypeSwitch<Attribute, Attribute>(initialValue)
+ .Case<IREE::Util::UninitializedAttr>([&](Attribute) {
+ return IREE::Util::UninitializedAttr::get(rewriter.getContext(),
+ newGlobalType);
+ })
+ .Case<IREE::Stream::NamedParameterAttr>(
+ [&](IREE::Stream::NamedParameterAttr attr) {
+ return IREE::Stream::NamedParameterAttr::get(
+ rewriter.getContext(), newGlobalType, attr.getScope(),
+ attr.getKey(), attr.getConfig());
+ })
+ .Default([&](Attribute) { return nullptr; });
+ if (!newInitialValue) {
+ return success();
}
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(global);
@@ -101,19 +115,19 @@
} // namespace
void FoldUnitExtentDimsPass::runOnOperation() {
- auto funcOp = getOperation();
+ auto moduleOp = getOperation();
MLIRContext *context = &getContext();
- Explorer explorer(funcOp, TraversalAction::RECURSE);
+ SymbolTable moduleSymbols(moduleOp);
+ Explorer explorer(moduleOp, TraversalAction::RECURSE);
explorer.initialize();
IRRewriter rewriter(context);
- SymbolTable moduleSymbols(funcOp);
// Fold unit dims of GlobalOpInterface ops.
explorer.forEachGlobal([&](const Explorer::GlobalInfo *globalInfo) {
IREE::Util::GlobalOpInterface global = globalInfo->op;
auto tensorType = dyn_cast<RankedTensorType>(global.getGlobalType());
- if (!tensorType || !global.isGlobalPrivate() || !global.isGlobalMutable()) {
+ if (!tensorType || !global.isGlobalPrivate()) {
return;
}
if (llvm::none_of(tensorType.getShape(),
@@ -142,7 +156,7 @@
};
linalg::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, options);
linalg::populateMoveInitOperandsToInputPattern(foldUnitDimsPatterns);
- if (failed(applyPatternsAndFoldGreedily(funcOp,
+ if (failed(applyPatternsAndFoldGreedily(moduleOp,
std::move(foldUnitDimsPatterns)))) {
return signalPassFailure();
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir
index 611c2da..e652c5e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir
@@ -91,3 +91,18 @@
// CHECK: util.func public @no_fold_global_unit_dims
// CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<1x32x1x1x64xf32>
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[LOAD]]
+
+// -----
+
+module @fold_stream_parameter {
+ util.global private mutable @global = #stream.parameter.named<"module"::"global"> : tensor<1x1x10xf32>
+ util.func public @fold_stream_parameter() -> tensor<1x1x10xf32> {
+ %global = util.global.load @global : tensor<1x1x10xf32>
+ util.return %global : tensor<1x1x10xf32>
+ }
+}
+
+// CHECK: module @fold_stream_parameter
+// CHECK: util.global private mutable @[[GLOBAL:.+]] = #stream.parameter.named<"module"::"global"> : tensor<10xf32>
+// CHECK: util.func public @fold_stream_parameter
+// CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<10xf32>