| // Copyright 2019 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Dialect/Util/IR/UtilOps.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" |
| #include "iree/compiler/Dialect/VM/IR/VMOps.h" |
| #include "iree/compiler/Dialect/VM/IR/VMTypes.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| |
| namespace mlir::iree_compiler { |
| |
| namespace { |
| |
| struct GlobalOpConversion : public OpConversionPattern<IREE::Util::GlobalOp> { |
| TypeConverter &typeConverter; |
| GlobalOpConversion(MLIRContext *context, TypeConverter &typeConverter) |
| : OpConversionPattern(context), typeConverter(typeConverter) {} |
| LogicalResult |
| matchAndRewrite(IREE::Util::GlobalOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Operation *newOp = nullptr; |
| auto convertedType = typeConverter.convertType(op.getType()); |
| const bool isInitialized = |
| op.getInitialValueAttr() && |
| !isa<IREE::Util::UninitializedAttr>(op.getInitialValueAttr()); |
| if (isa<IREE::VM::RefType>(convertedType) || |
| IREE::VM::RefType::isCompatible(convertedType)) { |
| newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalRefOp>( |
| op, op.getSymName(), op.getIsMutable(), convertedType, |
| llvm::to_vector(op->getDialectAttrs())); |
| } else if (convertedType.isInteger(32)) { |
| std::optional<TypedAttr> convertedValue = std::nullopt; |
| if (isInitialized) { |
| convertedValue = rewriter.getI32IntegerAttr(static_cast<int32_t>( |
| cast<IntegerAttr>(op.getInitialValue().value()).getInt())); |
| } |
| newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalI32Op>( |
| op, op.getSymName(), op.getIsMutable(), convertedType, convertedValue, |
| llvm::to_vector(op->getDialectAttrs())); |
| } else if (convertedType.isInteger(64)) { |
| std::optional<TypedAttr> convertedValue = std::nullopt; |
| if (isInitialized) { |
| convertedValue = rewriter.getI64IntegerAttr( |
| cast<IntegerAttr>(op.getInitialValue().value()).getInt()); |
| } |
| newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalI64Op>( |
| op, op.getSymName(), op.getIsMutable(), convertedType, convertedValue, |
| llvm::to_vector(op->getDialectAttrs())); |
| } else if (convertedType.isF32()) { |
| std::optional<TypedAttr> convertedValue = std::nullopt; |
| if (isInitialized) { |
| convertedValue = rewriter.getF32FloatAttr(static_cast<float>( |
| cast<FloatAttr>(op.getInitialValue().value()).getValueAsDouble())); |
| } |
| newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalF32Op>( |
| op, op.getSymName(), op.getIsMutable(), convertedType, convertedValue, |
| llvm::to_vector(op->getDialectAttrs())); |
| } else if (convertedType.isF64()) { |
| std::optional<TypedAttr> convertedValue = std::nullopt; |
| if (isInitialized) { |
| convertedValue = rewriter.getF64FloatAttr( |
| cast<FloatAttr>(op.getInitialValue().value()).getValueAsDouble()); |
| } |
| newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalF64Op>( |
| op, op.getSymName(), op.getIsMutable(), convertedType, convertedValue, |
| llvm::to_vector(op->getDialectAttrs())); |
| } else { |
| return op.emitOpError("unsupported global type"); |
| } |
| |
| // New global carries the same visibility as the original. |
| cast<SymbolOpInterface>(newOp).setVisibility(op.getVisibility()); |
| |
| return success(); |
| } |
| }; |
| |
| struct GlobalAddressOpConversion |
| : public OpConversionPattern<IREE::Util::GlobalAddressOp> { |
| TypeConverter &typeConverter; |
| GlobalAddressOpConversion(MLIRContext *context, TypeConverter &typeConverter) |
| : OpConversionPattern(context), typeConverter(typeConverter) {} |
| LogicalResult |
| matchAndRewrite(IREE::Util::GlobalAddressOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOpWithNewOp<IREE::VM::GlobalAddressOp>( |
| op, typeConverter.convertType(op.getType()), op.getGlobalAttr(), |
| op.getIsImmutableAttr()); |
| return success(); |
| } |
| }; |
| |
| struct GlobalLoadOpConversion |
| : public OpConversionPattern<IREE::Util::GlobalLoadOp> { |
| TypeConverter &typeConverter; |
| GlobalLoadOpConversion(MLIRContext *context, TypeConverter &typeConverter) |
| : OpConversionPattern(context), typeConverter(typeConverter) {} |
| LogicalResult |
| matchAndRewrite(IREE::Util::GlobalLoadOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto operandType = op.getType(); |
| auto convertedType = typeConverter.convertType(operandType); |
| if (IREE::VM::RefType::isCompatible(operandType)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadRefOp>( |
| op, convertedType, op.getGlobalAttr(), adaptor.getIsImmutableAttr()); |
| } else if (convertedType.isInteger(32)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadI32Op>( |
| op, convertedType, op.getGlobalAttr(), adaptor.getIsImmutableAttr()); |
| } else if (convertedType.isInteger(64)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadI64Op>( |
| op, convertedType, op.getGlobalAttr(), adaptor.getIsImmutableAttr()); |
| } else if (convertedType.isF32()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadF32Op>( |
| op, convertedType, op.getGlobalAttr(), adaptor.getIsImmutableAttr()); |
| } else if (convertedType.isF64()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadF64Op>( |
| op, convertedType, op.getGlobalAttr(), adaptor.getIsImmutableAttr()); |
| } else { |
| return rewriter.notifyMatchFailure(op, "unhandled global type"); |
| } |
| return success(); |
| } |
| }; |
| |
| struct GlobalLoadIndirectOpConversion |
| : public OpConversionPattern<IREE::Util::GlobalLoadIndirectOp> { |
| TypeConverter &typeConverter; |
| GlobalLoadIndirectOpConversion(MLIRContext *context, |
| TypeConverter &typeConverter) |
| : OpConversionPattern(context), typeConverter(typeConverter) {} |
| LogicalResult |
| matchAndRewrite(IREE::Util::GlobalLoadIndirectOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto operandType = op.getType(); |
| auto convertedType = typeConverter.convertType(operandType); |
| if (IREE::VM::RefType::isCompatible(operandType)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadIndirectRefOp>( |
| op, convertedType, adaptor.getGlobal(), adaptor.getIsImmutableAttr()); |
| } else if (convertedType.isInteger(32)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadIndirectI32Op>( |
| op, convertedType, adaptor.getGlobal(), adaptor.getIsImmutableAttr()); |
| } else if (convertedType.isInteger(64)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadIndirectI64Op>( |
| op, convertedType, adaptor.getGlobal(), adaptor.getIsImmutableAttr()); |
| } else if (convertedType.isF32()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadIndirectF32Op>( |
| op, convertedType, adaptor.getGlobal(), adaptor.getIsImmutableAttr()); |
| } else if (convertedType.isF64()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadIndirectF64Op>( |
| op, convertedType, adaptor.getGlobal(), adaptor.getIsImmutableAttr()); |
| } else { |
| return rewriter.notifyMatchFailure(op, "unhandled global type"); |
| } |
| return success(); |
| } |
| }; |
| |
| struct GlobalStoreOpConversion |
| : public OpConversionPattern<IREE::Util::GlobalStoreOp> { |
| GlobalStoreOpConversion(MLIRContext *context, TypeConverter &typeConverter) |
| : OpConversionPattern(context) {} |
| LogicalResult |
| matchAndRewrite(IREE::Util::GlobalStoreOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto operandType = adaptor.getValue().getType(); |
| if (isa<IREE::VM::RefType>(operandType)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::GlobalStoreRefOp>( |
| op, adaptor.getValue(), op.getGlobal()); |
| } else if (operandType.isInteger(32)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::GlobalStoreI32Op>( |
| op, adaptor.getValue(), op.getGlobal()); |
| } else if (operandType.isInteger(64)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::GlobalStoreI64Op>( |
| op, adaptor.getValue(), op.getGlobal()); |
| } else if (operandType.isF32()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::GlobalStoreF32Op>( |
| op, adaptor.getValue(), op.getGlobal()); |
| } else if (operandType.isF64()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::GlobalStoreF64Op>( |
| op, adaptor.getValue(), op.getGlobal()); |
| } else { |
| return rewriter.notifyMatchFailure(op, "unhandled global type"); |
| } |
| return success(); |
| } |
| }; |
| |
| struct GlobalStoreIndirectOpConversion |
| : public OpConversionPattern<IREE::Util::GlobalStoreIndirectOp> { |
| GlobalStoreIndirectOpConversion(MLIRContext *context, |
| TypeConverter &typeConverter) |
| : OpConversionPattern(context) {} |
| LogicalResult |
| matchAndRewrite(IREE::Util::GlobalStoreIndirectOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto operandType = adaptor.getValue().getType(); |
| if (isa<IREE::VM::RefType>(operandType)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::GlobalStoreIndirectRefOp>( |
| op, adaptor.getValue(), adaptor.getGlobal()); |
| } else if (operandType.isInteger(32)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::GlobalStoreIndirectI32Op>( |
| op, adaptor.getValue(), adaptor.getGlobal()); |
| } else if (operandType.isInteger(64)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::GlobalStoreIndirectI64Op>( |
| op, adaptor.getValue(), adaptor.getGlobal()); |
| } else if (operandType.isF32()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::GlobalStoreIndirectF32Op>( |
| op, adaptor.getValue(), adaptor.getGlobal()); |
| } else if (operandType.isF64()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::GlobalStoreIndirectF64Op>( |
| op, adaptor.getValue(), adaptor.getGlobal()); |
| } else { |
| return rewriter.notifyMatchFailure(op, "unhandled global type"); |
| } |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void populateUtilGlobalToVMPatterns(MLIRContext *context, |
| ConversionTarget &conversionTarget, |
| TypeConverter &typeConverter, |
| RewritePatternSet &patterns) { |
| conversionTarget.addIllegalOp< |
| IREE::Util::GlobalOp, IREE::Util::GlobalAddressOp, |
| IREE::Util::GlobalLoadOp, IREE::Util::GlobalLoadIndirectOp, |
| IREE::Util::GlobalStoreOp, IREE::Util::GlobalStoreIndirectOp>(); |
| patterns.insert<GlobalOpConversion, GlobalAddressOpConversion, |
| GlobalLoadOpConversion, GlobalLoadIndirectOpConversion, |
| GlobalStoreOpConversion, GlobalStoreIndirectOpConversion>( |
| context, typeConverter); |
| } |
| |
| } // namespace mlir::iree_compiler |