blob: 62631b8004b55157ebec149ab85ff38b87406715 [file] [log] [blame]
Stella Laurenzob0b03ee2020-12-03 12:32:44 -08001// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Main entry function for the iree-tf-import tool (and derived binaries).
16// Note that this is not an e2e tool: it is purely the first stage of the
17// pipeline intended to lower TensorFlow GraphDefs and SavedModels to a form
18// suitable for input to the IREE compiler.
19//
20// Since none of the TensorFlow imports come from an MLIR text form, it is a bit
21// of an odd fit for a *-translate style tool, which is why this diverges.
22
23#include "integrations/tensorflow/compiler/Passes.h"
24#include "llvm/Support/CommandLine.h"
25#include "llvm/Support/ErrorHandling.h"
26#include "llvm/Support/InitLLVM.h"
27#include "llvm/Support/ToolOutputFile.h"
28#include "mlir/IR/AsmState.h"
29#include "mlir/IR/BuiltinOps.h"
30#include "mlir/IR/Dialect.h"
31#include "mlir/IR/MLIRContext.h"
32#include "mlir/IR/OperationSupport.h"
33#include "mlir/Pass/PassManager.h"
34#include "mlir/Support/FileUtilities.h"
35#include "tensorflow/cc/saved_model/loader.h"
36#include "tensorflow/compiler/mlir/init_mlir.h"
37#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
38#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
39#include "tensorflow/core/platform/errors.h"
Stella Laurenzob0b03ee2020-12-03 12:32:44 -080040using namespace llvm;
41using namespace mlir;
42
43namespace {
44
45enum ImportType {
46 savedmodel_v2,
47 savedmodel_v1,
48};
49
50} // namespace
51
52static OwningModuleRef importSavedModelV2(
53 MLIRContext &context, const std::string &inputPath,
54 const std::string &savedModelExportedNames) {
55 tensorflow::SavedModelV2Bundle bundle;
56 auto loadStatus = tensorflow::SavedModelV2Bundle::Load(inputPath, &bundle);
57 if (!loadStatus.ok()) {
Stella Laurenzof4682732020-12-08 17:31:11 -080058 llvm::errs() << "TensorFlow reported error loading saved model:\n "
59 << loadStatus.ToString() << "\n\n";
Stella Laurenzob0b03ee2020-12-03 12:32:44 -080060 if (!tensorflow::errors::IsNotFound(loadStatus)) {
Stella Laurenzof4682732020-12-08 17:31:11 -080061 llvm::errs()
62 << "Note: Attempted to load V2 SavedModel. Double check that "
63 "this is correct "
64 << "and adjust via the flag "
65 "--tf-import-type=savedmodel_v1|savedmodel_v2\n";
Stella Laurenzob0b03ee2020-12-03 12:32:44 -080066 }
67 return nullptr;
68 }
69
70 std::vector<std::string> exportedNamesVector =
71 absl::StrSplit(savedModelExportedNames, ',', absl::SkipEmpty());
72 auto loadedModule = tensorflow::ConvertSavedModelToMlir(
73 &bundle, &context, absl::MakeSpan(exportedNamesVector));
74 if (!loadedModule.ok()) {
Stella Laurenzof4682732020-12-08 17:31:11 -080075 llvm::errs() << "Error performing initial import from SavedModel to MLIR. "
76 << "Reported error below (and see diagnostics):\n"
77 << " " << loadedModule.status().ToString() << "\n";
Stella Laurenzob0b03ee2020-12-03 12:32:44 -080078 return nullptr;
79 }
80
81 return loadedModule.ConsumeValueOrDie();
82}
83
84static OwningModuleRef importSavedModelV1(
85 MLIRContext &context, const std::string &inputPath,
86 const std::string &savedModelExportedNames,
87 const std::string &savedModelTags) {
88 tensorflow::SavedModelBundle bundle;
89 tensorflow::SessionOptions session_options;
90 // Force saved model states to be restored to CPU.
91 (*session_options.config.mutable_device_count())["GPU"] = 0;
92
93 std::unordered_set<std::string> tags =
94 absl::StrSplit(savedModelTags, ',', absl::SkipEmpty());
95 auto loadStatus =
96 tensorflow::LoadSavedModel(session_options,
97 /*run_options=*/{}, inputPath, tags, &bundle);
98 if (!loadStatus.ok()) {
Stella Laurenzof4682732020-12-08 17:31:11 -080099 llvm::errs() << "TensorFlow reported error loading saved model:\n "
100 << loadStatus.ToString() << "\n\n";
Stella Laurenzob0b03ee2020-12-03 12:32:44 -0800101 if (!tensorflow::errors::IsNotFound(loadStatus)) {
Stella Laurenzof4682732020-12-08 17:31:11 -0800102 llvm::errs()
103 << "Note: Attempted to load V1 SavedModel. Double check that "
104 "this is correct "
105 << "and adjust via the flag "
106 "--tf-import-type=savedmodel_v1|savedmodel_v2\n";
Stella Laurenzob0b03ee2020-12-03 12:32:44 -0800107 }
108 return nullptr;
109 }
110
111 std::vector<std::string> exportedNamesVector =
112 absl::StrSplit(savedModelExportedNames, ',', absl::SkipEmpty());
113
114 auto loadedModule = ConvertSavedModelV1ToMlir(
115 bundle, absl::MakeSpan(exportedNamesVector), &context,
116 /*upgrade_legacy=*/false);
117
118 if (!loadedModule.ok()) {
Stella Laurenzof4682732020-12-08 17:31:11 -0800119 llvm::errs() << "Error performing initial import from SavedModel to MLIR. "
120 << "Reported error below (and see diagnostics):\n"
121 << " " << loadedModule.status().ToString() << "\n";
Stella Laurenzob0b03ee2020-12-03 12:32:44 -0800122 return nullptr;
123 }
124
125 return loadedModule.ConsumeValueOrDie();
126}
127
128int main(int argc, char **argv) {
129 tensorflow::InitMlir y(&argc, &argv);
130
131 static cl::opt<std::string> inputPath(
132 cl::Positional, cl::desc("<saved model directory>"), cl::Required);
133 static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
134 cl::value_desc("filename"),
135 cl::init("-"));
136 static cl::opt<ImportType> importType(
137 "tf-import-type", cl::desc("The type of TensorFlow model to import"),
138 cl::values(clEnumVal(savedmodel_v2,
139 "Import a TensorFlow SavedModel V2 (directory)"),
140 clEnumVal(savedmodel_v1,
141 "Import a TensorFlow SavedModel V1 (directory)")));
142
143 static llvm::cl::opt<std::string> savedModelExportedNames(
144 "tf-savedmodel-exported-names",
145 llvm::cl::desc("Names to export from SavedModel, separated by ','. Empty "
146 "(the default) means export all."),
147 llvm::cl::init(""));
148
149 static llvm::cl::opt<std::string> savedModelTags(
150 "tf-savedmodel-tags",
151 llvm::cl::desc("Tags used to indicate which MetaGraphDef to import, "
152 "separated by ','"),
153 llvm::cl::init("serve"));
Stella Laurenzof4682732020-12-08 17:31:11 -0800154 static llvm::cl::opt<std::string> saveTempTfInput(
155 "save-temp-tf-input",
156 llvm::cl::desc("Save the TF pipeline input to this file"),
157 llvm::cl::init(""));
158 static llvm::cl::opt<std::string> saveTempIreeImport(
159 "save-temp-iree-input",
160 llvm::cl::desc("Save the resultant IR to this file (useful for saving an "
161 "intermediate in a pipeline)"),
162 llvm::cl::init(""));
Stella Laurenzob0b03ee2020-12-03 12:32:44 -0800163
164 // Register any command line options.
165 registerAsmPrinterCLOptions();
166 registerMLIRContextCLOptions();
167 registerPassManagerCLOptions();
168 cl::ParseCommandLineOptions(argc, argv);
169
170 DialectRegistry registry;
171 RegisterAllTensorFlowDialects(registry);
172
173 MLIRContext context;
174 OwningModuleRef module;
175 registry.loadAll(&context);
176
Stella Laurenzof4682732020-12-08 17:31:11 -0800177 auto saveToFile = [&](llvm::StringRef savePath) -> LogicalResult {
178 auto outputFile = openOutputFile(savePath);
179 if (!outputFile) {
180 llvm::errs() << "Could not open output file: " << savePath << "\n";
181 return failure();
182 }
183 OpPrintingFlags printFlags;
184 printFlags.enableDebugInfo();
185 module->print(outputFile->os(), printFlags);
186 outputFile->os() << "\n";
187 outputFile->keep();
188 return success();
189 };
190
Stella Laurenzob0b03ee2020-12-03 12:32:44 -0800191 // First stage import.
192 switch (importType) {
193 case savedmodel_v2:
194 module = importSavedModelV2(context, inputPath, savedModelExportedNames);
195 break;
196 case savedmodel_v1:
197 module = importSavedModelV1(context, inputPath, savedModelExportedNames,
198 savedModelTags);
199 break;
200 default:
201 llvm_unreachable("unsupported import type enum");
202 }
203 if (!module) return 1;
204
Stella Laurenzof4682732020-12-08 17:31:11 -0800205 // Save temp output.
206 if (!saveTempTfInput.empty()) {
207 if (failed(saveToFile(saveTempTfInput))) return 10;
208 }
209
Stella Laurenzob0b03ee2020-12-03 12:32:44 -0800210 // Run passes.
211 PassManager pm(&context, PassManager::Nesting::Implicit);
Stella Laurenzof4682732020-12-08 17:31:11 -0800212 applyPassManagerCLOptions(pm);
213
Stella Laurenzob0b03ee2020-12-03 12:32:44 -0800214 iree_compiler::TF::buildTFImportPassPipeline(pm);
215 if (failed(pm.run(*module))) {
Stella Laurenzof4682732020-12-08 17:31:11 -0800216 llvm::errs()
Stella Laurenzob0b03ee2020-12-03 12:32:44 -0800217 << "Running iree-tf-import pass pipeline failed (see diagnostics)\n";
218 return 2;
219 }
220
Stella Laurenzof4682732020-12-08 17:31:11 -0800221 // Save temp output.
222 if (!saveTempIreeImport.empty()) {
223 if (failed(saveToFile(saveTempIreeImport))) return 10;
Stella Laurenzob0b03ee2020-12-03 12:32:44 -0800224 }
Stella Laurenzob0b03ee2020-12-03 12:32:44 -0800225
Stella Laurenzof4682732020-12-08 17:31:11 -0800226 // Save output.
227 if (failed(saveToFile(outputFilename))) return 3;
Stella Laurenzob0b03ee2020-12-03 12:32:44 -0800228 return 0;
229}