// Trivial elf segment loader, just copies segments straight from the in-memory
// blob to the physical addresses specified in the program header.

use core::cmp;
use core::mem::transmute;
use core::ptr;
use core::slice;

use crate::tar_loader;
use matcha_hal::dprintf;

pub const ELF_MAGIC: u32 = 0x464c457f;

#[repr(C, packed)]
#[derive(Debug, Copy, Clone)]
pub struct Elf32Header {
    pub e_ident: [u8; 16],
    pub e_type: u16,      /* Relocatable=1, Executable=2 (+ some more ..) */
    pub e_machine: u16,   /* Target architecture: MIPS=8 */
    pub e_version: u32,   /* Elf version (should be 1) */
    pub e_entry: u32,     /* Code entry point */
    pub e_phoff: u32,     /* Program header table */
    pub e_shoff: u32,     /* Section header table */
    pub e_flags: u32,     /* Flags */
    pub e_ehsize: u16,    /* ELF header size */
    pub e_phentsize: u16, /* Size of one program segment header */
    pub e_phnum: u16,     /* Number of program segment headers */
    pub e_shentsize: u16, /* Size of one section header */
    pub e_shnum: u16,     /* Number of section headers */
    pub e_shstrndx: u16,  /* Section header index of the string table for section
                          header names */
}

#[repr(C, packed)]
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct Elf32Phdr {
    pub p_type: u32,   /* Segment type: Loadable segment = 1 */
    pub p_offset: u32, /* Offset of segment in file */
    pub p_vaddr: u32,  /* Reqd virtual address of segment when loading */
    pub p_paddr: u32,  /* Reqd physical address of segment (ignore) */
    pub p_filesz: u32, /* How many bytes this segment occupies in file */
    pub p_memsz: u32,  /* How many bytes this segment should occupy in memory */
    pub p_flags: u32,  /* Flags: logical "or" of PF_ constants below */
    pub p_align: u32,  /* Reqd alignment of segment in memory */
}

impl Elf32Phdr {
    pub fn from_bytes(b: &mut [u8]) -> Elf32Phdr {
        unsafe { ptr::read(b.as_ptr() as *const Elf32Phdr) }
    }
}

impl Elf32Header {
    pub fn from_bytes(b: &mut [u8]) -> Elf32Header {
        unsafe { ptr::read(b.as_ptr() as *const Elf32Header) }
    }

    pub fn check_magic(&self) -> bool {
        unsafe {
            let pmagic: *const u32 = transmute(&self.e_ident[0]);
            if pmagic.read_volatile() == ELF_MAGIC {
                return true;
            }
        }
        false
    }

    pub fn phoff(&self) -> u32 {
        self.e_phoff
    }

    pub fn phnum(&self) -> u16 {
        self.e_phnum
    }
}

pub fn elf_phys_min(segments: &[Elf32Phdr]) -> u32 {
    return segments.iter().fold(0xFFFFFFFF, |acc, seg| {
        if seg.p_type != 1 {
            return acc;
        } else {
            return cmp::min(acc, seg.p_paddr);
        }
    });
}

pub fn elf_phys_min_opt(segments: &[Option<Elf32Phdr>]) -> u32 {
    return segments.iter().fold(0xFFFFFFFF, |acc, seg| match seg {
        Some(seg) => {
            if seg.p_type != 1 {
                return acc;
            } else {
                return cmp::min(acc, seg.p_paddr);
            }
        }
        None => {
            return acc;
        }
    });
}

pub fn elf_phys_max_opt(segments: &[Option<Elf32Phdr>]) -> u32 {
    return segments.iter().fold(0, |acc, seg| match seg {
        Some(seg) => {
            if seg.p_type != 1 {
                return acc;
            } else {
                return cmp::max(acc, seg.p_paddr + seg.p_memsz);
            }
        }
        None => {
            return acc;
        }
    });
}

pub fn elf_phys_max(segments: &[Elf32Phdr]) -> u32 {
    return segments.iter().fold(0, |acc, seg| {
        if seg.p_type != 1 {
            return acc;
        } else {
            return cmp::max(acc, seg.p_paddr + seg.p_memsz);
        }
    });
}

