Added a lot of comments
This commit is contained in:
parent
f2a67ca92e
commit
4efb229d34
84
network.py
84
network.py
|
@ -19,13 +19,18 @@ class Network:
|
|||
'quack-fat' : (30, Board.board_features_quack_fat),
|
||||
'quack' : (28, Board.board_features_quack),
|
||||
'tesauro' : (198, Board.board_features_tesauro),
|
||||
'quack-norm': (30, Board.board_features_quack_norm)
|
||||
'quack-norm' : (30, Board.board_features_quack_norm),
|
||||
'tesauro-poop': (198, Board.board_features_tesauro_wrong)
|
||||
}
|
||||
|
||||
def custom_tanh(self, x, name=None):
|
||||
return tf.scalar_mul(tf.constant(2.00), tf.tanh(x, name))
|
||||
|
||||
def __init__(self, config, name):
|
||||
"""
|
||||
:param config:
|
||||
:param name:
|
||||
"""
|
||||
tf.enable_eager_execution()
|
||||
|
||||
xavier_init = tf.contrib.layers.xavier_initializer()
|
||||
|
@ -44,7 +49,6 @@ class Network:
|
|||
self.max_learning_rate = 0.1
|
||||
self.min_learning_rate = 0.001
|
||||
|
||||
#tf.train.get_or_create_global_step()
|
||||
# Restore trained episode count for model
|
||||
episode_count_path = os.path.join(self.checkpoint_path, "episodes_trained")
|
||||
if os.path.isfile(episode_count_path):
|
||||
|
@ -61,7 +65,6 @@ class Network:
|
|||
self.global_step = 0
|
||||
|
||||
|
||||
|
||||
self.model = tf.keras.Sequential([
|
||||
tf.keras.layers.Dense(40, activation="sigmoid", kernel_initializer=xavier_init,
|
||||
input_shape=(1,self.input_size)),
|
||||
|
@ -69,19 +72,29 @@ class Network:
|
|||
])
|
||||
|
||||
|
||||
|
||||
|
||||
def exp_decay(self, max_lr, epi_counter, decay_rate, decay_steps):
|
||||
res = max_lr * decay_rate**(epi_counter // decay_steps)
|
||||
def exp_decay(self, max_lr, global_step, decay_rate, decay_steps):
|
||||
"""
|
||||
Calculates the exponential decay on a learning rate
|
||||
:param max_lr: The learning rate that the network starts at
|
||||
:param global_step: The global step
|
||||
:param decay_rate: The rate at which the learning rate should decay
|
||||
:param decay_steps: The amount of steps between each decay
|
||||
:return: The result of the exponential decay performed on the learning rate
|
||||
"""
|
||||
res = max_lr * decay_rate**(global_step // decay_steps)
|
||||
return res
|
||||
|
||||
def do_backprop(self, prev_state, value_next):
|
||||
|
||||
"""
|
||||
Performs the Temporal-difference backpropagation step on the model
|
||||
:param prev_state: The previous state of the game, this has its value recalculated
|
||||
:param value_next: The value of the current move
|
||||
:return: Nothing, the calculation is performed on the model of the network
|
||||
"""
|
||||
self.learning_rate = tf.maximum(self.min_learning_rate,
|
||||
self.exp_decay(self.max_learning_rate, self.global_step, 0.96, 50000),
|
||||
name="learning_rate")
|
||||
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
value = self.model(prev_state.reshape(1,-1))
|
||||
grads = tape.gradient(value, self.model.variables)
|
||||
|
@ -89,8 +102,6 @@ class Network:
|
|||
difference_in_values = tf.reshape(tf.subtract(value_next, value, name='difference_in_values'), [])
|
||||
tf.summary.scalar("difference_in_values", tf.abs(difference_in_values))
|
||||
|
||||
# global_step_op = self.global_step.assign_add(1)
|
||||
|
||||
with tf.variable_scope('apply_gradients'):
|
||||
for grad, train_var in zip(grads, self.model.variables):
|
||||
backprop_calc = self.learning_rate * difference_in_values * grad
|
||||
|
@ -99,16 +110,25 @@ class Network:
|
|||
|
||||
|
||||
def print_variables(self):
|
||||
"""
|
||||
Prints all the variables of the model
|
||||
:return:
|
||||
"""
|
||||
variables = self.model.variables
|
||||
|
||||
for k in variables:
|
||||
print(k)
|
||||
|
||||
def eval_state(self, state):
|
||||
"""
|
||||
Evaluates a single state
|
||||
:param state:
|
||||
:return:
|
||||
"""
|
||||
return self.model(state.reshape(1,-1))
|
||||
|
||||
def save_model(self, episode_count):
|
||||
"""
|
||||
Saves the model of the network, it references global_step as self.global_step
|
||||
:param episode_count:
|
||||
:return:
|
||||
"""
|
||||
|
@ -128,6 +148,10 @@ class Network:
|
|||
|
||||
|
||||
def calc_vals(self, states):
|
||||
"""
|
||||
:param states:
|
||||
:return:
|
||||
"""
|
||||
values = self.model.predict_on_batch(states)
|
||||
return values
|
||||
|
||||
|
@ -195,6 +219,15 @@ class Network:
|
|||
return [best_move, best_score]
|
||||
|
||||
def make_move_n_ply(self, sess, board, roll, player, n = 1):
|
||||
"""
|
||||
|
||||
:param sess:
|
||||
:param board:
|
||||
:param roll:
|
||||
:param player:
|
||||
:param n:
|
||||
:return:
|
||||
"""
|
||||
best_pair = self.calc_n_ply(n, sess, board, player, roll)
|
||||
return best_pair
|
||||
|
||||
|
@ -232,6 +265,15 @@ class Network:
|
|||
return [best_board, max(all_rolls_scores)]
|
||||
|
||||
def calc_n_ply(self, n_init, sess, board, player, roll):
|
||||
"""
|
||||
|
||||
:param n_init:
|
||||
:param sess:
|
||||
:param board:
|
||||
:param player:
|
||||
:param roll:
|
||||
:return:
|
||||
"""
|
||||
|
||||
# find all legal states from the given board and the given roll
|
||||
init_legal_states = Board.calculate_legal_states(board, player, roll)
|
||||
|
@ -251,6 +293,14 @@ class Network:
|
|||
|
||||
|
||||
def n_ply(self, n_init, sess, boards_init, player_init):
|
||||
"""
|
||||
|
||||
:param n_init:
|
||||
:param sess:
|
||||
:param boards_init:
|
||||
:param player_init:
|
||||
:return:
|
||||
"""
|
||||
def ply(n, boards, player):
|
||||
def calculate_possible_states(board):
|
||||
possible_rolls = [ (1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
|
||||
|
@ -504,6 +554,13 @@ class Network:
|
|||
|
||||
|
||||
def train_model(self, episodes=1000, save_step_size=100, trained_eps=0):
|
||||
"""
|
||||
|
||||
:param episodes:
|
||||
:param save_step_size:
|
||||
:param trained_eps:
|
||||
:return:
|
||||
"""
|
||||
with tf.Session() as sess:
|
||||
difference_in_vals = 0
|
||||
|
||||
|
@ -563,11 +620,8 @@ class Network:
|
|||
final_score = np.array([Board.outcome(final_board)[1]])
|
||||
scaled_final_score = ((final_score + 2) / 4)
|
||||
|
||||
|
||||
self.do_backprop(self.board_trans_func(prev_board, player), scaled_final_score.reshape(1,1))
|
||||
|
||||
|
||||
|
||||
sys.stderr.write("\n")
|
||||
|
||||
if episode % min(save_step_size, episodes) == 0:
|
||||
|
|
Loading…
Reference in New Issue
Block a user