Stella Laurenzo | b0b03ee | 2020-12-03 12:32:44 -0800 | [diff] [blame] | 1 | // Copyright 2020 Google LLC |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | // you may not use this file except in compliance with the License. |
| 5 | // You may obtain a copy of the License at |
| 6 | // |
| 7 | // https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | // |
| 9 | // Unless required by applicable law or agreed to in writing, software |
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | // See the License for the specific language governing permissions and |
| 13 | // limitations under the License. |
| 14 | |
| 15 | // Main entry function for the iree-tf-import tool (and derived binaries). |
| 16 | // Note that this is not an e2e tool: it is purely the first stage of the |
| 17 | // pipeline intended to lower TensorFlow GraphDefs and SavedModels to a form |
| 18 | // suitable for input to the IREE compiler. |
| 19 | // |
| 20 | // Since none of the TensorFlow imports come from an MLIR text form, it is a bit |
| 21 | // of an odd fit for a *-translate style tool, which is why this diverges. |
| 22 | |
| 23 | #include "integrations/tensorflow/compiler/Passes.h" |
| 24 | #include "llvm/Support/CommandLine.h" |
| 25 | #include "llvm/Support/ErrorHandling.h" |
| 26 | #include "llvm/Support/InitLLVM.h" |
| 27 | #include "llvm/Support/ToolOutputFile.h" |
| 28 | #include "mlir/IR/AsmState.h" |
| 29 | #include "mlir/IR/BuiltinOps.h" |
| 30 | #include "mlir/IR/Dialect.h" |
| 31 | #include "mlir/IR/MLIRContext.h" |
| 32 | #include "mlir/IR/OperationSupport.h" |
| 33 | #include "mlir/Pass/PassManager.h" |
| 34 | #include "mlir/Support/FileUtilities.h" |
| 35 | #include "tensorflow/cc/saved_model/loader.h" |
| 36 | #include "tensorflow/compiler/mlir/init_mlir.h" |
| 37 | #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" |
| 38 | #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" |
| 39 | #include "tensorflow/core/platform/errors.h" |
Stella Laurenzo | b0b03ee | 2020-12-03 12:32:44 -0800 | [diff] [blame] | 40 | using namespace llvm; |
| 41 | using namespace mlir; |
| 42 | |
| 43 | namespace { |
| 44 | |
| 45 | enum ImportType { |
| 46 | savedmodel_v2, |
| 47 | savedmodel_v1, |
| 48 | }; |
| 49 | |
| 50 | } // namespace |
| 51 | |
| 52 | static OwningModuleRef importSavedModelV2( |
| 53 | MLIRContext &context, const std::string &inputPath, |
| 54 | const std::string &savedModelExportedNames) { |
| 55 | tensorflow::SavedModelV2Bundle bundle; |
| 56 | auto loadStatus = tensorflow::SavedModelV2Bundle::Load(inputPath, &bundle); |
| 57 | if (!loadStatus.ok()) { |
Stella Laurenzo | f468273 | 2020-12-08 17:31:11 -0800 | [diff] [blame^] | 58 | llvm::errs() << "TensorFlow reported error loading saved model:\n " |
| 59 | << loadStatus.ToString() << "\n\n"; |
Stella Laurenzo | b0b03ee | 2020-12-03 12:32:44 -0800 | [diff] [blame] | 60 | if (!tensorflow::errors::IsNotFound(loadStatus)) { |
Stella Laurenzo | f468273 | 2020-12-08 17:31:11 -0800 | [diff] [blame^] | 61 | llvm::errs() |
| 62 | << "Note: Attempted to load V2 SavedModel. Double check that " |
| 63 | "this is correct " |
| 64 | << "and adjust via the flag " |
| 65 | "--tf-import-type=savedmodel_v1|savedmodel_v2\n"; |
Stella Laurenzo | b0b03ee | 2020-12-03 12:32:44 -0800 | [diff] [blame] | 66 | } |
| 67 | return nullptr; |
| 68 | } |
| 69 | |
| 70 | std::vector<std::string> exportedNamesVector = |
| 71 | absl::StrSplit(savedModelExportedNames, ',', absl::SkipEmpty()); |
| 72 | auto loadedModule = tensorflow::ConvertSavedModelToMlir( |
| 73 | &bundle, &context, absl::MakeSpan(exportedNamesVector)); |
| 74 | if (!loadedModule.ok()) { |
Stella Laurenzo | f468273 | 2020-12-08 17:31:11 -0800 | [diff] [blame^] | 75 | llvm::errs() << "Error performing initial import from SavedModel to MLIR. " |
| 76 | << "Reported error below (and see diagnostics):\n" |
| 77 | << " " << loadedModule.status().ToString() << "\n"; |
Stella Laurenzo | b0b03ee | 2020-12-03 12:32:44 -0800 | [diff] [blame] | 78 | return nullptr; |
| 79 | } |
| 80 | |
| 81 | return loadedModule.ConsumeValueOrDie(); |
| 82 | } |
| 83 | |
| 84 | static OwningModuleRef importSavedModelV1( |
| 85 | MLIRContext &context, const std::string &inputPath, |
| 86 | const std::string &savedModelExportedNames, |
| 87 | const std::string &savedModelTags) { |
| 88 | tensorflow::SavedModelBundle bundle; |
| 89 | tensorflow::SessionOptions session_options; |
| 90 | // Force saved model states to be restored to CPU. |
| 91 | (*session_options.config.mutable_device_count())["GPU"] = 0; |
| 92 | |
| 93 | std::unordered_set<std::string> tags = |
| 94 | absl::StrSplit(savedModelTags, ',', absl::SkipEmpty()); |
| 95 | auto loadStatus = |
| 96 | tensorflow::LoadSavedModel(session_options, |
| 97 | /*run_options=*/{}, inputPath, tags, &bundle); |
| 98 | if (!loadStatus.ok()) { |
Stella Laurenzo | f468273 | 2020-12-08 17:31:11 -0800 | [diff] [blame^] | 99 | llvm::errs() << "TensorFlow reported error loading saved model:\n " |
| 100 | << loadStatus.ToString() << "\n\n"; |
Stella Laurenzo | b0b03ee | 2020-12-03 12:32:44 -0800 | [diff] [blame] | 101 | if (!tensorflow::errors::IsNotFound(loadStatus)) { |
Stella Laurenzo | f468273 | 2020-12-08 17:31:11 -0800 | [diff] [blame^] | 102 | llvm::errs() |
| 103 | << "Note: Attempted to load V1 SavedModel. Double check that " |
| 104 | "this is correct " |
| 105 | << "and adjust via the flag " |
| 106 | "--tf-import-type=savedmodel_v1|savedmodel_v2\n"; |
Stella Laurenzo | b0b03ee | 2020-12-03 12:32:44 -0800 | [diff] [blame] | 107 | } |
| 108 | return nullptr; |
| 109 | } |
| 110 | |
| 111 | std::vector<std::string> exportedNamesVector = |
| 112 | absl::StrSplit(savedModelExportedNames, ',', absl::SkipEmpty()); |
| 113 | |
| 114 | auto loadedModule = ConvertSavedModelV1ToMlir( |
| 115 | bundle, absl::MakeSpan(exportedNamesVector), &context, |
| 116 | /*upgrade_legacy=*/false); |
| 117 | |
| 118 | if (!loadedModule.ok()) { |
Stella Laurenzo | f468273 | 2020-12-08 17:31:11 -0800 | [diff] [blame^] | 119 | llvm::errs() << "Error performing initial import from SavedModel to MLIR. " |
| 120 | << "Reported error below (and see diagnostics):\n" |
| 121 | << " " << loadedModule.status().ToString() << "\n"; |
Stella Laurenzo | b0b03ee | 2020-12-03 12:32:44 -0800 | [diff] [blame] | 122 | return nullptr; |
| 123 | } |
| 124 | |
| 125 | return loadedModule.ConsumeValueOrDie(); |
| 126 | } |
| 127 | |
| 128 | int main(int argc, char **argv) { |
| 129 | tensorflow::InitMlir y(&argc, &argv); |
| 130 | |
| 131 | static cl::opt<std::string> inputPath( |
| 132 | cl::Positional, cl::desc("<saved model directory>"), cl::Required); |
| 133 | static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"), |
| 134 | cl::value_desc("filename"), |
| 135 | cl::init("-")); |
| 136 | static cl::opt<ImportType> importType( |
| 137 | "tf-import-type", cl::desc("The type of TensorFlow model to import"), |
| 138 | cl::values(clEnumVal(savedmodel_v2, |
| 139 | "Import a TensorFlow SavedModel V2 (directory)"), |
| 140 | clEnumVal(savedmodel_v1, |
| 141 | "Import a TensorFlow SavedModel V1 (directory)"))); |
| 142 | |
| 143 | static llvm::cl::opt<std::string> savedModelExportedNames( |
| 144 | "tf-savedmodel-exported-names", |
| 145 | llvm::cl::desc("Names to export from SavedModel, separated by ','. Empty " |
| 146 | "(the default) means export all."), |
| 147 | llvm::cl::init("")); |
| 148 | |
| 149 | static llvm::cl::opt<std::string> savedModelTags( |
| 150 | "tf-savedmodel-tags", |
| 151 | llvm::cl::desc("Tags used to indicate which MetaGraphDef to import, " |
| 152 | "separated by ','"), |
| 153 | llvm::cl::init("serve")); |
Stella Laurenzo | f468273 | 2020-12-08 17:31:11 -0800 | [diff] [blame^] | 154 | static llvm::cl::opt<std::string> saveTempTfInput( |
| 155 | "save-temp-tf-input", |
| 156 | llvm::cl::desc("Save the TF pipeline input to this file"), |
| 157 | llvm::cl::init("")); |
| 158 | static llvm::cl::opt<std::string> saveTempIreeImport( |
| 159 | "save-temp-iree-input", |
| 160 | llvm::cl::desc("Save the resultant IR to this file (useful for saving an " |
| 161 | "intermediate in a pipeline)"), |
| 162 | llvm::cl::init("")); |
Stella Laurenzo | b0b03ee | 2020-12-03 12:32:44 -0800 | [diff] [blame] | 163 | |
| 164 | // Register any command line options. |
| 165 | registerAsmPrinterCLOptions(); |
| 166 | registerMLIRContextCLOptions(); |
| 167 | registerPassManagerCLOptions(); |
| 168 | cl::ParseCommandLineOptions(argc, argv); |
| 169 | |
| 170 | DialectRegistry registry; |
| 171 | RegisterAllTensorFlowDialects(registry); |
| 172 | |
| 173 | MLIRContext context; |
| 174 | OwningModuleRef module; |
| 175 | registry.loadAll(&context); |
| 176 | |
Stella Laurenzo | f468273 | 2020-12-08 17:31:11 -0800 | [diff] [blame^] | 177 | auto saveToFile = [&](llvm::StringRef savePath) -> LogicalResult { |
| 178 | auto outputFile = openOutputFile(savePath); |
| 179 | if (!outputFile) { |
| 180 | llvm::errs() << "Could not open output file: " << savePath << "\n"; |
| 181 | return failure(); |
| 182 | } |
| 183 | OpPrintingFlags printFlags; |
| 184 | printFlags.enableDebugInfo(); |
| 185 | module->print(outputFile->os(), printFlags); |
| 186 | outputFile->os() << "\n"; |
| 187 | outputFile->keep(); |
| 188 | return success(); |
| 189 | }; |
| 190 | |
Stella Laurenzo | b0b03ee | 2020-12-03 12:32:44 -0800 | [diff] [blame] | 191 | // First stage import. |
| 192 | switch (importType) { |
| 193 | case savedmodel_v2: |
| 194 | module = importSavedModelV2(context, inputPath, savedModelExportedNames); |
| 195 | break; |
| 196 | case savedmodel_v1: |
| 197 | module = importSavedModelV1(context, inputPath, savedModelExportedNames, |
| 198 | savedModelTags); |
| 199 | break; |
| 200 | default: |
| 201 | llvm_unreachable("unsupported import type enum"); |
| 202 | } |
| 203 | if (!module) return 1; |
| 204 | |
Stella Laurenzo | f468273 | 2020-12-08 17:31:11 -0800 | [diff] [blame^] | 205 | // Save temp output. |
| 206 | if (!saveTempTfInput.empty()) { |
| 207 | if (failed(saveToFile(saveTempTfInput))) return 10; |
| 208 | } |
| 209 | |
Stella Laurenzo | b0b03ee | 2020-12-03 12:32:44 -0800 | [diff] [blame] | 210 | // Run passes. |
| 211 | PassManager pm(&context, PassManager::Nesting::Implicit); |
Stella Laurenzo | f468273 | 2020-12-08 17:31:11 -0800 | [diff] [blame^] | 212 | applyPassManagerCLOptions(pm); |
| 213 | |
Stella Laurenzo | b0b03ee | 2020-12-03 12:32:44 -0800 | [diff] [blame] | 214 | iree_compiler::TF::buildTFImportPassPipeline(pm); |
| 215 | if (failed(pm.run(*module))) { |
Stella Laurenzo | f468273 | 2020-12-08 17:31:11 -0800 | [diff] [blame^] | 216 | llvm::errs() |
Stella Laurenzo | b0b03ee | 2020-12-03 12:32:44 -0800 | [diff] [blame] | 217 | << "Running iree-tf-import pass pipeline failed (see diagnostics)\n"; |
| 218 | return 2; |
| 219 | } |
| 220 | |
Stella Laurenzo | f468273 | 2020-12-08 17:31:11 -0800 | [diff] [blame^] | 221 | // Save temp output. |
| 222 | if (!saveTempIreeImport.empty()) { |
| 223 | if (failed(saveToFile(saveTempIreeImport))) return 10; |
Stella Laurenzo | b0b03ee | 2020-12-03 12:32:44 -0800 | [diff] [blame] | 224 | } |
Stella Laurenzo | b0b03ee | 2020-12-03 12:32:44 -0800 | [diff] [blame] | 225 | |
Stella Laurenzo | f468273 | 2020-12-08 17:31:11 -0800 | [diff] [blame^] | 226 | // Save output. |
| 227 | if (failed(saveToFile(outputFilename))) return 3; |
Stella Laurenzo | b0b03ee | 2020-12-03 12:32:44 -0800 | [diff] [blame] | 228 | return 0; |
| 229 | } |