blob: cc97379c20f96f6a2e9945d371234cdc1ef28017 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "third_party/llvm/llvm/include/llvm/ADT/DenseSet.h"
#include "third_party/llvm/llvm/include/llvm/Support/Allocator.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Dialect/StandardOps/Ops.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/BlockAndValueMapping.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Function.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Location.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Module.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Pass/PassRegistry.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Transforms/DialectConversion.h"
#include "third_party/mlir_edge/iree/compiler/IR/Dialect.h"
#include "third_party/mlir_edge/iree/compiler/IR/Interpreter/HLDialect.h"
#include "third_party/mlir_edge/iree/compiler/IR/Interpreter/HLOps.h"
#include "third_party/mlir_edge/iree/compiler/IR/Interpreter/LLDialect.h"
#include "third_party/mlir_edge/iree/compiler/IR/Ops.h"
#include "third_party/mlir_edge/iree/compiler/Transforms/ConversionUtils.h"
#include "third_party/mlir_edge/iree/compiler/Utils/MemRefUtils.h"
namespace mlir {
namespace iree_compiler {
struct ConstantOpLowering : public ConversionPattern {
explicit ConstantOpLowering(MLIRContext *context)
: ConversionPattern(ConstantOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto midOp = rewriter.create<IREE::ConstantOp>(
op->getLoc(), cast<ConstantOp>(op).getValue());
auto result = wrapAsTensor(midOp.getResult(), op, rewriter);
rewriter.replaceOp(
op, {loadResultValue(op->getLoc(), op->getResult(0)->getType(), result,
rewriter)});
return matchSuccess();
}
};
class CallOpLowering : public ConversionPattern {
public:
explicit CallOpLowering(MLIRContext *context)
: ConversionPattern(CallOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto callOp = cast<CallOp>(op);
auto calleeType = callOp.getCalleeType();
rewriter.replaceOpWithNewOp<IREEInterp::HL::CallOp>(
op, callOp.getCallee(), calleeType.getResults(), operands);
return matchSuccess();
}
};
class CallIndirectOpLowering : public ConversionPattern {
public:
explicit CallIndirectOpLowering(MLIRContext *context)
: ConversionPattern(CallIndirectOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto callOp = cast<CallIndirectOp>(op);
rewriter.replaceOpWithNewOp<IREEInterp::HL::CallIndirectOp>(
op, callOp.getCallee(), operands);
return matchSuccess();
}
};
struct ReturnOpLowering : public ConversionPattern {
explicit ReturnOpLowering(MLIRContext *context)
: ConversionPattern(ReturnOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREEInterp::HL::ReturnOp>(op, operands);
return matchSuccess();
}
};
struct BranchOpLowering : public ConversionPattern {
explicit BranchOpLowering(MLIRContext *context)
: ConversionPattern(BranchOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value *> properOperands,
ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREEInterp::HL::BranchOp>(op, destinations[0],
operands[0]);
return this->matchSuccess();
}
};
struct CondBranchOpLowering : public ConversionPattern {
explicit CondBranchOpLowering(MLIRContext *context)
: ConversionPattern(CondBranchOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value *> properOperands,
ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands,
ConversionPatternRewriter &rewriter) const override {
auto *condValue =
loadAccessValue(op->getLoc(), properOperands[0], rewriter);
rewriter.replaceOpWithNewOp<IREEInterp::HL::CondBranchOp>(
op, condValue, destinations[IREEInterp::HL::CondBranchOp::trueIndex],
operands[IREEInterp::HL::CondBranchOp::trueIndex],
destinations[IREEInterp::HL::CondBranchOp::falseIndex],
operands[IREEInterp::HL::CondBranchOp::falseIndex]);
return this->matchSuccess();
}
};
template <typename SrcOp, typename DstOp>
struct CompareOpLowering : public ConversionPattern {
explicit CompareOpLowering(MLIRContext *context)
: ConversionPattern(SrcOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto lhValue = loadAccessValue(op->getLoc(), operands[0], rewriter);
auto rhValue = loadAccessValue(op->getLoc(), operands[1], rewriter);
lhValue = wrapAsMemRef(lhValue, op, rewriter);
rhValue = wrapAsMemRef(rhValue, op, rewriter);
// TODO(benvanik): map predicate to stable value.
auto predicate = rewriter.getI32IntegerAttr(
static_cast<int32_t>(dyn_cast<SrcOp>(op).getPredicate()));
auto dstType = getMemRefType(op->getResult(0), rewriter);
auto midOp = rewriter.create<DstOp>(op->getLoc(), dstType, predicate,
lhValue, rhValue);
auto result = wrapAsTensor(midOp.getResult(), op, rewriter);
rewriter.replaceOp(
op, {loadResultValue(op->getLoc(), op->getResult(0)->getType(), result,
rewriter)});
return this->matchSuccess();
}
};
struct CmpIOpLowering
: public CompareOpLowering<CmpIOp, IREEInterp::HL::CmpIOp> {
using CompareOpLowering::CompareOpLowering;
};
struct CmpFOpLowering
: public CompareOpLowering<CmpFOp, IREEInterp::HL::CmpFOp> {
using CompareOpLowering::CompareOpLowering;
};
struct AllocOpLowering : public ConversionPattern {
explicit AllocOpLowering(MLIRContext *context)
: ConversionPattern(AllocOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): replace with length computation.
rewriter.replaceOpWithNewOp<IREEInterp::HL::AllocHeapOp>(
op, *op->getResultTypes().begin(), operands);
return matchSuccess();
}
};
struct DeallocOpLowering : public ConversionPattern {
explicit DeallocOpLowering(MLIRContext *context)
: ConversionPattern(DeallocOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREEInterp::HL::DiscardOp>(op, operands[0]);
return matchSuccess();
}
};
struct ExtractElementOpLowering : public ConversionPattern {
explicit ExtractElementOpLowering(MLIRContext *context)
: ConversionPattern(ExtractElementOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto extractOp = cast<ExtractElementOp>(op);
Value *memRefInput = wrapAsMemRef(
loadAccessValue(op->getLoc(), extractOp.getAggregate(), rewriter), op,
rewriter);
SmallVector<Value *, 4> indices = {extractOp.indices().begin(),
extractOp.indices().end()};
rewriter.replaceOpWithNewOp<LoadOp>(op, memRefInput, indices);
return matchSuccess();
}
};
#define UNARY_OP_LOWERING(StdOpType, IREEOpType) \
struct StdOpType##Lowering : public UnaryOpLowering<StdOpType, IREEOpType> { \
using UnaryOpLowering::UnaryOpLowering; \
};
#define BINARY_OP_LOWERING(StdOpType, IREEOpType) \
struct StdOpType##Lowering \
: public BinaryOpLowering<StdOpType, IREEOpType> { \
using BinaryOpLowering::BinaryOpLowering; \
};
#define TERNARY_OP_LOWERING(StdOpType, IREEOpType) \
struct StdOpType##Lowering \
: public TernaryOpLowering<StdOpType, IREEOpType> { \
using TernaryOpLowering::TernaryOpLowering; \
};
// UNARY_OP_LOWERING(RankOp, IREEInterp::HL::RankOp);
UNARY_OP_LOWERING(DimOp, IREEInterp::HL::DimOp);
// UNARY_OP_LOWERING(ShapeOp, IREEInterp::HL::ShapeOp);
// UNARY_OP_LOWERING(LengthOp, IREEInterp::HL::LengthOp);
// UNARY_OP_LOWERING(NotOp, IREEInterp::HL::NotOp);
BINARY_OP_LOWERING(AndOp, IREEInterp::HL::AndOp);
BINARY_OP_LOWERING(OrOp, IREEInterp::HL::OrOp);
// BINARY_OP_LOWERING(XorOp, IREEInterp::HL::XorOp);
// BINARY_OP_LOWERING(ShiftLeftOp, IREEInterp::HL::ShiftLeftOp);
// BINARY_OP_LOWERING(ShiftRightLogicalOp, IREEInterp::HL::ShiftRightLogicalOp);
// BINARY_OP_LOWERING(ShiftRightArithmeticOp,
// IREEInterp::HL::ShiftRightArithmeticOp);
BINARY_OP_LOWERING(AddIOp, IREEInterp::HL::AddIOp);
BINARY_OP_LOWERING(AddFOp, IREEInterp::HL::AddFOp);
BINARY_OP_LOWERING(SubIOp, IREEInterp::HL::SubIOp);
BINARY_OP_LOWERING(SubFOp, IREEInterp::HL::SubFOp);
// UNARY_OP_LOWERING(AbsIOp, IREEInterp::HL::AbsIOp);
// UNARY_OP_LOWERING(AbsFOp, IREEInterp::HL::AbsFOp);
BINARY_OP_LOWERING(MulIOp, IREEInterp::HL::MulIOp);
BINARY_OP_LOWERING(MulFOp, IREEInterp::HL::MulFOp);
BINARY_OP_LOWERING(DivISOp, IREEInterp::HL::DivISOp);
BINARY_OP_LOWERING(DivIUOp, IREEInterp::HL::DivIUOp);
BINARY_OP_LOWERING(DivFOp, IREEInterp::HL::DivFOp);
// BINARY_OP_LOWERING(MulAddIOp, IREEInterp::HL::MulAddIOp);
// BINARY_OP_LOWERING(MulAddFOp, IREEInterp::HL::MulAddFOp);
// UNARY_OP_LOWERING(ExpFOp, IREEInterp::HL::ExpFOp);
// UNARY_OP_LOWERING(LogFOp, IREEInterp::HL::LogFOp);
// UNARY_OP_LOWERING(RsqrtFOp, IREEInterp::HL::RsqrtFOp);
// UNARY_OP_LOWERING(CosFOp, IREEInterp::HL::CosFOp);
// UNARY_OP_LOWERING(SinFOp, IREEInterp::HL::SinFOp);
// UNARY_OP_LOWERING(TanhFOp, IREEInterp::HL::TanhFOp);
// UNARY_OP_LOWERING(Atan2FOp, IREEInterp::HL::Atan2FOp);
// BINARY_OP_LOWERING(MinISOp, IREEInterp::HL::MinISOp);
// BINARY_OP_LOWERING(MinIUOp, IREEInterp::HL::MinIUOp);
// BINARY_OP_LOWERING(MinFOp, IREEInterp::HL::MinFOp);
// BINARY_OP_LOWERING(MaxISOp, IREEInterp::HL::MaxISOp);
// BINARY_OP_LOWERING(MaxIUOp, IREEInterp::HL::MaxIUOp);
// BINARY_OP_LOWERING(MaxFOp, IREEInterp::HL::MaxFOp);
// TERNARY_OP_LOWERING(ClampFOp, IREEInterp::HL::ClampFOp);
// UNARY_OP_LOWERING(FloorFOp, IREEInterp::HL::FloorFOp);
// UNARY_OP_LOWERING(CeilFOp, IREEInterp::HL::CeilFOp);
class LowerStdToInterpreterDialectPass
: public FunctionPass<LowerStdToInterpreterDialectPass> {
public:
void runOnFunction() override {
OwningRewritePatternList patterns;
patterns.insert<
ConstantOpLowering,
// Control flow.
CallOpLowering, CallIndirectOpLowering, ReturnOpLowering,
BranchOpLowering, CondBranchOpLowering, CmpIOpLowering, CmpFOpLowering,
// Memory management.
AllocOpLowering, DeallocOpLowering, ExtractElementOpLowering,
// Shape operations.
DimOpLowering,
// Logical ops.
AndOpLowering, OrOpLowering,
// Arithmetic ops.
AddIOpLowering, AddFOpLowering, SubIOpLowering, SubFOpLowering,
MulIOpLowering, MulFOpLowering, DivISOpLowering, DivIUOpLowering,
DivFOpLowering>(&getContext());
ConversionTarget target(getContext());
target.addLegalDialect<IREEHLInterpreterDialect, IREELLInterpreterDialect,
IREEDialect>();
target.addLegalOp<LoadOp, StoreOp, FuncOp, ModuleOp>();
target.addDynamicallyLegalOp<ConstantOp>([](ConstantOp constOp) {
// std.constant is legal for index integers.
return constOp.getValue().isa<IntegerAttr>() &&
constOp.getType().isIndex();
});
// This is only able to be a full conversion because we know that we do it
// last. Other conversions (e.g. XLA->IREE) can't be full because we might
// still have e.g. std ops.
// TODO(b/139012931) Move fully to the conversion framework.
if (failed(applyFullConversion(getFunction(), target, patterns))) {
return signalPassFailure();
}
}
};
std::unique_ptr<OpPassBase<FuncOp>> createLowerStdToInterpreterDialectPass() {
return std::make_unique<LowerStdToInterpreterDialectPass>();
}
static PassRegistration<LowerStdToInterpreterDialectPass> pass(
"lower-std-to-iree-interpreter",
"Convert all std functions to the IREE dialect");
} // namespace iree_compiler
} // namespace mlir