blob: 3923e96a6435c49eb499440df104edf25f5ed661 [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h"
#include "llvm/ADT/DenseMap.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace iree_compiler {
namespace {
class FuncOpSignatureConversion : public OpConversionPattern<mlir::FuncOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mlir::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto &typeConverter = *getTypeConverter();
// Convert the input signature types.
// TODO(benvanik): dynamic shapes by passing in tensor dynamic dims.
auto originalType = funcOp.getType();
TypeConverter::SignatureConversion newSignature(
originalType.getNumInputs());
for (auto argType : llvm::enumerate(originalType.getInputs())) {
if (failed(typeConverter.convertSignatureArg(
argType.index(), argType.value(), newSignature))) {
return failure();
}
}
SmallVector<Type, 4> newResultTypes;
if (failed(typeConverter.convertTypes(originalType.getResults(),
newResultTypes))) {
return failure();
}
// Replace function.
auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
newFuncOp.getBlocks().clear();
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
newFuncOp.setType(rewriter.getFunctionType(newSignature.getConvertedTypes(),
newResultTypes));
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
&newSignature))) {
return failure();
}
rewriter.eraseOp(funcOp);
return success();
}
};
class CallOpConversion : public OpConversionPattern<mlir::CallOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mlir::CallOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type, 4> resultTypes;
if (failed(getTypeConverter()->convertTypes(op.getResultTypes(),
resultTypes))) {
return rewriter.notifyMatchFailure(op, "unable to convert result types");
}
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, resultTypes, op.getCallee(),
adaptor.operands());
return success();
}
};
class BranchOpConversion : public OpConversionPattern<mlir::BranchOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mlir::BranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::BranchOp>(op, op.getDest(),
adaptor.getDestOperands());
return success();
}
};
class CondBranchOpConversion : public OpConversionPattern<mlir::CondBranchOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mlir::CondBranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::CondBranchOp>(
op, adaptor.getCondition(), op.getTrueDest(),
adaptor.getTrueDestOperands(), op.getFalseDest(),
adaptor.getFalseDestOperands());
return success();
}
};
class ReturnOpConversion : public OpConversionPattern<mlir::ReturnOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mlir::ReturnOp returnOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::ReturnOp>(returnOp, adaptor.operands());
return success();
}
};
class SelectOpConversion : public OpConversionPattern<mlir::SelectOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mlir::SelectOp selectOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::SelectOp>(
selectOp, adaptor.getCondition(), adaptor.getTrueValue(),
adaptor.getFalseValue());
return success();
}
};
struct ConvertIfOp : public OpConversionPattern<scf::IfOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
scf::IfOp ifOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultTypes = llvm::to_vector<4>(llvm::map_range(
ifOp.getResultTypes(),
[&](Type type) { return getTypeConverter()->convertType(type); }));
auto newOp = rewriter.create<scf::IfOp>(ifOp.getLoc(), resultTypes,
adaptor.getCondition(),
ifOp.elseBlock() != nullptr);
rewriter.inlineRegionBefore(ifOp.getThenRegion(), newOp.getThenRegion(),
newOp.getThenRegion().end());
rewriter.eraseBlock(&newOp.getThenRegion().front());
if (ifOp.elseBlock()) {
rewriter.inlineRegionBefore(ifOp.getElseRegion(), newOp.getElseRegion(),
newOp.getElseRegion().end());
rewriter.eraseBlock(&newOp.getElseRegion().front());
}
rewriter.replaceOp(ifOp, newOp.getResults());
return success();
}
};
struct ConvertYieldOp : public OpConversionPattern<scf::YieldOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
scf::YieldOp yieldOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, adaptor.getResults());
return success();
}
};
} // namespace
void populateStandardStructuralToHALPatterns(MLIRContext *context,
ConversionTarget &conversionTarget,
OwningRewritePatternList &patterns,
TypeConverter &typeConverter) {
conversionTarget.addLegalOp<mlir::ModuleOp>();
// We need to rewrite certain types on operands/results so use the default
// dynamic legality checker to force any ops using such types to run through
// our patterns.
conversionTarget.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp op) {
return typeConverter.isSignatureLegal(op.getType()) &&
typeConverter.isLegal(&op.getBody());
});
patterns
.insert<FuncOpSignatureConversion, CallOpConversion, BranchOpConversion,
CondBranchOpConversion, ReturnOpConversion, SelectOpConversion>(
typeConverter, context);
// TODO(benvanik): move to general utils conversion.
addGenericLegalOp<scf::IfOp>(conversionTarget, typeConverter);
addGenericLegalOp<scf::YieldOp>(conversionTarget, typeConverter);
patterns.insert<ConvertIfOp, ConvertYieldOp>(typeConverter, context);
}
} // namespace iree_compiler
} // namespace mlir