// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


#include "sw/device/lib/testing/test_rom/puppeteer.h"

#include <string.h>

#include "flash_ctrl_regs.h"  // NOLINT(build/include_subdir) Generated.
#include "hw/top_matcha/sw/autogen/top_matcha.h"  // Generated.
#include "sw/device/lib/base/memory.h"
#include "sw/device/lib/dif/dif_flash_ctrl.h"
#include "sw/device/lib/dif/dif_spi_host.h"
#include "sw/device/lib/testing/test_rom/puppeteer_utils/base64.h"
#include "sw/device/lib/testing/test_rom/puppeteer_utils/elf_loader.h"
#include "sw/device/lib/testing/test_rom/puppeteer_utils/opentitan_uart.h"
#include "sw/device/lib/testing/test_rom/puppeteer_utils/tar_loader.h"
#include "sw/device/lib/testing/test_rom/puppeteer_utils/tiny_hmac.h"
#include "sw/device/lib/testing/test_rom/puppeteer_utils/tiny_io.h"
#include "sw/device/silicon_creator/lib/drivers/flash_ctrl.h"
#include "sw/device/silicon_creator/lib/manifest.h"

static dif_flash_ctrl_state_t flash_ctrl;
static dif_spi_host spi_host0;

__attribute__((warn_unused_result)) bool eflash_chip_erase(void) {
  flash_ctrl_bank_erase_perms_set(kHardenedBoolTrue);
  rom_error_t err_0 = flash_ctrl_data_erase(0, kFlashCtrlEraseTypeBank);
  rom_error_t err_1 = flash_ctrl_data_erase(FLASH_CTRL_PARAM_BYTES_PER_BANK,
                                            kFlashCtrlEraseTypeBank);
  flash_ctrl_bank_erase_perms_set(kHardenedBoolFalse);
  if (err_0 != kErrorOk || err_1 != kErrorOk) {
    return false;
  }
  return true;
}

// Writes 256-bytes from src to dst.
// For now, lets take dst as a memory address, and mask off the top.
#define EFLASH_PAGE_SIZE (256)
__attribute__((warn_unused_result)) bool eflash_write_page(const void* dst,
                                                           uint8_t* src) {
  uint32_t flash_dst = (uint32_t)dst & ~TOP_MATCHA_EFLASH_BASE_ADDR;
  flash_ctrl_data_default_perms_set((flash_ctrl_perms_t){
      .read = kMultiBitBool4False,
      .write = kMultiBitBool4True,
      .erase = kMultiBitBool4False,
  });
  rom_error_t err = flash_ctrl_data_write(
      flash_dst, EFLASH_PAGE_SIZE / sizeof(uint32_t), src);
  if (err != kErrorOk) {
    return false;
  }
  flash_ctrl_data_default_perms_set((flash_ctrl_perms_t){
      .read = kMultiBitBool4True,
      .write = kMultiBitBool4False,
      .erase = kMultiBitBool4False,
  });
  return true;
}

void flash_ctrl_testutils_wait_for_init(dif_flash_ctrl_state_t* flash_state) {
  dif_flash_ctrl_status_t status;
  dif_result_t result;
  do {
    result = dif_flash_ctrl_get_status(flash_state, &status);
  } while (status.controller_init_wip);
}

//-----------------------------------------------------------------------------
// Puppeteer console command table.

