diff --git a/bin/run_all_tests.rb b/bin/run_all_tests.rb index 3530490..c9ae967 100644 --- a/bin/run_all_tests.rb +++ b/bin/run_all_tests.rb @@ -25,7 +25,7 @@ board_rep = "quack" model_name = "quack_test_0_ply" -run_stuff(board_rep, model_name) +#run_stuff(board_rep, model_name) #board_rep = "quack" @@ -42,7 +42,7 @@ board_rep = "quack-fat" model_name = "quack-fat_test_0_ply" -run_stuff(board_rep, model_name) +#run_stuff(board_rep, model_name) #board_rep = "quack-fat" #model_name = "quack-fat_test_1_ply" @@ -59,7 +59,7 @@ board_rep = "quack-norm" model_name = "quack-norm_test_0_ply" -run_stuff(board_rep, model_name) +#run_stuff(board_rep, model_name) #board_rep = "quack-norm" #model_name = "quack-norm_test_1_ply" @@ -73,7 +73,7 @@ run_stuff(board_rep, model_name) board_rep = "tesauro" -model_name = "tesauro_test_0_ply" +model_name = "tesauro_test3_0_ply" run_stuff(board_rep, model_name) diff --git a/plot.py b/plot.py index 5957854..7329523 100644 --- a/plot.py +++ b/plot.py @@ -47,27 +47,36 @@ def dataframes(model_name): if __name__ == '__main__': - fig, ax = plt.subplots(1, 1) + fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, sharey=True) - plt.ion() - plt.title('Mean over episodes') - plt.xlabel('Episodes trained') - plt.ylabel('Mean') - plt.grid(True) + #plt.ion() + ax1.set_title('Mean over episodes') + ax2.set_xlabel('Episodes trained') + ax1.set_ylabel('Points-per-game') + ax1.grid(True) + ax2.grid(True) #ax.set_xlim(left=0) - ax.set_ylim([-2, 2]) + ax1.set_ylim([-2, 2]) - plt.show() + df = dataframes('tesauro-5')['eval'] - while True: - df = dataframes('a')['eval'] + print(df) - print(df) - + dumbeval_df = df.query("method == 'dumbeval'") + pubeval_df = df.query("method == 'pubeval'") + + def plot_eval(axis, label, df, c): x = df['eps_train'] y = df['mean'] - plt.scatter(x, y, c=[[1, 0.5, 0]]) - #fig.canvas.draw() - plt.pause(2) + axis.scatter(x, y, label=label, c=c, marker="x") + + plot_eval(ax1, "dumbeval", dumbeval_df, [[1, 0.5, 0]]) + plot_eval(ax2, "pubeval", pubeval_df, [[0, 0.5, 1]]) + + ax1.legend() + ax2.legend() + + + plt.show()