blob: bd30f5552aee90632a1d304015cbb0e96ed43997 [file] [log] [blame]
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// NB: override default setup done for DIF code linked into this compartment
#ifdef CHERIOT_NO_AMBIENT_MALLOC
#undef CHERIOT_NO_AMBIENT_MALLOC
#endif
/*
* ML accelerator support for soundstream. There is no support for
* loading a model; it's assumed done by the security core which
* has access to the flash where model files are placed.
*/
#include "ml_top.h"
#include <fail-simulator-on-error.h>
#include <futex.h>
#include <locks.h>
#include <multiwaiter.h>
#include <thread.h>
#include <debug.hh>
#include <platform/sencha/platform-ml_top.hh>
#include "compat.h"
#include "hw/top_matcha/sw/autogen/top_matcha.h"
#include "ml_top_regs.h"
typedef uint8_t ml_top_dmem_t[TOP_MATCHA_RAM_ML_DMEM_SIZE_BYTES];
typedef uint32_t
ml_top_mmio_t[TOP_MATCHA_ML_TOP_CORE_SIZE_BYTES / sizeof(uint32_t)];
/// Expose debugging features unconditionally for this compartment.
using Debug = ConditionalDebug<false, "ML_TOP">;
static dif_ml_top_t ml_top;
#define CHECK_INIT() CHECK(ml_top.base_addr.base != 0)
static CountingSemaphoreState startup = { 0, 1 };
void ml_top_init(void) {
CHECK_DIF_OK(dif_ml_top_init(mmio_region_from_addr((uintptr_t)MMIO_CAPABILITY(
ml_top_mmio_t, ml_top_core)),
&ml_top));
ml_top_irq_set_enabled(kMlTopIrqFinish, /*enabled=*/true);
ml_top_irq_set_enabled(kMlTopIrqFault, /*enabled=*/true);
semaphore_put(&startup);
}
// NB: could be atomic but not needed for our usage.
volatile bool finish_done = false;
void ml_top_isr(void) {
Timeout t = {0, UnlimitedTimeout};
semaphore_get(&t, &startup);
Debug::log("ml_top_isr: ml_top {} (Thread {})",
ml_top.base_addr.base, thread_id_get());
MultiWaiter* mw;
Timeout unlimited{0, UnlimitedTimeout};
int error = multiwaiter_create(&unlimited, MALLOC_CAPABILITY, &mw, 2);
Debug::Assert(error == 0 && mw != nullptr,
"multiwaiter_create failed: {}", error);
EventWaiterSource events[2];
const uint32_t* mlTopFinishFutex =
interrupt_futex_get(STATIC_SEALED_VALUE(mlTopFinishInterruptCapability));
events[0] = {(void*)mlTopFinishFutex, EventWaiterFutex, *mlTopFinishFutex};
const uint32_t* mlTopFaultFutex =
interrupt_futex_get(STATIC_SEALED_VALUE(mlTopFaultInterruptCapability));
events[1] = {(void*)mlTopFaultFutex, EventWaiterFutex, *mlTopFaultFutex};
for (;;) {
Debug::Assert(multiwaiter_wait(&unlimited, mw, events, 2) == 0, "multiwaiter_wait");
if (events[1].value == 1) { // Fault signaled
Debug::log("ml_top_isr: Fault, Finish:{}", events[0].value);
abort();
events[1].value = *mlTopFaultFutex;
}
if (events[0].value == 1) { // Finish signaled
Debug::log("ml_top_isr: Finish");
finish_done = true;
CHECK_DIF_OK(dif_ml_top_reset_ctrl_en(&ml_top));
CHECK_DIF_OK(dif_ml_top_irq_acknowledge(&ml_top, kMlTopIrqFinish));
interrupt_complete(STATIC_SEALED_VALUE(mlTopFinishInterruptCapability));
events[0].value = *mlTopFinishFutex;
}
}
}
bool ml_top_finish_done(void) {
// return atomic_swap(&finish_done, false);
bool was_done = finish_done;
if (was_done) {
finish_done = false;
}
return was_done;
}
void ml_top_wait_for_finish(void) {
const uint32_t* mlTopFinishFutex =
interrupt_futex_get(STATIC_SEALED_VALUE(mlTopFinishInterruptCapability));
uint32_t last = *mlTopFinishFutex;
while (!ml_top_finish_done()) {
Debug::Assert(futex_wait(mlTopFinishFutex, last) == 0, "futex_wait");
last = *mlTopFinishFutex;
}
}
void ml_top_irq_acknowledge_all() {
CHECK_INIT();
CHECK_DIF_OK(dif_ml_top_irq_acknowledge_all(&ml_top));
}
void ml_top_irq_set_enabled(ml_top_irq_t irq_id, bool enabled) {
CHECK_INIT();
CHECK_DIF_OK(dif_ml_top_irq_set_enabled(
&ml_top, irq_id, enabled ? kDifToggleEnabled : kDifToggleDisabled));
}
void ml_top_resume_ctrl_en(uint32_t resume_pc) {
CHECK_INIT();
CHECK_DIF_OK(dif_ml_top_resume_ctrl_en(&ml_top, resume_pc));
}
void ml_top_set_input(void* const data, size_t data_len_bytes) {
uint8_t* ml_top_dmem_base = (uint8_t*) MMIO_CAPABILITY(ml_top_dmem_t, ml_top_dmem);
void* input_ptr = ml_top_dmem_base + (TOP_MATCHA_RAM_ML_DMEM_SIZE_BYTES - 4096);
memcpy(input_ptr, data, data_len_bytes);
}
void ml_top_get_output_header(struct output_header* header) {
const uint8_t* ml_top_dmem_base = (uint8_t*) MMIO_CAPABILITY(ml_top_dmem_t, ml_top_dmem);
const struct output_header* output_header_ptr = (const struct output_header*)
(ml_top_dmem_base + (TOP_MATCHA_RAM_ML_DMEM_SIZE_BYTES - 0x40));
header->return_code = output_header_ptr->return_code;
header->output_ptr = output_header_ptr->output_ptr;
header->length = output_header_ptr->length;
header->resume_pc = output_header_ptr->resume_pc;
#if 0
Debug::log("return_code: {}", header.return_code);
Debug::log("output_ptr: {}", header.output_ptr);
Debug::log("length: {}", header.length);
Debug::log("resume_pc: {}", header.resume_pc);
#endif
}
void ml_top_get_output_data(struct output_header* const header, void* buffer) {
const uint8_t* ml_top_dmem_base = (const uint8_t*) MMIO_CAPABILITY(ml_top_dmem_t, ml_top_dmem);
memcpy(buffer, ml_top_dmem_base + header->output_ptr, header->length);
}