blob: 1e1e47fb632d34e82eccbd548360e4e8df3004e9 [file] [log] [blame]
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import numpy as np
import unittest
import iree.compiler
import iree.runtime
from iree.runtime.benchmark import (
benchmark_module,
BenchmarkTimeoutError,
)
def create_simple_mul_module(instance):
binary = iree.compiler.compile_str(
"""
module @test_module {
func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%0 = arith.mulf %arg0, %arg1 : tensor<4xf32>
return %0 : tensor<4xf32>
}
}
""",
target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
)
m = iree.runtime.VmModule.from_flatbuffer(instance, binary)
return m
def create_multiple_entry_functions_module(instance):
binary = iree.compiler.compile_str(
"""
module @test_module {
func.func @entry_1(%arg0: tensor<4xf32>) -> tensor<4xf32> {
%0 = math.absf %arg0 : tensor<4xf32>
return %0 : tensor<4xf32>
}
func.func @entry_2(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = math.absf %arg0 : tensor<2xf32>
return %arg0 : tensor<2xf32>
}
}
""",
target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
)
m = iree.runtime.VmModule.from_flatbuffer(instance, binary)
return m
def create_large_matmul_module(instance):
binary = iree.compiler.compile_str(
"""
module @test_module {
func.func @large_matmul(%arg0: tensor<4000x4000xf32>, %arg1: tensor<4000x4000xf32>, %arg2: tensor<4000x4000xf32>) -> tensor<4000x4000xf32> {
%0 = linalg.matmul ins(%arg0, %arg1: tensor<4000x4000xf32>, tensor<4000x4000xf32>)
outs(%arg2: tensor<4000x4000xf32>) -> tensor<4000x4000xf32>
return %0 : tensor<4000x4000xf32>
}
}
""",
target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
)
m = iree.runtime.VmModule.from_flatbuffer(instance, binary)
return m
class BenchmarkTest(unittest.TestCase):
def setUp(self):
super().setUp()
def testBenchmarkModule(self):
ctx = iree.runtime.SystemContext()
vm_module = create_simple_mul_module(ctx.instance)
arg0 = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
arg1 = np.array([5.0, 6.0, 7.0, 8.0], dtype=np.float32)
benchmark_results = benchmark_module(
vm_module,
device=iree.compiler.core.DEFAULT_TESTING_DRIVER,
inputs=[arg0, arg1],
)
self.assertEqual(len(benchmark_results), 1)
benchmark_time = float(benchmark_results[0].time.split(" ")[0])
self.assertGreater(benchmark_time, 0)
def testBenchmarkModuleWithEntryFunction(self):
ctx = iree.runtime.SystemContext()
vm_module = create_multiple_entry_functions_module(ctx.instance)
arg1 = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
arg2 = np.array([1.0, 2.0], dtype=np.float32)
benchmark_results_1 = benchmark_module(
vm_module,
entry_function="entry_1",
device=iree.compiler.core.DEFAULT_TESTING_DRIVER,
inputs=[arg1],
)
self.assertEqual(len(benchmark_results_1), 1)
benchmark_results_2 = benchmark_module(
vm_module,
entry_function="entry_2",
device=iree.compiler.core.DEFAULT_TESTING_DRIVER,
inputs=[arg2],
)
self.assertEqual(len(benchmark_results_2), 1)
def testBenchmarkModuleTimeout(self):
ctx = iree.runtime.SystemContext()
vm_module = create_large_matmul_module(ctx.instance)
arg0 = np.zeros([4000, 4000], dtype=np.float32)
arg1 = np.zeros([4000, 4000], dtype=np.float32)
arg2 = np.zeros([4000, 4000], dtype=np.float32)
with self.assertRaises(BenchmarkTimeoutError):
_ = benchmark_module(
vm_module,
device=iree.compiler.core.DEFAULT_TESTING_DRIVER,
inputs=[arg0, arg1, arg2],
timeout=0.1,
)
if __name__ == "__main__":
unittest.main()