| # Copyright 2025 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| import cocotb |
| from cocotb.clock import Clock |
| from cocotb.triggers import ClockCycles, FallingEdge |
| |
| class SPIMaster: |
| def __init__(self, clk, csb, mosi, miso, main_clk, log): |
| self.clk = clk |
| self.csb = csb |
| self.mosi = mosi |
| self.miso = miso |
| self.main_clk = main_clk |
| self.log = log |
| self.spi_clk_driver = Clock(self.clk, 10) |
| self.clock_task = None |
| |
| # Initialize signal values |
| self.clk.value = 0 |
| self.csb.value = 1 |
| self.mosi.value = 0 |
| |
| async def start_clock(self): |
| if self.clock_task is None: |
| self.clock_task = cocotb.start_soon(self.spi_clk_driver.start()) |
| |
| async def stop_clock(self): |
| if self.clock_task: |
| self.clock_task.kill() |
| self.clock_task = None |
| self.clk.value = 0 |
| |
| async def _set_cs(self, active): |
| self.csb.value = not active |
| |
| async def _clock_byte(self, data_out): |
| data_in = 0 |
| for i in range(8): |
| self.mosi.value = (data_out >> (7-i)) & 1 |
| await FallingEdge(self.clk) |
| data_in = (data_in << 1) | int(self.miso.value) |
| return data_in |
| |
| async def idle_clocking(self, cycles): |
| await self.start_clock() |
| await ClockCycles(self.clk, cycles) |
| await self.stop_clock() |
| |
| async def spi_transaction(self, byte_out): |
| # Provide a setup time for CSb before the clock starts |
| await self._set_cs(True) |
| await ClockCycles(self.main_clk, 1) |
| |
| await self.start_clock() |
| byte_in = await self._clock_byte(byte_out) |
| await ClockCycles(self.clk, 2) |
| await self.stop_clock() |
| |
| # Provide a hold time for CSb after the clock stops |
| await ClockCycles(self.main_clk, 1) |
| await self._set_cs(False) |
| await ClockCycles(self.main_clk, 2) # Small delay between transactions |
| return byte_in |
| |
| async def write_reg(self, reg_addr, data, wait_cycles=10): |
| """Writes a byte to a register via SPI.""" |
| write_cmd = (1 << 7) | reg_addr |
| await self.spi_transaction(write_cmd) |
| await self.spi_transaction(data) |
| if wait_cycles > 0: |
| await ClockCycles(self.main_clk, wait_cycles) |
| |
| async def read_reg(self, reg_addr): |
| """Reads a byte from a register via SPI.""" |
| read_cmd = reg_addr # MSB is 0 for read |
| await self.spi_transaction(read_cmd) |
| await ClockCycles(self.main_clk, 10) |
| await self.idle_clocking(5) |
| await ClockCycles(self.main_clk, 10) |
| read_data = await self.spi_transaction(0x00) |
| return read_data |
| |
| async def poll_reg_for_value(self, reg_addr, expected_value, max_polls=20): |
| """Polls a register until it reads an expected value.""" |
| read_cmd = reg_addr # MSB is 0 for read |
| read_data = -1 |
| |
| # The first transaction just kicks off the read pipeline. The data is junk. |
| await self.spi_transaction(read_cmd) |
| |
| for i in range(max_polls): |
| # Each subsequent transaction sends a new read command and receives the |
| # result of the PREVIOUS command. |
| read_data = await self.spi_transaction(read_cmd) |
| if read_data == expected_value: |
| self.log.info(f"Successfully polled 0x{reg_addr:x} and got 0x{expected_value:x} after {i+1} attempts.") |
| return True |
| await ClockCycles(self.main_clk, 5) # Wait before next poll |
| |
| self.log.error(f"Timed out after {max_polls} polls waiting for register 0x{reg_addr:x} to be 0x{expected_value:x}, got 0x{read_data:x}") |
| return False |
| |
| async def bulk_read_data(self, reg_addr, num_bytes): |
| """Reads a block of data from a pipelined port.""" |
| read_cmd = reg_addr |
| |
| # The read pipeline is two stages deep. We need to send two commands |
| # to discard two junk bytes before the first valid data byte is received. |
| for _ in range(2): |
| await self.spi_transaction(read_cmd) |
| await ClockCycles(self.main_clk, 10) |
| await self.idle_clocking(5) |
| await ClockCycles(self.main_clk, 10) |
| |
| # Read the valid bytes. |
| received_bytes = [] |
| for _ in range(num_bytes): |
| read_byte = await self.spi_transaction(read_cmd) |
| received_bytes.append(read_byte) |
| await ClockCycles(self.main_clk, 5) |
| |
| # Assemble the received bytes into a single large integer |
| read_data = 0 |
| for i, byte in enumerate(received_bytes): |
| read_data |= (byte << (i * 8)) |
| |
| return read_data |
| |
| async def bulk_write_data(self, reg_addr, data, num_bytes): |
| """Writes a block of data to a port.""" |
| for i in range(num_bytes): |
| byte = (data >> (i * 8)) & 0xFF |
| await self.write_reg(reg_addr, byte, wait_cycles=5) |
| |
| async def packed_write_transaction(self, target_addr, num_beats, data_generator): |
| await self._set_cs(True) |
| await ClockCycles(self.main_clk, 1) |
| |
| await self.start_clock() |
| |
| # Write addr |
| await self._clock_byte(0x80) |
| await self._clock_byte((target_addr >> 0) & 0xFF) |
| await self._clock_byte(0x81) |
| await self._clock_byte((target_addr >> 8) & 0xFF) |
| await self._clock_byte(0x82) |
| await self._clock_byte((target_addr >> 16) & 0xFF) |
| await self._clock_byte(0x83) |
| await self._clock_byte((target_addr >> 24) & 0xFF) |
| |
| # Write beats |
| await self._clock_byte(0x84) |
| await self._clock_byte(num_beats - 1) |
| |
| # Write data |
| for j in range(num_beats): |
| data = data_generator(j) |
| for i in range(16): |
| byte = (data >> (i * 8)) & 0xFF |
| await self._clock_byte(0x87) |
| await self._clock_byte(byte) |
| |
| await self._clock_byte(0x85) |
| await self._clock_byte(0x02) |
| |
| await self.stop_clock() |
| await ClockCycles(self.main_clk, 1) |
| await self._set_cs(False) |