blob: 543f64ecf011b9fbf4080a873b6e60911fba4720 [file] [log] [blame]
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import pytest
from typing import Dict, Sequence, Union
from pathlib import Path
import subprocess
import time
from .artifacts import (
Artifact,
FetchedArtifact,
ProducedArtifact,
)
IREE_COMPILE_QOL_FLAGS = [
"--mlir-timing",
"--mlir-timing-display=list",
"--iree-consteval-jit-debug",
]
def fetch_source_fixture(url: str, *, group: str):
@pytest.fixture
def fetcher() -> FetchedArtifact:
art = FetchedArtifact(url=url, group=group)
art.start()
return art
return fetcher
def iree_compile(source: Artifact, compiled_variant: str, flags: Sequence[str]):
name = Path(source.name).with_suffix(f".{compiled_variant}.vmfb")
def callback(vmfb_artifact: ProducedArtifact):
sep = "\n "
print("**************************************************************")
print(f"Compiling {source} -> {vmfb_artifact} with flags:")
print(f" {sep.join(flags)}")
exec_args = (
[
"iree-compile",
"-o",
str(vmfb_artifact.path),
str(source.path),
]
+ IREE_COMPILE_QOL_FLAGS
+ flags
)
start_time = time.time()
subprocess.run(
exec_args, check=True, capture_output=True, cwd=source.group.directory
)
run_time = time.time() - start_time
print(f"Compilation succeeded in {run_time}s")
print("**************************************************************")
return ProducedArtifact(source.group, name, callback, depends=[source]).start()
def iree_run_module(vmfb: Artifact, *, device, function, args: Sequence[str] = ()):
vmfb.join()
exec_args = [
"iree-run-module",
f"--device={device}",
f"--module={vmfb.path}",
f"--function={function}",
]
exec_args.extend(args)
print("**************************************************************")
print("Exec:", " ".join(exec_args))
subprocess.run(exec_args, check=True, capture_output=True, cwd=vmfb.group.directory)
def iree_benchmark_module(
vmfb: Artifact, *, device, function, args: Sequence[str] = ()
):
vmfb.join()
exec_args = [
"iree-benchmark-module",
f"--device={device}",
f"--module={vmfb.path}",
f"--function={function}",
]
exec_args.extend(args)
print("**************************************************************")
print("Exec:", " ".join(exec_args))
subprocess.check_call(exec_args, cwd=vmfb.group.directory)