| # Copyright 2023 The IREE Authors |
| # |
| # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| import iree.compiler |
| import argparse |
| import iree.runtime |
| from iree.runtime.array_interop import DeviceArray |
| from mpi4py import MPI |
| import test_utils |
| |
| """ |
| Run 1 rank in a destributed context. |
| To start 4 ranks you would use |
| ``` |
| mpirun -n 4 python run_rank.py ... |
| ``` |
| """ |
| |
| |
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Run 1 rank in a destributed context.") |
| parser.add_argument("--driver", type=str, default="local-task", help="Device URI.") |
| parser.add_argument( |
| "--module_filepath", type=str, required=True, help="Path to IREE module." |
| ) |
| parser.add_argument( |
| "--function", type=str, required=True, help="Name of function to call." |
| ) |
| parser.add_argument( |
| "--inputs", |
| nargs="+", |
| type=str, |
| required=True, |
| help="Path to IREE module inputs for all ranks in npy format.", |
| ) |
| parser.add_argument( |
| "--outputs", |
| nargs="+", |
| type=str, |
| required=True, |
| help="Path to IREE module outputs form all ranks in npy format.", |
| ) |
| return parser.parse_args() |
| |
| |
| def run_module( |
| device: iree.runtime.HalDevice, |
| module_filepath: str, |
| function: str, |
| input_filepath: str, |
| output_filepath: str, |
| ) -> DeviceArray: |
| config = iree.runtime.Config(device=device) |
| with open(module_filepath, "rb") as f: |
| vm_flatbuffer = f.read() |
| vm_module = iree.runtime.VmModule.from_flatbuffer(config.vm_instance, vm_flatbuffer) |
| bound_module = iree.runtime.load_vm_module(vm_module, config) |
| input_args = test_utils.read_numpy_arrays_from_file(input_filepath) |
| results = getattr(bound_module, function)(*input_args) |
| if isinstance(results, DeviceArray): |
| results = [results] |
| test_utils.write_numpy_arrays_to_file(filepath=output_filepath, arrays=results) |
| |
| |
| def run_shard( |
| driver: str, |
| module_filepath: str, |
| function: str, |
| inputs: str, |
| outputs: str, |
| ): |
| rank = MPI.COMM_WORLD.Get_rank() |
| hal_driver = iree.runtime.get_driver(driver) |
| device_infos = hal_driver.query_available_devices() |
| device = hal_driver.create_device( |
| device_infos[rank % len(device_infos)]["device_id"] |
| ) |
| run_module( |
| device=device, |
| module_filepath=module_filepath, |
| function=function, |
| input_filepath=inputs[rank], |
| output_filepath=outputs[rank], |
| ) |
| |
| |
| if __name__ == "__main__": |
| args = parse_args() |
| run_shard(**vars(args)) |