| #!/usr/bin/env python3 |
| ## Copyright 2022 The IREE Authors |
| # |
| # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| import textwrap |
| import unittest |
| import cmake_builder.rules |
| |
| |
| class RulesTest(unittest.TestCase): |
| |
| def test_build_iree_bytecode_module(self): |
| rule = cmake_builder.rules.build_iree_bytecode_module( |
| target_name="abcd", |
| src="abcd.mlir", |
| module_name="abcd.vmfb", |
| flags=["--backend=cpu", "--opt=3"], |
| compile_tool_target="iree_iree-compile2", |
| c_identifier="abcd.c", |
| static_lib_path="libx.a", |
| deps=["iree_libx", "iree_liby"], |
| testonly=True, |
| public=False) |
| |
| self.assertEqual( |
| rule, |
| textwrap.dedent("""\ |
| iree_bytecode_module( |
| NAME "abcd" |
| SRC "abcd.mlir" |
| MODULE_FILE_NAME "abcd.vmfb" |
| C_IDENTIFIER "abcd.c" |
| COMPILE_TOOL "iree_iree-compile2" |
| STATIC_LIB_PATH "libx.a" |
| FLAGS |
| "--backend=cpu" |
| "--opt=3" |
| DEPS |
| "iree_libx" |
| "iree_liby" |
| TESTONLY |
| ) |
| """)) |
| |
| def test_build_iree_bytecode_module_with_defaults(self): |
| rule = cmake_builder.rules.build_iree_bytecode_module( |
| target_name="abcd", |
| src="abcd.mlir", |
| module_name="abcd.vmfb", |
| flags=["--backend=cpu", "--opt=3"]) |
| |
| self.assertEqual( |
| rule, |
| textwrap.dedent("""\ |
| iree_bytecode_module( |
| NAME "abcd" |
| SRC "abcd.mlir" |
| MODULE_FILE_NAME "abcd.vmfb" |
| FLAGS |
| "--backend=cpu" |
| "--opt=3" |
| PUBLIC |
| ) |
| """)) |
| |
| def test_build_iree_fetch_artifact(self): |
| rule = cmake_builder.rules.build_iree_fetch_artifact( |
| target_name="abcd", |
| source_url="https://example.com/abcd.tflite", |
| output="./abcd.tflite", |
| unpack=True) |
| |
| self.assertEqual( |
| rule, |
| textwrap.dedent("""\ |
| iree_fetch_artifact( |
| NAME "abcd" |
| SOURCE_URL "https://example.com/abcd.tflite" |
| OUTPUT "./abcd.tflite" |
| UNPACK |
| ) |
| """)) |
| |
| def test_build_iree_import_tf_model(self): |
| rule = cmake_builder.rules.build_iree_import_tf_model( |
| target_path="pkg_abcd", |
| source="abcd/model", |
| import_flags=[ |
| "--tf-savedmodel-exported-names=main", |
| "--tf-import-type=savedmodel_v1" |
| ], |
| output_mlir_file="abcd.mlir") |
| |
| self.assertEqual( |
| rule, |
| textwrap.dedent("""\ |
| iree_import_tf_model( |
| TARGET_NAME "pkg_abcd" |
| SOURCE "abcd/model" |
| IMPORT_FLAGS |
| "--tf-savedmodel-exported-names=main" |
| "--tf-import-type=savedmodel_v1" |
| OUTPUT_MLIR_FILE "abcd.mlir" |
| ) |
| """)) |
| |
| def test_build_iree_import_tflite_model(self): |
| rule = cmake_builder.rules.build_iree_import_tflite_model( |
| target_path="pkg_abcd", |
| source="abcd.tflite", |
| import_flags=["--fake-flag=abcd"], |
| output_mlir_file="abcd.mlir") |
| |
| self.assertEqual( |
| rule, |
| textwrap.dedent("""\ |
| iree_import_tflite_model( |
| TARGET_NAME "pkg_abcd" |
| SOURCE "abcd.tflite" |
| IMPORT_FLAGS |
| "--fake-flag=abcd" |
| OUTPUT_MLIR_FILE "abcd.mlir" |
| ) |
| """)) |
| |
| def test_build_iree_benchmark_suite_module_test(self): |
| rule = cmake_builder.rules.build_iree_benchmark_suite_module_test( |
| target_name="model_test", |
| driver="LOCAL_TASK", |
| expected_output="xyz", |
| platform_module_map={ |
| "x86_64": "a.vmfb", |
| "arm": "b.vmfb" |
| }, |
| runner_args=["--x=0", "--y=1"], |
| timeout_secs=10, |
| labels=["defaults", "e2e"], |
| xfail_platforms=["arm_64-Android", "riscv_32-Linux"]) |
| |
| self.assertEqual( |
| rule, |
| textwrap.dedent("""\ |
| iree_benchmark_suite_module_test( |
| NAME "model_test" |
| DRIVER "LOCAL_TASK" |
| EXPECTED_OUTPUT "xyz" |
| TIMEOUT "10" |
| MODULES |
| "x86_64=a.vmfb" |
| "arm=b.vmfb" |
| RUNNER_ARGS |
| "--x=0" |
| "--y=1" |
| LABELS |
| "defaults" |
| "e2e" |
| XFAIL_PLATFORMS |
| "arm_64-Android" |
| "riscv_32-Linux" |
| ) |
| """)) |
| |
| def test_build_add_dependencies(self): |
| rule = cmake_builder.rules.build_add_dependencies( |
| target="iree_mlir_suites", deps=["pkg_abcd", "pkg_efgh"]) |
| |
| self.assertEqual( |
| rule, |
| textwrap.dedent("""\ |
| add_dependencies(iree_mlir_suites |
| pkg_abcd |
| pkg_efgh |
| ) |
| """)) |
| |
| def test_build_set(self): |
| rule = cmake_builder.rules.build_set(variable_name="_ABC", value="123") |
| |
| self.assertEqual( |
| rule, |
| textwrap.dedent("""\ |
| set(_ABC |
| 123 |
| ) |
| """)) |
| |
| |
| if __name__ == "__main__": |
| unittest.main() |