// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Shodan ELF loader capsule

use core::cell::Cell;
use core::marker::PhantomData;
use kernel::common::cells::{OptionalCell, TakeCell};
use kernel::hil;
use kernel::hil::flash::Flash;
use kernel::{AppId, AppSlice, Callback, Driver, ReturnCode, Shared};
use matcha_hal::dprintf;
use matcha_hal::mailbox_hal::MailboxHAL;
use matcha_hal::smc_ctrl_hal;
use matcha_utils::elf_loader::{Elf32Header, Elf32Phdr};
use matcha_utils::smc_ram_memcpy;
use matcha_utils::tar_loader::TarHeader;

const PT_LOAD: u32 = 1;  // Loadable segment
const INVALID_LOAD_OFFSET: u32 = 0;

pub struct ElfLoaderCapsule<'a, F: hil::flash::Flash + 'static> {
    mailbox_hal: Option<&'static dyn MailboxHAL>,
    smc_ctrl_hal: Option<&'static dyn smc_ctrl_hal::SmcCtrlHal>,
    flash: Option<&'static capsules::virtual_flash::FlashUser<'static, F>>,
    flash_busy: Cell<bool>,
    read_page: TakeCell<'static, F::Page>,
    page_len: u32,
    state: Cell<ElfLoaderState>,
    phantom: PhantomData<&'a ()>,
    tasks: Cell<[(LoadTasks, bool); 3]>,
    current_task: OptionalCell<LoadTasks>,
    sel4_state: Cell<SEL4State>,
}

#[derive(Clone, Copy, PartialEq)]
enum LoadTasks {
    Noop,
    LoadElf(&'static str /* name */),
    StartSmc,
}

#[derive(Clone, Copy, Default)]
struct SEL4State {
    // State for current file being loaded.
    load_offset: u32, // File offset
    // Buffering used when a PHDR spans flash pages
    partial_phdr: Option<(usize, [u8; core::mem::size_of::<Elf32Phdr>()])>,
    // NB: 30 is enough for cheriot-rtos' test suite
    phdrs: [Option<Elf32Phdr>; 30], // ELF phdrs

    // Saved state for the "kernel" image.
    kernel_entry_point: u32,
    kernel_phys_max: u32,

    // Saved state for the "capdl-loader" image.
    capdl_loader_entry_point: u32,
    capdl_phys_min: u32,
    capdl_phys_max: u32,
}
impl SEL4State {
    fn reset_phdrs(&mut self) {
        self.partial_phdr = None;
        self.phdrs = SEL4State::default().phdrs;
    }
    fn add_phdr(&mut self, phdr: Elf32Phdr) {
        let mut next_slot: Option<usize> = None;
        for i in 0..self.phdrs.len() {
            if self.phdrs[i] == None {
                next_slot = Some(i);
                break;
            }
        }
//dprintf!("[{:?}] {:#08x?}\r\n", next_slot, phdr);
        if let Some(slot) = next_slot {
            self.phdrs[slot] = Some(phdr);
        } else {
            panic!("Too many ELF program headers");
        }
    }
}

#[derive(Clone, Copy, PartialEq)]
enum ElfLoaderState {
    Idle,
    FindingFile(&'static str /* name */, u32 /* cursor */),
    LoadElfHeader(u32 /* cursor */),
    LoadProgramHeaderTable(
        u32, /* cursor */
        usize, /* offset */
        u16, /* num */
        u16, /* index */
    ),
    LoadSegmentsNew(
        Elf32Phdr,    /* phdr */
        usize,        /* index */
        u32,          /* cursor */
        u32,          /* dst offset */
        u32,          /* offset_in_page */
        u32,          /* already loaded */
    ),
}

fn print_segment(index: usize, cursor: u32, offset: u32, phdr: &Elf32Phdr) {
    if phdr.p_filesz > 0 {
        dprintf!(
            "seg {:2}:{:X} -> {:X}:{:X} ({} bytes)\r\n",
            index,
            cursor,
            phdr.p_paddr + offset,
            phdr.p_paddr + offset + phdr.p_filesz,
            { phdr.p_filesz },
        );
    }
    let bss_size = phdr.p_memsz - phdr.p_filesz;
    if bss_size > 0 {
        let bss_start = phdr.p_paddr + offset + phdr.p_filesz;
        let bss_end = bss_start + bss_size;

        dprintf!(
            "bss {:2}:{:X} -> {:X}:{:X} ({} bytes)\r\n",
            index,
            cursor,
            bss_start,
            bss_end,
            bss_size
        );
    }
}

impl<'a, F: hil::flash::Flash> ElfLoaderCapsule<'a, F> {
    pub fn new() -> Self {
        Self {
            mailbox_hal: None,
            smc_ctrl_hal: None,
            flash: None,
            flash_busy: Cell::new(false),
            read_page: TakeCell::empty(),
            page_len: 0,
            state: Cell::new(ElfLoaderState::Idle),
            phantom: PhantomData,
            tasks: Cell::new([(LoadTasks::Noop, true); 3]),
            current_task: OptionalCell::empty(),
            sel4_state: Cell::new(SEL4State::default()),
        }
    }

