blob: 577ed25193cdc7a9a0e945a796d6a68891f2ad5a [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Stream/Analysis/Partitioning.h"
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Stream/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#define DEBUG_TYPE "iree-stream-schedule-concurrency"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Stream {
namespace {
// TODO(benvanik): deduplicate this with ScheduleExecution - almost all of this
// is identical.
// Incremental builder for a partitioned region of executable work.
// Must be constructed in a topological order of all partitions.
struct WavePartitionBuilder {
explicit WavePartitionBuilder(Block *parentBlock, size_t ordinal,
Partition *partition,
BlockAndValueMapping &parentMapping,
MLIRContext *context)
: ordinal(ordinal), partition(partition), builder(context) {
// Fuse the location of all ops we'll be putting in the partition.
SmallVector<Location> locs;
for (auto *op : partition->ops) {
locs.push_back(op->getLoc());
}
auto fusedLoc = FusedLoc::get(context, locs);
// Find the insertion point in the parent block.
// This is at the last op defining an input as all inputs must be available.
Operation *insertionPt = nullptr;
for (auto in : partition->ins) {
auto *definingOp = in.getDefiningOp();
if (!definingOp) continue;
if (definingOp->getBlock() != parentBlock) continue;
if (!insertionPt) {
insertionPt = definingOp; // first defining op
} else if (insertionPt->isBeforeInBlock(definingOp)) {
insertionPt = definingOp; // moving insertion point down
}
}
OpBuilder parentBuilder(context);
if (insertionPt) {
parentBuilder.setInsertionPointAfter(insertionPt);
} else {
parentBuilder.setInsertionPointToStart(parentBlock);
}
// Gather operands and result types from the declared partition I/O.
// These are values from the original block. Note that because we are
// constructing in order we know that any results of prior partitions are
// in the |parentMapping|.
SmallVector<Type> resultTypes;
SmallVector<Value> resultSizes;
resultTypes.reserve(partition->outs.size());
resultSizes.reserve(partition->outs.size());
for (auto out : partition->outs) {
resultTypes.push_back(out.getType());
auto resultSize = IREE::Util::SizeAwareTypeInterface::queryValueSize(
fusedLoc, out, parentBuilder);
if (resultSize) resultSizes.push_back(resultSize);
}
SmallVector<Value> operands;
SmallVector<Type> operandTypes;
SmallVector<Value> operandSizes;
operands.reserve(partition->ins.size());
operandTypes.reserve(partition->ins.size());
operandSizes.reserve(partition->ins.size());
for (auto in : partition->ins) {
if (!in.getType().isa<IREE::Stream::ResourceType>()) continue;
operands.push_back(in);
operandTypes.push_back(in.getType());
auto operandSize = IREE::Util::SizeAwareTypeInterface::queryValueSize(
fusedLoc, in, parentBuilder);
if (operandSize) operandSizes.push_back(operandSize);
}
// TODO(benvanik): tie operands, or leave to canonicalization.
SmallVector<int64_t> tiedOperands;
concurrentOp = parentBuilder.create<IREE::Stream::AsyncConcurrentOp>(
fusedLoc, resultTypes, resultSizes, operands, operandSizes,
tiedOperands);
// Add entry block and arguments.
auto &entryBlock = concurrentOp.body().emplaceBlock();
for (auto args :
llvm::zip(operands, entryBlock.addArguments(operandTypes))) {
mapping.map(std::get<0>(args), std::get<1>(args));
}
builder = OpBuilder::atBlockBegin(&entryBlock);
// Remap results for escaping outputs.
for (auto results : llvm::zip(partition->outs, concurrentOp.results())) {
parentMapping.map(std::get<0>(results), std::get<1>(results));
}
}
// Visits a block operation and clones it into the partition, if desired.
//
// Slightly suboptimal to be calling this on each op for each partition,
// however we only walk the block once and constructing a multimap would be
// way worse.
//
// Returns true if the operation was cloned into the partition.
bool visit(Operation *op) {
if (!partition->ops.contains(op)) return false;
// Clone the op into the partition and remap it.
auto *clonedOp = builder.clone(*op, mapping);
(void)clonedOp;
LLVM_DEBUG({
llvm::dbgs() << "Cloned op into partition " << ordinal << ": ";
clonedOp->dump();
});
return true;
}
void finish() {
// Gather results mapped into the SSA values we've cloned.
SmallVector<Value> results;
SmallVector<Value> resultSizes;
results.reserve(partition->outs.size());
resultSizes.reserve(partition->outs.size());
for (auto oldResult : partition->outs) {
auto newResult = mapping.lookup(oldResult);
results.push_back(newResult);
auto resultSize = IREE::Util::SizeAwareTypeInterface::queryValueSize(
concurrentOp.getLoc(), newResult, builder);
if (resultSize) resultSizes.push_back(resultSize);
}
builder.create<IREE::Stream::YieldOp>(concurrentOp.getLoc(), results,
resultSizes);
}
size_t ordinal = -1;
Partition *partition = nullptr;
IREE::Stream::AsyncConcurrentOp concurrentOp;
OpBuilder builder;
BlockAndValueMapping mapping;
};
class ScheduleConcurrencyPass
: public ScheduleConcurrencyBase<ScheduleConcurrencyPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Stream::StreamDialect>();
registry.insert<IREE::Util::UtilDialect>();
}
void runOnOperation() override {
auto parentOp = dyn_cast<CallableOpInterface>(getOperation());
if (!parentOp || !parentOp.getCallableRegion() ||
parentOp.getCallableRegion()->empty()) {
return;
}
for (auto executeOp :
parentOp.getCallableRegion()->getOps<IREE::Stream::AsyncExecuteOp>()) {
if (failed(runOnRegion(executeOp))) return signalPassFailure();
}
}
LogicalResult runOnRegion(IREE::Stream::AsyncExecuteOp parentOp) {
if (parentOp.body().empty()) {
return success();
}
auto *block = &parentOp.body().front();
// Lookup the optional config used to control partitioning.
auto configAttr = IREE::Stream::PartitioningConfigAttr::lookup(parentOp);
// Compute a set of partitions covering all of the streamable ops in the
// execution region.
auto waveSet = partitionRegionConcurrency(configAttr, block);
if (waveSet.empty()) return success();
if (failed(waveSet.verify(parentOp.getLoc()))) return failure();
// Create partition builders for each partition.
// We'll clone ops into each and insert them into the block at the
// appropriate position (first use... probably).
BlockAndValueMapping mapping;
SmallVector<WavePartitionBuilder> partitionBuilders;
partitionBuilders.reserve(waveSet.size());
for (auto partition : llvm::enumerate(waveSet.partitions)) {
if (partition.value().ops.size() == 1) continue;
partitionBuilders.push_back(WavePartitionBuilder(block, partition.index(),
&partition.value(),
mapping, &getContext()));
}
// Walk over each op in the original block and find those that need to be
// partitioned. Each partition builder may clone the op into itself. The
// op will always be left in the original block and we'll rely on DCE to
// remove the ones no longer required. This is not a good approach as it
// creates a lot of new IR (up to O(op*partitions)).
SetVector<Operation *> deadOps;
for (auto &op : *block) {
if (op.hasTrait<OpTrait::IsTerminator>()) continue;
bool handled = false;
for (auto &partitionBuilder : partitionBuilders) {
handled = partitionBuilder.visit(&op) || handled;
}
if (handled) {
deadOps.insert(&op);
}
}
// Apply remapping for values captured/escaping partitions.
// We must do this per block as we'll be updating dominated block values.
for (auto &partitionBuilder : partitionBuilders) {
for (auto resultPair :
llvm::zip(partitionBuilder.partition->outs,
partitionBuilder.concurrentOp.results())) {
auto oldResult = std::get<0>(resultPair);
auto newResult = std::get<1>(resultPair);
oldResult.replaceAllUsesWith(newResult);
deadOps.insert(oldResult.getDefiningOp());
}
partitionBuilder.finish();
// Extremely shady reordering of ops we know (should) be safe to move
// after the partition - otherwise, we shouldn't have moved the source
// ops into the partition.
auto concurrentOp = partitionBuilder.concurrentOp;
for (auto user : concurrentOp->getUsers()) {
if (user->getBlock() == concurrentOp->getBlock() &&
user->isBeforeInBlock(partitionBuilder.concurrentOp)) {
LLVM_DEBUG({
llvm::dbgs() << "Shady move of op to after partition: ";
user->dump();
});
user->moveAfter(concurrentOp);
}
}
}
for (auto *deadOp : llvm::reverse(deadOps)) {
deadOp->erase();
}
LLVM_DEBUG({
llvm::dbgs() << "\nWaves constructed:\n";
block->dump();
});
return success();
}
};
} // namespace
std::unique_ptr<OperationPass<>> createScheduleConcurrencyPass() {
return std::make_unique<ScheduleConcurrencyPass>();
}
} // namespace Stream
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir