save and restore number of trained episodes

This commit is contained in:
Christoffer Müller Madsen 2018-03-10 00:22:20 +01:00
parent fc88c64452
commit 9bc1a8ba9f
4 changed files with 20 additions and 6 deletions

View File

@ -91,12 +91,13 @@ class Game:
if episode % min(save_step_size, episodes) == 0: if episode % min(save_step_size, episodes) == 0:
sys.stderr.write("[TRAIN] Saving model...\n") sys.stderr.write("[TRAIN] Saving model...\n")
self.p1.get_network().save_model() self.p1.get_network().save_model(episode+trained_eps)
sys.stderr.write("[TRAIN] Loading model for training opponent...\n")
self.p2.restore_model() self.p2.restore_model()
sys.stderr.write("[TRAIN] Saving model for final episode...\n") sys.stderr.write("[TRAIN] Saving model for final episode...\n")
self.p1.get_network().save_model() self.p1.get_network().save_model(episode+trained_eps)
self.p2.restore_model() self.p2.restore_model()
return outcomes return outcomes

View File

@ -39,6 +39,9 @@ parser.add_argument('--train', action='store_true',
help='whether to train the neural network') help='whether to train the neural network')
parser.add_argument('--play', action='store_true', parser.add_argument('--play', action='store_true',
help='whether to play with the neural network') help='whether to play with the neural network')
parser.add_argument('--start-episode', action='store', dest='start_episode',
type=int, default=0,
help='episode count to start at; purely for display purposes')
args = parser.parse_args() args = parser.parse_args()
@ -48,7 +51,8 @@ config = {
'eval_methods': args.eval_methods, 'eval_methods': args.eval_methods,
'train': args.train, 'train': args.train,
'play': args.play, 'play': args.play,
'eval': args.eval 'eval': args.eval,
'start_episode': args.start_episode
} }
#print("-"*30) #print("-"*30)
@ -63,7 +67,7 @@ g.set_up_bots()
episode_count = args.episode_count episode_count = args.episode_count
if args.train: if args.train:
eps = 0 eps = config['start_episode']
while True: while True:
train_outcome = g.train_model(episodes = episode_count, trained_eps = eps) train_outcome = g.train_model(episodes = episode_count, trained_eps = eps)
eps += episode_count eps += episode_count

View File

@ -24,6 +24,13 @@ class Network:
self.session = session self.session = session
self.checkpoint_path = config['model_path'] self.checkpoint_path = config['model_path']
# Restore trained episode count for model
episode_count_path = os.path.join(self.checkpoint_path, "model.episodes")
if os.path.isfile(episode_count_path):
with open(episode_count_path, 'r') as f:
self.config['start_episode'] = int(f.read())
# input = x # input = x
self.x = tf.placeholder('float', [1, Network.input_size], name='x') self.x = tf.placeholder('float', [1, Network.input_size], name='x')
self.value_next = tf.placeholder('float', [1, Network.output_size], name="value_next") self.value_next = tf.placeholder('float', [1, Network.output_size], name="value_next")
@ -82,8 +89,10 @@ class Network:
return val return val
def save_model(self): def save_model(self, episode_count):
self.saver.save(self.session, os.path.join(self.checkpoint_path, 'model.ckpt')) self.saver.save(self.session, os.path.join(self.checkpoint_path, 'model.ckpt'))
with open(os.path.join(self.checkpoint_path, "model.episodes"), 'w+') as f:
f.write(str(episode_count) + "\n")
def restore_model(self): def restore_model(self):
if os.path.isfile(self.checkpoint_path): if os.path.isfile(self.checkpoint_path):