    pub fn set_smc_ctrl(&mut self, smc_ctrl: &'static dyn smc_ctrl_hal::SmcCtrlHal) {
        self.smc_ctrl_hal = Some(smc_ctrl);
    }

    pub fn set_mailbox(&mut self, mailbox: &'static dyn MailboxHAL) {
        self.mailbox_hal = Some(mailbox);
    }

    pub fn set_flash(
        &mut self,
        flash: &'static capsules::virtual_flash::FlashUser<'static, F>,
        read_page: &'static mut F::Page,
    ) {
        self.flash = Some(flash);
        self.read_page.replace(read_page);
        let mut page_len = 0;
        self.read_page.map(|page| {
            let mut_page = page.as_mut();
            page_len = mut_page.len() as u32;
        });
        self.page_len = page_len;
    }

    pub fn run_next_task(&self) {
        self.current_task.clear();
        let mut tasks = self.tasks.get();
        for i in 0..tasks.len() {
            let (task, done) = tasks[i];
            if !done {
                tasks[i] = (task, true);
                self.current_task.replace(task);
                break;
            }
        }
        self.tasks.replace(tasks);

        self.current_task.map_or_else(
            || {
                dprintf!("No current task\r\n");
                loop {}
            },
            |task| match task {
                LoadTasks::LoadElf(file) => {
                    self.find_file(file);
                }
                LoadTasks::StartSmc => {
                    let sel4_state = self.sel4_state.get();
                    if sel4_state.capdl_phys_max > 0 {
                        // If capdl-loader was loaded we are running seL4 and
                        // it will expect a boot message to be waiting in the
                        // mailbox FIFO telling it how to setup the kernel and
                        // the rootserver.

                        // NB: ui_p_reg_start & co. come from the code in seL4
                        //   that processes these values.
                        let ui_p_reg_start =
                            matcha_utils::round_up_to_page(sel4_state.kernel_phys_max);
                        let ui_p_reg_end = ui_p_reg_start
                            + matcha_utils::round_up_to_page(
                                sel4_state.capdl_phys_max - sel4_state.capdl_phys_min
                            );
                        let pv_offset = ui_p_reg_start - sel4_state.capdl_phys_min;
                        let v_entry = sel4_state.capdl_loader_entry_point;
    /*
                        dprintf!(
                            "StartSmc: ui_p_reg {:X}:{:X} pv_offset {:X} v_entry {:X}\r\n",
                            ui_p_reg_start,
                            ui_p_reg_end,
                            pv_offset,
                            v_entry
                        );
    */
                        // XXX suppress if not seL4?
                        self.mailbox_hal.map(|mb| {
                            matcha_utils::smc_send_bootmsg(mb, [
                                  ui_p_reg_start,
                                  ui_p_reg_end,
                                  pv_offset,
                                  v_entry,
                            ])
                        });
                    }

                    assert!(self.smc_ctrl_hal.is_some());
                    self.smc_ctrl_hal.unwrap().smc_ctrl_start();
                }
                _ => unreachable!(),
            },
        );
    }

