| # Copyright 2019 The IREE Authors |
| # |
| # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| import iree.runtime |
| |
| import gc |
| import numpy as np |
| import threading |
| import time |
| import unittest |
| |
| |
| class NonDeviceHalTest(unittest.TestCase): |
| def testMemoryEnums(self): |
| print("MemoryType:", iree.runtime.MemoryType) |
| print("HOST_VISIBLE:", int(iree.runtime.MemoryType.HOST_VISIBLE)) |
| |
| # Enum and/or operations on BufferCompatibility. |
| self.assertEqual( |
| iree.runtime.BufferCompatibility.IMPORTABLE |
| | iree.runtime.BufferCompatibility.EXPORTABLE, |
| int(iree.runtime.BufferCompatibility.IMPORTABLE) |
| | int(iree.runtime.BufferCompatibility.EXPORTABLE), |
| ) |
| self.assertEqual( |
| iree.runtime.BufferCompatibility.EXPORTABLE |
| & iree.runtime.BufferCompatibility.EXPORTABLE, |
| int(iree.runtime.BufferCompatibility.EXPORTABLE), |
| ) |
| |
| # Enum and/or operations on BufferUsage. |
| self.assertEqual( |
| iree.runtime.BufferUsage.TRANSFER | iree.runtime.BufferUsage.MAPPING, |
| int(iree.runtime.BufferUsage.TRANSFER) |
| | int(iree.runtime.BufferUsage.MAPPING), |
| ) |
| self.assertEqual( |
| iree.runtime.BufferUsage.TRANSFER & iree.runtime.BufferUsage.TRANSFER, |
| int(iree.runtime.BufferUsage.TRANSFER), |
| ) |
| |
| # Enum and/or operations on MemoryAccess. |
| self.assertEqual( |
| iree.runtime.MemoryAccess.READ | iree.runtime.MemoryAccess.WRITE, |
| int(iree.runtime.MemoryAccess.READ) | int(iree.runtime.MemoryAccess.WRITE), |
| ) |
| self.assertEqual( |
| iree.runtime.MemoryAccess.ALL & iree.runtime.MemoryAccess.READ, |
| int(iree.runtime.MemoryAccess.READ), |
| ) |
| |
| # Enum and/or operations on MemoryType. |
| self.assertEqual( |
| iree.runtime.MemoryType.DEVICE_LOCAL | iree.runtime.MemoryType.HOST_VISIBLE, |
| int(iree.runtime.MemoryType.DEVICE_LOCAL) |
| | int(iree.runtime.MemoryType.HOST_VISIBLE), |
| ) |
| self.assertEqual( |
| iree.runtime.MemoryType.OPTIMAL & iree.runtime.MemoryType.OPTIMAL, |
| int(iree.runtime.MemoryType.OPTIMAL), |
| ) |
| |
| def testElementTypeEnums(self): |
| i8 = iree.runtime.HalElementType.INT_8 |
| i4 = iree.runtime.HalElementType.INT_4 |
| self.assertTrue(iree.runtime.HalElementType.is_byte_aligned(i8)) |
| self.assertFalse(iree.runtime.HalElementType.is_byte_aligned(i4)) |
| self.assertEqual(1, iree.runtime.HalElementType.dense_byte_count(i8)) |
| |
| |
| class DeviceHalTest(unittest.TestCase): |
| def setUp(self): |
| super().setUp() |
| self.device = iree.runtime.get_device("local-task") |
| self.allocator = self.device.allocator |
| gc.collect() |
| |
| def testTrim(self): |
| self.allocator.trim() |
| # Just running is sufficient. |
| |
| def testProfilingDefaults(self): |
| self.device.begin_profiling() |
| self.device.flush_profiling() |
| self.device.end_profiling() |
| # Just running is sufficient. |
| |
| def testProfilingOptions(self): |
| self.device.begin_profiling(mode="queue", file_path="foo.rdc") |
| self.device.end_profiling() |
| # Just running is sufficient. |
| |
| def testProfilingInvalidOptions(self): |
| with self.assertRaisesRegex(ValueError, "unrecognized profiling mode"): |
| self.device.begin_profiling(mode="SOMETHING THAT DOESN'T EXIST") |
| |
| def testStatistics(self): |
| stats_dict = self.allocator.statistics |
| stats_str = self.allocator.formatted_statistics |
| if self.allocator.has_statistics: |
| self.assertIn("host_bytes_peak", stats_dict) |
| self.assertIn("host_bytes_allocated", stats_dict) |
| self.assertIn("host_bytes_freed", stats_dict) |
| self.assertIn("device_bytes_peak", stats_dict) |
| self.assertIn("device_bytes_allocated", stats_dict) |
| self.assertIn("device_bytes_freed", stats_dict) |
| self.assertIn("HOST_LOCAL", stats_str) |
| |
| def testQueryCompatibility(self): |
| compat = self.allocator.query_buffer_compatibility( |
| memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, |
| allowed_usage=iree.runtime.BufferUsage.DEFAULT, |
| intended_usage=iree.runtime.BufferUsage.DEFAULT, |
| allocation_size=1024, |
| ) |
| print("COMPAT:", compat) |
| self.assertTrue( |
| bool(compat & int(iree.runtime.BufferCompatibility.ALLOCATABLE)), |
| "should be allocatable", |
| ) |
| self.assertTrue( |
| bool(compat & int(iree.runtime.BufferCompatibility.IMPORTABLE)), |
| "should be importable", |
| ) |
| self.assertTrue( |
| bool(compat & int(iree.runtime.BufferCompatibility.EXPORTABLE)), |
| "should be exportable", |
| ) |
| |
| def testAllocateBuffer(self): |
| buffer = self.allocator.allocate_buffer( |
| memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, |
| allowed_usage=iree.runtime.BufferUsage.DEFAULT, |
| allocation_size=13, |
| ) |
| print("BUFFER:", buffer) |
| |
| def testBufferViewConstructor(self): |
| buffer = self.allocator.allocate_buffer( |
| memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, |
| allowed_usage=iree.runtime.BufferUsage.DEFAULT, |
| allocation_size=13, |
| ) |
| bv = iree.runtime.HalBufferView( |
| buffer, (1, 2), iree.runtime.HalElementType.INT_16 |
| ) |
| # NOTE: the exact bits set on type/usage/etc is implementation defined. |
| self.assertEqual( |
| repr(bv), |
| "<HalBufferView (1, 2), element_type=0x10000010, 13 bytes (at offset 0 into 13), memory_type=DEVICE_LOCAL|HOST_VISIBLE, allowed_access=ALL, allowed_usage=TRANSFER|DISPATCH_STORAGE|MAPPING|MAPPING_PERSISTENT>", |
| ) |
| self.assertEqual(4, bv.byte_length) |
| |
| def testBufferMap(self): |
| buffer = self.allocator.allocate_buffer( |
| memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, |
| allowed_usage=iree.runtime.BufferUsage.DEFAULT, |
| allocation_size=13, |
| ) |
| m = buffer.map() |
| self.assertIsInstance(m, iree.runtime.MappedMemory) |
| |
| def testAllocateBufferCopy(self): |
| ary = np.zeros([3, 4], dtype=np.int32) + 2 |
| buffer = self.allocator.allocate_buffer_copy( |
| memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, |
| allowed_usage=iree.runtime.BufferUsage.DEFAULT, |
| device=self.device, |
| buffer=ary, |
| ) |
| # NOTE: the exact bits set on type/usage/etc is implementation defined. |
| self.assertEqual( |
| repr(buffer), |
| "<HalBuffer 48 bytes (at offset 0 into 48), memory_type=DEVICE_LOCAL|HOST_VISIBLE, allowed_access=ALL, allowed_usage=TRANSFER|DISPATCH_STORAGE|MAPPING|MAPPING_PERSISTENT>", |
| ) |
| |
| def testAllocateBufferViewCopy(self): |
| ary = np.zeros([3, 4], dtype=np.int32) + 2 |
| buffer = self.allocator.allocate_buffer_copy( |
| memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, |
| allowed_usage=iree.runtime.BufferUsage.DEFAULT, |
| device=self.device, |
| buffer=ary, |
| element_type=iree.runtime.HalElementType.SINT_32, |
| ) |
| # NOTE: the exact bits set on type/usage/etc is implementation defined. |
| self.assertEqual( |
| repr(buffer), |
| "<HalBufferView (3, 4), element_type=0x20000011, 48 bytes (at offset 0 into 48), memory_type=DEVICE_LOCAL|HOST_VISIBLE, allowed_access=ALL, allowed_usage=TRANSFER|DISPATCH_STORAGE|MAPPING|MAPPING_PERSISTENT>", |
| ) |
| |
| def testAllocateHostStagingBufferCopy(self): |
| buffer = self.allocator.allocate_host_staging_buffer_copy( |
| self.device, np.int32(0) |
| ) |
| # NOTE: the exact bits set on type/usage/etc is implementation defined. |
| self.assertEqual( |
| repr(buffer), |
| "<HalBuffer 4 bytes (at offset 0 into 4), memory_type=DEVICE_LOCAL|HOST_VISIBLE, allowed_access=ALL, allowed_usage=TRANSFER|MAPPING|MAPPING_PERSISTENT>", |
| ) |
| |
| def testSemaphore(self): |
| sem0 = self.device.create_semaphore(0) |
| self.assertEqual(sem0.query(), 0) |
| sem1 = self.device.create_semaphore(1) |
| self.assertEqual(sem1.query(), 1) |
| sem1.signal(2) |
| self.assertEqual(sem1.query(), 2) |
| |
| def testSemaphoreSignal(self): |
| sem = self.device.create_semaphore(0) |
| self.assertFalse(sem.wait(1, deadline=0)) |
| sem.signal(1) |
| self.assertTrue(sem.wait(1, deadline=0)) |
| |
| def testSynchronousSemaphoreFailed(self): |
| sem = self.device.create_semaphore(0) |
| sem.fail("TEST FAILURE") |
| with self.assertRaisesRegex( |
| RuntimeError, "^synchronous semaphore failure.*TEST FAILURE" |
| ): |
| sem.wait(1, deadline=0) |
| |
| def testAsynchronousSemaphoreFailed(self): |
| sem = self.device.create_semaphore(0) |
| exceptions = [] |
| |
| def run(): |
| print("SIGNALLING ASYNC FAILURE") |
| time.sleep(0.2) |
| sem.fail("TEST FAILURE") |
| print("SIGNALLED") |
| |
| def wait(): |
| print("WAITING") |
| try: |
| sem.wait(1) |
| except RuntimeError as e: |
| exceptions.append(e) |
| |
| runner = threading.Thread(target=run) |
| waiter = threading.Thread(target=wait) |
| waiter.start() |
| runner.start() |
| waiter.join() |
| runner.join() |
| self.assertTrue(exceptions) |
| print(exceptions) |
| # Note: It is impossible to 100% guarantee that this sequences such as to |
| # report an asynchronous vs synchronous failure, although we tip the odds in |
| # this favor with the sleep in the signalling thread. Therefore, we do not |
| # check the "asynchronous" vs "synchronous" message prefix to avoid flaky |
| # test races. |
| self.assertIn("TEST FAILURE", str(exceptions[0])) |
| |
| def testTrivialQueueAlloc(self): |
| sem = self.device.create_semaphore(0) |
| buf = self.device.queue_alloca( |
| 1024, wait_semaphores=[(sem, 0)], signal_semaphores=[(sem, 1)] |
| ) |
| self.assertIsInstance(buf, iree.runtime.HalBuffer) |
| self.device.queue_dealloca( |
| buf, wait_semaphores=[(sem, 1)], signal_semaphores=[] |
| ) |
| |
| def testAllocAcceptsFences(self): |
| # Also tests HalFence, HalFence.insert, HalFence.wait (infinite) |
| sem = self.device.create_semaphore(0) |
| fence0 = iree.runtime.HalFence(1) |
| fence0.insert(sem, 0) |
| fence1 = iree.runtime.HalFence(1) |
| fence1.insert(sem, 1) |
| fence2 = iree.runtime.HalFence(2) |
| fence2.insert(sem, 2) |
| buf = self.device.queue_alloca( |
| 1024, wait_semaphores=fence0, signal_semaphores=fence1 |
| ) |
| self.assertIsInstance(buf, iree.runtime.HalBuffer) |
| self.device.queue_dealloca( |
| buf, wait_semaphores=fence1, signal_semaphores=fence2 |
| ) |
| self.assertTrue(fence2.wait()) |
| self.assertEqual(sem.query(), 2) |
| |
| def testFenceCreateAt(self): |
| sem = self.device.create_semaphore(0) |
| fence = iree.runtime.HalFence.create_at(sem, 1) |
| self.assertFalse(fence.wait(deadline=0)) |
| sem.signal(1) |
| self.assertTrue(fence.wait(deadline=0)) |
| |
| def testFenceSignal(self): |
| sem = self.device.create_semaphore(0) |
| fence = iree.runtime.HalFence.create_at(sem, 1) |
| self.assertFalse(fence.wait(deadline=0)) |
| fence.signal() |
| self.assertTrue(fence.wait(deadline=0)) |
| |
| def testSynchronousFenceFailed(self): |
| sem = self.device.create_semaphore(0) |
| fence = iree.runtime.HalFence.create_at(sem, 1) |
| fence.fail("TEST FAILURE") |
| with self.assertRaisesRegex( |
| RuntimeError, "^synchronous fence failure.*TEST FAILURE" |
| ): |
| fence.wait(deadline=0) |
| |
| def testAsynchronousFenceFailed(self): |
| sem = self.device.create_semaphore(0) |
| fence = iree.runtime.HalFence.create_at(sem, 1) |
| exceptions = [] |
| |
| def run(): |
| print("SIGNALLING ASYNC FAILURE") |
| time.sleep(0.2) |
| fence.fail("TEST FAILURE") |
| print("SIGNALLED") |
| |
| def wait(): |
| print("WAITING") |
| try: |
| fence.wait() |
| except RuntimeError as e: |
| exceptions.append(e) |
| |
| runner = threading.Thread(target=run) |
| waiter = threading.Thread(target=wait) |
| waiter.start() |
| runner.start() |
| waiter.join() |
| runner.join() |
| self.assertTrue(exceptions) |
| print(exceptions) |
| # Note: It is impossible to 100% guarantee that this sequences such as to |
| # report an asynchronous vs synchronous failure, although we tip the odds in |
| # this favor with the sleep in the signalling thread. Therefore, we do not |
| # check the "asynchronous" vs "synchronous" message prefix to avoid flaky |
| # test races. |
| self.assertIn("TEST FAILURE", str(exceptions[0])) |
| |
| def testFenceJoin(self): |
| sem1 = self.device.create_semaphore(0) |
| sem2 = self.device.create_semaphore(0) |
| fence1 = iree.runtime.HalFence.create_at(sem1, 1) |
| fence2 = iree.runtime.HalFence.create_at(sem2, 1) |
| fence = iree.runtime.HalFence.join([fence1, fence2]) |
| self.assertEqual(fence.timepoint_count, 2) |
| |
| def testFenceInsert(self): |
| sem1 = self.device.create_semaphore(0) |
| sem2 = self.device.create_semaphore(0) |
| fence = iree.runtime.HalFence(2) |
| fence.insert(sem1, 1) |
| self.assertEqual(fence.timepoint_count, 1) |
| fence.insert(sem1, 2) |
| self.assertEqual(fence.timepoint_count, 1) |
| fence.insert(sem2, 2) |
| self.assertEqual(fence.timepoint_count, 2) |
| |
| def testFenceExtend(self): |
| sem1 = self.device.create_semaphore(0) |
| sem2 = self.device.create_semaphore(0) |
| fence = iree.runtime.HalFence(2) |
| fence.insert(sem1, 1) |
| self.assertEqual(fence.timepoint_count, 1) |
| fence.extend(iree.runtime.HalFence.create_at(sem2, 2)) |
| self.assertEqual(fence.timepoint_count, 2) |
| |
| def testRoundTripQueueCopy(self): |
| original_ary = np.zeros([3, 4], dtype=np.int32) + 2 |
| source_bv = self.allocator.allocate_buffer_copy( |
| memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, |
| allowed_usage=iree.runtime.BufferUsage.DEFAULT, |
| device=self.device, |
| buffer=original_ary, |
| element_type=iree.runtime.HalElementType.SINT_32, |
| ) |
| source_buffer = source_bv.get_buffer() |
| target_buffer = self.allocator.allocate_buffer( |
| memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, |
| allowed_usage=iree.runtime.BufferUsage.DEFAULT, |
| allocation_size=source_buffer.byte_length(), |
| ) |
| sem = self.device.create_semaphore(0) |
| self.device.queue_copy( |
| source_buffer, |
| target_buffer, |
| wait_semaphores=iree.runtime.HalFence.create_at(sem, 0), |
| signal_semaphores=iree.runtime.HalFence.create_at(sem, 1), |
| ) |
| iree.runtime.HalFence.create_at(sem, 1).wait() |
| copy_ary = target_buffer.map().asarray(original_ary.shape, original_ary.dtype) |
| np.testing.assert_array_equal(original_ary, copy_ary) |
| |
| def testIncompatibleSizeQueueCopy(self): |
| source_buffer = self.allocator.allocate_buffer( |
| memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, |
| allowed_usage=iree.runtime.BufferUsage.DEFAULT, |
| allocation_size=13, |
| ) |
| target_buffer = self.allocator.allocate_buffer( |
| memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, |
| allowed_usage=iree.runtime.BufferUsage.DEFAULT, |
| allocation_size=12, |
| ) |
| sem = self.device.create_semaphore(0) |
| with self.assertRaisesRegex(ValueError, "length must be less than"): |
| self.device.queue_copy( |
| source_buffer, |
| target_buffer, |
| wait_semaphores=iree.runtime.HalFence.create_at(sem, 0), |
| signal_semaphores=iree.runtime.HalFence.create_at(sem, 1), |
| ) |
| |
| def testCommandBufferStartsByDefault(self): |
| cb = iree.runtime.HalCommandBuffer(self.device) |
| with self.assertRaisesRegex(RuntimeError, "FAILED_PRECONDITION"): |
| cb.begin() |
| cb = iree.runtime.HalCommandBuffer(self.device, begin=False) |
| cb.begin() |
| |
| def testCommandBufferCopy(self): |
| # Doesn't test much but that calls succeed. |
| cb = iree.runtime.HalCommandBuffer(self.device) |
| buffer1 = self.allocator.allocate_buffer( |
| memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, |
| allowed_usage=iree.runtime.BufferUsage.DEFAULT, |
| allocation_size=13, |
| ) |
| buffer2 = self.allocator.allocate_buffer( |
| memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, |
| allowed_usage=iree.runtime.BufferUsage.DEFAULT, |
| allocation_size=13, |
| ) |
| cb.copy(buffer1, buffer2, end=True) |
| with self.assertRaisesRegex(RuntimeError, "FAILED_PRECONDITION"): |
| cb.end() |
| |
| def testCommandBufferFill(self): |
| # Doesn't test much but that calls succeed. |
| cb = iree.runtime.HalCommandBuffer(self.device) |
| buffer1 = self.allocator.allocate_buffer( |
| memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, |
| allowed_usage=iree.runtime.BufferUsage.DEFAULT, |
| allocation_size=12, |
| ) |
| cb.fill(buffer1, np.int32(1), 0, 12, end=True) |
| with self.assertRaisesRegex(RuntimeError, "FAILED_PRECONDITION"): |
| cb.end() |
| |
| def testCommandBufferExecute(self): |
| # Doesn't test much but that calls succeed. |
| cb = iree.runtime.HalCommandBuffer(self.device) |
| buffer1 = self.allocator.allocate_buffer( |
| memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, |
| allowed_usage=iree.runtime.BufferUsage.DEFAULT, |
| allocation_size=12, |
| ) |
| cb.fill(buffer1, np.int32(1), 0, 12, end=True) |
| |
| sem = self.device.create_semaphore(0) |
| self.device.queue_execute( |
| [cb], wait_semaphores=[(sem, 0)], signal_semaphores=[(sem, 1)] |
| ) |
| iree.runtime.HalFence.create_at(sem, 1).wait() |
| |
| def testCommandBufferExecuteAcceptsFence(self): |
| # Doesn't test much but that calls succeed. |
| cb = iree.runtime.HalCommandBuffer(self.device) |
| buffer1 = self.allocator.allocate_buffer( |
| memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, |
| allowed_usage=iree.runtime.BufferUsage.DEFAULT, |
| allocation_size=12, |
| ) |
| cb.fill(buffer1, np.int32(1), 0, 12, end=True) |
| |
| sem = self.device.create_semaphore(0) |
| self.device.queue_execute( |
| [cb], |
| wait_semaphores=iree.runtime.HalFence.create_at(sem, 0), |
| signal_semaphores=iree.runtime.HalFence.create_at(sem, 1), |
| ) |
| iree.runtime.HalFence.create_at(sem, 1).wait() |
| |
| |
| class DeviceDLPackTest(unittest.TestCase): |
| """Tests low level DLPack import/export against the CPU HAL backend. |
| |
| This test leverages the fact that numpy is a reasonable dlpack |
| producer/consumer. It has the caveat that our low level support does not |
| allow import of non page aligned data, so we have to take some extra |
| steps to prep it. For pure CPU/Numpy import/export, we have better |
| supported paths than this, but we leverage it here for its testing |
| value, as it exercises code paths that are otherwise only accessible |
| on devices. |
| """ |
| |
| def setUp(self): |
| super().setUp() |
| self.device = iree.runtime.get_device("local-task") |
| self.allocator = self.device.allocator |
| gc.collect() |
| |
| def roundtrip(self, input_array, element_type): |
| # We have top copy the input array into our own buffer to ensure |
| # alignment (dlpack import/export require aligned data). |
| orig_bv = self.allocator.allocate_buffer_copy( |
| memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, |
| allowed_usage=iree.runtime.BufferUsage.DEFAULT, |
| device=self.device, |
| buffer=input_array, |
| element_type=element_type, |
| ) |
| aligned_input_array = orig_bv.map().asarray( |
| input_array.shape, input_array.dtype |
| ) |
| |
| # Export the __dlpack__ capsule from numpy, which should be a plain |
| # view over the buffer we originally allocated (therefore, aligned |
| # and importable). |
| input_capsule = aligned_input_array.__dlpack__() |
| aligned_input_array = None |
| gc.collect() |
| imported_bv = self.device.from_dlpack_capsule(input_capsule) |
| |
| # Export a capsule from this imported buffer view and create a new |
| # array out of it. |
| class DummyProducer: |
| def __dlpack__(_, stream=None): |
| capsule = self.device.create_dlpack_capsule(imported_bv, 1, 0) |
| return capsule |
| |
| def __dlpack_device__(self): |
| return (1, 0) # CPU, id 0 |
| |
| reimported_array = np.from_dlpack(DummyProducer()) |
| imported_bv = None |
| gc.collect() |
| np.testing.assert_array_equal(input_array, reimported_array) |
| |
| def testImportExportF64(self): |
| self.roundtrip(np.random.rand(3, 4), iree.runtime.HalElementType.FLOAT_64) |
| |
| def testImportExportF32(self): |
| self.roundtrip( |
| np.random.rand(3, 4, 16, 32, 1, 5, 2).astype(np.float32), |
| iree.runtime.HalElementType.FLOAT_32, |
| ) |
| |
| def testImportExportF16(self): |
| self.roundtrip( |
| np.random.rand(3, 4).astype(np.float16), |
| iree.runtime.HalElementType.FLOAT_16, |
| ) |
| |
| def testImportExportSI8(self): |
| self.roundtrip( |
| (np.random.rand(3, 4) * 255.0).astype(np.int8), |
| iree.runtime.HalElementType.SINT_8, |
| ) |
| |
| def testImportExportSI16(self): |
| self.roundtrip( |
| (np.random.rand(3, 4) * 255.0).astype(np.int16), |
| iree.runtime.HalElementType.SINT_16, |
| ) |
| |
| def testImportExportSI32(self): |
| self.roundtrip( |
| (np.random.rand(3, 4) * 255.0).astype(np.int32), |
| iree.runtime.HalElementType.SINT_32, |
| ) |
| |
| def testImportExportSI64(self): |
| self.roundtrip( |
| (np.random.rand(3, 4) * 255.0).astype(np.int64), |
| iree.runtime.HalElementType.SINT_64, |
| ) |
| |
| def testImportExportUI8(self): |
| self.roundtrip( |
| (np.random.rand(3, 4) * 255.0).astype(np.uint8), |
| iree.runtime.HalElementType.UINT_8, |
| ) |
| |
| def testImportExportUI16(self): |
| self.roundtrip( |
| (np.random.rand(3, 4) * 255.0).astype(np.uint16), |
| iree.runtime.HalElementType.UINT_16, |
| ) |
| |
| def testImportExportUI32(self): |
| self.roundtrip( |
| (np.random.rand(3, 4) * 255.0).astype(np.uint32), |
| iree.runtime.HalElementType.UINT_32, |
| ) |
| |
| def testImportExportUI64(self): |
| self.roundtrip( |
| (np.random.rand(3, 4) * 255.0).astype(np.uint64), |
| iree.runtime.HalElementType.UINT_64, |
| ) |
| |
| def testImportExportBool(self): |
| self.roundtrip( |
| (np.random.rand(3, 4) * 255.0).astype(np.bool_), |
| iree.runtime.HalElementType.BOOL_8, |
| ) |
| |
| def testImportExportUI64(self): |
| self.roundtrip( |
| (np.random.rand(3, 4) * 255.0).astype(np.uint64), |
| iree.runtime.HalElementType.UINT_64, |
| ) |
| |
| def testImportExportComplex64(self): |
| shape = (3, 1, 5, 6, 12, 2, 3) |
| self.roundtrip( |
| np.random.uniform(-1, 1, shape) + 1.0j * np.random.uniform(-1, 1, shape), |
| iree.runtime.HalElementType.COMPLEX_64, |
| ) |
| |
| def testImportExportComplex64(self): |
| shape = (3, 1, 5, 6, 12, 2, 3) |
| self.roundtrip( |
| ( |
| np.random.uniform(-1, 1, shape) + 1.0j * np.random.uniform(-1, 1, shape) |
| ).astype(np.complex128), |
| iree.runtime.HalElementType.COMPLEX_64, |
| ) |
| |
| |
| if __name__ == "__main__": |
| unittest.main() |