| // Copyright 2025 Google LLC | 
 | // | 
 | // Licensed under the Apache License, Version 2.0 (the "License"); | 
 | // you may not use this file except in compliance with the License. | 
 | // You may obtain a copy of the License at | 
 | // | 
 | //     http://www.apache.org/licenses/LICENSE-2.0 | 
 | // | 
 | // Unless required by applicable law or agreed to in writing, software | 
 | // distributed under the License is distributed on an "AS IS" BASIS, | 
 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
 | // See the License for the specific language governing permissions and | 
 | // limitations under the License. | 
 |  | 
 | package common | 
 |  | 
 | import chisel3._ | 
 | import chisel3.util._ | 
 |  | 
 | // Implements the vector gather operation such that: | 
 | // result[i] = data[indices[i]] | 
 | object Gather { | 
 |   def apply[T <: Data](indices: Vec[UInt], data: Vec[T]): Vec[T] = { | 
 |     assert((1 << indices(0).getWidth) == data.length) | 
 |     VecInit(indices.map(idx => data(idx))) | 
 |   } | 
 | } | 
 |  | 
 | // Performs a scatter operation. | 
 | // | 
 | // Scatter a vector of data elements into a result vector specified by an | 
 | // indices vector. A validity vector determines if the operation should be | 
 | // scattered or not. | 
 | // | 
 | // If two elements are scattered to the same location, the value of the first | 
 | // element is stored in the result vector. A selection vector is returned to | 
 | // indicate to the caller which elements were written. | 
 | // | 
 | // @param valid A `Vec[Bool]` where each element indicates if the corresponding | 
 | //              data element and index are valid for the current scatter | 
 | //              operation. | 
 | // @param indices A `Vec[UInt]` where each element `indices(i)` specifies the | 
 | //                target index in the output vector for the corresponding | 
 | //                `data(i)` element. The width of the elements in this vector | 
 | //                determines the number of entries in the output `result` vector | 
 | //                (r2^width). Must have the same length as the `valid` vector. | 
 | // @param data A `Vec[T]` where `T` is a subtype of `Data`. This vector contains | 
 | //             the input data elements to be scattered. Must have the same | 
 | //             length as the `valid` vector. | 
 | // @return A tuple containing three elements: | 
 | //    1. result (`Vec[T]`): The resulting vector after scattering the input | 
 | //        data. | 
 | //    2. resultMask (`Vec[Bool]`): A bitmask indicating which positions in the | 
 | //        `result` vector were written to during this scatter operation. Same | 
 | //        length as `result`. | 
 | //    3. `indicesSelected` (`Vec[Bool]`): A bitmask indicating which elements | 
 | //        from the input `data` (and `indices`) vector were successfully | 
 | //        written in this scatter operation. Same length as `valid`, `indices` | 
 | //        and `data` inputs. | 
 | object Scatter { | 
 |   def apply[T <: Data](valid: Vec[Bool], | 
 |                        indices: Vec[UInt], | 
 |                        data: Vec[T]): (Vec[T], Vec[Bool], Vec[Bool]) = { | 
 |     assert(valid.length == data.length) | 
 |     assert(indices.length == data.length) | 
 |     val dtype = chiselTypeOf(data(0)) | 
 |     val indexWidth = indices(0).getWidth | 
 |     // Prevent scattering to a unreasonably wide vector. Limit to ~65k elements. | 
 |     assert(indexWidth <= 16) | 
 |     val resultLength = 1 << indexWidth | 
 |  | 
 |  | 
 |     // Generate resultMask and indicesSelected using a "selectionMatrix". | 
 |     // resultMask tracks which bytes of a busLine are active for this | 
 |     // transaction of a scatter operation. | 
 |     // indicesSelected specifies which elements of the data vector were used for | 
 |     // this transaction of a scatter operation. | 
 |     // The selection matrix is a indicesSelected.length row by resultMask.length | 
 |     // col binary matrix. | 
 |     val validMatrix = (0 until indices.length).map(idx => | 
 |         Mux(valid(idx), UIntToOH(indices(idx)), 0.U(resultLength.W))) | 
 |     val valueSet = validMatrix.scan(0.U(resultLength.W))(_|_) | 
 |     val selectionMatrix = (0 until indices.length).map( | 
 |         idx => validMatrix(idx) & ~valueSet(idx)) | 
 |     val resultMask = VecInit(selectionMatrix.reduce(_|_).asBools) | 
 |     val indicesSelected = VecInit(selectionMatrix.map(x => (x =/= 0.U))) | 
 |  | 
 |     // Assertions | 
 |     // Each row/column should have at most 1 element set. (Disabled for speed) | 
 |     // selectionMatrix.foreach(x => assert(PopCount(x) <= 1.U)) | 
 |     // (0 until resultLength).foreach( | 
 |     //     i => assert(PopCount(selectionMatrix.map(_(i))) <= 1.U)) | 
 |     // Check indicesSelected is contained in valid | 
 |     assert(PopCount((0 until indices.length).map( | 
 |         i => indicesSelected(i) & ~valid(i))) === 0.U) | 
 |  | 
 |     // TODO(derekjchow): Review semantics for "ordered" and "unordered", and | 
 |     // implement behaviours correctly. | 
 |  | 
 |     val result = Wire(Vec(resultLength, dtype)) | 
 |     for (i <- 0 until resultLength) { | 
 |       result(i) := MuxCase(0.U.asTypeOf(dtype), | 
 |                            (0 until indices.length).map(idx => | 
 |         (valid(idx) && (indices(idx) === i.U)) -> data(idx) | 
 |       )) | 
 |     } | 
 |  | 
 |     (result, resultMask, indicesSelected) | 
 |   } | 
 | } |