commit | e9f4b1edf11f066d82e8743a3ca0b5717380f35c | [log] [tgz] |
---|---|---|
author | Quinn Dawkins <quinn.dawkins@gmail.com> | Wed Oct 25 19:59:46 2023 -0400 |
committer | GitHub <noreply@github.com> | Wed Oct 25 23:59:46 2023 +0000 |
tree | bff2028b3d5ca934fedc140cae35960e2027c934 | |
parent | c4be76f6ba2e8c64c5dcf4750442a3299b016ce9 [diff] |
[Flow] Add TensorBitCastOp (#15260) This patch adds `flow.tensor.bitcast` as a near mirror of `flow.tensor.reshape`, however allowing changing element type bit widths as well. This allows earlier lowerings to skip materializing a constant tensor of some difficult to represent type and instead bitcast from a nicer byte-aligned and/or integer type. Similarly, this can help bridge the gap between frameworks, which might have limited support even for integers of varying bit widths, and IREE. In terms of direct applications today, this removes the need to materialize the sub-byte constant tensors for quantized LLMs like LLaMa that have seen recent burn-downs. As a result, we can store the constants as i8 instead of converting the elements one-by-one to APInt to allow MLIR to represent the constant tensor, and instead just keep the values as is from the frontend (the frontend is giving it to us packed!). This should improve memory usage at compile time by at least, say, a factor of 2 for `i4` (or more, not sure how APInt is storing those values), as well as give significant compile time gains both at load and serialization time. One potential issue with this op, we would have a situation where the storage of a non-power-of-two sub-byte resource could be ambiguous from codegen's perspective. Currently let's say we have a constant of `tensor<64xi3>`. Currently this would be serialized such that 2 `i3`s are packed per byte with 2 wasted bits. If we instead had a constant of `tensor<24xi8>` and bitcasted to `tensor<64xi3>`, semantically a bitcast is a no-op and thus the resulting `i3` tensor will be stored with 24 bytes as opposed to 32. Codegen would only see the interface binding + offset for a `tensor<64xi3>` and thus can't know how to generate code. (see compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/constant_encoding.mlir) but ... right now codegen can't really handle non-power-of-two types right now anyway so I ignored the problem for now :/
IREE (Intermediate Representation Execution Environment, pronounced as “eerie”) is an MLIR-based end-to-end compiler and runtime that lowers Machine Learning (ML) models to a unified IR that scales up to meet the needs of the datacenter and down to satisfy the constraints and special considerations of mobile and edge deployments.
See our website for project details, user guides, and instructions on building from source.
IREE is still in its early phase. We have settled down on the overarching infrastructure and are actively improving various software components as well as project logistics. It is still quite far from ready for everyday use and is made available without any support at the moment. With that said, we welcome any kind of feedback on any communication channels!
See our website for more information.
IREE is licensed under the terms of the Apache 2.0 License with LLVM Exceptions. See LICENSE for more information.