| # Copyright 2023 The IREE Authors |
| # |
| # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| import pytest |
| from ireers_tools import * |
| |
| ############################################################################### |
| # Fixtures |
| ############################################################################### |
| |
| COMMON_FLAGS = [ |
| "--iree-input-type=none", |
| ] |
| |
| argmax_ukernel_source = fetch_source_fixture( |
| "https://sharktank.blob.core.windows.net/sharktank/ukernel_regression/20231217/argmax/argmax_3d_linalg.mlir", |
| group="argmax_ukernel_linalg", |
| ) |
| |
| |
| @pytest.fixture |
| def argmax_ukernel_host_cpu_vmfb(argmax_ukernel_source): |
| return iree_compile( |
| argmax_ukernel_source, |
| "host_cpu", |
| flags=COMMON_FLAGS |
| + [ |
| "--iree-hal-target-backends=llvm-cpu", |
| "--iree-llvmcpu-target-cpu-features=host", |
| ], |
| ) |
| |
| |
| @pytest.fixture |
| def argmax_ukernel_gfx90a_rocm_vmfb(argmax_ukernel_source): |
| return iree_compile( |
| argmax_ukernel_source, |
| "gfx90a_rocm", |
| flags=COMMON_FLAGS |
| + [ |
| "--iree-hal-target-backends=rocm", |
| "--iree-hip-target=gfx90a", |
| "--iree-hip-enable-ukernels=argmax", |
| ], |
| ) |
| |
| |
| @pytest.fixture |
| def argmax_ukernel_gfx940_rocm_vmfb(argmax_ukernel_source): |
| return iree_compile( |
| argmax_ukernel_source, |
| "gfx940_rocm", |
| flags=COMMON_FLAGS |
| + [ |
| "--iree-hal-target-backends=rocm", |
| "--iree-hip-target=gfx940", |
| "--iree-hip-enable-ukernels=argmax", |
| ], |
| ) |
| |
| |
| ############################################################################### |
| # Correctness |
| ############################################################################### |
| |
| # Generation script: |
| # argmax_input_f16 = np.random.normal(size=[2, 4, 33000]).astype(np.float32) |
| # argmax_output_f16 = np.argmax(argmax_input_f16,axis=-1).astype(np.float32) |
| # argmax_input_f32 = np.random.normal(size=[2, 4, 33000]).astype(np.float32) |
| # argmax_output_f32 = np.argmax(argmax_input_f32,axis=-1).astype(np.float32) |
| # TODO: Currently forcing sitofp (i32 -> f32) and (i64 -> f32) because expected_output |
| # cannot compare signless i64 from vmfb and by default si64 from npy. |
| |
| argmax_input_f16 = fetch_source_fixture( |
| "https://sharktank.blob.core.windows.net/sharktank/ukernel_regression/20231217/argmax/argmax_3d_input_f16.npy", |
| group="argmax_ukernel_input_f16", |
| ) |
| |
| argmax_output_f16 = fetch_source_fixture( |
| "https://sharktank.blob.core.windows.net/sharktank/ukernel_regression/20231217/argmax/argmax_3d_output_f16.npy", |
| group="argmax_ukernel_output_f16", |
| ) |
| |
| argmax_input_f32 = fetch_source_fixture( |
| "https://sharktank.blob.core.windows.net/sharktank/ukernel_regression/20231217/argmax/argmax_3d_input_f32.npy", |
| group="argmax_ukernel_input_f32", |
| ) |
| |
| argmax_output_f32 = fetch_source_fixture( |
| "https://sharktank.blob.core.windows.net/sharktank/ukernel_regression/20231217/argmax/argmax_3d_output_f32.npy", |
| group="argmax_ukernel_output_f32", |
| ) |
| |
| |
| @pytest.mark.presubmit |
| @pytest.mark.unstable_linalg |
| @pytest.mark.plat_host_cpu |
| def test_correctness_host_cpu( |
| argmax_ukernel_host_cpu_vmfb, |
| argmax_input_f16, |
| argmax_output_f16, |
| argmax_input_f32, |
| argmax_output_f32, |
| ): |
| iree_run_module( |
| argmax_ukernel_host_cpu_vmfb, |
| device="local-task", |
| function="argmax_3d_dyn_f16i32", |
| args=[ |
| f"--input=@{argmax_input_f16.path}", |
| f"--expected_output=@{argmax_output_f16.path}", |
| ], |
| ) |
| iree_run_module( |
| argmax_ukernel_host_cpu_vmfb, |
| device="local-task", |
| function="argmax_3d_dyn_f16i64", |
| args=[ |
| f"--input=@{argmax_input_f16.path}", |
| f"--expected_output=@{argmax_output_f16.path}", |
| ], |
| ) |
| |
| iree_run_module( |
| argmax_ukernel_host_cpu_vmfb, |
| device="local-task", |
| function="argmax_3d_dyn_f32i32", |
| args=[ |
| f"--input=@{argmax_input_f32.path}", |
| f"--expected_output=@{argmax_output_f32.path}", |
| ], |
| ) |
| iree_run_module( |
| argmax_ukernel_host_cpu_vmfb, |
| device="local-task", |
| function="argmax_3d_dyn_f32i64", |
| args=[ |
| f"--input=@{argmax_input_f32.path}", |
| f"--expected_output=@{argmax_output_f32.path}", |
| ], |
| ) |
| |
| |
| @pytest.mark.presubmit |
| @pytest.mark.unstable_linalg |
| @pytest.mark.plat_gfx90a_rocm |
| def test_correctness_gfx90a_rocm( |
| argmax_ukernel_gfx90a_rocm_vmfb, |
| argmax_input_f16, |
| argmax_output_f16, |
| argmax_input_f32, |
| argmax_output_f32, |
| ): |
| iree_run_module( |
| argmax_ukernel_gfx90a_rocm_vmfb, |
| device="hip", |
| function="argmax_3d_dyn_f16i32", |
| args=[ |
| f"--input=@{argmax_input_f16.path}", |
| f"--expected_output=@{argmax_output_f16.path}", |
| ], |
| ) |
| iree_run_module( |
| argmax_ukernel_gfx90a_rocm_vmfb, |
| device="hip", |
| function="argmax_3d_dyn_f16i64", |
| args=[ |
| f"--input=@{argmax_input_f16.path}", |
| f"--expected_output=@{argmax_output_f16.path}", |
| ], |
| ) |
| |
| iree_run_module( |
| argmax_ukernel_gfx90a_rocm_vmfb, |
| device="hip", |
| function="argmax_3d_dyn_f32i32", |
| args=[ |
| f"--input=@{argmax_input_f32.path}", |
| f"--expected_output=@{argmax_output_f32.path}", |
| ], |
| ) |
| iree_run_module( |
| argmax_ukernel_gfx90a_rocm_vmfb, |
| device="hip", |
| function="argmax_3d_dyn_f32i64", |
| args=[ |
| f"--input=@{argmax_input_f32.path}", |
| f"--expected_output=@{argmax_output_f32.path}", |
| ], |
| ) |
| |
| |
| @pytest.mark.presubmit |
| @pytest.mark.unstable_linalg |
| @pytest.mark.plat_gfx940_rocm |
| def test_correctness_gfx940_rocm( |
| argmax_ukernel_gfx940_rocm_vmfb, |
| argmax_input_f16, |
| argmax_output_f16, |
| argmax_input_f32, |
| argmax_output_f32, |
| ): |
| iree_run_module( |
| argmax_ukernel_gfx940_rocm_vmfb, |
| device="hip", |
| function="argmax_3d_dyn_f16i32", |
| args=[ |
| f"--input=@{argmax_input_f16.path}", |
| f"--expected_output=@{argmax_output_f16.path}", |
| ], |
| ) |
| iree_run_module( |
| argmax_ukernel_gfx940_rocm_vmfb, |
| device="hip", |
| function="argmax_3d_dyn_f16i64", |
| args=[ |
| f"--input=@{argmax_input_f16.path}", |
| f"--expected_output=@{argmax_output_f16.path}", |
| ], |
| ) |
| |
| iree_run_module( |
| argmax_ukernel_gfx940_rocm_vmfb, |
| device="hip", |
| function="argmax_3d_dyn_f32i32", |
| args=[ |
| f"--input=@{argmax_input_f32.path}", |
| f"--expected_output=@{argmax_output_f32.path}", |
| ], |
| ) |
| iree_run_module( |
| argmax_ukernel_gfx940_rocm_vmfb, |
| device="hip", |
| function="argmax_3d_dyn_f32i64", |
| args=[ |
| f"--input=@{argmax_input_f32.path}", |
| f"--expected_output=@{argmax_output_f32.path}", |
| ], |
| ) |