This library provides a way to easily handle arbitrary large integers.
This library provides the following operations :
- addition, substraction, multiplication, division and modulo
- bits operators (AND, OR, XOR, left and right shifts)
- boolean operators
- modular exponentiation (using montgomery algorithm)
- modular inverse
Example
In this example, we use a 1024 bits long RSA key to encrypt and decrypt a message. We first encrypt the value 0x41 (65 in decimal) and then decrypt it. At the end, m should be equal to 0x41. The encryption is fast (0, 4 second) while the decryption is really slow. This code will take between 30 seconds and 2 minutes to execute depending on the compiler and optimization flags.
main.cpp
#include "mbed.h" #include "BigInt.h" #include <stdlib.h> #include <stdio.h> uint8_t modbits[] = { 0xd9, 0x4d, 0x88, 0x9e, 0x88, 0x85, 0x3d, 0xd8, 0x97, 0x69, 0xa1, 0x80, 0x15, 0xa0, 0xa2, 0xe6, 0xbf, 0x82, 0xbf, 0x35, 0x6f, 0xe1, 0x4f, 0x25, 0x1f, 0xb4, 0xf5, 0xe2, 0xdf, 0x0d, 0x9f, 0x9a, 0x94, 0xa6, 0x8a, 0x30, 0xc4, 0x28, 0xb3, 0x9e, 0x33, 0x62, 0xfb, 0x37, 0x79, 0xa4, 0x97, 0xec, 0xea, 0xea, 0x37, 0x10, 0x0f, 0x26, 0x4d, 0x7f, 0xb9, 0xfb, 0x1a, 0x97, 0xfb, 0xf6, 0x21, 0x13, 0x3d, 0xe5, 0x5f, 0xdc, 0xb9, 0xb1, 0xad, 0x0d, 0x7a, 0x31, 0xb3, 0x79, 0x21, 0x6d, 0x79, 0x25, 0x2f, 0x5c, 0x52, 0x7b, 0x9b, 0xc6, 0x3d, 0x83, 0xd4, 0xec, 0xf4, 0xd1, 0xd4, 0x5c, 0xbf, 0x84, 0x3e, 0x84, 0x74, 0xba, 0xbc, 0x65, 0x5e, 0x9b, 0xb6, 0x79, 0x9c, 0xba, 0x77, 0xa4, 0x7e, 0xaf, 0xa8, 0x38, 0x29, 0x64, 0x74, 0xaf, 0xc2, 0x4b, 0xeb, 0x9c, 0x82, 0x5b, 0x73, 0xeb, 0xf5, 0x49 }; uint8_t dbits[] = { 0x04, 0x7b, 0x9c, 0xfd, 0xe8, 0x43, 0x17, 0x6b, 0x88, 0x74, 0x1d, 0x68, 0xcf, 0x09, 0x69, 0x52, 0xe9, 0x50, 0x81, 0x31, 0x51, 0x05, 0x8c, 0xe4, 0x6f, 0x2b, 0x04, 0x87, 0x91, 0xa2, 0x6e, 0x50, 0x7a, 0x10, 0x95, 0x79, 0x3c, 0x12, 0xba, 0xe1, 0xe0, 0x9d, 0x82, 0x21, 0x3a, 0xd9, 0x32, 0x69, 0x28, 0xcf, 0x7c, 0x23, 0x50, 0xac, 0xb1, 0x9c, 0x98, 0xf1, 0x9d, 0x32, 0xd5, 0x77, 0xd6, 0x66, 0xcd, 0x7b, 0xb8, 0xb2, 0xb5, 0xba, 0x62, 0x9d, 0x25, 0xcc, 0xf7, 0x2a, 0x5c, 0xeb, 0x8a, 0x8d, 0xa0, 0x38, 0x90, 0x6c, 0x84, 0xdc, 0xdb, 0x1f, 0xe6, 0x77, 0xdf, 0xfb, 0x2c, 0x02, 0x9f, 0xd8, 0x92, 0x63, 0x18, 0xee, 0xde, 0x1b, 0x58, 0x27, 0x2a, 0xf2, 0x2b, 0xda, 0x5c, 0x52, 0x32, 0xbe, 0x06, 0x68, 0x39, 0x39, 0x8e, 0x42, 0xf5, 0x35, 0x2d, 0xf5, 0x88, 0x48, 0xad, 0xad, 0x11, 0xa1 }; int main() { BigInt e = 65537, mod, d; mod.importData(modbits, sizeof(modbits)); d.importData(dbits, sizeof(dbits)); BigInt c = modPow(0x41,e,mod); c.print(); BigInt m = modPow(c,d,mod); m.print(); printf("done\n"); return 0; }
Revision 26:94e26bcd229d, committed 2014-05-11
- Comitter:
- feb11
- Date:
- Sun May 11 10:33:20 2014 +0000
- Parent:
- 25:3d5c1f299da2
- Commit message:
- Support signed integers. Add invMod function
Changed in this revision
BigInt.cpp | Show annotated file Show diff for this revision Revisions of this file |
BigInt.h | Show annotated file Show diff for this revision Revisions of this file |
--- a/BigInt.cpp Sun Apr 13 07:35:47 2014 +0000 +++ b/BigInt.cpp Sun May 11 10:33:20 2014 +0000 @@ -27,13 +27,22 @@ BigInt::BigInt(): +sign(POS), size(0), bits(0) { } -BigInt::BigInt(const uint32_t a) +BigInt::BigInt(int32_t a) { + if(a < 0) + { + a = -a; + sign = NEG; + } + else + sign = POS; + if(a >> 24) size = 4; else if(a >> 16) @@ -47,6 +56,7 @@ } BigInt::BigInt(const BigInt &a): +sign(a.sign), size(a.size) { uint32_t l = num(size); @@ -66,6 +76,7 @@ BigInt& BigInt::operator=(const BigInt& a) { + sign = a.sign; size = a.size; uint32_t l = num(size); if(bits) @@ -76,8 +87,9 @@ return *this; } -void BigInt::importData(uint8_t *data, uint32_t length) +void BigInt::importData(uint8_t *data, uint32_t length, bool sign) { + this->sign = sign; size = length; if(bits) delete[] bits; @@ -89,12 +101,14 @@ trim(); } -void BigInt::exportData(uint8_t *data, uint32_t length) +void BigInt::exportData(uint8_t *data, uint32_t length, bool &sign) { assert(isValid() && data != 0); if(length < size) return; + + sign = this->sign; uint32_t offset = length-size; memset(data, 0, offset); for(int i = size-1; i >= 0; --i) @@ -105,29 +119,14 @@ { assert(a.isValid() && b.isValid()); - BigInt result; - - result.size = std::max(a.size, b.size) + 1; - size_t l = num(result.size); - result.bits = new uint32_t[l]; - memset(result.bits, 0, sizeof(uint32_t)*l); - uint32_t al = num(a.size); - uint32_t bl = num(b.size); - uint32_t carry = 0; - for(int i = 0; i < (int)l; ++i) - { - uint32_t tmpA = 0, tmpB = 0; - if(i < (int)al) - tmpA = a.bits[i]; - if(i < (int)bl) - tmpB = b.bits[i]; - result.bits[i] = tmpA + tmpB + carry; - carry = result.bits[i] < std::max(tmpA, tmpB); - } - - result.trim(); - - return result; + if(a.sign == POS && b.sign == POS) // a+b + return add(a, b); + if(a.sign == NEG && b.sign == NEG) // (-a)+(-b) = -(a+b) + return -add(a, b); + else if(a.sign == POS) // a + (-b) = a-b + return a - (-b); + else // (-a) + b = b-a + return b - (-a); } BigInt& BigInt::operator+=(const BigInt &b) @@ -147,49 +146,33 @@ return t; } -// a - b, if b >= a, returns 0 -// No negative number allowed BigInt operator-(const BigInt& a, const BigInt& b) { assert(a.isValid() && b.isValid()); - - if(b >= a) - return 0; - - BigInt result; - result.size = a.size; - uint32_t l = num(a.size); - result.bits = new uint32_t[l]; - memset(result.bits, 0, sizeof(uint32_t)*l); - uint32_t bl = num(b.size); - uint8_t borrow = 0; - for(uint32_t i = 0; i < l; ++i) + + if(a.sign == POS && b.sign == POS) { - uint32_t tmpA = a.bits[i], tmpB = 0; - if(i < bl) - tmpB = b.bits[i]; - - if(borrow) - { - if(tmpA == 0) - tmpA = 0xFFFFFFFF; - else - { - --tmpA; - borrow = 0; - } - } - if(tmpA >= tmpB) - result.bits[i] = tmpA - tmpB; - else - { - result.bits[i] = 0xFFFFFFFF - tmpB; - result.bits[i] += tmpA + 1; - borrow = 1; - } + if(equals(a, b)) + return 0; + else if(greater(a, b)) + return sub(a, b); + else + return -sub(b, a); } - result.trim(); - + else if(a.sign == NEG && b.sign == NEG) + return (-b) - (-a); + else if(a.sign == NEG && b.sign == POS) + return -add(a, b); + else + return add(a, b); +} + +BigInt operator-(const BigInt &a) +{ + assert(a.isValid()); + + BigInt result = a; + result.sign = !a.sign; return result; } @@ -217,37 +200,19 @@ // if a == 0 or b == 0 then result = 0 if(!a || !b) return 0; - - // if a == 1, then result = b - if(a == 1) - return b; - - // if b == 1, then result = a - if(b == 1) - return a; - - BigInt result; - result.size = a.size + b.size; - result.bits = new uint32_t[num(result.size)]; - memset(result.bits, 0, sizeof(uint32_t)*num(result.size)); - for(int i = 0; i < (int)num(a.size); ++i) - { - uint64_t carry = 0; - for(int j = 0; j < (int)num(b.size); ++j) - { - uint64_t tmp = (uint64_t)a.bits[i] * (uint64_t)b.bits[j] + carry; - uint32_t t = result.bits[i+j]; - result.bits[i+j] += tmp; - carry = tmp >> 32; - if(t > result.bits[i+j]) - ++carry; - } - if(carry != 0) - result.bits[i+num(b.size)] += carry; - } - - result.trim(); - + + BigInt result; + if(equals(a, 1)) + result = b; + else if(equals(b, 1)) + result = a; + else + result = mul(a, b); + + if(a.sign == b.sign) + result.sign = POS; + else + result.sign = NEG; return result; } @@ -256,38 +221,28 @@ return (*this = (*this) * b); } - BigInt operator/(const BigInt &a, const BigInt &b) { assert(a.isValid() && b.isValid() && b != 0); - if(b == 1) - return a; - if(a < b) + if(lesser(a, b)) return 0; - if(a == b) - return 1; - BigInt u = a; - const uint32_t m = a.numBits() - b.numBits(); - BigInt q; - q.size = m/8 + 1; - q.bits = new uint32_t[num(q.size)]; - memset(q.bits, 0, num(q.size)*sizeof(uint32_t)); - BigInt tmp = b; - tmp <<= m; - for(int j = m; j >= 0; --j) - { - if(tmp <= u) - { - u -= tmp; - q.bits[j/32] |= BITS[j%32]; - } - tmp >>= 1; - } - q.trim(); - - return q; + BigInt result; + + if(equals(a, b)) + result = 1; + else if(equals(b, 1)) + result = a; + else + result = div(a, b); + + if(a.sign == b.sign) + result.sign = POS; + else + result.sign = NEG; + + return result; } BigInt& BigInt::operator/=(const BigInt &b) @@ -318,7 +273,8 @@ } result.trim(); - + result.checkZero(); + return result; } @@ -331,11 +287,11 @@ { assert(a.isValid()); + if(m == 0) + return a; + BigInt result; - if(m == 0) - return result = a; - result.size = m/8 + a.size; if((m%32)%8 != 0) ++result.size; @@ -346,8 +302,8 @@ result.bits[m/32] = a.bits[0] << s; for(uint32_t i = 1; i < num(a.size); ++i) result.bits[m/32+i] = (a.bits[i] << s) | (a.bits[i-1] >> (32-s)); - if(s != 0) - result.bits[num(result.size)-1] = a.bits[num(a.size)-1] >> (32-s); + if(s != 0 && num(result.size) != 1) + result.bits[num(result.size)-1] |= a.bits[num(a.size)-1] >> (32-s); result.trim(); @@ -362,7 +318,10 @@ BigInt operator%(const BigInt &a, const BigInt &b) { assert(a.isValid() && b.isValid() && b > 0); - + if(a < b) + return a; + if(a == b) + return 0; return a - (a/b)*b; } @@ -375,14 +334,7 @@ { assert(a.isValid() && b.isValid()); - if(a.size != b.size) - return false; - - uint32_t l = num(a.size); - for(int i = 0; i < (int)l; ++i) - if(a.bits[i] != b.bits[i]) - return false; - return true; + return a.sign == b.sign && equals(a, b); } bool operator!=(const BigInt &a, const BigInt &b) @@ -393,20 +345,15 @@ bool operator<(const BigInt &a, const BigInt &b) { assert(a.isValid() && b.isValid()); - - if(a.size < b.size) + + if(a.sign == NEG && b.sign == NEG) + return !lesser(a, b); + else if(a.sign == NEG && b.sign == POS) return true; - if(a.size > b.size) + else if(a.sign == POS && b.sign == NEG) return false; - uint32_t l = num(a.size); - for(int i = l-1; i >= 0; --i) - { - if(a.bits[i] < b.bits[i]) - return true; - else if(a.bits[i] > b.bits[i]) - return false; - } - return false; + else + return lesser(a, b); } bool operator<=(const BigInt &a, const BigInt &b) @@ -418,19 +365,14 @@ { assert(a.isValid() && b.isValid()); - if(a.size > b.size) - return true; - if(a.size < b.size) + if(a.sign == NEG && b.sign == NEG) + return !greater(a, b); + else if(a.sign == NEG && b.sign == POS) return false; - uint32_t l = num(a.size); - for(int i = l-1; i >= 0; --i) - { - if(a.bits[i] > b.bits[i]) - return true; - else if(a.bits[i] < b.bits[i]) - return false; - } - return false; + else if(a.sign == POS && b.sign == NEG) + return true; + else + return greater(a, b); } bool operator>=(const BigInt &a, const BigInt &b) @@ -544,14 +486,14 @@ while(r > 0) { if(a.bits[j/32] & BITS[j%32]) - result.add(b); + result.fastAdd(b); if(result.bits[0] & BITS[0]) - result.add(m); + result.fastAdd(m); ++j; --r; - result.shr(); + result.fastShr(); } if(result >= m) @@ -602,6 +544,36 @@ return montgomeryStep(tmp, 1, modulus, r); } +// Implementation as described in FIPS.186-4, Appendix C.1 +BigInt invMod(const BigInt &a, const BigInt &modulus) +{ + assert(a.isValid() && modulus.isValid() && 0 < a && a < modulus); + + BigInt i = modulus; + BigInt j = a; + BigInt y2 = 0; + BigInt y1 = 1; + do + { + BigInt quotient = i / j; + BigInt remainder = i - (j * quotient); + BigInt y = y2 - (y1 * quotient); + i = j; + j = remainder; + y2 = y1; + y1 = y; + }while(j > 0); + + + assert(i == 1); + + y2 %= modulus; + if(y2 < 0) + y2 += modulus; + + return y2; +} + bool BigInt::isValid() const { return size != 0 && bits != 0; @@ -613,12 +585,184 @@ printf("size: %lu bytes\n", size); uint32_t n = num(size); + if(sign == NEG) + printf("- "); for(int i = n-1; i >= 0; --i) printf("%08x ", (int)bits[i]); printf("\n"); } -void BigInt::add(const BigInt &b) +// return a + b +BigInt add(const BigInt &a, const BigInt &b) +{ + BigInt result; + + result.size = std::max(a.size, b.size) + 1; + size_t l = num(result.size); + result.bits = new uint32_t[l]; + memset(result.bits, 0, sizeof(uint32_t)*l); + uint32_t al = num(a.size); + uint32_t bl = num(b.size); + uint32_t carry = 0; + for(int i = 0; i < (int)l; ++i) + { + uint32_t tmpA = 0, tmpB = 0; + if(i < (int)al) + tmpA = a.bits[i]; + if(i < (int)bl) + tmpB = b.bits[i]; + result.bits[i] = tmpA + tmpB + carry; + carry = result.bits[i] < std::max(tmpA, tmpB); + } + + result.trim(); + result.checkZero(); + + return result; +} + +// return a - b +// Assume that a > b +BigInt sub(const BigInt &a, const BigInt &b) +{ + BigInt result; + result.size = a.size; + uint32_t l = num(a.size); + result.bits = new uint32_t[l]; + memset(result.bits, 0, sizeof(uint32_t)*l); + uint32_t bl = num(b.size); + uint8_t borrow = 0; + for(uint32_t i = 0; i < l; ++i) + { + uint32_t tmpA = a.bits[i], tmpB = 0; + if(i < bl) + tmpB = b.bits[i]; + + if(borrow) + { + if(tmpA == 0) + tmpA = 0xFFFFFFFF; + else + { + --tmpA; + borrow = 0; + } + } + if(tmpA >= tmpB) + result.bits[i] = tmpA - tmpB; + else + { + result.bits[i] = 0xFFFFFFFF - tmpB; + result.bits[i] += tmpA + 1; + borrow = 1; + } + } + result.trim(); + result.checkZero(); + + return result; +} + +BigInt mul(const BigInt &a, const BigInt &b) +{ + BigInt result; + result.size = a.size + b.size; + result.bits = new uint32_t[num(result.size)]; + memset(result.bits, 0, sizeof(uint32_t)*num(result.size)); + for(int i = 0; i < (int)num(a.size); ++i) + { + uint64_t carry = 0; + for(int j = 0; j < (int)num(b.size); ++j) + { + uint64_t tmp = (uint64_t)a.bits[i] * (uint64_t)b.bits[j] + carry; + uint32_t t = result.bits[i+j]; + result.bits[i+j] += tmp; + carry = tmp >> 32; + if(t > result.bits[i+j]) + ++carry; + } + if(carry != 0) + result.bits[i+num(b.size)] += carry; + } + + result.trim(); + + return result; +} + +BigInt div(const BigInt &a, const BigInt &b) +{ + BigInt u = a; + const uint32_t m = a.numBits() - b.numBits(); + BigInt q; + q.size = m/8 + 1; + q.bits = new uint32_t[num(q.size)]; + memset(q.bits, 0, num(q.size)*sizeof(uint32_t)); + BigInt tmp = b; + tmp <<= m; + for(int j = m; j >= 0; --j) + { + if(tmp <= u) + { + u -= tmp; + q.bits[j/32] |= BITS[j%32]; + } + tmp >>= 1; + } + q.trim(); + + return q; +} + +bool equals(const BigInt &a, const BigInt &b) +{ + if(a.size != b.size) + return false; + + uint32_t l = num(a.size); + for(int i = 0; i < (int)l; ++i) + if(a.bits[i] != b.bits[i]) + return false; + return true; +} + +bool lesser(const BigInt &a, const BigInt &b) +{ + if(a.size < b.size) + return true; + if(a.size > b.size) + return false; + + uint32_t l = num(a.size); + for(int i = l-1; i >= 0; --i) + { + if(a.bits[i] < b.bits[i]) + return true; + else if(a.bits[i] > b.bits[i]) + return false; + } + return false; +} + +bool greater(const BigInt &a, const BigInt &b) +{ + if(a.size > b.size) + return true; + if(a.size < b.size) + return false; + uint32_t l = num(a.size); + for(int i = l-1; i >= 0; --i) + { + if(a.bits[i] > b.bits[i]) + return true; + else if(a.bits[i] < b.bits[i]) + return false; + } + return false; +} + + +void BigInt::fastAdd(const BigInt &b) { uint32_t al = num(size); uint32_t bl = num(b.size); @@ -647,7 +791,7 @@ trim(); } -void BigInt::shr() +void BigInt::fastShr() { uint32_t lastBit = 0; uint32_t tmp; @@ -665,6 +809,7 @@ void BigInt::trim() { assert(isValid()); + uint8_t *tmp = (uint8_t*)bits; uint32_t newSize = size; @@ -697,4 +842,13 @@ n += tmp2; return n; -} \ No newline at end of file +} + +// Ensure that there is no negative zero +void BigInt::checkZero() +{ + assert(isValid()); + + if(size == 1 && bits[0] == 0) + sign = POS; +}
--- a/BigInt.h Sun Apr 13 07:35:47 2014 +0000 +++ b/BigInt.h Sun May 11 10:33:20 2014 +0000 @@ -3,18 +3,21 @@ #include <stdint.h> +#define POS (true) +#define NEG (false) + class BigInt { public : BigInt(); - BigInt(const uint32_t a); + BigInt(int32_t a); BigInt(const BigInt &a); BigInt& operator=(const BigInt& a); virtual ~BigInt(); - void importData(uint8_t *data, uint32_t length); - void exportData(uint8_t *data, uint32_t length); + void importData(uint8_t *data, uint32_t length, bool sign = POS); + void exportData(uint8_t *data, uint32_t length, bool &sign); // Arithmetic operations friend BigInt operator+(const BigInt &a, const BigInt &b); @@ -23,10 +26,11 @@ BigInt operator++(int); friend BigInt operator-(const BigInt &a, const BigInt &b); + friend BigInt operator-(const BigInt &a); BigInt& operator-=(const BigInt &b); BigInt& operator--(); BigInt operator--(int); - + friend BigInt operator*(const BigInt &a, const BigInt &b); BigInt& operator*=(const BigInt &b); @@ -60,24 +64,39 @@ // modular exponentiation friend BigInt modPow(const BigInt &a, const BigInt &expn, const BigInt &modulus); - + + // invert modular + friend BigInt invMod(const BigInt &a, const BigInt &modulus); + + // miscellaneous bool isValid() const; + uint32_t numBits() const; // debug void print() const; private : + friend BigInt add(const BigInt &a, const BigInt &b); + friend BigInt sub(const BigInt &a, const BigInt &b); + friend BigInt mul(const BigInt &a, const BigInt &b); + friend BigInt div(const BigInt &a, const BigInt &b); + + friend bool equals(const BigInt &a, const BigInt &b); + friend bool lesser(const BigInt &a, const BigInt &b); + friend bool greater(const BigInt &a, const BigInt &b); + // fast operations - void add(const BigInt &b); - void shr(); + void fastAdd(const BigInt &b); + void fastShr(); void trim(); - uint32_t numBits() const; + void checkZero(); friend BigInt montgomeryStep(const BigInt &a, const BigInt &b, const BigInt &m, uint32_t r); friend BigInt montgomeryStep2(const BigInt &a, const BigInt &m, uint32_t r); - + + bool sign; uint32_t size; // smaller number of bytes needed to represent integer uint32_t *bits; };