diff --git a/network.py b/network.py index 9eff2a8..76a21f8 100644 --- a/network.py +++ b/network.py @@ -131,6 +131,22 @@ class Network: if os.path.isfile(episode_count_path): with open(episode_count_path, 'r') as f: self.config['start_episode'] = int(f.read()) + elif glob.glob(os.path.join(os.path.join(self.config['model_storage_path'], "baseline_model"), 'model.ckpt*.index')): + checkpoint_path = os.path.join(self.config['model_storage_path'], "baseline_model") + latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path) + print("[NETWK] ({name}) Restoring model from:".format(name=self.name), + str(latest_checkpoint)) + self.saver.restore(sess, latest_checkpoint) + + variables_names = [v.name for v in tf.trainable_variables()] + values = sess.run(variables_names) + for k, v in zip(variables_names, values): + print("Variable: ", k) + print("Shape: ", v.shape) + print(v) + else: + print("You need to have baseline_model inside models") + exit() def make_move(self, sess, board, roll, player):