/*
 * Copyright 2023 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "sw/device/lib/spi_flash.h"

#include <string.h>

#include "hw/top_matcha/sw/autogen/top_matcha.h"
#include "sw/device/lib/arch/device.h"
#include "sw/device/lib/dif/dif_spi_host.h"
#include "sw/device/lib/eflash.h"
#include "sw/device/lib/tar.h"
#include "sw/device/lib/testing/test_framework/check.h"

static dif_spi_host_t spi_host0;
static uintptr_t flash_ctrl_addr;
static const int kSpiFlashBytes = (64 * 1024 * 1024);

dif_result_t spi_flash_read_page(uint32_t page, uint8_t* buf) {
  /* clang-format off */
  dif_spi_host_segment_t segments[] = {
      {
          .type = kDifSpiHostSegmentTypeOpcode,
          .opcode = 0x13,  // Page Read, 4b Addr
      },
      {.type = kDifSpiHostSegmentTypeAddress,
       .address = {
               .width = kDifSpiHostWidthStandard,
               .mode = kDifSpiHostAddrMode4b,
               .address = page,
           }},
      {.type = kDifSpiHostSegmentTypeRx,
       .rx = {.width = kDifSpiHostWidthStandard,
              .buf = buf,
              .length = SPI_PAGE_SIZE >> 1}},
  };
  dif_spi_host_segment_t segments2[] = {
      {
          .type = kDifSpiHostSegmentTypeOpcode,
          .opcode = 0x13,  // Page Read, 4b Addr
      },
      {.type = kDifSpiHostSegmentTypeAddress,
       .address = {
               .width = kDifSpiHostWidthStandard,
               .mode = kDifSpiHostAddrMode4b,
               .address = page + (SPI_PAGE_SIZE >> 1),
           }},
      {.type = kDifSpiHostSegmentTypeRx,
       .rx = {.width = kDifSpiHostWidthStandard,
              .buf = buf + (SPI_PAGE_SIZE >> 1),
              .length = SPI_PAGE_SIZE >> 1}},
  };
  /* clang-format on */
  CHECK_DIF_OK(dif_spi_host_transaction(&spi_host0, /*csid=*/0, segments,
                                        ARRAYSIZE(segments)));
  CHECK_DIF_OK(dif_spi_host_transaction(&spi_host0, /*csid=*/0, segments2,
                                        ARRAYSIZE(segments2)));
  return kDifOk;
}

dif_result_t spi_flash_init(uintptr_t spi_host_addr,
                            uintptr_t flash_ctrl_addr_, uintptr_t otp_addr) {
  flash_ctrl_addr = flash_ctrl_addr_;

  CHECK_DIF_OK(dif_spi_host_init(mmio_region_from_addr(spi_host_addr), &spi_host0));
  dif_spi_host_config_t config = {
      .spi_clock = kClockFreqSpiFlashHz,
      .peripheral_clock_freq_hz = kClockFreqCpuHz,
  };
  CHECK_DIF_OK(dif_spi_host_configure(&spi_host0, config));
  CHECK_DIF_OK(dif_spi_host_output_set_enabled(&spi_host0, /*enabled=*/true));
  CHECK_DIF_OK(eflash_init(flash_ctrl_addr, otp_addr));
  return kDifOk;
}

static int parse_octal(uint8_t* cursor) {
  int size = 0;
  while (*cursor) size = (size << 3) | ((*cursor++) - '0');
  return size;
}

size_t str_size(const char* str) {
  size_t size = 0;
  char* ptr = (char*)str;
  while (*ptr++) {
    size++;
  }
  return size;
}

