From 9bc1a8ba9f8b9e37cebebf2df71f521a6ca2db09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoffer=20M=C3=BCller=20Madsen?= Date: Sat, 10 Mar 2018 00:22:20 +0100 Subject: [PATCH] save and restore number of trained episodes --- bot.py | 2 +- game.py | 5 +++-- main.py | 8 ++++++-- network.py | 11 ++++++++++- 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/bot.py b/bot.py index a008134..b79e2b8 100644 --- a/bot.py +++ b/bot.py @@ -15,7 +15,7 @@ class Bot: with self.graph.as_default(): self.session = tf.Session() self.network = Network(self.session, config) - self.network.restore_model() + self.network.restore_model() def roll(self): diff --git a/game.py b/game.py index 60ad212..1b6d94e 100644 --- a/game.py +++ b/game.py @@ -91,12 +91,13 @@ class Game: if episode % min(save_step_size, episodes) == 0: sys.stderr.write("[TRAIN] Saving model...\n") - self.p1.get_network().save_model() + self.p1.get_network().save_model(episode+trained_eps) + sys.stderr.write("[TRAIN] Loading model for training opponent...\n") self.p2.restore_model() sys.stderr.write("[TRAIN] Saving model for final episode...\n") - self.p1.get_network().save_model() + self.p1.get_network().save_model(episode+trained_eps) self.p2.restore_model() return outcomes diff --git a/main.py b/main.py index 6598a59..2b1d91f 100644 --- a/main.py +++ b/main.py @@ -39,6 +39,9 @@ parser.add_argument('--train', action='store_true', help='whether to train the neural network') parser.add_argument('--play', action='store_true', help='whether to play with the neural network') +parser.add_argument('--start-episode', action='store', dest='start_episode', + type=int, default=0, + help='episode count to start at; purely for display purposes') args = parser.parse_args() @@ -48,7 +51,8 @@ config = { 'eval_methods': args.eval_methods, 'train': args.train, 'play': args.play, - 'eval': args.eval + 'eval': args.eval, + 'start_episode': args.start_episode } #print("-"*30) @@ -63,7 +67,7 @@ g.set_up_bots() episode_count = args.episode_count if args.train: - eps = 0 + eps = config['start_episode'] while True: train_outcome = g.train_model(episodes = episode_count, trained_eps = eps) eps += episode_count diff --git a/network.py b/network.py index cdb3510..435ccde 100644 --- a/network.py +++ b/network.py @@ -23,6 +23,13 @@ class Network: self.config = config self.session = session self.checkpoint_path = config['model_path'] + + + # Restore trained episode count for model + episode_count_path = os.path.join(self.checkpoint_path, "model.episodes") + if os.path.isfile(episode_count_path): + with open(episode_count_path, 'r') as f: + self.config['start_episode'] = int(f.read()) # input = x self.x = tf.placeholder('float', [1, Network.input_size], name='x') @@ -82,8 +89,10 @@ class Network: return val - def save_model(self): + def save_model(self, episode_count): self.saver.save(self.session, os.path.join(self.checkpoint_path, 'model.ckpt')) + with open(os.path.join(self.checkpoint_path, "model.episodes"), 'w+') as f: + f.write(str(episode_count) + "\n") def restore_model(self): if os.path.isfile(self.checkpoint_path):