| // Copyright 2019 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h" |
| |
| #include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h" |
| #include "iree/compiler/Dialect/VM/IR/VMOps.h" |
| #include "llvm/ADT/StringExtras.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/Parser/Parser.h" |
| |
| namespace mlir::iree_compiler { |
| |
| LogicalResult ImportTable::build(Operation *rootOp, |
| const TypeConverter &typeConverter) { |
| for (auto funcOp : rootOp->getRegion(0).getOps<FunctionOpInterface>()) { |
| if (!funcOp.isExternal()) { |
| continue; // only external functions are imports |
| } |
| |
| ImportTable::Import import; |
| import.name = funcOp.getNameAttr(); |
| import.fallback = funcOp->getAttrOfType<SymbolRefAttr>("vm.fallback"); |
| |
| // Try to use an assigned signature or fall back to converting the input. |
| if (auto importOp = dyn_cast<IREE::VM::ImportOp>(funcOp.getOperation())) { |
| // Import ops have their signature used directly. |
| import.signature = importOp.getFunctionType(); |
| } else if (auto signatureAttr = |
| funcOp->getAttrOfType<TypeAttr>("vm.signature")) { |
| // Directly use the specified signature. |
| import.signature = |
| dyn_cast_if_present<FunctionType>(signatureAttr.getValue()); |
| } |
| if (!import.signature) { |
| // Convert the signature using the type converter. |
| SmallVector<Type> argumentTypes; |
| if (failed(typeConverter.convertTypes(funcOp.getArgumentTypes(), |
| argumentTypes))) { |
| return funcOp.emitError() << "unable to convert import argument types"; |
| } |
| SmallVector<Type> resultTypes; |
| if (failed(typeConverter.convertTypes(funcOp.getResultTypes(), |
| resultTypes))) { |
| return funcOp.emitError() << "unable to convert import result types"; |
| } |
| import.signature = |
| FunctionType::get(rootOp->getContext(), argumentTypes, resultTypes); |
| } |
| |
| symbols[import.name.getValue()] = std::move(import); |
| } |
| |
| return success(); |
| } |
| |
| std::optional<ImportTable::Import> ImportTable::find(StringRef symbolName) { |
| auto it = symbols.find(symbolName); |
| if (it == symbols.end()) |
| return std::nullopt; |
| return it->second; |
| } |
| |
| // TODO(benvanik): replace with iree/compiler/Utils/ModuleUtils.h. |
| // There may be some special insertion order arrangement required based on the |
| // nested vm.module here. |
| |
| LogicalResult appendImportModule(IREE::VM::ModuleOp importModuleOp, |
| ModuleOp targetModuleOp) { |
| SymbolTable symbolTable(targetModuleOp); |
| OpBuilder targetBuilder(targetModuleOp); |
| targetBuilder.setInsertionPoint(&targetModuleOp.getBody()->back()); |
| importModuleOp.walk([&](IREE::VM::ImportOp importOp) { |
| std::string fullName = |
| (importModuleOp.getName() + "." + importOp.getName()).str(); |
| if (auto *existingOp = symbolTable.lookup(fullName)) { |
| existingOp->erase(); |
| } |
| auto clonedOp = cast<IREE::VM::ImportOp>(targetBuilder.clone(*importOp)); |
| mlir::StringAttr fullNameAttr = |
| mlir::StringAttr::get(clonedOp.getContext(), fullName); |
| clonedOp.setName(fullNameAttr); |
| clonedOp.setPrivate(); |
| }); |
| return success(); |
| } |
| |
| LogicalResult appendImportModule(StringRef importModuleSrc, |
| ModuleOp targetModuleOp) { |
| auto importModuleRef = mlir::parseSourceString<mlir::ModuleOp>( |
| importModuleSrc, targetModuleOp.getContext()); |
| if (!importModuleRef) { |
| return targetModuleOp.emitError() |
| << "unable to append import module; import module failed to parse"; |
| } |
| for (auto importModuleOp : importModuleRef->getOps<IREE::VM::ModuleOp>()) { |
| if (failed(appendImportModule(importModuleOp, targetModuleOp))) { |
| importModuleOp.emitError() << "failed to import module"; |
| } |
| } |
| return success(); |
| } |
| |
| Value castToImportType(Value value, Type targetType, OpBuilder &builder) { |
| auto sourceType = value.getType(); |
| if (sourceType == targetType) |
| return value; |
| bool sourceIsInteger = isa<IntegerType>(sourceType); |
| |
| // Allow bitcast between same width float/int types. This is used for |
| // marshalling to "untyped" VM interfaces, which will have an integer type. |
| if (isa<FloatType>(sourceType) && isa<IntegerType>(targetType) && |
| sourceType.getIntOrFloatBitWidth() == |
| targetType.getIntOrFloatBitWidth()) { |
| return mlir::arith::BitcastOp::create(builder, value.getLoc(), targetType, |
| value); |
| } else if (sourceIsInteger && |
| (targetType.isSignedInteger() || targetType.isSignlessInteger())) { |
| if (targetType.getIntOrFloatBitWidth() > |
| sourceType.getIntOrFloatBitWidth()) { |
| return mlir::arith::ExtSIOp::create(builder, value.getLoc(), targetType, |
| value); |
| } else { |
| return mlir::arith::TruncIOp::create(builder, value.getLoc(), targetType, |
| value); |
| } |
| } else if (sourceIsInteger && targetType.isUnsignedInteger()) { |
| if (targetType.getIntOrFloatBitWidth() > |
| sourceType.getIntOrFloatBitWidth()) { |
| return mlir::arith::ExtUIOp::create(builder, value.getLoc(), targetType, |
| value); |
| } else { |
| return mlir::arith::TruncIOp::create(builder, value.getLoc(), targetType, |
| value); |
| } |
| } else { |
| return value; |
| } |
| } |
| |
| Value castFromImportType(Value value, Type targetType, OpBuilder &builder) { |
| // Right now the to-import and from-import types are the same. |
| return castToImportType(value, targetType, builder); |
| } |
| |
| void copyImportAttrs(IREE::VM::ImportOp importOp, Operation *callOp) { |
| if (importOp->hasAttr("nosideeffects")) { |
| callOp->setAttr("nosideeffects", UnitAttr::get(importOp.getContext())); |
| } |
| } |
| |
| namespace detail { |
| |
| size_t getSegmentSpanSize(Type spanType) { |
| if (auto tupleType = dyn_cast<TupleType>(spanType)) { |
| return tupleType.size(); |
| } else { |
| return 1; |
| } |
| } |
| |
| std::optional<SmallVector<Value>> rewriteAttrToOperands(Location loc, |
| Attribute attrValue, |
| Type inputType, |
| OpBuilder &builder) { |
| if (auto intAttr = dyn_cast<IntegerAttr>(attrValue)) { |
| // NOTE: we intentionally go to std.constant ops so that the standard |
| // conversions can do their job. If we want to remove the dependency |
| // from standard ops in the future we could instead go directly to |
| // one of the vm constant ops. |
| auto constValue = mlir::arith::ConstantOp::create( |
| builder, loc, inputType, |
| IntegerAttr::get(inputType, APInt(inputType.getIntOrFloatBitWidth(), |
| intAttr.getValue().getSExtValue()))); |
| return {{constValue}}; |
| } else if (auto floatAttr = dyn_cast<FloatAttr>(attrValue)) { |
| bool lossy = false; |
| APFloat value = floatAttr.getValue(); |
| value.convert(cast<FloatType>(inputType).getFloatSemantics(), |
| llvm::RoundingMode::NearestTiesToEven, &lossy); |
| auto constValue = mlir::arith::ConstantOp::create( |
| builder, loc, inputType, FloatAttr::get(inputType, value)); |
| return {{constValue}}; |
| } else if (auto elementsAttr = dyn_cast<DenseIntElementsAttr>(attrValue)) { |
| SmallVector<Value> elementValues; |
| elementValues.reserve(elementsAttr.getNumElements()); |
| for (auto intAttr : elementsAttr.getValues<Attribute>()) { |
| elementValues.push_back(mlir::arith::ConstantOp::create( |
| builder, loc, elementsAttr.getType().getElementType(), |
| cast<TypedAttr>(intAttr))); |
| } |
| return elementValues; |
| } else if (auto arrayAttr = dyn_cast<ArrayAttr>(attrValue)) { |
| SmallVector<Value> allValues; |
| for (auto elementAttr : arrayAttr) { |
| auto flattenedValues = |
| rewriteAttrToOperands(loc, elementAttr, inputType, builder); |
| if (!flattenedValues) |
| return std::nullopt; |
| allValues.append(flattenedValues->begin(), flattenedValues->end()); |
| } |
| return allValues; |
| } else if (auto strAttr = dyn_cast<StringAttr>(attrValue)) { |
| return {{IREE::VM::RodataInlineOp::create(builder, loc, strAttr)}}; |
| } |
| |
| // This may be a custom dialect type. As we can't trivially access the storage |
| // of these we need to ask the dialect to do it for us. |
| auto *conversionInterface = |
| attrValue.getDialect() |
| .getRegisteredInterface<VMConversionDialectInterface>(); |
| if (conversionInterface) { |
| bool anyFailed = false; |
| SmallVector<Value> allValues; |
| if (auto tupleType = dyn_cast<TupleType>(inputType)) { |
| // Custom dialect type maps into a tuple; we expect 1:1 tuple elements to |
| // attribute storage elements. |
| auto tupleTypes = llvm::to_vector(tupleType.getTypes()); |
| int ordinal = 0; |
| LogicalResult walkStatus = conversionInterface->walkAttributeStorage( |
| attrValue, [&](Attribute elementAttr) { |
| if (anyFailed) |
| return; |
| auto elementType = tupleTypes[ordinal++]; |
| auto flattenedValues = |
| rewriteAttrToOperands(loc, elementAttr, elementType, builder); |
| if (!flattenedValues) { |
| anyFailed = true; |
| return; |
| } |
| allValues.append(flattenedValues->begin(), flattenedValues->end()); |
| }); |
| if (failed(walkStatus)) |
| return std::nullopt; |
| } else { |
| // Custom dialect type maps into zero or more input types (ala arrays). |
| LogicalResult walkStatus = conversionInterface->walkAttributeStorage( |
| attrValue, [&](Attribute elementAttr) { |
| if (anyFailed) |
| return; |
| auto flattenedValues = |
| rewriteAttrToOperands(loc, elementAttr, inputType, builder); |
| if (!flattenedValues) { |
| anyFailed = true; |
| return; |
| } |
| allValues.append(flattenedValues->begin(), flattenedValues->end()); |
| }); |
| if (failed(walkStatus)) |
| return std::nullopt; |
| } |
| if (anyFailed) |
| return std::nullopt; |
| return allValues; |
| } |
| |
| emitError(loc) << "unsupported attribute encoding: " << attrValue; |
| return std::nullopt; |
| } |
| |
| } // namespace detail |
| |
| } // namespace mlir::iree_compiler |