Handle supported ImportOptions in tf.py and fix Windows compatibility. (#13287)
This adds back support for `import_only` and `save_temp_iree_input` to
our TensorFlow `compile_saved_model` API. I also removed unsupported
options (`import_extra_args`, `save_temp_tf_input`,
`save_temp_mid_level_input`, and `use_tosa`).
Those flags were dropped in https://github.com/openxla/iree/pull/12758 /
https://github.com/openxla/iree/pull/13025 , but they are still useful
in Colab notebooks and when debugging tests.
---
Progress on https://github.com/openxla/iree/issues/13148, though some
further updates will be needed to our Colab notebooks, such as
```python
# before:
compiler_module = tfc.compile_module(
EdgeDetectionModule(), import_only=True,
import_extra_args=["--output-format=mlir-ir"])
print("Edge Detection MLIR: ", compiler_module.decode('utf-8'))
# after:
compiler_module = tfc.compile_module(
EdgeDetectionModule(), import_only=True)
print("Edge Detection MLIR: ", compiler_module)
```
---
I developed this change on Windows (Yes! Finally, I can use Python that
touches TF on Windows without needing to build TF from source!), where I
found that this pattern is broken:
```python
with tempfile.NamedTemporaryFile(mode="w") as temp_file:
__main__.import_saved_model(output_path=temp_file.name,
```
See https://stackoverflow.com/a/23212515 - `NamedTemporaryFile` _creates
and opens_ the file, and the file _cannot be opened again_... on Windows
(it can be opened again on Unix). I used the trick from
https://stackoverflow.com/a/45803022 to work around this:
```python
with tempfile.TemporaryDirectory() as tmpdir:
# ...
# Not saving the file, so generate a loose temp file without tfs.
tf_iree_input = os.path.join(tmpdir, 'tf-iree-input.mlir')
```diff --git a/compiler/bindings/python/iree/compiler/tools/tf.py b/compiler/bindings/python/iree/compiler/tools/tf.py
index 085659e..6a49e31 100644
--- a/compiler/bindings/python/iree/compiler/tools/tf.py
+++ b/compiler/bindings/python/iree/compiler/tools/tf.py
@@ -9,6 +9,7 @@
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
+import os
import logging
import tempfile
from typing import List, Optional, Sequence, Set, Union
@@ -89,11 +90,6 @@
import_type: Type of import to perform. See ImportType enum.
saved_model_tags: Set of tags to export (signature def/v1 saved models
only).
- import_extra_args: Extra arguments to pass to the iree-import-tf tool.
- save_temp_tf_input: Optionally save the IR that is input to the
- TensorFlow pipeline.
- save_temp_mid_level_input: Optionally save the IR that is input to the
- mid level IR.
save_temp_iree_input: Optionally save the IR that is the result of the
import (ready to be passed to IREE).
"""
@@ -103,11 +99,7 @@
import_type: ImportType = ImportType.OBJECT_GRAPH
input_type: Union[InputType, str] = InputType.XLA
saved_model_tags: Set[str] = field(default_factory=set)
- import_extra_args: Sequence[str] = ()
- save_temp_tf_input: Optional[str] = None
- save_temp_mid_level_input: Optional[str] = None
save_temp_iree_input: Optional[str] = None
- use_tosa: bool = False
def __post_init__(self):
self.import_type = ImportType.parse(self.import_type)
@@ -124,19 +116,35 @@
was specified.
"""
from iree.tools.tf.scripts.iree_import_tf import __main__
- with TempFileSaver.implicit() as tfs:
+
+ with TempFileSaver.implicit() as tfs, tempfile.TemporaryDirectory() as tmpdir:
options = ImportOptions(**kwargs)
- with tempfile.NamedTemporaryFile(mode="w") as temp_file:
- # Generate MLIR
- __main__.import_saved_model(output_path=temp_file.name,
+ if options.import_only and options.output_file:
+ # Importing to a file and stopping, write to that file directly.
+ tf_iree_input = options.output_file
+ elif options.save_temp_iree_input:
+ # Saving the file, use tfs.
+ tf_iree_input = tfs.alloc_optional("tf-iree-input.mlir",
+ export_as=options.save_temp_iree_input)
+ else:
+ # Not saving the file, so generate a loose temp file without tfs.
+ tf_iree_input = os.path.join(tmpdir, 'tf-iree-input.mlir')
+
+ __main__.import_saved_model(output_path=tf_iree_input,
saved_model_dir=saved_model_dir,
exported_names=",".join(options.exported_names),
import_type=options.import_type.value,
tags=",".join(options.saved_model_tags))
- # Full compilation pipeline.
- compile_cl = build_compile_command_line(temp_file.name, tfs, options)
+ if options.import_only:
+ if options.output_file:
+ return None
+ with open(tf_iree_input, "r") as f:
+ return f.read()
+
+ # Run IREE compilation pipeline
+ compile_cl = build_compile_command_line(tf_iree_input, tfs, options)
result = invoke_pipeline([compile_cl])
if options.output_file:
return None
diff --git a/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/module_utils.py b/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/module_utils.py
index 7f5e24d..925a2e1 100644
--- a/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/module_utils.py
+++ b/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/module_utils.py
@@ -58,9 +58,6 @@
if needs_temp_saved_model_dir:
kwargs["saved_model_dir"] = os.path.join(artifacts_dir,
"tfmodule.saved_model")
- kwargs["save_temp_tf_input"] = os.path.join(artifacts_dir, "tf_input.mlir")
- kwargs["save_temp_mid_level_input"] = os.path.join(artifacts_dir,
- "tf_mid_level_input.mlir")
kwargs["save_temp_iree_input"] = os.path.join(artifacts_dir,
"iree_input.mlir")