Remove network simulation -- can just do it in Python. 50x benchmark speedup.

This commit is contained in:
Casper V. Kristensen 2019-12-03 14:54:11 +01:00
parent d4e33c0098
commit 202650b1dd
5 changed files with 13 additions and 47 deletions

View File

@ -72,64 +72,57 @@ public class Driver {
for (int numServers = 1; numServers <= 16; numServers = numServers*2) { for (int numServers = 1; numServers <= 16; numServers = numServers*2) {
for (int databaseSize = 2048; databaseSize <= 32_768; databaseSize = databaseSize*2) { for (int databaseSize = 2048; databaseSize <= 32_768; databaseSize = databaseSize*2) {
for (int blockSize = 64; blockSize <= 16_384; blockSize = blockSize*2) { for (int blockSize = 64; blockSize <= 16_384; blockSize = blockSize*2) {
for (int latency = 0; latency <= 500; latency = latency + 50) {
for (int bandwidth = 64; bandwidth <= 16_384; bandwidth = bandwidth*2) { // in kbit/s
for (int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
runTest(numServers, databaseSize, blockSize, latency, bandwidth); runTest(numServers, databaseSize, blockSize);
}
}
} }
} }
} }
} }
} }
private static void runTest(int numServers, int databaseSize, int blockSize, int latency, int bandwidth) { private static void runTest(int numServers, int databaseSize, int blockSize) {
PIRSettings settings = new PIRSettings(databaseSize*blockSize, numServers, blockSize); PIRSettings settings = new PIRSettings(databaseSize*blockSize, numServers, blockSize);
int[] x = new int[databaseSize]; int[] x = new int[databaseSize*blockSize];
for (int i = 0; i < x.length; i++) { for (int i = 0; i < x.length; i++) {
x[i] = (int) (Math.random()*2); // 0 or 1 x[i] = (int) (Math.random()*2); // 0 or 1
} }
Database database = new MemoryDatabase(settings, x); Database database = new MemoryDatabase(settings, x);
Profiler profiler = new Profiler(latency, bandwidth/10, bandwidth); Profiler profiler = new Profiler();
profiler.reset(); profiler.reset();
testEvenSimplerScheme(settings, database, profiler); testEvenSimplerScheme(settings, database, profiler);
reportResult(numServers, databaseSize, blockSize, latency, bandwidth, profiler, "EvenSimplerScheme"); reportResult(numServers, databaseSize, blockSize, profiler, "EvenSimplerScheme");
if (numServers == 2) { if (numServers == 2) {
profiler.reset(); profiler.reset();
testSimpleScheme(settings, database, profiler); testSimpleScheme(settings, database, profiler);
reportResult(numServers, databaseSize, blockSize, latency, bandwidth, profiler, "SimpleScheme"); reportResult(numServers, databaseSize, blockSize, profiler, "SimpleScheme");
profiler.reset(); profiler.reset();
testSimpleBlockScheme(settings, database, profiler); testSimpleBlockScheme(settings, database, profiler);
reportResult(numServers, databaseSize, blockSize, latency, bandwidth, profiler, "SimpleBlockScheme"); reportResult(numServers, databaseSize, blockSize, profiler, "SimpleBlockScheme");
} }
if (settings.getS() != 0 && numServers != 1) { if (settings.getS() != 0 && numServers != 1) {
profiler.reset(); profiler.reset();
testGeneralInterPolyScheme(settings, database, profiler); testGeneralInterPolyScheme(settings, database, profiler);
reportResult(numServers, databaseSize, blockSize, latency, bandwidth, profiler, "GeneralInterPolyScheme"); reportResult(numServers, databaseSize, blockSize, profiler, "GeneralInterPolyScheme");
profiler.reset(); profiler.reset();
testGeneralInterPolyBlockScheme(settings, database, profiler); testGeneralInterPolyBlockScheme(settings, database, profiler);
reportResult(numServers, databaseSize, blockSize, latency, bandwidth, profiler, "GeneralInterPolyBlockScheme"); reportResult(numServers, databaseSize, blockSize, profiler, "GeneralInterPolyBlockScheme");
} }
} }
private static void reportResult(int numServers, int databaseSize, int blockSize, int latency, int bandwidth, Profiler profiler, String protocolName) { private static void reportResult(int numServers, int databaseSize, int blockSize, Profiler profiler, String protocolName) {
System.out.println( System.out.println(
numServers + " " + numServers + " " +
databaseSize + " " + databaseSize + " " +
blockSize + " " + blockSize + " " +
latency + " " +
bandwidth + " " +
protocolName + " " + protocolName + " " +
profiler.getTotalCPUTime() + " " + profiler.getTotalCPUTime() + " " +
profiler.getSent() + " " + profiler.getSent() + " " +
profiler.getReceived() + " " + profiler.getReceived()
profiler.getTotalNetworkTime()
); );
} }

View File

