blob: ae23ece99c6360181b41b901f7e459c856e2e4bc [file] [log] [blame]
# Lint as: python3
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Several baseline e2e simple arithmetic tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
import timeit
import numpy as np
import pyiree
import tensorflow.compat.v2 as tf
class SimpleArithmeticModule(tf.Module):
@tf.function(input_signature=[
tf.TensorSpec([4], tf.float32),
tf.TensorSpec([4], tf.float32)
])
def simple_mul(self, a, b):
return a * b
@tf.function(input_signature=[
tf.TensorSpec([128, 3072], tf.float32),
tf.TensorSpec([3072, 256], tf.float32),
])
def simple_matmul(self, a, b):
return tf.matmul(a, b)
def save_and_load_tf_module(tf_module):
with tempfile.TemporaryDirectory() as sm_path:
options = tf.saved_model.SaveOptions(save_debug_info=True)
tf.saved_model.save(tf_module, sm_path, options=options)
ctx = pyiree.CompilerContext()
input_module = pyiree.tf_load_saved_model(ctx, sm_path)
return input_module
def dump_iree_module(m):
print("Loaded module:", m.name)
i = 0
while True:
f = m.lookup_function_by_ordinal(i)
if not f:
break
print(" Export:", f.name, "-> args(", f.signature.argument_count,
"), results(", f.signature.result_count, ")")
i += 1
class SimpleArithmeticTest(tf.test.TestCase):
@classmethod
def tearDownClass(cls):
super().tearDownClass()
trace_file = os.path.join(tempfile.gettempdir(),
"simple_arithmetic_test.wtf-trace")
print("Flushing trace file to:", trace_file)
pyiree.tracing.flush(trace_file)
print("Flush complete")
@classmethod
def setUpClass(cls):
super().setUpClass()
# Compile the module. We do this once.
cls.tf_module = SimpleArithmeticModule()
cls.mlir_input_module = save_and_load_tf_module(cls.tf_module)
print("LOADED ASM:",
cls.mlir_input_module.to_asm(debug_info=True, pretty=True))
# Canonicalize the TF import.
cls.mlir_input_module.run_pass_pipeline([
"tf-executor-graph-pruning",
"tf-standard-pipeline",
"canonicalize",
])
print("CANONICAL TF ASM:",
cls.mlir_input_module.to_asm(debug_info=True, pretty=True))
# Legalize to XLA (high-level).
cls.mlir_input_module.run_pass_pipeline([
"xla-legalize-tf",
])
print("XLA ASM:",
cls.mlir_input_module.to_asm(debug_info=True, pretty=True))
# Compile the module with IREE.
cls.iree_blob = cls.mlir_input_module.compile_to_sequencer_blob(
print_mlir=True)
cls.iree_vm_module = pyiree.binding.vm.create_module_from_blob(
cls.iree_blob)
dump_iree_module(cls.iree_vm_module)
def test_simple_matmul(self):
pyiree.tracing.enable_thread()
# Initialize the runtime and register the module.
# Use the CPU interpreter driver (which has the most implementation done):
driver_name = "interpreter"
# Live on the edge and give the vulkan driver a try:
# driver_name = "vulkan"
policy = pyiree.binding.rt.Policy()
instance = pyiree.binding.rt.Instance(driver_name=driver_name)
context = pyiree.binding.rt.Context(instance=instance, policy=policy)
context.register_module(self.iree_vm_module)
f = context.resolve_function("module.simple_matmul")
tf_f = self.tf_module.simple_matmul
np.random.seed(12345)
# Note: scaling by a small value to increase numerical stability.
a = np.random.random((128, 3072)).astype(np.float32) * 1e-3
b = np.random.random((3072, 256)).astype(np.float32) * 1e-3
iree_event = pyiree.tracing.ScopedEvent(
"SimpleArithmeticTest#simple_matmul")
def invoke_iree():
with iree_event:
arg0 = context.wrap_for_input(a)
arg1 = context.wrap_for_input(b)
# Invoke the function and wait for completion.
inv = context.invoke(f, policy, [arg0, arg1])
inv.await_ready()
# Get the result as a numpy array and print.
results = inv.results
result = results[0].map()
result_ary = np.array(result, copy=False)
return result_ary
def invoke_tf():
arg0 = a
arg1 = b
result = tf_f(arg0, arg1)
return result.numpy()
# Check that results are equal.
self.assertAllClose(invoke_iree(), invoke_tf())
# Quick benchmark.
iterations = 100
print("+++BM simple_matmul:")
iree_time = timeit.timeit(invoke_iree, number=iterations)
print("IREE -> TIME/ITERATION =", (iree_time / iterations) * 1000, "ms")
tf_time = timeit.timeit(invoke_tf, number=iterations)
print("TF -> TIME/ITERATION =", (tf_time / iterations) * 1000, "ms")
tf_vs_iree_factor = tf_time / iree_time
print("IREE VS TF SPEEDUP FACTOR =", tf_vs_iree_factor)
def test_simple_scalar_mul(self):
pyiree.tracing.enable_thread()
# Initialize the runtime and register the module.
# Use the CPU interpreter driver (which has the most implementation done):
driver_name = "interpreter"
# Live on the edge and give the vulkan driver a try:
# driver_name = "vulkan"
policy = pyiree.binding.rt.Policy()
instance = pyiree.binding.rt.Instance(driver_name=driver_name)
context = pyiree.binding.rt.Context(instance=instance, policy=policy)
context.register_module(self.iree_vm_module)
f = context.resolve_function("module.simple_mul")
tf_f = self.tf_module.simple_mul
a = np.array([1., 2., 3., 4.], dtype=np.float32)
b = np.array([400., 5., 6., 7.], dtype=np.float32)
iree_event = pyiree.tracing.ScopedEvent("SimpleArithmeticTest#simple_mul")
def invoke_iree():
with iree_event:
arg0 = context.wrap_for_input(a)
arg1 = context.wrap_for_input(b)
# Invoke the function and wait for completion.
inv = context.invoke(f, policy, [arg0, arg1])
inv.await_ready()
# Get the result as a numpy array and print.
results = inv.results
result = results[0].map()
result_ary = np.array(result, copy=False)
return result_ary
def invoke_tf():
arg0 = a
arg1 = b
result = tf_f(arg0, arg1)
return result.numpy()
# Check that results are equal.
self.assertAllEqual(invoke_iree(), invoke_tf())
# Quick benchmark.
iterations = 1000
print("+++BM simple_mul:")
iree_time = timeit.timeit(invoke_iree, number=iterations)
print("IREE -> TIME/ITERATION =", (iree_time / iterations) * 1000, "ms")
tf_time = timeit.timeit(invoke_tf, number=iterations)
print("TF -> TIME/ITERATION =", (tf_time / iterations) * 1000, "ms")
tf_vs_iree_factor = tf_time / iree_time
print("IREE VS TF SPEEDUP FACTOR =", tf_vs_iree_factor)
def test_simple_scalar_mul_streamed(self):
pyiree.tracing.enable_thread()
# Initialize the runtime and register the module.
# Use the CPU interpreter driver (which has the most implementation done):
driver_name = "interpreter"
# Live on the edge and give the vulkan driver a try:
# driver_name = "vulkan"
policy = pyiree.binding.rt.Policy()
instance = pyiree.binding.rt.Instance(driver_name=driver_name)
context = pyiree.binding.rt.Context(instance=instance, policy=policy)
context.register_module(self.iree_vm_module)
f = context.resolve_function("module.simple_mul")
tf_f = self.tf_module.simple_mul
a = np.array([1., 2., 3., 4.], dtype=np.float32)
b = np.array([400., 5., 6., 7.], dtype=np.float32)
iree_dispatch_event = pyiree.tracing.ScopedEvent(
"SimpleArithmeticTest#simple_mul_dispatch")
iree_await_event = pyiree.tracing.ScopedEvent(
"SimpleArithmeticTest#simple_mul_await")
invocations = []
def invoke_iree():
with iree_dispatch_event:
arg0 = context.wrap_for_input(a)
arg1 = context.wrap_for_input(b)
# Invoke the function and wait for completion.
inv = context.invoke(f, policy, [arg0, arg1])
invocations.append(inv)
def await_all():
with iree_await_event:
invocations[-1].await_ready()
def invoke_tf():
arg0 = a
arg1 = b
result = tf_f(arg0, arg1)
return result.numpy()
# Quick benchmark.
iterations = 1000
print("+++BM simple_mul_streamed:")
iree_time = timeit.timeit(invoke_iree, number=iterations)
iree_time += timeit.timeit(await_all, number=1)
print("IREE -> TIME/ITERATION =", (iree_time / iterations) * 1000, "ms")
tf_time = timeit.timeit(invoke_tf, number=iterations)
print("TF -> TIME/ITERATION =", (tf_time / iterations) * 1000, "ms")
tf_vs_iree_factor = tf_time / iree_time
print("IREE VS TF SPEEDUP FACTOR =", tf_vs_iree_factor)
if __name__ == "__main__":
tf.enable_v2_behavior()
tf.test.main()