pub fn elf_virt_min(segments: &[Elf32Phdr]) -> u32 {
    return segments.iter().fold(0xFFFFFFFF, |acc, seg| {
        if seg.p_type != 1 {
            return acc;
        } else {
            return cmp::min(acc, seg.p_vaddr);
        }
    });
}

pub fn elf_virt_min_opt(segments: &[Option<Elf32Phdr>]) -> u32 {
    return segments.iter().fold(0xFFFFFFFF, |acc, seg| match seg {
        Some(seg) => {
            if seg.p_type != 1 {
                return acc;
            } else {
                return cmp::min(acc, seg.p_vaddr);
            }
        }
        None => {
            return acc;
        }
    });
}

pub fn elf_virt_max(segments: &[Elf32Phdr]) -> u32 {
    return segments.iter().fold(0, |acc, seg| {
        if seg.p_type != 1 {
            return acc;
        } else {
            return cmp::max(acc, seg.p_vaddr + seg.p_memsz);
        }
    });
}

pub fn elf_virt_max_opt(segments: &[Option<Elf32Phdr>]) -> u32 {
    return segments.iter().fold(0, |acc, seg| match seg {
        Some(seg) => {
            if seg.p_type != 1 {
                return acc;
            } else {
                return cmp::max(acc, seg.p_vaddr + seg.p_memsz);
            }
        }
        None => {
            return acc;
        }
    });
}

pub unsafe fn elf_get_segments(elf: *const Elf32Header) -> &'static [Elf32Phdr] {
    let base: *const u8 = transmute(elf);
    let seg_count = (*elf).e_phnum as usize;
    let segments: &[Elf32Phdr] =
        slice::from_raw_parts(transmute(base.offset((*elf).e_phoff as isize)), seg_count);
    return segments;
}

pub unsafe fn find_elf(name: &str) -> *const Elf32Header {
    let (tar_offset, tar_size) = tar_loader::find_file(name);

    if tar_size > 0 {
        let pmagic: *const u32 = transmute(tar_offset);
        if pmagic.read_volatile() == ELF_MAGIC {
            let elf_header: *const Elf32Header = transmute(tar_offset);
            return elf_header;
        }
    }

    dprintf!("Could not find elf for '{}'\n", name);
    return transmute(0);
}

pub unsafe fn load_elf_segments(elf: &Elf32Header, dst_offset: u32) {
    let base: *const u8 = transmute(elf);
    let seg_count = elf.e_phnum as usize;
    let segments: &[Elf32Phdr] =
        slice::from_raw_parts(transmute(base.offset(elf.e_phoff as isize)), seg_count);

    for i in 0..segments.len() {
        let seg = &segments[i];

        // Loadable segment type.
        const PT_LOAD: u32 = 1;
        if seg.p_type != PT_LOAD {
            continue;
        }
        if seg.p_filesz == 0 {
            continue;
        }

        let src: *const u8 = transmute(base.offset(seg.p_offset as isize));

        let txt_size = seg.p_filesz as usize;
        let txt_start: *mut u8 = transmute(seg.p_paddr + dst_offset);
        let txt_end = txt_start.offset(txt_size as isize);

        dprintf!(
            "seg {}: {:X} -> {:X}:{:X} ({} bytes)\n",
            i,
            src as usize,
            txt_start as usize,
            txt_end as usize,
            txt_size
        );
        ptr::copy_nonoverlapping(src, txt_start, txt_size);

        let bss_size = (seg.p_memsz - seg.p_filesz) as usize;
        let bss_start = txt_end;
        let bss_end = bss_start.offset(bss_size as isize);

        dprintf!(
            "bss {}: {:X} -> {:X}:{:X} ({} bytes)\n",
            i,
            src as usize,
            bss_start as usize,
            bss_end as usize,
            bss_size
        );
        ptr::write_bytes(bss_start, 0, bss_size as usize);
    }
}
