import os import pandas as pd from datetime import datetime import csv import datetime import matplotlib.pyplot as plt import matplotlib.ticker as mtick import matplotlib.dates as mdates train_headers = ['timestamp', 'eps_train', 'eps_trained_session', 'sum', 'mean'] eval_headers = ['timestamp', 'method', 'eps_train', 'eval_eps_used', 'sum', 'mean'] bench_headers = ['method', 'sample_count', 'i', 'time', 'sum', 'mean'] model_path = 'models' def plot_bench(data_path): df = pd.read_csv(data_path, sep=";", names=bench_headers, index_col=[0,1,2]) for method_label in df.index.levels[0]: df_prime = df[['mean']].loc[method_label].unstack().T plot = df_prime.plot.box() plot.set_title("Evaluation variance, {}".format(method_label)) plot.set_xlabel("Sample count") plot.set_ylabel("Mean score") plt.show(plot.figure) # for later use: variances = df_prime.var() print(variances) del df_prime, plot, variances def dataframes(model_name): def df_timestamp_to_datetime(df): df['timestamp'] = df['timestamp'].map(lambda t: datetime.datetime.fromtimestamp(t)) return df log_path = os.path.join(model_path, model_name, 'logs') raw_dfs = [ pd.read_csv(os.path.join(log_path, 'eval.log'), sep=';', names=eval_headers), pd.read_csv(os.path.join(log_path, 'train.log'), sep=';', names=train_headers) ] dfs = [ df_timestamp_to_datetime(df) for df in raw_dfs ] dataframes = { 'eval': dfs[0], 'train': dfs[1] } return dataframes if __name__ == '__main__': fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, sharey=True) #plt.ion() ax1.set_title('Mean over episodes') ax2.set_xlabel('Episodes trained') ax1.set_ylabel('Points-per-game') ax1.grid(True) ax2.grid(True) #ax.set_xlim(left=0) ax1.set_ylim([-2, 2]) df = dataframes('tesauro-5')['eval'] print(df) dumbeval_df = df.query("method == 'dumbeval'") pubeval_df = df.query("method == 'pubeval'") def plot_eval(axis, label, df, c): x = df['eps_train'] y = df['mean'] axis.scatter(x, y, label=label, c=c, marker="x") plot_eval(ax1, "dumbeval", dumbeval_df, [[1, 0.5, 0]]) plot_eval(ax2, "pubeval", pubeval_df, [[0, 0.5, 1]]) ax1.legend() ax2.legend() plt.show()