// Returns the offset of the file(or prefix) in the tar.
// `size_out` will be filled with the size of the file.
static size_t find_file_in_tar(const char* filename, size_t start_cursor,
                               size_t* size_out, char* found_filename,
                               size_t filename_max_len) {
  if (!size_out || !filename) {
    return 0;
  }
  const char kTarMagic[] = "ustar";
  tar_header tar;
  size_t cursor = start_cursor;
  CHECK_DIF_OK(spi_flash_read_page(cursor, (uint8_t*)&tar));
  while (cursor < kSpiFlashBytes) {
    // Check the tar header magic field to validate the header info.
    if (memcmp((const char*)&tar.magic, kTarMagic, str_size(kTarMagic)) != 0) {
      break;
    }
    int size = parse_octal((uint8_t*)&tar.size);
    cursor += 512;
    const char* tar_name = (const char*)&tar.name;
    size_t tar_name_len = str_size(tar_name);
    if (memcmp(tar_name, filename, str_size(filename)) == 0) {
      if (size_out) {
        *size_out = size;
      }
      if (found_filename) {
        size_t copy_len = (tar_name_len < filename_max_len - 1)
                              ? tar_name_len
                              : filename_max_len - 1;
        memcpy(found_filename, tar_name, copy_len);
        found_filename[copy_len] = '\0';
      }
      return cursor;
    }

    cursor += (size + 511) & ~511;
    CHECK_DIF_OK(spi_flash_read_page(cursor, (uint8_t*)&tar));
  }

  return 0;
}

// Copies data from SPI flash to a memory region.
// If `addr` points to the region represented by the eFLASH,
// special methods will be used to populate it.
// Otherwise, `memcpy` is used.
static dif_result_t copy_flash_to_mem(size_t bin_offset, size_t size,
                                      void* addr) {
  if (!addr) {
    return kDifBadArg;
  }
  if (bin_offset & (SPI_PAGE_SIZE - 1)) {
    return kDifBadArg;
  }
  uint8_t page[SPI_PAGE_SIZE];
  size_t remaining = size;
  bool eflash = false;
  if ((uintptr_t)addr >= TOP_MATCHA_EFLASH_BASE_ADDR &&
      (uintptr_t)addr <
          TOP_MATCHA_EFLASH_BASE_ADDR + TOP_MATCHA_EFLASH_SIZE_BYTES) {
    eflash = true;
    CHECK_DIF_OK(eflash_chip_erase(flash_ctrl_addr));
  };
  while (remaining > 0) {
    size_t page_used = remaining >= SPI_PAGE_SIZE ? SPI_PAGE_SIZE : remaining;
    size_t offset = size - remaining;
    CHECK_DIF_OK(spi_flash_read_page(bin_offset + offset, page));
    void* dst = (uint8_t*)addr + offset;
    if (eflash) {
      CHECK_DIF_OK(eflash_write_page(flash_ctrl_addr, dst, page));
      CHECK_DIF_OK(
          eflash_write_page(flash_ctrl_addr, dst + EFLASH_PAGE_SIZE, page + EFLASH_PAGE_SIZE));
    } else {
      if (memcpy(dst, page, page_used) != dst) {
        return kDifError;
      }
    }
    remaining -= page_used;
  }

  return kDifOk;
}

dif_result_t load_file_from_tar(const char* filename, void* addr,
                                size_t max_mem_addr) {
  if (!filename) {
    return kDifBadArg;
  }
  size_t size = 0;
  size_t bin_offset =
      find_file_in_tar(filename, /*start_cursor=*/0, &size,
                       /*found_filename=*/NULL, /*max_filename_len=*/0);
  if (!bin_offset) return kDifError;
  if ((size_t)addr + size > max_mem_addr) {
    return kDifBadArg;
  }
  return copy_flash_to_mem(bin_offset, size, addr);
}

dif_result_t load_file_prefix_from_tar(const char* file_prefix, void* addr,
                                       size_t max_mem_addr, size_t* tar_offset,
                                       char* filename, size_t filename_len) {
  if (!tar_offset || !file_prefix || !filename) {
    return kDifBadArg;
  }
  size_t size = 0;
  size_t start_offset = *tar_offset;
  size_t bin_offset = find_file_in_tar(file_prefix, start_offset, &size,
                                       filename, filename_len);
  if (!bin_offset) return kDifOutOfRange;
  if ((size_t)addr + size > max_mem_addr) {
    return kDifBadArg;
  }
  // Update the tar_offset for the future call.
  *tar_offset = bin_offset + ((size + 511) & ~511);
  return copy_flash_to_mem(bin_offset, size, addr);
}
