fix and clean

This commit is contained in:
Alexander Munch-Hansen 2018-05-18 14:55:10 +02:00
parent 3e379b40c4
commit 816cdfae00
2 changed files with 35 additions and 173 deletions

View File

@ -93,7 +93,7 @@ class Network:
:param decay_steps: The amount of steps between each decay
:return: The result of the exponential decay performed on the learning rate
"""
res = max_lr * decay_rate**(global_step // decay_steps)
res = max_lr * decay_rate ** (global_step // decay_steps)
return res
def do_backprop(self, prev_state, value_next):
@ -165,8 +165,7 @@ class Network:
:param states: A number of states. The states have to be transformed before being given to this function.
:return:
"""
values = self.model.predict_on_batch(states)
return values
return self.model.predict_on_batch(states)
def restore_model(self):
@ -174,7 +173,6 @@ class Network:
Restore a model for a session, such that a trained model and either be further trained or
used for evaluation
:param sess: Current session
:return: Nothing. It's a side-effect that a model gets restored for the network.
"""
@ -211,7 +209,6 @@ class Network:
and then picking the best, by using the network to evaluate each state. This is 0-ply, ie. no look-ahead.
The highest score is picked for the 1-player and the max(1-score) is picked for the -1-player.
:param sess:
:param board: Current board
:param roll: Current roll
:param player: Current player
@ -224,10 +221,9 @@ class Network:
transformed_scores = [x if np.sign(player) > 0 else 1 - x for x in scores]
best_score_idx = np.argmax(np.array(transformed_scores))
best_move = legal_moves[best_score_idx]
best_score = scores[best_score_idx]
best_move, best_score = legal_moves[best_score_idx], scores[best_score_idx]
return [best_move, best_score]
return (best_move, best_score)
def make_move_1_ply(self, board, roll, player):
"""
@ -237,9 +233,9 @@ class Network:
:param player:
:return:
"""
# start = time.time()
start = time.time()
best_pair = self.calculate_1_ply(board, roll, player)
# print(time.time() - start)
print(time.time() - start)
return best_pair
@ -248,35 +244,30 @@ class Network:
Find the best move based on a 1-ply look-ahead. First the x best moves are picked from a 0-ply and then
all moves and scores are found for them. The expected score is then calculated for each of the boards from the
0-ply.
:param sess:
:param board:
:param roll: The original roll
:param player: The current player
:return: Best possible move based on 1-ply look-ahead
"""
# find all legal states from the given board and the given roll
init_legal_states = Board.calculate_legal_states(board, player, roll)
legal_states = np.array([self.board_trans_func(state, player)[0] for state in init_legal_states])
scores = self.calc_vals(legal_states)
scores = [score.numpy() for score in scores]
scores = [ score.numpy()
for score
in self.calc_vals(legal_states) ]
moves_and_scores = list(zip(init_legal_states, scores))
sorted_moves_and_scores = sorted(moves_and_scores, key=itemgetter(1), reverse=player==1)
best_boards = [x[0] for x in sorted_moves_and_scores[:10]]
sorted_moves_and_scores = sorted(moves_and_scores, key=itemgetter(1), reverse=(player == 1))
best_boards = [ x[0] for x in sorted_moves_and_scores[:10] ]
scores, trans_scores = self.do_ply(best_boards, player)
best_score_idx = np.array(trans_scores).argmax()
return [best_boards[best_score_idx], scores[best_score_idx]]
return (best_boards[best_score_idx], scores[best_score_idx])
def do_ply(self, boards, player):
"""
@ -285,7 +276,6 @@ class Network:
allowing the function to search deeper, which could result in an even larger search space. If we wish
to have more than 2-ply, this should be fixed, so we could extend this method to allow for 3-ply.
:param sess:
:param boards: The boards to try all rolls on
:param player: The player of the previous ply
:return: An array of scores where each index describes one of the boards which was given as param
@ -305,11 +295,11 @@ class Network:
length_list = []
test_list = []
# Prepping of data
start= time.time()
start = time.time()
for board in boards:
length = 0
for roll in all_rolls:
all_states = list(Board.calculate_legal_states(board, player*-1, roll))
all_states = Board.calculate_legal_states(board, player*-1, roll)
for state in all_states:
state = np.array(self.board_trans_func(state, player*-1)[0])
test_list.append(state)
@ -318,148 +308,21 @@ class Network:
# print(time.time() - start)
start = time.time()
# start = time.time()
all_scores_legit = self.model.predict_on_batch(np.array(test_list))
all_scores = self.model.predict_on_batch(np.array(test_list))
split_scores = []
from_idx = 0
for length in length_list:
split_scores.append(all_scores_legit[from_idx:from_idx+length])
split_scores.append(all_scores[from_idx:from_idx+length])
from_idx += length
means_splits = [tf.reduce_mean(scores) for scores in split_scores]
transformed_means_splits = [x if player == 1 else (1-x) for x in means_splits]
# print(time.time() - start)
return ([means_splits, transformed_means_splits])
def calc_n_ply(self, n_init, sess, board, player, roll):
"""
:param n_init:
:param sess:
:param board:
:param player:
:param roll:
:return:
"""
# find all legal states from the given board and the given roll
init_legal_states = Board.calculate_legal_states(board, player, roll)
# find all values for the above boards
zero_ply_moves_and_scores = [(move, self.eval_state(sess, self.board_trans_func(move, player))) for move in init_legal_states]
# pythons reverse is in place and I can't call [:15] on it, without applying it to an object like so. Fuck.
sorted_moves_and_scores = sorted(zero_ply_moves_and_scores, key=itemgetter(1), reverse=player==1)
best_boards = [x[0] for x in sorted_moves_and_scores[:10]]
best_move_score_pair = self.n_ply(n_init, sess, best_boards, player)
return best_move_score_pair
def n_ply(self, n_init, sess, boards_init, player_init):
"""
:param n_init:
:param sess:
:param boards_init:
:param player_init:
:return:
"""
def ply(n, boards, player):
def calculate_possible_states(board):
possible_rolls = [ (1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
(1, 6), (2, 2), (2, 3), (2, 4), (2, 5),
(2, 6), (3, 3), (3, 4), (3, 5), (3, 6),
(4, 4), (4, 5), (4, 6), (5, 5), (5, 6),
(6, 6) ]
# for roll in possible_rolls:
# print(len(Board.calculate_legal_states(board, player, roll)))
return [ Board.calculate_legal_states(board, player, roll)
for roll
in possible_rolls ]
def find_best_state_score(boards):
score_pairs = [ (board, self.eval_state(sess, self.board_trans_func(board, player)))
for board
in boards ]
scores = [ pair[1]
for pair
in score_pairs ]
best_score_pair = score_pairs[np.array(scores).argmax()]
return best_score_pair
def average_score(boards):
return sum(boards)/len(boards)
def average_ply_score(board):
states_for_rolls = calculate_possible_states(board)
best_state_score_for_each_roll = [
find_best_state_score(states)
for states
in states_for_rolls ]
best_score_for_each_roll = [ x[1]
for x
in best_state_score_for_each_roll ]
average_score_var = average_score(best_score_for_each_roll)
return average_score_var
if n == 1:
average_score_pairs = [ (board, average_ply_score(board))
for board
in boards ]
return average_score_pairs
elif n > 1: # n != 1
def average_for_score_pairs(score_pairs):
scores = [ pair[1]
for pair
in score_pairs ]
return sum(scores)/len(scores)
def average_plain(scores):
return sum(scores)/len(scores)
print("+"*20)
print(n)
print(type(boards))
print(boards)
possible_states_for_boards = [
(board, calculate_possible_states(board))
for board
in boards ]
average_score_pairs = [
(inner_boards[0], average_plain([ average_for_score_pairs(ply(n - 1, inner_board, player * -1 if n == 1 else player))
for inner_board
in inner_boards[1] ]))
for inner_boards
in possible_states_for_boards ]
return average_score_pairs
else:
assert False
if n_init < 1: print("Unexpected argument n = {}".format(n_init)); exit()
boards_with_scores = ply(n_init, boards_init, -1 * player_init)
#print("Boards with scores:",boards_with_scores)
scores = [ ( pair[1] if player_init == 1 else (1 - pair[1]) )
for pair
in boards_with_scores ]
#print("All the scores:",scores)
best_score_pair = boards_with_scores[np.array(scores).argmax()]
return best_score_pair
return (means_splits, transformed_means_splits)
def eval(self, episode_count, trained_eps = 0):
@ -477,7 +340,6 @@ class Network:
"""
Do the actual evaluation
:param sess:
:param method: Either pubeval or dumbeval
:param episodes: Amount of episodes to use in the evaluation
:param trained_eps:
@ -509,11 +371,9 @@ class Network:
board = Board.initial_state
while Board.outcome(board) is None:
roll = (random.randrange(1, 7), random.randrange(1, 7))
board = (self.make_move(board, roll, 1))[0]
roll = (random.randrange(1, 7), random.randrange(1, 7))
board = Eval.make_pubeval_move(board, -1, roll)[0][0:26]
sys.stderr.write("\t outcome {}".format(Board.outcome(board)[1]))
@ -532,11 +392,9 @@ class Network:
board = Board.initial_state
while Board.outcome(board) is None:
roll = (random.randrange(1, 7), random.randrange(1, 7))
board = (self.make_move(board, roll, 1))[0]
roll = (random.randrange(1, 7), random.randrange(1, 7))
board = Eval.make_dumbeval_move(board, -1, roll)[0][0:26]
sys.stderr.write("\t outcome {}".format(Board.outcome(board)[1]))

