Implement test utils for doing comparisons and benchmarks between TF and IREE.
* Also re-organizes some of the python namespace for better ergonomics.
PiperOrigin-RevId: 279344889
diff --git a/bindings/python/pyiree/BUILD b/bindings/python/pyiree/BUILD
index a85ac9e..c4880ba 100644
--- a/bindings/python/pyiree/BUILD
+++ b/bindings/python/pyiree/BUILD
@@ -62,9 +62,11 @@
srcs_version = "PY3",
deps = [
":binding",
+ ":compiler",
"//bindings/python:pathsetup", # build_cleaner: keep
] + select({
"//iree:enable_tensorflow": [
+ "//bindings/python/pyiree/tf_interop:test_utils",
"//bindings/python/pyiree/tf_interop:tf_test_driver",
],
"//conditions:default": [
@@ -72,6 +74,16 @@
}),
)
+py_library(
+ name = "compiler",
+ srcs = ["compiler.py"],
+ srcs_version = "PY3",
+ deps = [
+ ":binding",
+ "//bindings/python:pathsetup", # build_cleaner: keep
+ ],
+)
+
cc_library(
name = "base",
srcs = [
diff --git a/bindings/python/pyiree/__init__.py b/bindings/python/pyiree/__init__.py
index 90f7e50..46739ee 100644
--- a/bindings/python/pyiree/__init__.py
+++ b/bindings/python/pyiree/__init__.py
@@ -18,19 +18,18 @@
# pylint: disable=g-import-not-at-top
# pylint: disable=g-bad-import-order
-# Always make the low-level native bindings accessible.
+# Top-level modules that are imported verbatim.
from . import binding
+from . import compiler
+from .binding import tracing
# Alias public compiler symbols.
+# TODO(laurenzo): Remove these aliases.
from .binding.compiler import CompilerContext
from .binding.compiler import CompilerModule
-# Alias tracing symbols.
-from .binding import tracing
-
-# Alias symbols from the native tf_interop module.
-if hasattr(binding, "tf_interop"):
- from .binding.tf_interop import load_saved_model as tf_load_saved_model
+# Alias specific native functions.
+from .binding.vm import create_module_from_blob
### Load non-native py_library deps here ###
### Order matters because these typically have a back-reference on this
@@ -40,5 +39,6 @@
# in).
try:
from .tf_interop import tf_test_driver
+ from .tf_interop import test_utils as tf_test_utils
except ImportError:
pass
diff --git a/bindings/python/pyiree/compiler.py b/bindings/python/pyiree/compiler.py
new file mode 100644
index 0000000..f7ce73f
--- /dev/null
+++ b/bindings/python/pyiree/compiler.py
@@ -0,0 +1,98 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""High level compiler API.
+
+This imports parts of the native bindings as appropriate.
+"""
+
+from typing import Collection, Optional, Sequence
+
+from . import binding as _binding
+
+# Native aliases.
+Context = _binding.compiler.CompilerContext
+Module = _binding.compiler.CompilerModule
+
+# Conditionally import TensorFlow interop aliases.
+HAS_TENSORFLOW = hasattr(_binding, "tf_interop")
+
+if HAS_TENSORFLOW:
+ # Pass pipeline that should run to lower a TF saved_model to a form suitable
+ # for input to the IREE compiler.
+ TF_IMPORT_PASS_PIPELINE = (
+ "tf-executor-graph-pruning",
+ "tf-standard-pipeline",
+ "canonicalize",
+ "xla-legalize-tf",
+ )
+
+ def tf_load_saved_model(
+ saved_model_dir: str,
+ compiler_context: Optional[Context] = None,
+ exported_names: Collection[str] = (),
+ pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE) -> Module:
+ """Loads a TensorFlow saved model from its persistent representation.
+
+ See also tf_compile_saved_model() for a one-shot API to load and compile.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ compiler_context: The pyiree.compiler.Context() backing the module.
+ exported_names: Optional tuple of strings representing the exported names
+ to keep.
+ pass_pipeline: Passes to run on the imported module prior to returning.
+ Defaults to TF_IMPORT_PASS_PIPELINE.
+
+ Returns:
+ An MLIR Module suitable for compilation by the IREE compiler.
+ This can be further compiled to an IREE blob by calling
+ .compile_to_sequencer_blob.
+ """
+ if not compiler_context:
+ compiler_context = Context()
+ input_module = _binding.tf_interop.load_saved_model(
+ compiler_context, saved_model_dir, exported_names=exported_names)
+ if pass_pipeline:
+ input_module.run_pass_pipeline(pass_pipeline)
+ return input_module
+
+ def tf_compile_saved_model(
+ saved_model_dir: str,
+ compiler_context: Optional[Context] = None,
+ exported_names: Collection[str] = (),
+ pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
+ print_mlir: bool = False,
+ target_backends: Collection[str] = ()
+ ) -> _binding.OpaqueBlob:
+ """Loads and compiles a TensorFlow saved model in one shot.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ compiler_context: The pyiree.compiler.Context() backing the module.
+ exported_names: Optional tuple of strings representing the exported names
+ to keep.
+ pass_pipeline: Passes to run on the imported module prior to returning.
+ Defaults to TF_IMPORT_PASS_PIPELINE.
+ print_mlir: Whether to print intermediate MLIR after each pass.
+ target_backends: The specific target backends to compile for (defaults to
+ all compiled in targets).
+
+ Returns:
+ An OpaqueBlob representing the compiled module.
+ """
+ input_module = tf_load_saved_model(saved_model_dir, compiler_context,
+ exported_names, pass_pipeline)
+ return input_module.compile_to_sequencer_blob(
+ print_mlir=print_mlir, target_backends=target_backends)
diff --git a/bindings/python/pyiree/compiler_test.py b/bindings/python/pyiree/compiler_test.py
index 7c8f22d..db989dc 100644
--- a/bindings/python/pyiree/compiler_test.py
+++ b/bindings/python/pyiree/compiler_test.py
@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+# Lint as: python3
from absl.testing import absltest
import pyiree
@@ -23,12 +21,12 @@
class CompilerTest(absltest.TestCase):
def testParseError(self):
- ctx = pyiree.CompilerContext()
+ ctx = pyiree.compiler.Context()
with self.assertRaisesRegex(ValueError, "custom op 'FOOBAR' is unknown"):
ctx.parse_asm("""FOOBAR: I SHOULD NOT PARSE""")
def testParseAndCompileToSequencer(self):
- ctx = pyiree.CompilerContext()
+ ctx = pyiree.compiler.Context()
input_module = ctx.parse_asm("""
func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
attributes { iree.module.export } {
diff --git a/bindings/python/pyiree/tf_interop/BUILD b/bindings/python/pyiree/tf_interop/BUILD
index b2180f2..6094507 100644
--- a/bindings/python/pyiree/tf_interop/BUILD
+++ b/bindings/python/pyiree/tf_interop/BUILD
@@ -122,7 +122,17 @@
py_library(
name = "tf_test_driver",
srcs = ["tf_test_driver.py"],
- deps = INTREE_TENSORFLOW_PY_DEPS,
+ deps = INTREE_TENSORFLOW_PY_DEPS + [
+ "//bindings/python/pyiree:binding",
+ ],
+)
+
+py_library(
+ name = "test_utils",
+ srcs = ["test_utils.py"],
+ deps = INTREE_TENSORFLOW_PY_DEPS + [
+ "//bindings/python/pyiree:binding",
+ ],
)
py_test(
diff --git a/bindings/python/pyiree/tf_interop/saved_model_test.py b/bindings/python/pyiree/tf_interop/saved_model_test.py
index 89a9267..e0c4293 100644
--- a/bindings/python/pyiree/tf_interop/saved_model_test.py
+++ b/bindings/python/pyiree/tf_interop/saved_model_test.py
@@ -71,30 +71,7 @@
tf.saved_model.save(my_module, sm_dir, options=options)
# Load it up.
- ctx = pyiree.CompilerContext()
- input_module = pyiree.tf_load_saved_model(ctx, sm_dir)
- input_asm = input_module.to_asm()
- print("LOADED ASM:\n", input_asm)
- # Should have out exported name and have executor islands.
- self.assertRegex(input_asm,
- r"""tf_saved_model.exported_names = \["add"\]""")
- self.assertRegex(input_asm, r"""tf_executor\.island""")
-
- # Run the necessary lowering passes. Makes sure that these are linked in.
- input_module.run_pass_pipeline([
- "tf-executor-graph-pruning",
- "tf-standard-pipeline",
- "canonicalize",
- ])
- lowered_asm = input_module.to_asm()
- print("LOWERED ASM:\n", lowered_asm)
- # Should have collapsed all executor islands.
- self.assertNotRegex(lowered_asm, r"""tf_executor\.island""")
-
- # And legalize to XLA.
- input_module.run_pass_pipeline([
- "xla-legalize-tf",
- ])
+ input_module = pyiree.compiler.tf_load_saved_model(sm_dir)
xla_asm = input_module.to_asm()
print("XLA ASM:", xla_asm)
self.assertRegex(xla_asm, "xla_hlo.tanh")
diff --git a/bindings/python/pyiree/tf_interop/test_utils.py b/bindings/python/pyiree/tf_interop/test_utils.py
new file mode 100644
index 0000000..0d8f81a
--- /dev/null
+++ b/bindings/python/pyiree/tf_interop/test_utils.py
@@ -0,0 +1,215 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test utilities interop with TensorFlow."""
+
+import os
+import tempfile
+import timeit
+
+from .. import binding
+from .. import compiler
+import numpy as np
+import tensorflow.compat.v2 as tf
+
+
+def save_and_compile_tf_module(tf_module):
+ with tempfile.TemporaryDirectory() as sm_path:
+ options = tf.saved_model.SaveOptions(save_debug_info=True)
+ tf.saved_model.save(tf_module, sm_path, options=options)
+ return compiler.tf_compile_saved_model(sm_path)
+
+
+def dump_iree_module(m):
+ print("Loaded module:", m.name)
+ i = 0
+ while True:
+ f = m.lookup_function_by_ordinal(i)
+ if not f:
+ break
+ print(" Export:", f.name, "-> args(", f.signature.argument_count,
+ "), results(", f.signature.result_count, ")")
+ i += 1
+
+
+def get_default_test_backends():
+ backends_env = os.environ.get("IREE_TEST_BACKENDS")
+ if backends_env:
+ return backends_env.split(",")
+ else:
+ return ("tf", "iree.interpreter")
+
+
+class _TfBackend(object):
+ """Backend for running directly on the TF module."""
+
+ def __init__(self, test_case, backend_name, fn_name):
+ self.backend_name = backend_name
+ self.module_f = getattr(test_case.tf_module, fn_name)
+
+ def __call__(self, *args):
+ return self.module_f(*args)
+
+ def postprocess(self, results):
+ # Handle single result (technically ambiguous with return of a tuple).
+ if not isinstance(results, tuple):
+ results = (results,)
+ # TODO(laurenzo): Handle structure mapping, etc.
+ return [r.numpy() for r in results]
+
+
+class _IreeBackend(object):
+ """Backend for running on an IREE driver."""
+
+ def __init__(self, test_case, backend_name, fn_name):
+ self.backend_name = backend_name
+ driver_name = backend_name.split(".")[-1]
+ self.policy = binding.rt.Policy()
+ instance = binding.rt.Instance(driver_name=driver_name)
+ self.context = binding.rt.Context(instance=instance, policy=self.policy)
+ self.context.register_module(test_case.iree_vm_module)
+ self.f = self.context.resolve_function("module." + fn_name)
+
+ def __call__(self, *args):
+ args = [self.context.wrap_for_input(arg) for arg in args]
+ # Invoke the function and wait for completion.
+ inv = self.context.invoke(self.f, self.policy, args)
+ inv.await_ready()
+ # Get results as a numpy array.
+ results = [np.array(r.map(), copy=False) for r in inv.results]
+ return results
+
+ def postprocess(self, results):
+ return results
+
+
+_ALL_BACKENDS = {
+ "tf": _TfBackend,
+ "iree.interpreter": _IreeBackend,
+ "iree.vulkan": _IreeBackend,
+}
+
+
+def _wrap_per_backend_fn(saved_model_test_case, fn_name, iterations=100):
+ """Generates a wrapper function for a backend fn name."""
+
+ def invoke_fn(*args):
+ """Lambda that invokes the function on all backends."""
+
+ backend_names = saved_model_test_case.BACKENDS
+ if not backend_names:
+ backend_names = get_default_test_backends()
+
+ backends = [
+ _ALL_BACKENDS[b](saved_model_test_case, b, fn_name)
+ for b in backend_names
+ ]
+ test_id = saved_model_test_case.id().split(".")[-1]
+
+ per_backend_results = []
+ binding.tracing.enable_thread()
+ for backend in backends:
+ # pylint: disable=cell-var-from-loop
+ print(":INVOKE %s:%s on %s" % (test_id, fn_name, backend.backend_name))
+ event = binding.tracing.ScopedEvent(
+ "%s_%s#%s" % (test_id, fn_name, backend.backend_name))
+
+ def run_iteration():
+ with event:
+ return backend(*args)
+
+ # Run one for correctness.
+ results = backend.postprocess(run_iteration())
+ per_backend_results.append((backend.backend_name, results))
+ # Then time it.
+ backend_time_ms = timeit.timeit(run_iteration, number=iterations) * 1000
+ iteration_time_ms = backend_time_ms / iterations
+ print(":BENCHMARK %s:%s on %s: time=%rms" %
+ (test_id, fn_name, backend.backend_name, iteration_time_ms))
+ # pylint: enable=cell-var-from-loop
+
+ # Verify results.
+ ref_backend_name, ref_results = per_backend_results[0]
+ print(":REF RESULTS %s:%s %s:" % (test_id, fn_name, ref_backend_name),
+ ref_results)
+ for backend_name, results in per_backend_results[1:]:
+ print(":COMPARE %s:%s %s vs %s" %
+ (test_id, fn_name, ref_backend_name, backend_name))
+ print(" :", results)
+ for ref_result, result in zip(ref_results, results):
+ saved_model_test_case.assertAllClose(
+ ref_result,
+ result,
+ msg="Result mismatch %s vs %s" % (ref_backend_name, backend_name))
+
+ return ref_results
+
+ return invoke_fn
+
+
+def per_backend_test(*fn_names):
+ """Wraps a SavedModelTestCase test method to run per backend tests.
+
+ Args:
+ *fn_names: Names of functions to run tests against. These will be converted
+ to python functions that invoke all of the backends and passed to the test
+ case method.
+
+ Returns:
+ A decorated function.
+ """
+
+ def decorator(f):
+
+ def replacement(self):
+ fns = [_wrap_per_backend_fn(self, fn_name) for fn_name in fn_names]
+ f(self, *fns)
+
+ replacement.__name__ = f.__name__
+ return replacement
+
+ return decorator
+
+
+class SavedModelTestCase(tf.test.TestCase):
+ """Tests against a SavedModel.
+
+ Use this by subclassing and then defining a TF_MODULE_CONSTRUCTOR member.
+ """
+
+ TF_MODULE_CONSTRUCTOR = None
+ TRACE_FILE_NAME = None
+ BACKENDS = None
+
+ @classmethod
+ def tearDownClass(cls):
+ trace_file_name = cls.TRACE_FILE_NAME
+ if not trace_file_name:
+ trace_file_name = cls.__name__ + ".wtf-trace"
+ trace_file = os.path.join(tempfile.gettempdir(), trace_file_name)
+ print("Flushing trace file to:", trace_file)
+ binding.tracing.flush(trace_file)
+ print("Flush complete")
+ super().tearDownClass()
+
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ if cls.TF_MODULE_CONSTRUCTOR is None:
+ raise ValueError("Expected a class level TF_MODULE_CONSTRUCTOR")
+ # Compile the module. We do this once.
+ cls.tf_module = cls.TF_MODULE_CONSTRUCTOR() # pylint: disable=not-callable
+ cls.iree_blob = save_and_compile_tf_module(cls.tf_module)
+ cls.iree_vm_module = binding.vm.create_module_from_blob(cls.iree_blob)
+ dump_iree_module(cls.iree_vm_module)
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index b9765e0..45a278f 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -16,6 +16,7 @@
"//iree:build_defs.bzl",
"INTREE_TENSORFLOW_PY_DEPS",
"NUMPY_DEPS",
+ "PLATFORM_VULKAN_DEPS",
)
package(
@@ -27,7 +28,8 @@
name = "simple_arithmetic_test",
srcs = ["simple_arithmetic_test.py"],
python_version = "PY3",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + PLATFORM_VULKAN_DEPS + [
"//bindings/python/pyiree",
+ "//iree/hal/vulkan:vulkan_driver_module",
],
)
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
new file mode 100644
index 0000000..cea7df1
--- /dev/null
+++ b/integrations/tensorflow/e2e/README.md
@@ -0,0 +1,46 @@
+# TensorFlow e2e tests
+
+This is a collection of e2e tests that, in various fashion saves a TensorFlow
+model, compiles it with IREE and runs/evaluates it on all backends.
+
+## Pre-Requisites
+
+You will need a TensorFlow 2.0+ nightly installed in your python environment:
+the python binary in `$PYTHON_BIN` should be able to `import tensorflow` and
+that TensorFlow should be version 2.0+. This can be checked with
+`tensorflow.version`.
+
+See [Install TensorFlow with pip](https://www.tensorflow.org/install/pip) for
+instructions.
+
+## Vulkan setup
+
+By default, tests run on TensorFlow and the IREE CPU interpreter, as it never
+needs additional environment setup. If you have your environment setup to use
+IREE with Vulkan (see [the doc](../../../docs/vulkan_and_spirv.md)), then you
+can enable the backends by setting the environment variable
+`IREE_TEST_BACKENDS=tf,iree.interpreter,iree.vulkan`.
+
+## Running tests
+
+```shell
+# Run all tests with defaults and output on failure.
+bazel test ... --test_output=errors
+
+# Run an individual test interactively.
+bazel test simple_arithmetic_test --test_output=streamed
+
+# Run tests with an altered list of backends.
+bazel test ... --test_env=IREE_TEST_BACKENDS=tf,iree.interpreter,iree.vulkan \
+ --test_output=errors
+```
+
+## Test harnesses
+
+### Simple function tests
+
+See `simple_arithmetic_test.py` for some examples of single function tests.
+These are done by extending a tf_test_utils.SavedModelTestCase and then
+annotating individual test methods with
+`@tf_test_utils.per_backend_test("function_name")` to get a function that will
+run and compare on all backends.
diff --git a/integrations/tensorflow/e2e/simple_arithmetic_test.py b/integrations/tensorflow/e2e/simple_arithmetic_test.py
index ae23ece..8759ccc 100644
--- a/integrations/tensorflow/e2e/simple_arithmetic_test.py
+++ b/integrations/tensorflow/e2e/simple_arithmetic_test.py
@@ -14,16 +14,8 @@
# limitations under the License.
"""Several baseline e2e simple arithmetic tests."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-import tempfile
-import timeit
-
import numpy as np
-import pyiree
+from pyiree import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -44,238 +36,26 @@
return tf.matmul(a, b)
-def save_and_load_tf_module(tf_module):
- with tempfile.TemporaryDirectory() as sm_path:
- options = tf.saved_model.SaveOptions(save_debug_info=True)
- tf.saved_model.save(tf_module, sm_path, options=options)
- ctx = pyiree.CompilerContext()
- input_module = pyiree.tf_load_saved_model(ctx, sm_path)
- return input_module
+class SimpleArithmeticTest(tf_test_utils.SavedModelTestCase):
+ TF_MODULE_CONSTRUCTOR = SimpleArithmeticModule
-def dump_iree_module(m):
- print("Loaded module:", m.name)
- i = 0
- while True:
- f = m.lookup_function_by_ordinal(i)
- if not f:
- break
- print(" Export:", f.name, "-> args(", f.signature.argument_count,
- "), results(", f.signature.result_count, ")")
- i += 1
-
-
-class SimpleArithmeticTest(tf.test.TestCase):
-
- @classmethod
- def tearDownClass(cls):
- super().tearDownClass()
- trace_file = os.path.join(tempfile.gettempdir(),
- "simple_arithmetic_test.wtf-trace")
- print("Flushing trace file to:", trace_file)
- pyiree.tracing.flush(trace_file)
- print("Flush complete")
-
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
- # Compile the module. We do this once.
- cls.tf_module = SimpleArithmeticModule()
- cls.mlir_input_module = save_and_load_tf_module(cls.tf_module)
- print("LOADED ASM:",
- cls.mlir_input_module.to_asm(debug_info=True, pretty=True))
-
- # Canonicalize the TF import.
- cls.mlir_input_module.run_pass_pipeline([
- "tf-executor-graph-pruning",
- "tf-standard-pipeline",
- "canonicalize",
- ])
- print("CANONICAL TF ASM:",
- cls.mlir_input_module.to_asm(debug_info=True, pretty=True))
-
- # Legalize to XLA (high-level).
- cls.mlir_input_module.run_pass_pipeline([
- "xla-legalize-tf",
- ])
- print("XLA ASM:",
- cls.mlir_input_module.to_asm(debug_info=True, pretty=True))
-
- # Compile the module with IREE.
- cls.iree_blob = cls.mlir_input_module.compile_to_sequencer_blob(
- print_mlir=True)
- cls.iree_vm_module = pyiree.binding.vm.create_module_from_blob(
- cls.iree_blob)
- dump_iree_module(cls.iree_vm_module)
-
- def test_simple_matmul(self):
- pyiree.tracing.enable_thread()
- # Initialize the runtime and register the module.
- # Use the CPU interpreter driver (which has the most implementation done):
- driver_name = "interpreter"
-
- # Live on the edge and give the vulkan driver a try:
- # driver_name = "vulkan"
-
- policy = pyiree.binding.rt.Policy()
- instance = pyiree.binding.rt.Instance(driver_name=driver_name)
- context = pyiree.binding.rt.Context(instance=instance, policy=policy)
- context.register_module(self.iree_vm_module)
-
- f = context.resolve_function("module.simple_matmul")
- tf_f = self.tf_module.simple_matmul
+ @tf_test_utils.per_backend_test("simple_matmul")
+ def test_simple_matmul(self, simple_matmul):
np.random.seed(12345)
# Note: scaling by a small value to increase numerical stability.
a = np.random.random((128, 3072)).astype(np.float32) * 1e-3
b = np.random.random((3072, 256)).astype(np.float32) * 1e-3
+ simple_matmul(a, b)
- iree_event = pyiree.tracing.ScopedEvent(
- "SimpleArithmeticTest#simple_matmul")
-
- def invoke_iree():
- with iree_event:
- arg0 = context.wrap_for_input(a)
- arg1 = context.wrap_for_input(b)
-
- # Invoke the function and wait for completion.
- inv = context.invoke(f, policy, [arg0, arg1])
- inv.await_ready()
-
- # Get the result as a numpy array and print.
- results = inv.results
- result = results[0].map()
- result_ary = np.array(result, copy=False)
- return result_ary
-
- def invoke_tf():
- arg0 = a
- arg1 = b
- result = tf_f(arg0, arg1)
- return result.numpy()
-
- # Check that results are equal.
- self.assertAllClose(invoke_iree(), invoke_tf())
- # Quick benchmark.
- iterations = 100
- print("+++BM simple_matmul:")
- iree_time = timeit.timeit(invoke_iree, number=iterations)
- print("IREE -> TIME/ITERATION =", (iree_time / iterations) * 1000, "ms")
- tf_time = timeit.timeit(invoke_tf, number=iterations)
- print("TF -> TIME/ITERATION =", (tf_time / iterations) * 1000, "ms")
- tf_vs_iree_factor = tf_time / iree_time
- print("IREE VS TF SPEEDUP FACTOR =", tf_vs_iree_factor)
-
- def test_simple_scalar_mul(self):
- pyiree.tracing.enable_thread()
- # Initialize the runtime and register the module.
- # Use the CPU interpreter driver (which has the most implementation done):
- driver_name = "interpreter"
-
- # Live on the edge and give the vulkan driver a try:
- # driver_name = "vulkan"
-
- policy = pyiree.binding.rt.Policy()
- instance = pyiree.binding.rt.Instance(driver_name=driver_name)
- context = pyiree.binding.rt.Context(instance=instance, policy=policy)
- context.register_module(self.iree_vm_module)
-
- f = context.resolve_function("module.simple_mul")
- tf_f = self.tf_module.simple_mul
+ @tf_test_utils.per_backend_test("simple_mul")
+ def test_simple_scalar_mul(self, simple_mul):
a = np.array([1., 2., 3., 4.], dtype=np.float32)
b = np.array([400., 5., 6., 7.], dtype=np.float32)
-
- iree_event = pyiree.tracing.ScopedEvent("SimpleArithmeticTest#simple_mul")
-
- def invoke_iree():
- with iree_event:
- arg0 = context.wrap_for_input(a)
- arg1 = context.wrap_for_input(b)
-
- # Invoke the function and wait for completion.
- inv = context.invoke(f, policy, [arg0, arg1])
- inv.await_ready()
-
- # Get the result as a numpy array and print.
- results = inv.results
- result = results[0].map()
- result_ary = np.array(result, copy=False)
- return result_ary
-
- def invoke_tf():
- arg0 = a
- arg1 = b
- result = tf_f(arg0, arg1)
- return result.numpy()
-
- # Check that results are equal.
- self.assertAllEqual(invoke_iree(), invoke_tf())
- # Quick benchmark.
- iterations = 1000
- print("+++BM simple_mul:")
- iree_time = timeit.timeit(invoke_iree, number=iterations)
- print("IREE -> TIME/ITERATION =", (iree_time / iterations) * 1000, "ms")
- tf_time = timeit.timeit(invoke_tf, number=iterations)
- print("TF -> TIME/ITERATION =", (tf_time / iterations) * 1000, "ms")
- tf_vs_iree_factor = tf_time / iree_time
- print("IREE VS TF SPEEDUP FACTOR =", tf_vs_iree_factor)
-
- def test_simple_scalar_mul_streamed(self):
- pyiree.tracing.enable_thread()
- # Initialize the runtime and register the module.
- # Use the CPU interpreter driver (which has the most implementation done):
- driver_name = "interpreter"
-
- # Live on the edge and give the vulkan driver a try:
- # driver_name = "vulkan"
-
- policy = pyiree.binding.rt.Policy()
- instance = pyiree.binding.rt.Instance(driver_name=driver_name)
- context = pyiree.binding.rt.Context(instance=instance, policy=policy)
- context.register_module(self.iree_vm_module)
-
- f = context.resolve_function("module.simple_mul")
- tf_f = self.tf_module.simple_mul
- a = np.array([1., 2., 3., 4.], dtype=np.float32)
- b = np.array([400., 5., 6., 7.], dtype=np.float32)
-
- iree_dispatch_event = pyiree.tracing.ScopedEvent(
- "SimpleArithmeticTest#simple_mul_dispatch")
- iree_await_event = pyiree.tracing.ScopedEvent(
- "SimpleArithmeticTest#simple_mul_await")
-
- invocations = []
-
- def invoke_iree():
- with iree_dispatch_event:
- arg0 = context.wrap_for_input(a)
- arg1 = context.wrap_for_input(b)
-
- # Invoke the function and wait for completion.
- inv = context.invoke(f, policy, [arg0, arg1])
- invocations.append(inv)
-
- def await_all():
- with iree_await_event:
- invocations[-1].await_ready()
-
- def invoke_tf():
- arg0 = a
- arg1 = b
- result = tf_f(arg0, arg1)
- return result.numpy()
-
- # Quick benchmark.
- iterations = 1000
- print("+++BM simple_mul_streamed:")
- iree_time = timeit.timeit(invoke_iree, number=iterations)
- iree_time += timeit.timeit(await_all, number=1)
- print("IREE -> TIME/ITERATION =", (iree_time / iterations) * 1000, "ms")
- tf_time = timeit.timeit(invoke_tf, number=iterations)
- print("TF -> TIME/ITERATION =", (tf_time / iterations) * 1000, "ms")
- tf_vs_iree_factor = tf_time / iree_time
- print("IREE VS TF SPEEDUP FACTOR =", tf_vs_iree_factor)
+ simple_mul(a, b)
if __name__ == "__main__":
- tf.enable_v2_behavior()
+ if hasattr(tf, "enable_v2_behavior"):
+ tf.enable_v2_behavior()
tf.test.main()