blob: 1b65ede2d249faf690c02f5f8f7cd73bae0d161c [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "compiler/IR/Ops.h"
#include "compiler/IR/StructureOps.h"
#include "compiler/Utils/MemRefUtils.h"
#include "compiler/Utils/TypeConversionUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Utils.h"
namespace mlir {
namespace iree_compiler {
namespace {
void copyOperationAttrs(Operation *oldOp, Operation *newOp) {
for (const auto &oldAttr : oldOp->getAttrs()) {
newOp->setAttr(oldAttr.first, oldAttr.second);
}
}
FunctionType getMemRefFunctionType(FunctionType type) {
Builder builder(type.getContext());
llvm::SmallVector<Type, 8> replacementInputs;
for (auto type : type.getInputs()) {
auto memRefType = convertTypeToMemRef(type);
if (!memRefType) {
return nullptr;
}
replacementInputs.push_back(memRefType);
}
llvm::SmallVector<Type, 8> replacementResults;
for (auto type : type.getResults()) {
auto memRefType = convertTypeToMemRef(type);
if (!memRefType) {
return nullptr;
}
replacementResults.push_back(memRefType);
}
return builder.getFunctionType(replacementInputs, replacementResults);
}
bool insertLoad(BlockArgument *oldArg, BlockArgument *newArg,
OpBuilder &builder, BlockAndValueMapping *mapping) {
auto loc = oldArg->getOwner()->getParent()->getLoc();
// If old arg was a memref we don't need to change anything. We still need
// to remap so that the use lists match through conversion, though.
if (oldArg->getType().isa<MemRefType>()) {
mapping->map(oldArg, newArg);
return false;
} else if (oldArg->getType().isa<TensorType>()) {
auto castOp = builder.create<IREE::MemRefToTensorOp>(loc, newArg);
mapping->map(oldArg, castOp.getResult());
return false;
}
// Insert the load we'll use to unbox the value.
auto loadedValue = builder.create<LoadOp>(loc, newArg, ArrayRef<Value *>{});
mapping->map(oldArg, loadedValue);
return false;
}
bool insertLoad(Operation *oldOp, Value *oldValue, Value *newValue,
OpBuilder &builder, BlockAndValueMapping *mapping) {
// If old value was a memref we don't need to change anything.
if (oldValue->getType().isa<MemRefType>()) {
mapping->map(oldValue, newValue);
return false;
} else if (oldValue->getType().isa<TensorType>()) {
auto castOp =
builder.create<IREE::MemRefToTensorOp>(oldOp->getLoc(), newValue);
mapping->map(oldValue, castOp.getResult());
return false;
}
assert(newValue->getType().isa<MemRefType>());
// Insert the load we'll use to unbox the value.
auto loadedValue =
builder.create<LoadOp>(oldOp->getLoc(), newValue, ArrayRef<Value *>{});
mapping->map(oldValue, loadedValue);
return false;
}
Value *insertStore(Operation *oldOp, Value *oldValue, OpBuilder &builder,
BlockAndValueMapping *mapping) {
auto *newValue = mapping->lookupOrNull(oldValue);
if (!newValue) {
return nullptr;
}
// If the previous value was already a memref we don't need to change
// anything.
// TODO(benvanik): ensure indices make sense.
if (oldValue->getType().isa<MemRefType>()) {
return newValue;
} else if (oldValue->getType().isa<TensorType>()) {
auto castOp =
builder.create<IREE::TensorToMemRefOp>(oldOp->getLoc(), newValue);
return castOp.getResult();
}
// Look back up and see if we can find the memref the value was loaded from.
if (auto *sourceMemRef = resolveValueToSourceMemRef(oldValue, oldOp)) {
return mapping->lookupOrNull(sourceMemRef);
}
// Allocate the memref to store the value.
auto newStorage = builder.create<AllocOp>(
oldOp->getLoc(), convertTypeToMemRef(oldValue->getType()));
// Insert the store we'll use to box the value.
builder.create<StoreOp>(oldOp->getLoc(), newValue, newStorage,
ArrayRef<Value *>{});
return newStorage;
}
bool convertCallOp(CallOp *oldOp, OpBuilder &builder,
BlockAndValueMapping *mapping) {
llvm::SmallVector<Value *, 4> newArgs;
for (auto *oldArg : oldOp->getOperands()) {
auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping);
if (!newArg) {
return true;
}
newArgs.push_back(newArg);
}
SmallVector<Type, 4> resultTypes;
for (auto oldType : oldOp->getOperation()->getResultTypes()) {
resultTypes.push_back(convertTypeToMemRef(oldType));
}
auto newOp = builder.create<CallOp>(oldOp->getLoc(), oldOp->getCallee(),
resultTypes, newArgs);
copyOperationAttrs(oldOp->getOperation(), newOp.getOperation());
for (int i = 0; i < newOp.getNumResults(); ++i) {
auto *oldResult = oldOp->getResult(i);
auto *newResult = newOp.getResult(i);
if (insertLoad(oldOp->getOperation(), oldResult, newResult, builder,
mapping)) {
return true;
}
}
return false;
}
bool convertCallIndirectOp(CallIndirectOp *oldOp, OpBuilder &builder,
BlockAndValueMapping *mapping) {
// TODO(benvanik): support wrapping callee values.
oldOp->emitError("CallIndirectOp not yet supported");
return true;
#if 0
llvm::SmallVector<Value *, 4> newArgs;
for (auto *oldArg : oldOp->getArgOperands()) {
auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping);
if (!newArg) {
return true;
}
newArgs.push_back(newArg);
}
auto newOp = builder.create<CallIndirectOp>(oldOp->getLoc(),
oldOp->getCallee(), newArgs);
copyOperationAttrs(oldOp->getOperation(), newOp.getOperation());
for (int i = 0; i < newOp.getNumResults(); ++i) {
auto *oldResult = oldOp->getResult(i);
auto *newResult = newOp.getResult(i);
if (insertLoad(oldOp->getOperation(), oldResult, newResult, builder,
mapping)) {
return true;
}
}
return false;
#endif // 0
}
bool convertReturnOp(Operation *oldOp, OpBuilder &builder,
BlockAndValueMapping *mapping) {
BlockAndValueMapping returnMapping;
for (auto *oldArg : oldOp->getOperands()) {
auto *newArg = insertStore(oldOp, oldArg, builder, mapping);
if (!newArg) {
return true;
}
returnMapping.map(oldArg, newArg);
}
builder.clone(*oldOp, returnMapping);
return false;
}
bool convertBranchOp(BranchOp *oldOp, OpBuilder &builder,
BlockAndValueMapping *mapping) {
llvm::SmallVector<Value *, 4> newArgs;
for (auto *oldArg : oldOp->getOperands()) {
auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping);
if (!newArg) {
return true;
}
newArgs.push_back(newArg);
}
auto *dest = mapping->lookupOrNull(oldOp->getDest());
if (!dest) {
oldOp->emitError("Destination block mapping not found");
return true;
}
auto newOp = builder.create<BranchOp>(oldOp->getLoc(), dest, newArgs);
copyOperationAttrs(oldOp->getOperation(), newOp.getOperation());
return false;
}
bool convertCondBranchOp(CondBranchOp *oldOp, OpBuilder &builder,
BlockAndValueMapping *mapping) {
llvm::SmallVector<Value *, 4> trueArgs;
for (auto *oldArg : oldOp->getTrueOperands()) {
auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping);
if (!newArg) {
return true;
}
trueArgs.push_back(newArg);
}
llvm::SmallVector<Value *, 4> falseArgs;
for (auto *oldArg : oldOp->getFalseOperands()) {
auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping);
if (!newArg) {
return true;
}
falseArgs.push_back(newArg);
}
auto *trueDest = mapping->lookupOrNull(oldOp->getTrueDest());
if (!trueDest) {
oldOp->emitError("True destination block mapping not found");
return true;
}
auto *falseDest = mapping->lookupOrNull(oldOp->getFalseDest());
if (!falseDest) {
oldOp->emitError("False destination block mapping not found");
return true;
}
// Lowering will take care of the condition store.
auto *newCondition = mapping->lookupOrNull(oldOp->getCondition());
if (!newCondition) {
oldOp->emitError("Condition value mapping not found");
return false;
}
auto newOp = builder.create<CondBranchOp>(
oldOp->getLoc(), newCondition, trueDest, trueArgs, falseDest, falseArgs);
copyOperationAttrs(oldOp->getOperation(), newOp.getOperation());
return false;
}
bool convertOperation(Operation *oldOp, OpBuilder &builder,
BlockAndValueMapping *mapping) {
if (isa<ConstantOp>(oldOp)) {
builder.clone(*oldOp, *mapping);
return false;
} else if (auto callOp = dyn_cast<CallOp>(oldOp)) {
return convertCallOp(&callOp, builder, mapping);
} else if (auto callIndirectOp = dyn_cast<CallIndirectOp>(oldOp)) {
return convertCallIndirectOp(&callIndirectOp, builder, mapping);
} else if (isa<ReturnOp>(oldOp) || isa<IREE::ReturnOp>(oldOp)) {
return convertReturnOp(oldOp, builder, mapping);
} else if (auto branchOp = dyn_cast<BranchOp>(oldOp)) {
return convertBranchOp(&branchOp, builder, mapping);
} else if (auto condBranchOp = dyn_cast<CondBranchOp>(oldOp)) {
return convertCondBranchOp(&condBranchOp, builder, mapping);
} else {
builder.clone(*oldOp, *mapping);
return false;
}
}
bool convertFunction(FuncOp oldFunc, FuncOp newFunc) {
OpBuilder builder(newFunc.getBody());
BlockAndValueMapping mapping;
// Create new blocks matching the expected arguments of the old ones.
// This sets up the block mappings to enable us to reference blocks forward
// during conversion.
newFunc.getBlocks().clear();
for (auto &oldBlock : oldFunc.getBlocks()) {
auto *newBlock = builder.createBlock(&newFunc.getBody());
for (auto *oldArg : oldBlock.getArguments()) {
// Replace the block args with memrefs.
auto memRefType = convertTypeToMemRef(oldArg->getType());
if (!memRefType) return true;
auto *newArg = newBlock->addArgument(memRefType);
// Insert loads to preserve type, if needed.
// This will replace all uses of the oldArg with the loaded value from
// newArg so that the block contents are still using unwrapped values.
if (insertLoad(oldArg, newArg, builder, &mapping)) {
return true;
}
}
mapping.map(&oldBlock, newBlock);
}
// Convert all ops in the blocks.
for (auto &oldBlock : oldFunc.getBlocks()) {
builder.setInsertionPointToEnd(mapping.lookupOrNull(&oldBlock));
for (auto &oldOp : oldBlock.getOperations()) {
if (convertOperation(&oldOp, builder, &mapping)) {
return true;
}
}
}
return false;
}
} // namespace
class ConvertToMemRefCallingConventionPass
: public ModulePass<ConvertToMemRefCallingConventionPass> {
public:
void runOnModule() override {
auto module = getModule();
// Build a list of (oldFunc, newFunc) for all functions we need to
// replace. This will ensure that when we go to convert function bodies we
// have only new functions defined.
std::vector<std::pair<FuncOp, FuncOp>> convertedFunctions;
for (auto oldFunc : module.getOps<FuncOp>()) {
// Create the replacement function, ensuring that we copy attributes.
auto functionType = getMemRefFunctionType(oldFunc.getType());
if (!functionType) {
return signalPassFailure();
}
auto newFunc = FuncOp::create(oldFunc.getLoc(), oldFunc.getName(),
functionType, oldFunc.getDialectAttrs());
convertedFunctions.push_back({oldFunc, newFunc});
// Perform the actual body conversion now.
if (convertFunction(oldFunc, newFunc)) {
return signalPassFailure();
}
}
// Replace functions in the module.
for (auto &pair : convertedFunctions) {
pair.first.erase();
module.push_back(pair.second);
}
}
};
std::unique_ptr<OpPassBase<ModuleOp>>
createConvertToMemRefCallingConventionPass() {
return std::make_unique<ConvertToMemRefCallingConventionPass>();
}
static PassRegistration<ConvertToMemRefCallingConventionPass> pass(
"convert-to-memref-calling-convention",
"Convert functions to use a memref-based calling convention.");
} // namespace iree_compiler
} // namespace mlir