blob: e65015ddc8627ca25c4647b2ebbe324f21edccec [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "llvm/ADT/DenseMap.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace iree_compiler {
namespace {
class FuncOpSignatureConversion : public OpConversionPattern<mlir::FuncOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mlir::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto &typeConverter = *getTypeConverter();
// Convert the input signature types.
// TODO(benvanik): dynamic shapes by passing in tensor dynamic dims.
auto originalType = funcOp.getType();
TypeConverter::SignatureConversion newSignature(
originalType.getNumInputs());
for (auto argType : llvm::enumerate(originalType.getInputs())) {
if (failed(typeConverter.convertSignatureArg(
argType.index(), argType.value(), newSignature))) {
return failure();
}
}
SmallVector<Type, 4> newResultTypes;
if (failed(typeConverter.convertTypes(originalType.getResults(),
newResultTypes))) {
return failure();
}
// Replace function.
auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
newFuncOp.getBlocks().clear();
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
newFuncOp.setType(rewriter.getFunctionType(newSignature.getConvertedTypes(),
newResultTypes));
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
&newSignature))) {
return failure();
}
rewriter.eraseOp(funcOp);
return success();
}
};
class CallOpConversion : public OpConversionPattern<mlir::CallOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mlir::CallOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type, 4> resultTypes;
if (failed(getTypeConverter()->convertTypes(op.getResultTypes(),
resultTypes))) {
return rewriter.notifyMatchFailure(op, "unable to convert result types");
}
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, resultTypes, op.callee(),
adaptor.operands());
return success();
}
};
class BranchOpConversion : public OpConversionPattern<mlir::BranchOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mlir::BranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::BranchOp>(op, op.dest(),
adaptor.destOperands());
return success();
}
};
class CondBranchOpConversion : public OpConversionPattern<mlir::CondBranchOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mlir::CondBranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::CondBranchOp>(
op, adaptor.condition(), op.trueDest(), adaptor.trueDestOperands(),
op.falseDest(), adaptor.falseDestOperands());
return success();
}
};
class ReturnOpConversion : public OpConversionPattern<mlir::ReturnOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mlir::ReturnOp returnOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::ReturnOp>(returnOp,
adaptor.getOperands());
return success();
}
};
class SelectOpConversion : public OpConversionPattern<mlir::SelectOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mlir::SelectOp selectOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::SelectOp>(selectOp, adaptor.condition(),
adaptor.true_value(),
adaptor.false_value());
return success();
}
};
} // namespace
void populateStandardStructuralToHALPatterns(MLIRContext *context,
OwningRewritePatternList &patterns,
TypeConverter &converter) {
patterns
.insert<FuncOpSignatureConversion, CallOpConversion, BranchOpConversion,
CondBranchOpConversion, ReturnOpConversion, SelectOpConversion>(
converter, context);
}
} // namespace iree_compiler
} // namespace mlir