Added global step + exponential decay

This commit is contained in:
Alexander Munch-Hansen 2018-04-14 23:11:20 +02:00
parent 1d9c94896d
commit 7d29fc02f2

View File

@ -36,7 +36,14 @@ class Network:
self.output_size = 1
self.hidden_size = 40
# Can't remember the best learning_rate, look this up
self.learning_rate = 0.01
self.max_learning_rate = 0.1
self.min_learning_rate = 0.001
# self.learning_rate = 0.01
self.global_step = tf.Variable(0, trainable=False, name="global_step")
self.learning_rate = tf.maximum(self.min_learning_rate, tf.train.exponential_decay(self.max_learning_rate, self.global_step, 50000, 0.96, staircase=True), name="alpha")
# Restore trained episode count for model
episode_count_path = os.path.join(self.checkpoint_path, "episodes_trained")
@ -80,6 +87,8 @@ class Network:
apply_gradients = []
global_step_op = self.global_step.assign_add(1)
with tf.variable_scope('apply_gradients'):
for gradient, trainable_var in zip(gradients, trainable_vars):
# Hopefully this is Δw_t = α(V_t+1 - V_t)▿_wV_t.
@ -128,8 +137,8 @@ class Network:
return sess.run(self.value, feed_dict={self.x: state})
def save_model(self, sess, episode_count):
self.saver.save(sess, os.path.join(self.checkpoint_path, 'model.ckpt'))
def save_model(self, sess, episode_count, global_step):
self.saver.save(sess, os.path.join(self.checkpoint_path, 'model.ckpt'), global_step=global_step)
with open(os.path.join(self.checkpoint_path, "episodes_trained"), 'w+') as f:
print("[NETWK] ({name}) Saving model to:".format(name=self.name),
os.path.join(self.checkpoint_path, 'model.ckpt'))
@ -396,7 +405,7 @@ class Network:
with tf.name_scope("final"):
merged = tf.summary.merge_all()
summary, _ = sess.run([merged, self.training_op],
summary, _, global_step = sess.run([merged, self.training_op, self.global_step],
feed_dict={self.x: self.board_trans_func(prev_board, player),
self.value_next: scaled_final_score.reshape((1, 1))})
writer.add_summary(summary, episode + trained_eps)
@ -405,13 +414,13 @@ class Network:
if episode % min(save_step_size, episodes) == 0:
sys.stderr.write("[TRAIN] Saving model...\n")
self.save_model(sess, episode + trained_eps)
self.save_model(sess, episode + trained_eps, global_step)
if episode % 50 == 0:
print_time_estimate(episode)
sys.stderr.write("[TRAIN] Saving model for final episode...\n")
self.save_model(sess, episode+trained_eps)
self.save_model(sess, episode+trained_eps, global_step=global_step)
writer.close()