Added global step + exponential decay
This commit is contained in:
parent
1d9c94896d
commit
7d29fc02f2
21
network.py
21
network.py
|
@ -36,7 +36,14 @@ class Network:
|
||||||
self.output_size = 1
|
self.output_size = 1
|
||||||
self.hidden_size = 40
|
self.hidden_size = 40
|
||||||
# Can't remember the best learning_rate, look this up
|
# Can't remember the best learning_rate, look this up
|
||||||
self.learning_rate = 0.01
|
self.max_learning_rate = 0.1
|
||||||
|
self.min_learning_rate = 0.001
|
||||||
|
# self.learning_rate = 0.01
|
||||||
|
|
||||||
|
self.global_step = tf.Variable(0, trainable=False, name="global_step")
|
||||||
|
self.learning_rate = tf.maximum(self.min_learning_rate, tf.train.exponential_decay(self.max_learning_rate, self.global_step, 50000, 0.96, staircase=True), name="alpha")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Restore trained episode count for model
|
# Restore trained episode count for model
|
||||||
episode_count_path = os.path.join(self.checkpoint_path, "episodes_trained")
|
episode_count_path = os.path.join(self.checkpoint_path, "episodes_trained")
|
||||||
|
@ -80,6 +87,8 @@ class Network:
|
||||||
|
|
||||||
apply_gradients = []
|
apply_gradients = []
|
||||||
|
|
||||||
|
global_step_op = self.global_step.assign_add(1)
|
||||||
|
|
||||||
with tf.variable_scope('apply_gradients'):
|
with tf.variable_scope('apply_gradients'):
|
||||||
for gradient, trainable_var in zip(gradients, trainable_vars):
|
for gradient, trainable_var in zip(gradients, trainable_vars):
|
||||||
# Hopefully this is Δw_t = α(V_t+1 - V_t)▿_wV_t.
|
# Hopefully this is Δw_t = α(V_t+1 - V_t)▿_wV_t.
|
||||||
|
@ -128,8 +137,8 @@ class Network:
|
||||||
|
|
||||||
return sess.run(self.value, feed_dict={self.x: state})
|
return sess.run(self.value, feed_dict={self.x: state})
|
||||||
|
|
||||||
def save_model(self, sess, episode_count):
|
def save_model(self, sess, episode_count, global_step):
|
||||||
self.saver.save(sess, os.path.join(self.checkpoint_path, 'model.ckpt'))
|
self.saver.save(sess, os.path.join(self.checkpoint_path, 'model.ckpt'), global_step=global_step)
|
||||||
with open(os.path.join(self.checkpoint_path, "episodes_trained"), 'w+') as f:
|
with open(os.path.join(self.checkpoint_path, "episodes_trained"), 'w+') as f:
|
||||||
print("[NETWK] ({name}) Saving model to:".format(name=self.name),
|
print("[NETWK] ({name}) Saving model to:".format(name=self.name),
|
||||||
os.path.join(self.checkpoint_path, 'model.ckpt'))
|
os.path.join(self.checkpoint_path, 'model.ckpt'))
|
||||||
|
@ -396,7 +405,7 @@ class Network:
|
||||||
|
|
||||||
with tf.name_scope("final"):
|
with tf.name_scope("final"):
|
||||||
merged = tf.summary.merge_all()
|
merged = tf.summary.merge_all()
|
||||||
summary, _ = sess.run([merged, self.training_op],
|
summary, _, global_step = sess.run([merged, self.training_op, self.global_step],
|
||||||
feed_dict={self.x: self.board_trans_func(prev_board, player),
|
feed_dict={self.x: self.board_trans_func(prev_board, player),
|
||||||
self.value_next: scaled_final_score.reshape((1, 1))})
|
self.value_next: scaled_final_score.reshape((1, 1))})
|
||||||
writer.add_summary(summary, episode + trained_eps)
|
writer.add_summary(summary, episode + trained_eps)
|
||||||
|
@ -405,13 +414,13 @@ class Network:
|
||||||
|
|
||||||
if episode % min(save_step_size, episodes) == 0:
|
if episode % min(save_step_size, episodes) == 0:
|
||||||
sys.stderr.write("[TRAIN] Saving model...\n")
|
sys.stderr.write("[TRAIN] Saving model...\n")
|
||||||
self.save_model(sess, episode + trained_eps)
|
self.save_model(sess, episode + trained_eps, global_step)
|
||||||
|
|
||||||
if episode % 50 == 0:
|
if episode % 50 == 0:
|
||||||
print_time_estimate(episode)
|
print_time_estimate(episode)
|
||||||
|
|
||||||
sys.stderr.write("[TRAIN] Saving model for final episode...\n")
|
sys.stderr.write("[TRAIN] Saving model for final episode...\n")
|
||||||
self.save_model(sess, episode+trained_eps)
|
self.save_model(sess, episode+trained_eps, global_step=global_step)
|
||||||
|
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user