| # Copyright 2022 The IREE Authors |
| # |
| # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| import numpy as np |
| import os |
| import iree.compiler as compiler |
| import iree.runtime as rt |
| |
| TOKEN_TABLE = [ |
| b"hi", |
| b"thanks", |
| b"bye", |
| b"for", |
| b"all", |
| b"so", |
| b"the", |
| b"fish", |
| b"the", |
| b"end", |
| b"long", |
| b"and", |
| b".", |
| b"now", |
| b"there", |
| ] |
| |
| |
| def create_tokenizer_module(): |
| """Creates a module which defines some custom methods for decoding.""" |
| |
| class Detokenizer: |
| def __init__(self, iface): |
| # Any class state here is maintained per-context. |
| self.start_of_text = True |
| self.start_of_sentence = True |
| |
| def reset(self): |
| self.start_of_text = True |
| self.start_of_sentence = True |
| |
| def accumtokens(self, ids_tensor_ref, token_list_ref): |
| # TODO: This little dance to turn BufferView refs into real arrays... is not good. |
| ids_bv = ids_tensor_ref.deref(rt.HalBufferView) |
| ids_array = ids_bv.map().asarray( |
| ids_bv.shape, rt.HalElementType.map_to_dtype(ids_bv.element_type) |
| ) |
| token_list = token_list_ref.deref(rt.VmVariantList) |
| for index in range(ids_array.shape[0]): |
| token_id = ids_array[index] |
| token = TOKEN_TABLE[token_id] |
| |
| # And this dance to make a buffer... is also not good. |
| # A real implementation would just map the constant memory, etc. |
| buffer = rt.VmBuffer(len(token)) |
| buffer_view = memoryview(buffer) |
| buffer_view[:] = token |
| token_list.push_ref(buffer) |
| return ids_array.shape[0] |
| |
| def jointokens(self, token_list_ref): |
| # The world's dumbest detokenizer. Ideally, the state tracking |
| # would be in a module private type that got retained and passed |
| # back through. |
| token_list = token_list_ref.deref(rt.VmVariantList) |
| text = bytearray() |
| for i in range(len(token_list)): |
| item = bytes(token_list.get_as_object(i, rt.VmBuffer)) |
| if item == b".": |
| text.extend(b".") |
| self.start_of_sentence = True |
| else: |
| if not self.start_of_text: |
| text.extend(b" ") |
| else: |
| self.start_of_text = False |
| if self.start_of_sentence: |
| text.extend(item[0:1].decode("utf-8").upper().encode("utf-8")) |
| text.extend(item[1:]) |
| self.start_of_sentence = False |
| else: |
| text.extend(item) |
| |
| # TODO: This dance to make a buffer is still bad. |
| results = rt.VmBuffer(len(text)) |
| memoryview(results)[:] = text |
| return results.ref |
| |
| iface = rt.PyModuleInterface("detokenizer", Detokenizer) |
| iface.export("accumtokens", "0rr_i", Detokenizer.accumtokens) |
| iface.export("jointokens", "0r_r", Detokenizer.jointokens) |
| iface.export("reset", "0v_v", Detokenizer.reset) |
| return iface.create() |
| |
| |
| def compile(): |
| return compiler.tools.compile_file( |
| os.path.join(os.path.dirname(__file__), "main.mlir"), target_backends=["vmvx"] |
| ) |
| |
| |
| def main(): |
| print("Compiling...") |
| vmfb_contents = compile() |
| print("Decoding secret message...") |
| config = rt.Config("local-sync") |
| main_module = rt.VmModule.from_flatbuffer(config.vm_instance, vmfb_contents) |
| modules = config.default_vm_modules + ( |
| create_tokenizer_module(), |
| main_module, |
| ) |
| context = rt.SystemContext(vm_modules=modules, config=config) |
| |
| # First message. |
| count = context.modules.main.add_tokens( |
| np.asarray([5, 10, 11, 1, 3, 4, 5, 7, 12], dtype=np.int32) |
| ) |
| print(f"ADDED {count} tokens") |
| |
| # Second message. |
| count = context.modules.main.add_tokens(np.asarray([2, 13], dtype=np.int32)) |
| print(f"ADDED {count} tokens") |
| |
| text = bytes(context.modules.main.get_results().deref(rt.VmBuffer)) |
| print(f"RESULTS: {text}") |
| |
| assert text == b"So long and thanks for all so fish. Bye now" |
| |
| # Reset and decode some more. |
| context.modules.main.reset() |
| count = context.modules.main.add_tokens(np.asarray([0, 14, 12], dtype=np.int32)) |
| print(f"ADDED {count} tokens") |
| text = bytes(context.modules.main.get_results().deref(rt.VmBuffer)) |
| print(f"RESULTS: {text}") |
| assert text == b"Hi there." |
| |
| |
| if __name__ == "__main__": |
| main() |