blob: 51d058bd4ebaef09ee3929c8f402a1ff08a2d015 [file]
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "Passes.h"
namespace mlir {
#define GEN_PASS_DEF_DECOMPOSEMEMREFSPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"
} // namespace mlir
using namespace mlir;
static void setInsertionPointToStart(OpBuilder &builder, Value val) {
if (auto *parentOp = val.getDefiningOp()) {
builder.setInsertionPointAfter(parentOp);
} else {
builder.setInsertionPointToStart(val.getParentBlock());
}
}
/// This is copied from static function affine::mlir::computeProduct.
/// TODO: enable this function in AffineOps.h
/// Return the product of `terms`, creating an `affine.apply` if any of them are
/// non-constant values. If any of `terms` is `nullptr`, return `nullptr`.
static OpFoldResult computeProduct(Location loc, OpBuilder &builder,
ArrayRef<OpFoldResult> terms) {
int64_t nDynamic = 0;
SmallVector<Value> dynamicPart;
AffineExpr result = builder.getAffineConstantExpr(1);
for (OpFoldResult term : terms) {
if (!term) {
return term;
}
std::optional<int64_t> maybeConst = getConstantIntValue(term);
if (maybeConst) {
result = result * builder.getAffineConstantExpr(*maybeConst);
} else {
dynamicPart.push_back(cast<Value>(term));
result = result * builder.getAffineSymbolExpr(nDynamic++);
}
}
if (auto constant = dyn_cast<AffineConstantExpr>(result)) {
return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
}
return affine::AffineApplyOp::create(builder, loc, result, dynamicPart)
.getResult();
}
static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>, OpFoldResult,
OpFoldResult>
getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
ArrayRef<OpFoldResult> subOffsets,
ArrayRef<OpFoldResult> subStrides = {}) {
auto sourceType = cast<MemRefType>(source.getType());
auto sourceRank = static_cast<unsigned>(sourceType.getRank());
memref::ExtractStridedMetadataOp newExtractStridedMetadata;
{
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointToStart(rewriter, source);
newExtractStridedMetadata =
memref::ExtractStridedMetadataOp::create(rewriter, loc, source);
}
auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
: rewriter.getIndexAttr(dim);
};
OpFoldResult origOffset =
getDim(sourceOffset, newExtractStridedMetadata.getOffset());
ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
OpFoldResult outmostDim =
getDim(sourceType.getShape().front(),
newExtractStridedMetadata.getSizes().front());
SmallVector<OpFoldResult> origStrides;
origStrides.reserve(sourceRank);
SmallVector<OpFoldResult> strides;
strides.reserve(sourceRank);
AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
for (auto i : llvm::seq(0u, sourceRank)) {
OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
if (!subStrides.empty()) {
strides.push_back(affine::makeComposedFoldedAffineApply(
rewriter, loc, s0 * s1, {subStrides[i], origStride}));
}
origStrides.emplace_back(origStride);
}
// Compute linearized index:
auto &&[expr, values] =
computeLinearIndex(rewriter.getIndexAttr(0), origStrides, subOffsets);
OpFoldResult linearizedIndex =
affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);
// Compute collapsed size: (the outmost stride * outmost dimension).
SmallVector<OpFoldResult> ops{origStrides.front(), outmostDim};
OpFoldResult collapsedSize = computeProduct(loc, rewriter, ops);
return {newExtractStridedMetadata.getBaseBuffer(), linearizedIndex,
origStrides, origOffset, collapsedSize};
}
static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
OpFoldResult in) {
if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
return arith::ConstantIndexOp::create(
rewriter, loc, cast<IntegerAttr>(offsetAttr).getInt());
}
return cast<Value>(in);
}
/// Returns a collapsed memref and the linearized index to access the element
/// at the specified indices.
static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
Location loc,
Value source,
ValueRange indices) {
auto &&[base, index, strides, offset, collapsedShape] =
getFlatOffsetAndStrides(rewriter, loc, source,
getAsOpFoldResult(indices));
return std::make_pair(
memref::ReinterpretCastOp::create(
rewriter, loc, source,
/* offset = */ offset,
/* shapes = */ ArrayRef<OpFoldResult>{collapsedShape},
/* strides = */ ArrayRef<OpFoldResult>{strides.back()}),
getValueFromOpFoldResult(rewriter, loc, index));
}
static bool needFlattening(Value val) {
auto type = cast<MemRefType>(val.getType());
return type.getRank() > 1;
}
static bool checkLayout(Value val) {
auto type = cast<MemRefType>(val.getType());
return type.getLayout().isIdentity() ||
isa<StridedLayoutAttr>(type.getLayout());
}
namespace {
template <typename T>
static Value getTargetMemref(T op) {
if constexpr (std::is_same_v<T, memref::LoadOp>) {
return op.getMemref();
} else if constexpr (std::is_same_v<T, vector::LoadOp>) {
return op.getBase();
} else if constexpr (std::is_same_v<T, memref::StoreOp>) {
return op.getMemref();
} else if constexpr (std::is_same_v<T, vector::StoreOp>) {
return op.getBase();
} else if constexpr (std::is_same_v<T, vector::MaskedLoadOp>) {
return op.getBase();
} else if constexpr (std::is_same_v<T, vector::MaskedStoreOp>) {
return op.getBase();
} else if constexpr (std::is_same_v<T, vector::TransferReadOp>) {
return op.getBase();
} else if constexpr (std::is_same_v<T, vector::TransferWriteOp>) {
return op.getBase();
}
return {};
}
template <typename T>
static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
Value offset) {
if constexpr (std::is_same_v<T, memref::LoadOp>) {
auto newLoad =
memref::LoadOp::create(rewriter, op->getLoc(), op->getResultTypes(),
flatMemref, ValueRange{offset});
newLoad->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newLoad.getResult());
} else if constexpr (std::is_same_v<T, vector::LoadOp>) {
auto newLoad =
vector::LoadOp::create(rewriter, op->getLoc(), op->getResultTypes(),
flatMemref, ValueRange{offset});
newLoad->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newLoad.getResult());
} else if constexpr (std::is_same_v<T, memref::StoreOp>) {
auto newStore = memref::StoreOp::create(rewriter, op->getLoc(),
op->getOperands().front(),
flatMemref, ValueRange{offset});
newStore->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newStore);
} else if constexpr (std::is_same_v<T, vector::StoreOp>) {
auto newStore = vector::StoreOp::create(rewriter, op->getLoc(),
op->getOperands().front(),
flatMemref, ValueRange{offset});
newStore->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newStore);
} else if constexpr (std::is_same_v<T, vector::TransferReadOp>) {
auto newTransferRead = vector::TransferReadOp::create(
rewriter, op->getLoc(), op.getType(), flatMemref, ValueRange{offset},
op.getPadding());
rewriter.replaceOp(op, newTransferRead.getResult());
} else if constexpr (std::is_same_v<T, vector::TransferWriteOp>) {
auto newTransferWrite = vector::TransferWriteOp::create(
rewriter, op->getLoc(), op.getVector(), flatMemref, ValueRange{offset});
rewriter.replaceOp(op, newTransferWrite);
} else if constexpr (std::is_same_v<T, vector::MaskedLoadOp>) {
auto newMaskedLoad = vector::MaskedLoadOp::create(
rewriter, op->getLoc(), op.getType(), flatMemref, ValueRange{offset},
op.getMask(), op.getPassThru());
newMaskedLoad->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newMaskedLoad.getResult());
} else if constexpr (std::is_same_v<T, vector::MaskedStoreOp>) {
auto newMaskedStore = vector::MaskedStoreOp::create(
rewriter, op->getLoc(), flatMemref, ValueRange{offset}, op.getMask(),
op.getValueToStore());
newMaskedStore->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newMaskedStore);
} else {
op.emitOpError("unimplemented: do not know how to replace op.");
}
}
template <typename T>
struct MemRefRewritePatternBase : public OpRewritePattern<T> {
using OpRewritePattern<T>::OpRewritePattern;
LogicalResult matchAndRewrite(T op,
PatternRewriter &rewriter) const override {
Value memref = getTargetMemref<T>(op);
if (!needFlattening(memref) || !checkLayout(memref)) {
return rewriter.notifyMatchFailure(op,
"nothing to do or unsupported layout");
}
auto &&[flatMemref, offset] = getFlattenMemrefAndOffset(
rewriter, op->getLoc(), memref, op.getIndices());
replaceOp<T>(op, rewriter, flatMemref, offset);
return success();
}
};
struct FlattenMemrefLoad : public MemRefRewritePatternBase<memref::LoadOp> {
using MemRefRewritePatternBase<memref::LoadOp>::MemRefRewritePatternBase;
};
struct FlattenVectorLoad : public MemRefRewritePatternBase<vector::LoadOp> {
using MemRefRewritePatternBase<vector::LoadOp>::MemRefRewritePatternBase;
};
struct FlattenMemrefStore : public MemRefRewritePatternBase<memref::StoreOp> {
using MemRefRewritePatternBase<memref::StoreOp>::MemRefRewritePatternBase;
};
struct FlattenVectorStore : public MemRefRewritePatternBase<vector::StoreOp> {
using MemRefRewritePatternBase<vector::StoreOp>::MemRefRewritePatternBase;
};
struct FlattenVectorMaskedLoad
: public MemRefRewritePatternBase<vector::MaskedLoadOp> {
using MemRefRewritePatternBase<
vector::MaskedLoadOp>::MemRefRewritePatternBase;
};
struct FlattenVectorMaskedStore
: public MemRefRewritePatternBase<vector::MaskedStoreOp> {
using MemRefRewritePatternBase<
vector::MaskedStoreOp>::MemRefRewritePatternBase;
};
struct FlattenVectorTransferRead
: public MemRefRewritePatternBase<vector::TransferReadOp> {
using MemRefRewritePatternBase<
vector::TransferReadOp>::MemRefRewritePatternBase;
};
struct FlattenVectorTransferWrite
: public MemRefRewritePatternBase<vector::TransferWriteOp> {
using MemRefRewritePatternBase<
vector::TransferWriteOp>::MemRefRewritePatternBase;
};
struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
using Base::Base;
LogicalResult matchAndRewrite(memref::SubViewOp op,
PatternRewriter &rewriter) const override {
Value memref = op.getSource();
if (!needFlattening(memref)) {
return rewriter.notifyMatchFailure(op, "nothing to do");
}
if (!checkLayout(memref)) {
return rewriter.notifyMatchFailure(op, "unsupported layout");
}
Location loc = op.getLoc();
SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
auto &&[base, finalOffset, strides, _, __] =
getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides);
auto srcType = cast<MemRefType>(memref.getType());
auto resultType = cast<MemRefType>(op.getType());
unsigned subRank = static_cast<unsigned>(resultType.getRank());
llvm::SmallBitVector droppedDims = op.getDroppedDims();
SmallVector<OpFoldResult> finalSizes;
finalSizes.reserve(subRank);
SmallVector<OpFoldResult> finalStrides;
finalStrides.reserve(subRank);
for (auto i : llvm::seq(0u, static_cast<unsigned>(srcType.getRank()))) {
if (droppedDims.test(i)) {
continue;
}
finalSizes.push_back(subSizes[i]);
finalStrides.push_back(strides[i]);
}
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
op, resultType, base, finalOffset, finalSizes, finalStrides);
return success();
}
};
struct DecomposeMemrefsPass
: public impl::DecomposeMemrefsPassBase<DecomposeMemrefsPass> {
using Base::Base;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<affine::AffineDialect, arith::ArithDialect,
memref::MemRefDialect, vector::VectorDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
mlir::iree_compiler::populateDecomposeMemrefsPatterns(patterns);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
namespace mlir::iree_compiler {
void populateDecomposeMemrefsPatterns(RewritePatternSet &patterns) {
patterns.insert<FlattenMemrefLoad, FlattenMemrefStore, FlattenSubview,
FlattenVectorMaskedLoad, FlattenVectorMaskedStore,
FlattenVectorLoad, FlattenVectorStore,
FlattenVectorTransferRead, FlattenVectorTransferWrite>(
patterns.getContext());
}
std::unique_ptr<Pass> createDecomposeMemrefsPass() {
return std::make_unique<DecomposeMemrefsPass>();
}
} // namespace mlir::iree_compiler