Added baseline model for testing
This commit is contained in:
parent
349ad718f1
commit
0509a51fd3
16
network.py
16
network.py
|
@ -131,6 +131,22 @@ class Network:
|
||||||
if os.path.isfile(episode_count_path):
|
if os.path.isfile(episode_count_path):
|
||||||
with open(episode_count_path, 'r') as f:
|
with open(episode_count_path, 'r') as f:
|
||||||
self.config['start_episode'] = int(f.read())
|
self.config['start_episode'] = int(f.read())
|
||||||
|
elif glob.glob(os.path.join(os.path.join(self.config['model_storage_path'], "baseline_model"), 'model.ckpt*.index')):
|
||||||
|
checkpoint_path = os.path.join(self.config['model_storage_path'], "baseline_model")
|
||||||
|
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
|
||||||
|
print("[NETWK] ({name}) Restoring model from:".format(name=self.name),
|
||||||
|
str(latest_checkpoint))
|
||||||
|
self.saver.restore(sess, latest_checkpoint)
|
||||||
|
|
||||||
|
variables_names = [v.name for v in tf.trainable_variables()]
|
||||||
|
values = sess.run(variables_names)
|
||||||
|
for k, v in zip(variables_names, values):
|
||||||
|
print("Variable: ", k)
|
||||||
|
print("Shape: ", v.shape)
|
||||||
|
print(v)
|
||||||
|
else:
|
||||||
|
print("You need to have baseline_model inside models")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
def make_move(self, sess, board, roll, player):
|
def make_move(self, sess, board, roll, player):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user