Revert adding unit dim folding to GlobalOps (#16708)
This reverts c07d1102dc9f25315f2c9b517325c97bc8bff10b and the dependent
commit a86b8bfa9dc6b077b7882fcf88055ef197d1cc8a because it caused some
regressions in llama2.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp
index 83b0ebb..6b49c75 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FoldUnitExtentDims.cpp
@@ -14,7 +14,6 @@
#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
-#include "iree/compiler/Dialect/Util/Analysis/Explorer.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -25,71 +24,6 @@
namespace mlir::iree_compiler::IREE::Flow {
-//===----------------------------------------------------------------------===//
-// Pass helpers
-//===----------------------------------------------------------------------===//
-
-static LogicalResult
-foldUnitDimsOnGlobal(IRRewriter &rewriter, IREE::Util::GlobalOpInterface global,
- SmallVector<IREE::Util::GlobalLoadOpInterface> loadOps,
- SmallVector<IREE::Util::GlobalStoreOpInterface> storeOps,
- SymbolTable moduleSymbols) {
- // Create a new transformed GlobalOp.
- SmallVector<int64_t> newShape;
- auto globalType = cast<RankedTensorType>(global.getGlobalType());
- for (auto size : globalType.getShape()) {
- if (size != 1) {
- newShape.push_back(size);
- }
- }
- auto newGlobalType = globalType.clone(newShape);
- auto initialValue = global.getGlobalInitialValue();
- // TODO: Handle non-uninitialized cases.
- auto uninitializedAttr =
- llvm::dyn_cast_if_present<IREE::Util::UninitializedAttr>(initialValue);
- if (initialValue && !uninitializedAttr)
- return success();
- TypedAttr newInitialValue;
- if (initialValue) {
- newInitialValue = IREE::Util::UninitializedAttr::get(rewriter.getContext(),
- newGlobalType);
- }
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPoint(global);
- auto newGlobal =
- clone(rewriter, global, global->getResultTypes(), global->getOperands());
- newGlobal.setGlobalType(newGlobalType);
- newGlobal.setGlobalInitialValue(newInitialValue);
-
- // Rewrite loads and stores to use the new global.
- auto expandShapeReInds =
- getReassociationIndicesForReshape(globalType, newGlobalType);
- if (!expandShapeReInds) {
- return failure();
- }
-
- for (auto load : loadOps) {
- rewriter.setInsertionPoint(load);
- auto newLoad = clone(rewriter, load, {newGlobalType}, load->getOperands());
- newLoad.setGlobalAttr(FlatSymbolRefAttr::get(newGlobal.getGlobalName()));
- rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
- load, globalType, newLoad->getResult(0), expandShapeReInds.value());
- }
- for (auto store : storeOps) {
- rewriter.setInsertionPoint(store);
- Value collapse = rewriter.create<tensor::CollapseShapeOp>(
- store.getLoc(), newGlobalType, store->getOperand(0),
- expandShapeReInds.value());
- auto newStore =
- clone(rewriter, store, store->getResultTypes(), store->getOperands());
- newStore.setGlobalAttr(FlatSymbolRefAttr::get(newGlobal.getGlobalName()));
- newStore.setStoredGlobalValue(collapse);
- rewriter.eraseOp(store);
- }
- rewriter.eraseOp(global);
- return success();
-}
-
namespace {
struct FoldUnitExtentDimsPass
: public FoldUnitExtentDimsBase<FoldUnitExtentDimsPass> {
@@ -104,35 +38,8 @@
} // namespace
void FoldUnitExtentDimsPass::runOnOperation() {
- auto moduleOp = getOperation();
+ Operation *funcOp = getOperation();
MLIRContext *context = &getContext();
- Explorer explorer(moduleOp, TraversalAction::RECURSE);
- explorer.initialize();
- IRRewriter rewriter(context);
- SymbolTable moduleSymbols(moduleOp);
-
- // Fold unit dims of GlobalOpInterface ops.
- explorer.forEachGlobal([&](const Explorer::GlobalInfo *globalInfo) {
- IREE::Util::GlobalOpInterface global = globalInfo->op;
- auto tensorType = dyn_cast<RankedTensorType>(global.getGlobalType());
- if (!tensorType || !global.isGlobalPrivate() || !global.isGlobalMutable()) {
- return;
- }
- if (llvm::none_of(tensorType.getShape(),
- [](int64_t size) { return size == 1; })) {
- return;
- }
- SmallVector<IREE::Util::GlobalLoadOpInterface> loadOps =
- llvm::to_vector(globalInfo->getLoads());
- SmallVector<IREE::Util::GlobalStoreOpInterface> storeOps =
- llvm::to_vector(globalInfo->getStores());
- if (failed(foldUnitDimsOnGlobal(rewriter, global, loadOps, storeOps,
- moduleSymbols))) {
- return signalPassFailure();
- }
- });
-
- // Fold unit dims on other operations.
RewritePatternSet foldUnitDimsPatterns(context);
linalg::ControlDropUnitDims options;
auto defaultFn = options.controlFn;
@@ -145,13 +52,14 @@
};
linalg::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, options);
linalg::populateMoveInitOperandsToInputPattern(foldUnitDimsPatterns);
- if (failed(applyPatternsAndFoldGreedily(moduleOp,
+ if (failed(applyPatternsAndFoldGreedily(funcOp,
std::move(foldUnitDimsPatterns)))) {
return signalPassFailure();
}
}
-std::unique_ptr<OperationPass<ModuleOp>> createFoldUnitExtentDimsPass() {
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createFoldUnitExtentDimsPass() {
return std::make_unique<FoldUnitExtentDimsPass>();
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
index d0e8a76..60cf678 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -59,7 +59,8 @@
// Create a pass that imports upstream patterns to fold unit extent dims
// but with IREE control.
-std::unique_ptr<OperationPass<ModuleOp>> createFoldUnitExtentDimsPass();
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createFoldUnitExtentDimsPass();
// Creates a pass to fuse Linalg operations on tensors.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
index dee6a41..743b19e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -58,7 +58,7 @@
}
def FoldUnitExtentDims :
- Pass<"iree-flow-fold-unit-extent-dims", "mlir::ModuleOp"> {
+ InterfacePass<"iree-flow-fold-unit-extent-dims", "mlir::FunctionOpInterface"> {
let summary = "Fold unit extent dimension of operations";
let constructor = "mlir::iree_compiler::IREE::Flow::createFoldUnitExtentDimsPass()";
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir
index 9078c86..cc9f684 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fold_unit_dims.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(iree-flow-fold-unit-extent-dims)" --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-fold-unit-extent-dims))" %s | FileCheck %s
util.func public @no_fold_unit_dims_in_dispatches(%arg0 : tensor<1x1x10xf32>) -> tensor<1x1x10xf32> {
%0 = tensor.empty() : tensor<1x1x10xf32>
@@ -21,73 +21,3 @@
// CHECK-SAME: ins(%[[ARG0]] : tensor<1x1x10xf32>)
// CHECK: flow.return %[[GENERIC]]
// CHECK: util.return %[[DISPATCH]]
-
-// -----
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-#map1 = affine_map<(d0, d1) -> (0, 0)>
-module @fold_unit_dims {
- util.global private mutable @global {inlining_policy = #util.inline.never} = #util.uninitialized : tensor<1x32x1x1x64xf32>
- util.global private mutable @unit_global = #util.uninitialized : tensor<1x1xf32>
- util.func public @fold_global_unit_dims() -> tensor<32x64xf32> {
- %global = util.global.load @global : tensor<1x32x1x1x64xf32>
- %unit_global = util.global.load @unit_global : tensor<1x1xf32>
- %collapsed = tensor.collapse_shape %global [[0, 1], [2, 3, 4]] : tensor<1x32x1x1x64xf32> into tensor<32x64xf32>
- %0 = tensor.empty() : tensor<32x64xf32>
- %1 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%collapsed, %unit_global : tensor<32x64xf32>, tensor<1x1xf32>) outs(%0 : tensor<32x64xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %2 = arith.addf %in, %in_0 : f32
- linalg.yield %2 : f32
- } -> tensor<32x64xf32>
- %expanded = tensor.expand_shape %1 [[0, 1], [2, 3, 4]] : tensor<32x64xf32> into tensor<1x32x1x1x64xf32>
- util.global.store %expanded, @global : tensor<1x32x1x1x64xf32>
- util.return %1 : tensor<32x64xf32>
- }
-}
-
-// CHECK: module @fold_unit_dims
-// CHECK: util.global private mutable @[[GLOBAL:.+]] {inlining_policy = #util.inline.never} = #util.uninitialized : tensor<32x64xf32>
-// CHECK: util.global private mutable @[[UNIT_GLOBAL:.+]] = #util.uninitialized : tensor<f32>
-// CHECK: util.func public @fold_global_unit_dims
-// CHECK: %[[LOAD0:.+]] = util.global.load @[[GLOBAL]] : tensor<32x64xf32>
-// CHECK: %[[LOAD1:.+]] = util.global.load @[[UNIT_GLOBAL]] : tensor<f32>
-// CHECK: %[[GENERIC:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[LOAD0]], %[[LOAD1]]
-// CHECK: util.global.store %[[GENERIC]], @[[GLOBAL]] : tensor<32x64xf32>
-// CHECK: util.return %[[GENERIC]]
-
-// -----
-
-module @no_fold_immutable {
- util.global private @global : tensor<1x32x1x1x64xf32>
- util.func public @no_fold_global_unit_dims() -> tensor<32x64xf32> {
- %global = util.global.load @global : tensor<1x32x1x1x64xf32>
- %collapsed = tensor.collapse_shape %global [[0, 1], [2, 3, 4]] : tensor<1x32x1x1x64xf32> into tensor<32x64xf32>
- util.return %collapsed : tensor<32x64xf32>
- }
-}
-
-// CHECK: module @no_fold_immutable
-// CHECK: util.global private @[[GLOBAL:.+]] : tensor<1x32x1x1x64xf32>
-// CHECK: util.func public @no_fold_global_unit_dims
-// CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<1x32x1x1x64xf32>
-// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[LOAD]]
-// CHECK: util.return %[[COLLAPSE]]
-
-// -----
-
-module @no_fold_public {
- util.global public mutable @global : tensor<1x32x1x1x64xf32>
- util.func public @no_fold_global_unit_dims() -> tensor<32x64xf32> {
- %global = util.global.load @global : tensor<1x32x1x1x64xf32>
- %collapsed = tensor.collapse_shape %global [[0, 1], [2, 3, 4]] : tensor<1x32x1x1x64xf32> into tensor<32x64xf32>
- util.return %collapsed : tensor<32x64xf32>
- }
-}
-
-// CHECK: module @no_fold_public
-// CHECK: util.global public mutable @[[GLOBAL:.+]] : tensor<1x32x1x1x64xf32>
-// CHECK: util.func public @no_fold_global_unit_dims
-// CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<1x32x1x1x64xf32>
-// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[LOAD]]
-// CHECK: util.return %[[COLLAPSE]]
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
index ac27f61..8ac8818 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
@@ -109,9 +109,8 @@
// dims as the unit dim folding pass updates indexing maps and is better
// at working with generics. By this point we have already done any
// specialized raising and the op names are no longer useful.
- .addPass(createGeneralizeLinalgNamedOpsPass);
- mainPassManager.addPass(IREE::Flow::createFoldUnitExtentDimsPass());
- FunctionLikeNest(mainPassManager)
+ .addPass(createGeneralizeLinalgNamedOpsPass)
+ .addPass(IREE::Flow::createFoldUnitExtentDimsPass)
.addPredicatedPass(clEnableFuseSiluHorizontalMatmul,
createFuseSiluHorizontalMatmulPass)
.addPredicatedPass(clEnableDemoteContractionInputsToBF16,
diff --git a/compiler/src/iree/compiler/Preprocessing/Passes.cpp b/compiler/src/iree/compiler/Preprocessing/Passes.cpp
index 0521a63..3d96476 100644
--- a/compiler/src/iree/compiler/Preprocessing/Passes.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Passes.cpp
@@ -92,8 +92,8 @@
.addPass(GlobalOptimization::createDetachElementwiseFromNamedOpsPass)
.addPass(mlir::createLinalgNamedOpConversionPass)
.addPass(GlobalOptimization::createConvert1X1FilterConv2DToMatmulPass)
- .addPass(createConvertConvToChannelsLastPass);
- passManager.addPass(IREE::Flow::createFoldUnitExtentDimsPass());
+ .addPass(createConvertConvToChannelsLastPass)
+ .addPass(IREE::Flow::createFoldUnitExtentDimsPass);
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createCSEPass());
}