Puppeteer::Command Puppeteer::commands[Puppeteer::command_count] = {
    {"help", &Puppeteer::on_help, "Print command help"},
    {"echo", &Puppeteer::on_echo, "Turn console prompt on or off"},
    {"quit", &Puppeteer::on_quit, "Exit console and continue boot"},
    {"peekb", &Puppeteer::on_peekb, "Read byte from address"},
    {"peekd", &Puppeteer::on_peekd, "Read dword from address"},
    {"pokeb", &Puppeteer::on_pokeb, "Write byte to address"},
    {"poked", &Puppeteer::on_poked, "Write dword to address"},
    {"hexb", &Puppeteer::on_hexb, "Dump address range as 8-bit hex values"},
    {"hexd", &Puppeteer::on_hexd, "Dump address range as 32-bit hex values"},
    {"dump", &Puppeteer::on_dump, "Dump address range as base64 blob"},
    {"boot_ottf", &Puppeteer::on_boot_ottf, "Boots OTTF from flash"},
    {"boot_tock", &Puppeteer::on_boot_tock, "Boots Tock from flash"},
    {"write", &Puppeteer::on_write, "Write blob to address"},
    {"hmac", &Puppeteer::on_test_hmac, "Test HMAC hardware"},
};

//-----------------------------------------------------------------------------
// Tiny console argument parser/matcher.

struct Args {
  const char* cursor;
  bool error = false;

  explicit Args(const char* command) {
    cursor = command;
    error = false;
  }

  const char* next() {
    while (*cursor && *cursor != ' ') cursor++;
    if (*cursor) cursor++;
    return cursor;
  }

  uint32_t get_u32() {
    const char* a = cursor;

    uint32_t temp = 0;
    int scale = 10;
    bool valid = false;
    bool is_hex = false;

    if (a[0] == '0' && (a[1] == 'x' || a[1] == 'X')) {
      is_hex = true;
      scale = 16;
      a += 2;
    }

    while (*a) {
      if (*a >= '0' && *a <= '9') {
        temp = (temp * scale) + (*a - '0');
        valid = true;
      } else if (is_hex && *a >= 'A' && *a <= 'F') {
        temp = (temp * scale) + (*a - 'A') + 10;
        valid = true;
      } else if (is_hex && *a >= 'a' && *a <= 'f') {
        temp = (temp * scale) + (*a - 'a') + 10;
        valid = true;
      } else {
        break;
      }
      a++;
    }

    if (valid) {
      next();
      return temp;
    } else {
      error = true;
      return 0;
    }
  }

  bool match(const char* b) {
    const char* a = cursor;
    if (a == nullptr) return false;
    if (b == nullptr) return false;
    while (1) {
      bool end_a = *a == 0 || *a == ' ';
      bool end_b = *b == 0 || *b == ' ';
      if (end_a && end_b) return true;
      if (*a != *b) return false;
      a++;
      b++;
    }
    return (*a == 0) && (*b == 0);
  }
};

//-----------------------------------------------------------------------------

int Puppeteer::on_test_hmac(Args /*args*/) {
  uint32_t key[8] = {
      0xDEADBEEF, 0xDEADBEEF, 0xDEADBEEF, 0xDEADBEEF,
      0xDEADBEEF, 0xDEADBEEF, 0xDEADBEEF, 0xDEADBEEF,
  };

  uint32_t digest[8];

  uint8_t message[] = {'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd'};

  TinyHMAC hmac(io);

  hmac.init(key);
  hmac.hash(message, 11);
  hmac.finish(digest);

  io.printf("Digest 0 %p\n", digest[0]);
  io.printf("Digest 1 %p\n", digest[1]);
  io.printf("Digest 2 %p\n", digest[2]);
  io.printf("Digest 3 %p\n", digest[3]);
  io.printf("Digest 4 %p\n", digest[4]);
  io.printf("Digest 5 %p\n", digest[5]);
  io.printf("Digest 6 %p\n", digest[6]);
  io.printf("Digest 7 %p\n", digest[7]);

  io.printf("test_hmac() done\n");

  return 0;
}

//-----------------------------------------------------------------------------
// Prints the command documentation.

int Puppeteer::on_help(Args /*args*/) {
  for (int i = 0; i < command_count; i++) {
    io.printf("Command '%s' - '%s'\n", Puppeteer::commands[i].name,
              Puppeteer::commands[i].help);
  }
  return 0;
}

//----------
// Turns the BOOTROM> prompt on and off

