blob: 93ea3f346887b90dcf778bbfd69da11893b50a49 [file] [log] [blame]
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package common
import chisel3._
import chisel3.simulator.scalatest.ChiselSim
import chisel3.util._
import org.scalatest.freespec.AnyFreeSpec
import scala.util.Random
class GatherTester extends Module {
val io = IO(new Bundle {
val indices = Input(Vec(16, UInt(4.W)))
val data = Input(Vec(16, UInt(8.W)))
val out = Output(Vec(16, UInt(8.W)))
})
io.out := Gather(io.indices, io.data)
}
class ScatterTester extends Module {
val io = IO(new Bundle {
val indicesValid = Input(Vec(16, Bool()))
val indices = Input(Vec(16, UInt(4.W)))
val data = Input(Vec(16, UInt(8.W)))
val indicesSelected = Output(Vec(16, Bool()))
val writeMask = Output(Vec(16, Bool()))
val outData = Output(Vec(16, UInt(8.W)))
val maskCount = Output(UInt(5.W))
})
val (result, writeMask, indicesSelected) = Scatter(
io.indicesValid, io.indices, io.data)
io.indicesSelected := indicesSelected
io.writeMask := writeMask
io.outData := result
io.maskCount := PopCount(writeMask)
}
class GatherSpec extends AnyFreeSpec with ChiselSim {
"Random Test" in {
simulate(new GatherTester) { dut =>
for (_ <- 0 until 100) {
// Set inputs
val indices = Seq.fill(16)(Random.between(0, 16))
val data = Seq.fill(16)(Random.between(0, 256))
for (i <- 0 until 16) {
dut.io.indices(i).poke(indices(i))
dut.io.data(i).poke(data(i))
}
// Check results
for (i <- 0 until 16) {
dut.io.out(i).expect(data(indices(i)))
}
}
}
}
}
class ScatterSpec extends AnyFreeSpec with ChiselSim {
"Random Test" in {
simulate(new ScatterTester) { dut =>
for (_ <- 0 until 500) {
// Set inputs
val indices = Seq.fill(16)(Random.between(0, 16))
val data = Seq.fill(16)(Random.between(0, 256))
val valid = Seq.fill(16)(Random.nextBoolean())
for (i <- 0 until 16) {
dut.io.indices(i).poke(indices(i))
dut.io.data(i).poke(data(i))
dut.io.indicesValid(i).poke(valid(i))
}
// Check results
val indicesSet = Array.fill(16)(false)
var nIndicesSet = 0
for (i <- 0 until 16) {
if (dut.io.indicesSelected(i).peek().litValue == 1) {
assertResult(true) { valid(i) }
assertResult(false) { indicesSet(indices(i)) }
indicesSet(indices(i)) = true
nIndicesSet = nIndicesSet + 1
dut.io.outData(indices(i)).expect(data(i))
dut.io.writeMask(indices(i)).expect(1)
}
}
dut.io.maskCount.expect(nIndicesSet)
}
}
}
}