// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Trivial Shodan elf loader capsule

// TODO(sleffler): cleanup panic msgs (currently fuzzy line numbers)

use core::cell::Cell;
use core::marker::PhantomData;
use kernel::common::cells::{OptionalCell, TakeCell};
use kernel::hil;
use kernel::hil::flash::Flash;
use kernel::{AppId, AppSlice, Callback, Driver, ReturnCode, Shared};
use matcha_hal::dprintf;
use matcha_hal::mailbox_hal::MailboxHAL;
use matcha_hal::smc_ctrl_hal;
use matcha_utils::elf_loader::{Elf32Header, Elf32Phdr};
use matcha_utils::smc_ram_memcpy;
use matcha_utils::tar_loader::TarHeader;

const PT_LOAD: u32 = 1;  // Loadable segment

pub struct ElfLoaderCapsule<'a, F: hil::flash::Flash + 'static> {
    mailbox_hal: Option<&'static dyn MailboxHAL>,
    smc_ctrl_hal: Option<&'static dyn smc_ctrl_hal::SmcCtrlHal>,
    flash: Option<&'static capsules::virtual_flash::FlashUser<'static, F>>,
    flash_busy: Cell<bool>,
    read_page: TakeCell<'static, F::Page>,
    page_len: u32,
    state: Cell<ElfLoaderState>,
    phantom: PhantomData<&'a ()>,
    tasks: Cell<[(LoadTasks, bool); 7]>,
    current_task: OptionalCell<LoadTasks>,
    sel4_state: Cell<SEL4State>,
}

