This library provides a way to easily handle arbitrary large integers.

This library provides the following operations :

  • addition, substraction, multiplication, division and modulo
  • bits operators (AND, OR, XOR, left and right shifts)
  • boolean operators
  • modular exponentiation (using montgomery algorithm)
  • modular inverse

Example

In this example, we use a 1024 bits long RSA key to encrypt and decrypt a message. We first encrypt the value 0x41 (65 in decimal) and then decrypt it. At the end, m should be equal to 0x41. The encryption is fast (0, 4 second) while the decryption is really slow. This code will take between 30 seconds and 2 minutes to execute depending on the compiler and optimization flags.

main.cpp

#include "mbed.h"
#include "BigInt.h"
#include <stdlib.h>
#include <stdio.h>

uint8_t modbits[] = {
0xd9, 0x4d, 0x88, 0x9e, 0x88, 0x85, 0x3d, 0xd8, 0x97, 0x69, 0xa1, 0x80, 0x15, 0xa0, 0xa2, 0xe6,
0xbf, 0x82, 0xbf, 0x35, 0x6f, 0xe1, 0x4f, 0x25, 0x1f, 0xb4, 0xf5, 0xe2, 0xdf, 0x0d, 0x9f, 0x9a,
0x94, 0xa6, 0x8a, 0x30, 0xc4, 0x28, 0xb3, 0x9e, 0x33, 0x62, 0xfb, 0x37, 0x79, 0xa4, 0x97, 0xec,
0xea, 0xea, 0x37, 0x10, 0x0f, 0x26, 0x4d, 0x7f, 0xb9, 0xfb, 0x1a, 0x97, 0xfb, 0xf6, 0x21, 0x13,
0x3d, 0xe5, 0x5f, 0xdc, 0xb9, 0xb1, 0xad, 0x0d, 0x7a, 0x31, 0xb3, 0x79, 0x21, 0x6d, 0x79, 0x25,
0x2f, 0x5c, 0x52, 0x7b, 0x9b, 0xc6, 0x3d, 0x83, 0xd4, 0xec, 0xf4, 0xd1, 0xd4, 0x5c, 0xbf, 0x84,
0x3e, 0x84, 0x74, 0xba, 0xbc, 0x65, 0x5e, 0x9b, 0xb6, 0x79, 0x9c, 0xba, 0x77, 0xa4, 0x7e, 0xaf,
0xa8, 0x38, 0x29, 0x64, 0x74, 0xaf, 0xc2, 0x4b, 0xeb, 0x9c, 0x82, 0x5b, 0x73, 0xeb, 0xf5, 0x49
};

uint8_t dbits[] = {
0x04, 0x7b, 0x9c, 0xfd, 0xe8, 0x43, 0x17, 0x6b, 0x88, 0x74, 0x1d, 0x68, 0xcf, 0x09, 0x69, 0x52,
0xe9, 0x50, 0x81, 0x31, 0x51, 0x05, 0x8c, 0xe4, 0x6f, 0x2b, 0x04, 0x87, 0x91, 0xa2, 0x6e, 0x50,
0x7a, 0x10, 0x95, 0x79, 0x3c, 0x12, 0xba, 0xe1, 0xe0, 0x9d, 0x82, 0x21, 0x3a, 0xd9, 0x32, 0x69,
0x28, 0xcf, 0x7c, 0x23, 0x50, 0xac, 0xb1, 0x9c, 0x98, 0xf1, 0x9d, 0x32, 0xd5, 0x77, 0xd6, 0x66,
0xcd, 0x7b, 0xb8, 0xb2, 0xb5, 0xba, 0x62, 0x9d, 0x25, 0xcc, 0xf7, 0x2a, 0x5c, 0xeb, 0x8a, 0x8d,
0xa0, 0x38, 0x90, 0x6c, 0x84, 0xdc, 0xdb, 0x1f, 0xe6, 0x77, 0xdf, 0xfb, 0x2c, 0x02, 0x9f, 0xd8,
0x92, 0x63, 0x18, 0xee, 0xde, 0x1b, 0x58, 0x27, 0x2a, 0xf2, 0x2b, 0xda, 0x5c, 0x52, 0x32, 0xbe,
0x06, 0x68, 0x39, 0x39, 0x8e, 0x42, 0xf5, 0x35, 0x2d, 0xf5, 0x88, 0x48, 0xad, 0xad, 0x11, 0xa1
};

int main() 
{
    BigInt e = 65537, mod, d;
    mod.importData(modbits, sizeof(modbits));
    d.importData(dbits, sizeof(dbits));

    BigInt c = modPow(0x41,e,mod);
    c.print();
    BigInt m = modPow(c,d,mod);
    m.print();
    printf("done\n");
    
    return 0;
}

Revision:
20:d747159d77c4
Parent:
19:412b822df7bf
Child:
21:cfa04e0e6b59
--- a/BigInt.cpp	Sat Mar 08 09:35:10 2014 +0000
+++ b/BigInt.cpp	Mon Mar 10 12:51:47 2014 +0000
@@ -105,10 +105,8 @@
 
     BigInt result;
         
