diff --git a/main.py b/main.py index 8916d3f..e2e8988 100644 --- a/main.py +++ b/main.py @@ -33,6 +33,9 @@ parser.add_argument('--list-models', action='store_true', help='list all known models') parser.add_argument('--force-creation', action='store_true', help='force model creation if model does not exist') +parser.add_argument('--board-rep', action='store', dest='board_rep', + default='tesauro', + help='name of board representation to use as input to neural network') parser.add_argument('--use-baseline', action='store_true', help='use the baseline model, note, has size 28') @@ -55,7 +58,7 @@ config = { 'train_perpetually': args.train_perpetually, 'model_storage_path': 'models', 'bench_storage_path': 'bench', - 'board_representation': 'quack-fat', + 'board_representation': args.board_rep, 'force_creation': args.force_creation, 'use_baseline': args.use_baseline } @@ -83,6 +86,7 @@ def log_train_outcome(outcome, diff_in_values, trained_eps = 0, log_path = os.pa 'time': int(time.time()), 'average_diff_in_vals': diff_in_values/len(outcome) } + with open(log_path, 'a+') as f: f.write("{time};{trained_eps};{count};{sum};{mean};{average_diff_in_vals}".format(**format_vars) + "\n") diff --git a/network.py b/network.py index f30e724..84802e3 100644 --- a/network.py +++ b/network.py @@ -582,6 +582,6 @@ class Network: writer.close() - return outcomes, difference_in_vals + return outcomes, difference_in_vals[0][0]