int Puppeteer::on_echo(Args args) {
  if (args.match("on")) {
    echo = true;
    return 0;
  } else if (args.match("off")) {
    echo = false;
    return 0;
  } else {
    return -1;
  }
}

//----------
// Exits the bootrom console.

int Puppeteer::on_quit(Args /*args*/) {
  io.printf("Exiting bootrom console\n");
  quit = true;
  return 0;
}

//-----------------------------------------------------------------------------
// Prints the value of the byte at the given address.

int Puppeteer::on_peekb(Args args) {
  uint32_t address = args.get_u32();
  if (args.error) return -1;

  io.printf("%p\n", *(volatile uint8_t*)address);
  return 0;
}

//----------
// Prints the value of the dword at the given address.

int Puppeteer::on_peekd(Args args) {
  uint32_t address = args.get_u32();
  if (args.error) return -1;

  io.printf("%p\n", *(volatile uint32_t*)address);
  return 0;
}

//----------
// Writes the given value as a single byte to the given address.

int Puppeteer::on_pokeb(Args args) {
  uint32_t address = args.get_u32();
  uint32_t data = args.get_u32();
  if (args.error) return -1;

  *(volatile uint8_t*)address = data;
  io.printf("%p\n", *(volatile uint8_t*)address);
  return 0;
}

//----------
// Writes the given value as a single dword to the given address.

int Puppeteer::on_poked(Args args) {
  uint32_t address = args.get_u32();
  uint32_t data = args.get_u32();
  if (args.error) return -1;

  *(volatile uint32_t*)address = data;
  io.printf("%p\n", *(volatile uint32_t*)address);
  return 0;
}

// Reads a 512-byte page from the attached SPI device
// into a user-provided buffer.
#define SPI_PAGE_SIZE (512)
void spi_read_page(dif_spi_host* spi_host, uint32_t page, uint8_t* buf) {
  /* clang-format off */
  dif_spi_host_segment_t segments[] = {
      {
          .type = kDifSpiHostSegmentTypeOpcode,
          .opcode = 0x13,  // Page Read, 4b Addr
      },
      {.type = kDifSpiHostSegmentTypeAddress,
       .address = {
               .width = kDifSpiHostWidthStandard,
               .mode = kDifSpiHostAddrMode4b,
               .address = page,
           }},
      // Two RX segments, due to the RX fifo only being 256B.
      {.type = kDifSpiHostSegmentTypeRx,
       .rx = {.width = kDifSpiHostWidthStandard,
              .buf = buf,
              .length = SPI_PAGE_SIZE >> 1}},
      {.type = kDifSpiHostSegmentTypeRx,
       .rx = {.width = kDifSpiHostWidthStandard,
              .buf = buf + (SPI_PAGE_SIZE >> 1),
              .length = SPI_PAGE_SIZE >> 1}},
  };
  /* clang-format on */
  dif_result_t result = dif_spi_host_transaction(spi_host, /*csid=*/0, segments,
                                                 ARRAYSIZE(segments));
}

size_t tar_read_data_from_spi(tar_header* buf, size_t offset) {
  constexpr size_t kExtflashSize = (1 << 24) - 1;
  if (offset + sizeof(tar_header) > kExtflashSize) {
    return 0;
  }
  size_t count = sizeof(tar_header);
  size_t remaining = sizeof(tar_header);
  while (remaining >= SPI_PAGE_SIZE) {
    spi_read_page(&spi_host0, offset + (count - remaining),
                  reinterpret_cast<uint8_t*>(buf));
    remaining -= SPI_PAGE_SIZE;
  }
  if (remaining != 0) {
    return 0;
  }
  return sizeof(tar_header);
}

