This library provides a way to easily handle arbitrary large integers.
This library provides the following operations :
- addition, substraction, multiplication, division and modulo
- bits operators (AND, OR, XOR, left and right shifts)
- boolean operators
- modular exponentiation (using montgomery algorithm)
- modular inverse
Example
In this example, we use a 1024 bits long RSA key to encrypt and decrypt a message. We first encrypt the value 0x41 (65 in decimal) and then decrypt it. At the end, m should be equal to 0x41. The encryption is fast (0, 4 second) while the decryption is really slow. This code will take between 30 seconds and 2 minutes to execute depending on the compiler and optimization flags.
main.cpp
#include "mbed.h" #include "BigInt.h" #include <stdlib.h> #include <stdio.h> uint8_t modbits[] = { 0xd9, 0x4d, 0x88, 0x9e, 0x88, 0x85, 0x3d, 0xd8, 0x97, 0x69, 0xa1, 0x80, 0x15, 0xa0, 0xa2, 0xe6, 0xbf, 0x82, 0xbf, 0x35, 0x6f, 0xe1, 0x4f, 0x25, 0x1f, 0xb4, 0xf5, 0xe2, 0xdf, 0x0d, 0x9f, 0x9a, 0x94, 0xa6, 0x8a, 0x30, 0xc4, 0x28, 0xb3, 0x9e, 0x33, 0x62, 0xfb, 0x37, 0x79, 0xa4, 0x97, 0xec, 0xea, 0xea, 0x37, 0x10, 0x0f, 0x26, 0x4d, 0x7f, 0xb9, 0xfb, 0x1a, 0x97, 0xfb, 0xf6, 0x21, 0x13, 0x3d, 0xe5, 0x5f, 0xdc, 0xb9, 0xb1, 0xad, 0x0d, 0x7a, 0x31, 0xb3, 0x79, 0x21, 0x6d, 0x79, 0x25, 0x2f, 0x5c, 0x52, 0x7b, 0x9b, 0xc6, 0x3d, 0x83, 0xd4, 0xec, 0xf4, 0xd1, 0xd4, 0x5c, 0xbf, 0x84, 0x3e, 0x84, 0x74, 0xba, 0xbc, 0x65, 0x5e, 0x9b, 0xb6, 0x79, 0x9c, 0xba, 0x77, 0xa4, 0x7e, 0xaf, 0xa8, 0x38, 0x29, 0x64, 0x74, 0xaf, 0xc2, 0x4b, 0xeb, 0x9c, 0x82, 0x5b, 0x73, 0xeb, 0xf5, 0x49 }; uint8_t dbits[] = { 0x04, 0x7b, 0x9c, 0xfd, 0xe8, 0x43, 0x17, 0x6b, 0x88, 0x74, 0x1d, 0x68, 0xcf, 0x09, 0x69, 0x52, 0xe9, 0x50, 0x81, 0x31, 0x51, 0x05, 0x8c, 0xe4, 0x6f, 0x2b, 0x04, 0x87, 0x91, 0xa2, 0x6e, 0x50, 0x7a, 0x10, 0x95, 0x79, 0x3c, 0x12, 0xba, 0xe1, 0xe0, 0x9d, 0x82, 0x21, 0x3a, 0xd9, 0x32, 0x69, 0x28, 0xcf, 0x7c, 0x23, 0x50, 0xac, 0xb1, 0x9c, 0x98, 0xf1, 0x9d, 0x32, 0xd5, 0x77, 0xd6, 0x66, 0xcd, 0x7b, 0xb8, 0xb2, 0xb5, 0xba, 0x62, 0x9d, 0x25, 0xcc, 0xf7, 0x2a, 0x5c, 0xeb, 0x8a, 0x8d, 0xa0, 0x38, 0x90, 0x6c, 0x84, 0xdc, 0xdb, 0x1f, 0xe6, 0x77, 0xdf, 0xfb, 0x2c, 0x02, 0x9f, 0xd8, 0x92, 0x63, 0x18, 0xee, 0xde, 0x1b, 0x58, 0x27, 0x2a, 0xf2, 0x2b, 0xda, 0x5c, 0x52, 0x32, 0xbe, 0x06, 0x68, 0x39, 0x39, 0x8e, 0x42, 0xf5, 0x35, 0x2d, 0xf5, 0x88, 0x48, 0xad, 0xad, 0x11, 0xa1 }; int main() { BigInt e = 65537, mod, d; mod.importData(modbits, sizeof(modbits)); d.importData(dbits, sizeof(dbits)); BigInt c = modPow(0x41,e,mod); c.print(); BigInt m = modPow(c,d,mod); m.print(); printf("done\n"); return 0; }
Diff: BigInt.cpp
- Revision:
- 20:d747159d77c4
- Parent:
- 19:412b822df7bf
- Child:
- 21:cfa04e0e6b59
--- a/BigInt.cpp Sat Mar 08 09:35:10 2014 +0000 +++ b/BigInt.cpp Mon Mar 10 12:51:47 2014 +0000 @@ -105,10 +105,8 @@ BigInt result; - result.size = a.size > b.size ? a.size : b.size; - size_t l = result.size/4; - if(result.size % 4 != 0) - l++; + result.size = std::max(a.size, b.size) + 1; + size_t l = num(result.size); result.bits = new uint32_t[l]; memset(result.bits, 0, sizeof(uint32_t)*l); uint32_t al = num(a.size); @@ -124,21 +122,9 @@ result.bits[i] = tmpA + tmpB + carry; carry = result.bits[i] < std::max(tmpA, tmpB); } - if(carry) - { - result.size++; - if(result.size > l*4) - { - l++; - result.bits = (uint32_t*)realloc(result.bits, l * sizeof(uint32_t)); - result.bits[l-1] = 0x00000001; - } - else - { - result.bits[l-1] += 1 << (8 *((result.size-1)%4)); - } - } - + + result.trim(); + return result; } @@ -161,58 +147,48 @@ // a - b, if b >= a, returns 0 // No negative number allowed -BigInt operator-(const BigInt &a, const BigInt& b) +BigInt operator-(const BigInt& a, const BigInt& b) { assert(a.isValid() && b.isValid()); - BigInt result; + if(b >= a) + return 0; - if(b >= a) - { - return result = 0; - } - else + BigInt result; + result.size = a.size; + uint32_t l = num(a.size); + result.bits = new uint32_t[l]; + memset(result.bits, 0, sizeof(uint32_t)*l); + uint32_t bl = num(b.size); + uint8_t borrow = 0; + for(uint32_t i = 0; i < l; ++i) { - result.size = a.size; - uint32_t l = num(a.size); - result.bits = new uint32_t[l]; - memset(result.bits, 0, sizeof(uint32_t)*l); - uint32_t bl = num(b.size); - uint8_t borrow = 0; - for(uint32_t i = 0; i < l; ++i) + uint32_t tmpA = a.bits[i], tmpB = 0; + if(i < bl) + tmpB = b.bits[i]; + + if(borrow) { - uint32_t tmpA = a.bits[i], tmpB = 0; - if(i < bl) - tmpB = b.bits[i]; - if(borrow) - { - if(tmpA == 0) - tmpA = ULONG_MAX; - else - --tmpA; - - if(tmpA < tmpB) - result.bits[i] = tmpA + 1 + (ULONG_MAX - tmpB); - else - result.bits[i] = tmpA - tmpB; - - if(a.bits[i] != 0 && tmpA > tmpB) - borrow = 0; - } + if(tmpA == 0) + tmpA = 0xFFFFFFFF; else { - if(tmpA < tmpB) - result.bits[i] = tmpA + 1 + (ULONG_MAX - tmpB); - else - result.bits[i] = tmpA - tmpB; - borrow = tmpA < tmpB; + --tmpA; + borrow = 0; } } - - result.trim(); + if(tmpA >= tmpB) + result.bits[i] = tmpA - tmpB; + else + { + result.bits[i] = 0xFFFFFFFF - tmpB; + result.bits[i] += tmpA + 1; + borrow = 1; + } + } + result.trim(); - return result; - } + return result; } BigInt& BigInt::operator-=(const BigInt &b) @@ -290,9 +266,10 @@ if(a == b) return 1; BigInt u = a; - int m = a.numBits() - b.numBits(); + const uint32_t m = a.numBits() - b.numBits(); + BigInt q; - q.size = m/8 + ((m%8 != 0) ? 1 : 0); + q.size = m/8 + 1; q.bits = new uint32_t[num(q.size)]; memset(q.bits, 0, num(q.size)*sizeof(uint32_t)); BigInt tmp = b; @@ -300,13 +277,14 @@ for(int j = m; j >= 0; --j) { if(tmp <= u) - { + { u -= tmp; q.bits[j/32] |= BITS[j%32]; - } + } tmp >>= 1; } q.trim(); + return q; } @@ -327,6 +305,7 @@ BigInt result; result.size = a.size - m/8; result.bits = new uint32_t[num(result.size)]; + memset(result.bits, 0, sizeof(uint32_t)*num(result.size)); uint8_t s = m%32; for(uint32_t i = 0; i < num(result.size); ++i) { @@ -343,10 +322,10 @@ BigInt& BigInt::operator>>=(const uint32_t m) { - return *this = *this >> m; + return (*this = (*this >> m)); } -BigInt operator<<(const BigInt &a, const uint32_t m) +BigInt operator<<(const BigInt& a, const uint32_t m) { assert(a.isValid()); @@ -356,22 +335,17 @@ return result = a; result.size = m/8 + a.size; - uint32_t h = a.bits[num(a.size)-1]; if((m%32)%8 != 0) ++result.size; uint32_t l = num(result.size); result.bits = new uint32_t[l]; memset(result.bits, 0, sizeof(uint32_t)*l); uint32_t s = m % 32; - for(uint32_t i = 0; i < num(a.size); ++i) - { - if(i == 0) - result.bits[m/32+i] = a.bits[i] << s; - else - result.bits[m/32+i] = (a.bits[i] << s) | (a.bits[i-1] >> (32-s)); - } - if(a.bits[num(a.size)-1] && s != 0) - result.bits[m/32+num(result.size)-1] |= a.bits[num(a.size)-1] >> (32-s); + result.bits[m/32] = a.bits[0] << s; + for(uint32_t i = 1; i < num(a.size); ++i) + result.bits[m/32+i] = (a.bits[i] << s) | (a.bits[i-1] >> (32-s)); + if(s != 0) + result.bits[num(result.size)-1] = a.bits[num(a.size)-1] >> (32-s); result.trim(); @@ -386,6 +360,7 @@ BigInt operator%(const BigInt &a, const BigInt &b) { assert(a.isValid() && b.isValid() && b > 0); + return a - (a/b)*b; } @@ -518,6 +493,8 @@ result.bits[i] = b.bits[i]; } + result.trim(); + return result; } @@ -530,7 +507,6 @@ { assert(a.isValid() && b.isValid()); - BigInt result; uint32_t na = num(a.size); @@ -568,19 +544,19 @@ while(r > 0) { if(a.bits[j/32] & BITS[j%32]) - result += b; + result += b; - if(result.bits[0] & 0x01) + if(result.bits[0] & BITS[0]) result += m; ++j; --r; - result >>= 1; + result >>= 1; } if(result >= m) return result - m; - + return result; } @@ -588,25 +564,27 @@ BigInt modPow(const BigInt &a, const BigInt &expn, const BigInt &modulus) { assert(a.isValid() && expn.isValid() && modulus.isValid() && modulus != 0); - + + if(modulus == 1) + return 0; if(expn == 0) return 1; - if(modulus == 1) - return 0; if(a == 1) return 1; + if(expn == 1) + return a % modulus; - uint32_t r = 8*modulus.size; - + const uint32_t r = 8*modulus.size; // convert a in montgomery world BigInt montA = (a << r) % modulus; + BigInt tmp; - if(expn.bits[0] & 0x01) + if(expn.bits[0] & BITS[0]) tmp = montA; - uint32_t n = expn.numBits(); + const uint32_t n = expn.numBits(); uint32_t j = 1; - while(j < n) + while(j <= n) { montA = montgomeryStep(montA, montA, modulus, r); @@ -620,7 +598,7 @@ ++j; } - // convert a to normal world + // convert tmp to normal world return montgomeryStep(tmp, 1, modulus, r); } @@ -644,15 +622,23 @@ void BigInt::trim() { + assert(isValid()); + uint8_t *tmp = (uint8_t*)bits; uint32_t newSize = size; - while(tmp[newSize-1] == 0 && newSize > 0) + while(newSize > 0 && tmp[newSize-1] == 0) newSize--; if(newSize == 0) newSize = 1; if(num(newSize) < num(size)) { - bits = (uint32_t*)realloc(bits, sizeof(uint32_t)*num(newSize)); + uint32_t *tmp = new uint32_t[num(size)]; + memcpy(tmp, bits, num(size)*sizeof(uint32_t)); + delete[] bits; + bits = new uint32_t[num(newSize)]; + memset(bits, 0, sizeof(uint32_t)*num(newSize)); + memcpy(bits, tmp, newSize); + delete[] tmp; } size = newSize; }