save and restore number of trained episodes
This commit is contained in:
parent
fc88c64452
commit
9bc1a8ba9f
2
bot.py
2
bot.py
|
@ -15,7 +15,7 @@ class Bot:
|
|||
with self.graph.as_default():
|
||||
self.session = tf.Session()
|
||||
self.network = Network(self.session, config)
|
||||
self.network.restore_model()
|
||||
self.network.restore_model()
|
||||
|
||||
|
||||
def roll(self):
|
||||
|
|
5
game.py
5
game.py
|
@ -91,12 +91,13 @@ class Game:
|
|||
|
||||
if episode % min(save_step_size, episodes) == 0:
|
||||
sys.stderr.write("[TRAIN] Saving model...\n")
|
||||
self.p1.get_network().save_model()
|
||||
self.p1.get_network().save_model(episode+trained_eps)
|
||||
sys.stderr.write("[TRAIN] Loading model for training opponent...\n")
|
||||
self.p2.restore_model()
|
||||
|
||||
|
||||
sys.stderr.write("[TRAIN] Saving model for final episode...\n")
|
||||
self.p1.get_network().save_model()
|
||||
self.p1.get_network().save_model(episode+trained_eps)
|
||||
self.p2.restore_model()
|
||||
|
||||
return outcomes
|
||||
|
|
8
main.py
8
main.py
|
@ -39,6 +39,9 @@ parser.add_argument('--train', action='store_true',
|
|||
help='whether to train the neural network')
|
||||
parser.add_argument('--play', action='store_true',
|
||||
help='whether to play with the neural network')
|
||||
parser.add_argument('--start-episode', action='store', dest='start_episode',
|
||||
type=int, default=0,
|
||||
help='episode count to start at; purely for display purposes')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -48,7 +51,8 @@ config = {
|
|||
'eval_methods': args.eval_methods,
|
||||
'train': args.train,
|
||||
'play': args.play,
|
||||
'eval': args.eval
|
||||
'eval': args.eval,
|
||||
'start_episode': args.start_episode
|
||||
}
|
||||
|
||||
#print("-"*30)
|
||||
|
@ -63,7 +67,7 @@ g.set_up_bots()
|
|||
episode_count = args.episode_count
|
||||
|
||||
if args.train:
|
||||
eps = 0
|
||||
eps = config['start_episode']
|
||||
while True:
|
||||
train_outcome = g.train_model(episodes = episode_count, trained_eps = eps)
|
||||
eps += episode_count
|
||||
|
|
11
network.py
11
network.py
|
@ -23,6 +23,13 @@ class Network:
|
|||
self.config = config
|
||||
self.session = session
|
||||
self.checkpoint_path = config['model_path']
|
||||
|
||||
|
||||
# Restore trained episode count for model
|
||||
episode_count_path = os.path.join(self.checkpoint_path, "model.episodes")
|
||||
if os.path.isfile(episode_count_path):
|
||||
with open(episode_count_path, 'r') as f:
|
||||
self.config['start_episode'] = int(f.read())
|
||||
|
||||
# input = x
|
||||
self.x = tf.placeholder('float', [1, Network.input_size], name='x')
|
||||
|
@ -82,8 +89,10 @@ class Network:
|
|||
|
||||
return val
|
||||
|
||||
def save_model(self):
|
||||
def save_model(self, episode_count):
|
||||
self.saver.save(self.session, os.path.join(self.checkpoint_path, 'model.ckpt'))
|
||||
with open(os.path.join(self.checkpoint_path, "model.episodes"), 'w+') as f:
|
||||
f.write(str(episode_count) + "\n")
|
||||
|
||||
def restore_model(self):
|
||||
if os.path.isfile(self.checkpoint_path):
|
||||
|
|
Loading…
Reference in New Issue
Block a user