int load_bin_from_tar(const char* filename, uint8_t* base_addr) {
  size_t size = 0;
  size_t bin_offset = reinterpret_cast<size_t>(
      find_file_in_tar(filename, tar_read_data_from_spi, &size));
  if (!bin_offset) {
    return 0;
  }

  uint8_t page[SPI_PAGE_SIZE];
  size_t remaining = size;
  bool result;
  while (remaining >= SPI_PAGE_SIZE) {
    size_t offset = size - remaining;
    spi_read_page(&spi_host0, bin_offset + offset, page);
    result = eflash_write_page(base_addr + offset, page);
    result = eflash_write_page(base_addr + offset + EFLASH_PAGE_SIZE,
                               page + EFLASH_PAGE_SIZE);
    remaining -= SPI_PAGE_SIZE;
  }

  if (remaining) {
    size_t offset = size - remaining;
    spi_read_page(&spi_host0, bin_offset + offset, page);
    result = eflash_write_page(base_addr + offset, page);
    result = eflash_write_page(base_addr + offset + EFLASH_PAGE_SIZE,
                               page + EFLASH_PAGE_SIZE);
  }
  return 1;
  // return *reinterpret_cast<generic_call*>(base_addr);
}

int Puppeteer::on_boot_ottf(Args args) { return on_boot_internal(args, false); }

int Puppeteer::on_boot_tock(Args args) { return on_boot_internal(args, true); }

//-----------------------------------------------------------------------------
// write words.
int Puppeteer::on_boot_internal(Args args, bool tock) {
  uint8_t* kEflashBase =
      reinterpret_cast<uint8_t*>(TOP_MATCHA_EFLASH_BASE_ADDR);
  // Initialize and erase the eflash.
  dif_flash_ctrl_init_state(
      &flash_ctrl, mmio_region_from_addr(TOP_MATCHA_FLASH_CTRL_CORE_BASE_ADDR));
  dif_flash_ctrl_start_controller_init(&flash_ctrl);
  flash_ctrl_testutils_wait_for_init(&flash_ctrl);
  dif_flash_ctrl_set_flash_enablement(&flash_ctrl, kDifToggleEnabled);
  dif_flash_ctrl_set_exec_enablement(&flash_ctrl, kDifToggleEnabled);
  if (!eflash_chip_erase()) {
    io.printf("Erasing eflash failed\n");
    io.etx();
    return 0;
  }

  // Initialize the SPI host device.
  dif_result_t result = dif_spi_host_init(
      mmio_region_from_addr(TOP_MATCHA_SPI_HOST0_BASE_ADDR), &spi_host0);
  dif_spi_host_config_t config = {
      .spi_clock = 1000000,
      .peripheral_clock_freq_hz = 25 * 100 * 1000,
  };
  result = dif_spi_host_configure(&spi_host0, config);
  result = dif_spi_host_output_set_enabled(&spi_host0, /*enabled=*/true);

  // Find the tock binary in the external flash.
  const char* kScBinary = "matcha-tock-bundle.bin";
  int ret = load_bin_from_tar(kScBinary, kEflashBase);

  if (!ret) {
    io.printf("Didn't find %s, don't try to start.\n", kScBinary);
    io.etx();
    return 0;
  }

  generic_call entry_point;
  if (tock) {
    entry_point = *reinterpret_cast<generic_call*>(kEflashBase);
  } else {
    entry_point = reinterpret_cast<generic_call>(manifest_entry_point_get(
        reinterpret_cast<const manifest_t*>(kEflashBase)));
  }

  // Execute at the returned entry point.
  io.printf("Starting %s @ 0x%p on security core\n", entry_point,
            tock ? "TockOS" : "OTTF");
  io.etx();
  entry_point(0, 0, 0, 0, 0, 0, 0, 0);

  return 0;
}

//-----------------------------------------------------------------------------
// Dumps 'count' bytes at 'start' in 8-bit hexadecimal chunks

int Puppeteer::on_hexb(Args args) {
  auto start = reinterpret_cast<uint8_t*>(args.get_u32());
  auto count = args.get_u32();
  if (args.error) return -1;

  io.printf("Start %p\n", start);
  io.printf("Count %d\n", count);

  for (uint32_t i = 0; i < count; i++) {
    if (i && ((i % 16) == 0)) io.printf("\n");
    io.printf("0x%b ", start[i]);
  }
  io.printf("\n");
  return 0;
}

