blob: b364fe0323f2c0b3786afc4b4aa9a01188c387f1 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "integrations/tensorflow/compiler/Passes.h"
#include "iree/base/signature_mangle.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Utils.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
namespace mlir {
namespace iree_compiler {
static LogicalResult importTfSavedModelGlobalTensorsToIREEFlow(
ModuleOp module) {
OpBuilder globalBuilder(module.getBodyRegion());
SymbolTable symbolTable(module);
DenseMap<StringRef, std::string> symNameToFlowSymName;
for (auto globalTensor : module.getOps<tf_saved_model::GlobalTensorOp>()) {
auto exportedNames = tf_saved_model::GetExportedNames(globalTensor);
std::string flowSymName;
if (exportedNames.empty()) {
flowSymName = "__iree_flow_" + globalTensor.sym_name().str();
} else if (exportedNames.size() == 1) {
flowSymName = exportedNames[0].str();
} else {
return globalTensor.emitError()
<< "Multiple exported names for global tensor not supported yet";
}
symNameToFlowSymName[globalTensor.sym_name()] = flowSymName;
globalBuilder.create<IREE::Flow::VariableOp>(
globalTensor.getLoc(), flowSymName, globalTensor.is_mutable(),
globalTensor.type(), globalTensor.value());
}
for (auto func : module.getOps<FuncOp>()) {
SmallVector<unsigned, 4> argsToErase;
for (int i = 0, e = func.getNumArguments(); i < e; i++) {
tf_saved_model::GlobalTensorOp globalTensor =
tf_saved_model::LookupBoundInput(func, i, symbolTable);
if (!globalTensor) {
continue;
}
argsToErase.push_back(i);
auto flowSymRef = globalBuilder.getSymbolRefAttr(
symNameToFlowSymName[globalTensor.sym_name()]);
Value arg = func.getArgument(i);
if (globalTensor.is_mutable()) {
// The value is a tensor<*x!tf.resource> type, which flows into
// tf.ReadVariableOp/tf.AssignVariableOp.
// XLA resource functionalization should have canonicalized everything
// to uses of those two ops in the body of the tf_saved_model exported
// function.
for (OpOperand &operand : llvm::make_early_inc_range(arg.getUses())) {
if (auto readVariable =
dyn_cast<TF::ReadVariableOp>(operand.getOwner())) {
auto load = OpBuilder(readVariable)
.create<IREE::Flow::VariableLoadOp>(
readVariable.getLoc(),
readVariable.value().getType(), flowSymRef);
readVariable.value().replaceAllUsesWith(load.result());
readVariable.erase();
continue;
}
if (auto assignVariable =
dyn_cast<TF::AssignVariableOp>(operand.getOwner())) {
OpBuilder(assignVariable)
.create<IREE::Flow::VariableStoreOp>(assignVariable.getLoc(),
assignVariable.value(),
flowSymRef);
assignVariable.erase();
continue;
}
return operand.getOwner()->emitError()
<< "unknown op operating on resource for global tensor : "
<< operand.getOwner()->getName();
}
} else {
// The value is already a tensor value type. Just RAUW it with a
// `flow.variable.load`.
auto load = OpBuilder(func.getBody())
.create<IREE::Flow::VariableLoadOp>(
globalTensor.getLoc(), arg.getType(), flowSymRef);
arg.replaceAllUsesWith(load.result());
}
}
func.eraseArguments(argsToErase);
}
// Erase all the global tensors.
for (auto globalTensor : llvm::make_early_inc_range(
module.getOps<tf_saved_model::GlobalTensorOp>())) {
globalTensor.erase();
}
return success();
}
class TFSavedModelLowerGlobalTensors
: public ModulePass<TFSavedModelLowerGlobalTensors> {
public:
void runOnModule() override {
if (failed(importTfSavedModelGlobalTensorsToIREEFlow(getModule()))) {
signalPassFailure();
}
}
};
std::unique_ptr<OpPassBase<ModuleOp>> createTFSavedModelLowerGlobalTensors() {
return std::make_unique<TFSavedModelLowerGlobalTensors>();
}
static PassRegistration<TFSavedModelLowerGlobalTensors> pass(
"iree-tf-saved-model-lower-global-tensors",
"Lowers tf_saved_model global tensors to flow dialect.");
} // namespace iree_compiler
} // namespace mlir