blob: 1aefdfab5a8018790a74c0c48c8a633198af8ce0 [file] [log] [blame] [edit]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/Dialect/VM/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir::iree_compiler {
namespace {
static Value castToI32(Value value, OpBuilder &builder) {
if (value.getType().isInteger(32))
return value;
return builder.createOrFold<IREE::VM::TruncI64I32Op>(
value.getLoc(), builder.getI32Type(), value);
}
static Value castToIndex(Value value, OpBuilder &builder) {
if (value.getType().isIndex())
return value;
return builder.createOrFold<arith::IndexCastOp>(
value.getLoc(), builder.getIndexType(), value);
}
class ListCreateOpConversion
: public OpConversionPattern<IREE::Util::ListCreateOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Util::ListCreateOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value initialCapacity = adaptor.getInitialCapacity();
if (initialCapacity) {
initialCapacity = castToI32(initialCapacity, rewriter);
} else {
initialCapacity = IREE::VM::ConstI32Op::create(
rewriter, srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
}
rewriter.replaceOpWithNewOp<IREE::VM::ListAllocOp>(
srcOp, typeConverter->convertType(srcOp.getResult().getType()),
initialCapacity);
return success();
}
};
class ListConstructOpConversion
: public OpConversionPattern<IREE::Util::ListConstructOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Util::ListConstructOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Allocate list with exact capacity.
size_t valueCount = srcOp.getValues().size();
Value size = IREE::VM::ConstI32Op::create(
rewriter, srcOp.getLoc(), rewriter.getI32IntegerAttr(valueCount));
Value list = IREE::VM::ListAllocOp::create(
rewriter, srcOp.getLoc(),
typeConverter->convertType(srcOp.getResult().getType()), size);
// Resize to the count (if needed).
if (valueCount > 0) {
IREE::VM::ListResizeOp::create(rewriter, srcOp.getLoc(), list, size);
}
// Add all entries.
for (auto [i, value] : llvm::enumerate(adaptor.getValues())) {
Value index = IREE::VM::ConstI32Op::create(
rewriter, srcOp.getLoc(),
rewriter.getI32IntegerAttr(static_cast<int>(i)));
if (value.getType().isInteger(32)) {
IREE::VM::ListSetI32Op::create(rewriter, srcOp.getLoc(), list, index,
value);
} else if (value.getType().isInteger(64)) {
IREE::VM::ListSetI64Op::create(rewriter, srcOp.getLoc(), list, index,
value);
} else if (value.getType().isFloat(32)) {
IREE::VM::ListSetF32Op::create(rewriter, srcOp.getLoc(), list, index,
value);
} else if (value.getType().isFloat(64)) {
IREE::VM::ListSetF64Op::create(rewriter, srcOp.getLoc(), list, index,
value);
} else if (isa<IREE::VM::RefType>(value.getType())) {
IREE::VM::ListSetRefOp::create(rewriter, srcOp.getLoc(), list, index,
value);
} else {
return rewriter.notifyMatchFailure(srcOp, "invalid list element type");
}
}
rewriter.replaceOp(srcOp, list);
return success();
}
};
class ListSizeOpConversion
: public OpConversionPattern<IREE::Util::ListSizeOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Util::ListSizeOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value size = IREE::VM::ListSizeOp::create(
rewriter, srcOp.getLoc(), rewriter.getI32Type(), adaptor.getList());
rewriter.replaceOp(srcOp, castToIndex(size, rewriter));
return success();
}
};
class ListResizeOpConversion
: public OpConversionPattern<IREE::Util::ListResizeOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Util::ListResizeOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREE::VM::ListResizeOp>(
srcOp, adaptor.getList(), castToI32(adaptor.getNewSize(), rewriter));
return success();
}
};
class ListGetOpConversion : public OpConversionPattern<IREE::Util::ListGetOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Util::ListGetOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto index = castToI32(adaptor.getIndex(), rewriter);
auto resultType = typeConverter->convertType(srcOp.getResult().getType());
if (resultType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::ListGetI32Op>(
srcOp, resultType, adaptor.getList(), index);
} else if (resultType.isInteger(64)) {
rewriter.replaceOpWithNewOp<IREE::VM::ListGetI64Op>(
srcOp, resultType, adaptor.getList(), index);
} else if (resultType.isF32()) {
rewriter.replaceOpWithNewOp<IREE::VM::ListGetF32Op>(
srcOp, resultType, adaptor.getList(), index);
} else if (resultType.isF64()) {
rewriter.replaceOpWithNewOp<IREE::VM::ListGetF64Op>(
srcOp, resultType, adaptor.getList(), index);
} else if (!resultType.isIntOrIndexOrFloat()) {
rewriter.replaceOpWithNewOp<IREE::VM::ListGetRefOp>(
srcOp, resultType, adaptor.getList(), index);
} else {
return srcOp.emitError() << "unsupported list element type in the VM";
}
return success();
}
};
class ListSetOpConversion : public OpConversionPattern<IREE::Util::ListSetOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Util::ListSetOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto index = castToI32(adaptor.getIndex(), rewriter);
auto valueType = adaptor.getValue().getType();
if (valueType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::ListSetI32Op>(
srcOp, adaptor.getList(), index, adaptor.getValue());
} else if (valueType.isInteger(64)) {
rewriter.replaceOpWithNewOp<IREE::VM::ListSetI64Op>(
srcOp, adaptor.getList(), index, adaptor.getValue());
} else if (valueType.isF32()) {
rewriter.replaceOpWithNewOp<IREE::VM::ListSetF32Op>(
srcOp, adaptor.getList(), index, adaptor.getValue());
} else if (valueType.isF64()) {
rewriter.replaceOpWithNewOp<IREE::VM::ListSetF64Op>(
srcOp, adaptor.getList(), index, adaptor.getValue());
} else if (!valueType.isIntOrIndexOrFloat()) {
rewriter.replaceOpWithNewOp<IREE::VM::ListSetRefOp>(
srcOp, adaptor.getList(), index, adaptor.getValue());
} else {
return srcOp.emitError() << "unsupported list element type in the VM";
}
return success();
}
};
} // namespace
void populateUtilListToVMPatterns(MLIRContext *context,
ConversionTarget &conversionTarget,
TypeConverter &typeConverter,
RewritePatternSet &patterns) {
typeConverter.addConversion(
[&typeConverter](IREE::Util::ListType type) -> std::optional<Type> {
Type elementType;
if (isa<IREE::Util::ObjectType>(type.getElementType()) ||
isa<IREE::Util::VariantType>(type.getElementType())) {
elementType = IREE::VM::OpaqueType::get(type.getContext());
} else {
elementType = typeConverter.convertType(type.getElementType());
}
if (!elementType)
return std::nullopt;
return IREE::VM::RefType::get(IREE::VM::ListType::get(elementType));
});
conversionTarget
.addIllegalOp<IREE::Util::ListCreateOp, IREE::Util::ListConstructOp,
IREE::Util::ListSizeOp, IREE::Util::ListResizeOp,
IREE::Util::ListGetOp, IREE::Util::ListSetOp>();
patterns.insert<ListCreateOpConversion, ListConstructOpConversion,
ListSizeOpConversion, ListResizeOpConversion,
ListGetOpConversion, ListSetOpConversion>(typeConverter,
context);
}
} // namespace mlir::iree_compiler