-    result.size = a.size > b.size ? a.size : b.size;
-    size_t l = result.size/4;
-    if(result.size % 4 != 0)
-        l++;
+    result.size = std::max(a.size, b.size) + 1;
+    size_t l = num(result.size);
     result.bits = new uint32_t[l];
     memset(result.bits, 0, sizeof(uint32_t)*l);
     uint32_t al = num(a.size);
@@ -124,21 +122,9 @@
         result.bits[i] = tmpA + tmpB + carry;
         carry = result.bits[i] < std::max(tmpA, tmpB);
     }
-    if(carry)
-    {
-        result.size++;
-        if(result.size > l*4)
-        { 
-            l++;
-            result.bits = (uint32_t*)realloc(result.bits, l * sizeof(uint32_t));
-            result.bits[l-1] = 0x00000001;
-        }
-        else
-        {
-            result.bits[l-1] += 1 << (8 *((result.size-1)%4));
-        }
-    }
-    
+
+    result.trim();
+
     return result;
 }
 
@@ -161,58 +147,48 @@
 
 // a - b, if b >= a, returns 0
 // No negative number allowed
-BigInt operator-(const BigInt &a, const BigInt& b)
+BigInt operator-(const BigInt& a, const BigInt& b)
 {
     assert(a.isValid() && b.isValid());
 
-    BigInt result;
+    if(b >= a)
+        return 0;
 
-    if(b >= a)
-    {
-        return result = 0;
-    }
-    else
+    BigInt result;
+    result.size = a.size;
+    uint32_t l = num(a.size);
+    result.bits = new uint32_t[l];
+    memset(result.bits, 0, sizeof(uint32_t)*l);
+    uint32_t bl = num(b.size);
+    uint8_t borrow = 0;
+    for(uint32_t i = 0; i < l; ++i)
     {
-        result.size = a.size;
-        uint32_t l = num(a.size);
-        result.bits = new uint32_t[l];
-        memset(result.bits, 0, sizeof(uint32_t)*l);
-        uint32_t bl = num(b.size);
-        uint8_t borrow = 0;
-        for(uint32_t i = 0; i < l; ++i)
+        uint32_t tmpA = a.bits[i], tmpB = 0;
+        if(i < bl)
+            tmpB = b.bits[i];
+            
+        if(borrow)  
         {
-            uint32_t tmpA = a.bits[i], tmpB = 0;
-            if(i < bl)
-                tmpB = b.bits[i];
-            if(borrow)
-            {
-                if(tmpA == 0)
-                    tmpA = ULONG_MAX;
-                else
-                    --tmpA;
-                    
-                if(tmpA < tmpB)            
-                    result.bits[i] = tmpA + 1 + (ULONG_MAX - tmpB);
-                else
-                    result.bits[i] = tmpA - tmpB;
-
-                if(a.bits[i] != 0 && tmpA > tmpB)
-                    borrow = 0;
-            }
+            if(tmpA == 0)
+                tmpA = 0xFFFFFFFF;
             else
             {
-                if(tmpA < tmpB)            
-                    result.bits[i] = tmpA + 1 + (ULONG_MAX - tmpB);
-                else
-                    result.bits[i] = tmpA - tmpB;
-                borrow = tmpA < tmpB;
+                --tmpA;
+                borrow = 0;
             }
         }
-        
-        result.trim();
+        if(tmpA >= tmpB)
+            result.bits[i] = tmpA - tmpB;
+        else 
+        {
+            result.bits[i] = 0xFFFFFFFF - tmpB;
+            result.bits[i] += tmpA + 1;
+            borrow = 1;
+        }
+    }
+    result.trim();
         
-        return result;
-    }
+    return result;
 }
 
 BigInt& BigInt::operator-=(const BigInt &b)
@@ -290,9 +266,10 @@
     if(a == b)
         return 1;
     BigInt u = a; 
-    int m = a.numBits() - b.numBits();
+    const uint32_t m = a.numBits() - b.numBits();
+
     BigInt q;
-    q.size = m/8 + ((m%8 != 0) ? 1 : 0);
+    q.size = m/8 + 1;
     q.bits = new uint32_t[num(q.size)];
     memset(q.bits, 0, num(q.size)*sizeof(uint32_t));
     BigInt tmp = b;
@@ -300,13 +277,14 @@
     for(int j = m; j >= 0; --j)
     {
         if(tmp <= u)
-        {
+        {   
             u -= tmp;
             q.bits[j/32] |= BITS[j%32]; 
-        }   
+        }
         tmp >>= 1;
     }
     q.trim();
+
     return q;
 }
 
@@ -327,6 +305,7 @@
     BigInt result;
     result.size = a.size - m/8;
     result.bits = new uint32_t[num(result.size)];
