Add VmModule.mmap() to Python API. (#14124)
We really need page aligned flatbuffer blobs vs normal malloc alignment.
The best way to be loading a file is via mmap, so just make that
available as an API.
This could be done in Python by the caller but is error-prone. The
public API will make this more robust.
Provides the mechanism to fix #13887
diff --git a/runtime/bindings/python/tests/vm_test.py b/runtime/bindings/python/tests/vm_test.py
index bde6ae9..7ff7600 100644
--- a/runtime/bindings/python/tests/vm_test.py
+++ b/runtime/bindings/python/tests/vm_test.py
@@ -8,6 +8,7 @@
import logging
import numpy as np
+import tempfile
import unittest
import iree.compiler
@@ -106,6 +107,27 @@
logging.info("result: %s", result)
self.assertEqual(result, 11)
+ def test_mmap(self):
+ binary = iree.compiler.compile_str(
+ """
+ func.func @add_scalar(%arg0: i32, %arg1: i32) -> i32 {
+ %0 = arith.addi %arg0, %arg1 : i32
+ return %0 : i32
+ }
+ """,
+ target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS)
+ with tempfile.NamedTemporaryFile() as tf:
+ tf.write(binary)
+ tf.flush()
+ m = iree.runtime.VmModule.mmap(self.instance, tf.name)
+ context = iree.runtime.VmContext(self.instance,
+ modules=[self.hal_module, m])
+ f = m.lookup_function("add_scalar")
+ finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
+ result = finv(5, 6)
+ logging.info("result: %s", result)
+ self.assertEqual(result, 11)
+
def test_synchronous_dynamic_shape_invoke_function_new_abi(self):
m = create_simple_dynamic_abs_module(self.instance)
context = iree.runtime.VmContext(self.instance,
diff --git a/runtime/bindings/python/vm.cc b/runtime/bindings/python/vm.cc
index b72ecde..9a3a4fd 100644
--- a/runtime/bindings/python/vm.cc
+++ b/runtime/bindings/python/vm.cc
@@ -150,6 +150,23 @@
return py_module;
}
+VmModule VmModule::MMap(VmInstance* instance, std::string filepath) {
+ IREE_TRACE_SCOPE_NAMED("VmModule::MMap");
+ auto mmap_module = py::module::import("mmap");
+ auto open_func = py::module::import("io").attr("open");
+ auto file_obj = open_func(filepath, "r+b");
+ auto flags = py::cast<int64_t>(mmap_module.attr("MAP_SHARED"));
+ // MAP_POPULATE isn't available on all versions/platforms.
+ if (py::hasattr(mmap_module, "MAP_POPULATE")) {
+ flags |= py::cast<int64_t>(mmap_module.attr("MAP_POPULATE"));
+ }
+ auto prot = py::cast<int64_t>(mmap_module.attr("PROT_READ"));
+ auto mapped_file =
+ mmap_module.attr("mmap")(file_obj.attr("fileno")(), 0, flags, prot);
+ mapped_file.attr("madvise")(mmap_module.attr("MADV_RANDOM"));
+ return FromFlatbufferBlob(instance, mapped_file);
+}
+
VmModule VmModule::FromFlatbufferBlob(VmInstance* instance,
py::object flatbuffer_blob_object) {
IREE_TRACE_SCOPE_NAMED("VmModule::FromFlatbufferBlob");
@@ -661,6 +678,7 @@
.def_static("resolve_module_dependency",
&VmModule::ResolveModuleDependency)
.def_static("from_flatbuffer", &VmModule::FromFlatbufferBlob)
+ .def_static("mmap", &VmModule::MMap)
.def_property_readonly("name", &VmModule::name)
.def_property_readonly("version",
[](VmModule& self) {
diff --git a/runtime/bindings/python/vm.h b/runtime/bindings/python/vm.h
index 4b097e7..9bf25f4 100644
--- a/runtime/bindings/python/vm.h
+++ b/runtime/bindings/python/vm.h
@@ -133,6 +133,7 @@
const std::string& name,
uint32_t minimum_version);
+ static VmModule MMap(VmInstance* instance, std::string filepath);
static VmModule FromFlatbufferBlob(VmInstance* instance,
py::object flatbuffer_blob_object);