blob: d4b66fb64e08e7d9ba0667407a42795596f16498 [file] [log] [blame]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#define DEBUG_TYPE "iree-codegen-expand-strided-metadata"
#include "iree/compiler/Codegen/Common/PassDetail.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir::iree_compiler {
namespace {
/// Helper struct to return the offset, sizes and strides
/// of a `source` of a `memref.extract_strided_metadata` op.
struct DescriptorInfo {
OpFoldResult offset;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
};
} // namespace
/// Returns an AffineMap for an add or a mul.
static AffineMap getAddMap(MLIRContext *context) {
AffineExpr s0, s1;
bindSymbols(context, s0, s1);
return AffineMap::get(0, 2, s0 + s1);
}
static AffineMap getMulMap(MLIRContext *context) {
AffineExpr s0, s1;
bindSymbols(context, s0, s1);
return AffineMap::get(0, 2, s0 * s1);
}
/// Returns the strides based on the sizes assuming that the `memref`
/// has default layout, i.e. it is not a result of a subview.
static SmallVector<OpFoldResult>
getStridesFromSizes(RewriterBase &rewriter, Location loc,
ArrayRef<OpFoldResult> sizes) {
if (sizes.size() == 0) {
return {};
}
SmallVector<OpFoldResult> strides(sizes.size());
strides.back() = rewriter.getIndexAttr(1);
if (sizes.size() == 1) {
return strides;
}
AffineMap mulMap = getMulMap(rewriter.getContext());
for (int i = sizes.size() - 2; i >= 0; --i) {
strides[i] = affine::makeComposedFoldedAffineApply(
rewriter, loc, mulMap, {strides[i + 1], sizes[i + 1]});
}
return strides;
}
static FailureOr<DescriptorInfo> resolveBufferDescriptorForInterfaceBinding(
IREE::HAL::InterfaceBindingSubspanOp binding, RewriterBase &rewriter,
Location loc) {
auto memRefType = llvm::cast<MemRefType>(binding.getResult().getType());
int rank = memRefType.getRank();
DescriptorInfo resultDescriptor;
// Compute sizes.
auto dynamicDimIt = binding.getDynamicDims().begin();
for (int i = 0; i < rank; ++i) {
if (memRefType.isDynamicDim(i)) {
resultDescriptor.sizes.push_back(*dynamicDimIt);
dynamicDimIt++;
} else {
resultDescriptor.sizes.push_back(
rewriter.getIndexAttr(memRefType.getDimSize(i)));
}
}
// Strides.
resultDescriptor.strides =
getStridesFromSizes(rewriter, loc, resultDescriptor.sizes);
// Offset.
resultDescriptor.offset = convertByteOffsetToElementOffset(
rewriter, loc, binding.getByteOffset(), memRefType.getElementType());
return resultDescriptor;
}
/// Replaces the offsets, sizes and strides based on values provided
/// by `DescriptorInfo` object.
template <typename OpTy>
static void
replaceOffsetSizesAndStridesWith(RewriterBase &rewriter, OpTy op,
const DescriptorInfo &resultDescriptor) {
int rank = resultDescriptor.sizes.size();
assert(rank == resultDescriptor.strides.size() &&
"expected number of sizes and strides to match");
assert(op.getSizes().size() == rank &&
"expected as many size replacements as the number of sizes in the "
"original operation");
assert(op.getStrides().size() == rank &&
"expected as many strides replacements as the number of strides in "
"the original operation");
Location loc = op.getLoc();
for (int i = 0; i < rank; ++i) {
// Sizes
rewriter.replaceAllUsesWith(op.getSizes()[i],
getValueOrCreateConstantIndexOp(
rewriter, loc, resultDescriptor.sizes[i]));
// Strides
rewriter.replaceAllUsesWith(
op.getStrides()[i], getValueOrCreateConstantIndexOp(
rewriter, loc, resultDescriptor.strides[i]));
}
// Offset
rewriter.replaceAllUsesWith(
op.getOffset(),
getValueOrCreateConstantIndexOp(rewriter, loc, resultDescriptor.offset));
}
namespace {
struct ResolveExtractMetadataFromHalInterfaceBindingSubspan
: public OpRewritePattern<memref::ExtractStridedMetadataOp> {
using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
PatternRewriter &rewriter) const override {
auto binding =
op.getSource()
.template getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
if (!binding)
return failure();
auto memRefType = llvm::cast<MemRefType>(binding.getResult().getType());
if (memRefType.getRank() < 1)
return failure();
auto loc = op.getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(binding);
FailureOr<DescriptorInfo> resultDescriptor =
resolveBufferDescriptorForInterfaceBinding(binding, rewriter, loc);
if (failed(resultDescriptor)) {
return rewriter.notifyMatchFailure(
op, "failed to resolve descriptor with source being binding op");
}
// For the base buffer of the `hal.interface.binding.subspan` create a 1D
// buffer with zero offset. For example, if the
// `hal.interface.binding.subspan` is
//
// ```mlir
// hal.interface.binding.subspan layout(#pipeline_layout) set(0)
// binding(1) offset(%offset)
// : memref<?x?xf32, strided<[?, 1], offset: 64]>>{%s0, %s1}
// ```
//
// convert it to
//
// ```mlir
// #map = affine_map<()[s0, s1, s2] -> (s0 + s1 * s2)>
// %linearSize = affine.apply #map()[%offset, %s0, %s1]
// %c0 = arith.constant 0 : index
// hal.interface.binding.subspan layout(#pipeline_layout) set(0)
// binding(1) offset(%c0)
// : memref<?xf32>{%linearSize}
// ```
//
// Only the base pointer of this subspan is needed, so creating a
// subspan with zero offset (with original offset folded into the size)
// is realistic representation of what the IR needs.
AffineMap mulMap = getMulMap(rewriter.getContext());
OpFoldResult linearizedMemrefSize = rewriter.getIndexAttr(1);
for (auto size : resultDescriptor->sizes) {
linearizedMemrefSize = affine::makeComposedFoldedAffineApply(
rewriter, loc, mulMap, {linearizedMemrefSize, size});
}
AffineMap addMap = getAddMap(rewriter.getContext());
linearizedMemrefSize = affine::makeComposedFoldedAffineApply(
rewriter, loc, addMap,
{linearizedMemrefSize, resultDescriptor->offset});
SmallVector<int64_t> staticLinearShape;
SmallVector<Value> dynamicLinearShape;
dispatchIndexOpFoldResult(linearizedMemrefSize, dynamicLinearShape,
staticLinearShape);
auto newBufferType = MemRefType::get(
staticLinearShape, memRefType.getElementType(),
MemRefLayoutAttrInterface(), memRefType.getMemorySpace());
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto linearInterfaceBinding =
rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
loc, newBufferType, binding.getLayoutAttr(), binding.getSetAttr(),
binding.getBindingAttr(), zero, dynamicLinearShape,
binding.getAlignmentAttr(), binding.getDescriptorFlagsAttr());
SmallVector<Value> results;
results.reserve(memRefType.getRank() + 2);
auto baseBufferType = llvm::cast<MemRefType>(op.getBaseBuffer().getType());
if (!op.getBaseBuffer().use_empty()) {
if (newBufferType == baseBufferType) {
results.push_back(linearInterfaceBinding);
} else {
Value reinterpretCast = rewriter.create<memref::ReinterpretCastOp>(
loc, baseBufferType, linearInterfaceBinding, /*offset=*/0,
/*sizes=*/ArrayRef<int64_t>(),
/*strides=*/ArrayRef<int64_t>());
results.push_back(reinterpretCast);
}
} else {
results.push_back(nullptr);
}
results.push_back(getValueOrCreateConstantIndexOp(
rewriter, loc, resultDescriptor->offset));
results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
resultDescriptor->sizes));
results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
resultDescriptor->strides));
rewriter.replaceOp(op, results);
return success();
}
};
// Converts IREE::Codegen::ExtractStridedMetadataOp to
// MemRef::ExtractStridedMetadataOp
struct ConvertCodegenIREEExtractMetadataToMemRef
: public OpRewritePattern<IREE::Codegen::ExtractStridedMetadataOp> {
using OpRewritePattern<
IREE::Codegen::ExtractStridedMetadataOp>::OpRewritePattern;
LogicalResult matchAndRewrite(IREE::Codegen::ExtractStridedMetadataOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<memref::ExtractStridedMetadataOp>(
op, op.getSource());
return success();
}
};
struct IREEExpandStridedMetadataPass
: public IREEExpandStridedMetadataBase<IREEExpandStridedMetadataPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<affine::AffineDialect, arith::ArithDialect,
IREE::Codegen::IREECodegenDialect, memref::MemRefDialect>();
}
void runOnOperation() override;
};
} // namespace
void populateIREEResolveExtractStridedMetadataPatterns(
MLIRContext *context, RewritePatternSet &patterns) {
memref::populateResolveExtractStridedMetadataPatterns(patterns);
patterns.insert<ResolveExtractMetadataFromHalInterfaceBindingSubspan>(
context);
patterns.insert<ConvertCodegenIREEExtractMetadataToMemRef>(context);
}
void IREEExpandStridedMetadataPass::runOnOperation() {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
populateIREEResolveExtractStridedMetadataPatterns(context, patterns);
populateRemoveDeadMemAllocPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
if (!allowUnresolved) {
SmallVector<Operation *> remaining;
getOperation()->walk([&](Operation *op) {
if (isa<memref::ExtractStridedMetadataOp>(op)) {
remaining.push_back(op);
}
});
if (!remaining.empty()) {
auto diag = getOperation()->emitError()
<< "Unable to resolve all strided buffer descriptors:";
for (auto *op : remaining) {
diag.attachNote(op->getLoc()) << "remaining live use";
}
signalPassFailure();
}
}
}
std::unique_ptr<Pass> createIREEExpandStridedMetadataPass() {
return std::make_unique<IREEExpandStridedMetadataPass>();
}
} // namespace mlir::iree_compiler