blob: 42682b4e90d771c284926daa034f33f518e14a54 [file] [log] [blame]
Stella Laurenzo2616ac72022-08-01 10:47:25 -07001# Copyright 2022 The IREE Authors
2#
3# Licensed under the Apache License v2.0 with LLVM Exceptions.
4# See https://llvm.org/LICENSE.txt for license information.
5# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
7import numpy as np
8import os
9import iree.compiler as compiler
10import iree.runtime as rt
11
12TOKEN_TABLE = [
13 b"hi",
14 b"thanks",
15 b"bye",
16 b"for",
17 b"all",
18 b"so",
19 b"the",
20 b"fish",
21 b"the",
22 b"end",
23 b"long",
24 b"and",
25 b".",
26 b"now",
27 b"there",
28]
29
30
31def create_tokenizer_module():
32 """Creates a module which defines some custom methods for decoding."""
33
34 class Detokenizer:
35
36 def __init__(self, iface):
37 # Any class state here is maintained per-context.
38 self.start_of_text = True
39 self.start_of_sentence = True
40
41 def reset(self):
42 self.start_of_text = True
43 self.start_of_sentence = True
44
45 def accumtokens(self, ids_tensor_ref, token_list_ref):
46 # TODO: This little dance to turn BufferView refs into real arrays... is not good.
47 ids_bv = ids_tensor_ref.deref(rt.HalBufferView)
48 ids_array = ids_bv.map().asarray(
49 ids_bv.shape, rt.HalElementType.map_to_dtype(ids_bv.element_type))
50 token_list = token_list_ref.deref(rt.VmVariantList)
51 for index in range(ids_array.shape[0]):
52 token_id = ids_array[index]
53 token = TOKEN_TABLE[token_id]
54
55 # And this dance to make a buffer... is also not good.
56 # A real implementation would just map the constant memory, etc.
57 buffer = rt.VmBuffer(len(token))
58 buffer_view = memoryview(buffer)
59 buffer_view[:] = token
60 token_list.push_ref(buffer)
61 return ids_array.shape[0]
62
63 def jointokens(self, token_list_ref):
64 # The world's dumbest detokenizer. Ideally, the state tracking
65 # would be in a module private type that got retained and passed
66 # back through.
67 token_list = token_list_ref.deref(rt.VmVariantList)
68 text = bytearray()
69 for i in range(len(token_list)):
70 item = bytes(token_list.get_as_object(i, rt.VmBuffer))
71 if item == b".":
72 text.extend(b".")
73 self.start_of_sentence = True
74 else:
75 if not self.start_of_text:
76 text.extend(b" ")
77 else:
78 self.start_of_text = False
79 if self.start_of_sentence:
80 text.extend(item[0:1].decode("utf-8").upper().encode("utf-8"))
81 text.extend(item[1:])
82 self.start_of_sentence = False
83 else:
84 text.extend(item)
85
86 # TODO: This dance to make a buffer is still bad.
87 results = rt.VmBuffer(len(text))
88 memoryview(results)[:] = text
89 return results.ref
90
91 iface = rt.PyModuleInterface("detokenizer", Detokenizer)
92 iface.export("accumtokens", "0rr_i", Detokenizer.accumtokens)
93 iface.export("jointokens", "0r_r", Detokenizer.jointokens)
94 iface.export("reset", "0v_v", Detokenizer.reset)
95 return iface.create()
96
97
98def compile():
Ben Vanik9aa83ed2022-08-06 12:55:34 -070099 return compiler.tools.compile_file(os.path.join(os.path.dirname(__file__),
100 "main.mlir"),
101 target_backends=["vmvx"])
Stella Laurenzo2616ac72022-08-01 10:47:25 -0700102
103
104def main():
105 print("Compiling...")
Ben Vanik9aa83ed2022-08-06 12:55:34 -0700106 vmfb_contents = compile()
Stella Laurenzo2616ac72022-08-01 10:47:25 -0700107 print("Decoding secret message...")
108 config = rt.Config("local-sync")
Ben Vanik9aa83ed2022-08-06 12:55:34 -0700109 main_module = rt.VmModule.from_flatbuffer(config.vm_instance, vmfb_contents)
Stella Laurenzo2616ac72022-08-01 10:47:25 -0700110 modules = config.default_vm_modules + (
111 create_tokenizer_module(),
112 main_module,
113 )
114 context = rt.SystemContext(vm_modules=modules, config=config)
115
116 # First message.
117 count = context.modules.main.add_tokens(
118 np.asarray([5, 10, 11, 1, 3, 4, 5, 7, 12], dtype=np.int32))
119 print(f"ADDED {count} tokens")
120
121 # Second message.
122 count = context.modules.main.add_tokens(np.asarray([2, 13], dtype=np.int32))
123 print(f"ADDED {count} tokens")
124
125 text = bytes(context.modules.main.get_results().deref(rt.VmBuffer))
126 print(f"RESULTS: {text}")
127
128 assert text == b"So long and thanks for all so fish. Bye now"
129
130 # Reset and decode some more.
131 context.modules.main.reset()
132 count = context.modules.main.add_tokens(
133 np.asarray([0, 14, 12], dtype=np.int32))
134 print(f"ADDED {count} tokens")
135 text = bytes(context.modules.main.get_results().deref(rt.VmBuffer))
136 print(f"RESULTS: {text}")
137 assert text == b"Hi there."
138
139
140if __name__ == "__main__":
141 main()