@ -4,20 +4,12 @@ import dk.au.pir.utils.FieldElement;
import dk.au.pir.utils.MathUtils; import dk.au.pir.utils.MathUtils;
public class Profiler { public class Profiler {
private final int latency;
private final int sendBandwidth;
private final int receiveBandwidth;
private int sent; private int sent;
private int received; private int received;
private int networkTime;
private long startTime; private long startTime;
private long stopTime; private long stopTime;
public Profiler(int latency, int sendBandwidth, int receiveBandwidth) { public Profiler() {
this.latency = latency;
this.sendBandwidth = sendBandwidth;
this.receiveBandwidth = receiveBandwidth;
reset(); reset();
} }
@ -33,19 +25,10 @@ public class Profiler {
public void reset() { public void reset() {
this.sent = 0; this.sent = 0;
this.received = 0; this.received = 0;
this.networkTime = 0;
this.startTime = 0; this.startTime = 0;
this.stopTime = 0; this.stopTime = 0;
} }
public void addNetworkDelay() {
this.addNetworkDelay(1);
}
public void addNetworkDelay(int n) {
this.networkTime += latency * n;
}
public int clientSend(int number) { public int clientSend(int number) {
this.sent += log2(number); this.sent += log2(number);
return number; return number;
@ -115,10 +98,6 @@ public class Profiler {
return this.stopTime - this.startTime; return this.stopTime - this.startTime;
} }
public int getTotalNetworkTime() {
return networkTime + (this.sent/this.sendBandwidth) + (this.received/this.receiveBandwidth);
}
public int log2(int n) { public int log2(int n) {
if (n == 0) { if (n == 0) {
return 1; // technically incorrect but for the sake of profiling, a 0-bit requires 1 bit of space return 1; // technically incorrect but for the sake of profiling, a 0-bit requires 1 bit of space

View File

@ -15,14 +15,12 @@ public class EvenSimplerClient {
} }
public int receiveBit(int index) { public int receiveBit(int index) {
this.profiler.addNetworkDelay(2);
int[] data = this.profiler.clientReceive(this.servers[0].giveDatabase()); int[] data = this.profiler.clientReceive(this.servers[0].giveDatabase());
return data[index]; return data[index];
} }
public int[] receiveBits(int record) { public int[] receiveBits(int record) {
int[] res = new int[settings.getBlocksize()]; int[] res = new int[settings.getBlocksize()];
this.profiler.addNetworkDelay();
int[] data = this.profiler.clientReceive(this.servers[0].giveDatabase()); int[] data = this.profiler.clientReceive(this.servers[0].giveDatabase());
System.arraycopy(data, (record * settings.getBlocksize()), res, 0, settings.getBlocksize()); System.arraycopy(data, (record * settings.getBlocksize()), res, 0, settings.getBlocksize());
return res; return res;

View File

@ -60,7 +60,6 @@ public class InterPolyClient {
private int receiveBit(int index) { private int receiveBit(int index) {
FieldElement[] randoms = this.getRandomFieldElements(); FieldElement[] randoms = this.getRandomFieldElements();
FieldElement[] Fs = new FieldElement[this.servers.length]; FieldElement[] Fs = new FieldElement[this.servers.length];
this.profiler.addNetworkDelay(2);
for (int z = 0; z < this.servers.length; z++) { for (int z = 0; z < this.servers.length; z++) {
Fs[z] = this.profiler.clientReceive(this.servers[z].F(this.profiler.clientSend(this.getGs(index, z+1, randoms)))); Fs[z] = this.profiler.clientReceive(this.servers[z].F(this.profiler.clientSend(this.getGs(index, z+1, randoms))));
} }
@ -84,7 +83,6 @@ public class InterPolyClient {
/** /**
* 1) Compute all the Gs for each server, s.t. the first index should be the blocksize and it should contain all the Gs for the given index * 1) Compute all the Gs for each server, s.t. the first index should be the blocksize and it should contain all the Gs for the given index
*/ */
this.profiler.addNetworkDelay(2);
for (int z = 0; z < this.servers.length; z++) { for (int z = 0; z < this.servers.length; z++) {
FieldElement[][] Gs = new FieldElement[settings.getBlocksize()][this.s]; FieldElement[][] Gs = new FieldElement[settings.getBlocksize()][this.s];
for (int i = 0; i < settings.getBlocksize(); i++) { for (int i = 0; i < settings.getBlocksize(); i++) {

View File

@ -36,7 +36,6 @@ public class SimpleClient {
S2[index] = 1; S2[index] = 1;
} }
this.profiler.addNetworkDelay(2);
int resBit1 = this.profiler.clientReceive(this.servers[0].computeBit(this.profiler.clientSend(S1))); int resBit1 = this.profiler.clientReceive(this.servers[0].computeBit(this.profiler.clientSend(S1)));
int resBit2 = this.profiler.clientReceive(this.servers[1].computeBit(this.profiler.clientSend(S2))); int resBit2 = this.profiler.clientReceive(this.servers[1].computeBit(this.profiler.clientSend(S2)));
@ -61,7 +60,6 @@ public class SimpleClient {
} }
} }
this.profiler.addNetworkDelay(2);
int[] resBit1 = this.profiler.clientReceive(this.servers[0].computeBits(this.profiler.clientSend(S1s))); int[] resBit1 = this.profiler.clientReceive(this.servers[0].computeBits(this.profiler.clientSend(S1s)));
int[] resBit2 = this.profiler.clientReceive(this.servers[1].computeBits(this.profiler.clientSend(S2s))); int[] resBit2 = this.profiler.clientReceive(this.servers[1].computeBits(this.profiler.clientSend(S2s)));