blob: 32fe33e9b1ef1dd44a86736b1d121ab3b2039428 [file] [log] [blame]
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cocotb
from cocotb.clock import Clock
from cocotb.triggers import ClockCycles, FallingEdge
class SPIMaster:
def __init__(self, clk, csb, mosi, miso, main_clk, log):
self.clk = clk
self.csb = csb
self.mosi = mosi
self.miso = miso
self.main_clk = main_clk
self.log = log
self.spi_clk_driver = Clock(self.clk, 10)
self.clock_task = None
# Initialize signal values
self.clk.value = 0
self.csb.value = 1
self.mosi.value = 0
async def start_clock(self):
if self.clock_task is None:
self.clock_task = cocotb.start_soon(self.spi_clk_driver.start())
async def stop_clock(self):
if self.clock_task:
self.clock_task.kill()
self.clock_task = None
self.clk.value = 0
async def _set_cs(self, active):
self.csb.value = not active
async def _clock_byte(self, data_out):
data_in = 0
for i in range(8):
self.mosi.value = (data_out >> (7-i)) & 1
await FallingEdge(self.clk)
data_in = (data_in << 1) | int(self.miso.value)
return data_in
async def idle_clocking(self, cycles):
await self.start_clock()
await ClockCycles(self.clk, cycles)
await self.stop_clock()
async def spi_transaction(self, byte_out):
# Provide a setup time for CSb before the clock starts
await self._set_cs(True)
await ClockCycles(self.main_clk, 1)
await self.start_clock()
byte_in = await self._clock_byte(byte_out)
await ClockCycles(self.clk, 2)
await self.stop_clock()
# Provide a hold time for CSb after the clock stops
await ClockCycles(self.main_clk, 1)
await self._set_cs(False)
await ClockCycles(self.main_clk, 2) # Small delay between transactions
return byte_in
async def write_reg(self, reg_addr, data, wait_cycles=10):
"""Writes a byte to a register via SPI."""
write_cmd = (1 << 7) | reg_addr
await self.spi_transaction(write_cmd)
await self.spi_transaction(data)
if wait_cycles > 0:
await ClockCycles(self.main_clk, wait_cycles)
async def read_reg(self, reg_addr):
"""Reads a byte from a register via SPI."""
read_cmd = reg_addr # MSB is 0 for read
await self.spi_transaction(read_cmd)
await ClockCycles(self.main_clk, 10)
await self.idle_clocking(5)
await ClockCycles(self.main_clk, 10)
read_data = await self.spi_transaction(0x00)
return read_data
async def poll_reg_for_value(self, reg_addr, expected_value, max_polls=20):
"""Polls a register until it reads an expected value."""
read_cmd = reg_addr # MSB is 0 for read
read_data = -1
# The first transaction just kicks off the read pipeline. The data is junk.
await self.spi_transaction(read_cmd)
for i in range(max_polls):
# Each subsequent transaction sends a new read command and receives the
# result of the PREVIOUS command.
read_data = await self.spi_transaction(read_cmd)
if read_data == expected_value:
self.log.info(f"Successfully polled 0x{reg_addr:x} and got 0x{expected_value:x} after {i+1} attempts.")
return True
await ClockCycles(self.main_clk, 5) # Wait before next poll
self.log.error(f"Timed out after {max_polls} polls waiting for register 0x{reg_addr:x} to be 0x{expected_value:x}, got 0x{read_data:x}")
return False
async def bulk_read_data(self, reg_addr, num_bytes):
"""Reads a block of data from a pipelined port."""
read_cmd = reg_addr
# The read pipeline is two stages deep. We need to send two commands
# to discard two junk bytes before the first valid data byte is received.
for _ in range(2):
await self.spi_transaction(read_cmd)
await ClockCycles(self.main_clk, 10)
await self.idle_clocking(5)
await ClockCycles(self.main_clk, 10)
# Read the valid bytes.
received_bytes = []
for _ in range(num_bytes):
read_byte = await self.spi_transaction(read_cmd)
received_bytes.append(read_byte)
await ClockCycles(self.main_clk, 5)
# Assemble the received bytes into a single large integer
read_data = 0
for i, byte in enumerate(received_bytes):
read_data |= (byte << (i * 8))
return read_data
async def bulk_write_data(self, reg_addr, data, num_bytes):
"""Writes a block of data to a port."""
for i in range(num_bytes):
byte = (data >> (i * 8)) & 0xFF
await self.write_reg(reg_addr, byte, wait_cycles=5)
async def packed_write_transaction(self, target_addr, num_beats, data_generator):
await self._set_cs(True)
await ClockCycles(self.main_clk, 1)
await self.start_clock()
# Write addr
await self._clock_byte(0x80)
await self._clock_byte((target_addr >> 0) & 0xFF)
await self._clock_byte(0x81)
await self._clock_byte((target_addr >> 8) & 0xFF)
await self._clock_byte(0x82)
await self._clock_byte((target_addr >> 16) & 0xFF)
await self._clock_byte(0x83)
await self._clock_byte((target_addr >> 24) & 0xFF)
# Write beats
await self._clock_byte(0x84)
await self._clock_byte(num_beats - 1)
# Write data
for j in range(num_beats):
data = data_generator(j)
for i in range(16):
byte = (data >> (i * 8)) & 0xFF
await self._clock_byte(0x87)
await self._clock_byte(byte)
await self._clock_byte(0x85)
await self._clock_byte(0x02)
await self.stop_clock()
await ClockCycles(self.main_clk, 1)
await self._set_cs(False)