    // Load seL4 state machine.
    pub fn load_sel4(&self) {
        self.tasks.replace([
            (LoadTasks::LoadElf("kernel"), false),
            (LoadTasks::LoadElf("capdl-loader"), false),
            (LoadTasks::StartSmc, false),
        ]);
        self.run_next_task();
    }

    // Reads the data at flash address |page|. |read_complete| (below)
    // dispatchs the callback depending on the current ElfLoaderState.
    fn read_page(&self, page: u32) {
        self.flash.map(|flash| {
            self.read_page.take().map(|read_page| {
                self.flash_busy.set(true);
                if let Err((_, buf)) = flash.read_page((page / self.page_len) as usize, read_page) {
                    self.read_page.replace(buf);
                }
            });
        });
    }

    // Initiates a file lookup.
    fn find_file(&self, name: &'static str) {
        self.state.set(ElfLoaderState::FindingFile(name, 0));
        self.read_page(0);
    }

    // Callback from |read_complete| for a file lookup. If the file was
    // found start loading the ELF headers; otherwise advance the file
    // lookup mechanism (to stride through the tar file). If the search
    // was unsuccessful log to the console and advance to the next task.
    fn find_file_callback(&self) {
        self.read_page.map(|page| {
            let mut_page = page.as_mut();
            let tar_header = TarHeader::from_bytes(mut_page);
            match self.state.get() {
                ElfLoaderState::FindingFile(name, cursor) => {
                    if tar_header.name().contains(name) {
                        let mut sel4_state = self.sel4_state.get();
                        sel4_state.load_offset = cursor + self.page_len;
                        self.sel4_state.replace(sel4_state);
                        self.state.set(ElfLoaderState::Idle);
                    } else if !tar_header.has_magic() {
                        // File not found.
                        dprintf!("{} not found!\r\n", name);
                        let mut sel4_state = self.sel4_state.get();
                        sel4_state.load_offset = INVALID_LOAD_OFFSET;
                        self.sel4_state.replace(sel4_state);
                        self.state.set(ElfLoaderState::Idle);
                    } else {
                        let new_cursor =
                            self.page_len + cursor + ((tar_header.size() + 511) & !511);
                        self.state
                            .set(ElfLoaderState::FindingFile(name, new_cursor));
                    }
                }
                _ => unreachable!(),
            };
        });

        match self.state.get() {
            ElfLoaderState::FindingFile(_, cursor) => self.read_page(cursor),
            ElfLoaderState::Idle => {
                if self.sel4_state.get().load_offset != INVALID_LOAD_OFFSET {
                    self.load_elf_headers();
                } else {
                    self.run_next_task();
                }
            }
            _ => panic!("507"),
        }
    }

    // Initiates reading ELF header data for the current file.
    // NB: this clobbers program headers so any state derived from a file's
    //   headers must be recorded below (see |load_elf_header_callback|).
    fn load_elf_headers(&self) {
        let mut sel4_state = self.sel4_state.get();
        sel4_state.reset_phdrs(); // Reset shared phdr state
        let cursor = sel4_state.load_offset;
        assert!(cursor  != INVALID_LOAD_OFFSET);
        self.sel4_state.replace(sel4_state);
        self.state.set(ElfLoaderState::LoadElfHeader(cursor));
        self.read_page(cursor);
    }

    // Callback from |read_complete| after reading the ELF header. Saves
    // the ELF entry point and initiate reading the ELF program headers.
    fn load_elf_header_callback(&self) {
        self.read_page.map(|page| {
            let mut_page = page.as_mut();
            let elf_header = Elf32Header::from_bytes(mut_page);
            assert!(elf_header.check_magic());
            self.current_task.map(|task| {
                let mut sel4_state = self.sel4_state.get();
                match task {
                    LoadTasks::LoadElf("kernel") => {
                        sel4_state.kernel_entry_point = elf_header.e_entry;
                    }
                    LoadTasks::LoadElf("capdl-loader") => {
                        sel4_state.capdl_loader_entry_point = elf_header.e_entry;
                    }
                    _ => unreachable!(),
                };
                self.sel4_state.replace(sel4_state);
            });
            match self.state.get() {
                ElfLoaderState::LoadElfHeader(cursor) => {
                    // Kick off reading the ELF program headers.
                    self.state.set(ElfLoaderState::LoadProgramHeaderTable(
                        cursor,
                        elf_header.phoff() as usize,
                        elf_header.phnum(),
                        0,
                    ));
                }
                _ => unreachable!(),
            }
        });
        match self.state.get() {
            ElfLoaderState::LoadProgramHeaderTable(cursor, ..) =>
                self.read_page(cursor),
            _ => unreachable!(),
        }
    }

