blob: ad95a8b3f4dfac9770d772e3dac0787936f6bc0b [file] [log] [blame]
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import array
import logging
import numpy as np
from pathlib import Path
import unittest
import iree.compiler
import iree.runtime as rt
MM_TEST_COMPILED = None
MM_TEST_ASM = r"""
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d1, d0)>
#map2 = affine_map<(d0, d1) -> (d1)>
module @main {
util.global private @_params.classifier.weight {noinline} = #stream.parameter.named<"params"::"weight"> : tensor<30x20xf32>
util.global private @_params.classifier.bias {noinline} = #stream.parameter.named<"params"::"bias"> : tensor<30xf32>
func.func @run(%arg0: tensor<128x20xf32>) -> tensor<128x30xf32> {
%0 = call @forward(%arg0) : (tensor<128x20xf32>) -> tensor<128x30xf32>
return %0 : tensor<128x30xf32>
}
func.func private @forward(%arg0: tensor<128x20xf32>) -> tensor<128x30xf32> attributes {torch.assume_strict_symbolic_shapes} {
%cst = arith.constant 0.000000e+00 : f32
%_params.classifier.weight = util.global.load @_params.classifier.weight : tensor<30x20xf32>
%0 = tensor.empty() : tensor<20x30xf32>
%1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%_params.classifier.weight : tensor<30x20xf32>) outs(%0 : tensor<20x30xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<20x30xf32>
%2 = tensor.empty() : tensor<128x30xf32>
%3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<128x30xf32>) -> tensor<128x30xf32>
%4 = linalg.matmul ins(%arg0, %1 : tensor<128x20xf32>, tensor<20x30xf32>) outs(%3 : tensor<128x30xf32>) -> tensor<128x30xf32>
%_params.classifier.bias = util.global.load @_params.classifier.bias : tensor<30xf32>
%5 = linalg.generic {indexing_maps = [#map, #map2, #map], iterator_types = ["parallel", "parallel"]} ins(%4, %_params.classifier.bias : tensor<128x30xf32>, tensor<30xf32>) outs(%2 : tensor<128x30xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%6 = arith.addf %in, %in_0 : f32
linalg.yield %6 : f32
} -> tensor<128x30xf32>
return %5 : tensor<128x30xf32>
}
}
"""
def compile_mm_test():
global MM_TEST_COMPILED
if not MM_TEST_COMPILED:
MM_TEST_COMPILED = iree.compiler.compile_str(
MM_TEST_ASM, target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS
)
return MM_TEST_COMPILED
def create_mm_test_module(instance):
binary = compile_mm_test()
return rt.VmModule.copy_buffer(instance, binary)
def _float_constant(val: float) -> array.array:
return array.array("f", [val])
class ParameterTest(unittest.TestCase):
def setUp(self):
self.instance = rt.VmInstance()
self.device = rt.get_device(iree.compiler.core.DEFAULT_TESTING_DRIVER)
self.config = rt.Config(device=self.device)
def testParameterIndex(self):
index = rt.ParameterIndex()
self.assertEqual(len(index), 0)
index.reserve(25)
self.assertEqual(len(index), 0)
provider = index.create_provider()
rt.create_io_parameters_module(self.instance, provider)
def testFileHandleWrap(self):
fh = rt.FileHandle.wrap_memory(b"foobar")
del fh
def testParameterIndexAddFromFile(self):
splat_index = rt.ParameterIndex()
fh = rt.FileHandle.wrap_memory(b"foobar")
splat_index.add_from_file_handle("data", fh, length=3, offset=3)
def testSplats(self):
splat_index = rt.ParameterIndex()
splat_index.add_splat("weight", _float_constant(2.0), 30 * 20 * 4)
splat_index.add_splat("bias", _float_constant(1.0), 30 * 4)
modules = rt.load_vm_modules(
rt.create_io_parameters_module(
self.instance, splat_index.create_provider(scope="params")
),
rt.create_hal_module(self.instance, self.device),
create_mm_test_module(self.instance),
config=self.config,
)
main = modules[-1]
input = np.zeros([128, 20], dtype=np.float32) + 2.0
result = main.run(input)
print(result.to_host())
# TODO: Fix splat in the parameter code so it is not all zeros.
# expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
# np.testing.assert_array_almost_equal(result, expected_result)
def testBuffers(self):
index = rt.ParameterIndex()
weight = np.zeros([30, 20], dtype=np.float32) + 2.0
bias = np.zeros([30], dtype=np.float32) + 1.0
index.add_buffer("weight", weight)
index.add_buffer("bias", bias)
modules = rt.load_vm_modules(
rt.create_io_parameters_module(
self.instance, index.create_provider(scope="params")
),
rt.create_hal_module(self.instance, self.device),
create_mm_test_module(self.instance),
config=self.config,
)
main = modules[-1]
input = np.zeros([128, 20], dtype=np.float32) + 2.0
result = main.run(input)
print(result.to_host())
expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
np.testing.assert_array_almost_equal(result, expected_result)
def testGguf(self):
index = rt.ParameterIndex()
index.load(
str(
Path(__file__).resolve().parent
/ "testdata"
/ "parameter_weight_bias_1.gguf"
)
)
modules = rt.load_vm_modules(
rt.create_io_parameters_module(
self.instance, index.create_provider(scope="params")
),
rt.create_hal_module(self.instance, self.device),
create_mm_test_module(self.instance),
config=self.config,
)
main = modules[-1]
input = np.zeros([128, 20], dtype=np.float32) + 2.0
result = main.run(input)
print(result.to_host())
expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
np.testing.assert_array_almost_equal(result, expected_result)
def testSafetensors(self):
index = rt.ParameterIndex()
index.load(
str(
Path(__file__).resolve().parent
/ "testdata"
/ "parameter_weight_bias_1.safetensors"
)
)
modules = rt.load_vm_modules(
rt.create_io_parameters_module(
self.instance, index.create_provider(scope="params")
),
rt.create_hal_module(self.instance, self.device),
create_mm_test_module(self.instance),
config=self.config,
)
main = modules[-1]
input = np.zeros([128, 20], dtype=np.float32) + 2.0
result = main.run(input)
print(result.to_host())
expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
np.testing.assert_array_almost_equal(result, expected_result)
def testSplatTooBig(self):
splat_index = rt.ParameterIndex()
with self.assertRaises(ValueError):
splat_index.add_splat(
"weight", array.array("f", [1.0, 2.0, 3.0, 4.0, 5.0]), 30 * 20 * 4
)
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()