| // Copyright 2019 Google LLC | 
 | // | 
 | // Licensed under the Apache License, Version 2.0 (the "License"); | 
 | // you may not use this file except in compliance with the License. | 
 | // You may obtain a copy of the License at | 
 | // | 
 | //      https://www.apache.org/licenses/LICENSE-2.0 | 
 | // | 
 | // Unless required by applicable law or agreed to in writing, software | 
 | // distributed under the License is distributed on an "AS IS" BASIS, | 
 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
 | // See the License for the specific language governing permissions and | 
 | // limitations under the License. | 
 |  | 
 | #include "integrations/tensorflow/compiler/Passes.h" | 
 | #include "iree/base/signature_mangle.h" | 
 | #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" | 
 | #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" | 
 | #include "iree/compiler/Dialect/IREE/IR/IREEDialect.h" | 
 | #include "iree/compiler/Dialect/IREE/IR/IREETypes.h" | 
 | #include "llvm/ADT/PostOrderIterator.h" | 
 | #include "llvm/ADT/STLExtras.h" | 
 | #include "mlir/IR/Attributes.h" | 
 | #include "mlir/IR/Dialect.h" | 
 | #include "mlir/IR/MLIRContext.h" | 
 | #include "mlir/IR/RegionGraphTraits.h" | 
 | #include "mlir/IR/SymbolTable.h" | 
 | #include "mlir/Pass/Pass.h" | 
 | #include "mlir/Pass/PassRegistry.h" | 
 | #include "mlir/Support/LLVM.h" | 
 | #include "mlir/Support/LogicalResult.h" | 
 | #include "mlir/Transforms/Utils.h" | 
 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" | 
 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" | 
 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" | 
 |  | 
 | namespace mlir { | 
 | namespace iree_compiler { | 
 |  | 
 | static LogicalResult rewriteTfResourceOpToFlowOp(Operation &op, Value flowPtr) { | 
 |   if (auto readVariable = dyn_cast<TF::ReadVariableOp>(op)) { | 
 |     auto load = | 
 |         OpBuilder(readVariable) | 
 |             .create<IREE::Flow::VariableLoadIndirectOp>( | 
 |                 readVariable.getLoc(), readVariable.value().getType(), flowPtr); | 
 |     readVariable.value().replaceAllUsesWith(load.result()); | 
 |     readVariable.erase(); | 
 |   } else if (auto assignVariable = dyn_cast<TF::AssignVariableOp>(op)) { | 
 |     OpBuilder(assignVariable) | 
 |         .create<IREE::Flow::VariableStoreIndirectOp>( | 
 |             assignVariable.getLoc(), assignVariable.value(), flowPtr); | 
 |     assignVariable.erase(); | 
 |   } else { | 
 |     return op.emitError() << "could not lower resource op to flow: " | 
 |                           << op.getName(); | 
 |   } | 
 |   return success(); | 
 | } | 
 |  | 
 | static LogicalResult importTfSavedModelGlobalTensorsToIREEFlow( | 
 |     ModuleOp module) { | 
 |   OpBuilder globalBuilder(module.getBodyRegion()); | 
 |   SymbolTable symbolTable(module); | 
 |  | 
 |   if (auto sessionInitializer = tf_saved_model::GetSessionInitializerOp(module)) | 
 |     return sessionInitializer.emitError() | 
 |            << "Session initializer is not supported yet"; | 
 |  | 
 |   DenseMap<StringRef, std::string> symNameToFlowSymName; | 
 |   for (auto globalTensor : module.getOps<tf_saved_model::GlobalTensorOp>()) { | 
 |     auto exportedNames = tf_saved_model::GetExportedNames(globalTensor); | 
 |     std::string flowSymName; | 
 |     if (exportedNames.empty()) { | 
 |       flowSymName = "__iree_flow_" + globalTensor.sym_name().str(); | 
 |     } else if (exportedNames.size() == 1) { | 
 |       flowSymName = exportedNames[0].str(); | 
 |     } else { | 
 |       return globalTensor.emitError() | 
 |              << "Multiple exported names for global tensor not supported yet"; | 
 |     } | 
 |     symNameToFlowSymName[globalTensor.sym_name()] = flowSymName; | 
 |     auto variableOp = globalBuilder.create<IREE::Flow::VariableOp>( | 
 |         globalTensor.getLoc(), flowSymName, globalTensor.is_mutable(), | 
 |         globalTensor.type(), globalTensor.value()); | 
 |     SymbolTable::setSymbolVisibility(variableOp, | 
 |                                      SymbolTable::Visibility::Private); | 
 |   } | 
 |  | 
 |   // TODO(silvasean): Make this conversion interprocedural. | 
 |   for (auto func : module.getOps<FuncOp>()) { | 
 |     if (!tf_saved_model::IsExported(func)) { | 
 |       continue; | 
 |     } | 
 |     SmallVector<unsigned, 4> argsToErase; | 
 |     OpBuilder builder(func.getBody()); | 
 |     SmallVector<Value, 8> typeConversionWorklist; | 
 |     for (int i = 0, e = func.getNumArguments(); i < e; i++) { | 
 |       auto globalTensor = tf_saved_model::LookupBoundInputOfType< | 
 |           tf_saved_model::GlobalTensorOp>(func, i, symbolTable); | 
 |       if (!globalTensor) { | 
 |         continue; | 
 |       } | 
 |       auto variableAddressOp = builder.create<IREE::Flow::VariableAddressOp>( | 
 |           globalTensor.getLoc(), IREE::PtrType::get(globalTensor.type()), | 
 |           builder.getSymbolRefAttr( | 
 |               symNameToFlowSymName[globalTensor.sym_name()])); | 
 |       typeConversionWorklist.push_back(variableAddressOp.getResult()); | 
 |       func.getArgument(i).replaceAllUsesWith(variableAddressOp.getResult()); | 
 |       argsToErase.push_back(i); | 
 |     } | 
 |     func.eraseArguments(argsToErase); | 
 |     Dialect *ireeFlowDialect = | 
 |         func.getContext()->getLoadedDialect<IREE::Flow::FlowDialect>(); | 
 |     while (!typeConversionWorklist.empty()) { | 
 |       Value v = typeConversionWorklist.pop_back_val(); | 
 |       Type desiredType = v.getType(); | 
 |       for (OpOperand &use : llvm::make_early_inc_range(v.getUses())) { | 
 |         Operation *owner = use.getOwner(); | 
 |         // If the user is already in the flow dialect, then everything is ok. | 
 |         if (owner->getDialect() == ireeFlowDialect) { | 
 |           continue; | 
 |         } | 
 |         // If a user is just a terminator passing the value through a successor | 
 |         // operand, propagate through the successor operand. | 
 |         // TODO(silvasean): Handle case of different types in preds. | 
 |         // This would require calculating a common type. | 
 |         // This won't be a problem unless we see IR that effectively phi's | 
 |         // together different resources, which I don't think tensorflow does. | 
 |         if (BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(owner)) { | 
 |           if (auto arg = | 
 |                   branchOp.getSuccessorBlockArgument(use.getOperandNumber())) { | 
 |             if (arg->getType() != desiredType) { | 
 |               arg->setType(desiredType); | 
 |               typeConversionWorklist.push_back(*arg); | 
 |             } | 
 |             continue; | 
 |           } | 
 |         } | 
 |         // Resource types can have subtypes (or lack thereof) and casting | 
 |         // between them is allowed. Here we just pass through. | 
 |         if (auto castOp = dyn_cast<TF::CastOp>(owner)) { | 
 |           assert(v == castOp.x()); | 
 |           castOp.y().replaceAllUsesWith(castOp.x()); | 
 |           castOp.erase(); | 
 |           // The RAUW could have added more uses of `v`, so put it back on the | 
 |           // worklist and process it again. | 
 |           typeConversionWorklist.push_back(v); | 
 |           break; | 
 |         } | 
 |         if (failed(rewriteTfResourceOpToFlowOp(*owner, v))) { | 
 |           return failure(); | 
 |         } | 
 |       } | 
 |     } | 
 |   } | 
 |  | 
 |   // Erase all the global tensors. | 
 |   for (auto globalTensor : llvm::make_early_inc_range( | 
 |            module.getOps<tf_saved_model::GlobalTensorOp>())) { | 
 |     globalTensor.erase(); | 
 |   } | 
 |   return success(); | 
 | } | 
 |  | 
 | class TFSavedModelLowerGlobalTensors | 
 |     : public PassWrapper<TFSavedModelLowerGlobalTensors, | 
 |                          OperationPass<ModuleOp>> { | 
 |  public: | 
 |   void getDependentDialects(DialectRegistry ®istry) const override { | 
 |     registry.insert<IREE::Flow::FlowDialect, IREEDialect>(); | 
 |   } | 
 |  | 
 |   void runOnOperation() override { | 
 |     if (failed(importTfSavedModelGlobalTensorsToIREEFlow(getOperation()))) { | 
 |       signalPassFailure(); | 
 |     } | 
 |   } | 
 | }; | 
 |  | 
 | std::unique_ptr<OperationPass<ModuleOp>> | 
 | createTFSavedModelLowerGlobalTensors() { | 
 |   return std::make_unique<TFSavedModelLowerGlobalTensors>(); | 
 | } | 
 |  | 
 | static PassRegistration<TFSavedModelLowerGlobalTensors> pass( | 
 |     "iree-tf-saved-model-lower-global-tensors", | 
 |     "Lowers tf_saved_model global tensors to flow dialect."); | 
 |  | 
 | }  // namespace iree_compiler | 
 | }  // namespace mlir |