diff --git a/network.py b/network.py index d548f26..1ce3704 100644 --- a/network.py +++ b/network.py @@ -93,8 +93,9 @@ class Network: self.saver.save(self.session, self.checkpoint_path + 'model.ckpt') def restore_model(self): - latest_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path) - self.saver.restore(self.session, latest_checkpoint) + if os.path.isfile(self.checkpoint_path): + latest_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path) + self.saver.restore(self.session, latest_checkpoint) # Have a circular dependency, #fuck, need to rewrite something def train(self, x, v_next):