blob: 22f1d598cc8c6253729ee790f99734b26533f9e1 [file] [log] [blame] [edit]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Stream/Analysis/Partitioning.h"
#include "iree/compiler/Dialect/Stream/Analysis/ResourceHazards.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/PatternMatch.h"
#define DEBUG_TYPE "iree-stream-partitioning"
namespace mlir::iree_compiler::IREE::Stream {
// Returns an AsmState at the ancestor to |block| that is isolated from above.
// Returns nullptr if debug dumps of partitioning is disabled.
static std::unique_ptr<AsmState> getRootAsmState(Block *block) {
LLVM_DEBUG({
auto *rootOp = block->getParentOp();
while (auto parentOp = rootOp->getParentOp()) {
if (!isa<IREE::Stream::TimelineOpInterface>(parentOp) &&
parentOp->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
rootOp = parentOp;
break;
}
rootOp = parentOp;
}
return std::make_unique<AsmState>(rootOp);
});
return nullptr;
}
// This is terrible. See Stream/Analysis/Partition.h for a description of what
// a real implementation would do. We want cost modeling for tie breakers when
// an op could be in multiple partitions, cloning for ops that are not worth
// spanning partitions (like splats), etc.
PartitionSet
partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
Block *block) {
PartitionSet partitionSet;
struct OpInfo {
// Which partitions the op is contained within.
llvm::BitVector membership;
// Which partitions transitively depend on this operation.
llvm::BitVector hazards;
};
DenseMap<Operation *, OpInfo> opInfos;
struct PartitionBuilder {
unsigned ordinal;
// Affinity of the partition.
IREE::Stream::AffinityAttr affinity;
// Ops present in the partition; ops may be present in multiple partitions.
SetVector<Operation *> ops;
// Ops that were cloned and are known not to have their values escape.
DenseSet<Operation *> clonedOps;
// Which partitions transitively depend on this partition.
llvm::BitVector hazards;
void insert(Operation *op, OpInfo &opInfo) {
if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
affinity = affinity ? affinity.joinAND(affinityOp.getAffinityAttr())
: affinityOp.getAffinityAttr();
}
opInfo.membership.set(ordinal);
if (opInfo.hazards.size() > ordinal)
opInfo.hazards.reset(ordinal);
ops.insert(op);
hazards |= opInfo.hazards;
}
};
SmallVector<std::unique_ptr<PartitionBuilder>> builders;
llvm::BitVector usableBuilders;
auto willCreateCircularDependencyBetweenPartitions =
[&](unsigned sourceOrdinal, unsigned targetOrdinal) -> bool {
// Returns:
// If we are to make partition with ordinal targetOrdinal to
// depend on partition with ordinal sourceOrdinal,
// will this create a circular dependency.
if (sourceOrdinal == targetOrdinal) {
return false;
}
return builders[sourceOrdinal]->hazards.size() > targetOrdinal &&
builders[sourceOrdinal]->hazards[targetOrdinal];
};
auto canAddOpToPartition = [&](Operation &op, OpInfo &opInfo,
unsigned partitionOrdinal) {
auto streamableOp = dyn_cast<IREE::Stream::StreamableOpInterface>(op);
if (!streamableOp) {
return false;
}
// Most ops should have affinity at this point. If they do not then we allow
// them to be placed anywhere (and whatever performance implications that
// has is on the higher layers for not explicitly saying).
IREE::Stream::AffinityAttr affinityAttr;
if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
affinityAttr = affinityOp.getAffinityAttr();
}
if (!IREE::Stream::AffinityAttr::canExecuteTogether(
affinityAttr, builders[partitionOrdinal]->affinity)) {
return false;
}
bool preferCloneToConsumers = streamableOp.preferCloneToConsumers();
llvm::BitVector *opHazards = nullptr;
llvm::BitVector opHazardsInCandidatePartition;
if (preferCloneToConsumers) {
// If we are cloning we care only about users that are a part of the
// candidate partition.
// Here we would need to walk further down the users if a user is also
// cloned into the partition. This will be useful if we have a block of
// cloneable ops. If left like that, other than the inefficiency,
// it should not produce invalid partitioning.
opHazards = &opHazardsInCandidatePartition;
for (auto user : op.getUsers()) {
if (builders[partitionOrdinal]->ops.contains(user)) {
opHazardsInCandidatePartition |= opInfos[user].hazards;
}
}
} else {
opHazards = &opInfo.hazards;
}
for (auto opHazardOrdinal : opHazards->set_bits()) {
if (partitionOrdinal < opHazardOrdinal) {
// Reject partition ordering that would require partition sorting.
// TODO: It is probably more optimal to reorder the partitions after
// their formation based on their dependency graph instead of rejecting
// here. Since this is considered not a good partitioning algorithm
// and will probably get removed, we leave it like that.
return false;
}
// Check for formation of circular dependency between partitions.
if (willCreateCircularDependencyBetweenPartitions(opHazardOrdinal,
partitionOrdinal)) {
return false;
}
}
return true;
};
auto asmState = getRootAsmState(block);
llvm::DenseMap<Operation *, llvm::SmallVector<Operation *>> syncOps;
for (auto &op : llvm::reverse(*block)) {
// Skip constants; they just add noise (and since they are heavily CSE'd
// they have lots of users to test).
if (op.hasTrait<OpTrait::ConstantLike>()) {
LLVM_DEBUG(llvm::dbgs() << "(ignoring constant)\n");
continue;
} else if (isa<IREE::Util::GlobalStoreOpInterface>(op)) {
// We ignore global stores as they are unobservable within an execution
// region - we must still block on loads though.
LLVM_DEBUG(llvm::dbgs() << "(ignoring global store)\n");
continue;
} else if (!isa<IREE::Stream::StreamableOpInterface>(op)) {
// Not a streamable op. If it has side-effects then we force a hazard on
// all builders so that we don't move ops across it.
if (!mlir::wouldOpBeTriviallyDead(&op)) {
LLVM_DEBUG({
llvm::dbgs() << "Side-effecting op forcing flush and freeze:\n";
op.print(llvm::dbgs(), *asmState);
llvm::dbgs() << "\n";
});
usableBuilders.reset();
}
// Even though not a streamable op we still want to track it below.
}
// Synchronizing operations should join with their producers if the producer
// is streamable.
if (dyn_cast<IREE::Stream::AsyncBarrierOp>(op) ||
dyn_cast<IREE::Stream::AsyncTransferOp>(op)) {
auto producer = op.getOperand(0).getDefiningOp();
auto streamable =
dyn_cast_or_null<IREE::Stream::StreamableOpInterface>(producer);
if (streamable) {
if (!syncOps.contains(producer)) {
syncOps[producer] = llvm::SmallVector<Operation *>();
}
syncOps[producer].push_back(&op);
continue;
}
}
// Initialize op info for this op - whether streamable or not. We track
// transitive hazards on each op. Note that thanks to the ordering of ops
// in SSA form (_reversed here!_) we know that once we visit this op no
// partition created after it can ever depend on it if it doesn't here. This
// lets us keep the bitvectors small.
auto &opInfo = opInfos[&op];
opInfo.hazards.reserve(builders.size() + 1);
opInfo.hazards.resize(builders.size(), /*t=*/false);
LLVM_DEBUG({
llvm::dbgs() << "====\nPartitioning op:\n";
op.print(llvm::dbgs(), *asmState);
llvm::dbgs() << "\n";
});
// Set bits for each partition this op may be able to be placed into.
// We prune the set based on whether the users are part of a transitive
// dependency chain down the use-def chain to a partition.
llvm::BitVector consumers(builders.size(), /*t=*/false);
for (auto user : op.getUsers()) {
auto userInfoIt = opInfos.find(user);
if (userInfoIt == opInfos.end()) {
continue;
}
auto &userInfo = userInfoIt->second;
LLVM_DEBUG({
llvm::dbgs() << "Testing user:\n";
user->print(llvm::dbgs(), *asmState);
llvm::dbgs() << "\n";
for (auto membershipOrdinal : userInfo.membership.set_bits()) {
llvm::dbgs() << " member of partition " << membershipOrdinal << "\n";
}
for (auto hazardOrdinal : userInfo.hazards.set_bits()) {
llvm::dbgs() << " hazard w/ partition " << hazardOrdinal << "\n";
}
});
consumers |= userInfo.membership;
opInfo.hazards |= userInfo.membership;
opInfo.hazards |= userInfo.hazards;
}
for (auto syncOp : syncOps[&op]) {
for (auto user : syncOp->getUsers()) {
auto userInfoIt = opInfos.find(user);
if (userInfoIt == opInfos.end()) {
continue;
}
auto &userInfo = userInfoIt->second;
opInfo.hazards |= userInfo.membership;
opInfo.hazards |= userInfo.hazards;
consumers.reset();
}
}
// For any sync ops not use this ops results we need to put in a
// non-consumer block:
llvm::BitVector candidates(builders.size(), /*t=*/true);
candidates ^= opInfo.hazards;
candidates |= consumers;
candidates &= usableBuilders;
// Prune candidates that do not have a compatible affinity.
for (auto ordinal : candidates.set_bits()) {
if (!canAddOpToPartition(op, opInfo, ordinal)) {
LLVM_DEBUG(llvm::dbgs()
<< "Candidate partition " << ordinal << " incompatible\n");
candidates.reset(ordinal);
}
}
for (auto syncOp : syncOps[&op]) {
for (auto ordinal : candidates.set_bits()) {
if (!canAddOpToPartition(*syncOp, opInfo, ordinal)) {
LLVM_DEBUG(llvm::dbgs()
<< "Candidate partition " << ordinal << " incompatible\n");
candidates.reset(ordinal);
}
}
}
// If this op is not streamable then bail here; we've still setup the hazard
// map for following iteration.
auto streamableOp = dyn_cast<IREE::Stream::StreamableOpInterface>(op);
if (!streamableOp) {
LLVM_DEBUG(llvm::dbgs() << "Not streamable (skip)\n");
continue;
}
// First see which partitions are consuming this that we can also safely
// move in to.
consumers &= candidates;
if (consumers.any())
candidates = consumers;
opInfo.membership.reserve(builders.size() + 1);
opInfo.membership.resize(builders.size(), /*t=*/false);
// No consumers - if there's any candidate then we'll go into that.
int firstCandidateOrdinal = candidates.find_first();
if (firstCandidateOrdinal == -1) {
// Mark the op as having hazards against all other partitions.
// It is better to be safe than incorrect, especially with our current
// minimal test coverage. It's not always safe to reorder things - if
// anything we are unlikely to be conservative enough here - for example,
// if there's a stream.resource.load of a resource or a global we can't
// move anything that may affect that resource or global. This
// partitioning was designed to be conservative because debugging such
// issues is really difficult.
if (!builders.empty()) {
opInfo.hazards.set(0, builders.size() - 1);
}
// Create a new partition just for this op.
opInfo.membership.resize(opInfo.membership.size() + 1, /*t=*/true);
auto builder = std::make_unique<PartitionBuilder>();
builder->ordinal = builders.size();
LLVM_DEBUG(llvm::dbgs()
<< "Created partition " << builder->ordinal << "\n");
builders.push_back(std::move(builder));
usableBuilders.resize(builders.size(), /*t=*/true);
firstCandidateOrdinal = builders.size() - 1;
}
auto &builder = builders[firstCandidateOrdinal];
// If we have synchronization operations we can place in the last block:
for (auto syncOp : syncOps[&op]) {
builder->insert(syncOp, opInfo);
}
LLVM_DEBUG(llvm::dbgs() << "Moving to first candidate partition "
<< firstCandidateOrdinal << " (continue)\n");
// If we are a clonable op (like splat) clone us into every partition.
// Otherwise we just pick the first we find (probably a bad heuristic).
if (consumers.count() > 1 && streamableOp.preferCloneToConsumers()) {
for (auto consumerOrdinal : consumers.set_bits()) {
LLVM_DEBUG(llvm::dbgs() << "Cloning into consumer partition "
<< consumerOrdinal << "\n");
auto &consumerBuilder = builders[consumerOrdinal];
consumerBuilder->insert(&op, opInfo);
consumerBuilder->clonedOps.insert(&op);
}
} else {
builder->insert(&op, opInfo);
}
}
// Ops cloned into multiple partitions may still escape if there are
// non-streamable consumers. We need to make sure we only let one result
// escape.
DenseSet<Operation *> clonedEscapingOps;
// Emit partitions in forward order (as they are topologically sorted in
// reverse order from our bottom-up walk).
for (auto &builder : llvm::reverse(builders)) {
Partition partition;
SetVector<Value> consumedValues;
SetVector<Value> producedValues;
SetVector<Value> escapingValues;
for (auto *op : llvm::reverse(builder->ops)) {
bool didCloneEscape = false;
for (auto operand : op->getOperands()) {
consumedValues.insert(operand);
}
for (auto result : op->getResults()) {
producedValues.insert(result);
// Cloned ops default to local usage but may still have users outside
// of any partition and need to escape.
if (builder->clonedOps.contains(op)) {
// We only want to have one partition produce the value and track ones
// we've already produced via clonedEscapingOps.
if (!clonedEscapingOps.contains(op)) {
for (auto user : result.getUsers()) {
if (!isa<IREE::Stream::StreamableOpInterface>(user)) {
escapingValues.insert(result);
didCloneEscape = true;
break;
}
}
}
} else {
// TODO(benvanik): optimize this - creates n^2/nlogn behavior.
for (auto user : result.getUsers()) {
if (!builder->ops.contains(user)) {
escapingValues.insert(result);
}
}
}
}
if (didCloneEscape) {
clonedEscapingOps.insert(op);
}
}
consumedValues.set_subtract(producedValues);
partition.affinity = builder->affinity;
partition.ins = consumedValues;
partition.outs = escapingValues;
partition.ops = std::move(builder->ops);
partitionSet.partitions.push_back(std::move(partition));
}
LLVM_DEBUG(partitionSet.dump(*asmState));
return partitionSet;
}
// This looks to extract a single level of concurrency; we should be recursively
// dividing the block to identify both serial and concurrent regions.
PartitionSet
partitionRegionConcurrencyReference(IREE::Stream::PartitioningConfigAttr config,
Block *block) {
PartitionSet waveSet;
auto favor = config.getFavor().getValue();
if (favor == IREE::Stream::Favor::Debug) {
// Disable partitioning when favoring debuggability.
return waveSet;
}
struct PartitionBuilder {
unsigned ordinal;
// Ops present in the wave; ops may be present in multiple waves.
SetVector<Operation *> ops;
};
SmallVector<std::unique_ptr<PartitionBuilder>> builders;
struct OpInfo {
// Which waves the op is contained within.
llvm::BitVector membership;
// Which waves transitively depend on this operation.
llvm::BitVector hazards;
};
DenseMap<Operation *, OpInfo> opInfos;
auto asmState = getRootAsmState(block);
// Run analysis - if it fails then we'll just be conservative.
IREE::Stream::ResourceHazardAnalysis hazardAnalysis(block->getParentOp());
if (failed(hazardAnalysis.run())) {
LLVM_DEBUG(llvm::dbgs() << "WARNING: resource hazard analysis failed; "
"conservatively scheduling\n");
}
for (auto &op : llvm::reverse(*block)) {
// Skip constants; they just add noise (and since they are heavily CSE'd
// they have lots of users to test).
if (op.hasTrait<OpTrait::ConstantLike>()) {
LLVM_DEBUG(llvm::dbgs() << "(ignoring constant)\n");
continue;
}
// NOTE: it's ok if this op is not streamable as we still need to track the
// hazards for other ops that it may use/may use it.
auto streamableOp = dyn_cast<IREE::Stream::StreamableOpInterface>(op);
// Initialize op info for this op - whether streamable or not. We track
// transitive hazards on each op. Note that thanks to the ordering of ops
// in SSA form (_reversed here!_) we know that once we visit this op no
// wave created after it can ever depend on it if it doesn't here. This
// lets us keep the bitvectors small.
auto &opInfo = opInfos[&op];
opInfo.hazards.reserve(builders.size() + 1);
opInfo.hazards.resize(builders.size(), /*t=*/false);
LLVM_DEBUG({
llvm::dbgs() << "====\nPartitioning op:\n";
op.print(llvm::dbgs(), *asmState);
llvm::dbgs() << "\n";
});
// Set bits for each wave this op may be able to be placed into.
// We prune the set based on whether the users are part of a transitive
// dependency chain down the use-def chain to a wave.
for (auto user : op.getUsers()) {
auto userInfoIt = opInfos.find(user);
if (userInfoIt == opInfos.end()) {
continue;
}
auto &userInfo = userInfoIt->second;
LLVM_DEBUG({
llvm::dbgs() << "Testing user:\n";
user->print(llvm::dbgs(), *asmState);
llvm::dbgs() << "\n";
for (auto membershipOrdinal : userInfo.membership.set_bits()) {
llvm::dbgs() << " member of wave " << membershipOrdinal << "\n";
}
int lastHazardOrdinal = userInfo.hazards.find_last();
if (lastHazardOrdinal != -1) {
llvm::dbgs() << " hazard w/ waves 0-" << lastHazardOrdinal << "\n";
}
});
bool hazardPresent = hazardAnalysis.hasHazard(streamableOp, user);
if (hazardPresent) {
// Hazard with existing op usage - prevent concurrent scheduling.
opInfo.hazards |= userInfo.membership;
} else {
LLVM_DEBUG(llvm::dbgs() << " $ hazard analysis says ok to schedule\n");
}
// Always inherit hazards whether merging or not.
opInfo.hazards |= userInfo.hazards;
}
// Additional exhaustive testing for users of tied operands.
// For each resource operand of this op we scan back through previously
// created waves to see if there are any partitioned ops that have a hazard.
for (auto operand : op.getOperands()) {
if (!isa<IREE::Stream::ResourceType>(operand.getType())) {
continue;
}
for (auto user : operand.getUsers()) {
if (user == &op || user->getBlock() != block ||
user->isBeforeInBlock(&op)) {
continue;
}
auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(user);
if (!tiedOp || !tiedOp.hasAnyTiedUses(operand)) {
continue;
}
auto userInfoIt = opInfos.find(user);
if (userInfoIt == opInfos.end()) {
continue;
}
auto &userInfo = userInfoIt->second;
LLVM_DEBUG({
llvm::dbgs() << "Testing tied user:\n";
user->print(llvm::dbgs(), *asmState);
llvm::dbgs() << "\n";
for (auto membershipOrdinal : userInfo.membership.set_bits()) {
llvm::dbgs() << " member of wave " << membershipOrdinal << "\n";
}
int lastHazardOrdinal = userInfo.hazards.find_last();
if (lastHazardOrdinal != -1) {
llvm::dbgs() << " hazard w/ waves 0-" << lastHazardOrdinal << "\n";
}
});
bool hazardPresent = hazardAnalysis.hasHazard(streamableOp, user);
if (hazardPresent) {
// Hazard with existing op usage - prevent concurrent scheduling.
opInfo.hazards |= userInfo.membership;
} else {
LLVM_DEBUG(llvm::dbgs()
<< " $ hazard analysis says ok to schedule with tied\n");
}
// Always inherit hazards whether merging or not.
opInfo.hazards |= userInfo.hazards;
}
}
llvm::BitVector candidates(builders.size(), /*t=*/true);
candidates ^= opInfo.hazards;
// If this op is not streamable then bail here; we've still setup the hazard
// map for following iteration.
if (!streamableOp || streamableOp.isMetadata()) {
LLVM_DEBUG(llvm::dbgs() << "Not streamable/is subview (skip)\n");
continue;
}
opInfo.membership.reserve(builders.size() + 1);
opInfo.membership.resize(builders.size(), /*t=*/false);
// No consumers - if there's any candidate then we'll go into that.
int firstCandidateOrdinal = favor == IREE::Stream::Favor::MaxConcurrency
? candidates.find_first()
: candidates.find_last();
if (firstCandidateOrdinal != -1) {
LLVM_DEBUG(llvm::dbgs() << "Moving to last candidate wave "
<< firstCandidateOrdinal << " (continue)\n");
builders[firstCandidateOrdinal]->ops.insert(&op);
opInfo.membership.set(firstCandidateOrdinal);
opInfo.hazards.set(0, firstCandidateOrdinal);
opInfo.hazards.reset(firstCandidateOrdinal);
continue;
}
// Mark the op as having hazards against all other waves.
opInfo.hazards.set(0, builders.size());
// Create a new wave just for this op.
opInfo.membership.resize(opInfo.membership.size() + 1, /*t=*/true);
auto builder = std::make_unique<PartitionBuilder>();
builder->ordinal = builders.size();
builder->ops.insert(&op);
LLVM_DEBUG(llvm::dbgs() << "Created wave " << builder->ordinal << "\n");
builders.push_back(std::move(builder));
}
// Emit waves in forward order (as they are topologically sorted in
// reverse order from our bottom-up walk).
for (auto &builder : llvm::reverse(builders)) {
Partition wave;
SetVector<Value> consumedValues;
SetVector<Value> producedValues;
SetVector<Value> escapingValues;
for (auto *op : llvm::reverse(builder->ops)) {
for (auto operand : op->getOperands()) {
consumedValues.insert(operand);
}
for (auto result : op->getResults()) {
producedValues.insert(result);
// TODO(benvanik): optimize this - creates n^2/nlogn behavior.
for (auto user : result.getUsers()) {
if (!builder->ops.contains(user)) {
escapingValues.insert(result);
}
}
}
}
consumedValues.set_subtract(producedValues);
wave.ins = consumedValues;
wave.outs = escapingValues;
wave.ops = std::move(builder->ops);
waveSet.partitions.push_back(std::move(wave));
}
LLVM_DEBUG(waveSet.dump(*asmState));
return waveSet;
}
} // namespace mlir::iree_compiler::IREE::Stream