//-----------------------------------------------------------------------------
// Dumps 'count' bytes at 'start' in 32-bit hexadecimal chunks

int Puppeteer::on_hexd(Args args) {
  auto start = reinterpret_cast<uint32_t*>(args.get_u32());
  auto count = args.get_u32();
  if (args.error) return -1;

  io.printf("Start %p\n", start);
  io.printf("Count %d\n", count);

  for (uint32_t i = 0; i < (count + 3) / 4; i++) {
    if (i && ((i % 8) == 0)) io.printf("\n");
    io.printf("0x%p ", start[i]);
  }
  io.printf("\n");
  return 0;
}

//-----------------------------------------------------------------------------
// Dumps 'count' bytes at 'start' encoded as base64

int Puppeteer::on_dump(Args args) {
  auto start = reinterpret_cast<uint8_t*>(args.get_u32());
  auto count = args.get_u32();
  auto end = start + count;
  if (args.error) return -1;

  Base64 b64;

  for (auto addr = start; addr < end; addr++) {
    b64.put_byte(*addr);
    while (b64.count >= 6) {
      io.putc(b64.get_b64());
    }
  }

  if (b64.count) {
    io.putc(b64.get_b64());
  }

  io.printf("\n");

  return 0;
}

//-----------------------------------------------------------------------------
// Writes a base64 blob to 'address'. Issue this command, then paste a base64-
// encoded block of data into the console.

int Puppeteer::on_write(Args args) {
  auto start = reinterpret_cast<uint8_t*>(args.get_u32());
  if (args.error) return -1;

  io.printf("Writing to %p, send a Base64 string\n", start);
  io.etx();

  Base64 b64;
  auto cursor = start;

  while (1) {
    uint8_t b = io.get();
    if (b == '\n') {
      break;
    } else {
      b64.put_b64(b);
      while (b64.count >= 8) {
        *cursor++ = b64.get_byte();
      }
    }
  }

  io.printf("Wrote %d bytes to 0x%p\n", cursor - start, start);

  return cursor - start;
}

int Puppeteer::on_err(char* command) {
  io.printf("Unknown command '%s'\n", command);
  return -1;
}

//-----------------------------------------------------------------------------

int Puppeteer::find_command(char* console_buf) {
  for (int i = 0; i < command_count; i++) {
    Args args(console_buf);
    if (args.match(commands[i].name)) {
      return i;
    }
  }
  return -1;
}

//-----------------------------------------------------------------------------

void Puppeteer::console_main() {
  io.printf("Puppeteer::console_main()\n");

  uint8_t console_cursor = 0;
  memset(console_buf, 0, sizeof(console_buf));

  io.printf("BOOTROM> ");

  // Screen does not echo, sends \r on enter but no \n
  // Telnet echos locally, buffers until enter, then sends buffer + \r\n.
  // Bootshell sends \n

  uint8_t prev = 0;
  uint8_t next = 0;

  while (!quit) {
    prev = next;
    next = io.get();

    // Ignore LF if it comes immediately after CR
    if (prev == '\r' && next == '\n') continue;

    if (echo) io.putc(next);

    if (next == '\r' || next == '\n') {
      if (console_cursor > 0 && console_cursor < 255) {
        int command_idx = find_command(console_buf);
        if (command_idx != -1) {
          Args args(console_buf);
          args.next();
          (this->*commands[command_idx].handler)(args);
        } else {
          io.printf("Unknown command '%s'\n", console_buf);
        }
      }
      memset(console_buf, 0, sizeof(console_buf));
      io.etx();
      io.printf("BOOTROM> ");
      console_cursor = 0;
    } else {
      if (console_cursor < 255) console_buf[console_cursor] = next;
      console_cursor++;
    }
  }
}

//-----------------------------------------------------------------------------
