From 7d29fc02f24370c82553f010649bd04b9e05aa35 Mon Sep 17 00:00:00 2001 From: Pownie Date: Sat, 14 Apr 2018 23:11:20 +0200 Subject: [PATCH] Added global step + exponential decay --- network.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/network.py b/network.py index 2722f6a..9b08e06 100644 --- a/network.py +++ b/network.py @@ -36,7 +36,14 @@ class Network: self.output_size = 1 self.hidden_size = 40 # Can't remember the best learning_rate, look this up - self.learning_rate = 0.01 + self.max_learning_rate = 0.1 + self.min_learning_rate = 0.001 + # self.learning_rate = 0.01 + + self.global_step = tf.Variable(0, trainable=False, name="global_step") + self.learning_rate = tf.maximum(self.min_learning_rate, tf.train.exponential_decay(self.max_learning_rate, self.global_step, 50000, 0.96, staircase=True), name="alpha") + + # Restore trained episode count for model episode_count_path = os.path.join(self.checkpoint_path, "episodes_trained") @@ -80,6 +87,8 @@ class Network: apply_gradients = [] + global_step_op = self.global_step.assign_add(1) + with tf.variable_scope('apply_gradients'): for gradient, trainable_var in zip(gradients, trainable_vars): # Hopefully this is Δw_t = α(V_t+1 - V_t)▿_wV_t. @@ -128,8 +137,8 @@ class Network: return sess.run(self.value, feed_dict={self.x: state}) - def save_model(self, sess, episode_count): - self.saver.save(sess, os.path.join(self.checkpoint_path, 'model.ckpt')) + def save_model(self, sess, episode_count, global_step): + self.saver.save(sess, os.path.join(self.checkpoint_path, 'model.ckpt'), global_step=global_step) with open(os.path.join(self.checkpoint_path, "episodes_trained"), 'w+') as f: print("[NETWK] ({name}) Saving model to:".format(name=self.name), os.path.join(self.checkpoint_path, 'model.ckpt')) @@ -396,7 +405,7 @@ class Network: with tf.name_scope("final"): merged = tf.summary.merge_all() - summary, _ = sess.run([merged, self.training_op], + summary, _, global_step = sess.run([merged, self.training_op, self.global_step], feed_dict={self.x: self.board_trans_func(prev_board, player), self.value_next: scaled_final_score.reshape((1, 1))}) writer.add_summary(summary, episode + trained_eps) @@ -405,13 +414,13 @@ class Network: if episode % min(save_step_size, episodes) == 0: sys.stderr.write("[TRAIN] Saving model...\n") - self.save_model(sess, episode + trained_eps) + self.save_model(sess, episode + trained_eps, global_step) if episode % 50 == 0: print_time_estimate(episode) sys.stderr.write("[TRAIN] Saving model for final episode...\n") - self.save_model(sess, episode+trained_eps) + self.save_model(sess, episode+trained_eps, global_step=global_step) writer.close()