Montgomery reduction algorithm with GoThe Montgomery reduction algorithm is a fast way to perform multiplications in the form of \(x y \pmod N\) or \(x^y \pmod N\). In this case we will take a prime number of \(N\) and determine \(x \times y \pmod N\) and \(x^y \pmod N\) [Python version]: |
Outline
With the RSA and the Diffie-Hellman method we perform large exponential calculations, such as:
\(C = M^e \pmod N\)
and where we will continually multiply large integers by an exponent to get a result. If we were to just calculate \(x^y\) and then take \(\pmod n\) it would take a while to produce the result. Thus we use Montgomery modular multiplication, and which significantly reduces the time to compute the result. In a traditional multiplication of two value (\(x\) and \(y\)) for a modulus of \(N\), we multiply \(x\) times \(y\) and then divide by \(N\) to find the remainder. The number of bits in the multiplication with this be the number of bits in \(x\) added to the number of bits in \(y\). In Montgomery reduction we add multiples of \(N\) in order to simplify the multiplication.
An example of this is here, and a sample run for \(x=10\), \(y=5\) and \(N=29\) is:
a=10, b=5, p=29 Result: 10*5 (mod 29) = 21 Traditional method result = 21 Result: 10^5 (mod 29) = 8 Traditional method result = 8
In this case we get \(50 \pmod {29}\) which is 21, and \(10^5 \pmod {29}\) which is 8.
The sample code is:
package main import ( "fmt" "math/big" "os" "strconv" ) // == From https://rosettacode.org/wiki/Montgomery_reduction#Go type mont struct { n uint // m.BitLen() m *big.Int // modulus, must be odd r2 *big.Int // (1<<2n) mod m } func newMont(m *big.Int) *mont { if m.Bit(0) != 1 { return nil } n := uint(m.BitLen()) x := big.NewInt(1) x.Sub(x.Lsh(x, n), m) return &mont{n, new(big.Int).Set(m), x.Mod(x.Mul(x, x), m)} } func (m mont) reduce(t *big.Int) *big.Int { a := new(big.Int).Set(t) for i := uint(0); i < m.n; i++ { if a.Bit(0) == 1 { a.Add(a, m.m) } a.Rsh(a, 1) } if a.Cmp(m.m) >= 0 { a.Sub(a, m.m) } return a } // == func main() { p:=13 a:=5 b:=6 argCount := len(os.Args[1:]) if (argCount>0) {a,_= strconv.Atoi(os.Args[1])} if (argCount>1) {b,_= strconv.Atoi(os.Args[2])} if (argCount>2) {p,_= strconv.Atoi(os.Args[3])} m := big.NewInt(int64(p)) fmt.Printf("a=%d, b=%d, p=%d\n\n",a,b,p) mr := newMont(m) x1 := big.NewInt(int64(a)) x2 := big.NewInt(int64(b)) t1 := mr.reduce(new(big.Int).Mul(x1, mr.r2)) t2 := mr.reduce(new(big.Int).Mul(x2, mr.r2)) res := mr.reduce(new(big.Int).Mul(t1, t2)) fmt.Printf("Result: %s*%s (mod %s) = %s\n",x1,x2,m,mr.reduce(res)) mul:=new(big.Int).Mul(x1, x2) fmt.Printf("Traditional method result = %s\n\n",mul.Mod(mul,m)) prod := mr.reduce(mr.r2) base := mr.reduce(new(big.Int).Mul(x1, mr.r2)) exp := new(big.Int).Set(x2) for exp.BitLen() > 0 { if exp.Bit(0) == 1 { prod = mr.reduce(prod.Mul(prod, base)) } exp.Rsh(exp, 1) base = mr.reduce(base.Mul(base, base)) } fmt.Printf("\nResult: %s^%s (mod %s) = %s\n",x1,x2,m,mr.reduce(prod)) fmt.Printf("Traditional method result = %s",new(big.Int).Exp(x1, x2, m)) }
Presentation
The following is a presentation on the method [slides]: