fix and clean
This commit is contained in:
parent
3e379b40c4
commit
816cdfae00
180
network.py
180
network.py
|
@ -93,7 +93,7 @@ class Network:
|
|||
:param decay_steps: The amount of steps between each decay
|
||||
:return: The result of the exponential decay performed on the learning rate
|
||||
"""
|
||||
res = max_lr * decay_rate**(global_step // decay_steps)
|
||||
res = max_lr * decay_rate ** (global_step // decay_steps)
|
||||
return res
|
||||
|
||||
def do_backprop(self, prev_state, value_next):
|
||||
|
@ -165,8 +165,7 @@ class Network:
|
|||
:param states: A number of states. The states have to be transformed before being given to this function.
|
||||
:return:
|
||||
"""
|
||||
values = self.model.predict_on_batch(states)
|
||||
return values
|
||||
return self.model.predict_on_batch(states)
|
||||
|
||||
|
||||
def restore_model(self):
|
||||
|
@ -174,7 +173,6 @@ class Network:
|
|||
Restore a model for a session, such that a trained model and either be further trained or
|
||||
used for evaluation
|
||||
|
||||
:param sess: Current session
|
||||
:return: Nothing. It's a side-effect that a model gets restored for the network.
|
||||
"""
|
||||
|
||||
|
@ -211,7 +209,6 @@ class Network:
|
|||
and then picking the best, by using the network to evaluate each state. This is 0-ply, ie. no look-ahead.
|
||||
The highest score is picked for the 1-player and the max(1-score) is picked for the -1-player.
|
||||
|
||||
:param sess:
|
||||
:param board: Current board
|
||||
:param roll: Current roll
|
||||
:param player: Current player
|
||||
|
@ -224,10 +221,9 @@ class Network:
|
|||
transformed_scores = [x if np.sign(player) > 0 else 1 - x for x in scores]
|
||||
|
||||
best_score_idx = np.argmax(np.array(transformed_scores))
|
||||
best_move = legal_moves[best_score_idx]
|
||||
best_score = scores[best_score_idx]
|
||||
best_move, best_score = legal_moves[best_score_idx], scores[best_score_idx]
|
||||
|
||||
return [best_move, best_score]
|
||||
return (best_move, best_score)
|
||||
|
||||
def make_move_1_ply(self, board, roll, player):
|
||||
"""
|
||||
|
@ -237,9 +233,9 @@ class Network:
|
|||
:param player:
|
||||
:return:
|
||||
"""
|
||||
# start = time.time()
|
||||
start = time.time()
|
||||
best_pair = self.calculate_1_ply(board, roll, player)
|
||||
# print(time.time() - start)
|
||||
print(time.time() - start)
|
||||
return best_pair
|
||||
|
||||
|
||||
|
@ -248,35 +244,30 @@ class Network:
|
|||
Find the best move based on a 1-ply look-ahead. First the x best moves are picked from a 0-ply and then
|
||||
all moves and scores are found for them. The expected score is then calculated for each of the boards from the
|
||||
0-ply.
|
||||
:param sess:
|
||||
|
||||
:param board:
|
||||
:param roll: The original roll
|
||||
:param player: The current player
|
||||
:return: Best possible move based on 1-ply look-ahead
|
||||
|
||||
"""
|
||||
|
||||
# find all legal states from the given board and the given roll
|
||||
init_legal_states = Board.calculate_legal_states(board, player, roll)
|
||||
|
||||
legal_states = np.array([self.board_trans_func(state, player)[0] for state in init_legal_states])
|
||||
|
||||
scores = self.calc_vals(legal_states)
|
||||
scores = [score.numpy() for score in scores]
|
||||
scores = [ score.numpy()
|
||||
for score
|
||||
in self.calc_vals(legal_states) ]
|
||||
|
||||
moves_and_scores = list(zip(init_legal_states, scores))
|
||||
|
||||
sorted_moves_and_scores = sorted(moves_and_scores, key=itemgetter(1), reverse=player==1)
|
||||
|
||||
best_boards = [x[0] for x in sorted_moves_and_scores[:10]]
|
||||
|
||||
|
||||
sorted_moves_and_scores = sorted(moves_and_scores, key=itemgetter(1), reverse=(player == 1))
|
||||
best_boards = [ x[0] for x in sorted_moves_and_scores[:10] ]
|
||||
|
||||
scores, trans_scores = self.do_ply(best_boards, player)
|
||||
|
||||
best_score_idx = np.array(trans_scores).argmax()
|
||||
|
||||
return [best_boards[best_score_idx], scores[best_score_idx]]
|
||||
return (best_boards[best_score_idx], scores[best_score_idx])
|
||||
|
||||
def do_ply(self, boards, player):
|
||||
"""
|
||||
|
@ -285,7 +276,6 @@ class Network:
|
|||
allowing the function to search deeper, which could result in an even larger search space. If we wish
|
||||
to have more than 2-ply, this should be fixed, so we could extend this method to allow for 3-ply.
|
||||
|
||||
:param sess:
|
||||
:param boards: The boards to try all rolls on
|
||||
:param player: The player of the previous ply
|
||||
:return: An array of scores where each index describes one of the boards which was given as param
|
||||
|
@ -305,11 +295,11 @@ class Network:
|
|||
length_list = []
|
||||
test_list = []
|
||||
# Prepping of data
|
||||
start= time.time()
|
||||
start = time.time()
|
||||
for board in boards:
|
||||
length = 0
|
||||
for roll in all_rolls:
|
||||
all_states = list(Board.calculate_legal_states(board, player*-1, roll))
|
||||
all_states = Board.calculate_legal_states(board, player*-1, roll)
|
||||
for state in all_states:
|
||||
state = np.array(self.board_trans_func(state, player*-1)[0])
|
||||
test_list.append(state)
|
||||
|
@ -318,148 +308,21 @@ class Network:
|
|||
|
||||
# print(time.time() - start)
|
||||
|
||||
start = time.time()
|
||||
# start = time.time()
|
||||
|
||||
all_scores_legit = self.model.predict_on_batch(np.array(test_list))
|
||||
all_scores = self.model.predict_on_batch(np.array(test_list))
|
||||
|
||||
split_scores = []
|
||||
from_idx = 0
|
||||
for length in length_list:
|
||||
split_scores.append(all_scores_legit[from_idx:from_idx+length])
|
||||
split_scores.append(all_scores[from_idx:from_idx+length])
|
||||
from_idx += length
|
||||
|
||||
means_splits = [tf.reduce_mean(scores) for scores in split_scores]
|
||||
transformed_means_splits = [x if player == 1 else (1-x) for x in means_splits]
|
||||
# print(time.time() - start)
|
||||
|
||||
return ([means_splits, transformed_means_splits])
|
||||
|
||||
|
||||
def calc_n_ply(self, n_init, sess, board, player, roll):
|
||||
"""
|
||||
:param n_init:
|
||||
:param sess:
|
||||
:param board:
|
||||
:param player:
|
||||
:param roll:
|
||||
:return:
|
||||
"""
|
||||
|
||||
# find all legal states from the given board and the given roll
|
||||
init_legal_states = Board.calculate_legal_states(board, player, roll)
|
||||
|
||||
# find all values for the above boards
|
||||
zero_ply_moves_and_scores = [(move, self.eval_state(sess, self.board_trans_func(move, player))) for move in init_legal_states]
|
||||
|
||||
# pythons reverse is in place and I can't call [:15] on it, without applying it to an object like so. Fuck.
|
||||
sorted_moves_and_scores = sorted(zero_ply_moves_and_scores, key=itemgetter(1), reverse=player==1)
|
||||
|
||||
|
||||
best_boards = [x[0] for x in sorted_moves_and_scores[:10]]
|
||||
|
||||
best_move_score_pair = self.n_ply(n_init, sess, best_boards, player)
|
||||
|
||||
return best_move_score_pair
|
||||
|
||||
|
||||
def n_ply(self, n_init, sess, boards_init, player_init):
|
||||
"""
|
||||
:param n_init:
|
||||
:param sess:
|
||||
:param boards_init:
|
||||
:param player_init:
|
||||
:return:
|
||||
"""
|
||||
def ply(n, boards, player):
|
||||
def calculate_possible_states(board):
|
||||
possible_rolls = [ (1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
|
||||
(1, 6), (2, 2), (2, 3), (2, 4), (2, 5),
|
||||
(2, 6), (3, 3), (3, 4), (3, 5), (3, 6),
|
||||
(4, 4), (4, 5), (4, 6), (5, 5), (5, 6),
|
||||
(6, 6) ]
|
||||
|
||||
# for roll in possible_rolls:
|
||||
# print(len(Board.calculate_legal_states(board, player, roll)))
|
||||
|
||||
return [ Board.calculate_legal_states(board, player, roll)
|
||||
for roll
|
||||
in possible_rolls ]
|
||||
|
||||
def find_best_state_score(boards):
|
||||
score_pairs = [ (board, self.eval_state(sess, self.board_trans_func(board, player)))
|
||||
for board
|
||||
in boards ]
|
||||
scores = [ pair[1]
|
||||
for pair
|
||||
in score_pairs ]
|
||||
best_score_pair = score_pairs[np.array(scores).argmax()]
|
||||
|
||||
return best_score_pair
|
||||
|
||||
def average_score(boards):
|
||||
return sum(boards)/len(boards)
|
||||
|
||||
def average_ply_score(board):
|
||||
states_for_rolls = calculate_possible_states(board)
|
||||
|
||||
best_state_score_for_each_roll = [
|
||||
find_best_state_score(states)
|
||||
for states
|
||||
in states_for_rolls ]
|
||||
best_score_for_each_roll = [ x[1]
|
||||
for x
|
||||
in best_state_score_for_each_roll ]
|
||||
|
||||
average_score_var = average_score(best_score_for_each_roll)
|
||||
return average_score_var
|
||||
|
||||
|
||||
if n == 1:
|
||||
average_score_pairs = [ (board, average_ply_score(board))
|
||||
for board
|
||||
in boards ]
|
||||
return average_score_pairs
|
||||
elif n > 1: # n != 1
|
||||
def average_for_score_pairs(score_pairs):
|
||||
scores = [ pair[1]
|
||||
for pair
|
||||
in score_pairs ]
|
||||
return sum(scores)/len(scores)
|
||||
|
||||
def average_plain(scores):
|
||||
return sum(scores)/len(scores)
|
||||
|
||||
print("+"*20)
|
||||
print(n)
|
||||
print(type(boards))
|
||||
print(boards)
|
||||
possible_states_for_boards = [
|
||||
(board, calculate_possible_states(board))
|
||||
for board
|
||||
in boards ]
|
||||
|
||||
average_score_pairs = [
|
||||
(inner_boards[0], average_plain([ average_for_score_pairs(ply(n - 1, inner_board, player * -1 if n == 1 else player))
|
||||
for inner_board
|
||||
in inner_boards[1] ]))
|
||||
for inner_boards
|
||||
in possible_states_for_boards ]
|
||||
|
||||
return average_score_pairs
|
||||
|
||||
else:
|
||||
assert False
|
||||
|
||||
if n_init < 1: print("Unexpected argument n = {}".format(n_init)); exit()
|
||||
|
||||
boards_with_scores = ply(n_init, boards_init, -1 * player_init)
|
||||
#print("Boards with scores:",boards_with_scores)
|
||||
scores = [ ( pair[1] if player_init == 1 else (1 - pair[1]) )
|
||||
for pair
|
||||
in boards_with_scores ]
|
||||
#print("All the scores:",scores)
|
||||
best_score_pair = boards_with_scores[np.array(scores).argmax()]
|
||||
return best_score_pair
|
||||
return (means_splits, transformed_means_splits)
|
||||
|
||||
|
||||
def eval(self, episode_count, trained_eps = 0):
|
||||
|
@ -477,7 +340,6 @@ class Network:
|
|||
"""
|
||||
Do the actual evaluation
|
||||
|
||||
:param sess:
|
||||
:param method: Either pubeval or dumbeval
|
||||
:param episodes: Amount of episodes to use in the evaluation
|
||||
:param trained_eps:
|
||||
|
@ -509,11 +371,9 @@ class Network:
|
|||
board = Board.initial_state
|
||||
while Board.outcome(board) is None:
|
||||
roll = (random.randrange(1, 7), random.randrange(1, 7))
|
||||
|
||||
board = (self.make_move(board, roll, 1))[0]
|
||||
|
||||
roll = (random.randrange(1, 7), random.randrange(1, 7))
|
||||
|
||||
board = Eval.make_pubeval_move(board, -1, roll)[0][0:26]
|
||||
|
||||
sys.stderr.write("\t outcome {}".format(Board.outcome(board)[1]))
|
||||
|
@ -532,11 +392,9 @@ class Network:
|
|||
board = Board.initial_state
|
||||
while Board.outcome(board) is None:
|
||||
roll = (random.randrange(1, 7), random.randrange(1, 7))
|
||||
|
||||
board = (self.make_move(board, roll, 1))[0]
|
||||
|
||||
roll = (random.randrange(1, 7), random.randrange(1, 7))
|
||||
|
||||
board = Eval.make_dumbeval_move(board, -1, roll)[0][0:26]
|
||||
|
||||
sys.stderr.write("\t outcome {}".format(Board.outcome(board)[1]))
|
||||
|
|
18
player.py
18
player.py
|
@ -20,21 +20,22 @@ class Player:
|
|||
sets.append([Board.calculate_legal_states(board, player, [r,0]), r])
|
||||
total += r
|
||||
sets.append([Board.calculate_legal_states(board, player, [total,0]), total])
|
||||
print(sets)
|
||||
return sets
|
||||
|
||||
|
||||
def tmp_name(self, from_board, to_board, roll, player, total_moves):
|
||||
def tmp_name(self, from_board, to_board, roll, player, total_moves, is_quad = False):
|
||||
sets = self.calc_move_sets(from_board, roll, player)
|
||||
return_board = from_board
|
||||
for idx, board_set in enumerate(sets):
|
||||
|
||||
board_set[0] = list(board_set[0])
|
||||
print(to_board)
|
||||
print(board_set)
|
||||
# print(to_board)
|
||||
# print(board_set)
|
||||
if to_board in board_set[0]:
|
||||
total_moves -= board_set[1]
|
||||
# if it's not the sum of the moves
|
||||
if idx < 2:
|
||||
if idx < (4 if is_quad else 2):
|
||||
roll[idx] = 0
|
||||
else:
|
||||
roll = [0,0]
|
||||
|
@ -43,8 +44,11 @@ class Player:
|
|||
return total_moves, roll, return_board
|
||||
|
||||
def make_human_move(self, board, roll):
|
||||
total_moves = roll[0] + roll[1] if roll[0] != roll[1] else int(roll[0])*4
|
||||
move = ""
|
||||
is_quad = roll[0] == roll[1]
|
||||
total_moves = roll[0] + roll[1] if not is_quad else int(roll[0])*4
|
||||
if is_quad:
|
||||
roll = [roll[0]]*4
|
||||
|
||||
while total_moves != 0:
|
||||
while True:
|
||||
print("You have {roll} left!".format(roll=total_moves))
|
||||
|
@ -60,6 +64,6 @@ class Player:
|
|||
print("The correct syntax is: 2/5 for a move from index 2 to 5.")
|
||||
|
||||
to_board = Board.apply_moves_to_board(board, self.get_sym(), move)
|
||||
total_moves, roll, board = self.tmp_name(board, to_board, list(roll), self.get_sym(), total_moves)
|
||||
total_moves, roll, board = self.tmp_name(board, to_board, list(roll), self.get_sym(), total_moves, is_quad)
|
||||
print(Board.pretty(board))
|
||||
return board
|
Loading…
Reference in New Issue
Block a user