    // Callback from |read_complete| after reading an ELF program header.
    // Saves the header and initiates a read of the next header or, if all
    // headers have been read, initiates loading the ELF segments. If all
    // headers have been read this also records per-file data based on the
    // full set of headers (that may be clobbered if another file is read).
    fn load_program_header_table_callback(&self) {
        let mut have_partial_phdr: bool = false;
        let mut new_page_offset: usize = 0;
        self.read_page.map(|page| {
            let mut_page = page.as_mut();
            match self.state.get() {
                ElfLoaderState::LoadProgramHeaderTable(_, mut page_offset, count, mut index) => {
                    let mut sel4_state = self.sel4_state.get();
                    if let Some((part_len, mut part)) = sel4_state.partial_phdr {
                        // There is a partial PHDR, append the remainder and process.
                        // NB: PHDR is a fixed size that cannot span >2 flash pages
                        //   so there must be room for the 2nd part
                        let remaining_bytes = core::mem::size_of::<Elf32Phdr>() - part_len;
                        let buffer = &mut part[..];
                        assert!(page_offset == 0);
                        buffer[part_len..].copy_from_slice(&mut_page[..remaining_bytes]);
                        sel4_state.add_phdr(Elf32Phdr::from_bytes(buffer));
                        sel4_state.partial_phdr = None;

                        // Advance to next PHDR.
                        page_offset += remaining_bytes;
                        index += 1;
                    }
                    if index < count {
                        let page_bytes = &mut_page[page_offset..];
                        // NB: assume word-alignment so it's safe to read the p_type field
                        if Elf32Phdr::get_type(page_bytes) == PT_LOAD {
                            if page_bytes.len() < core::mem::size_of::<Elf32Phdr>() {
                                // This PHDR spans flash pages; stash what we have and arrange
                                // for the next flash page to be read without advancing the PHDR
                                // state. Note the logic below depends on there not being
                                // back-to-back spanning headers (which cannot happen since a
                                // header is 32 bytes and a page is 512 bytes).
                                let mut buffer = [0u8; core::mem::size_of::<Elf32Phdr>()];
                                buffer[..page_bytes.len()].copy_from_slice(page_bytes);
                                sel4_state.partial_phdr = Some((page_bytes.len(), buffer));
                                have_partial_phdr = true;
                            } else {
                                // We have the entire PHDR; process without using the copy buffer.
                                sel4_state.add_phdr(Elf32Phdr::from_bytes(page_bytes));
                            }
                        }
                    }
                    self.sel4_state.replace(sel4_state);
                    new_page_offset = page_offset;
                }
                _ => unreachable!(),
            };
        });
        match self.state.get() {
            ElfLoaderState::LoadProgramHeaderTable(cursor, _, count, index) => {
                if have_partial_phdr {
                    // We just recorded a partial PHDR; read the next page without
                    // advancing the PHDR index.
                    let new_cursor = cursor + self.page_len;
                    self.state.set(ElfLoaderState::LoadProgramHeaderTable(
                        new_cursor,
                        0,
                        count,
                        index,
                    ));
                    self.read_page(new_cursor);
                } else if index + 1 < count {
                    // We just recorded a complete PHDR; advance to the next PHDR
                    // and read the next page required.
                    new_page_offset += core::mem::size_of::<Elf32Phdr>();
                    let new_cursor = if new_page_offset >= self.page_len as usize {
                        new_page_offset %= self.page_len as usize;
                        cursor + self.page_len
                    } else {
                        // NB: this reads the same flash page for multiple PHDRs
                        cursor
                    };
                    self.state.set(ElfLoaderState::LoadProgramHeaderTable(
                        new_cursor,
                        new_page_offset,
                        count,
                        index + 1,
                    ));
                    self.read_page(new_cursor);
                } else {
                    // Record state derived from phdrs before they get clobbered.
                    let mut segment_load_offset = 0;
                    self.current_task.map(|task| {
                        let mut sel4_state = self.sel4_state.get();
                        match task {
                            LoadTasks::LoadElf("kernel") => {
                                // Stash end of kernel for loading capdl-loader
                                sel4_state.kernel_phys_max = matcha_utils::elf_loader::elf_phys_max_opt(&sel4_state.phdrs);

                                dprintf!("Loading seL4 kernel\r\n");
                                segment_load_offset = 0;
                            }
                            LoadTasks::LoadElf("capdl-loader") => {
                                sel4_state.capdl_phys_max = matcha_utils::elf_loader::elf_phys_max_opt(&sel4_state.phdrs);
                                sel4_state.capdl_phys_min = matcha_utils::elf_loader::elf_phys_min_opt(&sel4_state.phdrs);

                                dprintf!("Loading capdl-loader to the page after seL4\r\n");
                                segment_load_offset = matcha_utils::round_up_to_page(
                                    sel4_state.kernel_phys_max
                                ) - sel4_state.capdl_phys_min;
                            }
                            _ => unreachable!(),
                        };
                        self.sel4_state.replace(sel4_state);
                    });
                    self.load_elf_segments(segment_load_offset);
                }
            }
            _ => unreachable!(),
        }
    }

