blob: 3a9d2620d9d67c8b5b67b7c0c8b77be0a1a10e80 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h"
#include <algorithm>
#include <map>
#include <utility>
#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Analysis/Dominance.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
namespace mlir {
namespace iree_compiler {
static Attribute getStrArrayAttr(Builder &builder,
ArrayRef<std::string> values) {
return builder.getStrArrayAttr(llvm::to_vector<8>(llvm::map_range(
values, [](const std::string &value) { return StringRef(value); })));
}
// static
LogicalResult RegisterAllocation::annotateIR(IREE::VM::FuncOp funcOp) {
RegisterAllocation registerAllocation;
if (failed(registerAllocation.recalculate(funcOp))) {
funcOp.emitOpError() << "failed to allocate registers for function";
return failure();
}
Builder builder(funcOp.getContext());
for (auto &block : funcOp.getBlocks()) {
SmallVector<std::string, 8> blockRegStrs;
blockRegStrs.reserve(block.getNumArguments());
for (auto blockArg : block.getArguments()) {
uint8_t reg = registerAllocation.map_[blockArg];
blockRegStrs.push_back(std::to_string(reg));
}
block.front().setAttr("block_registers",
getStrArrayAttr(builder, blockRegStrs));
for (auto &op : block.getOperations()) {
if (op.getNumResults() == 0) continue;
SmallVector<std::string, 8> regStrs;
regStrs.reserve(op.getNumResults());
for (auto result : op.getResults()) {
uint8_t reg = registerAllocation.map_[result];
regStrs.push_back(std::to_string(reg));
}
op.setAttr("result_registers", getStrArrayAttr(builder, regStrs));
}
Operation *terminatorOp = block.getTerminator();
if (terminatorOp->getNumSuccessors() > 0) {
SmallVector<Attribute, 2> successorAttrs;
successorAttrs.reserve(terminatorOp->getNumSuccessors());
for (int i = 0; i < terminatorOp->getNumSuccessors(); ++i) {
auto srcDstRegs =
registerAllocation.remapSuccessorRegisters(terminatorOp, i);
SmallVector<std::string, 8> remappingStrs;
for (auto &srcDstReg : srcDstRegs) {
remappingStrs.push_back(
llvm::formatv("{0}->{1}", srcDstReg.first, srcDstReg.second)
.str());
}
successorAttrs.push_back(getStrArrayAttr(builder, remappingStrs));
}
terminatorOp->setAttr("remap_registers",
builder.getArrayAttr(successorAttrs));
}
}
return success();
}
// Forms a register reference byte as interpreted by the VM.
// Assumes that the ordinal has been constructed in the valid range.
static uint8_t makeRegisterByte(Type type, int ordinal, bool isMove) {
if (type.isSignlessIntOrIndexOrFloat()) {
assert(ordinal < kIntRegisterCount);
return ordinal;
} else {
assert(ordinal < kRefRegisterCount);
return (ordinal | kRefRegisterTypeBit) | (isMove ? kRefRegisterMoveBit : 0);
}
}
// Bitmaps set indicating which registers of which banks are in use.
struct RegisterUsage {
llvm::BitVector intRegisters{kIntRegisterCount};
llvm::BitVector refRegisters{kRefRegisterCount};
int maxI32RegisterOrdinal = -1;
int maxRefRegisterOrdinal = -1;
void reset() {
intRegisters.reset();
refRegisters.reset();
maxI32RegisterOrdinal = -1;
maxRefRegisterOrdinal = -1;
}
Optional<uint8_t> allocateRegister(Type type) {
if (type.isSignlessIntOrIndexOrFloat()) {
int ordinal = intRegisters.find_first_unset();
if (ordinal >= kIntRegisterCount) {
return {};
}
intRegisters.set(ordinal);
maxI32RegisterOrdinal = std::max(ordinal, maxI32RegisterOrdinal);
return makeRegisterByte(type, ordinal, /*isMove=*/false);
} else {
int ordinal = refRegisters.find_first_unset();
if (ordinal >= kRefRegisterCount) {
return {};
}
refRegisters.set(ordinal);
maxRefRegisterOrdinal = std::max(ordinal, maxRefRegisterOrdinal);
return makeRegisterByte(type, ordinal, /*isMove=*/false);
}
}
void markRegisterUsed(uint8_t reg) {
int ordinal = getRegisterOrdinal(reg);
if (isRefRegister(reg)) {
refRegisters.set(ordinal);
maxRefRegisterOrdinal = std::max(ordinal, maxRefRegisterOrdinal);
} else {
intRegisters.set(ordinal);
maxI32RegisterOrdinal = std::max(ordinal, maxI32RegisterOrdinal);
}
}
void releaseRegister(uint8_t reg) {
if (isRefRegister(reg)) {
refRegisters.reset(reg & 0x3F);
} else {
intRegisters.reset(reg & 0x7F);
}
}
};
// Sorts blocks in dominance order such that the entry block is first and
// all of the following blocks are dominated only by blocks that have come
// before them in the list. This ensures that we always know all registers for
// block live-in values as we walk the blocks.
static SmallVector<Block *, 8> sortBlocksInDominanceOrder(
IREE::VM::FuncOp funcOp) {
DominanceInfo dominanceInfo(funcOp);
llvm::SmallSetVector<Block *, 8> unmarkedBlocks;
for (auto &block : funcOp.getBlocks()) {
unmarkedBlocks.insert(&block);
}
llvm::SmallSetVector<Block *, 8> markedBlocks;
std::function<void(Block *)> visit = [&](Block *block) {
if (markedBlocks.count(block) > 0) return;
for (auto *childBlock : dominanceInfo.getNode(block)->getChildren()) {
visit(childBlock->getBlock());
}
markedBlocks.insert(block);
};
while (!unmarkedBlocks.empty()) {
visit(unmarkedBlocks.pop_back_val());
}
auto orderedBlocks = markedBlocks.takeVector();
std::reverse(orderedBlocks.begin(), orderedBlocks.end());
return orderedBlocks;
}
// NOTE: this is not a good algorithm, nor is it a good allocator. If you're
// looking at this and have ideas of how to do this for real please feel
// free to rip it all apart :)
//
// Because I'm lazy we really only look at individual blocks at a time. It'd
// be much better to use dominance info to track values across blocks and
// ensure we are avoiding as many moves as possible. The special case we need to
// handle is when values are not defined within the current block (as values in
// dominators are allowed to cross block boundaries outside of arguments).
LogicalResult RegisterAllocation::recalculate(IREE::VM::FuncOp funcOp) {
map_.clear();
if (failed(liveness_.recalculate(funcOp))) {
return funcOp.emitError()
<< "failed to caclculate required liveness information";
}
// Walk the blocks in dominance order and build their register usage tables.
// We are accumulating value->register mappings in |map_| as we go and since
// we are traversing in order know that for each block we will have values in
// the |map_| for all implicitly captured values.
auto orderedBlocks = sortBlocksInDominanceOrder(funcOp);
for (auto *block : orderedBlocks) {
// Use the block live-in info to populate the register usage info at block
// entry. This way if the block is dominated by multiple blocks or the
// live-out of the dominator is a superset of this blocks live-in we are
// only working with the minimal set.
RegisterUsage registerUsage;
for (auto liveInValue : liveness_.getBlockLiveIns(block)) {
registerUsage.markRegisterUsed(mapToRegister(liveInValue));
}
// Allocate arguments first from left-to-right.
for (auto blockArg : block->getArguments()) {
auto reg = registerUsage.allocateRegister(blockArg.getType());
if (!reg.hasValue()) {
return funcOp.emitError() << "register allocation failed for block arg "
<< blockArg.getArgNumber();
}
map_[blockArg] = reg.getValue();
}
// Cleanup any block arguments that were unused. We do this after the
// initial allocation above so that block arguments can never alias as that
// makes things really hard to read. Ideally an optimization pass that
// removes unused block arguments would prevent this from happening.
for (auto blockArg : block->getArguments()) {
if (blockArg.use_empty()) {
registerUsage.releaseRegister(map_[blockArg]);
}
}
for (auto &op : block->getOperations()) {
for (auto &operand : op.getOpOperands()) {
if (liveness_.isLastValueUse(operand.get(), &op)) {
registerUsage.releaseRegister(map_[operand.get()]);
}
}
for (auto result : op.getResults()) {
auto reg = registerUsage.allocateRegister(result.getType());
if (!reg.hasValue()) {
return op.emitError() << "register allocation failed for result "
<< result.cast<OpResult>().getResultNumber();
}
map_[result] = reg.getValue();
if (result.use_empty()) {
registerUsage.releaseRegister(reg.getValue());
}
}
}
// Track the maximum register of each type used.
maxI32RegisterOrdinal_ =
std::max(maxI32RegisterOrdinal_, registerUsage.maxI32RegisterOrdinal);
maxRefRegisterOrdinal_ =
std::max(maxRefRegisterOrdinal_, registerUsage.maxRefRegisterOrdinal);
}
// Always allocate one register of each type more for scratch space.
// These scratch registers are used during remapping registers during branches
// that may have hazards (such as a remap set of 0->1 and 1->0). If we
// precomputed whether remappings were required here then we could avoid this
// but it doesn't seem worth it for a single register (yet).
if (maxI32RegisterOrdinal_ > 0) {
++maxI32RegisterOrdinal_;
}
if (maxRefRegisterOrdinal_ > 0) {
++maxRefRegisterOrdinal_;
}
// We currently don't check during the allocation above. If we implement
// spilling we could use this max information to reserve space for spilling.
if (maxI32RegisterOrdinal_ > kIntRegisterCount ||
maxRefRegisterOrdinal_ > kRefRegisterCount) {
return funcOp.emitError() << "function overflows stack register banks; "
"spilling to memory not yet implemented";
}
return success();
}
uint8_t RegisterAllocation::mapToRegister(Value value) {
auto it = map_.find(value);
assert(it != map_.end());
return it->getSecond();
}
uint8_t RegisterAllocation::mapUseToRegister(Value value, Operation *useOp,
int operandIndex) {
uint8_t reg = mapToRegister(value);
if (isRefRegister(reg) &&
liveness_.isLastValueUse(value, useOp, operandIndex)) {
reg |= kRefRegisterMoveBit;
}
return reg;
}
// A feedback arc set containing the minimal list of cycle-causing edges.
// https://en.wikipedia.org/wiki/Feedback_arc_set
struct FeedbackArcSet {
using NodeID = uint8_t;
using Edge = std::pair<NodeID, NodeID>;
// Edges making up a DAG (inputEdges - feedbackEdges).
SmallVector<Edge, 8> acyclicEdges;
// Edges of the feedback arc set that, if added to acyclicEdges, would cause
// cycles.
SmallVector<Edge, 8> feedbackEdges;
// Computes the FAS of a given directed graph that may contain cycles.
static FeedbackArcSet compute(ArrayRef<Edge> inputEdges) {
FeedbackArcSet result;
if (inputEdges.empty()) {
return result;
} else if (inputEdges.size() == 1) {
result.acyclicEdges.push_back(inputEdges.front());
return result;
}
struct FASNode {
NodeID id;
int indegree = 0;
int outdegree = 0;
};
SmallVector<FASNode, 8> nodeStorage;
llvm::SmallDenseMap<NodeID, FASNode *> nodes;
for (auto &edge : inputEdges) {
NodeID sourceID = getBaseRegister(edge.first);
NodeID sinkID = getBaseRegister(edge.second);
assert(sourceID != sinkID && "self-cycles not supported");
if (nodes.count(sourceID) == 0) {
nodeStorage.push_back({sourceID, 0, 0});
nodes.insert({sourceID, &nodeStorage.back()});
}
if (nodes.count(sinkID) == 0) {
nodeStorage.push_back({sinkID, 0, 0});
nodes.insert({sinkID, &nodeStorage.back()});
}
}
struct FASEdge {
FASNode *source;
FASNode *sink;
};
int maxOutdegree = 0;
int maxIndegree = 0;
SmallVector<FASEdge, 8> edges;
for (auto &edge : inputEdges) {
NodeID sourceID = getBaseRegister(edge.first);
NodeID sinkID = getBaseRegister(edge.second);
auto *sourceNode = nodes[sourceID];
++sourceNode->outdegree;
maxOutdegree = std::max(maxOutdegree, sourceNode->outdegree);
auto *sinkNode = nodes[sinkID];
++sinkNode->indegree;
maxIndegree = std::max(maxIndegree, sinkNode->indegree);
edges.push_back({sourceNode, sinkNode});
}
std::vector<SmallVector<FASNode *, 2>> buckets;
buckets.resize(std::max(maxOutdegree, maxIndegree) + 2);
auto nodeToBucketIndex = [&](FASNode *node) {
return node->indegree == 0 || node->outdegree == 0
? buckets.size() - 1
: std::abs(node->outdegree - node->indegree);
};
auto assignBucket = [&](FASNode *node) {
buckets[nodeToBucketIndex(node)].push_back(node);
};
auto removeBucket = [&](FASNode *node) {
int index = nodeToBucketIndex(node);
auto it = std::find(buckets[index].begin(), buckets[index].end(), node);
if (it != buckets[index].end()) {
buckets[index].erase(it);
}
};
for (auto &nodeEntry : nodes) {
assignBucket(nodeEntry.second);
}
auto removeNode = [&](FASNode *node) {
SmallVector<FASEdge, 4> inEdges;
inEdges.reserve(node->indegree);
SmallVector<FASEdge, 4> outEdges;
outEdges.reserve(node->outdegree);
for (auto &edge : edges) {
if (edge.sink == node) inEdges.push_back(edge);
if (edge.source == node) outEdges.push_back(edge);
}
bool collectInEdges = node->indegree <= node->outdegree;
bool collectOutEdges = !collectInEdges;
SmallVector<Edge, 4> results;
for (auto &edge : inEdges) {
if (edge.source == node) continue;
if (collectInEdges) {
results.push_back({edge.source->id, edge.sink->id});
}
removeBucket(edge.source);
--edge.source->outdegree;
assignBucket(edge.source);
}
for (auto &edge : outEdges) {
if (edge.sink == node) continue;
if (collectOutEdges) {
results.push_back({edge.source->id, edge.sink->id});
}
removeBucket(edge.sink);
--edge.sink->indegree;
assignBucket(edge.sink);
}
nodes.erase(node->id);
edges.erase(std::remove_if(edges.begin(), edges.end(),
[&](const FASEdge &edge) {
return edge.source == node ||
edge.sink == node;
}),
edges.end());
return results;
};
auto ends = buckets.back();
while (!nodes.empty()) {
while (!ends.empty()) {
auto *node = ends.front();
ends.erase(ends.begin());
removeNode(node);
}
if (nodes.empty()) break;
for (int i = buckets.size() - 1; i >= 0; --i) {
if (buckets[i].empty()) continue;
auto *bucket = buckets[i].front();
buckets[i].erase(buckets[i].begin());
auto feedbackEdges = removeNode(bucket);
result.feedbackEdges.append(feedbackEdges.begin(), feedbackEdges.end());
break;
}
}
// Build the DAG of the remaining edges now that we've isolated the ones
// that cause cycles.
llvm::SmallSetVector<NodeID, 8> acyclicNodes;
SmallVector<Edge, 8> acyclicEdges;
for (auto &inputEdge : inputEdges) {
auto it = std::find_if(
result.feedbackEdges.begin(), result.feedbackEdges.end(),
[&](const Edge &edge) {
return compareRegistersEqual(edge.first, inputEdge.first) &&
compareRegistersEqual(edge.second, inputEdge.second);
});
if (it == result.feedbackEdges.end()) {
acyclicEdges.push_back(inputEdge);
acyclicNodes.insert(getBaseRegister(inputEdge.first));
acyclicNodes.insert(getBaseRegister(inputEdge.second));
}
}
// Topologically sort the DAG so that we don't overwrite anything.
llvm::SmallSetVector<NodeID, 8> unmarkedNodes = acyclicNodes;
llvm::SmallSetVector<NodeID, 8> markedNodes;
std::function<void(NodeID)> visit = [&](NodeID node) {
if (markedNodes.count(node) > 0) return;
for (auto &edge : acyclicEdges) {
if (edge.first != node) continue;
visit(edge.second);
}
markedNodes.insert(node);
};
while (!unmarkedNodes.empty()) {
visit(unmarkedNodes.pop_back_val());
}
for (auto node : markedNodes.takeVector()) {
for (auto &edge : acyclicEdges) {
if (edge.first != node) continue;
result.acyclicEdges.push_back({edge.first, edge.second});
}
}
return result;
}
};
SmallVector<std::pair<uint8_t, uint8_t>, 8>
RegisterAllocation::remapSuccessorRegisters(Operation *op, int successorIndex) {
// Compute the initial directed graph of register movements.
// This may contain cycles ([reg 0->1], [reg 1->0], ...) that would not be
// possible to evaluate as a direct remapping.
SmallVector<std::pair<uint8_t, uint8_t>, 8> srcDstRegs;
auto *targetBlock = op->getSuccessor(successorIndex);
auto operands = op->getSuccessorOperands(successorIndex);
for (auto it : llvm::enumerate(operands)) {
uint8_t srcReg = mapToRegister(it.value());
BlockArgument targetArg = targetBlock->getArgument(it.index());
uint8_t dstReg = mapToRegister(targetArg);
if (!compareRegistersEqual(srcReg, dstReg)) {
srcDstRegs.push_back({srcReg, dstReg});
}
}
// Compute the feedback arc set to determine which edges are the ones inducing
// cycles, if any. This also provides us a DAG that we can trivially remap
// without worrying about cycles.
auto feedbackArcSet = FeedbackArcSet::compute(srcDstRegs);
assert(feedbackArcSet.acyclicEdges.size() +
feedbackArcSet.feedbackEdges.size() ==
srcDstRegs.size() &&
"lost an edge during feedback arc set computation");
// If there's no cycles we can simply use the sorted DAG produced.
if (feedbackArcSet.feedbackEdges.empty()) {
return feedbackArcSet.acyclicEdges;
}
assert(feedbackArcSet.feedbackEdges.size() == 1 &&
"liveness tracking of scratch registers not yet implemented");
// The last register in each bank is reserved for swapping, when required.
uint8_t scratchI32Reg = maxI32RegisterOrdinal_;
uint8_t scratchRefReg = kRefRegisterTypeBit | maxRefRegisterOrdinal_;
for (auto feedbackEdge : feedbackArcSet.feedbackEdges) {
uint8_t scratchReg =
isRefRegister(feedbackEdge.first) ? scratchRefReg : scratchI32Reg;
feedbackArcSet.acyclicEdges.insert(feedbackArcSet.acyclicEdges.begin(),
{feedbackEdge.first, scratchReg});
feedbackArcSet.acyclicEdges.push_back({scratchReg, feedbackEdge.second});
}
return feedbackArcSet.acyclicEdges;
}
} // namespace iree_compiler
} // namespace mlir