Montgomery reduction algorithmThe Montgomery reduction algorithm is a fast way to perform multiplications in the form of \(x y \pmod N\) or \(x^y \pmod N\). In this case we take a number of bits of N and will generate a valid value for N (which must be odd): |
Outline
With the RSA and the Diffie-Hellman method we perform large exponential calculations, such as:
\(C = M^e \pmod N\)
and where we will continually multiply large integers by an exponent to get a result. If we were to just calculate \(x^y\) and then take \(\pmod n\) it would take a while to produce the result. Thus we use Montgomery modular multiplication, and which significantly reduces the time to compute the result. In a traditional multiplication of two value (\(x\) and \(y\)) for a modulus of \(N\), we multiply \(x\) times \(y\) and then divide by \(N\) to find the remainder. The number of bits in the multiplication with this be the number of bits in \(x\) added to the number of bits in \(y\). In Montgomery reduction we add multiples of \(N\) in order to simplify the multiplication.
An example of this is here, and a sample run for \(x=10\), \(y=5\) and \(N=29\) is:
x= 10 y= 5 N= 29 x*y (mod N) Result (Montgomery)= 21 Result (x*y % mod)= 21 x^y (mod N) Result (Montgomery)= 8 Result (x^y % mod)= 8
In this case we get \(50 \pmod {29}\) which is 21, and \(10^5 \pmod {29}\) which is 8.
The sample code is [here]:
# # Montgomery reduction algorithm (Python) # # Copyright (c) 2018 Project Nayuki # All rights reserved. Contact Nayuki for licensing. # https://www.nayuki.io/page/montgomery-reduction-algorithm # import random,sys import math class MontgomeryReducer(object): def __init__(self, mod): # Modulus if mod < 3 or mod % 2 == 0: raise ValueError("Modulus must be an odd number at least 3") self.modulus = mod # Reducer self.reducerbits = (mod.bit_length() // 8 + 1) * 8 # This is a multiple of 8 self.reducer = 1 << self.reducerbits # This is a power of 256 self.mask = self.reducer - 1 assert self.reducer > mod and math.gcd(self.reducer, mod) == 1 # Other computed numbers self.reciprocal = MontgomeryReducer.reciprocal_mod(self.reducer % mod, mod) self.factor = (self.reducer * self.reciprocal - 1) // mod self.convertedone = self.reducer % mod # The range of x is unlimited def convert_in(self, x): return (x << self.reducerbits) % self.modulus # The range of x is unlimited def convert_out(self, x): return (x * self.reciprocal) % self.modulus # Inputs and output are in Montgomery form and in the range [0, modulus) def multiply(self, x, y): mod = self.modulus assert 0 <= x < mod and 0 <= y < mod product = x * y temp = ((product & self.mask) * self.factor) & self.mask reduced = (product + temp * mod) >> self.reducerbits result = reduced if (reduced < mod) else (reduced - mod) assert 0 <= result < mod return result # Input x (base) and output (power) are in Montgomery form and in the range [0, modulus); input y (exponent) is in standard form def pow(self, x, y): assert 0 <= x < self.modulus if y < 0: raise ValueError("Negative exponent") z = self.convertedone while y != 0: if y & 1 != 0: z = self.multiply(z, x) x = self.multiply(x, x) y >>= 1 return z @staticmethod def reciprocal_mod(x, mod): # Based on a simplification of the extended Euclidean algorithm assert mod > 0 and 0 <= x < mod y = x x = mod a = 0 b = 1 while y != 0: a, b = b, a - x // y * b x, y = y, x % y if x == 1: return a % mod else: raise ValueError("Reciprocal does not exist") bitlen=20 x=100 y=50 if (len(sys.argv)>1): x=int(sys.argv[1]) if (len(sys.argv)>2): y=int(sys.argv[2]) if (len(sys.argv)>3): bitlen=int(sys.argv[3]) #bitlen = random.randint(2, 100) mod = random.randrange(1 << bitlen, 2 << bitlen) | 1 # Force it to be odd mr = MontgomeryReducer(mod) #x = random.randrange(0, mod) #y = random.randrange(0, mod) u = mr.convert_in(x) v = mr.convert_in(y) w = mr.multiply(u, v) print("x=\t",x) print("y=\t",y) print("Bits in modulus=\t",bitlen) print("N=\t",mod) print("\nx*y (mod N)") print("\nResult (Montgomery)=\t",mr.convert_out(w)) print("Result (x*y % mod)=\t",x*y % mod) #x = random.randrange(0, mod) #y = random.randrange(0, mod) u = mr.convert_in(x) v = mr.pow(u, y) print("\nx^y (mod N)") print("Result (Montgomery)=\t",mr.convert_out(v)) print("Result (x^y % mod)=\t",x**y % mod)
Presentation
The following is a presentation on the method [slides]: