alex-retrieval/pir/pir_Master_supreme.py
2019-11-14 12:21:30 +01:00

108 lines
1.9 KiB
Python

import random
import math
n = 8
k = 3
def createX(x_string):
x_list = []
for i in x_string:
x_list.append(int(i))
return x_list
x_string ="11111010"
#print(x_string)
x = createX(x_string)
#print(x)
def findS():
for s_candidate in range(k, n+1):
binom = math.factorial(s_candidate) / ( math.factorial(k-1) * (math.factorial(s_candidate - (k-1))) )
if binom >= n:
return s_candidate
s = findS()
#print(s)
def makeOneSeq(seq):
seq_temp = seq.copy()
ones_remainding_counter = k - 1
while ones_remainding_counter != 0:
rand_idx = random.randint(0, s - 1)
if seq_temp[rand_idx] == 0:
seq_temp[rand_idx] = 1
ones_remainding_counter -= 1
return seq_temp
def makeSequences():
sequnces = [ [0 for _ in range(s)] for j in range(n) ]
for i, seq in enumerate(sequnces):
candidate = makeOneSeq(seq)
while( candidate in sequnces):
candidate = makeOneSeq(seq)
sequnces[i] = candidate
sequnces.sort()
return sequnces
sequnces = makeSequences()
#print(sequnces)
def user_send(rands, i):
gs_for_servers = [[rands[l] * z + int(i[l]) for l in range(0, s)] for z in range(1, k+1)]
return gs_for_servers
def server_comp(gs):
F = 0
for j in range(0, n):
f = 1
j_bitlist = sequnces[j]
for l in range(s):
if j_bitlist[l] == 1:
f *= gs[l]
F += f * x[j]
return F
def servers_comp(gs_for_servers):
Fs = []
for gs in gs_for_servers:
Fs.append(server_comp(gs))
return Fs
def poly_interpolation(x, Fs):
add_product = 0
for i in range(0, k):
mult_product = 1
for j in range(0, k):
if j != i:
mult_product *= (x - (j+1))/((i+1) - (j+1))
add_product += mult_product * Fs[i]
return int(add_product)
def protocol(i):
rands = [random.randint(1,20) for _ in range(s)]
gs_for_servers = user_send(rands, i)
Fs = servers_comp(gs_for_servers)
F0 = poly_interpolation(0, Fs)
# print(F0)
return F0
for i in range(n):
i_bitlist = sequnces[i]
F0 = protocol(i_bitlist)
assert (F0 == x[i])