+    memset(result.bits, 0, sizeof(uint32_t)*num(result.size));
     uint8_t s = m%32;
     for(uint32_t i = 0; i < num(result.size); ++i)
     {
@@ -343,10 +322,10 @@
 
 BigInt& BigInt::operator>>=(const uint32_t m)
 {
-    return *this = *this >> m;
+    return (*this = (*this >> m));
 }
  
-BigInt operator<<(const BigInt &a, const uint32_t m)
+BigInt operator<<(const BigInt& a, const uint32_t m)
 {
     assert(a.isValid());
 
@@ -356,22 +335,17 @@
         return result = a;
 
     result.size = m/8 + a.size;
-    uint32_t h = a.bits[num(a.size)-1];
     if((m%32)%8 != 0)
         ++result.size;
     uint32_t l = num(result.size);
     result.bits = new uint32_t[l];
     memset(result.bits, 0, sizeof(uint32_t)*l);
     uint32_t s = m % 32;
-    for(uint32_t i = 0; i < num(a.size); ++i)
-    {
-        if(i == 0)
-            result.bits[m/32+i] = a.bits[i] << s;
-        else
-            result.bits[m/32+i] = (a.bits[i] << s) | (a.bits[i-1] >> (32-s));
-    }
-    if(a.bits[num(a.size)-1] && s != 0)
-        result.bits[m/32+num(result.size)-1] |= a.bits[num(a.size)-1] >> (32-s);
+    result.bits[m/32] = a.bits[0] << s;
+    for(uint32_t i = 1; i < num(a.size); ++i)
+        result.bits[m/32+i] = (a.bits[i] << s) | (a.bits[i-1] >> (32-s));
+    if(s != 0)
+        result.bits[num(result.size)-1] = a.bits[num(a.size)-1] >> (32-s);
 
     result.trim();
     
@@ -386,6 +360,7 @@
 BigInt operator%(const BigInt &a, const BigInt &b)
 {
     assert(a.isValid() && b.isValid() && b > 0);
+    
     return a - (a/b)*b;
 }
 
@@ -518,6 +493,8 @@
             result.bits[i] = b.bits[i];
     }
     
+    result.trim();
+    
     return result;
 }
 
@@ -530,7 +507,6 @@
 {
     assert(a.isValid() && b.isValid());
 
-
     BigInt result;
 
     uint32_t na = num(a.size);
@@ -568,19 +544,19 @@
     while(r > 0)
     {
         if(a.bits[j/32] & BITS[j%32])
-            result += b;   
+            result += b;
         
-        if(result.bits[0] & 0x01)
+        if(result.bits[0] & BITS[0])
             result += m;
      
         ++j; 
         --r;
-        result >>= 1;    
+        result >>= 1;
     }
     
     if(result >= m)
         return result - m;
-
+    
     return result;
 }
   
@@ -588,25 +564,27 @@
 BigInt modPow(const BigInt &a, const BigInt &expn, const BigInt &modulus)
 {
     assert(a.isValid() && expn.isValid() && modulus.isValid() && modulus != 0);
-    
+
+    if(modulus == 1)
+        return 0;    
     if(expn == 0)
         return 1;
-    if(modulus == 1)
-        return 0;
     if(a == 1)
         return 1;
+    if(expn == 1)
+       return a % modulus;
     
-    uint32_t r = 8*modulus.size;
-
+    const uint32_t r = 8*modulus.size;
     // convert a in montgomery world
     BigInt montA = (a << r) % modulus;
+    
     BigInt tmp;
-    if(expn.bits[0] & 0x01)
+    if(expn.bits[0] & BITS[0])
         tmp = montA;
     
-    uint32_t n = expn.numBits();
+    const uint32_t n = expn.numBits();
     uint32_t j = 1; 
-    while(j < n)
+    while(j <= n)
     {
         montA = montgomeryStep(montA, montA, modulus, r);
 
@@ -620,7 +598,7 @@
         ++j;
     }
     
-    // convert a to normal world
+    // convert tmp to normal world
     return montgomeryStep(tmp, 1, modulus, r);
 }
 
@@ -644,15 +622,23 @@
 
 void BigInt::trim()
 {
+    assert(isValid());
+    
     uint8_t *tmp = (uint8_t*)bits;
     uint32_t newSize = size;
-    while(tmp[newSize-1] == 0 && newSize > 0)
+    while(newSize > 0 && tmp[newSize-1] == 0)
         newSize--;
     if(newSize == 0)
         newSize = 1;
     if(num(newSize) < num(size))
     {
-        bits = (uint32_t*)realloc(bits, sizeof(uint32_t)*num(newSize)); 
+        uint32_t *tmp = new uint32_t[num(size)];
+        memcpy(tmp, bits, num(size)*sizeof(uint32_t));
+        delete[] bits;
+        bits = new uint32_t[num(newSize)];
+        memset(bits, 0, sizeof(uint32_t)*num(newSize));
+        memcpy(bits, tmp, newSize);
+        delete[] tmp;
     }
     size = newSize; 
 }