View File

@ -20,21 +20,22 @@ class Player:
sets.append([Board.calculate_legal_states(board, player, [r,0]), r])
total += r
sets.append([Board.calculate_legal_states(board, player, [total,0]), total])
print(sets)
return sets
def tmp_name(self, from_board, to_board, roll, player, total_moves):
def tmp_name(self, from_board, to_board, roll, player, total_moves, is_quad = False):
sets = self.calc_move_sets(from_board, roll, player)
return_board = from_board
for idx, board_set in enumerate(sets):
board_set[0] = list(board_set[0])
print(to_board)
print(board_set)
# print(to_board)
# print(board_set)
if to_board in board_set[0]:
total_moves -= board_set[1]
# if it's not the sum of the moves
if idx < 2:
if idx < (4 if is_quad else 2):
roll[idx] = 0
else:
roll = [0,0]
@ -43,8 +44,11 @@ class Player:
return total_moves, roll, return_board
def make_human_move(self, board, roll):
total_moves = roll[0] + roll[1] if roll[0] != roll[1] else int(roll[0])*4
move = ""
is_quad = roll[0] == roll[1]
total_moves = roll[0] + roll[1] if not is_quad else int(roll[0])*4
if is_quad:
roll = [roll[0]]*4
while total_moves != 0:
while True:
print("You have {roll} left!".format(roll=total_moves))
@ -60,6 +64,6 @@ class Player:
print("The correct syntax is: 2/5 for a move from index 2 to 5.")
to_board = Board.apply_moves_to_board(board, self.get_sym(), move)
total_moves, roll, board = self.tmp_name(board, to_board, list(roll), self.get_sym(), total_moves)
total_moves, roll, board = self.tmp_name(board, to_board, list(roll), self.get_sym(), total_moves, is_quad)
print(Board.pretty(board))
return board