Fork of François Berder Crypto, fixed AES CBC and small rework

Dependents:   AES_example shaun_larada Smartage

Fork of Crypto by Francois Berder

Revision:
3:85c6ee25cf3e
Parent:
2:473bac39ae7c
Child:
4:0da19393bd57
--- a/SHA2_32.cpp	Mon Sep 09 16:16:24 2013 +0000
+++ b/SHA2_32.cpp	Wed Sep 11 17:22:40 2013 +0000
@@ -23,47 +23,19 @@
     0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2
 };
 
-static uint32_t rotLeft(uint32_t w, uint8_t n)
-{
-    return (w << n) | (w >> (32-n));
-}
-
-static uint32_t rotRight(uint32_t w, uint8_t n)
-{
-    return rotLeft(w,32-n);
-}
-
-static uint32_t CH(uint32_t x, uint32_t y, uint32_t z)
-{
-    return (x & y) ^ ((~x) & z);
-}
-
-static uint32_t MAJ(uint32_t x, uint32_t y, uint32_t z)
-{
-    return (x & y) ^ (x & z) ^ (y & z);
-}
-
-static uint32_t BSIG0(uint32_t x)
-{
-    return rotRight(x,2) ^ rotRight(x,13) ^ rotRight(x,22);
-}
-
-static uint32_t BSIG1(uint32_t x)
-{
-    return rotRight(x,6) ^ rotRight(x,11) ^ rotRight(x,25);
-}
-
-static uint32_t SSIG0(uint32_t x)
-{
-    return rotRight(x,7) ^ rotRight(x,18) ^ (x >> 3);
-}
- 
-static uint32_t SSIG1(uint32_t x)
-{
-    return rotRight(x,17) ^ rotRight(x,19) ^ (x >> 10);
-}
-
-
+#define ROTL(W,N) (((W) << (N)) | ((W) >> (32-(N))))
+#define ROTR(W,N) (((W) >> (N)) | ((W) << (32-(N))))
+#define CH(X,Y,Z) (((X) & (Y)) ^ ((~(X)) & (Z)))
+#define MAJ(X,Y,Z) (((X) & (Y)) ^ ((X) & (Z)) ^ ((Y) & (Z)))
+#define BSIG0(X) (ROTR(X,2) ^ ROTR(X,13) ^ ROTR(X,22))
+#define BSIG1(X) (ROTR(X,6) ^ ROTR(X,11) ^ ROTR(X,25))
+#define SSIG0(X) (ROTR((X),7) ^ ROTR((X),18) ^ ((X) >> 3))
+#define SSIG1(X) (ROTR((X),17) ^ ROTR((X),19) ^ ((X) >> 10))
+#define R(A,B,C,D,E,F,G,H,T)  T1 = H + BSIG1(E) + CH(E,F,G) + K[T] + w[T]; \
+                              T2 = BSIG0(A) + MAJ(A,B,C); \
+                              D += T1; \
+                              H = T1 + T2;
+        
 static const uint32_t H[] =
 {
     // SHA-224
@@ -138,62 +110,46 @@
         padding = 56 - (totalBufferLength % 64);
     else
         padding = 56 + (64 - (totalBufferLength % 64));
-    uint8_t val = 0x80;
-    add(&val, 1);
-    val = 0;
-    for(int i = 0; i < padding-1; ++i)
-        add(&val,1);
-    totalBufferLength -= padding;
-    uint64_t lengthBit = totalBufferLength * 8;
+
+    buffer[bufferLength++] = 0x80;
+    padding--;
+    if(padding+bufferLength == 56)
+        memset(&buffer[bufferLength], 0, padding);
+    else
+    {
+        memset(&buffer[bufferLength], 0, 64-bufferLength);
+        computeBlock(&h0, &h1, &h2, &h3, &h4, &h5, &h6, &h7, buffer);
+        memset(buffer, 0, bufferLength);
+    }
+    
+    uint64_t lengthBit = totalBufferLength << 3;
     uint32_t lengthBitLow = lengthBit;
     uint32_t lengthBitHigh = lengthBit >> 32;
-    uint8_t tmp[4];
-    tmp[0] = lengthBitHigh >> 24;
-    tmp[1] = lengthBitHigh >> 16;
-    tmp[2] = lengthBitHigh >> 8;
-    tmp[3] = lengthBitHigh;
-    add(tmp, 4);
-    tmp[0] = lengthBitLow >> 24;
-    tmp[1] = lengthBitLow >> 16;
-    tmp[2] = lengthBitLow >> 8;
-    tmp[3] = lengthBitLow;
-    add(tmp, 4);
+    lengthBitLow = __rev(lengthBitLow);
+    lengthBitHigh = __rev(lengthBitHigh);
+    memcpy(&buffer[60], &lengthBitLow, 4);    
+    memcpy(&buffer[56], &lengthBitHigh, 4);    
+    computeBlock(&h0, &h1, &h2, &h3, &h4, &h5, &h6, &h7, buffer);
 
-    digest[0] = h0 >> 24;
-    digest[1] = h0 >> 16;
-    digest[2] = h0 >> 8;
-    digest[3] = h0;
-    digest[4] = h1 >> 24;
-    digest[5] = h1 >> 16;
-    digest[6] = h1 >> 8;
-    digest[7] = h1;
-    digest[8] = h2 >> 24;
-    digest[9] = h2 >> 16;
-    digest[10] = h2 >> 8;
-    digest[11] = h2;
-    digest[12] = h3 >> 24;
-    digest[13] = h3 >> 16;
-    digest[14] = h3 >> 8;
-    digest[15] = h3;
-    digest[16] = h4 >> 24;
-    digest[17] = h4 >> 16;
-    digest[18] = h4 >> 8;
-    digest[19] = h4;
-    digest[20] = h5 >> 24;
-    digest[21] = h5 >> 16;
-    digest[22] = h5 >> 8;
-    digest[23] = h5;
-    digest[24] = h6 >> 24;
-    digest[25] = h6 >> 16;
-    digest[26] = h6 >> 8;
-    digest[27] = h6;
-
+    h0 = __rev(h0);
+    h1 = __rev(h1);
+    h2 = __rev(h2);
+    h3 = __rev(h3);
+    h4 = __rev(h4);
+    h5 = __rev(h5);
+    h6 = __rev(h6);
+    memcpy(digest, &h0, 4);
+    memcpy(&digest[4], &h1, 4);
+    memcpy(&digest[8], &h2, 4);
+    memcpy(&digest[12], &h3, 4);
+    memcpy(&digest[16], &h4, 4);
+    memcpy(&digest[20], &h5, 4);
+    memcpy(&digest[24], &h6, 4);
+    
     if(type == SHA_256)
     {
-        digest[28] = h7 >> 24;
-        digest[29] = h7 >> 16;
-        digest[30] = h7 >> 8;
-        digest[31] = h7;
+        h7 = __rev(h7);
+        memcpy(&digest[28], &h7, 4);
     }
     
     // reset state
@@ -236,27 +192,102 @@
                         uint8_t *buffer)
 {
     uint32_t w[64];
-    for(int t = 0; t < 16; ++t)
-    {
-        w[t] = (buffer[t*4] << 24) | (buffer[t*4+1] << 16) | (buffer[t*4+2] << 8) | buffer[t*4+3]; 
-    }
+    uint32_t *buffer2 = (uint32_t*)buffer;
+    w[0] = __rev(buffer2[0]);
+    w[1] = __rev(buffer2[1]);
+    w[2] = __rev(buffer2[2]);
+    w[3] = __rev(buffer2[3]);
+    w[4] = __rev(buffer2[4]);
+    w[5] = __rev(buffer2[5]);
+    w[6] = __rev(buffer2[6]);
+    w[7] = __rev(buffer2[7]);
+    w[8] = __rev(buffer2[8]);
+    w[9] = __rev(buffer2[9]);
+    w[10] = __rev(buffer2[10]);
+    w[11] = __rev(buffer2[11]);
+    w[12] = __rev(buffer2[12]);
+    w[13] = __rev(buffer2[13]);
+    w[14] = __rev(buffer2[14]);
+    w[15] = __rev(buffer2[15]);
+
     for(int t = 16; t < 64; ++t)
         w[t] = SSIG1(w[t-2]) + w[t-7] + SSIG0(w[t-15]) + w[t-16];
     
-     uint32_t a = *h02, b = *h12, c = *h22, d = *h32, e = *h42, f = *h52, g = *h62, h = *h72;
-    for(int t = 0; t < 64; ++t)
-    {
-        uint32_t T1 = h + BSIG1(e) + CH(e,f,g) + K[t] + w[t];
-        uint32_t T2 = BSIG0(a) + MAJ(a,b,c);
-        h = g;
-        g = f;
-        f = e;
-        e = d + T1;
-        d = c;
-        c = b;
-        b = a;
-        a = T1 + T2;
-    }
+    uint32_t a = *h02, b = *h12, c = *h22, d = *h32, e = *h42, f = *h52, g = *h62, h = *h72;
+    uint32_t T1, T2;
+    
+    R(a,b,c,d,e,f,g,h,0)
+    R(h,a,b,c,d,e,f,g,1)
+    R(g,h,a,b,c,d,e,f,2)
+    R(f,g,h,a,b,c,d,e,3)
+    R(e,f,g,h,a,b,c,d,4)
+    R(d,e,f,g,h,a,b,c,5)
+    R(c,d,e,f,g,h,a,b,6)
+    R(b,c,d,e,f,g,h,a,7)
+
+    R(a,b,c,d,e,f,g,h,8)
+    R(h,a,b,c,d,e,f,g,9)
+    R(g,h,a,b,c,d,e,f,10)
+    R(f,g,h,a,b,c,d,e,11)
+    R(e,f,g,h,a,b,c,d,12)
+    R(d,e,f,g,h,a,b,c,13)
+    R(c,d,e,f,g,h,a,b,14)
+    R(b,c,d,e,f,g,h,a,15)
+    
+    R(a,b,c,d,e,f,g,h,16)
+    R(h,a,b,c,d,e,f,g,17)
+    R(g,h,a,b,c,d,e,f,18)
+    R(f,g,h,a,b,c,d,e,19)
+    R(e,f,g,h,a,b,c,d,20)
+    R(d,e,f,g,h,a,b,c,21)
+    R(c,d,e,f,g,h,a,b,22)
+    R(b,c,d,e,f,g,h,a,23)
+    
+    R(a,b,c,d,e,f,g,h,24)
+    R(h,a,b,c,d,e,f,g,25)
+    R(g,h,a,b,c,d,e,f,26)
+    R(f,g,h,a,b,c,d,e,27)
+    R(e,f,g,h,a,b,c,d,28)
+    R(d,e,f,g,h,a,b,c,29)
+    R(c,d,e,f,g,h,a,b,30)
+    R(b,c,d,e,f,g,h,a,31) 
+    
+    R(a,b,c,d,e,f,g,h,32)
+    R(h,a,b,c,d,e,f,g,33)
+    R(g,h,a,b,c,d,e,f,34)
+    R(f,g,h,a,b,c,d,e,35)
+    R(e,f,g,h,a,b,c,d,36)
+    R(d,e,f,g,h,a,b,c,37)
+    R(c,d,e,f,g,h,a,b,38)
+    R(b,c,d,e,f,g,h,a,39)
+    
+    R(a,b,c,d,e,f,g,h,40)
+    R(h,a,b,c,d,e,f,g,41)
+    R(g,h,a,b,c,d,e,f,42)
+    R(f,g,h,a,b,c,d,e,43)
+    R(e,f,g,h,a,b,c,d,44)
+    R(d,e,f,g,h,a,b,c,45)
+    R(c,d,e,f,g,h,a,b,46)
+    R(b,c,d,e,f,g,h,a,47)
+
+    R(a,b,c,d,e,f,g,h,48)
+    R(h,a,b,c,d,e,f,g,49)
+    R(g,h,a,b,c,d,e,f,50)
+    R(f,g,h,a,b,c,d,e,51)
+    R(e,f,g,h,a,b,c,d,52)
+    R(d,e,f,g,h,a,b,c,53)
+    R(c,d,e,f,g,h,a,b,54)
+    R(b,c,d,e,f,g,h,a,55)
+    
+    R(a,b,c,d,e,f,g,h,56)
+    R(h,a,b,c,d,e,f,g,57)
+    R(g,h,a,b,c,d,e,f,58)
+    R(f,g,h,a,b,c,d,e,59)
+    R(e,f,g,h,a,b,c,d,60)
+    R(d,e,f,g,h,a,b,c,61)
+    R(c,d,e,f,g,h,a,b,62)
+    R(b,c,d,e,f,g,h,a,63)
+    
     
     *h02 += a;
     *h12 += b;
@@ -272,35 +303,32 @@
 {
     uint32_t h0 = H[type*8], h1 = H[type*8+1], h2 = H[type*8+2], h3 = H[type*8+3];
     uint32_t h4 = H[type*8+4], h5 = H[type*8+5], h6 = H[type*8+6], h7 = H[type*8+7];
-    int offset = 0;
-    while(length - offset >= 64)
-    {
-        computeBlock(&h0, &h1, &h2, &h3, &h4, &h5, &h6, &h7, &in[offset]);
-        offset += 64;
-    }
-    uint8_t bufferLength = length-offset;
-    uint8_t buffer[64];
-    memcpy(buffer, &in[offset],bufferLength); 
+    uint64_t lengthBit = length << 3;
     uint16_t padding;
     if(length % 64 < 56)
         padding = 56 - (length % 64);
     else
         padding = 56 + (64 - (length % 64));
-    buffer[bufferLength] = 0x80;
-    bufferLength++;
-    padding--;
-    while(padding > 0)
+        
+    while(length >= 64)
     {
-        if(bufferLength == 64)
-        {
-            computeBlock(&h0, &h1, &h2, &h3, &h4, &h5, &h6, &h7, buffer);
-            bufferLength = 0;
-        }
-        buffer[bufferLength] = 0;
-        bufferLength++;
-        padding--;
+        computeBlock(&h0, &h1, &h2, &h3, &h4, &h5, &h6, &h7, in);
+        length -= 64;
+        in += 64;
     }
-    uint64_t lengthBit = length * 8;
+    uint8_t buffer[64];
+    memcpy(buffer, in,length); 
+    buffer[length++] = 0x80;
+    padding--;
+    if(padding+length == 56)
+        memset(&buffer[length], 0, padding);
+    else
+    {
+        memset(&buffer[length], 0, 64-length);
+        computeBlock(&h0, &h1, &h2, &h3, &h4, &h5, &h6, &h7, buffer);
+        memset(buffer, 0, length);
+    }
+    
     uint32_t lengthBitLow = lengthBit;
     uint32_t lengthBitHigh = lengthBit >> 32;
     lengthBitLow = __rev(lengthBitLow);