/* 
 * File:   bit_storage.h
 * Author: jdaniels
 *
 * Created on September 22, 2015, 7:24 AM
 */

#ifndef BIT_STORAGE_H
#define BIT_STORAGE_H
#include <limits>
#include <stdint.h>

#define ELEMENTS_NEEDED(bits,TUnitOfStorage) \
    ( \
        size_t((bits)/(sizeof(TUnitOfStorage)*8)) \
        + \
        size_t( \
           ( \
              (((bits) % (sizeof(TUnitOfStorage)*8)) > 0 ? 1 : 0) \
           ) \
         ) \
    )
    
#define BYTES_NEEDED(bits) (ELEMENTS_NEEDED(bits,uint8_t))

template<
    size_t BitsOfData, 
    class TUnitOfStorage=uint8_t,
    bool MSBInIndexZero=true>
struct bit_storage_t
{
    
    typedef bit_storage_t<BitsOfData,TUnitOfStorage,MSBInIndexZero> data_type;
    typedef TUnitOfStorage element_type;    

    private:
    
    element_type _elements[ELEMENTS_NEEDED(BitsOfData,element_type)];
    
    public:
    
    // Get an element sized bit mask with the first n-bits, masked on (1). 
    // Default is all bits for the element type turned on. Explicitly passing 0 or a value greater than the 
    // number of bits per element has the same result.
    static const element_type element_mask(size_t bits=0) { 
        element_type all_bits=std::numeric_limits<element_type>::max();
        if (bits==0||bits>BitsOfData)
            return all_bits;
        return (all_bits >> (BitsOfData-bits));
    }
    
    static size_t bits_per_element() { return (sizeof(element_type)*8); }
    
    inline element_type& element(const size_t& index) { 
        return _elements[index];
    }

    inline size_t element_index_of(const size_t& bit) {
        return MSBInIndexZero ? (element_count()-1) - (ELEMENTS_NEEDED(bit+1,element_type)-1) : ELEMENTS_NEEDED(bit+1,element_type)-1;
    }

    inline element_type& element_for(const size_t& bit) { 
        return element(element_index_of(bit)); 
    }

    inline element_type* elements() { return _elements; }

        inline uint8_t& byte(const size_t& index) { 
        return bytes()[index];
    }

    inline size_t byte_index_of(const size_t& bit) {
        return MSBInIndexZero ? (byte_count()-1) - (BYTES_NEEDED(bit+1)-1) : BYTES_NEEDED(bit+1)-1;
    }


    inline uint8_t& byte_for(const size_t& bit) { 
        return byte(byte_index_of(bit)); 
    }

    inline uint8_t* bytes() { return (uint8_t*) _elements; }

    inline size_t bit_size() { return BitsOfData; }

    inline size_t element_count() { 
        return sizeof(_elements)/sizeof(element_type);//ELEMENTS_NEEDED(BitsOfData,element_type);
    }

     inline size_t element_size() { return sizeof(element_type); }
     inline size_t element_bit_size() { return element_size()*8; }
     inline size_t byte_count() { return sizeof(_elements); }

    inline element_type element_mask_for(size_t bit) {
        return element_type(1) << element_type(bit % bits_per_element());
    }

    inline void set(const size_t& bit, const bool& value=true)
    { 
        if (value)
            element_for(bit) |= element_mask_for(bit);
        else
            clear(bit);
    }

    inline void toggle(const size_t& bit)
    { 
        element_for(bit) ^= element_mask_for(bit);
    }

    inline void clear(const size_t& bit) {
        element_for(bit) &= ~element_mask_for(bit);
    }
    
    inline bool get(const size_t& bit)
    { 
        return element_for(bit) & element_mask_for(bit);
    }
    
    inline void set_all() {
        //memset(_elements,0xFF,sizeof(_elements));
        const element_type mask=element_mask();
        for(size_t idx=0;idx<element_count();idx++)
        {
            _elements[idx]=mask;
        }
    }
    
    inline void toggle_all() {
        
        const element_type mask=element_mask();
        for(size_t idx=0;idx<element_count();idx++)
        {
            _elements[idx] ^=mask;
        }
        
    }

    inline void clear_all() {
        /*
        const element_type mask=element_mask();
        for(size_t idx=0;idx<element_count();idx++)
        {
            _elements[idx] =0;
        }*/
        memset(_elements,0x00,sizeof(_elements));
    }
    
    inline bool calc_window_endpoint_crossed(const size_t& bit, const bool& is_msb=true)
    {
        if (is_msb) {
            return (bit < (element_bit_size()-1));
        }
        else {
            return (bit > (BitsOfData-(element_bit_size()-1)));
        }        
    }
    
    inline size_t msb_index_for(const size_t& bit, const bool& is_msb=true)
    {
        size_t element_high_bit = element_bit_size()-1;
        size_t msb = (is_msb || (bit + element_high_bit >= BitsOfData)) ?
                     bit : bit + element_high_bit;
            
        return element_index_of(msb);
    }
    
    inline size_t lsb_index_for(const size_t& bit, const bool& is_msb=false)
    {
        size_t lsb=bit;
        if (is_msb && bit >= element_bit_size()) 
            lsb=bit-(element_bit_size()-1);
        return element_index_of(lsb);
    }

