blob: 4dd0b7cad1852f62fb8042db73afb5ff937e5668 [file] [log] [blame]
# Copyright 2023 Stella Laurenzo
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# TODO: Upstream this to IREE.
import platform
if platform.system() == "Windows":
print("WARNING: Test disabled on Windows due to suspected MSVC bug")
else:
from contextlib import closing
import os
from pathlib import Path
import tempfile
import unittest
from iree.compiler.api import (
Session,
Source,
Output,
)
from iree.compiler import ir
class DlFlagsTest(unittest.TestCase):
def testDefaultFlags(self):
session = Session()
flags = session.get_flags()
print(flags)
self.assertIn("--iree-input-type=auto", flags)
def testNonDefaultFlags(self):
session = Session()
flags = session.get_flags(non_default_only=True)
self.assertEqual(flags, [])
session.set_flags("--iree-input-type=none")
flags = session.get_flags(non_default_only=True)
self.assertIn("--iree-input-type=none", flags)
def testFlagsAreScopedToSession(self):
session1 = Session()
session2 = Session()
session1.set_flags("--iree-input-type=tosa")
session2.set_flags("--iree-input-type=none")
self.assertIn("--iree-input-type=tosa", session1.get_flags())
self.assertIn("--iree-input-type=none", session2.get_flags())
def testFlagError(self):
session = Session()
with self.assertRaises(ValueError):
session.set_flags("--does-not-exist=1")
def testOptFlags(self):
session = Session()
flags = session.get_flags()
self.assertIn("--iree-opt-level=O0", flags)
self.assertIn("--iree-global-optimization-opt-level=O0", flags)
self.assertIn("--iree-opt-strip-assertions=false", flags)
session.set_flags("--iree-opt-level=O2")
flags = session.get_flags()
self.assertIn("--iree-opt-level=O2", flags)
self.assertIn("--iree-global-optimization-opt-level=O0", flags)
self.assertIn("--iree-opt-strip-assertions=false", flags)
inv = session.invocation()
with tempfile.NamedTemporaryFile("w", delete=False) as tf:
tf.write("module {}")
tf.close()
try:
source = Source.open_file(session, tf.name)
inv.parse_source(source)
finally:
os.unlink(tf.name)
out = Output.open_membuffer()
inv.output_ir(out)
inv.execute()
flags = session.get_flags()
self.assertIn("--iree-opt-level=O2", flags)
self.assertIn("--iree-global-optimization-opt-level=O0", flags)
self.assertIn("--iree-opt-strip-assertions=false", flags)
class DlInvocationTest(unittest.TestCase):
def testCreate(self):
session = Session()
inv = session.invocation()
def testInputFile(self):
session = Session()
inv = session.invocation()
with tempfile.NamedTemporaryFile("w", delete=False) as tf:
tf.write("module {}")
tf.close()
try:
source = Source.open_file(session, tf.name)
inv.parse_source(source)
finally:
os.unlink(tf.name)
out = Output.open_membuffer()
inv.output_ir(out)
mem = out.map_memory()
self.assertIn(b"module", bytes(mem))
out.close()
def testInputBuffer(self):
session = Session()
inv = session.invocation()
source = Source.wrap_buffer(session, b"builtin.module {}")
inv.parse_source(source)
out = Output.open_membuffer()
inv.output_ir(out)
mem = out.map_memory()
self.assertIn(b"module", bytes(mem))
out.close()
def testInputBytecode(self):
this_dir = os.path.dirname(__file__)
with open(
os.path.join(this_dir, "testdata", "bytecode_testfile.bc"), "rb"
) as f:
bytecode = f.read()
session = Session()
inv = session.invocation()
source = Source.wrap_buffer(session, bytecode)
inv.parse_source(source)
out = Output.open_membuffer()
inv.output_ir(out)
mem = out.map_memory()
self.assertIn(b"module", bytes(mem))
out.close()
def testInputZeroTerminatedBytecode(self):
this_dir = os.path.dirname(__file__)
with open(
os.path.join(
this_dir, "testdata", "bytecode_zero_terminated_testfile.bc"
),
"rb",
) as f:
bytecode = f.read()
session = Session()
inv = session.invocation()
source = Source.wrap_buffer(session, bytecode)
inv.parse_source(source)
out = Output.open_membuffer()
inv.output_ir(out)
mem = out.map_memory()
self.assertIn(b"module", bytes(mem))
out.close()
def testInputRoundtrip(self):
test_ir = b"builtin.module {}"
session = Session()
inv = session.invocation()
source = Source.wrap_buffer(
session,
bytes(test_ir),
)
inv.parse_source(source)
out = Output.open_membuffer()
inv.output_ir_bytecode(out)
mem = out.map_memory()
bytecode = bytes(mem)
out.close()
session = Session()
inv = session.invocation()
source = Source.wrap_buffer(
session,
bytecode,
)
inv.parse_source(source)
out = Output.open_membuffer()
inv.output_ir(out)
mem = out.map_memory()
text_out = bytes(mem)
out.close()
self.assertIn(b"module", text_out)
def testOutputBytecode(self):
session = Session()
inv = session.invocation()
source = Source.wrap_buffer(session, b"builtin.module {}")
inv.parse_source(source)
out = Output.open_membuffer()
inv.output_ir_bytecode(out)
mem = out.map_memory()
self.assertIn(b"module", bytes(mem))
out.close()
def testExecutePassPipeline(self):
session = Session()
inv = session.invocation()
source = Source.wrap_buffer(
session,
b"""
builtin.module {
func.func private @foobar() -> ()
}
""",
)
inv.parse_source(source)
inv.execute_text_pass_pipeline("symbol-dce")
out = Output.open_membuffer()
inv.output_ir(out)
mem = out.map_memory()
self.assertNotIn(b"func", bytes(mem))
out.close()
def testExecuteStdPipeline(self):
session = Session()
session.set_flags("--iree-hal-target-device=local")
session.set_flags("--iree-hal-local-target-device-backends=vmvx")
inv = session.invocation()
source = Source.wrap_buffer(
session,
b"""
builtin.module {
func.func @main(%arg0: i32) -> (i32) {
return %arg0 : i32
}
}
""",
)
inv.parse_source(source)
inv.execute()
out = Output.open_membuffer()
inv.output_vm_bytecode(out)
out.close()
class DlOutputTest(unittest.TestCase):
def testOpenMembuffer(self):
out = Output.open_membuffer()
def testOpenMembufferExplicitClose(self):
out = Output.open_membuffer()
out.close()
def testOpenMembufferWrite(self):
out = Output.open_membuffer()
out.write(b"foobar")
mem = out.map_memory()
self.assertEqual(b"foobar", bytes(mem))
out.close()
def testOpenFileNoKeep(self):
file_path = tempfile.mktemp()
out = Output.open_file(file_path)
try:
out.write(b"foobar")
self.assertTrue(Path(file_path).exists())
finally:
out.close()
# Didn't call keep, so should be deleted.
self.assertFalse(Path(file_path).exists())
def testOpenFileKeep(self):
file_path = tempfile.mktemp()
out = Output.open_file(file_path)
try:
try:
out.write(b"foobar")
out.keep()
finally:
out.close()
# Didn't call keep, so should be deleted.
with open(file_path, "rb") as f:
contents = f.read()
self.assertEqual(b"foobar", contents)
finally:
Path(file_path).unlink()
class DlInteropTest(unittest.TestCase):
def testContextFromSession(self):
s = Session()
# TODO: Test that multiple calls return the same context.
# TODO: Do gc stuff to verify memory.
context1 = s.context
context2 = s.context
self.assertIsNotNone(context1)
self.assertIs(context1, context2)
def testImportModule(self):
s = Session()
with ir.Location.unknown(s.context):
module_op = ir.Module.create().operation
module_op.attributes["test.test"] = ir.Attribute.parse('"working"')
inv = s.invocation()
inv.import_module(module_op)
# Round-trip it back through an Output and verify that the attribute
# we set is still there.
output = Output.open_membuffer()
inv.output_ir(output)
contents = bytes(output.map_memory()).decode()
print(contents)
self.assertIn('test.test = "working"', contents)
def testExportModule(self):
s = Session()
with ir.Location.unknown(s.context):
source = Source.wrap_buffer(s, b"builtin.module {}")
inv = s.invocation()
self.assertTrue(inv.parse_source(source))
module_op = inv.export_module()
module_op.attributes["test.test"] = ir.Attribute.parse('"working"')
# Round-trip it back through an Output and verify that the attribute
# we set is still there.
output = Output.open_membuffer()
inv.output_ir(output)
contents = bytes(output.map_memory()).decode()
print(contents)
self.assertIn('test.test = "working"', contents)
if __name__ == "__main__":
unittest.main()