/*
 * Copyright 2023 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// NB: override default setup done for DIF code linked into this compartment
#ifdef CHERIOT_NO_AMBIENT_MALLOC
#undef CHERIOT_NO_AMBIENT_MALLOC
#endif

/*
 * ML accelerator support for soundstream. There is no support for
 * loading a model; it's assumed done by the security core which
 * has access to the flash where model files are placed.
 */
#include "ml_top.h"

#include <fail-simulator-on-error.h>
#include <futex.h>
#include <locks.h>
#include <multiwaiter.h>
#include <thread.h>

#include <debug.hh>
#include <platform/sencha/platform-ml_top.hh>

#include "compat.h"
#include "hw/top_matcha/sw/autogen/top_matcha.h"
#include "ml_top_regs.h"
#if DEVICE_EXISTS_spi_host
#include "spi.h"
#endif

typedef uint8_t ml_top_dmem_t[TOP_MATCHA_RAM_ML_DMEM_SIZE_BYTES];
typedef uint32_t
    ml_top_mmio_t[TOP_MATCHA_ML_TOP_CORE_SIZE_BYTES / sizeof(uint32_t)];

/// Expose debugging features unconditionally for this compartment.
using Debug = ConditionalDebug<false, "ML_TOP">;

static dif_ml_top_t ml_top;
#define CHECK_INIT() CHECK(ml_top.base_addr.base != 0)

static CountingSemaphoreState startup = { 0, 1 };

void ml_top_init(void) {
#if DEVICE_EXISTS_spi_host
  spi_init();
#endif
  CHECK_DIF_OK(dif_ml_top_init(mmio_region_from_addr((uintptr_t)MMIO_CAPABILITY(
                                   ml_top_mmio_t, ml_top_core)),
                               &ml_top));
  ml_top_irq_set_enabled(kMlTopIrqFinish, /*enabled=*/true);
  ml_top_irq_set_enabled(kMlTopIrqFault, /*enabled=*/true);
  semaphore_put(&startup);
}

// NB: could be atomic but not needed for our usage.
volatile bool finish_done = false;

void ml_top_isr(void) {
  Timeout t = {0, UnlimitedTimeout};
  semaphore_get(&t, &startup);
  Debug::log("ml_top_isr: ml_top {} (Thread {})",
             ml_top.base_addr.base, thread_id_get());

  MultiWaiter* mw;
  Timeout unlimited{0, UnlimitedTimeout};
  int error = multiwaiter_create(&unlimited, MALLOC_CAPABILITY, &mw, 2);
  Debug::Assert(error == 0 && mw != nullptr,
                "multiwaiter_create failed: {}", error);

  EventWaiterSource events[2];
  const uint32_t* mlTopFinishFutex =
      interrupt_futex_get(STATIC_SEALED_VALUE(mlTopFinishInterruptCapability));
  events[0] = {(void*)mlTopFinishFutex, *mlTopFinishFutex};
  const uint32_t* mlTopFaultFutex =
      interrupt_futex_get(STATIC_SEALED_VALUE(mlTopFaultInterruptCapability));
  events[1] = {(void*)mlTopFaultFutex, *mlTopFaultFutex};

  for (;;) {
    Debug::Assert(multiwaiter_wait(&unlimited, mw, events, 2) == 0, "multiwaiter_wait");
    if (events[1].value == 1) {  // Fault signaled
      Debug::log("ml_top_isr: Fault, Finish:{}", events[0].value);
      abort();

      events[1].value = *mlTopFaultFutex;
    }
    if (events[0].value == 1) {  // Finish signaled
      Debug::log("ml_top_isr: Finish");
      finish_done = true;
      CHECK_DIF_OK(dif_ml_top_reset_ctrl_en(&ml_top));
      CHECK_DIF_OK(dif_ml_top_irq_acknowledge(&ml_top, kMlTopIrqFinish));
      interrupt_complete(STATIC_SEALED_VALUE(mlTopFinishInterruptCapability));

      events[0].value = *mlTopFinishFutex;
    }
  }
}
bool ml_top_finish_done(void) {
  // return atomic_swap(&finish_done, false);
  bool was_done = finish_done;
  if (was_done) {
    finish_done = false;
  }
  return was_done;
}

void ml_top_wait_for_finish(void) {
  const uint32_t* mlTopFinishFutex =
      interrupt_futex_get(STATIC_SEALED_VALUE(mlTopFinishInterruptCapability));
  uint32_t last = *mlTopFinishFutex;
  while (!ml_top_finish_done()) {
    Debug::Assert(futex_wait(mlTopFinishFutex, last) == 0, "futex_wait");
    last = *mlTopFinishFutex;
  }
}

void ml_top_irq_acknowledge_all() {
  CHECK_INIT();
  CHECK_DIF_OK(dif_ml_top_irq_acknowledge_all(&ml_top));
}

void ml_top_irq_set_enabled(ml_top_irq_t irq_id, bool enabled) {
  CHECK_INIT();
  CHECK_DIF_OK(dif_ml_top_irq_set_enabled(
      &ml_top, irq_id, enabled ? kDifToggleEnabled : kDifToggleDisabled));
}

void ml_top_resume_ctrl_en(uint32_t resume_pc) {
  CHECK_INIT();
  CHECK_DIF_OK(dif_ml_top_resume_ctrl_en(&ml_top, resume_pc));
}

void ml_top_set_input(void* const data, size_t data_len_bytes) {
  uint8_t* ml_top_dmem_base = (uint8_t*)
      MMIO_CAPABILITY(ml_top_dmem_t, ml_top_dmem);
  void* input_ptr = ml_top_dmem_base + (TOP_MATCHA_RAM_ML_DMEM_SIZE_BYTES - 4096);
  memcpy(input_ptr, data, data_len_bytes);
}

void ml_top_get_output_header(struct output_header* header) {
  const uint8_t* ml_top_dmem_base = (uint8_t*)
      MMIO_CAPABILITY(ml_top_dmem_t, ml_top_dmem);
  const struct output_header* output_header_ptr = (const struct output_header*)
      (ml_top_dmem_base + (TOP_MATCHA_RAM_ML_DMEM_SIZE_BYTES - 0x40));
  header->return_code = output_header_ptr->return_code;
  header->output_ptr = output_header_ptr->output_ptr;
  header->length = output_header_ptr->length;
  header->resume_pc = output_header_ptr->resume_pc;
#if 0
  Debug::log("return_code: {}", header.return_code);
  Debug::log("output_ptr: {}", header.output_ptr);
  Debug::log("length: {}", header.length);
  Debug::log("resume_pc: {}", header.resume_pc);
#endif
}

void ml_top_get_output_data(struct output_header* const header, void* buffer) {
  const uint8_t* ml_top_dmem_base = (const uint8_t*)
      MMIO_CAPABILITY(ml_top_dmem_t, ml_top_dmem);
  memcpy(buffer, ml_top_dmem_base + header->output_ptr, header->length);
}

void ml_top_load_file_from_tar(const char* filename) {
  CHECK_INIT();
#if DEVICE_EXISTS_spi_host
  uint8_t* ml_top_dmem_base = (uint8_t*)
      MMIO_CAPABILITY(ml_top_dmem_t, ml_top_dmem);
  spi_load_file_from_tar(
      filename, ml_top_dmem_base,
      TOP_MATCHA_ML_TOP_DMEM_BASE_ADDR + TOP_MATCHA_RAM_ML_DMEM_SIZE_BYTES);
#else
  Debug::log("No SPI support, expecting {} to be side-loaded", filename);
#endif
}
