blob: f36a06f33af5f8d450d8248e38dad99dbddbc9d1 [file] [log] [blame]
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "iree-codegen-gpu-vector-distribution"
using namespace mlir::iree_compiler::IREE::VectorExt;
namespace mlir::iree_compiler {
using VectorValue = TypedValue<VectorType>;
constexpr StringLiteral kVectorLayoutFetcherStorageAttrName =
"__vector_layout_fetcher_storage";
constexpr StringLiteral kVectorLayoutRedistributeAttrName =
"__vector_layout_redistribute";
/// Set signature for the operation based on the analysis. Returns failure if
/// an operation contains vectors that cannot be distributed i.e. they have no
/// layout.
LogicalResult setOpSignature(Operation *op, VectorLayoutAnalysis &analysis,
const VectorLayoutOptions &options) {
SmallVector<Attribute> operands;
SmallVector<Attribute> results;
for (Value operand : op->getOperands()) {
if (auto vectorOperand = dyn_cast<VectorValue>(operand)) {
if (auto layout =
analysis.getLayout<VectorLayoutInterface>(vectorOperand)) {
operands.push_back(layout);
continue;
}
if (auto layout = options.getDefaultLayout(vectorOperand.getType())) {
operands.push_back(layout);
continue;
}
return failure();
}
operands.push_back(UnitAttr::get(op->getContext()));
}
for (Value result : op->getResults()) {
if (auto vectorResult = dyn_cast<VectorValue>(result)) {
if (auto layout =
analysis.getLayout<VectorLayoutInterface>(vectorResult)) {
results.push_back(layout);
continue;
}
if (auto layout = options.getDefaultLayout(vectorResult.getType())) {
results.push_back(layout);
continue;
}
return failure();
}
results.push_back(UnitAttr::get(op->getContext()));
}
ArrayAttr operandsAttr = ArrayAttr::get(op->getContext(), operands);
ArrayAttr resultsAttr = ArrayAttr::get(op->getContext(), results);
Attribute signature[] = {operandsAttr, resultsAttr};
op->setAttr(kVectorLayoutFetcherStorageAttrName,
ArrayAttr::get(op->getContext(), signature));
return success();
}
static bool hasOpSignature(Operation *op) {
return op->hasAttrOfType<ArrayAttr>(kVectorLayoutFetcherStorageAttrName);
}
static DistributionSignature getOpSignature(Operation *op) {
ArrayAttr signatureAttr =
op->getAttrOfType<ArrayAttr>(kVectorLayoutFetcherStorageAttrName);
assert(signatureAttr && "Op should have a signature attribute.");
assert(signatureAttr.size() == 2 && "Malformed signature attribute.");
ArrayAttr operandsAttr = dyn_cast<ArrayAttr>(signatureAttr[0]);
ArrayAttr resultsAttr = dyn_cast<ArrayAttr>(signatureAttr[1]);
assert(operandsAttr && resultsAttr && "Malformed signature attribute.");
assert(operandsAttr.size() == op->getNumOperands() &&
"Malformed signature attribute.");
assert(resultsAttr.size() == op->getNumResults() &&
"Malformed signature attribute.");
DistributionSignature signature;
auto addLayoutToSignature([&](Value value, Attribute layout) {
// Ignore null attributes.
if (isa<UnitAttr>(layout)) {
assert(!isa<VectorValue>(value) &&
"Malformed signature attribute: unit attribute for vector value.");
return;
}
assert(isa<VectorValue>(value) &&
"Malformed signature attribute: non-unit attribute for non-vector "
"value.");
auto vector = cast<VectorValue>(value);
auto vectorLayout = cast<VectorLayoutInterface>(layout);
assert(vectorLayout && "Malformed signature attribute.");
signature[vector] = vectorLayout;
});
for (auto [value, layout] :
llvm::zip_equal(op->getOperands(), operandsAttr)) {
addLayoutToSignature(value, layout);
}
for (auto [value, layout] : llvm::zip_equal(op->getResults(), resultsAttr)) {
addLayoutToSignature(value, layout);
}
return signature;
}
VectorValue
DistributionPattern::getDistributed(RewriterBase &rewriter, VectorValue value,
VectorLayoutInterface layout) const {
// If this is a result of a "to_simd" op, use the source value of it.
if (auto toSIMD = value.getDefiningOp<IREE::VectorExt::ToSIMDOp>()) {
return cast<VectorValue>(toSIMD.getInput());
}
// Create a "to_simt" op to convert the value to the distributed layout.
SmallVector<int64_t> distributedShape = layout.getDistributedShape();
VectorType distributedType =
VectorType::get(distributedShape, value.getType().getElementType());
auto toSIMT = rewriter.create<IREE::VectorExt::ToSIMTOp>(
value.getLoc(), distributedType, value);
return toSIMT.getResult();
}
void DistributionPattern::replaceOpWithDistributedValues(
RewriterBase &rewriter, Operation *op, ValueRange values) const {
// Replace all OpResults with the given values.
SmallVector<Value> replacements;
for (auto [opResult, replacement] :
llvm::zip_equal(op->getOpResults(), values)) {
// If this value is a vector type, it must be converted back to simd.
if (isa<VectorType>(replacement.getType())) {
auto oldResult = cast<VectorValue>(opResult);
// Create a toSIMD op to convert the value back to the simd.
rewriter.setInsertionPointAfterValue(oldResult);
Value toSIMD = rewriter.create<IREE::VectorExt::ToSIMDOp>(
oldResult.getLoc(), oldResult.getType(), replacement);
// Add to replacements.
replacement = toSIMD;
}
replacements.push_back(replacement);
}
rewriter.replaceOp(op, replacements);
}
std::optional<DistributionSignature>
DistributionPattern::getOpSignature(Operation *op) const {
if (!hasOpSignature(op)) {
return std::nullopt;
}
return ::mlir::iree_compiler::getOpSignature(op);
}
void DistributionPattern::setSignatureForRedistribution(
PatternRewriter &rewriter, Operation *op, Attribute inputLayoutsAttr,
Attribute outputLayoutsAttr) const {
Attribute signature[] = {inputLayoutsAttr, outputLayoutsAttr};
auto unitAttr = UnitAttr::get(rewriter.getContext());
rewriter.modifyOpInPlace(op, [&]() {
op->setAttr(kVectorLayoutFetcherStorageAttrName,
ArrayAttr::get(rewriter.getContext(), signature));
op->setAttr(kVectorLayoutRedistributeAttrName, unitAttr);
});
}
static void
debugPrintUniqueOperationNames(const std::deque<Operation *> &worklist) {
DenseSet<StringRef> uniqueNames;
for (Operation *op : worklist) {
uniqueNames.insert(op->getName().getStringRef());
}
for (StringRef name : uniqueNames) {
llvm::dbgs().indent(2) << "* " << name << "\n";
}
LLVM_DEBUG(llvm::dbgs() << "\n");
}
/// A rewriter for the pattern rewriting driver.
struct VectorDistributionRewriter : PatternRewriter {
VectorDistributionRewriter(MLIRContext *ctx) : PatternRewriter(ctx) {}
};
/// Custom listener to store emitted ops that needs to be distributed.
struct VectorDistributionListener : public RewriterBase::Listener {
bool hasOpsToBeDistributed() { return !toBeDistributed.empty(); }
void clearOpsToBeDistributed() { return toBeDistributed.clear(); }
const std::deque<Operation *> &getOpsToBeDistributed() const {
return toBeDistributed;
}
void notifyOperationModified(Operation *op) override {
if (op->hasAttr(kVectorLayoutRedistributeAttrName) &&
op->hasAttrOfType<ArrayAttr>(kVectorLayoutFetcherStorageAttrName)) {
op->removeAttr(kVectorLayoutRedistributeAttrName);
toBeDistributed.push_back(op);
}
}
private:
std::deque<Operation *> toBeDistributed;
};
static void applyVectorDistribution(Operation *root,
const FrozenRewritePatternSet &patterns) {
VectorDistributionRewriter rewriter(root->getContext());
VectorDistributionListener listener;
rewriter.setListener(&listener);
PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();
// Collect all the operations to be distributed.
std::deque<Operation *> worklist;
LLVM_DEBUG(llvm::dbgs() << "Collecting operations to be distributed\n");
root->walk([&](Operation *op) {
if (hasOpSignature(op)) {
worklist.push_back(op);
}
});
LLVM_DEBUG(llvm::dbgs() << "Operations to be distributed:\n");
LLVM_DEBUG(debugPrintUniqueOperationNames(worklist));
// Note that the pattern application here never runs on a newly created
// operation. It always runs on an existing operation. This ensures that no
// invalidated state of the analysis is ever used.
while (!worklist.empty()) {
Operation *op = worklist.front();
worklist.pop_front();
if (op == nullptr)
continue;
LLVM_DEBUG(llvm::dbgs() << "Distributing: ");
LLVM_DEBUG(op->print(llvm::dbgs(), OpPrintingFlags().skipRegions()));
LLVM_DEBUG(llvm::dbgs() << "\n");
if (failed(applicator.matchAndRewrite(op, rewriter))) {
LLVM_DEBUG(llvm::dbgs().indent(2)
<< ": Failed to distribute operation:\n");
continue;
}
// Move recently emitted operations that needs to be distributed
// from the local/rewriter worklist into the "global" worklist.
if (listener.hasOpsToBeDistributed()) {
auto opstoBeDistributed = listener.getOpsToBeDistributed();
LLVM_DEBUG(llvm::dbgs()
<< "Recently emitted operations to be distributed:\n");
LLVM_DEBUG(debugPrintUniqueOperationNames(opstoBeDistributed));
worklist.insert(worklist.end(), opstoBeDistributed.begin(),
opstoBeDistributed.end());
listener.clearOpsToBeDistributed();
}
LLVM_DEBUG(llvm::dbgs().indent(2)
<< ": Successfully distributed operation:\n");
}
}
LogicalResult distributeVectorOps(Operation *root,
RewritePatternSet &distributionPatterns,
VectorLayoutOptions &options) {
// Run the analysis and determine the layouts.
LLVM_DEBUG(llvm::dbgs() << "Running Layout Analysis\n");
VectorLayoutAnalysis analysis(root);
if (failed(analysis.run()))
return failure();
LLVM_DEBUG(llvm::dbgs() << "Layout Analysis Succeded\n");
LLVM_DEBUG(llvm::dbgs() << "\n\n");
// Go to each operation, and set its distribution signature.
LLVM_DEBUG(
llvm::dbgs() << "Setting distribution signatures for operations\n");
root->walk([&](Operation *op) {
if (failed(setOpSignature(op, analysis, options))) {
LLVM_DEBUG({
llvm::dbgs() << "Skipping operation because not all vector "
"operands/results have a layout:\n";
op->print(llvm::dbgs());
});
}
});
LLVM_DEBUG(llvm::dbgs() << "Distribution signatures set\n");
LLVM_DEBUG(root->print(llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n\n");
FrozenRewritePatternSet frozenPatterns(std::move(distributionPatterns));
applyVectorDistribution(root, frozenPatterns);
RewritePatternSet patterns(root->getContext());
IREE::VectorExt::ToSIMDOp::getCanonicalizationPatterns(patterns,
root->getContext());
IREE::VectorExt::ToSIMTOp::getCanonicalizationPatterns(patterns,
root->getContext());
if (failed(applyPatternsGreedily(root, std::move(patterns)))) {
return failure();
}
// Remove signature after distribution.
root->walk([](Operation *op) {
op->removeDiscardableAttr(kVectorLayoutFetcherStorageAttrName);
});
if (options.verifyConversion()) {
WalkResult hasConversionOp = root->walk([](Operation *op) {
if (isa<IREE::VectorExt::ToSIMDOp, IREE::VectorExt::ToSIMTOp>(op)) {
for (auto user : op->getUsers()) {
if (!isa<IREE::VectorExt::ToSIMDOp, IREE::VectorExt::ToSIMTOp>(
user)) {
LLVM_DEBUG({
llvm::dbgs() << "Found live cast op: " << *op << "\n";
llvm::dbgs() << "With live user: " << *user << "\n";
});
return WalkResult::interrupt();
}
}
}
return WalkResult::advance();
});
if (hasConversionOp.wasInterrupted()) {
return failure();
}
}
return success();
}
} // namespace mlir::iree_compiler