    /**
        retrieve an element-sized window of bits starting with the specified bit.
        
            @param bit the bit offset to retrieve the window from
            @param is_msb indicates if the bit is the most significant bit.
    */
    inline element_type get_window(const size_t& bit, const bool& is_msb=true) {
        size_t msb_idx = msb_index_for(bit,is_msb);
        size_t lsb_idx = lsb_index_for(bit,is_msb);
        
        if (msb_idx == lsb_idx && !calc_window_endpoint_crossed(bit,is_msb)) 
            return _elements[msb_idx];
        element_type msb_element = _elements[msb_idx];
        element_type lsb_element = _elements[lsb_idx];
        size_t msb= is_msb ? bit : bit + element_bit_size()-1;
        size_t lsb_shift = ((msb + 1) % element_bit_size());
        size_t msb_shift = element_bit_size() - lsb_shift;

        if (calc_window_endpoint_crossed(bit,is_msb)) {
            if (is_msb)
                return element_type(msb_element << msb_shift) >>msb_shift;
            else
                return element_type(lsb_element >> lsb_shift);
        }
        return element_type(msb_element << msb_shift) | element_type(lsb_element>>lsb_shift);        
    }
    
    /**
        Applies the bits from an element to a window of bits within the storage. used during shift operations.
    */
    inline void set_window(const size_t& bit, const element_type& data, const bool& is_msb=true) {
        size_t msb_idx = msb_index_for(bit,is_msb);
        size_t lsb_idx = lsb_index_for(bit,is_msb);
        
        if (msb_idx == lsb_idx && !calc_window_endpoint_crossed(bit,is_msb)) 
            _elements[msb_idx]=data;
            
        element_type& msb_element = _elements[msb_idx];
        element_type& lsb_element = _elements[lsb_idx];
        /*
         * Take bits msb through 0, split on the element boundary at bit position N
         * place bits msb to msb-N of data into msb-N to 0 of the msb element.
         * place msb
         */                
        size_t msb= is_msb ? bit : bit + element_bit_size()-1;
        size_t lsb=  msb - (element_bit_size()-1);
        
        size_t split = ((msb + 1) % element_bit_size());
        
        size_t lsb_shift = split;
        size_t msb_shift = element_bit_size() - lsb_shift;
        
        element_type msb_data_mask = (~element_type(0)) << msb_shift;
        element_type lsb_data_mask = (~element_type(0)) >> lsb_shift;
        element_type msb_data = (data & msb_data_mask) >> msb_shift;
        element_type lsb_data = (data & lsb_data_mask) << lsb_shift;
        element_type msb_kept = msb_element & ~(msb_data_mask);
        element_type lsb_kept = lsb_element & ~(lsb_data_mask);
        
        element_type msb_new = msb_kept | msb_data;
        element_type lsb_new = lsb_kept | lsb_data;
        if (!calc_window_endpoint_crossed(bit,is_msb)) {
            msb_element = msb_new;
            lsb_element = lsb_new;
        }
        else if (is_msb) {
            msb_element = msb_new;
        }
        else {
            lsb_element = lsb_new;
        }    
    }

    inline void __rs_msb_zero(const size_t& n, const int& ws, const int& s, const int& s1){
        int i=element_count()-1;
        if (s==0) {
            for(;i>=ws;i--) 
                _elements[i]=_elements[i-ws];
        }
        else {
            for(;i>ws;i--) 
                _elements[i]= element_type(_elements[i-ws] >> s) | element_type(_elements[i-ws-1] << s1);

            _elements[ws] = element_type(_elements[0] >> s);
        }
        for (i=ws-1;i>=0;i--)
               _elements[i]=0;
    }
    
    inline void __rs_lsb_zero(const size_t& n, const int& ws, const int& s, const int& s1){
        const int z= element_count()- ws -1;
        int i=0;
        if (s==0) {
            for(;i<=z;i++)
                _elements[i]= _elements[i+ws];
        }
        else
        {
            for(;i<z;i++)            
                _elements[i] = ((_elements[i + ws] >> s) | (_elements[i + ws + 1] << s1));
            
            _elements[z] = _elements[element_count()-1] >> s;
        }   
        for(i=z+1;i<element_count();i++)
            _elements[i]=0;
    }

    inline void __ls_lsb_zero(const size_t& n, const int& ws, const int& s, const int& s1){
        int i=element_count()-1;
        if (s==0) {
            for(;i>ws;i--) 
                _elements[i]=_elements[i-ws];
        }
        else {
            for(;i>ws;i--) 
                _elements[i]= element_type(_elements[i-ws] << s) | element_type(_elements[i-ws-1] >> s1);

            _elements[ws] = element_type(_elements[0] << s);
        }
        for (i=ws-1;i>=0;i--)
               _elements[i]=0;        
    }
    
    inline void __ls_msb_zero(const size_t& n, const int& ws, const int& s, const int& s1){
        const int z= element_count()- ws -1;
        int i=0;
        if (s==0) {
            for(;i<=z;i++)
                _elements[i]= _elements[i+ws];
        }
        else
        {
            for(;i<z;i++)            
                _elements[i] = ((_elements[i + ws] << s) | (_elements[i + ws + 1] >> s1));
            
            _elements[z] = _elements[element_count()-1] << s;
        }   
        for(i=z+1;i<element_count();i++)
            _elements[i]=0;
    }
        
    inline void shift_right(const size_t& n) {
        if (n>=BitsOfData) {
            memset(_elements,0,sizeof(_elements));
            return;
        }

        const int ws=n/element_bit_size();
        const int s=n%element_bit_size();
        const int s1=element_bit_size()-s;
        
        if (MSBInIndexZero)
            __rs_msb_zero(n,ws,s,s1);
        else
            __rs_lsb_zero(n,ws,s,s1);
    }
    
    inline void shift_left(const size_t& n) {
        if (n>=BitsOfData) {
            memset(_elements,0,sizeof(_elements));
            return;
        }

        const int ws=n/element_bit_size();
        const int s=n%element_bit_size();
        const int s1=element_bit_size()-s;
        
        if (MSBInIndexZero)
            __ls_msb_zero(n,ws,s,s1);
        else
            __ls_lsb_zero(n,ws,s,s1);        
    }    

};

#endif  /* BIT_STORAGE_H */

