blob: e2a81d87ad9f6524dbecb85ee5f9accc192ab725 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "third_party/llvm/llvm/include/llvm/ADT/ArrayRef.h"
#include "third_party/llvm/llvm/include/llvm/ADT/DenseMap.h"
#include "third_party/llvm/llvm/include/llvm/ADT/SmallVector.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Dialect/StandardOps/Ops.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/BlockAndValueMapping.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Builders.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Location.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/OperationSupport.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Pass/PassRegistry.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Support/LogicalResult.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Transforms/DialectConversion.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Transforms/Utils.h"
#include "third_party/mlir_edge/iree/compiler/IR/Dialect.h"
#include "third_party/mlir_edge/iree/compiler/IR/Ops.h"
#include "third_party/mlir_edge/iree/compiler/IR/Sequencer/HLDialect.h"
#include "third_party/mlir_edge/iree/compiler/IR/Sequencer/HLOps.h"
#include "third_party/mlir_edge/iree/compiler/IR/StructureOps.h"
#include "third_party/mlir_edge/iree/compiler/Utils/MemRefUtils.h"
namespace mlir {
namespace iree_compiler {
namespace {
class SequencerConversionPattern : public ConversionPattern {
public:
SequencerConversionPattern(StringRef operationName, int benefit,
MLIRContext *context,
MemRefTypeConverter &typeConverter)
: ConversionPattern(operationName, benefit, context),
typeConverter_(typeConverter) {}
protected:
MemRefTypeConverter &typeConverter_;
};
struct ConstantOpLowering : public SequencerConversionPattern {
ConstantOpLowering(MLIRContext *context, MemRefTypeConverter &typeConverter)
: SequencerConversionPattern(ConstantOp::getOperationName(), 1, context,
typeConverter) {}
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
const auto &valueAttr = cast<ConstantOp>(op).getValue();
auto midOp = rewriter.create<IREE::ConstantOp>(op->getLoc(), valueAttr);
auto result = wrapAsTensor(midOp.getResult(), op, rewriter);
rewriter.replaceOp(
op, {loadResultValue(op->getLoc(), op->getResult(0)->getType(), result,
rewriter)});
return matchSuccess();
}
};
class CallOpLowering : public SequencerConversionPattern {
public:
CallOpLowering(MLIRContext *context, MemRefTypeConverter &typeConverter)
: SequencerConversionPattern(CallOp::getOperationName(), 1, context,
typeConverter) {}
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto callOp = cast<CallOp>(op);
SmallVector<Type, 4> convertedResults;
auto result = typeConverter_.convertTypes(
callOp.getCalleeType().getResults(), convertedResults);
(void)result;
assert(succeeded(result) && "expected valid callee type conversion");
rewriter.replaceOpWithNewOp<IREESeq::HL::CallOp>(
op, callOp.getCallee(), convertedResults, operands);
return matchSuccess();
}
};
class CallIndirectOpLowering : public SequencerConversionPattern {
public:
CallIndirectOpLowering(MLIRContext *context,
MemRefTypeConverter &typeConverter)
: SequencerConversionPattern(CallIndirectOp::getOperationName(), 1,
context, typeConverter) {}
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto callOp = cast<CallIndirectOp>(op);
rewriter.replaceOpWithNewOp<IREESeq::HL::CallIndirectOp>(
op, callOp.getCallee(), operands);
return matchSuccess();
}
};
struct ReturnOpLowering : public SequencerConversionPattern {
ReturnOpLowering(MLIRContext *context, MemRefTypeConverter &typeConverter)
: SequencerConversionPattern(ReturnOp::getOperationName(), 1, context,
typeConverter) {}
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value *, 4> newOperands;
newOperands.reserve(operands.size());
for (auto *operand : operands) {
newOperands.push_back(wrapAsMemRef(operand, op, rewriter));
}
rewriter.replaceOpWithNewOp<IREESeq::HL::ReturnOp>(op, newOperands);
return matchSuccess();
}
};
struct BranchOpLowering : public SequencerConversionPattern {
BranchOpLowering(MLIRContext *context, MemRefTypeConverter &typeConverter)
: SequencerConversionPattern(BranchOp::getOperationName(), 1, context,
typeConverter) {}
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value *> properOperands,
ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREESeq::HL::BranchOp>(op, destinations[0],
operands[0]);
return this->matchSuccess();
}
};
struct CondBranchOpLowering : public SequencerConversionPattern {
CondBranchOpLowering(MLIRContext *context, MemRefTypeConverter &typeConverter)
: SequencerConversionPattern(CondBranchOp::getOperationName(), 1, context,
typeConverter) {}
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value *> properOperands,
ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands,
ConversionPatternRewriter &rewriter) const override {
auto *condValue =
loadAccessValue(op->getLoc(), properOperands[0], rewriter);
rewriter.replaceOpWithNewOp<IREESeq::HL::CondBranchOp>(
op, condValue, destinations[IREESeq::HL::CondBranchOp::trueIndex],
operands[IREESeq::HL::CondBranchOp::trueIndex],
destinations[IREESeq::HL::CondBranchOp::falseIndex],
operands[IREESeq::HL::CondBranchOp::falseIndex]);
return this->matchSuccess();
}
};
class AllocOpLowering : public SequencerConversionPattern {
public:
AllocOpLowering(MLIRContext *context, MemRefTypeConverter &typeConverter)
: SequencerConversionPattern(AllocOp::getOperationName(), 1, context,
typeConverter) {}
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): replace with length computation.
rewriter.replaceOpWithNewOp<IREESeq::HL::AllocHeapOp>(
op, *op->getResultTypes().begin(), operands);
return matchSuccess();
}
};
class DeallocOpLowering : public SequencerConversionPattern {
public:
DeallocOpLowering(MLIRContext *context, MemRefTypeConverter &typeConverter)
: SequencerConversionPattern(DeallocOp::getOperationName(), 1, context,
typeConverter) {}
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREESeq::HL::DiscardOp>(op, operands[0]);
return matchSuccess();
}
};
void populateStdToSequencerConversionPatterns(
MLIRContext *context, MemRefTypeConverter &converter,
OwningRewritePatternList &patterns) {
patterns.insert<ConstantOpLowering,
// Control flow.
CallOpLowering, CallIndirectOpLowering, ReturnOpLowering,
BranchOpLowering, CondBranchOpLowering,
// Memory management.
AllocOpLowering, DeallocOpLowering>(context, converter);
}
} // namespace
// Lowers functions using std.* ops to the IREE HL sequencer dialect and buffer
// view types.
// FuncOp signatures will be updated to use the buffer view type and
// dispatch regions will get iree.bind_input where needed.
//
// Beyond bindings there will be no other changes within dispatchable regions.
// It is up to the downstream dialects to properly use the bindings to map their
// I/O to expected values.
//
// Note that output buffer allocation is required following this pass to either
// elide dispatch results entirely and provide output params or provide both
// while ensuring that the returned value is always sliced from an input. This
// should happen prior to outlining.
class LowerStdToSequencerDialectPass
: public ModulePass<LowerStdToSequencerDialectPass> {
public:
void runOnModule() override {
auto module = getModule();
// Only convert top-level functions, not ones nested in executables.
std::vector<Operation *> toConvert;
for (auto funcOp : module.getOps<FuncOp>()) {
toConvert.push_back(funcOp);
}
// Convert the signature and body of all sequencer functions.
MemRefTypeConverter converter(&getContext());
ConversionTarget target(getContext());
target.addLegalDialect<IREEHLSequencerDialect, IREEDialect>();
target.addLegalOp<LoadOp, StoreOp>();
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
OwningRewritePatternList patterns;
populateStdToSequencerConversionPatterns(&getContext(), converter,
patterns);
if (failed(
applyPartialConversion(toConvert, target, patterns, &converter))) {
return signalPassFailure();
}
}
};
std::unique_ptr<OpPassBase<ModuleOp>> createLowerStdToSequencerDialectPass() {
return std::make_unique<LowerStdToSequencerDialectPass>();
}
static PassRegistration<LowerStdToSequencerDialectPass> pass(
"iree-lower-std-to-sequencer-dialect",
"Lowers std ops to the IREE HL sequencer dialect.");
} // namespace iree_compiler
} // namespace mlir