import math import re import statistics from collections import defaultdict import matplotlib.cm as cm import matplotlib.pyplot as plt import numpy as np from matplotlib.artist import setp import util def load_results(filename) -> dict: results = defaultdict(list) with open(f"results/{filename}", "r") as file: for line in file: if "error" in line: continue values = line.split() keys = [ ("num_servers", int), ("database_size", int), ("block_size", int), ("protocol_name", str), ("total_cpu_time", int), ("bits_sent", int), ("bits_received", int) ] d = dict([(a[0], a[1](b)) for a, b in zip(keys, values)]) results[(d["num_servers"], d["database_size"], d["block_size"], d["protocol_name"])].append(d) return results def clean_results(results) -> dict: cleaned_results = defaultdict(list) for test, result in results.items(): cpu_time = statistics.mean(sorted([int(r["total_cpu_time"]) for r in result])) bits_sent = statistics.mean(sorted([int(r["bits_sent"]) for r in result])) bits_received = statistics.mean(sorted([int(r["bits_received"]) for r in result])) cleaned_results[result[0]["protocol_name"]].append({ **result[0], "total_cpu_time": cpu_time, "bits_sent": bits_sent, "bits_received": bits_received }) return cleaned_results def filter_results(results: dict, func: callable): return {protocol_name: [r for r in results if func(r)] for protocol_name, results in results.items()} def save_fig(plt, title): clean_title = re.sub(r"\W", r"_", title) plt.savefig(f"plots/{clean_title}.pdf") def with_bandwidth(result: dict, bandwidth=10): return max(1, result["total_cpu_time"] + ((result["bits_sent"] + result["bits_received"]) / (bandwidth * 1000))) # 1000 bits/ms = 1 Mbit/s def plot(all_results: dict, y_func: callable, x_func: callable, title=None, y_label=None, x_label=None, logx=False, logy=False, scatter=False): fig, ax = plt.subplots() for protocol_name, results in all_results.items(): sorted_results = sorted(results, key=lambda r: x_func(r)) if scatter: plot_func = ax.scatter else: plot_func = ax.plot plot_func( [x_func(r) for r in sorted_results], [y_func(r) for r in sorted_results], label=protocol_name.replace("_", " ") ) #for results in all_results.values(): # for r in results: # ax.annotate(f"{r['database_size']}, {r['block_size']}", (x_func(r), y_func(r)), fontsize=3) if logx: ax.set_xscale("log", basex=2) if logy: ax.set_yscale("log", basey=2) if x_label is not None: plt.xlabel(x_label) if y_label is not None: plt.ylabel(y_label) plt.legend(loc="upper left") #if title is not None: # plt.title(title) save_fig(plt, title) #plt.show() def plot_3x_with_simulated_bandwidth(all_results: dict, title: str): ax1 = plt.subplot(121) ax2 = plt.subplot(122, sharex=ax1, sharey=ax1) ax1.set_ylabel("Time (ms)") setp(ax2.get_yticklabels(), visible=False) ax1.set_xlabel("Total Database Size (bits)") ax2.set_xlabel("Total Database Size (bits)") for ax in (ax1, ax2): ax.tick_params("y") ax.set_xscale("log", basex=2) ax.set_yscale("log", basey=2) ax1.set_title("10 Mbit/s)") ax2.set_title("100 Mbit/s") for protocol_name, results in all_results.items(): x_func = lambda r: r["database_size"] * r["block_size"] sorted_results = sorted(results, key=lambda r: x_func(r)) ax1.plot( [x_func(r) for r in sorted_results], [with_bandwidth(r, 10) for r in sorted_results], label=protocol_name.replace("_", " ") ) ax2.plot( [x_func(r) for r in sorted_results], [with_bandwidth(r, 100) for r in sorted_results], label=protocol_name.replace("_", " ") ) ax1.legend(loc="upper left") # fig.subplots_adjust(wspace=0) save_fig(plt, title) #plt.show() def plot_send_receive(all_results: dict, title: str): ax1 = plt.subplot(121) ax2 = plt.subplot(122, sharex=ax1) ax1.set_ylabel("Sent (bits)") ax2.set_ylabel("Received (bits)") setp(ax2.get_yticklabels(), visible=False) ax2.yaxis.set_label_position("left") for ax in (ax1, ax2): ax.set_xlabel("Total Database Size (bits)") ax.tick_params("y") ax.set_xscale("log", basex=2) ax.set_yscale("log", basey=2) for protocol_name, results in all_results.items(): x_func = lambda r: r["database_size"] * r["block_size"] sorted_results = sorted(results, key=lambda r: x_func(r)) ax1.plot( [x_func(r) for r in sorted_results], [max(1, r["bits_sent"]) for r in sorted_results], label=protocol_name.replace("_", " ") ) ax2.plot( [x_func(r) for r in sorted_results], [max(1, r["bits_received"]) for r in sorted_results], label=protocol_name.replace("_", " ") ) ax1.legend(loc="upper left") # fig.subplots_adjust(wspace=0) save_fig(plt, title) #plt.show() def matrixify(results: list, x_func: callable, y_func: callable, z_func: callable): x_labels = list(sorted(set(x_func(r) for r in results))) y_labels = list(sorted(set(y_func(r) for r in results))) data = {y: {x: 1 for x in x_labels} for y in y_labels} for r in results: data[y_func(r)][x_func(r)] = z_func(r) return np.array([list(y.values()) for y in data.values()]), x_labels, y_labels def plot_scheme_heatmap(results: list, title: str, bandwidth: int): data, x_labels, y_labels = matrixify( results, x_func=lambda r: r["database_size"], y_func=lambda r: r["block_size"], z_func=lambda r: with_bandwidth(r, bandwidth) ) im, cbar = util.heatmap( data, [f"$2^{{{int(math.log2(y))}}}$" for y in y_labels], [f"$2^{{{int(math.log2(x))}}}$" for x in x_labels], xlabel="Database Size (bits)", ylabel="Block Size (bits)", cbarlabel="Time (ms)", logcolor=True, origin="lower", cmap=cm.gray ) save_fig(plt, title) def plot_old_vs_new_heatmap(all_results: dict, old_func: callable, new_func: callable, title: str): data_old, x_labels, y_labels = matrixify( old_func(all_results), x_func=lambda r: r["database_size"], y_func=lambda r: r["block_size"], z_func=lambda r: with_bandwidth(r, 10) ) data_new, x_labels, y_labels = matrixify( new_func(all_results), x_func=lambda r: r["database_size"], y_func=lambda r: r["block_size"], z_func=lambda r: with_bandwidth(r, 10) ) def calc(i, j): try: return data_new[i, j] - data_old[i, j] except IndexError: return 0 im, cbar = util.heatmap( np.array([[calc(i, j) for j, y in enumerate(x)] for i, x in enumerate(data_new)]), [f"$2^{{{int(math.log2(y))}}}$" for y in y_labels], [f"$2^{{{int(math.log2(x))}}}$" for x in x_labels], xlabel="Database Size (bits)", ylabel="Block Size (bits)", cbarlabel="Time Difference (ms)", sym_logcolor=True, origin="lower", ) save_fig(plt, title) def main(): # Simple CPU Time plot( filter_results(clean_results(load_results("results_combined.log")), lambda r: r["block_size"] == 1), y_label="Time (ms)", x_label="Total Database Size (bits)", title="Computation Time - 1-bit Block Size", y_func=lambda r: max(1, r["total_cpu_time"]), x_func=lambda r: r["database_size"] * r["block_size"], logx=True, logy=True ) plt.close() # ... with simulated bandwidth, e.g. estimated total real time plot_3x_with_simulated_bandwidth( filter_results(clean_results(load_results("results_combined.log")), lambda r: r["block_size"] == 1), title="Total Time with Simulated Bandwidth - 1-bit Block Size" ) # CPU Time per bit as a function of block/database-ratio #plot( # filter_results(clean_results(load_results("results_combined.log")), # lambda r: r["protocol_name"] != "Interpolation" and r["database_size"] * r["block_size"] > 1024), # y_label="Time (ms)", # x_label="Block Size / Database Size (ratio)", # title="Computation Time per bit - Block Size / Database Size Ratio", # y_func=lambda r: max(1, r["total_cpu_time"] / (r["database_size"] * r["block_size"])), # x_func=lambda r: r["block_size"] / r["database_size"], # logx=True #) plt.close() # Simple Network Traffic plot_send_receive( filter_results(clean_results(load_results("results_combined.log")), lambda r: r["block_size"] == 1), title="Network Traffic - 1-bit Block Size" ) # Scatter-plot of total real-time (cpu + simulated bandwidth), varying both block size and database size #plot( # clean_results(load_results("results_fast_var-bs_var-db.log")), # y_label="Time (ms)", # x_label="Total Database Size (bits)", # title="Total Time with Simulated Bandwidth - Varying Block and Database Size", # y_func=lambda r: max(1, r["total_cpu_time"] + ((r["bits_sent"]+r["bits_received"])/(10*1000))), # 1000 bits/ms = 1 Mbit/s # x_func=lambda r: r["database_size"] * r["block_size"], # scatter=True #) plt.close() # 2D Heatmap of CPU time for Simple/XOR/Balanced XOR - varying both database size and block size plot_scheme_heatmap( clean_results(load_results("results_fast_var-bs_var-db.log"))["Send_All"], title="Total Simulated Time Heatmap: Send All - Varying Database Size and Block Size - 10Mbit/s", bandwidth=10 ) plt.close() plot_scheme_heatmap( clean_results(load_results("results_fast_var-bs_var-db.log"))["XOR"], title="Total Simulated Time Heatmap: XOR - Varying Database Size and Block Size - 10Mbit/s", bandwidth=10 ) plt.close() plot_scheme_heatmap( clean_results(load_results("results_fast_var-bs_var-db.log"))["Balanced_XOR"], title="Total Simulated Time Heatmap: Balanced XOR - Varying Database Size and Block Size - 10Mbit/s", bandwidth=10 ) plt.close() # 2D Heatmaps of Schemes Versus (CPU + simulated bandwidth), varying both block size and database size plot_old_vs_new_heatmap( clean_results(load_results("results_fast_var-bs_var-db.log")), old_func=lambda rs: rs["Send_All"], new_func=lambda rs: rs["Balanced_XOR"], title="Total Simulated Time Heatmap: Send All vs Balanced XOR - Varying Database Size and Block Size - 10 Mbit/s" ) plt.close() plot_old_vs_new_heatmap( clean_results(load_results("results_fast_var-bs_var-db.log")), old_func=lambda rs: rs["XOR"], new_func=lambda rs: rs["Balanced_XOR"], title="Total Simulated Time Heatmap: XOR vs Balanced XOR - Varying Database Size and Block Size - 10 Mbit/s" ) plt.close() if __name__ == '__main__': main()