blob: 6633d3bb6eaca527ff4e5b02e348d28ea049c303 [file] [log] [blame] [edit]
# Copyright 2025 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import iree.runtime as rt
import iree.compiler as compiler
import numpy as np
import os
def main():
# Create a vmfb for the llvm-cpu backend.
#
# GPU debug hint: use iree-compile with all the flags to reproduce the
# numerical error, plus `--iree-flow-trace-dispatch-tensors` for
# trace points. Then use rt.VmModule.copy_buffer(config.vm_instance, .)
# to load the vmfb.
vmfb_contents = compiler.compile_file(
os.path.join(os.path.dirname(__file__), "model.mlir"),
target_backends=["llvm-cpu"],
extra_args=[
"--iree-flow-trace-dispatch-tensors",
"--iree-llvmcpu-target-cpu=host",
],
)
# GPU debug hint: Use 'hip' if running on an AMD GPU.
config = rt.Config("local-sync")
# The callback records the means of all tensors.
callback_results = []
def callback(key: str, buffer_views: list[rt.HalBufferView]):
for i, bv in enumerate(buffer_views):
arr = bv.map().asarray(
bv.shape, rt.HalElementType.map_to_dtype(bv.element_type)
)
callback_results.append([key, i, float(arr.sum())])
hal_module = rt.create_hal_module(
config.vm_instance,
config.device,
debug_sink=rt.HalModuleDebugSink(callback),
)
module = rt.VmModule.copy_buffer(config.vm_instance, vmfb_contents)
vm_modules = rt.load_vm_modules(hal_module, module, config=config)
# Perform softmax(matmul(A, B)):
A = 0.25 * np.ones((4, 16), dtype=np.float32)
B = np.ones((16, 16), dtype=np.float32)
vm_modules[-1].main(A, B)
# We assume that softmax(matmul(A, B)) compiled to 2 dispatches.
assert (
len(callback_results) == 5
), "2 dispatches. mm:2->1, sm:1->1. Expected 5 tensors."
# A
dispatch_name = callback_results[0][0]
tensor_index = callback_results[0][1]
tensor_sum = callback_results[0][2]
assert dispatch_name.startswith("main_dispatch_0")
assert dispatch_name.endswith("inputs")
assert tensor_index == 0
assert tensor_sum == 16.0
# B
assert callback_results[1][0].startswith("main_dispatch_0")
assert callback_results[1][0].endswith("inputs")
assert callback_results[1][1] == 1
assert callback_results[1][2] == 256.0
# matmul(A, B)
assert callback_results[2][0].startswith("main_dispatch_0")
assert callback_results[2][0].endswith("outputs")
assert callback_results[2][1] == 0
assert callback_results[2][2] == 256.0
# softmax inputs: matmul(A, B)
assert callback_results[3][0].startswith("main_dispatch_1")
assert callback_results[3][0].endswith("inputs")
assert callback_results[3][1] == 0
assert (
callback_results[3][2] == callback_results[2][2]
), "Softmax input sum should match matmul output sum."
# softmax(matmul(A, B))
assert callback_results[4][0].startswith("main_dispatch_1")
assert callback_results[4][0].endswith("outputs")
assert callback_results[4][1] == 0
assert callback_results[4][2] == 4.0
if __name__ == "__main__":
main()