#include "BlockCipher.h"
#include <string.h>

BlockCipher::BlockCipher(uint32_t bs, BLOCK_CIPHER_MODE m, uint8_t *iv):
Cipher(),
blockSize(bs),
mode(m),
IV(0),
tmpIV(0),
tmpdata(0)
{
    if(mode != ECB_MODE)
    {
        IV = new uint8_t[blockSize];
        tmpIV = new uint8_t[blockSize];
        tmpdatain = new uint8_t[blockSize];
        tmpdata = new uint8_t[blockSize];
        memcpy(IV, iv, blockSize); 
    }
}

BlockCipher::~BlockCipher()
{
    if(IV != 0) delete[] IV;
    if(tmpIV != 0) delete[] tmpIV;
    if(tmpdatain != 0) delete[] tmpdatain;
    if(tmpdata != 0) delete[] tmpdata;
}

CIPHER_TYPE BlockCipher::getType() const
{
    return BLOCK_CIPHER;
}

uint32_t BlockCipher::getBlockSize() const
{
    return blockSize;
}

void BlockCipher::encrypt(uint8_t *out, uint8_t *in, uint32_t length)
{
    
    switch (mode)
    {
        case ECB_MODE:
            for(uint32_t i = 0; i < length; i += blockSize)
            {
                encryptBlock(out+i, in+i);
            }
            break;
        case PCBC_MODE:
        case CBC_MODE:
            memcpy(tmpIV, IV, blockSize);  
            for(uint32_t i = 0; i < length; i += blockSize)
            {
                if(mode==PCBC_MODE) memcpy(tmpdata, in+i, blockSize);
                memcpy(tmpdatain, in+i, blockSize);
                for(int j = 0; j < blockSize; ++j) tmpdatain[j] ^= tmpIV[j];
                encryptBlock(out+i, tmpdatain);
                memcpy(tmpIV, out+i, blockSize);
                if(mode==PCBC_MODE)
                {
                    for(int j = 0; j < blockSize; ++j) tmpIV[j] ^= tmpdata[j];
                }
            }
            break;
    }
}

void BlockCipher::decrypt(uint8_t *out, uint8_t *in, uint32_t length)
{
    switch (mode)
    {
        case ECB_MODE:
            for(uint32_t i = 0; i < length; i += blockSize)
            {
                decryptBlock(out+i, in+i);
            }
            break;
        case PCBC_MODE:
        case CBC_MODE:
            memcpy(tmpIV, IV, blockSize);
            for(uint32_t i = 0; i < length; i += blockSize)
            {
                memcpy(tmpdatain, in+i, blockSize);
                decryptBlock(out+i, tmpdatain);
                for(int j = 0; j < blockSize; ++j) out[i+j] ^= tmpIV[j];
                memcpy(tmpIV, tmpdatain, blockSize); 
                if(mode==PCBC_MODE)
                {
                    for(int j = 0; j < blockSize; ++j) tmpIV[j] ^= out[i+j];
                }
            }
            break;
    }
}

void BlockCipher::setIV(uint8_t *iv)
{
    if(IV!=0) memcpy(IV, iv, blockSize);
}
