Add support for saving mid level IR to a file. (#6439)
Fixes https://github.com/google/iree/issues/6412
diff --git a/bindings/python/iree/compiler/tf.py b/bindings/python/iree/compiler/tf.py
index 2389c37..e8a3b55 100644
--- a/bindings/python/iree/compiler/tf.py
+++ b/bindings/python/iree/compiler/tf.py
@@ -82,6 +82,7 @@
saved_model_tags: Set[str] = set(),
import_extra_args: Sequence[str] = (),
save_temp_tf_input: Optional[str] = None,
+ save_temp_mid_level_input: Optional[str] = None,
save_temp_iree_input: Optional[str] = None,
**kwargs):
"""Initialize options from keywords.
@@ -101,6 +102,8 @@
import_extra_args: Extra arguments to pass to the iree-import-tf tool.
save_temp_tf_input: Optionally save the IR that is input to the
TensorFlow pipeline.
+ save_temp_mid_level_input: Optionally save the IR that is input to the
+ mid level IR.
save_temp_iree_input: Optionally save the IR that is the result of the
import (ready to be passed to IREE).
"""
@@ -111,6 +114,7 @@
self.saved_model_tags = saved_model_tags
self.import_extra_args = import_extra_args
self.save_temp_tf_input = save_temp_tf_input
+ self.save_temp_mid_level_input = save_temp_mid_level_input
self.save_temp_iree_input = save_temp_iree_input
@@ -151,6 +155,10 @@
export_as=options.save_temp_tf_input)
if save_tf_input:
cl.append(f"--save-temp-tf-input={save_tf_input}")
+ save_mid_level_input = tfs.alloc_optional(
+ "tf-mid-level-input.mlir", export_as=options.save_temp_mid_level_input)
+ if save_mid_level_input:
+ cl.append(f"--save-temp-mid-level-input={save_mid_level_input}")
save_iree_input = tfs.alloc_optional("tf-iree-input.mlir",
export_as=options.save_temp_iree_input)
if save_iree_input:
diff --git a/integrations/tensorflow/bindings/python/iree/tf/support/module_utils.py b/integrations/tensorflow/bindings/python/iree/tf/support/module_utils.py
index 69dee01..a84afe0 100644
--- a/integrations/tensorflow/bindings/python/iree/tf/support/module_utils.py
+++ b/integrations/tensorflow/bindings/python/iree/tf/support/module_utils.py
@@ -66,6 +66,8 @@
kwargs["saved_model_dir"] = os.path.join(artifacts_dir,
"tfmodule.saved_model")
kwargs["save_temp_tf_input"] = os.path.join(artifacts_dir, "tf_input.mlir")
+ kwargs["save_temp_mid_level_input"] = os.path.join(artifacts_dir,
+ "tf_mid_level_input.mlir")
kwargs["save_temp_iree_input"] = os.path.join(artifacts_dir,
"iree_input.mlir")
diff --git a/integrations/tensorflow/iree_tf_compiler/BUILD b/integrations/tensorflow/iree_tf_compiler/BUILD
index 4e94220..29bbcb5 100644
--- a/integrations/tensorflow/iree_tf_compiler/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/BUILD
@@ -73,6 +73,7 @@
name = "iree-import-tf",
srcs = ["iree-import-tf-main.cpp"],
deps = [
+ "//iree_tf_compiler/MHLO",
"//iree_tf_compiler/TF",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/Passes.cpp b/integrations/tensorflow/iree_tf_compiler/TF/Passes.cpp
index 5d94373..b2f9a00 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/Passes.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TF/Passes.cpp
@@ -8,7 +8,6 @@
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
-#include "iree_tf_compiler/MHLO/Passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
@@ -104,8 +103,6 @@
pm.addPass(createStripModuleMetadataPass());
pm.nest<ModuleOp>().addPass(createStripFunctionMetadataPass());
pm.addPass(createVerifyFullyConvertedPass());
-
- MHLO::buildMHLOImportPassPipeline(pm);
}
void registerTFImportPassPipeline() {
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-import-tf-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-import-tf-main.cpp
index d2efece..feec3ff 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-import-tf-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-import-tf-main.cpp
@@ -12,6 +12,7 @@
// Since none of the TensorFlow imports come from an MLIR text form, it is a bit
// of an odd fit for a *-translate style tool, which is why this diverges.
+#include "iree_tf_compiler/MHLO/Passes.h"
#include "iree_tf_compiler/TF/Passes.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorHandling.h"
@@ -30,6 +31,7 @@
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/core/platform/errors.h"
+
using namespace llvm;
using namespace mlir;
@@ -148,6 +150,9 @@
"save-temp-tf-input",
llvm::cl::desc("Save the TF pipeline input to this file"),
llvm::cl::init(""));
+ static llvm::cl::opt<std::string> saveTempMidLevelImport(
+ "save-temp-mid-level-input",
+ llvm::cl::desc("Save the mid level IR to this file"), llvm::cl::init(""));
static llvm::cl::opt<std::string> saveTempIreeImport(
"save-temp-iree-input",
llvm::cl::desc("Save the resultant IR to this file (useful for saving an "
@@ -209,18 +214,33 @@
}
// Run passes.
- PassManager pm(&context, PassManager::Nesting::Implicit);
- applyPassManagerCLOptions(pm);
+ {
+ PassManager pm(&context, PassManager::Nesting::Implicit);
+ applyPassManagerCLOptions(pm);
- if (prettifyTfDebugInfo) {
- pm.addPass(iree_integrations::TF::createPrettifyDebugInfoPass());
+ if (prettifyTfDebugInfo) {
+ pm.addPass(iree_integrations::TF::createPrettifyDebugInfoPass());
+ }
+
+ iree_integrations::TF::buildTFImportPassPipeline(pm);
+ if (failed(pm.run(*module))) {
+ llvm::errs()
+ << "Running iree-import-tf pass pipeline failed (see diagnostics)\n";
+ return 2;
+ }
+ if (!saveTempMidLevelImport.empty()) {
+ if (failed(saveToFile(saveTempMidLevelImport))) return 10;
+ }
}
-
- iree_integrations::TF::buildTFImportPassPipeline(pm);
- if (failed(pm.run(*module))) {
- llvm::errs()
- << "Running iree-import-tf pass pipeline failed (see diagnostics)\n";
- return 2;
+ {
+ PassManager pm(&context, PassManager::Nesting::Implicit);
+ applyPassManagerCLOptions(pm);
+ iree_integrations::MHLO::buildMHLOImportPassPipeline(pm);
+ if (failed(pm.run(*module))) {
+ llvm::errs()
+ << "Running iree-import-tf pass pipeline failed (see diagnostics)\n";
+ return 2;
+ }
}
// Save temp output.