blob: fa837d41c992fe04b41458d3d6865a8f6f1c302e [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
#include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "llvm/ADT/StringExtras.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Parser/Parser.h"
namespace mlir::iree_compiler {
// TODO(benvanik): replace with iree/compiler/Utils/ModuleUtils.h.
// There may be some special insertion order arrangement required based on the
// nested vm.module here.
LogicalResult appendImportModule(IREE::VM::ModuleOp importModuleOp,
ModuleOp targetModuleOp) {
SymbolTable symbolTable(targetModuleOp);
OpBuilder targetBuilder(targetModuleOp);
targetBuilder.setInsertionPoint(&targetModuleOp.getBody()->back());
importModuleOp.walk([&](IREE::VM::ImportOp importOp) {
std::string fullName =
(importModuleOp.getName() + "." + importOp.getName()).str();
if (auto *existingOp = symbolTable.lookup(fullName)) {
existingOp->erase();
}
auto clonedOp = cast<IREE::VM::ImportOp>(targetBuilder.clone(*importOp));
mlir::StringAttr fullNameAttr =
mlir::StringAttr::get(clonedOp.getContext(), fullName);
clonedOp.setName(fullNameAttr);
clonedOp.setPrivate();
});
return success();
}
LogicalResult appendImportModule(StringRef importModuleSrc,
ModuleOp targetModuleOp) {
auto importModuleRef = mlir::parseSourceString<mlir::ModuleOp>(
importModuleSrc, targetModuleOp.getContext());
if (!importModuleRef) {
return targetModuleOp.emitError()
<< "unable to append import module; import module failed to parse";
}
for (auto importModuleOp : importModuleRef->getOps<IREE::VM::ModuleOp>()) {
if (failed(appendImportModule(importModuleOp, targetModuleOp))) {
importModuleOp.emitError() << "failed to import module";
}
}
return success();
}
Value castToImportType(Value value, Type targetType, OpBuilder &builder) {
auto sourceType = value.getType();
if (sourceType == targetType)
return value;
bool sourceIsInteger = llvm::isa<IntegerType>(sourceType);
// Allow bitcast between same width float/int types. This is used for
// marshalling to "untyped" VM interfaces, which will have an integer type.
if (llvm::isa<FloatType>(sourceType) && llvm::isa<IntegerType>(targetType) &&
sourceType.getIntOrFloatBitWidth() ==
targetType.getIntOrFloatBitWidth()) {
return builder.create<mlir::arith::BitcastOp>(value.getLoc(), targetType,
value);
} else if (sourceIsInteger &&
(targetType.isSignedInteger() || targetType.isSignlessInteger())) {
if (targetType.getIntOrFloatBitWidth() >
sourceType.getIntOrFloatBitWidth()) {
return builder.create<mlir::arith::ExtSIOp>(value.getLoc(), targetType,
value);
} else {
return builder.create<mlir::arith::TruncIOp>(value.getLoc(), targetType,
value);
}
} else if (sourceIsInteger && targetType.isUnsignedInteger()) {
if (targetType.getIntOrFloatBitWidth() >
sourceType.getIntOrFloatBitWidth()) {
return builder.create<mlir::arith::ExtUIOp>(value.getLoc(), targetType,
value);
} else {
return builder.create<mlir::arith::TruncIOp>(value.getLoc(), targetType,
value);
}
} else {
return value;
}
}
Value castFromImportType(Value value, Type targetType, OpBuilder &builder) {
// Right now the to-import and from-import types are the same.
return castToImportType(value, targetType, builder);
}
void copyImportAttrs(IREE::VM::ImportOp importOp, Operation *callOp) {
if (importOp->hasAttr("nosideeffects")) {
callOp->setAttr("nosideeffects", UnitAttr::get(importOp.getContext()));
}
}
namespace detail {
size_t getSegmentSpanSize(Type spanType) {
if (auto tupleType = llvm::dyn_cast<TupleType>(spanType)) {
return tupleType.size();
} else {
return 1;
}
}
std::optional<SmallVector<Value>> rewriteAttrToOperands(Location loc,
Attribute attrValue,
Type inputType,
OpBuilder &builder) {
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attrValue)) {
// NOTE: we intentionally go to std.constant ops so that the standard
// conversions can do their job. If we want to remove the dependency
// from standard ops in the future we could instead go directly to
// one of the vm constant ops.
auto constValue = builder.create<mlir::arith::ConstantOp>(
loc, inputType,
IntegerAttr::get(inputType, APInt(inputType.getIntOrFloatBitWidth(),
intAttr.getValue().getSExtValue())));
return {{constValue}};
} else if (auto elementsAttr =
llvm::dyn_cast<DenseIntElementsAttr>(attrValue)) {
SmallVector<Value> elementValues;
elementValues.reserve(elementsAttr.getNumElements());
for (auto intAttr : elementsAttr.getValues<Attribute>()) {
elementValues.push_back(builder.create<mlir::arith::ConstantOp>(
loc, elementsAttr.getType().getElementType(),
cast<TypedAttr>(intAttr)));
}
return elementValues;
} else if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attrValue)) {
SmallVector<Value> allValues;
for (auto elementAttr : arrayAttr) {
auto flattenedValues =
rewriteAttrToOperands(loc, elementAttr, inputType, builder);
if (!flattenedValues)
return std::nullopt;
allValues.append(flattenedValues->begin(), flattenedValues->end());
}
return allValues;
} else if (auto strAttr = llvm::dyn_cast<StringAttr>(attrValue)) {
return {{builder.create<IREE::VM::RodataInlineOp>(loc, strAttr)}};
}
// This may be a custom dialect type. As we can't trivially access the storage
// of these we need to ask the dialect to do it for us.
auto *conversionInterface =
attrValue.getDialect()
.getRegisteredInterface<VMConversionDialectInterface>();
if (conversionInterface) {
bool anyFailed = false;
SmallVector<Value> allValues;
if (auto tupleType = llvm::dyn_cast<TupleType>(inputType)) {
// Custom dialect type maps into a tuple; we expect 1:1 tuple elements to
// attribute storage elements.
auto tupleTypes = llvm::to_vector(tupleType.getTypes());
int ordinal = 0;
LogicalResult walkStatus = conversionInterface->walkAttributeStorage(
attrValue, [&](Attribute elementAttr) {
if (anyFailed)
return;
auto elementType = tupleTypes[ordinal++];
auto flattenedValues =
rewriteAttrToOperands(loc, elementAttr, elementType, builder);
if (!flattenedValues) {
anyFailed = true;
return;
}
allValues.append(flattenedValues->begin(), flattenedValues->end());
});
if (failed(walkStatus))
return std::nullopt;
} else {
// Custom dialect type maps into zero or more input types (ala arrays).
LogicalResult walkStatus = conversionInterface->walkAttributeStorage(
attrValue, [&](Attribute elementAttr) {
if (anyFailed)
return;
auto flattenedValues =
rewriteAttrToOperands(loc, elementAttr, inputType, builder);
if (!flattenedValues) {
anyFailed = true;
return;
}
allValues.append(flattenedValues->begin(), flattenedValues->end());
});
if (failed(walkStatus))
return std::nullopt;
}
if (anyFailed)
return std::nullopt;
return allValues;
}
emitError(loc) << "unsupported attribute encoding: " << attrValue;
return std::nullopt;
}
} // namespace detail
} // namespace mlir::iree_compiler