blob: 0a5af53ae9afd55866209461bd78fde31de74f2d [file] [log] [blame]
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import pytest
from ireers_tools import *
import os
import setuptools
from conftest import VmfbManager
from pathlib import Path
iree_test_path_extension = os.getenv("IREE_TEST_PATH_EXTENSION", default=Path.cwd())
vmfb_dir = os.getenv("TEST_OUTPUT_ARTIFACTS", default=Path.cwd())
rocm_chip = os.getenv("ROCM_CHIP", default="gfx90a")
###############################################################################
# Fixtures
###############################################################################
sdxl_unet_inference_input_0 = fetch_source_fixture(
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-scheduled-unet/inference_input.0.bin",
group="sdxl_unet",
)
sdxl_unet_inference_input_1 = fetch_source_fixture(
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-scheduled-unet/inference_input.1.bin",
group="sdxl_unet",
)
sdxl_unet_inference_input_2 = fetch_source_fixture(
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-scheduled-unet/inference_input.2.bin",
group="sdxl_unet",
)
sdxl_unet_inference_input_3 = fetch_source_fixture(
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-scheduled-unet/inference_input.3.bin",
group="sdxl_unet",
)
sdxl_unet_inference_output_0 = fetch_source_fixture(
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-scheduled-unet/inference_output.0.bin",
group="sdxl_unet",
)
sdxl_unet_real_weights = fetch_source_fixture(
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-scheduled-unet/real_weights.irpa",
group="sdxl_unet",
)
sdxl_unet_mlir = fetch_source_fixture(
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-scheduled-unet/model.mlirbc",
group="sdxl_unet",
)
sdxl_unet_pipeline_mlir = fetch_source_fixture(
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-scheduled-unet/sdxl_unet_pipeline_bench_f16.mlir",
group="sdxl_unet",
)
CPU_COMPILE_FLAGS = [
"--iree-hal-target-backends=llvm-cpu",
"--iree-llvmcpu-target-cpu-features=host",
"--iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false",
"--iree-llvmcpu-distribution-size=32",
"--iree-opt-const-eval=false",
"--iree-llvmcpu-enable-ukernels=all",
"--iree-global-opt-enable-quantized-matmul-reassociation",
]
@pytest.fixture
def SDXL_UNET_COMMON_RUN_FLAGS(
sdxl_unet_inference_input_0,
sdxl_unet_inference_input_1,
sdxl_unet_inference_input_2,
sdxl_unet_inference_input_3,
sdxl_unet_inference_output_0,
):
return [
f"--input=1x4x128x128xf16=@{sdxl_unet_inference_input_0.path}",
f"--input=2x64x2048xf16=@{sdxl_unet_inference_input_1.path}",
f"--input=2x1280xf16=@{sdxl_unet_inference_input_2.path}",
f"--input=1xf16=@{sdxl_unet_inference_input_3.path}",
f"--expected_output=1x4x128x128xf16=@{sdxl_unet_inference_output_0.path}",
]
ROCM_COMPILE_FLAGS = [
"--iree-hal-target-backends=rocm",
f"--iree-hip-target={rocm_chip}",
"--iree-opt-const-eval=false",
f"--iree-codegen-transform-dialect-library={iree_test_path_extension}/attention_and_matmul_spec.mlir",
"--iree-global-opt-propagate-transposes=true",
"--iree-flow-enable-fuse-horizontal-contractions=true",
"--iree-flow-enable-aggressive-fusion=true",
"--iree-opt-aggressively-propagate-transposes=true",
"--iree-opt-outer-dim-concat=true",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-llvmgpu-enable-prefetch=true",
"--iree-opt-data-tiling=false",
"--iree-codegen-gpu-native-math-precision=true",
"--iree-codegen-llvmgpu-use-vector-distribution",
"--iree-hip-waves-per-eu=2",
"--iree-execution-model=async-external",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)",
"--iree-scheduling-dump-statistics-format=json",
"--iree-scheduling-dump-statistics-file=compilation_info.json",
]
ROCM_PIPELINE_COMPILE_FLAGS = [
"--iree-hal-target-backends=rocm",
f"--iree-hip-target={rocm_chip}",
"--verify=false",
"--iree-opt-const-eval=false",
]
###############################################################################
# CPU
###############################################################################
def test_compile_unet_pipeline_cpu(sdxl_unet_pipeline_mlir):
VmfbManager.sdxl_unet_cpu_pipeline_vmfb = iree_compile(
sdxl_unet_pipeline_mlir,
CPU_COMPILE_FLAGS,
Path(vmfb_dir)
/ Path("sdxl_unet_vmfbs")
/ Path(sdxl_unet_pipeline_mlir.path.name).with_suffix(f".cpu.vmfb"),
)
def test_compile_unet_cpu(sdxl_unet_mlir):
VmfbManager.sdxl_unet_cpu_vmfb = iree_compile(
sdxl_unet_mlir,
CPU_COMPILE_FLAGS,
Path(vmfb_dir)
/ Path("sdxl_unet_vmfbs")
/ Path(sdxl_unet_mlir.path.name).with_suffix(f".cpu.vmfb"),
)
@pytest.mark.depends(on=["test_compile_unet_pipeline_cpu", "test_compile_unet_cpu"])
def test_run_unet_cpu(SDXL_UNET_COMMON_RUN_FLAGS, sdxl_unet_real_weights):
return iree_run_module(
VmfbManager.sdxl_unet_cpu_vmfb,
device="local-task",
function="produce_image_latents",
args=[
f"--parameters=model={sdxl_unet_real_weights.path}",
f"--module={VmfbManager.sdxl_unet_cpu_pipeline_vmfb}",
"--expected_f16_threshold=0.8f",
]
+ SDXL_UNET_COMMON_RUN_FLAGS,
)
###############################################################################
# ROCM
###############################################################################
def test_compile_unet_pipeline_rocm(sdxl_unet_pipeline_mlir):
VmfbManager.sdxl_unet_rocm_pipeline_vmfb = iree_compile(
sdxl_unet_pipeline_mlir,
ROCM_PIPELINE_COMPILE_FLAGS,
Path(vmfb_dir)
/ Path("sdxl_unet_vmfbs")
/ Path(sdxl_unet_pipeline_mlir.path.name).with_suffix(
f".rocm_{rocm_chip}.vmfb"
),
)
def test_compile_unet_rocm(sdxl_unet_mlir):
VmfbManager.sdxl_unet_rocm_vmfb = iree_compile(
sdxl_unet_mlir,
ROCM_COMPILE_FLAGS,
Path(vmfb_dir)
/ Path("sdxl_unet_vmfbs")
/ Path(sdxl_unet_mlir.path.name).with_suffix(f".rocm_{rocm_chip}.vmfb"),
)
@pytest.mark.depends(on=["test_compile_unet_pipeline_rocm", "test_compile_unet_rocm"])
def test_run_unet_rocm(SDXL_UNET_COMMON_RUN_FLAGS, sdxl_unet_real_weights):
return iree_run_module(
VmfbManager.sdxl_unet_rocm_vmfb,
device="hip",
function="produce_image_latents",
args=[
f"--parameters=model={sdxl_unet_real_weights.path}",
f"--module={VmfbManager.sdxl_unet_rocm_pipeline_vmfb}",
"--expected_f16_threshold=0.705f",
]
+ SDXL_UNET_COMMON_RUN_FLAGS,
)