blob: 9f4c0e5ce7f943147e1029242a48d72b0aa08312 [file] [log] [blame] [edit]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Stream/Analysis/Partitioning.h"
#include "iree/compiler/Dialect/Stream/Analysis/ResourceHazards.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/PatternMatch.h"
#define DEBUG_TYPE "iree-stream-partitioning"
namespace mlir::iree_compiler::IREE::Stream {
// Returns an AsmState at the ancestor to |block| that is isolated from above.
// Returns nullptr if debug dumps of partitioning is disabled.
static std::unique_ptr<AsmState> getRootAsmState(Block *block) {
LLVM_DEBUG({
auto *rootOp = block->getParentOp();
while (auto parentOp = rootOp->getParentOp()) {
if (!isa<IREE::Stream::TimelineOpInterface>(parentOp) &&
parentOp->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
rootOp = parentOp;
break;
}
rootOp = parentOp;
}
return std::make_unique<AsmState>(rootOp);
});
return nullptr;
}
struct OpInfo {
// Which partitions the op is contained within.
llvm::BitVector membership;
// Which partitions transitively depend on this operation.
llvm::BitVector hazards;
// Hazards specifically from nested region captures creating circular
// dependencies. These must always be respected, even for operations with
// preferCloneToConsumers.
llvm::BitVector nestedRegionHazards;
};
struct PartitionBuilder {
unsigned ordinal;
// Affinity of the partition.
IREE::Stream::AffinityAttr affinity;
// Ops present in the partition; ops may be present in multiple partitions.
SetVector<Operation *> ops;
// Ops that were cloned and are known not to have their values escape.
DenseSet<Operation *> clonedOps;
// Which partitions transitively depend on this partition.
llvm::BitVector hazards;
void insert(Operation *op, OpInfo &opInfo) {
if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
affinity = affinity ? affinity.joinAND(affinityOp.getAffinityAttr())
: affinityOp.getAffinityAttr();
}
opInfo.membership.set(ordinal);
if (opInfo.hazards.size() > ordinal) {
opInfo.hazards.reset(ordinal);
}
ops.insert(op);
hazards |= opInfo.hazards;
hazards |= opInfo.nestedRegionHazards;
}
};
// Find the effective OpInfo for a user operation.
// Returns the OpInfo* to use for hazard tracking, or nullptr if the user
// should be skipped (either because hazards were already set or not found).
static OpInfo *getEffectiveUserInfo(Operation *user, Block *block,
DenseMap<Operation *, OpInfo> &opInfos,
OpInfo &opInfo) {
// Check if user is in the same block we're partitioning.
if (user->getBlock() == block) {
// Direct lookup - user in same block.
LLVM_DEBUG(llvm::dbgs() << " User in same block\n");
auto userInfoIt = opInfos.find(user);
if (userInfoIt != opInfos.end()) {
return &userInfoIt->second;
}
return nullptr;
}
LLVM_DEBUG(llvm::dbgs() << " User in different block (nested region)\n");
// User is in a nested region - find the containing operation in our block.
// This handles cases like scf.for where operations inside the loop body
// use values from the parent block. We treat the containing operation
// (scf.for) as a proxy for all its nested users.
Operation *containing = user->getParentOp();
while (containing && containing->getBlock() != block) {
containing = containing->getParentOp();
}
if (!containing) {
return nullptr;
}
// Use the containing op's info as proxy for nested users.
auto containingInfoIt = opInfos.find(containing);
if (containingInfoIt != opInfos.end()) {
// Only use containing op's info if it's streamable.
// Non-streamable ops (scf.for, scf.if) need special circular dependency
// tracking via nestedRegionHazards to prevent grouping captured values
// with operations that consume the containing op's results.
if (isa<IREE::Stream::StreamableOpInterface>(containing)) {
return &containingInfoIt->second;
}
// Fall through to set nestedRegionHazards for non-streamable containing
// ops.
}
// The containing operation is not streamable (e.g., scf.for).
// Values captured into it create circular dependency hazards with any
// partitions that use the containing operation's results.
for (auto result : containing->getResults()) {
for (auto resultUser : result.getUsers()) {
if (resultUser->getBlock() == block) {
auto resultUserInfoIt = opInfos.find(resultUser);
if (resultUserInfoIt != opInfos.end()) {
opInfo.nestedRegionHazards |= resultUserInfoIt->second.membership;
opInfo.nestedRegionHazards |= resultUserInfoIt->second.hazards;
LLVM_DEBUG({
llvm::dbgs() << "Setting nestedRegionHazards for captured value, "
<< "hazard count: "
<< opInfo.nestedRegionHazards.count() << "\n";
});
}
}
}
}
return nullptr;
}
// This is terrible. See Stream/Analysis/Partition.h for a description of what
// a real implementation would do. We want cost modeling for tie breakers when
// an op could be in multiple partitions, cloning for ops that are not worth
// spanning partitions (like splats), etc.
PartitionSet
partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
Block *block) {
PartitionSet partitionSet;
DenseMap<Operation *, OpInfo> opInfos;
SmallVector<std::unique_ptr<PartitionBuilder>> builders;
llvm::BitVector usableBuilders;
auto willCreateCircularDependencyBetweenPartitions =
[&](unsigned sourceOrdinal, unsigned targetOrdinal) -> bool {
// Returns:
// If we are to make partition with ordinal targetOrdinal to
// depend on partition with ordinal sourceOrdinal,
// will this create a circular dependency.
if (sourceOrdinal == targetOrdinal) {
return false;
}
return builders[sourceOrdinal]->hazards.size() > targetOrdinal &&
builders[sourceOrdinal]->hazards[targetOrdinal];
};
auto canAddOpToPartition = [&](Operation &op, OpInfo &opInfo,
unsigned partitionOrdinal,
bool check_for_clones = true) {
auto streamableOp = dyn_cast<IREE::Stream::StreamableOpInterface>(op);
if (!streamableOp) {
return false;
}
// Most ops should have affinity at this point. If they do not then we allow
// them to be placed anywhere (and whatever performance implications that
// has is on the higher layers for not explicitly saying).
IREE::Stream::AffinityAttr affinityAttr;
if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
affinityAttr = affinityOp.getAffinityAttr();
}
if (!IREE::Stream::AffinityAttr::canExecuteTogether(
affinityAttr, builders[partitionOrdinal]->affinity)) {
return false;
}
bool preferCloneToConsumers =
check_for_clones && streamableOp.preferCloneToConsumers();
llvm::BitVector *opHazards = nullptr;
llvm::BitVector opHazardsInCandidatePartition;
if (preferCloneToConsumers) {
// If we are cloning we care only about users that are a part of the
// candidate partition.
// Here we would need to walk further down the users if a user is also
// cloned into the partition. This will be useful if we have a block of
// cloneable ops. If left like that, other than the inefficiency,
// it should not produce invalid partitioning.
opHazards = &opHazardsInCandidatePartition;
for (auto user : op.getUsers()) {
if (builders[partitionOrdinal]->ops.contains(user)) {
opHazardsInCandidatePartition |= opInfos[user].hazards;
}
}
// For preferCloneToConsumers ops, we DON'T include nestedRegionHazards
// because cloning into multiple partitions solves the circular
// dependency. The op will be duplicated in each consumer partition, so
// there's no shared state causing circular dependencies.
} else {
// For non-clonable ops, nestedRegionHazards must be respected to prevent
// circular dependencies from nested region captures. Unlike normal
// hazards which propagate through use-def chains, nestedRegionHazards may
// not be fully captured in opInfo.hazards when the containing operation
// (e.g., scf.for) is not streamable and returns nullptr from
// getEffectiveUserInfo.
opHazardsInCandidatePartition = opInfo.hazards;
opHazardsInCandidatePartition |= opInfo.nestedRegionHazards;
opHazards = &opHazardsInCandidatePartition;
}
// Check nestedRegionHazards separately first. Unlike normal hazards
// (which indicate dependencies), nestedRegionHazards indicate partitions
// the op CANNOT join due to circular dependency constraints.
for (auto nestedHazardOrdinal : opInfo.nestedRegionHazards.set_bits()) {
if (nestedHazardOrdinal == partitionOrdinal) {
// This op cannot be in this partition due to nested region capture
// circular dependency constraints.
return false;
}
}
for (auto opHazardOrdinal : opHazards->set_bits()) {
if (partitionOrdinal < opHazardOrdinal) {
// Reject partition ordering that would require partition sorting.
// TODO: It is probably more optimal to reorder the partitions after
// their formation based on their dependency graph instead of rejecting
// here. Since this is considered not a good partitioning algorithm
// and will probably get removed, we leave it like that.
return false;
}
// Check for formation of circular dependency between partitions.
if (willCreateCircularDependencyBetweenPartitions(opHazardOrdinal,
partitionOrdinal)) {
return false;
}
}
return true;
};
auto asmState = getRootAsmState(block);
// Build a map from values to operations that consume them in nested regions.
// This allows us to detect when an operation with nested regions (like
// scf.for) consumes values produced by other operations that may create
// hazards even though there's no direct use-def edge.
DenseMap<Value, SmallVector<Operation *>> nestedConsumers;
for (auto &op : *block) {
// Only check operations with regions that are not streamable.
// Streamable ops are handled normally but non-streamable ops with regions
// (like scf.for, scf.if, scf.while) need special handling.
if (op.getNumRegions() == 0 ||
isa<IREE::Stream::StreamableOpInterface>(op)) {
continue;
}
// Collect all values consumed by this operation (including in nested
// regions).
SetVector<Value> consumedValues;
collectConsumedValues(&op, consumedValues);
// For each consumed value record this op as a nested consumer.
for (auto value : consumedValues) {
// Only track values defined in the same block (not block arguments).
auto definingOp = value.getDefiningOp();
if (definingOp && definingOp->getBlock() == block) {
nestedConsumers[value].push_back(&op);
}
}
}
llvm::DenseMap<Operation *, llvm::SmallVector<Operation *>> syncOps;
for (auto &op : llvm::reverse(*block)) {
LLVM_DEBUG({
llvm::dbgs() << "====\nPartitioning op:\n";
op.print(llvm::dbgs(), *asmState);
llvm::dbgs() << "\n";
});
// Skip constants; they just add noise (and since they are heavily CSE'd
// they have lots of users to test).
if (op.hasTrait<OpTrait::ConstantLike>()) {
LLVM_DEBUG(llvm::dbgs() << "(ignoring constant)\n");
continue;
} else if (isa<IREE::Util::GlobalStoreOpInterface>(op)) {
// We ignore global stores as they are unobservable within an execution
// region - we must still block on loads though.
LLVM_DEBUG(llvm::dbgs() << "(ignoring global store)\n");
continue;
} else if (!isa<IREE::Stream::StreamableOpInterface>(op)) {
// Not a streamable op. If it has side-effects then we force a hazard on
// all builders so that we don't move ops across it.
if (!mlir::wouldOpBeTriviallyDead(&op)) {
LLVM_DEBUG({
llvm::dbgs() << "Side-effecting op forcing flush and freeze:\n";
op.print(llvm::dbgs(), *asmState);
llvm::dbgs() << "\n";
});
usableBuilders.reset();
}
// Even though not a streamable op we still want to track it below.
}
// Synchronizing operations should join with their producers if the producer
// is streamable.
if (dyn_cast<IREE::Stream::AsyncBarrierOp>(op) ||
dyn_cast<IREE::Stream::AsyncTransferOp>(op)) {
auto producer = op.getOperand(0).getDefiningOp();
auto streamable =
dyn_cast_if_present<IREE::Stream::StreamableOpInterface>(producer);
if (streamable) {
if (!syncOps.contains(producer)) {
syncOps[producer] = llvm::SmallVector<Operation *>();
}
syncOps[producer].push_back(&op);
LLVM_DEBUG({
llvm::dbgs() << "Skipping sync op for now \n";
op.print(llvm::dbgs(), *asmState);
llvm::dbgs() << "\n";
});
continue;
}
}
// Initialize op info for this op - whether streamable or not. We track
// transitive hazards on each op. Note that thanks to the ordering of ops
// in SSA form (_reversed here!_) we know that once we visit this op no
// partition created after it can ever depend on it if it doesn't here. This
// lets us keep the bitvectors small.
auto &opInfo = opInfos[&op];
opInfo.hazards.reserve(builders.size() + 1);
opInfo.hazards.resize(builders.size(), /*t=*/false);
// Set bits for each partition this op may be able to be placed into.
// We prune the set based on whether the users are part of a transitive
// dependency chain down the use-def chain to a partition.
llvm::BitVector consumers(builders.size(), /*t=*/false);
// Process direct users (common case).
for (auto user : op.getUsers()) {
OpInfo *effectiveUserInfo =
getEffectiveUserInfo(user, block, opInfos, opInfo);
if (!effectiveUserInfo) {
continue;
}
auto &userInfo = *effectiveUserInfo;
LLVM_DEBUG({
llvm::dbgs() << "Testing user:\n";
user->print(llvm::dbgs(), *asmState);
llvm::dbgs() << "\n";
for (auto membershipOrdinal : userInfo.membership.set_bits()) {
llvm::dbgs() << " member of partition " << membershipOrdinal << "\n";
}
for (auto hazardOrdinal : userInfo.hazards.set_bits()) {
llvm::dbgs() << " hazard w/ partition " << hazardOrdinal << "\n";
}
});
consumers |= userInfo.membership;
opInfo.hazards |= userInfo.membership;
opInfo.hazards |= userInfo.hazards;
}
// Process indirect users (operations that consume this op's results in
// their nested regions e.g. scf.for using a value in its body).
for (auto result : op.getResults()) {
auto nestedConsumersIt = nestedConsumers.find(result);
if (nestedConsumersIt != nestedConsumers.end()) {
for (auto *nestedUser : nestedConsumersIt->second) {
auto userInfoIt = opInfos.find(nestedUser);
if (userInfoIt == opInfos.end()) {
continue;
}
auto &userInfo = userInfoIt->second;
LLVM_DEBUG({
llvm::dbgs() << "Testing nested region user:\n";
nestedUser->print(llvm::dbgs(), *asmState);
llvm::dbgs() << "\n";
for (auto membershipOrdinal : userInfo.membership.set_bits()) {
llvm::dbgs() << " member of partition " << membershipOrdinal
<< "\n";
}
for (auto hazardOrdinal : userInfo.hazards.set_bits()) {
llvm::dbgs() << " hazard w/ partition " << hazardOrdinal << "\n";
}
});
// For nested consumers we only propagate hazards and not membership.
// The nested consumer itself is not streamable and won't be in a
// partition but it creates a dependency barrier.
opInfo.hazards |= userInfo.membership;
opInfo.hazards |= userInfo.hazards;
}
}
}
for (auto syncOp : syncOps[&op]) {
for (auto user : syncOp->getUsers()) {
OpInfo *effectiveUserInfo =
getEffectiveUserInfo(user, block, opInfos, opInfo);
if (!effectiveUserInfo) {
continue;
}
auto &userInfo = *effectiveUserInfo;
LLVM_DEBUG({
llvm::dbgs() << "Testing sync user:\n";
user->print(llvm::dbgs(), *asmState);
llvm::dbgs() << "\n";
for (auto membershipOrdinal : userInfo.membership.set_bits()) {
llvm::dbgs() << " member of partition " << membershipOrdinal
<< "\n";
}
for (auto hazardOrdinal : userInfo.hazards.set_bits()) {
llvm::dbgs() << " hazard w/ partition " << hazardOrdinal << "\n";
}
});
opInfo.hazards |= userInfo.membership;
opInfo.hazards |= userInfo.hazards;
consumers.reset();
}
}
// For any sync ops not use this ops results we need to put in a
// non-consumer block:
llvm::BitVector candidates(builders.size(), /*t=*/true);
candidates ^= opInfo.hazards;
candidates |= consumers;
candidates &= usableBuilders;
// Prune candidates that do not have a compatible affinity.
for (auto ordinal : candidates.set_bits()) {
if (!canAddOpToPartition(op, opInfo, ordinal)) {
LLVM_DEBUG(llvm::dbgs()
<< "Candidate partition " << ordinal << " incompatible\n");
candidates.reset(ordinal);
}
}
for (auto syncOp : syncOps[&op]) {
for (auto ordinal : candidates.set_bits()) {
if (!canAddOpToPartition(*syncOp, opInfo, ordinal)) {
LLVM_DEBUG(llvm::dbgs()
<< "Candidate partition " << ordinal << " incompatible\n");
candidates.reset(ordinal);
}
}
}
// If this op is not streamable then bail here; we've still setup the hazard
// map for following iteration.
auto streamableOp = dyn_cast<IREE::Stream::StreamableOpInterface>(op);
if (!streamableOp) {
LLVM_DEBUG(llvm::dbgs() << "Not streamable (skip)\n");
continue;
}
// If we prefer to clone to our consumers, but we are
// only cloning to a subset, we have to re-check our
// partitions as they may generate cycles.
if (streamableOp.preferCloneToConsumers()) {
auto tempCandidates = candidates;
tempCandidates &= consumers;
if (tempCandidates.count() != consumers.count()) {
// Prune candidates that do not have a compatible affinity.
for (auto ordinal : candidates.set_bits()) {
if (!canAddOpToPartition(op, opInfo, ordinal, false)) {
LLVM_DEBUG(llvm::dbgs() << "Candidate partition " << ordinal
<< " incompatible for clone\n");
candidates.reset(ordinal);
}
}
for (auto syncOp : syncOps[&op]) {
for (auto ordinal : candidates.set_bits()) {
if (!canAddOpToPartition(*syncOp, opInfo, ordinal, false)) {
LLVM_DEBUG(llvm::dbgs() << "Candidate partition " << ordinal
<< " incompatible for clone\n");
candidates.reset(ordinal);
}
}
}
}
}
// First see which partitions are consuming this that we can also safely
// move in to.
consumers &= candidates;
if (consumers.any()) {
candidates = consumers;
}
opInfo.membership.reserve(builders.size() + 1);
opInfo.membership.resize(builders.size(), /*t=*/false);
// No consumers - if there's any candidate then we'll go into that.
int firstCandidateOrdinal = candidates.find_first();
if (firstCandidateOrdinal == -1) {
// Mark the op as having hazards against all other partitions.
// It is better to be safe than incorrect, especially with our current
// minimal test coverage. It's not always safe to reorder things - if
// anything we are unlikely to be conservative enough here - for example,
// if there's a stream.resource.load of a resource or a global we can't
// move anything that may affect that resource or global. This
// partitioning was designed to be conservative because debugging such
// issues is really difficult.
if (!builders.empty()) {
opInfo.hazards.set(0, builders.size() - 1);
}
// Create a new partition just for this op.
opInfo.membership.resize(opInfo.membership.size() + 1, /*t=*/true);
auto builder = std::make_unique<PartitionBuilder>();
builder->ordinal = builders.size();
LLVM_DEBUG(llvm::dbgs()
<< "Created partition " << builder->ordinal << "\n");
builders.push_back(std::move(builder));
usableBuilders.resize(builders.size(), /*t=*/true);
firstCandidateOrdinal = builders.size() - 1;
}
auto &builder = builders[firstCandidateOrdinal];
// If we have synchronization operations we can place in the last block:
for (auto syncOp : syncOps[&op]) {
LLVM_DEBUG(llvm::dbgs() << "Moving sync to candidate partition "
<< firstCandidateOrdinal << ":\n ");
LLVM_DEBUG(syncOp->print(llvm::dbgs(), *asmState));
LLVM_DEBUG(llvm::dbgs() << "\n");
builder->insert(syncOp, opInfo);
}
LLVM_DEBUG(llvm::dbgs() << "Moving to first candidate partition "
<< firstCandidateOrdinal << " (continue)\n");
// If we are a clonable op (like splat) clone us into every partition.
// We also clone if there are nested region hazards, indicating consumers
// in nested regions that will need their own copy.
// Otherwise we just pick the first we find (probably a bad heuristic).
bool shouldClone =
streamableOp.preferCloneToConsumers() &&
(consumers.count() > 1 || opInfo.nestedRegionHazards.any());
if (shouldClone && consumers.any()) {
for (auto consumerOrdinal : consumers.set_bits()) {
LLVM_DEBUG(llvm::dbgs() << "Cloning into consumer partition "
<< consumerOrdinal << "\n");
auto &consumerBuilder = builders[consumerOrdinal];
consumerBuilder->insert(&op, opInfo);
consumerBuilder->clonedOps.insert(&op);
}
} else if (shouldClone && !consumers.any()) {
// We want to clone but have no streamable consumers. Check if the op
// has any uses in the same block - if so, it needs to be in a partition.
// If all uses are in nested regions, leave it for re-materialization.
bool hasUsesInSameBlock = false;
for (auto result : op.getResults()) {
for (auto user : result.getUsers()) {
if (user->getBlock() == block) {
hasUsesInSameBlock = true;
break;
}
}
if (hasUsesInSameBlock) {
break;
}
}
if (hasUsesInSameBlock) {
// Has uses in same block, insert into first candidate partition.
builder->insert(&op, opInfo);
}
// Else: all uses are in nested regions, leave for re-materialization.
} else {
// Not a clonable op, insert into first candidate partition.
builder->insert(&op, opInfo);
}
}
// Ops cloned into multiple partitions may still escape if there are
// non-streamable consumers. We need to make sure we only let one result
// escape.
DenseSet<Operation *> clonedEscapingOps;
// Emit partitions in forward order (as they are topologically sorted in
// reverse order from our bottom-up walk).
for (auto &builder : llvm::reverse(builders)) {
Partition partition;
SetVector<Value> consumedValues;
SetVector<Value> producedValues;
SetVector<Value> escapingValues;
for (auto *op : llvm::reverse(builder->ops)) {
// Collect direct operands only. Do NOT recursively collect values from
// nested regions of operations that define our operands - that would
// incorrectly add values consumed inside control flow ops (scf.for, etc)
// to partition inputs, creating spurious circular dependencies.
// The circular dependency check in Partition::verify() will correctly
// detect actual circular dependencies using collectConsumedValues.
bool didCloneEscape = false;
for (auto operand : op->getOperands()) {
consumedValues.insert(operand);
}
for (auto result : op->getResults()) {
producedValues.insert(result);
// Cloned ops default to local usage but may still have users outside
// of any partition and need to escape.
if (builder->clonedOps.contains(op)) {
// We only want to have one partition produce the value and track ones
// we've already produced via clonedEscapingOps.
if (!clonedEscapingOps.contains(op)) {
for (auto user : result.getUsers()) {
if (!isa<IREE::Stream::StreamableOpInterface>(user)) {
escapingValues.insert(result);
didCloneEscape = true;
break;
}
}
}
} else {
// TODO(benvanik): optimize this - creates n^2/nlogn behavior.
for (auto user : result.getUsers()) {
if (!builder->ops.contains(user)) {
escapingValues.insert(result);
}
}
}
}
if (didCloneEscape) {
clonedEscapingOps.insert(op);
}
}
consumedValues.set_subtract(producedValues);
partition.affinity = builder->affinity;
partition.ins = consumedValues;
partition.outs = escapingValues;
partition.ops = std::move(builder->ops);
partitionSet.partitions.push_back(std::move(partition));
}
LLVM_DEBUG(partitionSet.dump(*asmState));
return partitionSet;
}
// This looks to extract a single level of concurrency; we should be recursively
// dividing the block to identify both serial and concurrent regions.
PartitionSet
partitionRegionConcurrencyReference(IREE::Stream::PartitioningConfigAttr config,
Block *block) {
PartitionSet waveSet;
auto favor = config.getFavor().getValue();
if (favor == IREE::Stream::Favor::Debug) {
// Disable partitioning when favoring debuggability.
return waveSet;
}
struct PartitionBuilder {
unsigned ordinal;
// Ops present in the wave; ops may be present in multiple waves.
SetVector<Operation *> ops;
};
SmallVector<std::unique_ptr<PartitionBuilder>> builders;
struct OpInfo {
// Which waves the op is contained within.
llvm::BitVector membership;
// Which waves transitively depend on this operation.
llvm::BitVector hazards;
};
DenseMap<Operation *, OpInfo> opInfos;
auto asmState = getRootAsmState(block);
// Run analysis - if it fails then we'll just be conservative.
IREE::Stream::ResourceHazardAnalysis hazardAnalysis(block->getParentOp());
if (failed(hazardAnalysis.run())) {
LLVM_DEBUG(llvm::dbgs() << "WARNING: resource hazard analysis failed; "
"conservatively scheduling\n");
}
for (auto &op : llvm::reverse(*block)) {
// Skip constants; they just add noise (and since they are heavily CSE'd
// they have lots of users to test).
if (op.hasTrait<OpTrait::ConstantLike>()) {
LLVM_DEBUG(llvm::dbgs() << "(ignoring constant)\n");
continue;
}
// NOTE: it's ok if this op is not streamable as we still need to track the
// hazards for other ops that it may use/may use it.
auto streamableOp = dyn_cast<IREE::Stream::StreamableOpInterface>(op);
// Initialize op info for this op - whether streamable or not. We track
// transitive hazards on each op. Note that thanks to the ordering of ops
// in SSA form (_reversed here!_) we know that once we visit this op no
// wave created after it can ever depend on it if it doesn't here. This
// lets us keep the bitvectors small.
auto &opInfo = opInfos[&op];
opInfo.hazards.reserve(builders.size() + 1);
opInfo.hazards.resize(builders.size(), /*t=*/false);
LLVM_DEBUG({
llvm::dbgs() << "====\nPartitioning op:\n";
op.print(llvm::dbgs(), *asmState);
llvm::dbgs() << "\n";
});
// Set bits for each wave this op may be able to be placed into.
// We prune the set based on whether the users are part of a transitive
// dependency chain down the use-def chain to a wave.
for (auto user : op.getUsers()) {
auto userInfoIt = opInfos.find(user);
if (userInfoIt == opInfos.end()) {
continue;
}
auto &userInfo = userInfoIt->second;
LLVM_DEBUG({
llvm::dbgs() << "Testing user:\n";
user->print(llvm::dbgs(), *asmState);
llvm::dbgs() << "\n";
for (auto membershipOrdinal : userInfo.membership.set_bits()) {
llvm::dbgs() << " member of wave " << membershipOrdinal << "\n";
}
int lastHazardOrdinal = userInfo.hazards.find_last();
if (lastHazardOrdinal != -1) {
llvm::dbgs() << " hazard w/ waves 0-" << lastHazardOrdinal << "\n";
}
});
bool hazardPresent = hazardAnalysis.hasHazard(streamableOp, user);
if (hazardPresent) {
// Hazard with existing op usage - prevent concurrent scheduling.
opInfo.hazards |= userInfo.membership;
} else {
LLVM_DEBUG(llvm::dbgs() << " $ hazard analysis says ok to schedule\n");
}
// Always inherit hazards whether merging or not.
opInfo.hazards |= userInfo.hazards;
}
// Additional exhaustive testing for users of tied operands.
// For each resource operand of this op we scan back through previously
// created waves to see if there are any partitioned ops that have a hazard.
for (auto operand : op.getOperands()) {
if (!isa<IREE::Stream::ResourceType>(operand.getType())) {
continue;
}
for (auto user : operand.getUsers()) {
if (user == &op || user->getBlock() != block ||
user->isBeforeInBlock(&op)) {
continue;
}
auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(user);
if (!tiedOp || !tiedOp.hasAnyTiedUses(operand)) {
continue;
}
auto userInfoIt = opInfos.find(user);
if (userInfoIt == opInfos.end()) {
continue;
}
auto &userInfo = userInfoIt->second;
LLVM_DEBUG({
llvm::dbgs() << "Testing tied user:\n";
user->print(llvm::dbgs(), *asmState);
llvm::dbgs() << "\n";
for (auto membershipOrdinal : userInfo.membership.set_bits()) {
llvm::dbgs() << " member of wave " << membershipOrdinal << "\n";
}
int lastHazardOrdinal = userInfo.hazards.find_last();
if (lastHazardOrdinal != -1) {
llvm::dbgs() << " hazard w/ waves 0-" << lastHazardOrdinal << "\n";
}
});
bool hazardPresent = hazardAnalysis.hasHazard(streamableOp, user);
if (hazardPresent) {
// Hazard with existing op usage - prevent concurrent scheduling.
opInfo.hazards |= userInfo.membership;
} else {
LLVM_DEBUG(llvm::dbgs()
<< " $ hazard analysis says ok to schedule with tied\n");
}
// Always inherit hazards whether merging or not.
opInfo.hazards |= userInfo.hazards;
}
}
llvm::BitVector candidates(builders.size(), /*t=*/true);
candidates ^= opInfo.hazards;
// If this op is not streamable then bail here; we've still setup the hazard
// map for following iteration.
if (!streamableOp || streamableOp.isMetadata()) {
LLVM_DEBUG(llvm::dbgs() << "Not streamable/is subview (skip)\n");
continue;
}
opInfo.membership.reserve(builders.size() + 1);
opInfo.membership.resize(builders.size(), /*t=*/false);
// No consumers - if there's any candidate then we'll go into that.
int firstCandidateOrdinal = favor == IREE::Stream::Favor::MaxConcurrency
? candidates.find_first()
: candidates.find_last();
if (firstCandidateOrdinal != -1) {
LLVM_DEBUG(llvm::dbgs() << "Moving to last candidate wave "
<< firstCandidateOrdinal << " (continue)\n");
builders[firstCandidateOrdinal]->ops.insert(&op);
opInfo.membership.set(firstCandidateOrdinal);
opInfo.hazards.set(0, firstCandidateOrdinal);
opInfo.hazards.reset(firstCandidateOrdinal);
continue;
}
// Mark the op as having hazards against all other waves.
opInfo.hazards.set(0, builders.size());
// Create a new wave just for this op.
opInfo.membership.resize(opInfo.membership.size() + 1, /*t=*/true);
auto builder = std::make_unique<PartitionBuilder>();
builder->ordinal = builders.size();
builder->ops.insert(&op);
LLVM_DEBUG(llvm::dbgs() << "Created wave " << builder->ordinal << "\n");
builders.push_back(std::move(builder));
}
// Emit waves in forward order (as they are topologically sorted in
// reverse order from our bottom-up walk).
for (auto &builder : llvm::reverse(builders)) {
Partition wave;
SetVector<Value> consumedValues;
SetVector<Value> producedValues;
SetVector<Value> escapingValues;
for (auto *op : llvm::reverse(builder->ops)) {
for (auto operand : op->getOperands()) {
consumedValues.insert(operand);
}
for (auto result : op->getResults()) {
producedValues.insert(result);
// TODO(benvanik): optimize this - creates n^2/nlogn behavior.
for (auto user : result.getUsers()) {
if (!builder->ops.contains(user)) {
escapingValues.insert(result);
}
}
}
}
consumedValues.set_subtract(producedValues);
wave.ins = consumedValues;
wave.outs = escapingValues;
wave.ops = std::move(builder->ops);
waveSet.partitions.push_back(std::move(wave));
}
LLVM_DEBUG(waveSet.dump(*asmState));
return waveSet;
}
} // namespace mlir::iree_compiler::IREE::Stream