blob: a76eb1392a1fbc931b1ad246d3d676cdb19d6b06 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/iterator_range.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Utils.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Flow {
namespace {
// Given a set of types, unpack to a list of a types, removing all tuples.
void untupleTypes(llvm::ArrayRef<Type> types,
llvm::SmallVectorImpl<Type> *newTypes) {
for (auto &type : types) {
if (type.isa<TupleType>()) {
untupleTypes(type.dyn_cast<TupleType>().getTypes(), newTypes);
} else {
newTypes->push_back(type);
}
}
}
ValuePtr processTuple(Type type, Location loc, Block *block,
OpBuilder &builder) {
if (!type.isa<TupleType>()) {
return block->addArgument(type);
}
auto tupleType = type.dyn_cast<TupleType>();
llvm::SmallVector<ValuePtr, 4> values;
values.reserve(tupleType.size());
for (auto subtype : tupleType.getTypes()) {
values.push_back(processTuple(subtype, loc, block, builder));
}
return builder.create<xla_hlo::TupleOp>(loc, tupleType, values);
}
void copyOperationAttrs(Operation *oldOp, Operation *newOp) {
for (const auto &oldAttr : oldOp->getAttrs()) {
newOp->setAttr(oldAttr.first, oldAttr.second);
}
}
bool recursiveUntuple(ValuePtr value, Location loc, OpBuilder &builder,
BlockAndValueMapping *mapping,
llvm::SmallVectorImpl<ValuePtr> *newValues) {
Type type = value->getType();
// We can return the value as is.
if (!type.isa<TupleType>()) {
newValues->push_back(value);
return false;
}
TupleType tupleType = type.dyn_cast<TupleType>();
for (int i = 0; i < tupleType.size(); i++) {
auto subType = tupleType.getType(i);
auto elementOp = builder.create<xla_hlo::GetTupleElementOp>(
loc, subType, value, builder.getI32IntegerAttr(i));
recursiveUntuple(elementOp.getResult(), loc, builder, mapping, newValues);
}
return false;
}
ValuePtr recursiveRetuple(Type oldType, Operation::result_range *values,
OpBuilder &builder, Location loc) {
if (!oldType.isa<TupleType>()) {
ValuePtr returnValue = *values->begin();
*values = values->drop_front();
return returnValue;
}
TupleType tupleType = oldType.dyn_cast<TupleType>();
llvm::SmallVector<ValuePtr, 10> subValues;
for (auto subtype : tupleType.getTypes()) {
subValues.push_back(recursiveRetuple(subtype, values, builder, loc));
}
return builder.create<xla_hlo::TupleOp>(loc, tupleType, subValues)
.getResult();
}
template <typename T>
bool untupleAndLookupValues(T values,
llvm::SmallVectorImpl<ValuePtr> *newValues,
OpBuilder &builder, Location loc,
BlockAndValueMapping *mapping) {
for (auto operand : values) {
auto newValue = mapping->lookupOrNull(operand);
if (!newValue) {
return true;
}
recursiveUntuple(newValue, loc, builder, mapping, newValues);
}
return false;
}
bool convertReturnOp(ReturnOp *op, OpBuilder &builder,
BlockAndValueMapping *mapping) {
llvm::SmallVector<ValuePtr, 10> newOperands;
if (untupleAndLookupValues(op->getOperands(), &newOperands, builder,
op->getLoc(), mapping)) {
return true;
}
builder.create<ReturnOp>(op->getLoc(), newOperands);
return false;
}
bool convertCallOp(CallOp *oldOp, OpBuilder &builder,
BlockAndValueMapping *mapping) {
llvm::SmallVector<ValuePtr, 4> newArgs;
if (untupleAndLookupValues(oldOp->getOperands(), &newArgs, builder,
oldOp->getLoc(), mapping)) {
return true;
}
SmallVector<Type, 4> originalTypes(oldOp->getOperation()->getResultTypes());
SmallVector<Type, 4> resultTypes;
untupleTypes(originalTypes, &resultTypes);
auto newOp = builder.create<CallOp>(oldOp->getLoc(), oldOp->getCallee(),
resultTypes, newArgs);
copyOperationAttrs(oldOp->getOperation(), newOp.getOperation());
auto newResults = newOp.getResults();
for (auto oldResult : oldOp->getResults()) {
llvm::SmallVector<ValuePtr, 10> subValues;
auto newResult = recursiveRetuple(oldResult->getType(), &newResults,
builder, oldOp->getLoc());
mapping->map(oldResult, newResult);
}
return false;
}
bool convertIndirectCallOp(CallIndirectOp *oldOp, OpBuilder &builder,
BlockAndValueMapping *mapping) {
llvm::SmallVector<ValuePtr, 4> newArgs;
if (untupleAndLookupValues(oldOp->getOperands(), &newArgs, builder,
oldOp->getLoc(), mapping)) {
return true;
}
auto newOp = builder.create<CallIndirectOp>(oldOp->getLoc(),
oldOp->getCallee(), newArgs);
copyOperationAttrs(oldOp->getOperation(), newOp.getOperation());
for (int i = 0; i < newOp.getNumResults(); ++i) {
auto oldResult = oldOp->getResult(i);
auto newResult = newOp.getResult(i);
mapping->map(oldResult, newResult);
}
return false;
}
bool convertBranchOp(BranchOp *oldOp, OpBuilder &builder,
BlockAndValueMapping *mapping) {
llvm::SmallVector<ValuePtr, 4> newArgs;
if (untupleAndLookupValues(oldOp->getOperands(), &newArgs, builder,
oldOp->getLoc(), mapping)) {
return true;
}
auto newOp = builder.create<BranchOp>(
oldOp->getLoc(), mapping->lookupOrNull(oldOp->getDest()), newArgs);
copyOperationAttrs(oldOp->getOperation(), newOp.getOperation());
return false;
}
bool convertCondBranchOp(CondBranchOp *oldOp, OpBuilder &builder,
BlockAndValueMapping *mapping) {
llvm::SmallVector<ValuePtr, 4> trueArgs;
if (untupleAndLookupValues(oldOp->getTrueOperands(), &trueArgs, builder,
oldOp->getLoc(), mapping)) {
return true;
}
llvm::SmallVector<ValuePtr, 4> falseArgs;
if (untupleAndLookupValues(oldOp->getFalseOperands(), &falseArgs, builder,
oldOp->getLoc(), mapping)) {
return true;
}
auto newOp = builder.create<CondBranchOp>(
oldOp->getLoc(), mapping->lookupOrNull(oldOp->getCondition()),
mapping->lookupOrNull(oldOp->getTrueDest()), trueArgs,
mapping->lookupOrNull(oldOp->getFalseDest()), falseArgs);
copyOperationAttrs(oldOp->getOperation(), newOp.getOperation());
return false;
}
bool convertOperation(Operation *op, OpBuilder &builder,
BlockAndValueMapping *mapping) {
if (auto returnOp = dyn_cast<ReturnOp>(op)) {
return convertReturnOp(&returnOp, builder, mapping);
} else if (auto callOp = dyn_cast<CallOp>(op)) {
return convertCallOp(&callOp, builder, mapping);
} else if (auto callIndirectOp = dyn_cast<CallIndirectOp>(op)) {
return convertIndirectCallOp(&callIndirectOp, builder, mapping);
} else if (auto branchOp = dyn_cast<BranchOp>(op)) {
return convertBranchOp(&branchOp, builder, mapping);
} else if (auto condBranchOp = dyn_cast<CondBranchOp>(op)) {
return convertCondBranchOp(&condBranchOp, builder, mapping);
}
builder.clone(*op, *mapping);
return false;
}
bool convertFunction(FuncOp oldFunction, FuncOp newFunction) {
OpBuilder builder(newFunction.getBody());
BlockAndValueMapping mapping;
for (auto attr : oldFunction.getAttrs()) {
if (attr.first != oldFunction.getTypeAttrName()) {
newFunction.setAttr(attr.first, attr.second);
}
}
newFunction.getBlocks().clear();
for (auto &oldBlock : oldFunction.getBlocks()) {
auto *newBlock = builder.createBlock(&newFunction.getBody());
for (auto oldArg : oldBlock.getArguments()) {
llvm::SmallVector<Type, 4> newTypes;
untupleTypes(oldArg->getType(), &newTypes);
ValuePtr newTuple = processTuple(oldArg->getType(), oldFunction.getLoc(),
newBlock, builder);
if (!newTuple) {
return true;
}
mapping.map(oldArg, newTuple);
}
mapping.map(&oldBlock, newBlock);
}
// Convert all ops in the blocks.
for (auto &oldBlock : oldFunction.getBlocks()) {
builder.setInsertionPointToEnd(mapping.lookupOrNull(&oldBlock));
for (auto &oldOp : oldBlock.getOperations()) {
if (convertOperation(&oldOp, builder, &mapping)) {
return true;
}
}
}
return false;
}
class FlattenTuplesInCFGPass : public ModulePass<FlattenTuplesInCFGPass> {
public:
void runOnModule() override {
auto module = getModule();
Builder builder(module.getContext());
// Build a list of (oldFunction, newFunction) for all functions we need to
// replace. This will ensure that when we go to convert function bodies we
// have only new functions defined.
std::vector<std::pair<FuncOp, FuncOp>> convertedFunctions;
for (auto oldFunction : module.getOps<FuncOp>()) {
auto oldFunctionType = oldFunction.getType();
llvm::SmallVector<Type, 10> newInputTypes;
untupleTypes(oldFunctionType.getInputs(), &newInputTypes);
llvm::SmallVector<Type, 10> newResultTypes;
untupleTypes(oldFunctionType.getResults(), &newResultTypes);
auto newFunctionType =
builder.getFunctionType(newInputTypes, newResultTypes);
auto newFunction =
FuncOp::create(oldFunction.getLoc(), oldFunction.getName(),
newFunctionType, oldFunction.getDialectAttrs());
convertedFunctions.push_back({oldFunction, newFunction});
// Perform the actual body conversion now that we have proper signatures.
if (convertFunction(oldFunction, newFunction)) {
return signalPassFailure();
}
}
// Replace functions in the module.
for (auto &pair : convertedFunctions) {
pair.first.erase();
module.push_back(pair.second);
}
}
};
} // namespace
std::unique_ptr<OpPassBase<ModuleOp>> createFlattenTuplesInCFGPass() {
return std::make_unique<FlattenTuplesInCFGPass>();
}
static PassRegistration<FlattenTuplesInCFGPass> pass(
"iree-flow-flatten-tuples-in-cfg",
"Convert functions to remove tuples from method signatures and blocks");
} // namespace Flow
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir