blob: 9052f2b334adf21fbf8edde2fdad341520ac7c6a [file] [log] [blame]
## Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Defines e2e model tests on benchmark models."""
from typing import List
from dataclasses import dataclass
import dataclasses
import enum
from e2e_test_framework.definitions import common_definitions, iree_definitions
from e2e_test_framework.models import tflite_models
from benchmark_suites.iree import (
riscv_benchmarks,
x86_64_benchmarks,
armv8_a_benchmarks,
module_execution_configs,
)
class CMakePlatform(enum.Enum):
"""Enum of CMake system platform string."""
ANDROID_ARMV8_A = "arm_64-Android"
LINUX_RISCV32 = "riscv_32-Linux"
LINUX_RISCV64 = "riscv_64-Linux"
LINUX_X86_64 = "x86_64-Linux"
# Compile config used for each CMake system platform.
PLATFORM_COMPILE_CONFIG_MAP = {
CMakePlatform.ANDROID_ARMV8_A: armv8_a_benchmarks.Android_ARMv8_A_Benchmarks.DT_UK_COMPILE_CONFIG,
CMakePlatform.LINUX_RISCV32: riscv_benchmarks.Linux_RV32_Benchmarks.DEFAULT_COMPILE_CONFIG,
CMakePlatform.LINUX_RISCV64: riscv_benchmarks.Linux_RV64_Benchmarks.DEFAULT_COMPILE_CONFIG,
CMakePlatform.LINUX_X86_64: x86_64_benchmarks.Linux_x86_64_Benchmarks.CASCADELAKE_DT_UK_COMPILE_CONFIG,
}
@dataclass(frozen=True)
class ModelTestConfig(object):
"""Defines an e2e model test to run by iree-run-module."""
# Test name shown in the test rule.
name: str
imported_model: iree_definitions.ImportedModel
execution_config: iree_definitions.ModuleExecutionConfig
# Either a string literal or a file path.
expected_output: str
# Platforms to ignore this test.
unsupported_platforms: List[CMakePlatform] = dataclasses.field(default_factory=list)
# Platforms to expect this test failed.
xfail_platforms: List[CMakePlatform] = dataclasses.field(default_factory=list)
# Extra flags for `iree-run-module`.
extra_test_flags: List[str] = dataclasses.field(default_factory=list)
TEST_CONFIGS = [
# mobilenet_v1_fp32_correctness_test
ModelTestConfig(
name="mobilenet_v1_fp32_correctness_test",
imported_model=iree_definitions.ImportedModel.from_model(
tflite_models.MOBILENET_V1
),
execution_config=module_execution_configs.ELF_LOCAL_SYNC_CONFIG,
expected_output="mobilenet_v1_fp32_expected_output.txt",
unsupported_platforms=[
CMakePlatform.LINUX_RISCV32,
CMakePlatform.ANDROID_ARMV8_A,
],
),
# efficientnet_int8_correctness_test
ModelTestConfig(
name="efficientnet_int8_correctness_test",
imported_model=iree_definitions.ImportedModel.from_model(
tflite_models.EFFICIENTNET_INT8
),
execution_config=module_execution_configs.ELF_LOCAL_SYNC_CONFIG,
expected_output="efficientnet_int8_expected_output.txt",
unsupported_platforms=[
CMakePlatform.ANDROID_ARMV8_A,
CMakePlatform.LINUX_RISCV32,
CMakePlatform.LINUX_RISCV64,
],
),
# deeplab_v3_fp32_correctness_test
ModelTestConfig(
name="deeplab_v3_fp32_correctness_test",
imported_model=iree_definitions.ImportedModel.from_model(
tflite_models.DEEPLABV3_FP32
),
execution_config=module_execution_configs.ELF_LOCAL_SYNC_CONFIG,
expected_output="https://storage.googleapis.com/iree-model-artifacts/deeplab_v3_fp32_input_0_expected_output.npy",
extra_test_flags=["--expected_f32_threshold=0.001"],
unsupported_platforms=[
CMakePlatform.LINUX_RISCV32,
CMakePlatform.LINUX_RISCV64,
],
),
# person_detect_int8_correctness_test
ModelTestConfig(
name="person_detect_int8_correctness_test",
imported_model=iree_definitions.ImportedModel.from_model(
tflite_models.PERSON_DETECT_INT8
),
execution_config=module_execution_configs.ELF_LOCAL_SYNC_CONFIG,
expected_output="1x2xi8=[72 -72]",
unsupported_platforms=[CMakePlatform.ANDROID_ARMV8_A],
),
]