Added baseline model for testing

This commit is contained in:
Alexander Munch-Hansen 2018-04-24 22:30:58 +02:00
parent 349ad718f1
commit 0509a51fd3

View File

@ -131,6 +131,22 @@ class Network:
if os.path.isfile(episode_count_path):
with open(episode_count_path, 'r') as f:
self.config['start_episode'] = int(f.read())
elif glob.glob(os.path.join(os.path.join(self.config['model_storage_path'], "baseline_model"), 'model.ckpt*.index')):
checkpoint_path = os.path.join(self.config['model_storage_path'], "baseline_model")
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
print("[NETWK] ({name}) Restoring model from:".format(name=self.name),
str(latest_checkpoint))
self.saver.restore(sess, latest_checkpoint)
variables_names = [v.name for v in tf.trainable_variables()]
values = sess.run(variables_names)
for k, v in zip(variables_names, values):
print("Variable: ", k)
print("Shape: ", v.shape)
print(v)
else:
print("You need to have baseline_model inside models")
exit()
def make_move(self, sess, board, roll, player):