// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::errors::*;
use crate::utils::*;

pub struct BitVector<'a> {
    pub bits: &'a mut [u32],
}

impl<'a> BitVector<'a> {
    pub fn new(bits: &'a mut [u32]) -> Self {
        return BitVector { bits: bits };
    }

    pub fn check_pos(&self, pos: usize) -> Result<(), BFSErr> {
        dcheck!(pos <= self.bits.len() * 32, BFSErr::OutOfBounds);
        return Ok(());
    }

    pub fn check_range(&self, begin: usize, end: usize) -> Result<(), BFSErr> {
        self.check_pos(begin)?;
        self.check_pos(end)?;
        dcheck!(end >= begin, BFSErr::OutOfBounds);
        return Ok(());
    }

    pub fn clear_all(&mut self) {
        for i in 0..self.bits.len() {
            self.bits[i] = 0x00000000;
        }
    }

    pub fn set_all(&mut self) {
        for i in 0..self.bits.len() {
            self.bits[i] = 0xFFFFFFFF;
        }
    }

    pub fn get_bit(&self, pos: usize) -> Result<u32, BFSErr> {
        self.check_pos(pos)?;
        return Ok((self.bits[pos >> 5] >> (pos & 31)) & 1);
    }

    pub fn set_bit(&mut self, pos: usize) -> Result<(), BFSErr> {
        self.check_pos(pos)?;
        self.bits[pos >> 5] |= 1 << (pos & 31);
        return Ok(());
    }

    pub fn clear_bit(&mut self, pos: usize) -> Result<(), BFSErr> {
        self.check_pos(pos)?;
        self.bits[pos >> 5] &= !(1 << (pos & 31));
        return Ok(());
    }

    fn bit_mask(head: usize, tail: usize) -> Result<u32, BFSErr> {
        dcheck!(head < 32, BFSErr::OutOfBounds);
        dcheck!(tail < 32, BFSErr::OutOfBounds);

        let a = 0xFFFFFFFF << head;
        let b = 0xFFFFFFFF >> (32 - tail - 1);
        return Ok(a & b);
    }

    pub fn set_range(&mut self, begin: usize, end: usize) -> Result<(), BFSErr> {
        self.check_range(begin, end)?;

        let block_head = begin >> 5;
        let block_tail = (end - 1) >> 5;
        let bit_head = begin & 31;
        let bit_tail = (end - 1) & 31;

        if block_head == block_tail {
            let mask = BitVector::bit_mask(bit_head, bit_tail)?;
            self.bits[block_head as usize] |= mask;
        } else {
            let mask_head = BitVector::bit_mask(bit_head, 31)?;
            self.bits[block_head as usize] |= mask_head;

            for i in block_head + 1..block_tail {
                self.bits[i as usize] = 0xFFFFFFFF;
            }

            let mask_tail = BitVector::bit_mask(0, bit_tail)?;
            self.bits[block_tail as usize] |= mask_tail;
        }

        return Ok(());
    }

    pub fn clear_range(&mut self, begin: usize, end: usize) -> Result<(), BFSErr> {
        self.check_range(begin, end)?;

        let block_head = begin >> 5;
        let block_tail = (end - 1) >> 5;
        let bit_head = begin & 31;
        let bit_tail = (end - 1) & 31;

        if block_head == block_tail {
            let mask = BitVector::bit_mask(bit_head, bit_tail)?;
            self.bits[block_head] &= !mask;
        } else {
            let mask_head = BitVector::bit_mask(bit_head, 31)?;
            self.bits[block_head] &= !mask_head;

            for i in block_head + 1..block_tail {
                self.bits[i] = 0x00000000;
            }

            let mask_tail = BitVector::bit_mask(0, bit_tail)?;
            self.bits[block_tail] &= !mask_tail;
        }

        return Ok(());
    }

    pub fn count_range(&mut self, begin: usize, end: usize) -> Result<usize, BFSErr> {
        self.check_range(begin, end)?;

        let block_head = begin >> 5;
        let block_tail = (end - 1) >> 5;
        let bit_head = begin & 31;
        let bit_tail = (end - 1) & 31;

        let mut count: usize = 0;

        if block_head == block_tail {
            let mask = BitVector::bit_mask(bit_head, bit_tail)?;
            count += (self.bits[block_head] & mask).count_ones() as usize;
        } else {
            let mask_head = BitVector::bit_mask(bit_head, 31)?;
            count += (self.bits[block_head] & mask_head).count_ones() as usize;

            for i in block_head + 1..block_tail {
                count += self.bits[i].count_ones() as usize;
            }

            let mask_tail = BitVector::bit_mask(0, bit_tail)?;
            count += (self.bits[block_tail] & mask_tail).count_ones() as usize;
        }

        return Ok(count);
    }

