Save MHLO input in iree-import-xla (#6550)
This can help debugging MHLO import failures. We already do this in the
TF import, but weren't doing it here, probably because MHLO input used
to be synonymous with IREE input.
diff --git a/bindings/python/iree/compiler/xla.py b/bindings/python/iree/compiler/xla.py
index eb04231..a0b0578 100644
--- a/bindings/python/iree/compiler/xla.py
+++ b/bindings/python/iree/compiler/xla.py
@@ -74,6 +74,7 @@
import_format: Union[ImportFormat,
str] = ImportFormat.BINARY_PROTO,
import_extra_args: Sequence[str] = (),
+ save_temp_mhlo_input: Optional[str] = None,
save_temp_iree_input: Optional[str] = None,
**kwargs):
"""Initialize options from keywords.
@@ -87,6 +88,7 @@
self.import_only = import_only
self.import_format = ImportFormat.parse(import_format)
self.import_extra_args = import_extra_args
+ self.save_temp_mhlo_input = save_temp_mhlo_input
self.save_temp_iree_input = save_temp_iree_input
@@ -121,6 +123,10 @@
cl.append("--mlir-print-op-generic")
# Save temps flags.
+ save_mhlo_input = tfs.alloc_optional("tf-mhlo.mlir",
+ export_as=options.save_temp_mhlo_input)
+ if save_mhlo_input:
+ cl.append(f"--save-temp-mhlo-input={save_mhlo_input}")
iree_input = tfs.alloc_optional("xla-iree-input.mlir",
export_as=options.save_temp_iree_input)
if iree_input:
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-import-tf-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-import-tf-main.cpp
index feec3ff..ee95f1e 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-import-tf-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-import-tf-main.cpp
@@ -224,8 +224,8 @@
iree_integrations::TF::buildTFImportPassPipeline(pm);
if (failed(pm.run(*module))) {
- llvm::errs()
- << "Running iree-import-tf pass pipeline failed (see diagnostics)\n";
+ llvm::errs() << "Running iree-import-tf TF import pass pipeline failed "
+ "(see diagnostics)\n";
return 2;
}
if (!saveTempMidLevelImport.empty()) {
@@ -237,8 +237,8 @@
applyPassManagerCLOptions(pm);
iree_integrations::MHLO::buildMHLOImportPassPipeline(pm);
if (failed(pm.run(*module))) {
- llvm::errs()
- << "Running iree-import-tf pass pipeline failed (see diagnostics)\n";
+ llvm::errs() << "Running iree-import-tf MHLO Import pass pipeline failed "
+ "(see diagnostics)\n";
return 2;
}
}
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
index 2e11d7f..08f270a 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
@@ -100,6 +100,10 @@
static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
cl::value_desc("filename"),
cl::init("-"));
+ static llvm::cl::opt<std::string> saveTempMhloInput(
+ "save-temp-mhlo-input",
+ llvm::cl::desc("Save the MHLO pipeline input IR to this file"),
+ llvm::cl::init(""));
static llvm::cl::opt<std::string> saveTempIreeImport(
"save-temp-iree-input",
llvm::cl::desc("Save the resultant IR to this file (useful for saving an "
@@ -251,6 +255,11 @@
return success();
};
+ // Save temp output.
+ if (!saveTempMhloInput.empty()) {
+ if (failed(saveToFile(saveTempMhloInput))) return 10;
+ }
+
// Run passes.
PassManager pm(&context, PassManager::Nesting::Implicit);
applyPassManagerCLOptions(pm);
@@ -264,8 +273,8 @@
iree_integrations::MHLO::createEmitDefaultIREEABIPass());
if (failed(pm.run(*module))) {
- llvm::errs()
- << "Running iree-xla-import pass pipeline failed (see diagnostics)\n";
+ llvm::errs() << "Running iree-xla-import MHLO import pass pipeline failed "
+ "(see diagnostics)\n";
return 2;
}