diff --git a/main.py b/main.py index 4302e7a..a4d0cd0 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,8 @@ from itertools import combinations def euclid(a, b): - """returns the Greatest Common Divisor of a and b""" + """returns the Greatest Common Divisor of a and b. + Inspired by https://en.wikipedia.org/wiki/Euclidean_algorithm#Implementations""" a = abs(a) b = abs(b) if a < b: @@ -15,156 +16,136 @@ def euclid(a, b): return a -def coPrime(l): +def coPrime(x, y): """returns 'True' if the values in the list L are all co-prime otherwise, it returns 'False'. """ - for i, j in combinations(l, 2): - if euclid(i, j) != 1: - return False + if euclid(x, y) != 1: + return False return True -def extendedEuclid(a, b): - """return a tuple of three values: x, y and z, such that x is - the GCD of a and b, and x = y * a + z * b""" - if a == 0: - return b, 0, 1 - else: - g, y, x = extendedEuclid(b % a, a) - return g, x - (b // a) * y, y +def iterative_extended_euclid(a, b): + """ returns the GCD (in old_r) of a and b as well as the "Bezout's Identity" such that + a*old_s + b*old_t = GCD(a,b). + Algorithm is adopted from https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Pseudocode + """ + old_r, r = a, b + old_s, s = 1, 0 + old_t, t = 0, 1 + + while r != 0: + quotient = old_r // r + old_r, r = r, old_r - (quotient * r) + old_s, s = s, old_s - (quotient * s) + old_t, t = t, old_t - (quotient * t) + + print(old_r, old_s, old_t) + return old_r, old_s, old_t + + def modInv(a, m): """returns the multiplicative inverse of a in modulo m as a - positive value between zero and m-1""" - # notice that a and m need to be co-prime to each other. - if coPrime([a, m]): - linearCombination = extendedEuclid(a, m) - return linearCombination[1] % m + positive value between zero and m-1 + Adopted from https://en.wikipedia.org/wiki/Modular_multiplicative_inverse#Extended_Euclidean_algorithm + """ + if coPrime(a, m): + linear_combination = iterative_extended_euclid(a, m) + print(linear_combination[1] % m) + return linear_combination[1] % m else: return 0 -def extractTwos(m): - """m is a positive integer. A tuple (s, d) of integers is returned - such that m = (2 ** s) * d.""" - # the problem can be reduced to counting how many '0's there are in - # the end of bin(m). This can be done this way: m & a stretch of '1's - # which can be represent as (2 ** n) - 1. - assert m >= 0 - i = 0 - while m & (2 ** i) == 0: - i += 1 - return i, m >> i +def miller_rabin(n, k): + # Implementation uses the Miller-Rabin Primality Test + # The optimal number of rounds for this test is 40 + # See http://stackoverflow.com/questions/6325576/how-many-iterations-of-rabin-miller-should-i-use-for-cryptographic-safe-primes + # for justification -def millerRabin(n, k): - """ - Miller Rabin pseudo-prime test - return True means likely a prime, (how sure about that, depending on k) - return False means definitely a composite. - Raise assertion error when n, k are not positive integers - and n is not 1 - """ - assert n >= 1 - # ensure n is bigger than 1 - assert k > 0 - # ensure k is a positive integer so everything down here makes sense + # If number is even, it's a composite number if n == 2: return True - # make sure to return True if n == 2 if n % 2 == 0: return False - # immediately return False for all the even numbers bigger than 2 - extract2 = extractTwos(n - 1) - s = extract2[0] - d = extract2[1] - assert 2 ** s * d == n - 1 - - def tryComposite(a): - """Inner function which will inspect whether a given witness - will reveal the true identity of n. Will only be called within - millerRabin""" - x = pow(a,d,n) + r, s = 0, n - 1 + while s % 2 == 0: + r += 1 + s //= 2 + for _ in range(k): + a = random.randrange(2, n - 1) + x = pow(a, s, n) if x == 1 or x == n - 1: - return None + continue + for _ in range(r - 1): + x = pow(x, 2, n) + if x == n - 1: + break else: - for j in range(1, s): - x = pow(x,2,n) - if x == 1: - return False - elif x == n - 1: - return None return False - - for i in range(0, k): - a = random.randint(2, n - 2) - if tryComposite(a) == False: - return False - return True # actually, we should return probably true. + return True def findAPrime(a, b, k): """Return a pseudo prime number roughly between a and b, - (could be larger than b). Raise ValueError if cannot find a - pseudo prime after 10 * ln(x) + 3 tries. """ + (could be larger than b). """ x = random.randint(a, b) - for i in range(0, int(10 * math.log(x) + 3)): - if millerRabin(x, k): + while True: + if miller_rabin(x, k): return x else: x += 1 - raise ValueError def newKey(a, b, k): """ Try to find two large pseudo primes roughly between a and b. - Generate public and private keys for RSA encryption. - Raises ValueError if it fails to find one""" - try: - p = findAPrime(a, b, k) - while True: - q = findAPrime(a, b, k) - if q != p: - break - except: - raise ValueError - - n = p * q - m = (p - 1) * (q - 1) + Generate public and private keys for RSA encryption.""" + p = findAPrime(a, b, k) while True: - e = random.randint(1, m) - if coPrime([e, m]): + q = findAPrime(a, b, k) + if q != p: break + n = p * q + m = (p - 1) * (q - 1) # Compute phi(n) for n=pq where p and q are prime + + # Find and e that is coprime to phi(n) to be used in the public key + while True: + e = random.randint(1, m) + if coPrime(e, m): + break + + # Let d be the modular inverse to e, to be used as private key d = modInv(e, m) return (n, e, d) -def encrypt(message, modN, e, blockSize): +def encrypt(message, modN, e): """given a string message, public keys and blockSize, encrypt using RSA algorithms.""" return pow(message, e, modN) -def decrypt(secret, modN, d, blockSize): +def decrypt(secret, modN, d): """reverse function of encrypt""" return pow(secret, d, modN) -if __name__ == '__main__': +if __name__ == '__main__': n, e, d = newKey(2**40, 2 ** 41, 20) message = 35274764 print("original message is {}".format(message)) print("-"*80) - cipher = encrypt(message, n, e, 15) + cipher = encrypt(message, n, e) print("cipher text is {}".format(cipher)) print("-"*80) - deciphered = decrypt(cipher, n, d, 15) + deciphered = decrypt(cipher, n, d) print("decrypted message is {}".format(deciphered))