blob: 2a27ad68ddb7acfbd228b896190705c7c272d81c [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/VM/Conversion/MemRefToVM/ConvertMemRefToVM.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/Dialect/VM/Conversion/TargetOptions.h"
#include "iree/compiler/Dialect/VM/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace iree_compiler {
namespace {
/// Pattern to lower operations that become a no-ops at this level.
template <typename OpTy>
struct FoldAsNoOp final : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
OpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(op, operands);
return success();
}
};
/// Returns true if the given `type` is a MemRef of rank 0 or 1.
static bool isRankZeroOrOneMemRef(Type type) {
if (auto memrefType = type.dyn_cast<MemRefType>()) {
return memrefType.hasRank() && memrefType.getRank() <= 1;
}
return false;
}
// Returns the number of bytes an element of the given type occupies
// post-conversion. For example, the size of i1 would be '1 byte'.
static int32_t getRoundedElementByteWidth(Type type) {
return (type.getIntOrFloatBitWidth() + 8 - 1) / 8;
}
// Returns the offset, in bytes, of an index within a linearized dense buffer.
// Expects that the |memrefValue| has been linearized already.
static Value getBufferOffset(Location loc, Value memrefValue,
ValueRange indices, Type indiceType,
ConversionPatternRewriter &rewriter) {
auto memrefType = memrefValue.getType().cast<ShapedType>();
if (memrefType.getRank() == 0) {
// Rank 0 buffers (like memref<i32>) have only a single valid offset at 0.
return rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
}
assert(memrefType.getRank() == 1 && "memrefs should have been flattened");
// Element type byte length as the base.
auto elementType = memrefType.getElementType();
auto scalingExpr = getAffineBinaryOpExpr(
AffineExprKind::Mul, getAffineSymbolExpr(0, rewriter.getContext()),
getAffineConstantExpr(getRoundedElementByteWidth(elementType),
rewriter.getContext()));
// Rank 1 memrefs are just offset by their element width by the offset.
Value offset = rewriter.createOrFold<AffineApplyOp>(
loc, scalingExpr, ArrayRef<Value>{indices.front()});
return rewriter.create<arith::IndexCastOp>(loc, offset, indiceType);
}
class ConvertMemRefGlobalOp : public OpConversionPattern<memref::GlobalOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
memref::GlobalOp globalOp, ArrayRef<Value> rawOperands,
ConversionPatternRewriter &rewriter) const override {
memref::GlobalOpAdaptor operands(rawOperands);
if (!isRankZeroOrOneMemRef(globalOp.type())) {
return rewriter.notifyMatchFailure(
globalOp,
"only rank-0 and rank-1 memrefs are supported; flatten first");
}
// For mutable values we'd want to either have a RwdataOp or a global
// !vm.buffer that we initialized with rodata.
if (!globalOp.constant()) {
return rewriter.notifyMatchFailure(
globalOp, "mutable global memrefs not yet implemented");
}
auto rodataOp = rewriter.replaceOpWithNewOp<IREE::VM::RodataOp>(
globalOp, globalOp.sym_name(),
globalOp.initial_valueAttr().cast<ElementsAttr>());
rodataOp.setPrivate();
return success();
}
};
class ConvertMemRefGetGlobalOp
: public OpConversionPattern<memref::GetGlobalOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
memref::GetGlobalOp getOp, ArrayRef<Value> rawOperands,
ConversionPatternRewriter &rewriter) const override {
memref::GetGlobalOpAdaptor operands(rawOperands);
if (!isRankZeroOrOneMemRef(getOp.result().getType())) {
return rewriter.notifyMatchFailure(
getOp, "only rank-0 and rank-1 memrefs are supported; flatten first");
}
rewriter.replaceOpWithNewOp<IREE::VM::ConstRefRodataOp>(getOp,
getOp.name());
return success();
}
};
class ConvertMemRefLoadOp : public OpConversionPattern<memref::LoadOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
memref::LoadOp loadOp, ArrayRef<Value> rawOperands,
ConversionPatternRewriter &rewriter) const override {
memref::LoadOpAdaptor operands(rawOperands);
if (!isRankZeroOrOneMemRef(loadOp.memref().getType())) {
return rewriter.notifyMatchFailure(
loadOp,
"only rank-0 and rank-1 memrefs are supported; flatten first");
}
auto oldType = loadOp.result().getType();
auto newType = getTypeConverter()->convertType(oldType);
auto byteOffset = getBufferOffset(
loadOp.getLoc(), loadOp.memref(), loadOp.indices(),
getTypeConverter()->convertType(rewriter.getIndexType()), rewriter);
if (auto integerType = oldType.dyn_cast<IntegerType>()) {
if (integerType.isInteger(1) || integerType.isInteger(8)) {
if (integerType.isSigned() || integerType.isSignless()) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI8SOp>(
loadOp, newType, operands.memref(), byteOffset);
} else {
rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI8UOp>(
loadOp, newType, operands.memref(), byteOffset);
}
} else if (integerType.isInteger(16)) {
if (integerType.isSigned() || integerType.isSignless()) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI16SOp>(
loadOp, newType, operands.memref(), byteOffset);
} else {
rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI16UOp>(
loadOp, newType, operands.memref(), byteOffset);
}
} else if (integerType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI32Op>(
loadOp, newType, operands.memref(), byteOffset);
} else if (integerType.isInteger(64)) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI64Op>(
loadOp, newType, operands.memref(), byteOffset);
} else {
return rewriter.notifyMatchFailure(
loadOp, "invalid integer buffer element type");
}
} else if (oldType.isF32()) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadF32Op>(
loadOp, newType, operands.memref(), byteOffset);
} else if (oldType.isF64()) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadF64Op>(
loadOp, newType, operands.memref(), byteOffset);
} else {
return rewriter.notifyMatchFailure(loadOp,
"invalid float buffer element type");
}
return success();
}
};
class ConvertMemRefStoreOp : public OpConversionPattern<memref::StoreOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
memref::StoreOp storeOp, ArrayRef<Value> rawOperands,
ConversionPatternRewriter &rewriter) const override {
memref::StoreOpAdaptor operands(rawOperands);
if (!isRankZeroOrOneMemRef(storeOp.memref().getType())) {
return rewriter.notifyMatchFailure(
storeOp,
"only rank-0 and rank-1 memrefs are supported; flatten first");
}
auto oldType = storeOp.value().getType();
auto byteOffset = getBufferOffset(
storeOp.getLoc(), storeOp.memref(), storeOp.indices(),
getTypeConverter()->convertType(rewriter.getIndexType()), rewriter);
if (oldType.isInteger(1) || oldType.isInteger(8)) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreI8Op>(
storeOp, operands.memref(), byteOffset, operands.value());
} else if (oldType.isInteger(16)) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreI16Op>(
storeOp, operands.memref(), byteOffset, operands.value());
} else if (oldType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreI32Op>(
storeOp, operands.memref(), byteOffset, operands.value());
} else if (oldType.isInteger(64)) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreI64Op>(
storeOp, operands.memref(), byteOffset, operands.value());
} else if (oldType.isF32()) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreF32Op>(
storeOp, operands.memref(), byteOffset, operands.value());
} else if (oldType.isF64()) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreF64Op>(
storeOp, operands.memref(), byteOffset, operands.value());
} else {
return rewriter.notifyMatchFailure(storeOp,
"invalid buffer element type");
}
return success();
}
};
} // namespace
void populateMemRefToVMPatterns(MLIRContext *context,
ConversionTarget &conversionTarget,
TypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
conversionTarget.addIllegalDialect<memref::MemRefDialect>();
typeConverter.addConversion([&](MemRefType type) -> llvm::Optional<Type> {
if (isRankZeroOrOneMemRef(type)) {
return IREE::VM::RefType::get(
IREE::VM::BufferType::get(type.getContext()));
}
return llvm::None;
});
patterns.insert<FoldAsNoOp<memref::BufferCastOp>>(typeConverter, context);
patterns.insert<ConvertMemRefGlobalOp, ConvertMemRefGetGlobalOp,
ConvertMemRefLoadOp, ConvertMemRefStoreOp>(typeConverter,
context);
}
} // namespace iree_compiler
} // namespace mlir