    pub fn find_hole(&self, mut begin: usize, end: usize, width: usize) -> Result<usize, BFSErr> {
        self.check_range(begin, end)?;
        dcheck!(width <= end - begin, BFSErr::NotFound);

        let mut skip = false;
        let mut block_head = begin >> 5;
        let block_tail = (end - 1) >> 5;

        while (block_head <= block_tail) && (self.bits[block_head] == 0xFFFFFFFF) {
            skip = true;
            block_head = block_head + 1;
        }

        if block_head > block_tail {
            return Err(BFSErr::NotFound);
        }

        if skip {
            begin = block_head << 5;
        }

        for mut i in begin..=(end - width) {
            skip = false;
            for j in (i..i + width).rev() {
                if self.get_bit(j)? == 1 {
                    i = j + 1;
                    skip = true;
                    break;
                }
            }
            if !skip {
                return Ok(i);
            }
        }

        return Err(BFSErr::NotFound);
    }

    pub fn find_span(&self, mut begin: usize, end: usize, width: usize) -> Result<usize, BFSErr> {
        self.check_range(begin, end)?;
        dcheck!(width <= end - begin, BFSErr::NotFound);

        let mut skip = false;
        let mut block_head = begin >> 5;
        let block_tail = (end - 1) >> 5;

        while (block_head <= block_tail) && (self.bits[block_head] == 0x00000000) {
            skip = true;
            block_head = block_head + 1;
        }

        if block_head > block_tail {
            return Err(BFSErr::NotFound);
        }

        if skip {
            begin = block_head << 5;
        }

        for mut i in begin..=(end - width) {
            skip = false;
            for j in (i..i + width).rev() {
                if self.get_bit(j)? == 0 {
                    i = j + 1;
                    skip = true;
                    break;
                }
            }
            if !skip {
                return Ok(i);
            }
        }

        return Err(BFSErr::NotFound);
    }
}

/// For all vector sizes up to <N> and all possible hole sizes & positions in
/// that range, create a vector consisting of set bits outside the hole and
/// cleared bits inside the hole.
///
/// Verify that findHole() always finds holes equal to or smaller than the one
/// punched and fails to find holes larger than the one punched.
///
/// (This test is a bit slow, so we limit maximum bit vector size to 96 and run
/// tests in this crate in optimized mode)

#[test]
fn test_find_hole() {
    for size in 1..=96 as usize {
        let block_count = (size + 31) / 32;
        let mut bits: Vec<u32> = vec![0xFFFFFFFF as u32; block_count];
        let mut bit_vec = BitVector::new(bits.as_mut());
        for width in 1..=size - 1 {
            for begin in 0..=(size - width) {
                // Punch a hole in the bit vector.
                let end = begin + width;
                assert_ok!(bit_vec.clear_range(begin, end));

                // We should be able to find the hole if we look for it.
                assert_ok!(bit_vec.find_hole(0, size, width));

                // We should be able to find a hole smaller than the one we punched.
                assert_ok!(bit_vec.find_hole(0, size, width - 1));

                // If we look for a hole larger than the one we punched, we should
                // find nothing.
                assert_err!(bit_vec.find_hole(0, size, width + 1));

                // Fill the hole back up.
                bit_vec.set_range(begin, end).unwrap();

                // We should no longer be able to find it.
                assert_err!(bit_vec.find_hole(0, size, width));
            }
        }
    }
}

/// Same as above, but set bits (spans) instead of holes.
#[test]
fn test_find_span() {
    for size in 1..=96 {
        let block_count = (size + 31) / 32;
        let mut bits: Vec<u32> = vec![0x00000000 as u32; block_count];
        let mut bit_vec = BitVector::new(bits.as_mut());
        for width in 1..=size - 1 {
            for begin in 0..=(size - width) {
                // Create a span in the bit vector.
                let end = begin + width;
                assert_ok!(bit_vec.set_range(begin, end));

                // We should be able to find the hole if we look for it.
                assert_ok!(bit_vec.find_span(0, size, width));

                // We should be able to find a hole smaller than the one we punched.
                assert_ok!(bit_vec.find_span(0, size, width - 1));

                // If we look for a hole larger than the one we punched, we should
                // find nothing.
                assert_err!(bit_vec.find_span(0, size, width + 1));

                // Erase the span
                bit_vec.clear_range(begin, end).unwrap();

                // We should no longer be able to find it.
                assert_err!(bit_vec.find_span(0, size, width));
            }
        }
    }
}
