diff --git a/main.py b/main.py index e3ded40..0631df3 100644 --- a/main.py +++ b/main.py @@ -34,14 +34,15 @@ parser.add_argument('--list-models', action='store_true', parser.add_argument('--force-creation', action='store_true', help='force model creation if model does not exist') parser.add_argument('--board-rep', action='store', dest='board_rep', - default='tesauro', help='name of board representation to use as input to neural network') parser.add_argument('--use-baseline', action='store_true', help='use the baseline model, note, has size 28') parser.add_argument('--verbose', action='store_true', help='If set, a lot of stuff will be printed') -parser.add_argument('--ply', action='store', dest='ply', +parser.add_argument('--ply', action='store', dest='ply', default='0', help='defines the amount of ply used when deciding what move to make') +parser.add_argument('--repeat-eval', action='store', dest='repeat_eval', default='1', + help='the amount of times the evaluation method should be repeated') args = parser.parse_args() @@ -67,10 +68,11 @@ config = { 'use_baseline': args.use_baseline, 'global_step': 0, 'verbose': args.verbose, - 'ply': args.ply - + 'ply': args.ply, + 'repeat_eval': args.repeat_eval } + # Create models folder if not os.path.exists(config['model_storage_path']): os.makedirs(config['model_storage_path']) @@ -133,6 +135,24 @@ def log_bench_eval_outcomes(outcomes, log_path, index, time, trained_eps = 0): with open(log_path, 'a+') as f: f.write("{method};{count};{index};{time};{sum};{mean}".format(**format_vars) + "\n") +def find_board_rep(): + checkpoint_path = os.path.join(config['model_storage_path'], config['model']) + board_rep_path = os.path.join(checkpoint_path, "board_representation") + with open(board_rep_path, 'r') as f: + return f.read() + + +def board_rep_file_exists(): + checkpoint_path = os.path.join(config['model_storage_path'], config['model']) + board_rep_path = os.path.join(checkpoint_path, "board_representation") + return os.path.isfile(board_rep_path) + +def create_board_rep(): + checkpoint_path = os.path.join(config['model_storage_path'], config['model']) + board_rep_path = os.path.join(checkpoint_path, "board_representation") + with open(board_rep_path, 'a+') as f: + f.write(config['board_representation']) + # Do actions specified by command-line if args.list_models: def get_eps_trained(folder): @@ -155,6 +175,22 @@ if __name__ == "__main__": # Set up variables episode_count = config['episode_count'] + + if config['board_representation'] is None: + if board_rep_file_exists(): + config['board_representation'] = find_board_rep() + else: + sys.stderr.write("Was not given a board_rep and was unable to find a board_rep file\n") + exit() + else: + if not board_rep_file_exists(): + create_board_rep() + else: + if config['board_representation'] != find_board_rep(): + sys.stderr.write("Board representation \"{given}\", does not match one in board_rep file, \"{board_rep}\"\n". + format(given = config['board_representation'], board_rep = find_board_rep())) + exit() + if args.train: network = Network(config, config['model']) @@ -172,12 +208,13 @@ if __name__ == "__main__": elif args.eval: network = Network(config, config['model']) - start_episode = network.episodes_trained - # Evaluation measures are described in `config` - outcomes = network.eval(config['episode_count']) - log_eval_outcomes(outcomes, trained_eps = start_episode) - # elif args.play: - # g.play(episodes = episode_count) + for i in range(int(config['repeat_eval'])): + start_episode = network.episodes_trained + # Evaluation measures are described in `config` + outcomes = network.eval(config['episode_count']) + log_eval_outcomes(outcomes, trained_eps = start_episode) + # elif args.play: + # g.play(episodes = episode_count) elif args.bench_eval_scores: diff --git a/network.py b/network.py index 5ea80cd..ad8e27a 100644 --- a/network.py +++ b/network.py @@ -177,6 +177,7 @@ class Network: :return: Nothing. It's a side-effect that a model gets restored for the network. """ + if glob.glob(os.path.join(self.checkpoint_path, 'model.ckpt*.index')): latest_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path) @@ -235,9 +236,9 @@ class Network: :param player: :return: """ - start = time.time() + # start = time.time() best_pair = self.calculate_1_ply(board, roll, player) - print(time.time() - start) + # print(time.time() - start) return best_pair