    // Initiates loading the ELF segments for the current file.
    fn load_elf_segments(&self, offset: u32) {
        let phdr = self.sel4_state.get().phdrs[0].unwrap();
        // XXX assumes p_filesz > 0; consolidate with code below
        let mod_offset = phdr.p_offset % self.page_len;
        let div_offset = phdr.p_offset / self.page_len;
        let load_offset = self.sel4_state.get().load_offset;
        assert!(load_offset  != INVALID_LOAD_OFFSET);
        let cursor = load_offset + (div_offset * self.page_len);
        let new_state =
            ElfLoaderState::LoadSegmentsNew(phdr, 0, cursor, offset, mod_offset, 0);
        self.state.set(new_state);
        // XXX wrong, but preserve for now
        assert!(phdr.p_type == PT_LOAD && phdr.p_filesz > 0);
        print_segment(0, cursor, offset, &phdr);
        self.read_page(cursor);
    }

    // Callback from |read_complete| after reading a page for the current
    // loadable program segment. Data may be copied to the SMC memory or
    // SMC memory may be zero'd (e.g. for bss). If the segment is completed
    // advamce to the next load segment. If all segments have been processed
    // advance to the next task.
    fn load_segment_callback(&self) {
        match self.state.get() {
            ElfLoaderState::LoadSegmentsNew(
                phdr,
                index,
                cursor,
                offset,
                mod_offset,
                loaded,
            ) => {
                // Copy read data to SMC ram.
                let bytes_in_this_page = self.page_len - mod_offset;
                self.read_page.map(|page| {
                    let mut_page = page.as_mut();
                    // NB: read_page always reads a full page from flash, maybe copy less
                    let bytes_to_copy =
                        if loaded + bytes_in_this_page >= phdr.p_filesz {
                            phdr.p_filesz - loaded
                        } else { bytes_in_this_page };
                    smc_ram_memcpy(
                        &mut mut_page[(mod_offset as usize)..],
                        phdr.p_paddr + offset + loaded,
                        bytes_to_copy as usize,
                    );
                });
                let progress = loaded + bytes_in_this_page;
                if progress < phdr.p_filesz {
                    // Same program header, more file data; issue a new read.
                    let next_cursor = cursor + self.page_len;
                    let new_state = ElfLoaderState::LoadSegmentsNew(
                        phdr, index, next_cursor, offset, /*mod_offset=*/ 0, progress,
                    );
                    self.state.set(new_state);
                    self.read_page(next_cursor);
                } else {
                    // No more file data to read for this program header.
                    if phdr.p_filesz < phdr.p_memsz {
                        // Zero-pad remainder of segment.
                        let zero_bytes = phdr.p_memsz - phdr.p_filesz;
                        matcha_utils::smc_ram_zero(
                            phdr.p_paddr + offset + phdr.p_filesz,
                            zero_bytes as usize
                        );
                    }
                    // Current segment is complete, advance to the next segment.
                    let next_index = index + 1;
                    if next_index >= self.sel4_state.get().phdrs.len() ||
                       self.sel4_state.get().phdrs[next_index] == None {
                        // Last load segment, advance to the next task.
                        self.state.set(ElfLoaderState::Idle);
                        self.run_next_task();
                    } else {
                        // Setup reading the next load segment.
                        let next_phdr = self.sel4_state.get().phdrs[next_index].unwrap();
                        let (mod_offset, div_offset) = if next_phdr.p_filesz > 0 {
                            (next_phdr.p_offset % self.page_len, next_phdr.p_offset / self.page_len)
                        } else {
                            (self.page_len, 0)
                        };
                        let next_cursor = self.sel4_state.get().load_offset + (div_offset * self.page_len);

                        print_segment(next_index, next_cursor, offset, &next_phdr);

                        let new_state =
                            ElfLoaderState::LoadSegmentsNew(next_phdr, next_index, next_cursor, offset, mod_offset, 0);
                        self.state.set(new_state);
                        self.read_page(next_cursor);
                    }
                }
            }
            _ => unreachable!(),
        }
    }
}

impl<'a, F: hil::flash::Flash> Driver for ElfLoaderCapsule<'a, F> {
    fn subscribe(&self, _: usize, _: Option<Callback>, _: AppId) -> ReturnCode {
        return ReturnCode::EINVAL;
    }

