Revert "[codegen] more consumer fusion (#21521)" (#21819)
This reverts commit 4d91ffb09a8e48844b6460edce58ccba245bced7.
The above commit causes failure in compilation of llama 405B fp4 model.
Ticket to track the same https://github.com/iree-org/iree/issues/21814
Signed-off-by: Praveen G <praveen.g2@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp b/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp
index 64de60c..21a269a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp
@@ -9,7 +9,6 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
@@ -88,51 +87,16 @@
}
}
-namespace {
-// Entry for the pseudo-priority queue of consumer fusion candidates. Contains
-// the consumer (fusableUser) that can be fused and the set of slice operations
-// in the loop to fuse into that feed the consumer.
-struct ConsumerFusionQueueEntry {
- ConsumerFusionQueueEntry(SmallVector<Operation *> &&slices,
- Operation *fusableUser)
- : slices(std::move(slices)), fusableUser(fusableUser) {}
-
- SmallVector<Operation *> slices;
- Operation *fusableUser;
-};
-} // namespace
-
FailureOr<std::queue<Operation *>>
-fuseConsumersIntoForall(RewriterBase &rewriter, ArrayRef<Operation *> tiledOps,
+fuseConsumersIntoForall(RewriterBase &rewriter, Operation *tiledOp,
MutableArrayRef<LoopLikeOpInterface> loops,
std::function<bool(Operation *)> filterFn) {
// Collect the candidate slices which can be potential consumers that can be
- // fused. Keep them in a vector reverse-sorted by dominance: the candidate
- // dominating others comes last (so it can be cheaply popped from the vector).
- // The most-dominating candidate is to be fused first since not fusing it may
- // prevent dominated candidates to be fused:
- //
- // A
- // |
- // B
- // / |
- // | D
- // | /
- // C
- //
- // here, B must be fused before both C and D, and D must be fused before C.
- // Candidates are kept in a vector rather than a priority queue since we may
- // update them as fusion happens, in particular, more slices may need to be
- // handled. For example, fusing B with A will create a slice of B that will
- // need to be handled correctly.
- SmallVector<ConsumerFusionQueueEntry> candidates;
+ // fused.
+ std::queue<SmallVector<Operation *>> candidates;
llvm::SmallDenseSet<tensor::ParallelInsertSliceOp> allCandidates;
auto addCandidateSlices = [&candidates, &allCandidates,
&filterFn](Operation *fusedOp) {
- // Dominance info recreated since op creation/movement in the fusion logic
- // invalidates it anyway.
- DominanceInfo dominanceInfo;
-
for (auto *userOp : fusedOp->getResults().getUsers()) {
auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(userOp);
if (!sliceOp || allCandidates.contains(sliceOp)) {
@@ -171,38 +135,22 @@
allCandidates.insert_range(slices);
}
if (!fusedSlices.empty()) {
- ConsumerFusionQueueEntry entry(std::move(fusedSlices), fusableUser);
-
- // Comparator that puts the dominating user last.
- auto comp = [&](const ConsumerFusionQueueEntry &lhs,
- const ConsumerFusionQueueEntry &rhs) {
- return dominanceInfo.properlyDominates(rhs.fusableUser,
- lhs.fusableUser);
- };
-
- // If the fusable user is already a candidate, update it with the new
- // list of slices to handle. Otherwise, insert it into the right
- // position based on dominance.
- auto *it = llvm::lower_bound(candidates, entry, comp);
- if (it != candidates.end() && it->fusableUser == fusableUser)
- *it = std::move(entry);
- else
- candidates.insert(it, std::move(entry));
+ candidates.emplace(std::move(fusedSlices));
}
}
};
- // Add slices from all tiled ops, not only the "main" one.
- for (Operation *tiledOp : tiledOps)
- addCandidateSlices(tiledOp);
+ addCandidateSlices(tiledOp);
std::queue<Operation *> newFusionOpportunities;
while (!candidates.empty()) {
- // Get the next candidate.
- ConsumerFusionQueueEntry entry = candidates.pop_back_val();
+ // Traverse the slices in BFS fashion.
+ SmallVector<Operation *> candidateSlices = candidates.front();
+ candidates.pop();
FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
- mlir::scf::tileAndFuseConsumerOfSlices(rewriter, entry.slices, loops);
+ mlir::scf::tileAndFuseConsumerOfSlices(rewriter, candidateSlices,
+ loops);
if (failed(fusedResult)) {
return failure();
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.h b/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.h
index 303f831..690df25 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.h
@@ -31,11 +31,12 @@
void collectTiledAndFusedOps(Operation *rootOp,
llvm::SmallDenseSet<Operation *> &result);
-/// Fuse all consumers of the given `tiledOps` into the surrounding `scf.forall`
-/// unless specified otherwise by `filterFn`. Returns a list of new
-/// `tensor.extract_slice` ops with new fusion opportunities.
+/// Fuse all consumers of the given `tiledOp` into the surrounding `scf.forall`.
+/// Returns a list of new `tensor.extract_slice` ops with new fusion
+/// opportunities, as well as the new surrounding `scf.forall` (because consumer
+/// fusion replaces the loop).
FailureOr<std::queue<Operation *>> fuseConsumersIntoForall(
- RewriterBase &rewriter, ArrayRef<Operation *> tiledOps,
+ RewriterBase &rewriter, Operation *tiledOp,
MutableArrayRef<LoopLikeOpInterface> loops,
std::function<bool(Operation *)> filterFn = [](Operation *) {
return true;
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
index 978cae4..9650f73 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
@@ -210,18 +210,6 @@
// Pass implementation.
//===---------------------------------------------------------------------===//
-/// Returns true if any value produced by `producer` is used as an init value
-/// for the DPS `user`. Returns false if the user is not in DPS.
-static bool isUsedAsInit(Operation *producer, Operation *user) {
- auto dpsIface = dyn_cast<DestinationStyleOpInterface>(user);
- if (!dpsIface)
- return false;
- ValueRange results = producer->getResults();
- return llvm::any_of(dpsIface.getDpsInits(), [&](Value operand) {
- return llvm::is_contained(results, operand);
- });
-}
-
void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
auto funcOp = getOperation();
auto *context = &getContext();
@@ -244,14 +232,10 @@
llvm::DenseSet<Operation *> yieldReplacementsFor;
for (auto op : tiledAndFusedOps) {
- // Require replacement for values that are used after the main tilable op or
- // by ops that will definitely not be fused. Note that if a value is used as
- // an init of a DPS op, the user currently cannot be fused. Having a
- // replacement for it would attempt fusion and fail, so avoid such cases.
+ // If tiledAndFused ops doesn't contain the user; add an replacement
+ // for that.
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
- if (isUsedAsInit(op, user))
- return false;
- return dominanceInfo.properlyDominates(tilableOp, user) ||
+ return dominanceInfo.properlyDominates(tilableOp, user) &&
!tiledAndFusedOps.contains(user);
})) {
yieldReplacementsFor.insert(op);
@@ -333,18 +317,16 @@
});
}
std::swap(tileAndFuseResult->loops, tilingLoops);
-
+ Operation *rootTiledOp = tileAndFuseResult->tiledAndFusedOps.front();
FailureOr<std::queue<Operation *>> newFusionOpportunities =
- fuseConsumersIntoForall(
- rewriter, tileAndFuseResult->tiledAndFusedOps.getArrayRef(),
- tilingLoops, [&tiledAndFusedOps](Operation *op) {
- return tiledAndFusedOps.contains(op);
- });
+ fuseConsumersIntoForall(rewriter, rootTiledOp, tilingLoops,
+ [&tiledAndFusedOps](Operation *op) {
+ return tiledAndFusedOps.contains(op);
+ });
if (failed(newFusionOpportunities)) {
// Continue the work if the failure is allowed.
if (!verifyComputeOpsAfterDistribution(funcOp)) {
- tileAndFuseResult->tiledAndFusedOps.front()->emitOpError(
- "failed to fuse consumers");
+ rootTiledOp->emitOpError("failed to fuse consumers");
return signalPassFailure();
}
} else {
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir
index 5ddb2d4..e4503ff 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir
@@ -1083,107 +1083,3 @@
// CHECK: linalg.generic
// CHECK: scf.forall.in_parallel {
// CHECK: linalg.pack
-
-// -----
-
-// Adapted from layer normalization. The graph structure is as follows
-//
-// %14
-// / | \
-// / %15 %17
-// | | / |
-// | [%19] |
-// %21 | %22
-// | | |
-// v v v
-//
-// In particular, %21 and %22 are not users of the "main" tilable
-// operation but we still want them to be fused. %19, %21 and %22
-// all produce results returned from the function.
-//
-// Check that everything is fused and that there are three results
-// from the loop being produced and returned.
-//
-// CHECK-LABEL: @multi_result_consumer_fusion
-// CHECK-NOT: linalg.generic
-// CHECK: %[[LOOP:.+]]:3 = scf.forall (%[[I:.+]], %[[J:.+]]) in (16, 256) shared_outs(%[[OUT0:.+]] = %{{.+}}, %[[OUT1:.+]] = %{{.+}}, %[[OUT2:.+]] = %{{.+}})
-// CHECK: %[[v14:.+]] = linalg.generic
-// CHECK: arith.divf
-// CHECK: %[[v15:.+]] = linalg.generic
-// CHECK: arith.subf
-// CHECK: %[[v17:.+]] = linalg.generic
-// CHECK: arith.divf
-// CHECK: math.rsqrt
-// CHECK: %[[RES0:.+]] = linalg.generic
-// CHECK: arith.mulf
-// CHECK: arith.extf
-// CHECK: arith.mulf
-// CHECK: arith.extf
-// CHECK: arith.addf
-// CHECK: arith.truncf
-// CHECK: %[[RES1:.+]] = linalg.generic {{.*}} ins(%[[v14]] :
-// CHECK: arith.truncf
-// CHECK: %[[RES2:.+]] = linalg.generic {{.*}} ins(%[[v17]] :
-// CHECK: arith.truncf
-// CHECK: scf.forall.in_parallel
-// CHECK: tensor.parallel_insert_slice %[[RES0]] into %[[OUT0]]
-// CHECK: tensor.parallel_insert_slice %[[RES1]] into %[[OUT1]]
-// CHECK: tensor.parallel_insert_slice %[[RES2]] into %[[OUT2]]
-// CHECK-NOT: linalg.generic
-// CHECK: return %[[LOOP]]#0, %[[LOOP]]#1, %[[LOOP]]#2
-func.func @multi_result_consumer_fusion(
- %6: tensor<16x256x2048xbf16>,
- %7: tensor<2048xbf16>,
- %8: tensor<2048xbf16>,
- %10: tensor<16x256x2048xf32>,
- %13: tensor<16x256xf32>
-) -> (
- tensor<16x256x2048xbf16>,
- tensor<16x256xbf16>,
- tensor<16x256xbf16>
-) {
- %cst = arith.constant 0.000000e+00 : f32
- %cst_0 = arith.constant 2.048000e+03 : f32
- %c0 = arith.constant 0 : index
- %9 = tensor.empty() : tensor<16x256x2048xf32>
- %11 = tensor.empty() : tensor<16x256xf32>
- %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%13 : tensor<16x256xf32>) outs(%11 : tensor<16x256xf32>) {
- ^bb0(%in: f32, %out: f32):
- %23 = arith.divf %in, %cst_0 : f32
- linalg.yield %23 : f32
- } -> tensor<16x256xf32>
- %15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%10, %14 : tensor<16x256x2048xf32>, tensor<16x256xf32>) outs(%9 : tensor<16x256x2048xf32>) {
- ^bb0(%in: f32, %in_1: f32, %out: f32):
- %23 = arith.subf %in, %in_1 : f32
- linalg.yield %23 : f32
- } -> tensor<16x256x2048xf32>
- %17 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%14 : tensor<16x256xf32>) outs(%11 : tensor<16x256xf32>) {
- ^bb0(%in: f32, %out: f32):
- %23 = arith.divf %in, %cst_0 : f32
- %24 = math.rsqrt %23 : f32
- linalg.yield %24 : f32
- } -> tensor<16x256xf32>
- %18 = tensor.empty() : tensor<16x256x2048xbf16>
- %19 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15, %17, %7, %8 : tensor<16x256x2048xf32>, tensor<16x256xf32>, tensor<2048xbf16>, tensor<2048xbf16>) outs(%18 : tensor<16x256x2048xbf16>) attrs = {lowering_config = #iree_gpu.lowering_config<{lane_basis = [[1, 1, 64], [0, 1, 2]], reduction = [0, 0, 256], subgroup_basis = [[1, 1, 1], [0, 1, 2]], thread = [0, 0, 4], workgroup = [1, 1, 0]}>} {
- ^bb0(%in: f32, %in_1: f32, %in_2: bf16, %in_3: bf16, %out: bf16):
- %23 = arith.mulf %in, %in_1 : f32
- %24 = arith.extf %in_2 : bf16 to f32
- %25 = arith.mulf %23, %24 : f32
- %26 = arith.extf %in_3 : bf16 to f32
- %27 = arith.addf %25, %26 : f32
- %28 = arith.truncf %27 : f32 to bf16
- linalg.yield %28 : bf16
- } -> tensor<16x256x2048xbf16>
- %20 = tensor.empty() : tensor<16x256xbf16>
- %21 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%14 : tensor<16x256xf32>) outs(%20 : tensor<16x256xbf16>) {
- ^bb0(%in: f32, %out: bf16):
- %23 = arith.truncf %in : f32 to bf16
- linalg.yield %23 : bf16
- } -> tensor<16x256xbf16>
- %22 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%17 : tensor<16x256xf32>) outs(%20 : tensor<16x256xbf16>) {
- ^bb0(%in: f32, %out: bf16):
- %23 = arith.truncf %in : f32 to bf16
- linalg.yield %23 : bf16
- } -> tensor<16x256xbf16>
- return %19, %21, %22 : tensor<16x256x2048xbf16>, tensor<16x256xbf16>, tensor<16x256xbf16>
-}