The Paillier cryptosystem supports homomorphic encryption, where two encrypted values can be added or subtracted together, and the decryption of the result gives the difference between two values:
Homomorphic Difference in Python (Paillier) |
Theory
The following is a screen shot from Wikipedia on the method:
In this case we start with two prime numbers (p and q), and then compute n. Next we get the Lowest Common Multiplier for (p-1) and (q-1), and then we get a random number g:
def gcd(a,b): while b > 0: a, b = b, a % b return a def lcm(a, b): return a * b / gcd(a, b) n = p*q gLambda = lcm(p-1,q-1) g = randint(0,100)
The next two steps involve calculating the value of the L function, and then gMu, which is the inverse of l mod n (I will show the inverse function later in the article):
l = (pow(g, gLambda, n*n)-1)//n gMu = inverse_of(l, n)
The public key is then (n,g) and the private key is (gLamda,gMu).
cipher is then computed from the message (the function pow(a,b,n) raises a to the power of b, and then takes a mod of n):
k1 = pow(g, m, n*n) k2 = pow(r, n, n*n) cipher = (k1 * k2) % (n*n)
And is decrypted with:
l = (pow(cipher, gLambda, n*n)-1) // n mess= (l * gMu) % n
A sample run with p=17, q=19, and m=10 is:
p= 17 q= 19 g= 45 r= 59 ================ Mu: 66 gLambda: 144 ================ Public key (n,g): 323 45 Private key (lambda,mu): 144 66 ================ Message: 10 Cipher: 336 Decrypted: 10
With Pallier we should be able to take values and then encrypt with the public key and then add them together:
m1=2 k3 = pow(g, m1, n*n) cipher2 = (k3 * k2) % (n*n) ciphertotal = (cipher* cipher2) % (n*n) l = (pow(ciphertotal, gLambda, n*n)-1) // n mess2= (l * gMu) % n print "Result:\t\t",mess2
and when we run we get:
p= 17 q= 19 g= 86 r= 91 ================ Mu: 40 gLambda: 144 ================ Public key (n,g): 323 86 Private key (lambda,mu): 144 40 ================ Message: 10 Cipher: 95297 Decrypted: 10 ================ Now we will add a ciphered value of 2 to the encrypted value Result: 12
and so it has computed the right value.
Coding
Here is the Python coding:
import libnum import sys from Crypto.Util.number import getPrime from Crypto.Random import get_random_bytes from random import randint import libnum def gcd(a,b): """Compute the greatest common divisor of a and b""" while b > 0: a, b = b, a % b return a def lcm(a, b): """Compute the lowest common multiple of a and b""" return a * b // gcd(a, b) def genparams(): p = getPrime(primebits, randfunc=get_random_bytes) q = getPrime(primebits, randfunc=get_random_bytes) n = p*q g=n while (gcd(g,n*n)!=1): g = randint(2000,3000) gLambda = lcm(p-1,q-1) l = (pow(g, gLambda, n*n)-1)//n gMu = libnum.invmod(l, n) return gLambda, n, g, gMu,primebits def encrypt(k): return pow(g, k, n*n) def decrypt(cipher): l = (pow(cipher, gLambda, n*n)-1) // n mess= (l * gMu) % n return mess def add(cipher,cipher2): return (cipher* cipher2) % (n*n) def sub(cipher,cipher2): v1=(cipher* (libnum.invmod(cipher2,n*n)) % (n*n)) v2=(cipher2* (libnum.invmod(cipher,n*n)) % (n*n)) return v1,v2 def L(x,n): return ((x-1)//n) primebits=60 a=9 b=3 if (len(sys.argv)>1): a=int(sys.argv[1]) if (len(sys.argv)>2): b=int(sys.argv[2]) if (len(sys.argv)>3): primebits=int(sys.argv[3]) ######## Encrypting gLambda, n, g, gMu, primesize = genparams() enc_a = encrypt(a) enc_b = encrypt(b) print (f"Public key: g={g}, n={n}\n") print (f"Private key: lambda={gLambda}, Mu={gMu}\n") print (f"\nEncrypted a={enc_a}") print (f"Encrypted b={enc_b}") cipher2_1,cipher2_2 = sub(enc_a,enc_b) mess1= decrypt(cipher2_1) mess2= decrypt(cipher2_2) res=mess1 if (mess1>n/8): res=mess2 print(f"\nVal1={a},Val2={b}. Diff: {res}")
Public key: g=131, n=758408400447992819650352783065189811 Private key: lambda=75840840044799281790208637606550480, Mu=371146169651011623634107680609529963 Encrypted a=86730203469006241 Encrypted b=2248091 Val1=8,Val2=3. Diff: 5