blob: 4bb18028326ba6366d6b5a09070d4a18069c156e [file] [log] [blame]
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import asyncio
import timeit
import unittest
from iree.runtime import (
get_device,
HalDeviceLoopBridge,
)
class HalDeviceLoopBridgeTest(unittest.TestCase):
def testBridge(self):
loop = asyncio.new_event_loop()
bridge = HalDeviceLoopBridge(self.device, loop)
sem1 = None
sem2 = None
report = None
async def main():
def done_1(x):
report("PYTHON: sem2.signal(1)")
sem2.signal(1)
def done_2(x):
report("PYTHON: sem2.signal(2)")
sem2.signal(2)
f1 = bridge.on_semaphore(sem1, 1, "Semaphore 1 Signaled")
f1.add_done_callback(done_1)
f2 = bridge.on_semaphore(sem2, 1, "Semaphore 2 Signaled")
f2.add_done_callback(done_2)
f2_again = bridge.on_semaphore(sem2, 2, "Semaphore 2 Signaled Again")
sem1.signal(1)
f1_result = await f1
report("PYTHON: await f1 =", f1_result)
f2_result = await f2
report("PYTHON: await f2 =", f2_result)
f2_again_result = await f2_again
report("PYTHON: await f2_again =", f2_again_result)
self.assertEqual(f1_result, "Semaphore 1 Signaled")
self.assertEqual(f2_result, "Semaphore 2 Signaled")
self.assertEqual(f2_again_result, "Semaphore 2 Signaled Again")
report("PYTHON: ASYNC MAIN() COMPLETE")
def run_iter(with_report):
nonlocal sem1
nonlocal sem2
nonlocal report
sem1 = self.device.create_semaphore(0)
sem2 = self.device.create_semaphore(0)
if with_report:
report = lambda *args: print(*args)
else:
report = lambda *args: None
loop.run_until_complete(main())
try:
run_iter(True)
iter_time = timeit.timeit("run_iter(False)", globals=locals(), number=10)
print(f"Time/iter = {iter_time}s")
finally:
bridge.stop()
def setUp(self):
super().setUp()
self.device = get_device("local-task")
self.allocator = self.device.allocator
if __name__ == "__main__":
unittest.main()