save and restore number of trained episodes
This commit is contained in:
parent
fc88c64452
commit
9bc1a8ba9f
2
bot.py
2
bot.py
|
@ -15,7 +15,7 @@ class Bot:
|
||||||
with self.graph.as_default():
|
with self.graph.as_default():
|
||||||
self.session = tf.Session()
|
self.session = tf.Session()
|
||||||
self.network = Network(self.session, config)
|
self.network = Network(self.session, config)
|
||||||
self.network.restore_model()
|
self.network.restore_model()
|
||||||
|
|
||||||
|
|
||||||
def roll(self):
|
def roll(self):
|
||||||
|
|
5
game.py
5
game.py
|
@ -91,12 +91,13 @@ class Game:
|
||||||
|
|
||||||
if episode % min(save_step_size, episodes) == 0:
|
if episode % min(save_step_size, episodes) == 0:
|
||||||
sys.stderr.write("[TRAIN] Saving model...\n")
|
sys.stderr.write("[TRAIN] Saving model...\n")
|
||||||
self.p1.get_network().save_model()
|
self.p1.get_network().save_model(episode+trained_eps)
|
||||||
|
sys.stderr.write("[TRAIN] Loading model for training opponent...\n")
|
||||||
self.p2.restore_model()
|
self.p2.restore_model()
|
||||||
|
|
||||||
|
|
||||||
sys.stderr.write("[TRAIN] Saving model for final episode...\n")
|
sys.stderr.write("[TRAIN] Saving model for final episode...\n")
|
||||||
self.p1.get_network().save_model()
|
self.p1.get_network().save_model(episode+trained_eps)
|
||||||
self.p2.restore_model()
|
self.p2.restore_model()
|
||||||
|
|
||||||
return outcomes
|
return outcomes
|
||||||
|
|
8
main.py
8
main.py
|
@ -39,6 +39,9 @@ parser.add_argument('--train', action='store_true',
|
||||||
help='whether to train the neural network')
|
help='whether to train the neural network')
|
||||||
parser.add_argument('--play', action='store_true',
|
parser.add_argument('--play', action='store_true',
|
||||||
help='whether to play with the neural network')
|
help='whether to play with the neural network')
|
||||||
|
parser.add_argument('--start-episode', action='store', dest='start_episode',
|
||||||
|
type=int, default=0,
|
||||||
|
help='episode count to start at; purely for display purposes')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -48,7 +51,8 @@ config = {
|
||||||
'eval_methods': args.eval_methods,
|
'eval_methods': args.eval_methods,
|
||||||
'train': args.train,
|
'train': args.train,
|
||||||
'play': args.play,
|
'play': args.play,
|
||||||
'eval': args.eval
|
'eval': args.eval,
|
||||||
|
'start_episode': args.start_episode
|
||||||
}
|
}
|
||||||
|
|
||||||
#print("-"*30)
|
#print("-"*30)
|
||||||
|
@ -63,7 +67,7 @@ g.set_up_bots()
|
||||||
episode_count = args.episode_count
|
episode_count = args.episode_count
|
||||||
|
|
||||||
if args.train:
|
if args.train:
|
||||||
eps = 0
|
eps = config['start_episode']
|
||||||
while True:
|
while True:
|
||||||
train_outcome = g.train_model(episodes = episode_count, trained_eps = eps)
|
train_outcome = g.train_model(episodes = episode_count, trained_eps = eps)
|
||||||
eps += episode_count
|
eps += episode_count
|
||||||
|
|
11
network.py
11
network.py
|
@ -23,6 +23,13 @@ class Network:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.session = session
|
self.session = session
|
||||||
self.checkpoint_path = config['model_path']
|
self.checkpoint_path = config['model_path']
|
||||||
|
|
||||||
|
|
||||||
|
# Restore trained episode count for model
|
||||||
|
episode_count_path = os.path.join(self.checkpoint_path, "model.episodes")
|
||||||
|
if os.path.isfile(episode_count_path):
|
||||||
|
with open(episode_count_path, 'r') as f:
|
||||||
|
self.config['start_episode'] = int(f.read())
|
||||||
|
|
||||||
# input = x
|
# input = x
|
||||||
self.x = tf.placeholder('float', [1, Network.input_size], name='x')
|
self.x = tf.placeholder('float', [1, Network.input_size], name='x')
|
||||||
|
@ -82,8 +89,10 @@ class Network:
|
||||||
|
|
||||||
return val
|
return val
|
||||||
|
|
||||||
def save_model(self):
|
def save_model(self, episode_count):
|
||||||
self.saver.save(self.session, os.path.join(self.checkpoint_path, 'model.ckpt'))
|
self.saver.save(self.session, os.path.join(self.checkpoint_path, 'model.ckpt'))
|
||||||
|
with open(os.path.join(self.checkpoint_path, "model.episodes"), 'w+') as f:
|
||||||
|
f.write(str(episode_count) + "\n")
|
||||||
|
|
||||||
def restore_model(self):
|
def restore_model(self):
|
||||||
if os.path.isfile(self.checkpoint_path):
|
if os.path.isfile(self.checkpoint_path):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user