blob: 3d7d4804401240161ffecf3eb049c30a4a0a8ac9 [file] [log] [blame]
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import numpy as np
from pyiree import binding as binding
def create_simple_mul_module():
blob = binding.compiler.compile_module_from_asm("""
func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
attributes { iree.module.export } {
%0 = "xla_hlo.mul"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
""")
m = binding.vm.create_module_from_blob(blob)
return m
def create_host_buffer_view(context):
b = context.allocate_device_visible(16)
b.fill_zero(0, 16)
bv = b.create_view(binding.hal.Shape([4]), 4)
print("BUFFER VIEW:", bv)
return bv
class RuntimeTest(absltest.TestCase):
def testModuleAndFunction(self):
m = create_simple_mul_module()
print("Module:", m)
print("Module name:", m.name)
self.assertEqual("module", m.name)
# Function 0.
f = m.lookup_function_by_ordinal(0)
print("Function 0:", f)
self.assertEqual("simple_mul", f.name)
sig = f.signature
self.assertEqual(2, sig.argument_count)
self.assertEqual(1, sig.result_count)
# Function 1.
f = m.lookup_function_by_ordinal(1)
self.assertIs(f, None)
# By name.
f = m.lookup_function_by_name("simple_mul")
self.assertEqual("simple_mul", f.name)
sig = f.signature
self.assertEqual(2, sig.argument_count)
self.assertEqual(1, sig.result_count)
# By name not found.
f = m.lookup_function_by_name("not_here")
self.assertIs(f, None)
def testInitialization(self):
policy = binding.rt.Policy()
print("policy =", policy)
instance = binding.rt.Instance()
print("instance =", instance)
context = binding.rt.Context(instance=instance, policy=policy)
print("context =", context)
context_id = context.context_id
print("context_id =", context.context_id)
self.assertGreater(context_id, 0)
def testRegisterModule(self):
policy = binding.rt.Policy()
instance = binding.rt.Instance()
context = binding.rt.Context(instance=instance, policy=policy)
m = create_simple_mul_module()
context.register_module(m)
self.assertIsNot(context.lookup_module_by_name("module"), None)
self.assertIs(context.lookup_module_by_name("nothere"), None)
f = context.resolve_function("module.simple_mul")
self.assertIsNot(f, None)
print("Resolved function:", f.name)
self.assertIs(context.resolve_function("module.nothere"), None)
def testInvoke(self):
policy = binding.rt.Policy()
instance = binding.rt.Instance()
context = binding.rt.Context(instance=instance, policy=policy)
m = create_simple_mul_module()
context.register_module(m)
f = context.resolve_function("module.simple_mul")
print("INVOKE F:", f)
arg0 = context.wrap_for_input(np.array([1., 2., 3., 4.], dtype=np.float32))
arg1 = context.wrap_for_input(np.array([4., 5., 6., 7.], dtype=np.float32))
inv = context.invoke(f, policy, [arg0, arg1])
print("Status:", inv.query_status())
inv.await_ready()
results = inv.results
print("Results:", results)
result = results[0].map()
print("Mapped result:", result)
result_ary = np.array(result, copy=False)
print("NP result:", result_ary)
self.assertEqual(4., result_ary[0])
self.assertEqual(10., result_ary[1])
self.assertEqual(18., result_ary[2])
self.assertEqual(28., result_ary[3])
if __name__ == "__main__":
# Uncomment to initialize the extension with custom flags.
# binding.initialize_extension(["--logtostderr"])
absltest.main()