Use Black to format Python files (#14161)
Switch from yapf to Black to better align with the LLVM and broader
Python community. I decided not to go with Pyink as it seems much less
popular and differs in formatting style beyond indentation.
- Reformat all python files outside of `third_party` with black.
- Update the lint workflow to use black. This only considers files
modified by the PR.
- Delete old dotfiles.
The command used to reformat all files at once:
```shell
fd -e py --exclude third_party | xargs black
```
To learn more about Back, see: https://black.readthedocs.io/en/stable/
and https://github.com/psf/black.
In the next PR, once the commit SHA of this PR is finalized, I plan to
add this commit to `.git-blame-ignore-revs` to keep the blame history
clean.
Issue: https://github.com/openxla/iree/issues/14135
diff --git a/samples/colab/test_notebooks.py b/samples/colab/test_notebooks.py
index a714bae..8dbbba2 100755
--- a/samples/colab/test_notebooks.py
+++ b/samples/colab/test_notebooks.py
@@ -24,36 +24,37 @@
class ColabNotebookTests(unittest.TestCase):
- """Tests running all Colab notebooks in this directory."""
+ """Tests running all Colab notebooks in this directory."""
- @classmethod
- def generateTests(cls):
- repo_root = os.path.dirname(
- os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
- script_path = os.path.join(repo_root,
- "build_tools/testing/run_python_notebook.sh")
+ @classmethod
+ def generateTests(cls):
+ repo_root = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ )
+ script_path = os.path.join(
+ repo_root, "build_tools/testing/run_python_notebook.sh"
+ )
- # Create a test case for each notebook in this folder.
- notebooks_path = os.path.join(repo_root, "samples/colab/")
- for notebook_path in glob.glob(notebooks_path + "*.ipynb"):
- notebook_name = os.path.basename(notebook_path)
+ # Create a test case for each notebook in this folder.
+ notebooks_path = os.path.join(repo_root, "samples/colab/")
+ for notebook_path in glob.glob(notebooks_path + "*.ipynb"):
+ notebook_name = os.path.basename(notebook_path)
- def unit_test(self, notebook_path=notebook_path):
+ def unit_test(self, notebook_path=notebook_path):
+ completed_process = subprocess.run([script_path, notebook_path])
+ self.assertEqual(completed_process.returncode, 0)
- completed_process = subprocess.run([script_path, notebook_path])
- self.assertEqual(completed_process.returncode, 0)
+ if notebook_name in NOTEBOOKS_TO_SKIP:
+ unit_test = unittest.skip("Skip requested")(unit_test)
+ elif notebook_name in NOTEBOOKS_EXPECTED_TO_FAIL:
+ unit_test = unittest.expectedFailure(unit_test)
- if notebook_name in NOTEBOOKS_TO_SKIP:
- unit_test = unittest.skip("Skip requested")(unit_test)
- elif notebook_name in NOTEBOOKS_EXPECTED_TO_FAIL:
- unit_test = unittest.expectedFailure(unit_test)
-
- # Add 'unit_test' to this class, so the test runner runs it.
- unit_test.__name__ = f"test_{notebook_name}"
- setattr(cls, unit_test.__name__, unit_test)
+ # Add 'unit_test' to this class, so the test runner runs it.
+ unit_test.__name__ = f"test_{notebook_name}"
+ setattr(cls, unit_test.__name__, unit_test)
if __name__ == "__main__":
- ColabNotebookTests.generateTests()
- logging.basicConfig(level=logging.DEBUG)
- unittest.main()
+ ColabNotebookTests.generateTests()
+ logging.basicConfig(level=logging.DEBUG)
+ unittest.main()
diff --git a/samples/compiler_plugins/simple_io_sample/test/run_mock.py b/samples/compiler_plugins/simple_io_sample/test/run_mock.py
index db9454a..942c051 100644
--- a/samples/compiler_plugins/simple_io_sample/test/run_mock.py
+++ b/samples/compiler_plugins/simple_io_sample/test/run_mock.py
@@ -17,22 +17,20 @@
print(f"--- Loading {input_file}")
with open(input_file, "rb") as f:
- vmfb_contents = f.read()
+ vmfb_contents = f.read()
def create_simple_io_module():
+ class SimpleIO:
+ def __init__(self, iface):
+ ...
- class SimpleIO:
+ def print_impl(self):
+ print("+++ HELLO FROM SIMPLE_IO")
- def __init__(self, iface):
- ...
-
- def print_impl(self):
- print("+++ HELLO FROM SIMPLE_IO")
-
- iface = rt.PyModuleInterface("simple_io", SimpleIO)
- iface.export("print", "0v_v", SimpleIO.print_impl)
- return iface.create()
+ iface = rt.PyModuleInterface("simple_io", SimpleIO)
+ iface.export("print", "0v_v", SimpleIO.print_impl)
+ return iface.create()
config = rt.Config("local-sync")
diff --git a/samples/lit.cfg.py b/samples/lit.cfg.py
index 77a0498..cc344cb 100644
--- a/samples/lit.cfg.py
+++ b/samples/lit.cfg.py
@@ -20,13 +20,17 @@
config.test_format = lit.formats.ShTest(execute_external=True)
# Forward all IREE environment variables
passthrough_env_vars = ["VK_ICD_FILENAMES"]
-config.environment.update({
- k: v
- for k, v in os.environ.items()
- if k.startswith("IREE_") or k in passthrough_env_vars
-})
+config.environment.update(
+ {
+ k: v
+ for k, v in os.environ.items()
+ if k.startswith("IREE_") or k in passthrough_env_vars
+ }
+)
# Use the most preferred temp directory.
-config.test_exec_root = (os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") or
- os.environ.get("TEST_TMPDIR") or
- os.path.join(tempfile.gettempdir(), "lit"))
+config.test_exec_root = (
+ os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR")
+ or os.environ.get("TEST_TMPDIR")
+ or os.path.join(tempfile.gettempdir(), "lit")
+)
diff --git a/samples/py_custom_module/decode_secret_message.py b/samples/py_custom_module/decode_secret_message.py
index 42682b4..e28d3e3 100644
--- a/samples/py_custom_module/decode_secret_message.py
+++ b/samples/py_custom_module/decode_secret_message.py
@@ -29,113 +29,113 @@
def create_tokenizer_module():
- """Creates a module which defines some custom methods for decoding."""
+ """Creates a module which defines some custom methods for decoding."""
- class Detokenizer:
+ class Detokenizer:
+ def __init__(self, iface):
+ # Any class state here is maintained per-context.
+ self.start_of_text = True
+ self.start_of_sentence = True
- def __init__(self, iface):
- # Any class state here is maintained per-context.
- self.start_of_text = True
- self.start_of_sentence = True
+ def reset(self):
+ self.start_of_text = True
+ self.start_of_sentence = True
- def reset(self):
- self.start_of_text = True
- self.start_of_sentence = True
+ def accumtokens(self, ids_tensor_ref, token_list_ref):
+ # TODO: This little dance to turn BufferView refs into real arrays... is not good.
+ ids_bv = ids_tensor_ref.deref(rt.HalBufferView)
+ ids_array = ids_bv.map().asarray(
+ ids_bv.shape, rt.HalElementType.map_to_dtype(ids_bv.element_type)
+ )
+ token_list = token_list_ref.deref(rt.VmVariantList)
+ for index in range(ids_array.shape[0]):
+ token_id = ids_array[index]
+ token = TOKEN_TABLE[token_id]
- def accumtokens(self, ids_tensor_ref, token_list_ref):
- # TODO: This little dance to turn BufferView refs into real arrays... is not good.
- ids_bv = ids_tensor_ref.deref(rt.HalBufferView)
- ids_array = ids_bv.map().asarray(
- ids_bv.shape, rt.HalElementType.map_to_dtype(ids_bv.element_type))
- token_list = token_list_ref.deref(rt.VmVariantList)
- for index in range(ids_array.shape[0]):
- token_id = ids_array[index]
- token = TOKEN_TABLE[token_id]
+ # And this dance to make a buffer... is also not good.
+ # A real implementation would just map the constant memory, etc.
+ buffer = rt.VmBuffer(len(token))
+ buffer_view = memoryview(buffer)
+ buffer_view[:] = token
+ token_list.push_ref(buffer)
+ return ids_array.shape[0]
- # And this dance to make a buffer... is also not good.
- # A real implementation would just map the constant memory, etc.
- buffer = rt.VmBuffer(len(token))
- buffer_view = memoryview(buffer)
- buffer_view[:] = token
- token_list.push_ref(buffer)
- return ids_array.shape[0]
+ def jointokens(self, token_list_ref):
+ # The world's dumbest detokenizer. Ideally, the state tracking
+ # would be in a module private type that got retained and passed
+ # back through.
+ token_list = token_list_ref.deref(rt.VmVariantList)
+ text = bytearray()
+ for i in range(len(token_list)):
+ item = bytes(token_list.get_as_object(i, rt.VmBuffer))
+ if item == b".":
+ text.extend(b".")
+ self.start_of_sentence = True
+ else:
+ if not self.start_of_text:
+ text.extend(b" ")
+ else:
+ self.start_of_text = False
+ if self.start_of_sentence:
+ text.extend(item[0:1].decode("utf-8").upper().encode("utf-8"))
+ text.extend(item[1:])
+ self.start_of_sentence = False
+ else:
+ text.extend(item)
- def jointokens(self, token_list_ref):
- # The world's dumbest detokenizer. Ideally, the state tracking
- # would be in a module private type that got retained and passed
- # back through.
- token_list = token_list_ref.deref(rt.VmVariantList)
- text = bytearray()
- for i in range(len(token_list)):
- item = bytes(token_list.get_as_object(i, rt.VmBuffer))
- if item == b".":
- text.extend(b".")
- self.start_of_sentence = True
- else:
- if not self.start_of_text:
- text.extend(b" ")
- else:
- self.start_of_text = False
- if self.start_of_sentence:
- text.extend(item[0:1].decode("utf-8").upper().encode("utf-8"))
- text.extend(item[1:])
- self.start_of_sentence = False
- else:
- text.extend(item)
+ # TODO: This dance to make a buffer is still bad.
+ results = rt.VmBuffer(len(text))
+ memoryview(results)[:] = text
+ return results.ref
- # TODO: This dance to make a buffer is still bad.
- results = rt.VmBuffer(len(text))
- memoryview(results)[:] = text
- return results.ref
-
- iface = rt.PyModuleInterface("detokenizer", Detokenizer)
- iface.export("accumtokens", "0rr_i", Detokenizer.accumtokens)
- iface.export("jointokens", "0r_r", Detokenizer.jointokens)
- iface.export("reset", "0v_v", Detokenizer.reset)
- return iface.create()
+ iface = rt.PyModuleInterface("detokenizer", Detokenizer)
+ iface.export("accumtokens", "0rr_i", Detokenizer.accumtokens)
+ iface.export("jointokens", "0r_r", Detokenizer.jointokens)
+ iface.export("reset", "0v_v", Detokenizer.reset)
+ return iface.create()
def compile():
- return compiler.tools.compile_file(os.path.join(os.path.dirname(__file__),
- "main.mlir"),
- target_backends=["vmvx"])
+ return compiler.tools.compile_file(
+ os.path.join(os.path.dirname(__file__), "main.mlir"), target_backends=["vmvx"]
+ )
def main():
- print("Compiling...")
- vmfb_contents = compile()
- print("Decoding secret message...")
- config = rt.Config("local-sync")
- main_module = rt.VmModule.from_flatbuffer(config.vm_instance, vmfb_contents)
- modules = config.default_vm_modules + (
- create_tokenizer_module(),
- main_module,
- )
- context = rt.SystemContext(vm_modules=modules, config=config)
+ print("Compiling...")
+ vmfb_contents = compile()
+ print("Decoding secret message...")
+ config = rt.Config("local-sync")
+ main_module = rt.VmModule.from_flatbuffer(config.vm_instance, vmfb_contents)
+ modules = config.default_vm_modules + (
+ create_tokenizer_module(),
+ main_module,
+ )
+ context = rt.SystemContext(vm_modules=modules, config=config)
- # First message.
- count = context.modules.main.add_tokens(
- np.asarray([5, 10, 11, 1, 3, 4, 5, 7, 12], dtype=np.int32))
- print(f"ADDED {count} tokens")
+ # First message.
+ count = context.modules.main.add_tokens(
+ np.asarray([5, 10, 11, 1, 3, 4, 5, 7, 12], dtype=np.int32)
+ )
+ print(f"ADDED {count} tokens")
- # Second message.
- count = context.modules.main.add_tokens(np.asarray([2, 13], dtype=np.int32))
- print(f"ADDED {count} tokens")
+ # Second message.
+ count = context.modules.main.add_tokens(np.asarray([2, 13], dtype=np.int32))
+ print(f"ADDED {count} tokens")
- text = bytes(context.modules.main.get_results().deref(rt.VmBuffer))
- print(f"RESULTS: {text}")
+ text = bytes(context.modules.main.get_results().deref(rt.VmBuffer))
+ print(f"RESULTS: {text}")
- assert text == b"So long and thanks for all so fish. Bye now"
+ assert text == b"So long and thanks for all so fish. Bye now"
- # Reset and decode some more.
- context.modules.main.reset()
- count = context.modules.main.add_tokens(
- np.asarray([0, 14, 12], dtype=np.int32))
- print(f"ADDED {count} tokens")
- text = bytes(context.modules.main.get_results().deref(rt.VmBuffer))
- print(f"RESULTS: {text}")
- assert text == b"Hi there."
+ # Reset and decode some more.
+ context.modules.main.reset()
+ count = context.modules.main.add_tokens(np.asarray([0, 14, 12], dtype=np.int32))
+ print(f"ADDED {count} tokens")
+ text = bytes(context.modules.main.get_results().deref(rt.VmBuffer))
+ print(f"RESULTS: {text}")
+ assert text == b"Hi there."
if __name__ == "__main__":
- main()
+ main()
diff --git a/samples/vision_inference/convert_image.py b/samples/vision_inference/convert_image.py
index 6253cca..ee1ef44 100644
--- a/samples/vision_inference/convert_image.py
+++ b/samples/vision_inference/convert_image.py
@@ -16,12 +16,12 @@
# Read image from stdin (in any format supported by PIL).
with Image.open(sys.stdin.buffer) as color_img:
- # Resize to 28x28, matching what the program expects.
- resized_color_img = color_img.resize((28, 28))
- # Convert to grayscale.
- grayscale_img = resized_color_img.convert('L')
- # Rescale to a float32 in range [0.0, 1.0].
- grayscale_arr = np.array(grayscale_img)
- grayscale_arr_f32 = grayscale_arr.astype(np.float32) / 255.0
- # Write bytes back out to stdout.
- sys.stdout.buffer.write(grayscale_arr_f32.tobytes())
+ # Resize to 28x28, matching what the program expects.
+ resized_color_img = color_img.resize((28, 28))
+ # Convert to grayscale.
+ grayscale_img = resized_color_img.convert("L")
+ # Rescale to a float32 in range [0.0, 1.0].
+ grayscale_arr = np.array(grayscale_img)
+ grayscale_arr_f32 = grayscale_arr.astype(np.float32) / 255.0
+ # Write bytes back out to stdout.
+ sys.stdout.buffer.write(grayscale_arr_f32.tobytes())