Added global step + exponential decay
This commit is contained in:
parent
1d9c94896d
commit
7d29fc02f2
21
network.py
21
network.py
|
@ -36,7 +36,14 @@ class Network:
|
|||
self.output_size = 1
|
||||
self.hidden_size = 40
|
||||
# Can't remember the best learning_rate, look this up
|
||||
self.learning_rate = 0.01
|
||||
self.max_learning_rate = 0.1
|
||||
self.min_learning_rate = 0.001
|
||||
# self.learning_rate = 0.01
|
||||
|
||||
self.global_step = tf.Variable(0, trainable=False, name="global_step")
|
||||
self.learning_rate = tf.maximum(self.min_learning_rate, tf.train.exponential_decay(self.max_learning_rate, self.global_step, 50000, 0.96, staircase=True), name="alpha")
|
||||
|
||||
|
||||
|
||||
# Restore trained episode count for model
|
||||
episode_count_path = os.path.join(self.checkpoint_path, "episodes_trained")
|
||||
|
@ -80,6 +87,8 @@ class Network:
|
|||
|
||||
apply_gradients = []
|
||||
|
||||
global_step_op = self.global_step.assign_add(1)
|
||||
|
||||
with tf.variable_scope('apply_gradients'):
|
||||
for gradient, trainable_var in zip(gradients, trainable_vars):
|
||||
# Hopefully this is Δw_t = α(V_t+1 - V_t)▿_wV_t.
|
||||
|
@ -128,8 +137,8 @@ class Network:
|
|||
|
||||
return sess.run(self.value, feed_dict={self.x: state})
|
||||
|
||||
def save_model(self, sess, episode_count):
|
||||
self.saver.save(sess, os.path.join(self.checkpoint_path, 'model.ckpt'))
|
||||
def save_model(self, sess, episode_count, global_step):
|
||||
self.saver.save(sess, os.path.join(self.checkpoint_path, 'model.ckpt'), global_step=global_step)
|
||||
with open(os.path.join(self.checkpoint_path, "episodes_trained"), 'w+') as f:
|
||||
print("[NETWK] ({name}) Saving model to:".format(name=self.name),
|
||||
os.path.join(self.checkpoint_path, 'model.ckpt'))
|
||||
|
@ -396,7 +405,7 @@ class Network:
|
|||
|
||||
with tf.name_scope("final"):
|
||||
merged = tf.summary.merge_all()
|
||||
summary, _ = sess.run([merged, self.training_op],
|
||||
summary, _, global_step = sess.run([merged, self.training_op, self.global_step],
|
||||
feed_dict={self.x: self.board_trans_func(prev_board, player),
|
||||
self.value_next: scaled_final_score.reshape((1, 1))})
|
||||
writer.add_summary(summary, episode + trained_eps)
|
||||
|
@ -405,13 +414,13 @@ class Network:
|
|||
|
||||
if episode % min(save_step_size, episodes) == 0:
|
||||
sys.stderr.write("[TRAIN] Saving model...\n")
|
||||
self.save_model(sess, episode + trained_eps)
|
||||
self.save_model(sess, episode + trained_eps, global_step)
|
||||
|
||||
if episode % 50 == 0:
|
||||
print_time_estimate(episode)
|
||||
|
||||
sys.stderr.write("[TRAIN] Saving model for final episode...\n")
|
||||
self.save_model(sess, episode+trained_eps)
|
||||
self.save_model(sess, episode+trained_eps, global_step=global_step)
|
||||
|
||||
writer.close()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user