| Stella Laurenzo | 2616ac7 | 2022-08-01 10:47:25 -0700 | [diff] [blame] | 1 | # Copyright 2022 The IREE Authors |
| 2 | # |
| 3 | # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | # See https://llvm.org/LICENSE.txt for license information. |
| 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | |
| 7 | import numpy as np |
| 8 | import os |
| 9 | import iree.compiler as compiler |
| 10 | import iree.runtime as rt |
| 11 | |
| 12 | TOKEN_TABLE = [ |
| 13 | b"hi", |
| 14 | b"thanks", |
| 15 | b"bye", |
| 16 | b"for", |
| 17 | b"all", |
| 18 | b"so", |
| 19 | b"the", |
| 20 | b"fish", |
| 21 | b"the", |
| 22 | b"end", |
| 23 | b"long", |
| 24 | b"and", |
| 25 | b".", |
| 26 | b"now", |
| 27 | b"there", |
| 28 | ] |
| 29 | |
| 30 | |
| 31 | def create_tokenizer_module(): |
| 32 | """Creates a module which defines some custom methods for decoding.""" |
| 33 | |
| 34 | class Detokenizer: |
| 35 | |
| 36 | def __init__(self, iface): |
| 37 | # Any class state here is maintained per-context. |
| 38 | self.start_of_text = True |
| 39 | self.start_of_sentence = True |
| 40 | |
| 41 | def reset(self): |
| 42 | self.start_of_text = True |
| 43 | self.start_of_sentence = True |
| 44 | |
| 45 | def accumtokens(self, ids_tensor_ref, token_list_ref): |
| 46 | # TODO: This little dance to turn BufferView refs into real arrays... is not good. |
| 47 | ids_bv = ids_tensor_ref.deref(rt.HalBufferView) |
| 48 | ids_array = ids_bv.map().asarray( |
| 49 | ids_bv.shape, rt.HalElementType.map_to_dtype(ids_bv.element_type)) |
| 50 | token_list = token_list_ref.deref(rt.VmVariantList) |
| 51 | for index in range(ids_array.shape[0]): |
| 52 | token_id = ids_array[index] |
| 53 | token = TOKEN_TABLE[token_id] |
| 54 | |
| 55 | # And this dance to make a buffer... is also not good. |
| 56 | # A real implementation would just map the constant memory, etc. |
| 57 | buffer = rt.VmBuffer(len(token)) |
| 58 | buffer_view = memoryview(buffer) |
| 59 | buffer_view[:] = token |
| 60 | token_list.push_ref(buffer) |
| 61 | return ids_array.shape[0] |
| 62 | |
| 63 | def jointokens(self, token_list_ref): |
| 64 | # The world's dumbest detokenizer. Ideally, the state tracking |
| 65 | # would be in a module private type that got retained and passed |
| 66 | # back through. |
| 67 | token_list = token_list_ref.deref(rt.VmVariantList) |
| 68 | text = bytearray() |
| 69 | for i in range(len(token_list)): |
| 70 | item = bytes(token_list.get_as_object(i, rt.VmBuffer)) |
| 71 | if item == b".": |
| 72 | text.extend(b".") |
| 73 | self.start_of_sentence = True |
| 74 | else: |
| 75 | if not self.start_of_text: |
| 76 | text.extend(b" ") |
| 77 | else: |
| 78 | self.start_of_text = False |
| 79 | if self.start_of_sentence: |
| 80 | text.extend(item[0:1].decode("utf-8").upper().encode("utf-8")) |
| 81 | text.extend(item[1:]) |
| 82 | self.start_of_sentence = False |
| 83 | else: |
| 84 | text.extend(item) |
| 85 | |
| 86 | # TODO: This dance to make a buffer is still bad. |
| 87 | results = rt.VmBuffer(len(text)) |
| 88 | memoryview(results)[:] = text |
| 89 | return results.ref |
| 90 | |
| 91 | iface = rt.PyModuleInterface("detokenizer", Detokenizer) |
| 92 | iface.export("accumtokens", "0rr_i", Detokenizer.accumtokens) |
| 93 | iface.export("jointokens", "0r_r", Detokenizer.jointokens) |
| 94 | iface.export("reset", "0v_v", Detokenizer.reset) |
| 95 | return iface.create() |
| 96 | |
| 97 | |
| 98 | def compile(): |
| Ben Vanik | 9aa83ed | 2022-08-06 12:55:34 -0700 | [diff] [blame] | 99 | return compiler.tools.compile_file(os.path.join(os.path.dirname(__file__), |
| 100 | "main.mlir"), |
| 101 | target_backends=["vmvx"]) |
| Stella Laurenzo | 2616ac7 | 2022-08-01 10:47:25 -0700 | [diff] [blame] | 102 | |
| 103 | |
| 104 | def main(): |
| 105 | print("Compiling...") |
| Ben Vanik | 9aa83ed | 2022-08-06 12:55:34 -0700 | [diff] [blame] | 106 | vmfb_contents = compile() |
| Stella Laurenzo | 2616ac7 | 2022-08-01 10:47:25 -0700 | [diff] [blame] | 107 | print("Decoding secret message...") |
| 108 | config = rt.Config("local-sync") |
| Ben Vanik | 9aa83ed | 2022-08-06 12:55:34 -0700 | [diff] [blame] | 109 | main_module = rt.VmModule.from_flatbuffer(config.vm_instance, vmfb_contents) |
| Stella Laurenzo | 2616ac7 | 2022-08-01 10:47:25 -0700 | [diff] [blame] | 110 | modules = config.default_vm_modules + ( |
| 111 | create_tokenizer_module(), |
| 112 | main_module, |
| 113 | ) |
| 114 | context = rt.SystemContext(vm_modules=modules, config=config) |
| 115 | |
| 116 | # First message. |
| 117 | count = context.modules.main.add_tokens( |
| 118 | np.asarray([5, 10, 11, 1, 3, 4, 5, 7, 12], dtype=np.int32)) |
| 119 | print(f"ADDED {count} tokens") |
| 120 | |
| 121 | # Second message. |
| 122 | count = context.modules.main.add_tokens(np.asarray([2, 13], dtype=np.int32)) |
| 123 | print(f"ADDED {count} tokens") |
| 124 | |
| 125 | text = bytes(context.modules.main.get_results().deref(rt.VmBuffer)) |
| 126 | print(f"RESULTS: {text}") |
| 127 | |
| 128 | assert text == b"So long and thanks for all so fish. Bye now" |
| 129 | |
| 130 | # Reset and decode some more. |
| 131 | context.modules.main.reset() |
| 132 | count = context.modules.main.add_tokens( |
| 133 | np.asarray([0, 14, 12], dtype=np.int32)) |
| 134 | print(f"ADDED {count} tokens") |
| 135 | text = bytes(context.modules.main.get_results().deref(rt.VmBuffer)) |
| 136 | print(f"RESULTS: {text}") |
| 137 | assert text == b"Hi there." |
| 138 | |
| 139 | |
| 140 | if __name__ == "__main__": |
| 141 | main() |