| # Copyright 2022 The IREE Authors |
| # |
| # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| import logging |
| import unittest |
| |
| from iree.runtime import system_setup as ss |
| |
| |
| class DeviceSetupTest(unittest.TestCase): |
| def testQueryDriversDevices(self): |
| driver_names = ss.query_available_drivers() |
| print(f"Drivers: {driver_names}") |
| self.assertIn("local-sync", driver_names) |
| self.assertIn("local-task", driver_names) |
| |
| for driver_name in ["local-sync", "local-task"]: |
| driver = ss.get_driver(driver_name) |
| print(f"Driver {driver_name}: {driver}") |
| device_infos = driver.query_available_devices() |
| print(f"DeviceInfos: {device_infos}") |
| if driver_name == "local-sync": |
| # We happen to know that this should have one device_info |
| self.assertEqual( |
| device_infos, [{"device_id": 0, "path": "", "name": "default"}] |
| ) |
| |
| def testCreateBadDeviceId(self): |
| driver = ss.get_driver("local-sync") |
| with self.assertRaises( |
| ValueError, |
| msg="Device id 5555 not found. Available devices: [{ device_id:0, path:'', name:'default'}]", |
| ): |
| _ = driver.create_device(5555) |
| |
| def testCreateDevice(self): |
| driver = ss.get_driver("local-sync") |
| infos = driver.query_available_devices() |
| # Each record is a dict: |
| # {"device_id": obj, "path": str, "name": str}. |
| device1 = driver.create_device(infos[0]["device_id"]) |
| # Should also take the info dict directly for convenience. |
| device2 = driver.create_device(infos[0]) |
| |
| def testCreateDeviceByName(self): |
| device1 = ss.get_device("local-task") |
| device2 = ss.get_device("local-sync") |
| device3 = ss.get_device("local-sync") |
| device4 = ss.get_device("local-sync", cache=False) |
| self.assertIsNot(device1, device2) |
| self.assertIsNot(device3, device4) |
| self.assertIs(device2, device3) |
| |
| with self.assertRaises(ValueError, msg="Device not found: local-sync://1"): |
| _ = ss.get_device("local-sync://1") |
| |
| def testCreateDeviceWithAllocators(self): |
| driver = ss.get_driver("local-sync") |
| infos = driver.query_available_devices() |
| device1 = driver.create_device(infos[0]["device_id"], allocators=[]) |
| device2 = driver.create_device( |
| infos[0]["device_id"], allocators=["caching", "debug"] |
| ) |
| |
| |
| if __name__ == "__main__": |
| logging.basicConfig(level=logging.INFO) |
| unittest.main() |