#[derive(Clone, Copy, PartialEq)]
enum LoadTasks {
    FindFile(&'static str /* name */),
    LoadElfHeaders(&'static str /* name */),
    LoadElf(&'static str /* name */),
    StartSmc,
}

#[derive(Clone, Copy)]
struct SEL4State {
    kernel_offset: u32,
    capdl_loader_offset: u32,
    kernel_entry_point: u32,
    capdl_loader_entry_point: u32,
    kernel_headers: [Option<Elf32Phdr>; 4],
    capdl_loader_headers: [Option<Elf32Phdr>; 4],
    kernel_bss_start: u32,
    kernel_bss_size: u32,
    capdl_bss_start: u32,
    capdl_bss_size: u32,
}

impl Default for SEL4State {
    fn default() -> SEL4State {
        SEL4State {
            kernel_offset: 0,
            capdl_loader_offset: 0,
            kernel_entry_point: 0,
            capdl_loader_entry_point: 0,
            kernel_headers: [None, None, None, None],
            capdl_loader_headers: [None, None, None, None],
            kernel_bss_start: 0,
            kernel_bss_size: 0,
            capdl_bss_start: 0,
            capdl_bss_size: 0,
        }
    }
}

#[derive(Clone, Copy, PartialEq)]
enum ElfLoaderState {
    Idle,
    FindingFile(&'static str /* name */, u32 /* cursor */),
    LoadElfHeader(u32 /* cursor */),
    LoadProgramHeaderTable(
        u32, /* cursor */
        u32, /* offset */
        u16, /* num */
        u16, /* index */
    ),
    LoadSegmentsNew(
        &'static str, /* name */
        Elf32Phdr,    /* phdr */
        usize,        /* index */
        u32,          /* cursor */
        u32,          /* dst offset */
        u32,          /* offset_in_page */
        u32,          /* already loaded */
    ),
}

impl<'a, F: hil::flash::Flash> ElfLoaderCapsule<'a, F> {
    pub fn new() -> Self {
        Self {
            mailbox_hal: None,
            smc_ctrl_hal: None,
            flash: None,
            flash_busy: Cell::new(false),
            read_page: TakeCell::empty(),
            page_len: 0,
            state: Cell::new(ElfLoaderState::Idle),
            phantom: PhantomData,
            tasks: Cell::new([
                (LoadTasks::FindFile("kernel"), false),
                (LoadTasks::FindFile("capdl-loader"), false),
                (LoadTasks::LoadElfHeaders("kernel"), false),
                (LoadTasks::LoadElfHeaders("capdl-loader"), false),
                (LoadTasks::LoadElf("kernel"), false),
                (LoadTasks::LoadElf("capdl-loader"), false),
                (LoadTasks::StartSmc, false),
            ]),
            current_task: OptionalCell::empty(),
            sel4_state: Cell::new(SEL4State::default()),
        }
    }

    pub fn set_smc_ctrl(&mut self, smc_ctrl: &'static dyn smc_ctrl_hal::SmcCtrlHal) {
        self.smc_ctrl_hal = Some(smc_ctrl);
    }

    pub fn set_mailbox(&mut self, mailbox: &'static dyn MailboxHAL) {
        self.mailbox_hal = Some(mailbox);
    }

    pub fn set_flash(
        &mut self,
        flash: &'static capsules::virtual_flash::FlashUser<'static, F>,
        read_page: &'static mut F::Page,
    ) {
        self.flash = Some(flash);
        self.read_page.replace(read_page);
        let mut page_len = 0;
        self.read_page.map(|page| {
            let mut_page = page.as_mut();
            page_len = mut_page.len() as u32;
        });
        self.page_len = page_len;
    }

    pub fn run_next_task(&self) {
        self.current_task.clear();
        let mut tasks = self.tasks.get();
        for i in 0..tasks.len() {
            let (task, done) = tasks[i];
            if !done {
                tasks[i] = (task, true);
                self.current_task.replace(task);
                break;
            }
        }
        self.tasks.replace(tasks);

        self.current_task.map_or_else(
            || {
                dprintf!("No current task\r\n");
                loop {}
            },
            |task| match task {
                LoadTasks::FindFile(file) => {
                    self.find_file(file);
                }
                LoadTasks::LoadElfHeaders(file) => {
                    self.load_elf_headers(file);
                }
                LoadTasks::LoadElf("kernel") => {
                    dprintf!("Loading seL4 kernel\r\n");
                    self.load_elf("kernel", 0);
                }
                LoadTasks::LoadElf("capdl-loader") => {
                    dprintf!("Loading capdl-loader to the page after seL4\r\n");
                    let sel4_pend =
                        matcha_utils::round_up_to_page(matcha_utils::elf_loader::elf_phys_max_opt(
                            &self.sel4_state.get().kernel_headers,
                        ));
                    let offset = sel4_pend
                        - matcha_utils::elf_loader::elf_phys_min_opt(
                            &self.sel4_state.get().capdl_loader_headers,
                        );
                    self.load_elf("capdl-loader", offset);
                }
                LoadTasks::StartSmc => {
                    // NB: ui_p_reg_start & co. come from the code in seL4
                    //    the processes these values.
                    let ui_p_reg_start =
                        matcha_utils::round_up_to_page(matcha_utils::elf_loader::elf_phys_max_opt(
                            &self.sel4_state.get().kernel_headers,
                        ));
                    let ui_p_reg_end = ui_p_reg_start
                        + matcha_utils::round_up_to_page(
                            matcha_utils::elf_loader::elf_phys_max_opt(
                                &self.sel4_state.get().capdl_loader_headers,
                            ) - matcha_utils::elf_loader::elf_phys_min_opt(
                                &self.sel4_state.get().capdl_loader_headers,
                            ),
                        );
                    let pv_offset = ui_p_reg_start
                        - matcha_utils::elf_loader::elf_phys_min_opt(
                            &self.sel4_state.get().capdl_loader_headers,
                        );
                    let v_entry = self.sel4_state.get().capdl_loader_entry_point;
/*
                    dprintf!(
                        "{:X} {:X} {:X} {:X}\r\n",
                        ui_p_reg_start,
                        ui_p_reg_end,
                        pv_offset,
                        v_entry
                    );
                    dprintf!(
                        "{:X} {:X} {:X}\r\n",
                        ui_p_reg_start,
                        matcha_utils::elf_loader::elf_phys_max_opt(
                            &self.sel4_state.get().capdl_loader_headers
                        ),
                        matcha_utils::elf_loader::elf_phys_min_opt(
                            &self.sel4_state.get().capdl_loader_headers
                        ),
                    );
*/
                    matcha_utils::smc_ram_zero(
                        self.sel4_state.get().kernel_bss_start,
                        self.sel4_state.get().kernel_bss_size as usize
                    );
                    matcha_utils::smc_ram_zero(
                        self.sel4_state.get().capdl_bss_start,
                        self.sel4_state.get().capdl_bss_size as usize
                    );
                    match self.smc_ctrl_hal {
                        Some(smc_ctrl) => {
                            self.mailbox_hal.map(|mb| {
                                matcha_utils::smc_send_bootmsg(mb, [
                                      ui_p_reg_start,
                                      ui_p_reg_end,
                                      pv_offset,
                                      v_entry,
                                ])
                            });
                            smc_ctrl.smc_ctrl_start();
                        }
                        None => {
                            panic!("221");
                        }
                    }
                }
                _ => panic!("225"),
            },
        );
    }

    pub fn load_sel4(&self) {
        // Reset task list to start over.
        let mut tasks = self.tasks.get();
        for i in 0..tasks.len() {
            tasks[i].1 = false;
        }
        self.tasks.replace(tasks);
        self.run_next_task();
    }

    fn load_elf_headers(&self, name: &'static str) {
        match name {
            "kernel" => self.state.set(ElfLoaderState::LoadElfHeader(
                self.sel4_state.get().kernel_offset,
            )),
            "capdl-loader" => self.state.set(ElfLoaderState::LoadElfHeader(
                self.sel4_state.get().capdl_loader_offset,
            )),
            _ => panic!("207"),
        }
        match self.state.get() {
            ElfLoaderState::LoadElfHeader(cursor) => self.read_page(cursor),
            _ => panic!("214"),
        }
    }

    fn load_elf(&self, name: &'static str, offset: u32) {
        let phdr = match name {
            "kernel" => self.sel4_state.get().kernel_headers[0].unwrap(),
            "capdl-loader" => self.sel4_state.get().capdl_loader_headers[0].unwrap(),
            _ => panic!("290"),
        };
        let mod_offset = phdr.p_offset % self.page_len;
        let div_offset = phdr.p_offset / self.page_len;
        let cursor = match name {
            "kernel" => self.sel4_state.get().kernel_offset,
            "capdl-loader" => self.sel4_state.get().capdl_loader_offset,
            _ => panic!("295"),
        } + (div_offset * self.page_len);
        let new_state =
            ElfLoaderState::LoadSegmentsNew(name, phdr, 0, cursor, offset, mod_offset, 0);
        self.state.set(new_state);
        if phdr.p_type == PT_LOAD && phdr.p_filesz != 0 {
            dprintf!(
                "seg 0:{:X} -> {:X}:{:X} ({} bytes)\r\n",
                cursor,
                phdr.p_paddr + offset,
                phdr.p_paddr + offset + phdr.p_filesz,
                { phdr.p_filesz },
            );
            let bss_size = phdr.p_memsz - phdr.p_filesz;
            let bss_start = phdr.p_paddr + offset + phdr.p_filesz;
            let bss_end = bss_start + bss_size;

            dprintf!(
                "bss 0:{:X} -> {:X}:{:X} ({} bytes)\r\n",
                cursor,
                bss_start,
                bss_end,
                bss_size
            );
            match name {
                "kernel" => {
                    let mut sel4_state = self.sel4_state.get();
                    sel4_state.kernel_bss_start = bss_start;
                    sel4_state.kernel_bss_size = bss_size;
                    self.sel4_state.replace(sel4_state);
                },
                "capdl-loader" => {
                    let mut sel4_state = self.sel4_state.get();
                    sel4_state.capdl_bss_start = bss_start;
                    sel4_state.capdl_bss_size = bss_size;
                    self.sel4_state.replace(sel4_state);
                },
                _ => panic!("349"),
            };
            // matcha_utils::smc_ram_zero(bss_start, bss_size as usize);
            self.read_page(cursor);
        } else {
            panic!("295");
        }
    }

    fn find_file(&self, name: &'static str) {
        self.state.set(ElfLoaderState::FindingFile(name, 0));
        self.read_page(0);
    }

    fn load_elf_header_callback(&self) {
        self.read_page.map(|page| {
            let mut_page = page.as_mut();
            let elf_header = Elf32Header::from_bytes(mut_page);
            if !elf_header.check_magic() {
                dprintf!("bad elf magic\r\n");
                panic!("233");
            }
            self.current_task.map(|task| {
                match task {
                    LoadTasks::LoadElfHeaders("kernel") => {
                        let mut sel4_state = self.sel4_state.get();
                        sel4_state.kernel_entry_point = elf_header.e_entry;
                        self.sel4_state.replace(sel4_state);
                    }
                    LoadTasks::LoadElfHeaders("capdl-loader") => {
                        let mut sel4_state = self.sel4_state.get();
                        sel4_state.capdl_loader_entry_point = elf_header.e_entry;
                        self.sel4_state.replace(sel4_state);
                    }
                    _ => panic!("290"),
                };
            });
            match self.state.get() {
                ElfLoaderState::LoadElfHeader(cursor) => {
                    self.state.set(ElfLoaderState::LoadProgramHeaderTable(
                        cursor,
                        elf_header.phoff(),
                        elf_header.phnum(),
                        0,
                    ));
                }
                _ => panic!("245"),
            }
        });
        match self.state.get() {
            ElfLoaderState::LoadProgramHeaderTable(cursor, ..) =>
                self.read_page(cursor),
            _ => panic!("254"),
        }
    }

    fn load_program_header_table_callback(&self) {
        self.read_page.map(|page| {
            let mut_page = page.as_mut();
            match self.state.get() {
                ElfLoaderState::LoadProgramHeaderTable(_, offset, _, index) => {
                    let phdr = Elf32Phdr::from_bytes(
                        &mut mut_page[(((offset as usize)
                            + (index as usize * core::mem::size_of::<Elf32Phdr>()))
                            as usize)..],
                    );
                    if phdr.p_type == PT_LOAD
                    /* && (phdr.p_filesz != 0)*/
                    {
                        self.current_task.map(|task| {
                            match task {
                                LoadTasks::LoadElfHeaders("kernel") => {
                                    let mut sel4_state = self.sel4_state.get();
                                    let mut next_slot: Option<usize> = None;
                                    for i in 0..sel4_state.kernel_headers.len() {
                                        if sel4_state.kernel_headers[i] == None {
                                            next_slot = Some(i);
                                            break;
                                        }
                                    }
                                    match next_slot {
                                        Some(slot) => {
                                            sel4_state.kernel_headers[slot] = Some(phdr);
                                        }
                                        None => panic!("380"),
                                    }
                                    self.sel4_state.replace(sel4_state);
                                }
                                LoadTasks::LoadElfHeaders("capdl-loader") => {
                                    let mut sel4_state = self.sel4_state.get();
                                    let mut next_slot: Option<usize> = None;
                                    for i in 0..sel4_state.capdl_loader_headers.len() {
                                        if sel4_state.capdl_loader_headers[i] == None {
                                            next_slot = Some(i);
                                            break;
                                        }
                                    }
                                    match next_slot {
                                        Some(slot) => {
                                            sel4_state.capdl_loader_headers[slot] = Some(phdr);
                                        }
                                        None => panic!("380"),
                                    }
                                    self.sel4_state.replace(sel4_state);
                                }
                                _ => panic!("317"),
                            };
                        });
                    }
                }
                _ => panic!("335"),
            };
        });
        match self.state.get() {
            ElfLoaderState::LoadProgramHeaderTable(cursor, offset, count, index) => {
                if index + 1 < count {
                    self.state.set(ElfLoaderState::LoadProgramHeaderTable(
                        cursor,
                        offset,
                        count,
                        index + 1,
                    ));
                    self.read_page(cursor);
                } else {
                    self.state.set(ElfLoaderState::Idle);
                    self.run_next_task();
                }
            }
            _ => panic!("375"),
        }
    }

    fn load_segment_callback(&self) {
        match self.state.get() {
            ElfLoaderState::LoadSegmentsNew(
                name,
                phdr,
                index,
                cursor,
                offset,
                mod_offset,
                loaded,
            ) => {
                let bytes_in_this_page = self.page_len - mod_offset;
                // Copy data into SMC ram.
                self.read_page.map(|page| {
                    let mut_page = page.as_mut();
                    smc_ram_memcpy(
                        &mut mut_page[(mod_offset as usize)..],
                        phdr.p_paddr + offset + loaded,
                        bytes_in_this_page as usize,
                    );
                });
                let progress = loaded + bytes_in_this_page;
                if progress < phdr.p_filesz {
                    // Another page
                    let new_cursor = cursor + self.page_len;
                    let new_state = ElfLoaderState::LoadSegmentsNew(
                        name, phdr, index, new_cursor, offset, 0, progress,
                    );
                    self.state.set(new_state);
                    self.read_page(new_cursor);
                } else {
                    let segments = match name {
                        "kernel" => self.sel4_state.get().kernel_headers,
                        "capdl-loader" => self.sel4_state.get().capdl_loader_headers,
                        _ => panic!("461"),
                    };
                    let next_index = index + 1;
                    let mut last_segment = false;
                    if next_index == segments.len() {
                        last_segment = true;
                    }
                    if !last_segment {
                        last_segment = match segments[next_index] {
                            Some(_) => false,
                            None => true,
                        };
                    }
                    if last_segment {
                        self.state.set(ElfLoaderState::Idle);
                        self.run_next_task();
                    } else {
                        let next_segment = segments[next_index].unwrap();
                        if next_segment.p_memsz != 0 && next_segment.p_filesz == 0 {
                            // some bss segment, let's clear the memory.
                            matcha_utils::smc_ram_zero(next_segment.p_paddr + offset, next_segment.p_memsz as usize);
                        }
                        self.state.set(ElfLoaderState::Idle);
                        self.run_next_task();
                        // TODO(atv): Finish this here.
                        // panic!("476");
                        // let next_segment = segments[next_index];
                        // let new_cursor = ?;
                        // let mod_offset = ?;
                        // let new_state = ElfLoaderState::LoadSegmentsNew(name, next_segment, next_index, new_cursor, offset, mod_offset, 0);
                        // self.state.set(new_state);
                        // self.read_page(new_cursor);
                    }
                }
            }
            _ => panic!("445"),
        }
    }

    fn find_file_callback(&self) {
        self.read_page.map(|page| {
            let mut_page = page.as_mut();
            let tar_header = TarHeader::from_bytes(mut_page);
            match self.state.get() {
                ElfLoaderState::FindingFile(name, cursor) => {
                    let found_file = tar_header.name().contains(name);
                    self.current_task.map(|task| match task {
                        LoadTasks::FindFile(_file) => {
                            if found_file {
                                let mut sel4_state = self.sel4_state.get();
                                match name {
                                    "kernel" => {
                                        sel4_state.kernel_offset = cursor + self.page_len;
                                    }
                                    "capdl-loader" => {
                                        sel4_state.capdl_loader_offset = cursor + self.page_len;
                                    }
                                    _ => {}
                                }
                                self.sel4_state.replace(sel4_state);
                                self.state.set(ElfLoaderState::Idle);
                            } else {
                                let new_cursor =
                                    self.page_len + cursor + ((tar_header.size() + 511) & !511);
                                self.state
                                    .set(ElfLoaderState::FindingFile(name, new_cursor));
                            }
                        }
                        LoadTasks::LoadElf(_file) => {
                            if found_file {
                                self.state
                                    .set(ElfLoaderState::LoadElfHeader(cursor + self.page_len));
                            } else {
                                let new_cursor =
                                    self.page_len + cursor + ((tar_header.size() + 511) & !511);
                                self.state
                                    .set(ElfLoaderState::FindingFile(name, new_cursor));
                            }
                        }
                        _ => panic!("487"),
                    })
                }
                _ => panic!("491"),
            };
        });

        match self.state.get() {
            ElfLoaderState::FindingFile(_, cursor) => self.read_page(cursor),
            ElfLoaderState::LoadElfHeader(cursor) => self.read_page(cursor),
            ElfLoaderState::Idle => self.run_next_task(),
            _ => panic!("507"),
        }
    }

    fn read_page(&self, page: u32) {
        self.flash.map(|flash| {
            self.read_page.take().map(|read_page| {
                self.flash_busy.set(true);
                if let Err((_, buf)) = flash.read_page((page / self.page_len) as usize, read_page) {
                    self.read_page.replace(buf);
                }
            });
        });
    }
}

impl<'a, F: hil::flash::Flash> Driver for ElfLoaderCapsule<'a, F> {
    fn subscribe(&self, _: usize, _: Option<Callback>, _: AppId) -> ReturnCode {
        return ReturnCode::EINVAL;
    }

    fn command(&self, minor_num: usize, _r2: usize, _r3: usize, _app_id: AppId) -> ReturnCode {
        if minor_num == matcha_config::CMD_ELFLOADER_BOOT_SEL4 {
            match self.flash {
                Some(_) => {
                    #[cfg(debug_assertions)]
                    let debug = true;

                    #[cfg(not(debug_assertions))]
                    let debug = false;
                    if !debug {
                        self.load_sel4();
                    } else {
                        dprintf!("Debug; bypass loading from SPI flash\r\n");
                        self.mailbox_hal.map(|mb| {
                            matcha_utils::load_sel4(mb);
                        });
                    }
                    return ReturnCode::SUCCESS;
                }
                None => {
                    dprintf!("No flash available\r\n");
                    return ReturnCode::EINVAL;
                }
            }
        }

        return ReturnCode::EINVAL;
    }

    fn allow(&self, _: AppId, _: usize, _: Option<AppSlice<Shared, u8>>) -> ReturnCode {
        return ReturnCode::EINVAL;
    }
}

impl<'a, F: hil::flash::Flash> hil::flash::Client<capsules::virtual_flash::FlashUser<'a, F>>
    for ElfLoaderCapsule<'a, F>
{
    fn read_complete(&self, read_page: &'static mut F::Page, _err: kernel::hil::flash::Error) {
        self.read_page.replace(read_page);
        self.flash_busy.set(false);
        match self.state.get() {
            ElfLoaderState::FindingFile(..) => self.find_file_callback(),
            ElfLoaderState::LoadElfHeader(..) =>
                self.load_elf_header_callback(),
            ElfLoaderState::LoadProgramHeaderTable(..) =>
                self.load_program_header_table_callback(),
            ElfLoaderState::LoadSegmentsNew(..) =>
                self.load_segment_callback(),
            _ => panic!("583"),
        }
    }
    fn write_complete(
        &self,
        _: &'static mut <F as kernel::hil::flash::Flash>::Page,
        _: kernel::hil::flash::Error,
    ) {
        todo!()
    }
    fn erase_complete(&self, _: kernel::hil::flash::Error) {
        todo!()
    }
}
