Fucking meget.

This commit is contained in:
Casper V. Kristensen 2019-12-10 17:48:24 +01:00
parent a6383b0d30
commit 6e5fcd013b
24 changed files with 463 additions and 465 deletions

View File

@ -1,119 +1,133 @@
package dk.au.pir; package dk.au.pir;
import dk.au.pir.databases.Database; import dk.au.pir.databases.Database;
import dk.au.pir.databases.MemoryDatabase; import dk.au.pir.databases.FakeDatabase;
import dk.au.pir.profilers.Profiler; import dk.au.pir.profilers.Profiler;
import dk.au.pir.protocols.evenSimpler.EvenSimplerClient;
import dk.au.pir.protocols.evenSimpler.EvenSimplerServer;
import dk.au.pir.protocols.interpoly.InterPolyClient; import dk.au.pir.protocols.interpoly.InterPolyClient;
import dk.au.pir.protocols.interpoly.InterPolyServer; import dk.au.pir.protocols.interpoly.InterPolyServer;
import dk.au.pir.protocols.simple.SimpleClient; import dk.au.pir.protocols.stupid.SendAllClient;
import dk.au.pir.protocols.simple.SimpleServer; import dk.au.pir.protocols.stupid.SendAllServer;
import dk.au.pir.protocols.xor.SqrtXORClient;
import dk.au.pir.protocols.xor.SqrtXORServer;
import dk.au.pir.protocols.xor.XORClient;
import dk.au.pir.protocols.xor.XORServer;
import dk.au.pir.settings.PIRSettings; import dk.au.pir.settings.PIRSettings;
import java.util.Arrays;
public class Driver { public class Driver {
private static void testEvenSimplerScheme(PIRSettings settings, Database database, Profiler profiler) { private static int[] testSendAllScheme(PIRSettings settings, Database database, Profiler profiler) {
EvenSimplerServer[] servers = new EvenSimplerServer[settings.getNumServers()]; SendAllServer[] servers = new SendAllServer[settings.getNumServers()];
for (int i = 0; i < settings.getNumServers(); i++) { for (int i = 0; i < settings.getNumServers(); i++) {
servers[i] = new EvenSimplerServer(database, settings); servers[i] = new SendAllServer(database, settings);
} }
EvenSimplerClient client = new EvenSimplerClient(settings, servers, profiler); SendAllClient client = new SendAllClient(settings, servers, profiler);
profiler.start(); profiler.start();
client.receiveBits(0); int[] res = client.receive(0);
profiler.stop(); profiler.stop();
return res;
} }
private static void testSimpleScheme(PIRSettings settings, Database database, Profiler profiler) { private static int[] testXORScheme(PIRSettings settings, Database database, Profiler profiler) {
SimpleServer[] servers = new SimpleServer[settings.getNumServers()]; XORServer[] servers = new XORServer[settings.getNumServers()];
for (int i = 0; i < settings.getNumServers(); i++) { for (int i = 0; i < settings.getNumServers(); i++) {
servers[i] = new SimpleServer(database, settings); servers[i] = new XORServer(database, settings);
} }
SimpleClient client = new SimpleClient(settings, servers, profiler); XORClient client = new XORClient(settings, servers, profiler);
profiler.start(); profiler.start();
client.receiveBit(0); int[] res = client.receive(0);
profiler.stop(); profiler.stop();
return res;
} }
private static void testSimpleBlockScheme(PIRSettings settings, Database database, Profiler profiler) { private static int[] testSqrtXORScheme(PIRSettings settings, Database database, Profiler profiler) {
SimpleServer[] servers = new SimpleServer[settings.getNumServers()]; SqrtXORServer[] servers = new SqrtXORServer[settings.getNumServers()];
for (int i = 0; i < settings.getNumServers(); i++) { for (int i = 0; i < settings.getNumServers(); i++) {
servers[i] = new SimpleServer(database, settings); servers[i] = new SqrtXORServer(database, settings);
} }
SimpleClient client = new SimpleClient(settings, servers, profiler); SqrtXORClient client = new SqrtXORClient(settings, servers, profiler);
profiler.start(); profiler.start();
client.receiveBits(0); int[] res = client.receive(0);
profiler.stop(); profiler.stop();
return res;
} }
private static void testGeneralInterPolyScheme(PIRSettings settings, Database database, Profiler profiler) { private static int[] testInterPolyScheme(PIRSettings settings, Database database, Profiler profiler) throws IllegalArgumentException {
InterPolyServer[] servers = new InterPolyServer[settings.getNumServers()]; InterPolyServer[] servers = new InterPolyServer[settings.getNumServers()];
for (int i = 0; i < settings.getNumServers(); i++) { for (int i = 0; i < settings.getNumServers(); i++) {
servers[i] = new InterPolyServer(database, settings); servers[i] = new InterPolyServer(database, settings);
} }
InterPolyClient client = new InterPolyClient(settings, servers, profiler); InterPolyClient client = new InterPolyClient(settings, servers, profiler);
profiler.start(); profiler.start();
client.receive(0); int[] res = client.receive(0);
profiler.stop(); profiler.stop();
return res;
} }
private static void testGeneralInterPolyBlockScheme(PIRSettings settings, Database database, Profiler profiler) { private static void runTests(int numServers, int databaseSize, int blockSize) {
InterPolyServer[] servers = new InterPolyServer[settings.getNumServers()]; PIRSettings settings = new PIRSettings(databaseSize, numServers, blockSize);
for (int i = 0; i < settings.getNumServers(); i++) { for (int i = 0; i < 3; i++) { // TODO: repeat x times to warm-up
servers[i] = new InterPolyServer(database, settings); runTest(numServers, databaseSize, blockSize, settings);
}
InterPolyClient client = new InterPolyClient(settings, servers, profiler);
profiler.start();
client.receiveBlock(0);
profiler.stop();
}
private static void runTests() {
for (int numServers = 1; numServers <= 16; numServers = numServers*2) {
for (int databaseSize = 2048; databaseSize <= 32_768; databaseSize = databaseSize*2) {
for (int blockSize = 64; blockSize <= 16_384; blockSize = blockSize*2) {
for (int i = 0; i < 5; i++) {
runTest(numServers, databaseSize, blockSize);
}
}
}
} }
} }
private static void runTest(int numServers, int databaseSize, int blockSize) { private static void runTest(int numServers, int databaseSize, int blockSize, PIRSettings settings) {
PIRSettings settings = new PIRSettings(databaseSize*blockSize, numServers, blockSize); Database database = new FakeDatabase(settings);
int[] x = new int[databaseSize*blockSize];
for (int i = 0; i < x.length; i++) {
x[i] = (int) (Math.random()*2); // 0 or 1
}
Database database = new MemoryDatabase(settings, x);
Profiler profiler = new Profiler(); Profiler profiler = new Profiler();
try {
profiler.reset(); profiler.reset();
testEvenSimplerScheme(settings, database, profiler); testSendAllScheme(settings, database, profiler);
reportResult(numServers, databaseSize, blockSize, profiler, "EvenSimplerScheme"); reportResult(numServers, databaseSize, blockSize, profiler, "SendAllScheme");
} catch (OutOfMemoryError error) {
reportFailure(numServers, databaseSize, blockSize, "oom", "SendAllScheme");
}
if (numServers == 2) { if (numServers == 2) {
try {
profiler.reset(); profiler.reset();
testSimpleScheme(settings, database, profiler); testXORScheme(settings, database, profiler);
reportResult(numServers, databaseSize, blockSize, profiler, "SimpleScheme"); reportResult(numServers, databaseSize, blockSize, profiler, "XORScheme");
} catch (OutOfMemoryError error) {
profiler.reset(); reportFailure(numServers, databaseSize, blockSize, "oom", "XORScheme");
testSimpleBlockScheme(settings, database, profiler);
reportResult(numServers, databaseSize, blockSize, profiler, "SimpleBlockScheme");
} }
if (settings.getS() != 0 && numServers != 1) { try {
profiler.reset(); profiler.reset();
testGeneralInterPolyScheme(settings, database, profiler); testSqrtXORScheme(settings, database, profiler);
reportResult(numServers, databaseSize, blockSize, profiler, "GeneralInterPolyScheme"); reportResult(numServers, databaseSize, blockSize, profiler, "SqrtXORScheme");
} catch (OutOfMemoryError error) {
profiler.reset(); reportFailure(numServers, databaseSize, blockSize, "oom", "SqrtXORScheme");
testGeneralInterPolyBlockScheme(settings, database, profiler);
reportResult(numServers, databaseSize, blockSize, profiler, "GeneralInterPolyBlockScheme");
} }
} }
try {
boolean interPolySchemeShouldFuckOff = true;
if (numServers != 1 && !interPolySchemeShouldFuckOff) {
try {
profiler.reset();
testInterPolyScheme(settings, database, profiler);
reportResult(numServers, databaseSize, blockSize, profiler, "InterPolyScheme");
} catch (OutOfMemoryError error) {
reportFailure(numServers, databaseSize, blockSize, "oom", "InterPolyScheme");
}
}
} catch (IllegalArgumentException ignored) {
}
}
private static void reportFailure(int numServers, int databaseSize, int blockSize, String msg, String protocolName) {
System.out.println(
numServers + " " +
databaseSize + " " +
blockSize + " " +
protocolName + " " +
"error:" + msg
);
}
private static void reportResult(int numServers, int databaseSize, int blockSize, Profiler profiler, String protocolName) { private static void reportResult(int numServers, int databaseSize, int blockSize, Profiler profiler, String protocolName) {
System.out.println( System.out.println(
numServers + " " + numServers + " " +
@ -126,7 +140,19 @@ public class Driver {
); );
} }
private static void testSanity() {
PIRSettings settings = new PIRSettings(64, 2, 4);
Database database = new FakeDatabase(settings);
Profiler profiler = new Profiler();
System.out.println(Arrays.toString(testSendAllScheme(settings, database, profiler)));
System.out.println(Arrays.toString(testXORScheme(settings, database, profiler)));
System.out.println(Arrays.toString(testSqrtXORScheme(settings, database, profiler)));
System.out.println(Arrays.toString(testInterPolyScheme(settings, database, profiler)));
}
public static void main(String[] args) { public static void main(String[] args) {
runTests(); //testSanity();
runTests(Integer.parseInt(args[1]), Integer.parseInt(args[2]), Integer.parseInt(args[3]));
} }
} }

View File

@ -1,5 +1,5 @@
package dk.au.pir.databases; package dk.au.pir.databases;
public interface Database { public interface Database extends Iterable<Byte[]> {
public int[] getX(); public int get(long index);
} }

View File

@ -0,0 +1,64 @@
package dk.au.pir.databases;
import dk.au.pir.settings.PIRSettings;
import dk.au.pir.utils.ArrayUtils;
import java.io.*;
import java.util.Iterator;
public class DiskDatabase implements Database {
private PIRSettings settings;
private String fileName;
public DiskDatabase(PIRSettings settings, String fileName) {
this.settings = settings;
this.fileName = fileName;
}
@Override
public int get(long index) {
return 0;
}
private BufferedInputStream getBufferedInputStream() {
try {
File file = new File(this.fileName);
FileInputStream fileInputStream = new FileInputStream(file);
return new BufferedInputStream(fileInputStream);
} catch (FileNotFoundException ignored) {
return null; // yeah, fuck that
}
}
@Override
public Iterator<Byte[]> iterator() {
return new Iterator<Byte[]>() {
int blockSize = settings.getBlocksize();
byte[] cbuf = new byte[blockSize];
long hasRead = 0;
BufferedInputStream bufferedInputStream = null;
@Override
public boolean hasNext() {
return hasRead < settings.getDatabaseSize();
}
@Override
public Byte[] next() {
try {
if (!hasNext()) {
return null;
}
while (bufferedInputStream == null || bufferedInputStream.read(cbuf, 0, blockSize) == -1) {
System.out.println("Resetting bufferedInputStream");
bufferedInputStream = getBufferedInputStream();
}
hasRead += blockSize;
return ArrayUtils.toWrapper(cbuf);
} catch (IOException error) {
return null;
}
}
};
}
}

View File

@ -0,0 +1,20 @@
package dk.au.pir.databases;
import dk.au.pir.settings.PIRSettings;
import java.util.Iterator;
public class FakeDatabase implements Database {
public FakeDatabase(PIRSettings settings) {
}
@Override
public int get(long index) {
return (int) (index % 2);
}
@Override
public Iterator<Byte[]> iterator() {
return null;
}
}

View File

@ -1,15 +0,0 @@
package dk.au.pir.databases;
import dk.au.pir.settings.PIRSettings;
public class MemoryDatabase implements Database{
private final int[] x;
public MemoryDatabase(PIRSettings settings, int[] x) {
this.x = x;
}
public int[] getX() {
return x;
}
}

View File

@ -4,8 +4,8 @@ import dk.au.pir.utils.FieldElement;
import dk.au.pir.utils.MathUtils; import dk.au.pir.utils.MathUtils;
public class Profiler { public class Profiler {
private int sent; private long sent;
private int received; private long received;
private long startTime; private long startTime;
private long stopTime; private long stopTime;
@ -48,6 +48,23 @@ public class Profiler {
return numbersArrays; return numbersArrays;
} }
public boolean clientSend(boolean bool) {
this.sent += 1;
return bool;
}
public boolean[] clientSend(boolean[] bools) {
this.sent += bools.length;
return bools;
}
public boolean[][] clientSend(boolean[][] boolsArray) {
for (boolean[] bools: boolsArray) {
clientSend(bools);
}
return boolsArray;
}
public FieldElement[] clientSend(FieldElement[] elements) { public FieldElement[] clientSend(FieldElement[] elements) {
for (FieldElement element : elements) { for (FieldElement element : elements) {
this.sent += element.getValue().bitLength(); this.sent += element.getValue().bitLength();
@ -55,11 +72,15 @@ public class Profiler {
return elements; return elements;
} }
public FieldElement[][] clientSend(FieldElement[][] elements) { public FieldElement[][] clientSend(FieldElement[][] elementsArray) {
for (FieldElement[] fe : elements) { for (FieldElement[] fe : elementsArray) {
this.clientSend(fe); this.clientSend(fe);
} }
return elements; return elementsArray;
}
public void addClientReceived(long bits) {
this.received += bits;
} }
public int clientReceive(int number) { public int clientReceive(int number) {
@ -86,11 +107,11 @@ public class Profiler {
return elements; return elements;
} }
public int getSent() { public long getSent() {
return this.sent; return this.sent;
} }
public int getReceived() { public long getReceived() {
return this.received; return this.received;
} }
@ -98,10 +119,10 @@ public class Profiler {
return this.stopTime - this.startTime; return this.stopTime - this.startTime;
} }
public int log2(int n) { private long log2(long n) {
if (n == 0) { if (n == 0) {
return 1; // technically incorrect but for the sake of profiling, a 0-bit requires 1 bit of space return 1; // technically incorrect but for the sake of profiling, a 0-bit requires 1 bit of space
} }
return Integer.SIZE - Integer.numberOfLeadingZeros(n); return Long.SIZE - Long.numberOfLeadingZeros(n);
} }
} }

View File

@ -1,84 +0,0 @@
package dk.au.pir.protocols.balancedBlockScheme;
import dk.au.pir.databases.Database;
import dk.au.pir.databases.MemoryDatabase;
import dk.au.pir.profilers.Profiler;
import dk.au.pir.protocols.simple.SimpleClient;
import dk.au.pir.protocols.simple.SimpleServer;
import dk.au.pir.settings.PIRSettings;
import java.util.Arrays;
import java.util.Random;
public class balancedBlockClient {
private final PIRSettings settings;
private final balancedBlockServer[] servers;
private final int sqrtSize;
private Profiler profiler;
public balancedBlockClient(PIRSettings settings, balancedBlockServer[] servers, Profiler profiler) {
this.settings = settings;
this.servers = servers;
this.profiler = profiler;
this.sqrtSize = (int) Math.ceil(Math.sqrt(settings.getDatabaseSize()));
}
public int[] selectIndexes(int n) {
int[] indexes = new int[n];
Random rand = new Random();
for (int i=0; i < n; i++) {
indexes[i] = rand.nextInt(2);
}
return indexes;
}
public int receiveBit(int index) {
/**
* PLAN:
* Divide n into sqrt(n)
* Compute which index we want find this within a block
* Send block
*/
int[] S1 = selectIndexes(this.sqrtSize);
int[] S2 = S1.clone();
int impBlock = (int) Math.floor(index/this.sqrtSize);
System.out.println("ImpBlock: " + impBlock);
if (S1[index % this.sqrtSize] == 1) {
S2[index % this.sqrtSize] = 0; // Remove the index, if it's contained in S.
} else {
S2[index % this.sqrtSize] = 1;
}
System.out.println("S1: " + Arrays.toString(S1));
System.out.println("S2: " + Arrays.toString(S2));
int[] resBit1 = this.servers[0].computeBit(S1);
int[] resBit2 = this.servers[1].computeBit(S2);
return ((resBit1[impBlock] + resBit2[impBlock]) % 2);
}
public static void main(String[] args) {
PIRSettings settings = new PIRSettings(16, 2, 1);
balancedBlockServer[] servers = new balancedBlockServer[settings.getNumServers()];
Database database = new MemoryDatabase(settings, new int[] {0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0});
for (int i = 0; i < settings.getNumServers(); i++) {
servers[i] = new balancedBlockServer(database, settings);
}
balancedBlockClient client = new balancedBlockClient(settings, servers, null);
System.out.println(client.receiveBit(11));
}
}

View File

@ -1,54 +0,0 @@
package dk.au.pir.protocols.balancedBlockScheme;
import dk.au.pir.databases.Database;
import dk.au.pir.settings.PIRSettings;
import java.util.Arrays;
public class balancedBlockServer {
private final Database database;
private final PIRSettings settings;
private final int sqrtSize;
public balancedBlockServer(Database database, PIRSettings settings) {
this.database = database;
this.settings = settings;
this.sqrtSize = (int) Math.ceil(Math.sqrt(settings.getDatabaseSize()));
}
public int[] computeBit(int[] indexes) {
int[] db = database.getX();
/*
Divide n in the sqrt(n) size chunks
Get sqrt(n) size array from client, which we cycle through sqrt(n) times
We return a sqrt(n) size list of bits. One from each cycle.
*/
int[] resList = new int[this.sqrtSize];
for (int i = 0; i < this.sqrtSize; i++) {
int tmpRes = 0;
for (int j = 0; j < this.sqrtSize; j++) {
try {
boolean test = indexes[j] == 1;
if (test) {
System.out.println("Looking at index: " + (j + (this.sqrtSize * i)));
tmpRes = (tmpRes + db[j + (this.sqrtSize * i)]) % 2;
}
} catch (ArrayIndexOutOfBoundsException e) {
tmpRes = (tmpRes) % 2;
}
}
resList[i] = tmpRes;
}
System.out.println("ResList: " + Arrays.toString(resList));
return resList;
}
}

View File

@ -1,28 +0,0 @@
package dk.au.pir.protocols.evenSimpler;
import dk.au.pir.profilers.Profiler;
import dk.au.pir.settings.PIRSettings;
public class EvenSimplerClient {
private final PIRSettings settings;
private final EvenSimplerServer[] servers;
private final Profiler profiler;
public EvenSimplerClient(PIRSettings settings, EvenSimplerServer[] servers, Profiler profiler) {
this.settings = settings;
this.servers = servers;
this.profiler = profiler;
}
public int receiveBit(int index) {
int[] data = this.profiler.clientReceive(this.servers[0].giveDatabase());
return data[index];
}
public int[] receiveBits(int record) {
int[] res = new int[settings.getBlocksize()];
int[] data = this.profiler.clientReceive(this.servers[0].giveDatabase());
System.arraycopy(data, (record * settings.getBlocksize()), res, 0, settings.getBlocksize());
return res;
}
}

View File

@ -1,16 +0,0 @@
package dk.au.pir.protocols.evenSimpler;
import dk.au.pir.databases.Database;
import dk.au.pir.settings.PIRSettings;
public class EvenSimplerServer {
private final Database database;
public EvenSimplerServer(Database database, PIRSettings settings) {
this.database = database;
}
public int[] giveDatabase() {
return this.database.getX(); // lol
}
}

View File

@ -5,28 +5,34 @@ import dk.au.pir.profilers.Profiler;
import dk.au.pir.settings.PIRSettings; import dk.au.pir.settings.PIRSettings;
import dk.au.pir.utils.FieldElement; import dk.au.pir.utils.FieldElement;
import dk.au.pir.utils.FieldElementLagrange; import dk.au.pir.utils.FieldElementLagrange;
import dk.au.pir.utils.MathUtils;
import java.math.BigInteger; import dk.au.pir.utils.ProtocolUtils;
import java.util.Arrays;
import static dk.au.pir.utils.ProtocolUtils.printIntArrayArray;
public class InterPolyClient { public class InterPolyClient {
private PIRSettings settings; private PIRSettings settings;
private InterPolyServer[] servers; private InterPolyServer[] servers;
private final int s; private final int s;
private final BigIntegerField field; private final BigIntegerField field;
private final int[][] sequences; private final boolean[][] sequences;
private Profiler profiler; private Profiler profiler;
public InterPolyClient(PIRSettings settings, InterPolyServer[] servers, Profiler profiler) { public InterPolyClient(PIRSettings settings, InterPolyServer[] servers, Profiler profiler) throws IllegalArgumentException {
this.settings = settings; this.settings = settings;
this.servers = servers; this.servers = servers;
this.s = settings.getS();
this.field = settings.getField(); this.field = settings.getField();
this.sequences = settings.getSequences();
this.profiler = profiler; this.profiler = profiler;
this.s = calculateS(this.settings.getNumServers(), this.settings.getDatabaseSize() * this.settings.getBlocksize()); // TODO: Should be long-multiplication
this.sequences = ProtocolUtils.createSequences(s, this.settings.getNumServers(), this.settings.getDatabaseSize() * this.settings.getBlocksize());
}
private int calculateS(int k, int n) throws IllegalArgumentException {
for (int s = k-1; s <= n; s++) {
if (MathUtils.binomial(s, k-1) >= n) {
return s;
}
}
throw new IllegalArgumentException();
} }
private FieldElement[] getRandomFieldElements() { private FieldElement[] getRandomFieldElements() {
@ -48,47 +54,28 @@ public class InterPolyClient {
private FieldElement[] getGs(int index, int serverNumber, FieldElement[] random) { private FieldElement[] getGs(int index, int serverNumber, FieldElement[] random) {
FieldElement[] gs = new FieldElement[this.s]; FieldElement[] gs = new FieldElement[this.s];
int[] i = this.sequences[index]; boolean[] i = this.sequences[index];
for (int l = 0; l < this.s; l++) { for (int l = 0; l < this.s; l++) {
gs[l] = random[l].multiply(this.field.valueOf(serverNumber));
gs[l] = random[l].multiply(this.field.valueOf(serverNumber)).add(this.field.valueOf(i[l])); if (i[l]) {
gs[l].add(this.field.valueOf(1));
}
} }
return gs; return gs;
}
private int receiveBit(int index) {
FieldElement[] randoms = this.getRandomFieldElements();
FieldElement[] Fs = new FieldElement[this.servers.length];
for (int z = 0; z < this.servers.length; z++) {
Fs[z] = this.profiler.clientReceive(this.servers[z].F(this.profiler.clientSend(this.getGs(index, z+1, randoms))));
}
FieldElement res = FieldElementLagrange.interpolate(this.field, Fs);
return res.getValue().intValue();
} }
public int[] receive(int record) { public int[] receive(int record) {
int[] results = new int[settings.getBlocksize()];
for (int i = 0; i < settings.getBlocksize(); i++) {
results[i] = this.receiveBit((settings.getBlocksize() * record) + i);
}
return results;
}
public int[] receiveBlock(int record) {
int[] results = new int[settings.getBlocksize()]; int[] results = new int[settings.getBlocksize()];
FieldElement[][] randoms = this.getRandomFieldElementsBlock(); FieldElement[][] randoms = this.getRandomFieldElementsBlock();
FieldElement[][] Fs = new FieldElement[this.servers.length][settings.getBlocksize()]; FieldElement[][] Fs = new FieldElement[this.servers.length][settings.getBlocksize()];
/** //Compute all the Gs for each server, s.t. the first index should be the blocksize and it should contain all the Gs for the given index
* 1) Compute all the Gs for each server, s.t. the first index should be the blocksize and it should contain all the Gs for the given index
*/
for (int z = 0; z < this.servers.length; z++) { for (int z = 0; z < this.servers.length; z++) {
FieldElement[][] Gs = new FieldElement[settings.getBlocksize()][this.s]; FieldElement[][] Gs = new FieldElement[settings.getBlocksize()][this.s];
for (int i = 0; i < settings.getBlocksize(); i++) { for (int i = 0; i < settings.getBlocksize(); i++) {
Gs[i] = this.getGs(record*settings.getBlocksize() + i, z+1, randoms[i]); Gs[i] = this.getGs(record*settings.getBlocksize() + i, z+1, randoms[i]);
} }
Fs[z] = profiler.clientReceive(this.servers[z].FBlock(profiler.clientSend(Gs))); Fs[z] = profiler.clientReceive(this.servers[z].FBlock(profiler.clientSend(Gs), this.s, this.sequences));
} }
for (int i = 0; i < settings.getBlocksize(); i++) { for (int i = 0; i < settings.getBlocksize(); i++) {

View File

@ -17,29 +17,29 @@ public class InterPolyServer {
this.field = settings.getField(); this.field = settings.getField();
} }
public FieldElement F(FieldElement[] gs) { private FieldElement F(FieldElement[] gs, int s, boolean[][] sequences) {
FieldElement sum = this.field.valueOf(0); FieldElement sum = this.field.valueOf(0);
for (int j = 0; j < this.settings.getDatabaseSize(); j++) { for (int j = 0; j < this.settings.getDatabaseSize() * this.settings.getBlocksize(); j++) { // TODO: Should be long-multiplcation
FieldElement product = this.field.valueOf(1); FieldElement product = this.field.valueOf(1);
for (int l = 0; l < this.settings.getS(); l++) { for (int l = 0; l < s; l++) {
if (this.settings.getSequences()[j][l] == 1) { if (sequences[j][l]) {
product = product.multiply(gs[l]); product = product.multiply(gs[l]);
//System.out.println("gs: " + gs[l]); //System.out.println("gs: " + gs[l]);
} }
} }
sum = sum.add(product.multiply(this.field.valueOf(this.database.getX()[j]))); sum = sum.add(product.multiply(this.field.valueOf(this.database.get(j))));
} }
return sum; return sum;
} }
public FieldElement[] FBlock(FieldElement[][] gs) { public FieldElement[] FBlock(FieldElement[][] gs, int s, boolean[][] sequences) {
FieldElement[] sum = new FieldElement[this.settings.getBlocksize()]; FieldElement[] sum = new FieldElement[this.settings.getBlocksize()];
for (int i = 0; i < sum.length; i++) { for (int i = 0; i < sum.length; i++) {
sum[i] = this.field.valueOf(0); sum[i] = this.field.valueOf(0);
} }
for (int i = 0; i < this.settings.getBlocksize(); i++) { for (int i = 0; i < this.settings.getBlocksize(); i++) {
sum[i] = F(gs[i]); sum[i] = F(gs[i], s, sequences);
} }
return sum; return sum;
} }

View File

@ -1,72 +0,0 @@
package dk.au.pir.protocols.simple;
import dk.au.pir.profilers.Profiler;
import dk.au.pir.settings.PIRSettings;
import java.util.Random;
public class SimpleClient {
private final PIRSettings settings;
private final SimpleServer[] servers;
private Profiler profiler;
public SimpleClient(PIRSettings settings, SimpleServer[] servers, Profiler profiler) {
this.settings = settings;
this.servers = servers;
this.profiler = profiler;
}
public int[] selectIndexes() {
int[] indexes = new int[settings.getDatabaseSize()];
Random rand = new Random();
for (int i=0; i < settings.getDatabaseSize(); i++) {
indexes[i] = rand.nextInt(2);
}
return indexes;
}
public int receiveBit(int index) {
int[] S1 = selectIndexes();
int[] S2 = S1.clone();
if (S1[index] == 1) {
S2[index] = 0; // Remove the index, if it's contained in S.
} else {
S2[index] = 1;
}
int resBit1 = this.profiler.clientReceive(this.servers[0].computeBit(this.profiler.clientSend(S1)));
int resBit2 = this.profiler.clientReceive(this.servers[1].computeBit(this.profiler.clientSend(S2)));
return ((resBit1 + resBit2) % 2);
}
public int[] receiveBits(int record) {
int[] result = new int[settings.getBlocksize()];
int[][] S1s = new int[settings.getBlocksize()][settings.getDatabaseSize()];
int[][] S2s = new int[settings.getBlocksize()][settings.getDatabaseSize()];
for (int i = 0; i < settings.getBlocksize(); i++) {
S1s[i] = selectIndexes();
S2s[i] = S1s[i].clone();
if (S1s[i][(record*settings.getBlocksize())+i] == 1) {
// Remove the index, if it's contained in S.
S2s[i][(record*settings.getBlocksize())+i] = 0;
} else {
S2s[i][(record*settings.getBlocksize())+i] = 1;
}
}
int[] resBit1 = this.profiler.clientReceive(this.servers[0].computeBits(this.profiler.clientSend(S1s)));
int[] resBit2 = this.profiler.clientReceive(this.servers[1].computeBits(this.profiler.clientSend(S2s)));
for (int i = 0; i < settings.getBlocksize(); i++) {
result[i] = (resBit1[i] + resBit2[i]) % 2;
}
return result;
}
}

View File

@ -0,0 +1,27 @@
package dk.au.pir.protocols.stupid;
import dk.au.pir.databases.Database;
import dk.au.pir.profilers.Profiler;
import dk.au.pir.settings.PIRSettings;
public class SendAllClient {
private final PIRSettings settings;
private final SendAllServer[] servers;
private final Profiler profiler;
public SendAllClient(PIRSettings settings, SendAllServer[] servers, Profiler profiler) {
this.settings = settings;
this.servers = servers;
this.profiler = profiler;
}
public int[] receive(int record) {
int[] res = new int[settings.getBlocksize()];
Database database = this.servers[0].giveDatabase();
this.profiler.addClientReceived((long) this.settings.getDatabaseSize() * (long) this.settings.getBlocksize());
for (int i = 0; i < this.settings.getBlocksize(); i++) {
res[i] = database.get(record * settings.getBlocksize() + i);
}
return res;
}
}

View File

@ -0,0 +1,16 @@
package dk.au.pir.protocols.stupid;
import dk.au.pir.databases.Database;
import dk.au.pir.settings.PIRSettings;
public class SendAllServer {
private final Database database;
public SendAllServer(Database database, PIRSettings settings) {
this.database = database;
}
public Database giveDatabase() {
return this.database; // lol
}
}

View File

@ -0,0 +1,58 @@
package dk.au.pir.protocols.xor;
import dk.au.pir.profilers.Profiler;
import dk.au.pir.settings.PIRSettings;
import java.util.Random;
public class SqrtXORClient {
private final PIRSettings settings;
private final SqrtXORServer[] servers;
private final int sqrtSize;
private Profiler profiler;
public SqrtXORClient(PIRSettings settings, SqrtXORServer[] servers, Profiler profiler) {
this.settings = settings;
this.servers = servers;
this.profiler = profiler;
this.sqrtSize = (int) Math.ceil(Math.sqrt((long) settings.getDatabaseSize() * (long) settings.getBlocksize()));
}
public boolean[] selectIndexes(int n) {
boolean[] indexes = new boolean[n];
Random rand = new Random();
for (int i=0; i < n; i++) {
indexes[i] = rand.nextBoolean();
}
return indexes;
}
public int receiveBit(int index) {
/**
* PLAN:
* Divide n into sqrt(n)
* Compute which index we want find this within a block
* Send block
*/
boolean[] S1 = selectIndexes(this.sqrtSize);
boolean[] S2 = S1.clone();
int impBlock = (int) Math.floor(index/this.sqrtSize);
S2[index % this.sqrtSize] = !S1[index % this.sqrtSize]; // Remove the index, if it's contained in S.
int[] resBit1 = this.profiler.clientReceive(this.servers[0].computeBit(this.profiler.clientSend(S1)));
int[] resBit2 = this.profiler.clientReceive(this.servers[1].computeBit(this.profiler.clientSend(S2)));
return ((resBit1[impBlock] + resBit2[impBlock]) % 2);
}
public int[] receive(int record) {
// TODO: This is bad - should merge with above receiveBit-method to send entire array of bits at once (like the simple XORScheme)
int[] result = new int[settings.getBlocksize()];
for (int i = 0; i < settings.getBlocksize(); i++) {
result[i] = this.receiveBit(record * this.settings.getBlocksize() + i);
}
return result;
}
}

View File

@ -0,0 +1,34 @@
package dk.au.pir.protocols.xor;
import dk.au.pir.databases.Database;
import dk.au.pir.settings.PIRSettings;
public class SqrtXORServer {
private final Database database;
private final PIRSettings settings;
private final int sqrtSize;
public SqrtXORServer(Database database, PIRSettings settings) {
this.database = database;
this.settings = settings;
this.sqrtSize = (int) Math.ceil(Math.sqrt((long) settings.getDatabaseSize() * (long) settings.getBlocksize()));
}
public int[] computeBit(boolean[] indexes) {
int[] resList = new int[this.sqrtSize];
for (int i = 0; i < this.sqrtSize; i++) {
int tmpRes = 0;
for (int j = 0; j < this.sqrtSize; j++) {
try {
if (indexes[j]) {
tmpRes = (tmpRes + this.database.get(j + (this.sqrtSize * i))) % 2;
}
} catch (ArrayIndexOutOfBoundsException ignored) {
}
}
resList[i] = tmpRes;
}
return resList;
}
}

View File

@ -0,0 +1,51 @@
package dk.au.pir.protocols.xor;
import dk.au.pir.profilers.Profiler;
import dk.au.pir.settings.PIRSettings;
import java.util.Random;
public class XORClient {
private final PIRSettings settings;
private final XORServer[] servers;
private Profiler profiler;
public XORClient(PIRSettings settings, XORServer[] servers, Profiler profiler) {
this.settings = settings;
this.servers = servers;
this.profiler = profiler;
}
private boolean[] selectIndexes() {
boolean[] indexes = new boolean[settings.getDatabaseSize() * settings.getBlocksize()]; // TODO: should be long-multiplication
Random rand = new Random();
for (int i=0; i < settings.getDatabaseSize() * settings.getBlocksize(); i++) {
indexes[i] = rand.nextBoolean();
}
return indexes;
}
public int[] receive(int record) {
int[] result = new int[settings.getBlocksize()];
boolean[][] S1s = new boolean[settings.getBlocksize()][settings.getDatabaseSize() * settings.getBlocksize()]; // TODO: Should be long-multiplication
boolean[][] S2s = new boolean[settings.getBlocksize()][settings.getDatabaseSize() * settings.getBlocksize()];
for (int i = 0; i < settings.getBlocksize(); i++) {
S1s[i] = selectIndexes(); // TODO
S2s[i] = S1s[i].clone();
// Remove the index, if it's contained in S.
S2s[i][(record*settings.getBlocksize())+i] = !S1s[i][(record * settings.getBlocksize()) + i];
}
int[] resBit1 = this.profiler.clientReceive(this.servers[0].computeBits(this.profiler.clientSend(S1s)));
int[] resBit2 = this.profiler.clientReceive(this.servers[1].computeBits(this.profiler.clientSend(S2s)));
for (int i = 0; i < settings.getBlocksize(); i++) {
result[i] = (resBit1[i] + resBit2[i]) % 2;
}
return result;
}
}

View File

@ -1,27 +1,27 @@
package dk.au.pir.protocols.simple; package dk.au.pir.protocols.xor;
import dk.au.pir.databases.Database; import dk.au.pir.databases.Database;
import dk.au.pir.settings.PIRSettings; import dk.au.pir.settings.PIRSettings;
public class SimpleServer { public class XORServer {
private final Database database; private final Database database;
private final PIRSettings settings; private final PIRSettings settings;
public SimpleServer(Database database, PIRSettings settings) { public XORServer(Database database, PIRSettings settings) {
this.database = database; this.database = database;
this.settings = settings; this.settings = settings;
} }
public int computeBit(int[] indexes) { private int computeBit(boolean[] indexes) {
int res = database.getX()[indexes[0]]; int res = 0;
for (int i=1; i<indexes.length; i++) { for (int i=1; i<indexes.length; i++) {
if (indexes[i] == 1) if (indexes[i])
res = (res + database.getX()[i]) % 2; res = (res + database.get(i)) % 2;
} }
return res; return res;
} }
public int[] computeBits(int[][] indexes) { public int[] computeBits(boolean[][] indexes) {
int[] res = new int[settings.getBlocksize()]; int[] res = new int[settings.getBlocksize()];
for (int i = 0; i < settings.getBlocksize(); i++) { for (int i = 0; i < settings.getBlocksize(); i++) {
res[i] = computeBit(indexes[i]); res[i] = computeBit(indexes[i]);

View File

@ -1,15 +1,11 @@
package dk.au.pir.settings; package dk.au.pir.settings;
import dk.au.pir.BigIntegerField; import dk.au.pir.BigIntegerField;
import dk.au.pir.utils.MathUtils;
import dk.au.pir.utils.ProtocolUtils;
public class PIRSettings { public class PIRSettings {
private final int databaseSize; private final int databaseSize;
private final int numServers; private final int numServers;
private int s;
private int[][] sequences;
private final BigIntegerField field; private final BigIntegerField field;
private final int blocksize; private final int blocksize;
@ -19,25 +15,9 @@ public class PIRSettings {
this.numServers = numServers; this.numServers = numServers;
this.blocksize = blocksize; this.blocksize = blocksize;
try {
this.s = calculateS(numServers, databaseSize);
this.sequences = ProtocolUtils.createSequences(s, numServers, databaseSize);
} catch (IllegalArgumentException error) {
this.s = 0;
}
this.field = new BigIntegerField(); this.field = new BigIntegerField();
} }
private int calculateS(int k, int n) throws IllegalArgumentException {
for (int s = k-1; s <= n; s++) {
if (MathUtils.binomial(s, k-1) >= n) {
return s;
}
}
throw new IllegalArgumentException();
}
public int getDatabaseSize() { public int getDatabaseSize() {
return databaseSize; return databaseSize;
} }
@ -46,14 +26,6 @@ public class PIRSettings {
return numServers; return numServers;
} }
public int getS() {
return s;
}
public int[][] getSequences() {
return sequences;
}
public BigIntegerField getField() { public BigIntegerField getField() {
return field; return field;
} }

View File

@ -0,0 +1,11 @@
package dk.au.pir.utils;
import java.util.Arrays;
public class ArrayUtils {
public static Byte[] toWrapper(byte[] bytesPrim) {
Byte[] bytes = new Byte[bytesPrim.length];
Arrays.setAll(bytes, n -> bytesPrim[n]);
return bytes;
}
}

View File

@ -2,10 +2,12 @@ package dk.au.pir.utils;
public class MathUtils { public class MathUtils {
public static int binomial(int n, int k) { public static int binomial(int n, int k) {
if ((n == k) || (k == 0)) { if (k > n - k)
return 1; k = n - k;
} else {
return binomial(n - 1, k) + binomial(n - 1, k - 1); int b = 1;
} for (int i=1, m=n; i<=k; i++, m--)
b = b * m / i;
return b;
} }
} }

View File

@ -1,43 +1,14 @@
package dk.au.pir.utils; package dk.au.pir.utils;
import java.util.*;
import java.util.stream.Collectors;
public class ProtocolUtils { public class ProtocolUtils {
private static int[] createSequence(int s, int k) { public static boolean[][] createSequences(int s, int k, int n) {
Random rand = new Random(); // TODO: Un-hardcode for k!=2
int[] sequence = new int[s]; boolean[][] arrays = new boolean[n][s];
int kRemaining = k - 1; for (int i = 0; i < n; i++) {
while (kRemaining != 0) { for (int j = 0; j < s; j++) {
int rand_idx = rand.nextInt(s); arrays[i][j] = i == j;
if (sequence[rand_idx] == 0) {
sequence[rand_idx] = 1;
kRemaining--;
} }
} }
return sequence;
}
public static int[][] createSequences(int s, int k, int n) {
Set<List<Integer>> sequences = new HashSet<>();
while (sequences.size() < n) {
sequences.add(Arrays.stream(createSequence(s, k)).boxed().collect(Collectors.toList()));
}
List<List<Integer>> lists = new ArrayList<>(sequences);
lists.sort((l1, l2) -> {
for (int i = 0; i < l1.size(); i++) {
int equals = l1.get(i).compareTo(l2.get(i));
if (equals != 0) {
return equals;
}
}
return 0;
});
int[][] arrays = new int[n][s];
for (int j = 0; j < n; j++) {
int[] array = lists.get(j).stream().mapToInt(i -> i).toArray();
arrays[j] = array;
}
return arrays; return arrays;
} }
@ -48,6 +19,5 @@ public class ProtocolUtils {
} }
System.out.println(""); System.out.println("");
} }
} }
} }

8
pir/test.sh Executable file
View File

@ -0,0 +1,8 @@
apt update
apt install -y htop tmux openjdk-11-jdk
rm -f ~/results.log
cd classes/
tmux \
new-session 'python3 ../collect.py | tee ~/results.log' \; \
split-window -h 'htop' \;