    fn command(&self, minor_num: usize, _r2: usize, _r3: usize, _app_id: AppId) -> ReturnCode {
        if minor_num == matcha_config::CMD_ELFLOADER_BOOT_SEL4 {
            match self.flash {
                Some(_) => {
                    #[cfg(debug_assertions)]
                    let debug = true;

                    #[cfg(not(debug_assertions))]
                    let debug = false;
                    if !debug {
                        self.load_sel4();
                    } else {
                        dprintf!("Debug; bypass loading from SPI flash\r\n");
                        self.mailbox_hal.map(|mb| {
                            matcha_utils::load_sel4(mb);
                        });
                    }
                    return ReturnCode::SUCCESS;
                }
                None => {
                    dprintf!("No flash available\r\n");
                    return ReturnCode::EINVAL;
                }
            }
        }

        return ReturnCode::EINVAL;
    }

    fn allow(&self, _: AppId, _: usize, _: Option<AppSlice<Shared, u8>>) -> ReturnCode {
        return ReturnCode::EINVAL;
    }
}

impl<'a, F: hil::flash::Flash> hil::flash::Client<capsules::virtual_flash::FlashUser<'a, F>>
    for ElfLoaderCapsule<'a, F>
{
    fn read_complete(&self, read_page: &'static mut F::Page, _err: kernel::hil::flash::Error) {
        self.read_page.replace(read_page);
        self.flash_busy.set(false);
        match self.state.get() {
            ElfLoaderState::FindingFile(..) => self.find_file_callback(),
            ElfLoaderState::LoadElfHeader(..) =>
                self.load_elf_header_callback(),
            ElfLoaderState::LoadProgramHeaderTable(..) =>
                self.load_program_header_table_callback(),
            ElfLoaderState::LoadSegmentsNew(..) =>
                self.load_segment_callback(),
            _ => panic!("583"),
        }
    }
    fn write_complete(
        &self,
        _: &'static mut <F as kernel::hil::flash::Flash>::Page,
        _: kernel::hil::flash::Error,
    ) {
        todo!()
    }
    fn erase_complete(&self, _: kernel::hil::flash::Error) {
        todo!()
    }
}
