blob: 0be2c806c15f3c1b2534d381e0d92027242a5e37 [file] [log] [blame] [edit]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/VMVX/IR/VMVXDialect.h"
#include "iree/compiler/Dialect/VMVX/IR/VMVXOps.h"
#include "iree/compiler/Dialect/VMVX/Transforms/Passes.h"
#include "iree/compiler/Utils/IntegerSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "iree-vmvx-resolve-buffer-descriptor"
namespace mlir::iree_compiler::IREE::VMVX {
#define GEN_PASS_DEF_RESOLVEBUFFERDESCRIPTORSPASS
#include "iree/compiler/Dialect/VMVX/Transforms/Passes.h.inc"
namespace {
/// Helper struct to return the offset, sizes and strides
/// of a `source` of a `memref.extract_strided_metadata` op.
struct DescriptorInfo {
OpFoldResult offset;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
};
} // namespace
/// Returns an AffineMap for an add or a mul.
static AffineMap getAddMap(MLIRContext *context) {
AffineExpr s0, s1;
bindSymbols(context, s0, s1);
return AffineMap::get(0, 2, s0 + s1);
}
static AffineMap getMulMap(MLIRContext *context) {
AffineExpr s0, s1;
bindSymbols(context, s0, s1);
return AffineMap::get(0, 2, s0 * s1);
}
static FailureOr<DescriptorInfo> resolveBufferDescriptorForSubview(
memref::SubViewOp subview, RewriterBase &rewriter, Location loc,
Value sourceOffset, ValueRange sourceSizes, ValueRange sourceStrides) {
DescriptorInfo resultDescriptor;
// For sizes, we just use the new ones.
resultDescriptor.sizes = subview.getMixedSizes();
// Apply stride multipliers.
AffineMap mulMap = getMulMap(rewriter.getContext());
for (auto [index, stride] : llvm::enumerate(subview.getMixedStrides())) {
OpFoldResult currentStride = affine::makeComposedFoldedAffineApply(
rewriter, loc, mulMap, {sourceStrides[index], stride});
resultDescriptor.strides.push_back(currentStride);
}
// Offsets.
resultDescriptor.offset = sourceOffset;
AffineMap addMap = getAddMap(rewriter.getContext());
for (auto [index, offset] : llvm::enumerate(subview.getMixedOffsets())) {
OpFoldResult physicalOffset = affine::makeComposedFoldedAffineApply(
rewriter, loc, mulMap, {offset, resultDescriptor.strides[index]});
resultDescriptor.offset = affine::makeComposedFoldedAffineApply(
rewriter, loc, addMap, {resultDescriptor.offset, physicalOffset});
}
return resultDescriptor;
}
/// Returns the strides based on the sizes assuming that the `memref`
/// has default layout, i.e. it is not a result of a subview.
static SmallVector<OpFoldResult>
getStridesFromSizes(RewriterBase &rewriter, Location loc,
ArrayRef<OpFoldResult> sizes) {
if (sizes.size() == 0) {
return {};
}
SmallVector<OpFoldResult> strides(sizes.size());
strides.back() = rewriter.getIndexAttr(1);
if (sizes.size() == 1) {
return strides;
}
AffineMap mulMap = getMulMap(rewriter.getContext());
for (int i = sizes.size() - 2; i >= 0; --i) {
strides[i] = affine::makeComposedFoldedAffineApply(
rewriter, loc, mulMap, {strides[i + 1], sizes[i + 1]});
}
return strides;
}
static FailureOr<DescriptorInfo> resolveBufferDescriptorForInterfaceBinding(
IREE::HAL::InterfaceBindingSubspanOp binding, RewriterBase &rewriter,
Location loc) {
auto memRefType = cast<MemRefType>(binding.getResult().getType());
int rank = memRefType.getRank();
DescriptorInfo resultDescriptor;
// Compute sizes.
auto dynamicDimIt = binding.getDynamicDims().begin();
for (int i = 0; i < rank; ++i) {
if (memRefType.isDynamicDim(i)) {
resultDescriptor.sizes.push_back(*dynamicDimIt);
dynamicDimIt++;
} else {
resultDescriptor.sizes.push_back(
rewriter.getIndexAttr(memRefType.getDimSize(i)));
}
}
// Strides.
resultDescriptor.strides =
getStridesFromSizes(rewriter, loc, resultDescriptor.sizes);
// Offset.
resultDescriptor.offset = convertByteOffsetToElementOffset(
rewriter, loc, binding.getByteOffset(), memRefType.getElementType());
return resultDescriptor;
}
static FailureOr<DescriptorInfo>
resolveBufferDescriptorForAllocation(memref::AllocaOp alloca,
RewriterBase &rewriter, Location loc) {
DescriptorInfo resultDescriptor;
// Replace the op with values:
// base_buffer: The subspan result
// offset: byte offset from subspan divided by element type size
// sizes: static and dynamic sizes from the subspan
// strides: identity strides
auto memRefType = cast<MemRefType>(alloca.getResult().getType());
int rank = memRefType.getRank();
// Compute sizes.
auto dynamicDimIt = alloca.getDynamicSizes().begin();
for (int i = 0; i < rank; ++i) {
if (memRefType.isDynamicDim(i)) {
resultDescriptor.sizes.push_back(*dynamicDimIt);
dynamicDimIt++;
} else {
resultDescriptor.sizes.push_back(
rewriter.getIndexAttr(memRefType.getDimSize(i)));
}
}
// Strides (just creates identity strides).
resultDescriptor.strides =
getStridesFromSizes(rewriter, loc, resultDescriptor.sizes);
resultDescriptor.offset = rewriter.getIndexAttr(0);
return resultDescriptor;
}
static FailureOr<DescriptorInfo>
resolveBufferDescriptorForGetGlobalOp(memref::GetGlobalOp global,
RewriterBase &rewriter, Location loc) {
IndexSet indexSet(loc, rewriter);
DescriptorInfo resultDescriptor;
// Replace the op with values:
// base_buffer: The subspan result
// offset: byte offset from subspan divided by element type size
// sizes: static and dynamic sizes from the subspan
// strides: identity strides
auto memRefType = cast<MemRefType>(global.getResult().getType());
int rank = memRefType.getRank();
// Compute sizes.
for (int i = 0; i < rank; ++i) {
if (memRefType.isDynamicDim(i)) {
return rewriter.notifyMatchFailure(
global, "memref.get_global does not support dynamic dims");
}
resultDescriptor.sizes.push_back(
rewriter.getIndexAttr(memRefType.getDimSize(i)));
}
// Strides (just creates identity strides).
resultDescriptor.strides =
getStridesFromSizes(rewriter, loc, resultDescriptor.sizes);
// Offset.
resultDescriptor.offset = rewriter.getIndexAttr(0);
return resultDescriptor;
}
/// Replaces the offsets, sizes and strides based on values provided
/// by `DescriptorInfo` object.
static void
replaceOffsetSizesAndStridesWith(RewriterBase &rewriter,
GetBufferDescriptorOp op,
const DescriptorInfo &resultDescriptor) {
int rank = resultDescriptor.sizes.size();
assert(rank == resultDescriptor.strides.size() &&
"expected number of sizes and strides to match");
assert(op.getSizes().size() == rank &&
"expected as many size replacements as the number of sizes in the "
"original operation");
assert(op.getStrides().size() == rank &&
"expected as many strides replacements as the number of strides in "
"the original operation");
Location loc = op.getLoc();
for (int i = 0; i < rank; ++i) {
// Sizes
rewriter.replaceAllUsesWith(op.getSizes()[i],
getValueOrCreateConstantIndexOp(
rewriter, loc, resultDescriptor.sizes[i]));
// Strides
rewriter.replaceAllUsesWith(
op.getStrides()[i], getValueOrCreateConstantIndexOp(
rewriter, loc, resultDescriptor.strides[i]));
}
// Offset
rewriter.replaceAllUsesWith(
op.getOffset(),
getValueOrCreateConstantIndexOp(rewriter, loc, resultDescriptor.offset));
}
namespace {
struct FromMemRefSubView : public OpRewritePattern<GetBufferDescriptorOp> {
using Base::Base;
LogicalResult matchAndRewrite(GetBufferDescriptorOp op,
PatternRewriter &rewriter) const override {
auto subview = op.getSource().template getDefiningOp<memref::SubViewOp>();
if (!subview)
return failure();
auto loc = op.getLoc();
IndexSet indexSet(loc, rewriter);
// Get types.
auto subType = cast<MemRefType>(subview.getResult().getType());
Value source = subview.getSource();
auto sourceType = cast<MemRefType>(source.getType());
int sourceRank = sourceType.getRank();
int subRank = subType.getRank();
(void)subRank;
// Create a descriptor for the source.
IndexType indexType = rewriter.getIndexType();
SmallVector<Type> sizeStrideTypes;
for (int i = 0; i < sourceRank; i++) {
sizeStrideTypes.push_back(indexType);
}
auto sourceDesc = GetBufferDescriptorOp::create(
rewriter, loc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes,
sizeStrideTypes, source);
FailureOr<DescriptorInfo> resultDescriptor =
resolveBufferDescriptorForSubview(
subview, rewriter, loc, sourceDesc.getOffset(),
sourceDesc.getSizes(), sourceDesc.getStrides());
if (failed(resultDescriptor)) {
return rewriter.notifyMatchFailure(
op, "failed to resolve descriptor with source being a subview op");
}
llvm::SmallBitVector droppedDims = subview.getDroppedDims();
int targetIndex = 0;
for (int i = 0; i < sourceRank; ++i) {
if (droppedDims.test(i))
continue;
rewriter.replaceAllUsesWith(
op.getSizes()[targetIndex],
getValueOrCreateConstantIndexOp(rewriter, loc,
resultDescriptor->sizes[i]));
rewriter.replaceAllUsesWith(
op.getStrides()[targetIndex],
getValueOrCreateConstantIndexOp(rewriter, loc,
resultDescriptor->strides[i]));
targetIndex++;
}
rewriter.replaceAllUsesWith(op.getOffset(),
getValueOrCreateConstantIndexOp(
rewriter, loc, resultDescriptor->offset));
// Base.
rewriter.replaceAllUsesWith(op.getBaseBuffer(), sourceDesc.getBaseBuffer());
rewriter.eraseOp(op);
return success();
}
};
struct FromHalInterfaceBindingSubspan
: public OpRewritePattern<GetBufferDescriptorOp> {
using Base::Base;
LogicalResult matchAndRewrite(GetBufferDescriptorOp op,
PatternRewriter &rewriter) const override {
auto binding =
op.getSource()
.template getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
if (!binding)
return failure();
auto loc = op.getLoc();
FailureOr<DescriptorInfo> resultDescriptor =
resolveBufferDescriptorForInterfaceBinding(binding, rewriter, loc);
if (failed(resultDescriptor)) {
return rewriter.notifyMatchFailure(
op, "failed to resolve descriptor with source being binding op");
}
replaceOffsetSizesAndStridesWith(rewriter, op, resultDescriptor.value());
// Base buffer.
rewriter.replaceAllUsesWith(
op.getBaseBuffer(), IREE::VMVX::GetRawInterfaceBindingBufferOp::create(
rewriter, loc, op.getBaseBuffer().getType(),
binding.getLayout(), binding.getBindingAttr())
.getResult());
rewriter.eraseOp(op);
return success();
}
};
/// Function to handle replacement of base pointer of buffer
/// descriptors.
static Value
getBaseBufferReplacementForDescriptor(GetBufferDescriptorOp descriptorOp,
RewriterBase &rewriter, Location loc,
Value source) {
return UnrealizedConversionCastOp::create(
rewriter, loc, descriptorOp.getBaseBuffer().getType(), source)
.getResult(0);
}
struct FromMemRefAssumeAlignment
: public OpRewritePattern<GetBufferDescriptorOp> {
using Base::Base;
LogicalResult matchAndRewrite(GetBufferDescriptorOp op,
PatternRewriter &rewriter) const override {
auto assumeOp = op.getSource().getDefiningOp<memref::AssumeAlignmentOp>();
if (!assumeOp) {
return failure();
}
auto binding = assumeOp.getMemref()
.getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
if (!binding) {
return failure();
}
Location loc = op.getLoc();
// TODO(hanchung): Refactor resolverBufferDescriptor* out, so we don't have
// to track the SSA chain above.
FailureOr<DescriptorInfo> resultDescriptor =
resolveBufferDescriptorForInterfaceBinding(binding, rewriter, loc);
if (failed(resultDescriptor)) {
return rewriter.notifyMatchFailure(
op, "failed to resolve descriptor with source being binding op");
}
replaceOffsetSizesAndStridesWith(rewriter, op, resultDescriptor.value());
// Base buffer.
rewriter.replaceAllUsesWith(
op.getBaseBuffer(), IREE::VMVX::GetRawInterfaceBindingBufferOp::create(
rewriter, loc, op.getBaseBuffer().getType(),
binding.getLayout(), binding.getBindingAttr())
.getResult());
rewriter.eraseOp(op);
return success();
}
};
// Allocations always return a non-offset memref and are matched by this
// pattern.
struct FromAllocation : public OpRewritePattern<GetBufferDescriptorOp> {
using Base::Base;
LogicalResult matchAndRewrite(GetBufferDescriptorOp op,
PatternRewriter &rewriter) const override {
auto alloca = op.getSource().template getDefiningOp<memref::AllocaOp>();
if (!alloca)
return failure();
auto memRefType = cast<MemRefType>(alloca.getResult().getType());
if (!memRefType.getLayout().isIdentity()) {
return rewriter.notifyMatchFailure(op, "not identity allocation");
}
auto loc = op.getLoc();
FailureOr<DescriptorInfo> resultDescriptor =
resolveBufferDescriptorForAllocation(alloca, rewriter, loc);
if (failed(resultDescriptor)) {
return rewriter.notifyMatchFailure(
op, "failed to resolve descriptor for memref.alloca op");
}
replaceOffsetSizesAndStridesWith(rewriter, op, resultDescriptor.value());
// Base buffer.
Value replacement = getBaseBufferReplacementForDescriptor(
op, rewriter, loc, alloca.getResult());
rewriter.replaceAllUsesWith(op.getBaseBuffer(), replacement);
rewriter.eraseOp(op);
return success();
}
};
// MemRef globals are always static shaped and reference a non-offset
// buffer.
struct FromGlobal : public OpRewritePattern<GetBufferDescriptorOp> {
using Base::Base;
LogicalResult matchAndRewrite(GetBufferDescriptorOp op,
PatternRewriter &rewriter) const override {
auto global = op.getSource().template getDefiningOp<memref::GetGlobalOp>();
if (!global)
return failure();
auto memRefType = cast<MemRefType>(global.getResult().getType());
if (!memRefType.getLayout().isIdentity()) {
return rewriter.notifyMatchFailure(op, "not identity allocation");
}
auto loc = op.getLoc();
FailureOr<DescriptorInfo> resultDescriptor =
resolveBufferDescriptorForGetGlobalOp(global, rewriter, loc);
if (failed(resultDescriptor)) {
return rewriter.notifyMatchFailure(
op, "failed to resolve descriptor for memref.get_global source");
}
replaceOffsetSizesAndStridesWith(rewriter, op, resultDescriptor.value());
// Base buffer.
Value replacement = getBaseBufferReplacementForDescriptor(
op, rewriter, loc, global.getResult());
rewriter.replaceAllUsesWith(op.getBaseBuffer(), replacement);
rewriter.eraseOp(op);
return success();
}
};
//===---------------------------------------------------------------------===//
// Pass To resovle descriptors.
//===---------------------------------------------------------------------===//
class ResolveBufferDescriptorsPass final
: public impl::ResolveBufferDescriptorsPassBase<
ResolveBufferDescriptorsPass> {
public:
ResolveBufferDescriptorsPass() = default;
ResolveBufferDescriptorsPass(const ResolveBufferDescriptorsPass &) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<affine::AffineDialect, IREE::VMVX::VMVXDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.insert<FromAllocation, FromGlobal, FromHalInterfaceBindingSubspan,
FromMemRefSubView, FromMemRefAssumeAlignment>(
&getContext());
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
// If any get_buffer_descriptor patterns remain, we fail.
if (!allowUnresolved) {
SmallVector<Operation *> remaining;
getOperation()->walk([&](Operation *op) {
if (isa<GetBufferDescriptorOp>(op)) {
remaining.push_back(op);
}
});
if (!remaining.empty()) {
auto diag = getOperation()->emitError()
<< "Unable to resolve all strided buffer descriptors:";
for (auto *op : remaining) {
diag.attachNote(op->getLoc()) << "remaining live use";
}
signalPassFailure();
}
}
}
Option<bool> allowUnresolved{
*this, "allow-unresolved",
llvm::cl::desc("Allow unresolved descriptors (for testing)"),
llvm::cl::init(false)};
};
} // namespace
} // namespace mlir::iree_compiler::IREE::VMVX