| // Copyright 2025 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package common |
| |
| import chisel3._ |
| import chisel3.util._ |
| |
| // Implements the vector gather operation such that: |
| // result[i] = data[indices[i]] |
| object Gather { |
| def apply[T <: Data](indices: Vec[UInt], data: Vec[T]): Vec[T] = { |
| assert((1 << indices(0).getWidth) == data.length) |
| VecInit(indices.map(idx => data(idx))) |
| } |
| } |
| |
| // Performs a scatter operation. |
| // |
| // Scatter a vector of data elements into a result vector specified by an |
| // indices vector. A validity vector determines if the operation should be |
| // scattered or not. |
| // |
| // If two elements are scattered to the same location, the value of the first |
| // element is stored in the result vector. A selection vector is returned to |
| // indicate to the caller which elements were written. |
| // |
| // @param valid A `Vec[Bool]` where each element indicates if the corresponding |
| // data element and index are valid for the current scatter |
| // operation. |
| // @param indices A `Vec[UInt]` where each element `indices(i)` specifies the |
| // target index in the output vector for the corresponding |
| // `data(i)` element. The width of the elements in this vector |
| // determines the number of entries in the output `result` vector |
| // (r2^width). Must have the same length as the `valid` vector. |
| // @param data A `Vec[T]` where `T` is a subtype of `Data`. This vector contains |
| // the input data elements to be scattered. Must have the same |
| // length as the `valid` vector. |
| // @return A tuple containing three elements: |
| // 1. result (`Vec[T]`): The resulting vector after scattering the input |
| // data. |
| // 2. resultMask (`Vec[Bool]`): A bitmask indicating which positions in the |
| // `result` vector were written to during this scatter operation. Same |
| // length as `result`. |
| // 3. `indicesSelected` (`Vec[Bool]`): A bitmask indicating which elements |
| // from the input `data` (and `indices`) vector were successfully |
| // written in this scatter operation. Same length as `valid`, `indices` |
| // and `data` inputs. |
| object Scatter { |
| def apply[T <: Data](valid: Vec[Bool], |
| indices: Vec[UInt], |
| data: Vec[T]): (Vec[T], Vec[Bool], Vec[Bool]) = { |
| assert(valid.length == data.length) |
| assert(indices.length == data.length) |
| val dtype = chiselTypeOf(data(0)) |
| val indexWidth = indices(0).getWidth |
| // Prevent scattering to a unreasonably wide vector. Limit to ~65k elements. |
| assert(indexWidth <= 16) |
| val resultLength = 1 << indexWidth |
| |
| |
| // Generate resultMask and indicesSelected using a "selectionMatrix". |
| // resultMask tracks which bytes of a busLine are active for this |
| // transaction of a scatter operation. |
| // indicesSelected specifies which elements of the data vector were used for |
| // this transaction of a scatter operation. |
| // The selection matrix is a indicesSelected.length row by resultMask.length |
| // col binary matrix. |
| val validMatrix = (0 until indices.length).map(idx => |
| Mux(valid(idx), UIntToOH(indices(idx)), 0.U(resultLength.W))) |
| val valueSet = validMatrix.scan(0.U(resultLength.W))(_|_) |
| val selectionMatrix = (0 until indices.length).map( |
| idx => validMatrix(idx) & ~valueSet(idx)) |
| val resultMask = VecInit(selectionMatrix.reduce(_|_).asBools) |
| val indicesSelected = VecInit(selectionMatrix.map(x => (x =/= 0.U))) |
| |
| // Assertions |
| // Each row/column should have at most 1 element set. (Disabled for speed) |
| // selectionMatrix.foreach(x => assert(PopCount(x) <= 1.U)) |
| // (0 until resultLength).foreach( |
| // i => assert(PopCount(selectionMatrix.map(_(i))) <= 1.U)) |
| // Check indicesSelected is contained in valid |
| assert(PopCount((0 until indices.length).map( |
| i => indicesSelected(i) & ~valid(i))) === 0.U) |
| |
| // TODO(derekjchow): Review semantics for "ordered" and "unordered", and |
| // implement behaviours correctly. |
| |
| val result = Wire(Vec(resultLength, dtype)) |
| for (i <- 0 until resultLength) { |
| result(i) := MuxCase(0.U.asTypeOf(dtype), |
| (0 until indices.length).map(idx => |
| (valid(idx) && (indices(idx) === i.U)) -> data(idx) |
| )) |
| } |
| |
| (result, resultMask, indicesSelected) |
| } |
| } |