blob: 1ed4b620f35e3a0a396da3c716cfcf3c9c5480be [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/IR/Dialect.h"
#include "iree/compiler/IR/Interpreter/HLDialect.h"
#include "iree/compiler/IR/Interpreter/HLOps.h"
#include "iree/compiler/IR/Interpreter/LLDialect.h"
#include "iree/compiler/IR/Ops.h"
#include "iree/compiler/Transforms/ConversionUtils.h"
#include "iree/compiler/Utils/MemRefUtils.h"
#include "iree/compiler/Utils/OpCreationUtils.h"
#include "iree/compiler/Utils/TypeConversionUtils.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/Allocator.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace iree_compiler {
namespace {
struct CallOpLowering : public OpConversionPattern<CallOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
CallOp op, ArrayRef<ValuePtr> operands,
ConversionPatternRewriter &rewriter) const override {
auto callOp = cast<CallOp>(op);
auto calleeType = callOp.getCalleeType();
rewriter.replaceOpWithNewOp<IREEInterp::HL::CallOp>(
op, callOp.getCallee(), calleeType.getResults(), operands);
return matchSuccess();
}
};
struct CallIndirectOpLowering : public OpConversionPattern<CallIndirectOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
CallIndirectOp op, ArrayRef<ValuePtr> operands,
ConversionPatternRewriter &rewriter) const override {
auto callOp = cast<CallIndirectOp>(op);
rewriter.replaceOpWithNewOp<IREEInterp::HL::CallIndirectOp>(
op, callOp.getCallee(), operands);
return matchSuccess();
}
};
struct ReturnOpLowering : public OpConversionPattern<ReturnOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
ReturnOp op, ArrayRef<ValuePtr> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREEInterp::HL::ReturnOp>(op, operands);
return matchSuccess();
}
};
struct BranchOpLowering : public OpConversionPattern<BranchOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
BranchOp op, ArrayRef<ValuePtr> properOperands,
ArrayRef<Block *> destinations, ArrayRef<ArrayRef<ValuePtr>> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREEInterp::HL::BranchOp>(op, destinations[0],
operands[0]);
return this->matchSuccess();
}
};
struct CondBranchOpLowering : public OpConversionPattern<CondBranchOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
CondBranchOp op, ArrayRef<ValuePtr> properOperands,
ArrayRef<Block *> destinations, ArrayRef<ArrayRef<ValuePtr>> operands,
ConversionPatternRewriter &rewriter) const override {
auto condValue = loadAccessValue(op.getLoc(), properOperands[0], rewriter);
rewriter.replaceOpWithNewOp<IREEInterp::HL::CondBranchOp>(
op, condValue, destinations[IREEInterp::HL::CondBranchOp::trueIndex],
operands[IREEInterp::HL::CondBranchOp::trueIndex],
destinations[IREEInterp::HL::CondBranchOp::falseIndex],
operands[IREEInterp::HL::CondBranchOp::falseIndex]);
return this->matchSuccess();
}
};
template <typename SrcOp, typename DstOp>
struct CompareOpLowering : public OpConversionPattern<SrcOp> {
using OpConversionPattern<SrcOp>::OpConversionPattern;
PatternMatchResult matchAndRewrite(
SrcOp op, ArrayRef<ValuePtr> operands,
ConversionPatternRewriter &rewriter) const override {
auto lhValue = loadAccessValue(op.getLoc(), operands[0], rewriter);
auto rhValue = loadAccessValue(op.getLoc(), operands[1], rewriter);
lhValue = wrapAsMemRef(lhValue, op, rewriter);
rhValue = wrapAsMemRef(rhValue, op, rewriter);
// TODO(benvanik): map predicate to stable value.
auto predicate =
rewriter.getI32IntegerAttr(static_cast<int32_t>(op.getPredicate()));
auto dstType = convertTypeToMemRef(op.getResult());
auto midOp = rewriter.create<DstOp>(op.getLoc(), dstType, predicate,
lhValue, rhValue);
auto result = wrapAsTensor(midOp.getResult(), op, rewriter);
rewriter.replaceOp(
op, {loadResultValue(op.getLoc(), op.getType(), result, rewriter)});
return this->matchSuccess();
}
};
struct CmpIOpLowering
: public CompareOpLowering<CmpIOp, IREEInterp::HL::CmpIOp> {
using CompareOpLowering::CompareOpLowering;
};
struct CmpFOpLowering
: public CompareOpLowering<CmpFOp, IREEInterp::HL::CmpFOp> {
using CompareOpLowering::CompareOpLowering;
};
struct AllocOpLowering : public OpConversionPattern<AllocOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
AllocOp op, ArrayRef<ValuePtr> operands,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): replace with length computation.
rewriter.replaceOpWithNewOp<IREEInterp::HL::AllocHeapOp>(op, op.getType(),
operands);
return matchSuccess();
}
};
struct DeallocOpLowering : public OpConversionPattern<DeallocOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
DeallocOp op, ArrayRef<ValuePtr> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREEInterp::HL::DiscardOp>(op, operands[0]);
return matchSuccess();
}
};
struct LoadOpLowering : public OpRewritePattern<LoadOp> {
using OpRewritePattern::OpRewritePattern;
PatternMatchResult matchAndRewrite(LoadOp loadOp,
PatternRewriter &rewriter) const override {
if (loadOp.getMemRefType().getRank() != 0) {
loadOp.emitError() << "Cannot lower load of non-scalar";
return matchFailure();
}
ArrayRef<ValuePtr> dimPieces;
auto dst =
rewriter
.create<AllocOp>(loadOp.getLoc(), loadOp.getMemRefType(), dimPieces)
.getResult();
auto emptyArrayMemref = createArrayConstant(rewriter, loadOp.getLoc(), {});
rewriter.create<IREEInterp::HL::CopyOp>(
loadOp.getLoc(), loadOp.getMemRef(),
/*srcIndices=*/emptyArrayMemref, dst,
/*dstIndices=*/emptyArrayMemref, /*lengths=*/emptyArrayMemref);
rewriter.replaceOpWithNewOp<IREE::MemRefToScalarOp>(loadOp, dst);
return matchSuccess();
}
};
struct StoreOpLowering : public OpRewritePattern<StoreOp> {
using OpRewritePattern::OpRewritePattern;
PatternMatchResult matchAndRewrite(StoreOp storeOp,
PatternRewriter &rewriter) const override {
if (storeOp.getMemRefType().getRank() != 0) {
storeOp.emitError() << "Cannot lower store of non-scalar";
return matchFailure();
}
auto src = rewriter.create<IREE::ScalarToMemRefOp>(
storeOp.getLoc(), storeOp.getValueToStore());
auto emptyArrayMemref = createArrayConstant(rewriter, storeOp.getLoc(), {});
rewriter.replaceOpWithNewOp<IREEInterp::HL::CopyOp>(
storeOp, src, /*srcIndices=*/emptyArrayMemref, storeOp.getMemRef(),
/*dstIndices=*/emptyArrayMemref, /*lengths=*/emptyArrayMemref);
return matchSuccess();
}
};
#define UNARY_OP_LOWERING(StdOpType, IREEOpType) \
struct StdOpType##Lowering : public UnaryOpLowering<StdOpType, IREEOpType> { \
using UnaryOpLowering::UnaryOpLowering; \
};
#define BINARY_OP_LOWERING(StdOpType, IREEOpType) \
struct StdOpType##Lowering \
: public BinaryOpLowering<StdOpType, IREEOpType> { \
using BinaryOpLowering::BinaryOpLowering; \
};
#define TERNARY_OP_LOWERING(StdOpType, IREEOpType) \
struct StdOpType##Lowering \
: public TernaryOpLowering<StdOpType, IREEOpType> { \
using TernaryOpLowering::TernaryOpLowering; \
};
// UNARY_OP_LOWERING(RankOp, IREEInterp::HL::RankOp);
UNARY_OP_LOWERING(DimOp, IREEInterp::HL::DimOp);
// UNARY_OP_LOWERING(ShapeOp, IREEInterp::HL::ShapeOp);
// UNARY_OP_LOWERING(LengthOp, IREEInterp::HL::LengthOp);
// UNARY_OP_LOWERING(NotOp, IREEInterp::HL::NotOp);
BINARY_OP_LOWERING(AndOp, IREEInterp::HL::AndOp);
BINARY_OP_LOWERING(OrOp, IREEInterp::HL::OrOp);
// BINARY_OP_LOWERING(XorOp, IREEInterp::HL::XorOp);
// BINARY_OP_LOWERING(ShiftLeftOp, IREEInterp::HL::ShiftLeftOp);
// BINARY_OP_LOWERING(ShiftRightLogicalOp, IREEInterp::HL::ShiftRightLogicalOp);
// BINARY_OP_LOWERING(ShiftRightArithmeticOp,
// IREEInterp::HL::ShiftRightArithmeticOp);
BINARY_OP_LOWERING(AddIOp, IREEInterp::HL::AddIOp);
BINARY_OP_LOWERING(AddFOp, IREEInterp::HL::AddFOp);
BINARY_OP_LOWERING(SubIOp, IREEInterp::HL::SubIOp);
BINARY_OP_LOWERING(SubFOp, IREEInterp::HL::SubFOp);
// UNARY_OP_LOWERING(AbsIOp, IREEInterp::HL::AbsIOp);
// UNARY_OP_LOWERING(AbsFOp, IREEInterp::HL::AbsFOp);
BINARY_OP_LOWERING(MulIOp, IREEInterp::HL::MulIOp);
BINARY_OP_LOWERING(MulFOp, IREEInterp::HL::MulFOp);
BINARY_OP_LOWERING(SignedDivIOp, IREEInterp::HL::DivISOp);
BINARY_OP_LOWERING(UnsignedDivIOp, IREEInterp::HL::DivIUOp);
BINARY_OP_LOWERING(DivFOp, IREEInterp::HL::DivFOp);
BINARY_OP_LOWERING(SignedRemIOp, IREEInterp::HL::RemISOp);
BINARY_OP_LOWERING(UnsignedRemIOp, IREEInterp::HL::RemIUOp);
BINARY_OP_LOWERING(RemFOp, IREEInterp::HL::RemFOp);
// BINARY_OP_LOWERING(MulAddIOp, IREEInterp::HL::MulAddIOp);
// BINARY_OP_LOWERING(MulAddFOp, IREEInterp::HL::MulAddFOp);
// UNARY_OP_LOWERING(ExpFOp, IREEInterp::HL::ExpFOp);
// UNARY_OP_LOWERING(LogFOp, IREEInterp::HL::LogFOp);
// UNARY_OP_LOWERING(RsqrtFOp, IREEInterp::HL::RsqrtFOp);
// UNARY_OP_LOWERING(SqrtFOp, IREEInterp::HL::SqrtFOp);
// UNARY_OP_LOWERING(CosFOp, IREEInterp::HL::CosFOp);
// UNARY_OP_LOWERING(SinFOp, IREEInterp::HL::SinFOp);
// UNARY_OP_LOWERING(TanhFOp, IREEInterp::HL::TanhFOp);
// UNARY_OP_LOWERING(Atan2FOp, IREEInterp::HL::Atan2FOp);
// BINARY_OP_LOWERING(MinISOp, IREEInterp::HL::MinISOp);
// BINARY_OP_LOWERING(MinIUOp, IREEInterp::HL::MinIUOp);
// BINARY_OP_LOWERING(MinFOp, IREEInterp::HL::MinFOp);
// BINARY_OP_LOWERING(MaxISOp, IREEInterp::HL::MaxISOp);
// BINARY_OP_LOWERING(MaxIUOp, IREEInterp::HL::MaxIUOp);
// BINARY_OP_LOWERING(MaxFOp, IREEInterp::HL::MaxFOp);
// TERNARY_OP_LOWERING(ClampFOp, IREEInterp::HL::ClampFOp);
// UNARY_OP_LOWERING(FloorFOp, IREEInterp::HL::FloorFOp);
// UNARY_OP_LOWERING(CeilFOp, IREEInterp::HL::CeilFOp);
} // namespace
void populateLowerStdToInterpreterPatterns(OwningRewritePatternList &patterns,
MLIRContext *ctx) {
patterns.insert<
// Control flow.
CallOpLowering, CallIndirectOpLowering, ReturnOpLowering,
BranchOpLowering, CondBranchOpLowering, CmpIOpLowering, CmpFOpLowering,
// Memory management.
AllocOpLowering, DeallocOpLowering, LoadOpLowering, StoreOpLowering,
// Shape operations.
DimOpLowering,
// Logical ops.
AndOpLowering, OrOpLowering,
// Arithmetic ops.
AddIOpLowering, AddFOpLowering, SubIOpLowering, SubFOpLowering,
MulIOpLowering, MulFOpLowering, SignedDivIOpLowering,
UnsignedDivIOpLowering, DivFOpLowering, RemFOpLowering,
SignedRemIOpLowering, UnsignedRemIOpLowering>(ctx);
}
} // namespace iree_compiler
} // namespace mlir