fix and clean

This commit is contained in:
Alexander Munch-Hansen 2018-05-18 14:55:10 +02:00
parent 3e379b40c4
commit 816cdfae00
2 changed files with 35 additions and 173 deletions

View File

@ -93,7 +93,7 @@ class Network:
:param decay_steps: The amount of steps between each decay :param decay_steps: The amount of steps between each decay
:return: The result of the exponential decay performed on the learning rate :return: The result of the exponential decay performed on the learning rate
""" """
res = max_lr * decay_rate**(global_step // decay_steps) res = max_lr * decay_rate ** (global_step // decay_steps)
return res return res
def do_backprop(self, prev_state, value_next): def do_backprop(self, prev_state, value_next):
@ -104,8 +104,8 @@ class Network:
:return: Nothing, the calculation is performed on the model of the network :return: Nothing, the calculation is performed on the model of the network
""" """
self.learning_rate = tf.maximum(self.min_learning_rate, self.learning_rate = tf.maximum(self.min_learning_rate,
self.exp_decay(self.max_learning_rate, self.global_step, 0.96, 50000), self.exp_decay(self.max_learning_rate, self.global_step, 0.96, 50000),
name="learning_rate") name="learning_rate")
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
value = self.model(prev_state.reshape(1,-1)) value = self.model(prev_state.reshape(1,-1))
@ -165,8 +165,7 @@ class Network:
:param states: A number of states. The states have to be transformed before being given to this function. :param states: A number of states. The states have to be transformed before being given to this function.
:return: :return:
""" """
values = self.model.predict_on_batch(states) return self.model.predict_on_batch(states)
return values
def restore_model(self): def restore_model(self):
@ -174,7 +173,6 @@ class Network:
Restore a model for a session, such that a trained model and either be further trained or Restore a model for a session, such that a trained model and either be further trained or
used for evaluation used for evaluation
:param sess: Current session
:return: Nothing. It's a side-effect that a model gets restored for the network. :return: Nothing. It's a side-effect that a model gets restored for the network.
""" """
@ -211,7 +209,6 @@ class Network:
and then picking the best, by using the network to evaluate each state. This is 0-ply, ie. no look-ahead. and then picking the best, by using the network to evaluate each state. This is 0-ply, ie. no look-ahead.
The highest score is picked for the 1-player and the max(1-score) is picked for the -1-player. The highest score is picked for the 1-player and the max(1-score) is picked for the -1-player.
:param sess:
:param board: Current board :param board: Current board
:param roll: Current roll :param roll: Current roll
:param player: Current player :param player: Current player
@ -224,10 +221,9 @@ class Network:
transformed_scores = [x if np.sign(player) > 0 else 1 - x for x in scores] transformed_scores = [x if np.sign(player) > 0 else 1 - x for x in scores]
best_score_idx = np.argmax(np.array(transformed_scores)) best_score_idx = np.argmax(np.array(transformed_scores))
best_move = legal_moves[best_score_idx] best_move, best_score = legal_moves[best_score_idx], scores[best_score_idx]
best_score = scores[best_score_idx]
return [best_move, best_score] return (best_move, best_score)
def make_move_1_ply(self, board, roll, player): def make_move_1_ply(self, board, roll, player):
""" """
@ -237,9 +233,9 @@ class Network:
:param player: :param player:
:return: :return:
""" """
# start = time.time() start = time.time()
best_pair = self.calculate_1_ply(board, roll, player) best_pair = self.calculate_1_ply(board, roll, player)
# print(time.time() - start) print(time.time() - start)
return best_pair return best_pair
@ -248,35 +244,30 @@ class Network:
Find the best move based on a 1-ply look-ahead. First the x best moves are picked from a 0-ply and then Find the best move based on a 1-ply look-ahead. First the x best moves are picked from a 0-ply and then
all moves and scores are found for them. The expected score is then calculated for each of the boards from the all moves and scores are found for them. The expected score is then calculated for each of the boards from the
0-ply. 0-ply.
:param sess:
:param board: :param board:
:param roll: The original roll :param roll: The original roll
:param player: The current player :param player: The current player
:return: Best possible move based on 1-ply look-ahead :return: Best possible move based on 1-ply look-ahead
""" """
# find all legal states from the given board and the given roll # find all legal states from the given board and the given roll
init_legal_states = Board.calculate_legal_states(board, player, roll) init_legal_states = Board.calculate_legal_states(board, player, roll)
legal_states = np.array([self.board_trans_func(state, player)[0] for state in init_legal_states]) legal_states = np.array([self.board_trans_func(state, player)[0] for state in init_legal_states])
scores = self.calc_vals(legal_states) scores = [ score.numpy()
scores = [score.numpy() for score in scores] for score
in self.calc_vals(legal_states) ]
moves_and_scores = list(zip(init_legal_states, scores)) moves_and_scores = list(zip(init_legal_states, scores))
sorted_moves_and_scores = sorted(moves_and_scores, key=itemgetter(1), reverse=(player == 1))
sorted_moves_and_scores = sorted(moves_and_scores, key=itemgetter(1), reverse=player==1) best_boards = [ x[0] for x in sorted_moves_and_scores[:10] ]
best_boards = [x[0] for x in sorted_moves_and_scores[:10]]
scores, trans_scores = self.do_ply(best_boards, player) scores, trans_scores = self.do_ply(best_boards, player)
best_score_idx = np.array(trans_scores).argmax() best_score_idx = np.array(trans_scores).argmax()
return [best_boards[best_score_idx], scores[best_score_idx]] return (best_boards[best_score_idx], scores[best_score_idx])
def do_ply(self, boards, player): def do_ply(self, boards, player):
""" """
@ -285,7 +276,6 @@ class Network:
allowing the function to search deeper, which could result in an even larger search space. If we wish allowing the function to search deeper, which could result in an even larger search space. If we wish
to have more than 2-ply, this should be fixed, so we could extend this method to allow for 3-ply. to have more than 2-ply, this should be fixed, so we could extend this method to allow for 3-ply.
:param sess:
:param boards: The boards to try all rolls on :param boards: The boards to try all rolls on
:param player: The player of the previous ply :param player: The player of the previous ply
:return: An array of scores where each index describes one of the boards which was given as param :return: An array of scores where each index describes one of the boards which was given as param
@ -305,11 +295,11 @@ class Network:
length_list = [] length_list = []
test_list = [] test_list = []
# Prepping of data # Prepping of data
start= time.time() start = time.time()
for board in boards: for board in boards:
length = 0 length = 0
for roll in all_rolls: for roll in all_rolls:
all_states = list(Board.calculate_legal_states(board, player*-1, roll)) all_states = Board.calculate_legal_states(board, player*-1, roll)
for state in all_states: for state in all_states:
state = np.array(self.board_trans_func(state, player*-1)[0]) state = np.array(self.board_trans_func(state, player*-1)[0])
test_list.append(state) test_list.append(state)
@ -318,148 +308,21 @@ class Network:
# print(time.time() - start) # print(time.time() - start)
start = time.time() # start = time.time()
all_scores_legit = self.model.predict_on_batch(np.array(test_list)) all_scores = self.model.predict_on_batch(np.array(test_list))
split_scores = [] split_scores = []
from_idx = 0 from_idx = 0
for length in length_list: for length in length_list:
split_scores.append(all_scores_legit[from_idx:from_idx+length]) split_scores.append(all_scores[from_idx:from_idx+length])
from_idx += length from_idx += length
means_splits = [tf.reduce_mean(scores) for scores in split_scores] means_splits = [tf.reduce_mean(scores) for scores in split_scores]
transformed_means_splits = [x if player == 1 else (1-x) for x in means_splits] transformed_means_splits = [x if player == 1 else (1-x) for x in means_splits]
# print(time.time() - start) # print(time.time() - start)
return ([means_splits, transformed_means_splits]) return (means_splits, transformed_means_splits)
def calc_n_ply(self, n_init, sess, board, player, roll):
"""
:param n_init:
:param sess:
:param board:
:param player:
:param roll:
:return:
"""
# find all legal states from the given board and the given roll
init_legal_states = Board.calculate_legal_states(board, player, roll)
# find all values for the above boards
zero_ply_moves_and_scores = [(move, self.eval_state(sess, self.board_trans_func(move, player))) for move in init_legal_states]
# pythons reverse is in place and I can't call [:15] on it, without applying it to an object like so. Fuck.
sorted_moves_and_scores = sorted(zero_ply_moves_and_scores, key=itemgetter(1), reverse=player==1)
best_boards = [x[0] for x in sorted_moves_and_scores[:10]]
best_move_score_pair = self.n_ply(n_init, sess, best_boards, player)
return best_move_score_pair
def n_ply(self, n_init, sess, boards_init, player_init):
"""
:param n_init:
:param sess:
:param boards_init:
:param player_init:
:return:
"""
def ply(n, boards, player):
def calculate_possible_states(board):
possible_rolls = [ (1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
(1, 6), (2, 2), (2, 3), (2, 4), (2, 5),
(2, 6), (3, 3), (3, 4), (3, 5), (3, 6),
(4, 4), (4, 5), (4, 6), (5, 5), (5, 6),
(6, 6) ]
# for roll in possible_rolls:
# print(len(Board.calculate_legal_states(board, player, roll)))
return [ Board.calculate_legal_states(board, player, roll)
for roll
in possible_rolls ]
def find_best_state_score(boards):
score_pairs = [ (board, self.eval_state(sess, self.board_trans_func(board, player)))
for board
in boards ]
scores = [ pair[1]
for pair
in score_pairs ]
best_score_pair = score_pairs[np.array(scores).argmax()]
return best_score_pair
def average_score(boards):
return sum(boards)/len(boards)
def average_ply_score(board):
states_for_rolls = calculate_possible_states(board)
best_state_score_for_each_roll = [
find_best_state_score(states)
for states
in states_for_rolls ]
best_score_for_each_roll = [ x[1]
for x
in best_state_score_for_each_roll ]
average_score_var = average_score(best_score_for_each_roll)
return average_score_var
if n == 1:
average_score_pairs = [ (board, average_ply_score(board))
for board
in boards ]
return average_score_pairs
elif n > 1: # n != 1
def average_for_score_pairs(score_pairs):
scores = [ pair[1]
for pair
in score_pairs ]
return sum(scores)/len(scores)
def average_plain(scores):
return sum(scores)/len(scores)
print("+"*20)
print(n)
print(type(boards))
print(boards)
possible_states_for_boards = [
(board, calculate_possible_states(board))
for board
in boards ]
average_score_pairs = [
(inner_boards[0], average_plain([ average_for_score_pairs(ply(n - 1, inner_board, player * -1 if n == 1 else player))
for inner_board
in inner_boards[1] ]))
for inner_boards
in possible_states_for_boards ]
return average_score_pairs
else:
assert False
if n_init < 1: print("Unexpected argument n = {}".format(n_init)); exit()
boards_with_scores = ply(n_init, boards_init, -1 * player_init)
#print("Boards with scores:",boards_with_scores)
scores = [ ( pair[1] if player_init == 1 else (1 - pair[1]) )
for pair
in boards_with_scores ]
#print("All the scores:",scores)
best_score_pair = boards_with_scores[np.array(scores).argmax()]
return best_score_pair
def eval(self, episode_count, trained_eps = 0): def eval(self, episode_count, trained_eps = 0):
@ -477,7 +340,6 @@ class Network:
""" """
Do the actual evaluation Do the actual evaluation
:param sess:
:param method: Either pubeval or dumbeval :param method: Either pubeval or dumbeval
:param episodes: Amount of episodes to use in the evaluation :param episodes: Amount of episodes to use in the evaluation
:param trained_eps: :param trained_eps:
@ -509,11 +371,9 @@ class Network:
board = Board.initial_state board = Board.initial_state
while Board.outcome(board) is None: while Board.outcome(board) is None:
roll = (random.randrange(1, 7), random.randrange(1, 7)) roll = (random.randrange(1, 7), random.randrange(1, 7))
board = (self.make_move(board, roll, 1))[0] board = (self.make_move(board, roll, 1))[0]
roll = (random.randrange(1, 7), random.randrange(1, 7)) roll = (random.randrange(1, 7), random.randrange(1, 7))
board = Eval.make_pubeval_move(board, -1, roll)[0][0:26] board = Eval.make_pubeval_move(board, -1, roll)[0][0:26]
sys.stderr.write("\t outcome {}".format(Board.outcome(board)[1])) sys.stderr.write("\t outcome {}".format(Board.outcome(board)[1]))
@ -532,11 +392,9 @@ class Network:
board = Board.initial_state board = Board.initial_state
while Board.outcome(board) is None: while Board.outcome(board) is None:
roll = (random.randrange(1, 7), random.randrange(1, 7)) roll = (random.randrange(1, 7), random.randrange(1, 7))
board = (self.make_move(board, roll, 1))[0] board = (self.make_move(board, roll, 1))[0]
roll = (random.randrange(1, 7), random.randrange(1, 7)) roll = (random.randrange(1, 7), random.randrange(1, 7))
board = Eval.make_dumbeval_move(board, -1, roll)[0][0:26] board = Eval.make_dumbeval_move(board, -1, roll)[0][0:26]
sys.stderr.write("\t outcome {}".format(Board.outcome(board)[1])) sys.stderr.write("\t outcome {}".format(Board.outcome(board)[1]))

View File

@ -20,21 +20,22 @@ class Player:
sets.append([Board.calculate_legal_states(board, player, [r,0]), r]) sets.append([Board.calculate_legal_states(board, player, [r,0]), r])
total += r total += r
sets.append([Board.calculate_legal_states(board, player, [total,0]), total]) sets.append([Board.calculate_legal_states(board, player, [total,0]), total])
print(sets)
return sets return sets
def tmp_name(self, from_board, to_board, roll, player, total_moves): def tmp_name(self, from_board, to_board, roll, player, total_moves, is_quad = False):
sets = self.calc_move_sets(from_board, roll, player) sets = self.calc_move_sets(from_board, roll, player)
return_board = from_board return_board = from_board
for idx, board_set in enumerate(sets): for idx, board_set in enumerate(sets):
board_set[0] = list(board_set[0]) board_set[0] = list(board_set[0])
print(to_board) # print(to_board)
print(board_set) # print(board_set)
if to_board in board_set[0]: if to_board in board_set[0]:
total_moves -= board_set[1] total_moves -= board_set[1]
# if it's not the sum of the moves # if it's not the sum of the moves
if idx < 2: if idx < (4 if is_quad else 2):
roll[idx] = 0 roll[idx] = 0
else: else:
roll = [0,0] roll = [0,0]
@ -43,8 +44,11 @@ class Player:
return total_moves, roll, return_board return total_moves, roll, return_board
def make_human_move(self, board, roll): def make_human_move(self, board, roll):
total_moves = roll[0] + roll[1] if roll[0] != roll[1] else int(roll[0])*4 is_quad = roll[0] == roll[1]
move = "" total_moves = roll[0] + roll[1] if not is_quad else int(roll[0])*4
if is_quad:
roll = [roll[0]]*4
while total_moves != 0: while total_moves != 0:
while True: while True:
print("You have {roll} left!".format(roll=total_moves)) print("You have {roll} left!".format(roll=total_moves))
@ -60,6 +64,6 @@ class Player:
print("The correct syntax is: 2/5 for a move from index 2 to 5.") print("The correct syntax is: 2/5 for a move from index 2 to 5.")
to_board = Board.apply_moves_to_board(board, self.get_sym(), move) to_board = Board.apply_moves_to_board(board, self.get_sym(), move)
total_moves, roll, board = self.tmp_name(board, to_board, list(roll), self.get_sym(), total_moves) total_moves, roll, board = self.tmp_name(board, to_board, list(roll), self.get_sym(), total_moves, is_quad)
print(Board.pretty(board)) print(Board.pretty(board))
return board return board