fix and clean
This commit is contained in:
parent
3e379b40c4
commit
816cdfae00
174
network.py
174
network.py
|
@ -165,8 +165,7 @@ class Network:
|
||||||
:param states: A number of states. The states have to be transformed before being given to this function.
|
:param states: A number of states. The states have to be transformed before being given to this function.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
values = self.model.predict_on_batch(states)
|
return self.model.predict_on_batch(states)
|
||||||
return values
|
|
||||||
|
|
||||||
|
|
||||||
def restore_model(self):
|
def restore_model(self):
|
||||||
|
@ -174,7 +173,6 @@ class Network:
|
||||||
Restore a model for a session, such that a trained model and either be further trained or
|
Restore a model for a session, such that a trained model and either be further trained or
|
||||||
used for evaluation
|
used for evaluation
|
||||||
|
|
||||||
:param sess: Current session
|
|
||||||
:return: Nothing. It's a side-effect that a model gets restored for the network.
|
:return: Nothing. It's a side-effect that a model gets restored for the network.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -211,7 +209,6 @@ class Network:
|
||||||
and then picking the best, by using the network to evaluate each state. This is 0-ply, ie. no look-ahead.
|
and then picking the best, by using the network to evaluate each state. This is 0-ply, ie. no look-ahead.
|
||||||
The highest score is picked for the 1-player and the max(1-score) is picked for the -1-player.
|
The highest score is picked for the 1-player and the max(1-score) is picked for the -1-player.
|
||||||
|
|
||||||
:param sess:
|
|
||||||
:param board: Current board
|
:param board: Current board
|
||||||
:param roll: Current roll
|
:param roll: Current roll
|
||||||
:param player: Current player
|
:param player: Current player
|
||||||
|
@ -224,10 +221,9 @@ class Network:
|
||||||
transformed_scores = [x if np.sign(player) > 0 else 1 - x for x in scores]
|
transformed_scores = [x if np.sign(player) > 0 else 1 - x for x in scores]
|
||||||
|
|
||||||
best_score_idx = np.argmax(np.array(transformed_scores))
|
best_score_idx = np.argmax(np.array(transformed_scores))
|
||||||
best_move = legal_moves[best_score_idx]
|
best_move, best_score = legal_moves[best_score_idx], scores[best_score_idx]
|
||||||
best_score = scores[best_score_idx]
|
|
||||||
|
|
||||||
return [best_move, best_score]
|
return (best_move, best_score)
|
||||||
|
|
||||||
def make_move_1_ply(self, board, roll, player):
|
def make_move_1_ply(self, board, roll, player):
|
||||||
"""
|
"""
|
||||||
|
@ -237,9 +233,9 @@ class Network:
|
||||||
:param player:
|
:param player:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# start = time.time()
|
start = time.time()
|
||||||
best_pair = self.calculate_1_ply(board, roll, player)
|
best_pair = self.calculate_1_ply(board, roll, player)
|
||||||
# print(time.time() - start)
|
print(time.time() - start)
|
||||||
return best_pair
|
return best_pair
|
||||||
|
|
||||||
|
|
||||||
|
@ -248,35 +244,30 @@ class Network:
|
||||||
Find the best move based on a 1-ply look-ahead. First the x best moves are picked from a 0-ply and then
|
Find the best move based on a 1-ply look-ahead. First the x best moves are picked from a 0-ply and then
|
||||||
all moves and scores are found for them. The expected score is then calculated for each of the boards from the
|
all moves and scores are found for them. The expected score is then calculated for each of the boards from the
|
||||||
0-ply.
|
0-ply.
|
||||||
:param sess:
|
|
||||||
:param board:
|
:param board:
|
||||||
:param roll: The original roll
|
:param roll: The original roll
|
||||||
:param player: The current player
|
:param player: The current player
|
||||||
:return: Best possible move based on 1-ply look-ahead
|
:return: Best possible move based on 1-ply look-ahead
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# find all legal states from the given board and the given roll
|
# find all legal states from the given board and the given roll
|
||||||
init_legal_states = Board.calculate_legal_states(board, player, roll)
|
init_legal_states = Board.calculate_legal_states(board, player, roll)
|
||||||
|
|
||||||
legal_states = np.array([self.board_trans_func(state, player)[0] for state in init_legal_states])
|
legal_states = np.array([self.board_trans_func(state, player)[0] for state in init_legal_states])
|
||||||
|
|
||||||
scores = self.calc_vals(legal_states)
|
scores = [ score.numpy()
|
||||||
scores = [score.numpy() for score in scores]
|
for score
|
||||||
|
in self.calc_vals(legal_states) ]
|
||||||
|
|
||||||
moves_and_scores = list(zip(init_legal_states, scores))
|
moves_and_scores = list(zip(init_legal_states, scores))
|
||||||
|
sorted_moves_and_scores = sorted(moves_and_scores, key=itemgetter(1), reverse=(player == 1))
|
||||||
sorted_moves_and_scores = sorted(moves_and_scores, key=itemgetter(1), reverse=player==1)
|
|
||||||
|
|
||||||
best_boards = [ x[0] for x in sorted_moves_and_scores[:10] ]
|
best_boards = [ x[0] for x in sorted_moves_and_scores[:10] ]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
scores, trans_scores = self.do_ply(best_boards, player)
|
scores, trans_scores = self.do_ply(best_boards, player)
|
||||||
|
|
||||||
best_score_idx = np.array(trans_scores).argmax()
|
best_score_idx = np.array(trans_scores).argmax()
|
||||||
|
|
||||||
return [best_boards[best_score_idx], scores[best_score_idx]]
|
return (best_boards[best_score_idx], scores[best_score_idx])
|
||||||
|
|
||||||
def do_ply(self, boards, player):
|
def do_ply(self, boards, player):
|
||||||
"""
|
"""
|
||||||
|
@ -285,7 +276,6 @@ class Network:
|
||||||
allowing the function to search deeper, which could result in an even larger search space. If we wish
|
allowing the function to search deeper, which could result in an even larger search space. If we wish
|
||||||
to have more than 2-ply, this should be fixed, so we could extend this method to allow for 3-ply.
|
to have more than 2-ply, this should be fixed, so we could extend this method to allow for 3-ply.
|
||||||
|
|
||||||
:param sess:
|
|
||||||
:param boards: The boards to try all rolls on
|
:param boards: The boards to try all rolls on
|
||||||
:param player: The player of the previous ply
|
:param player: The player of the previous ply
|
||||||
:return: An array of scores where each index describes one of the boards which was given as param
|
:return: An array of scores where each index describes one of the boards which was given as param
|
||||||
|
@ -309,7 +299,7 @@ class Network:
|
||||||
for board in boards:
|
for board in boards:
|
||||||
length = 0
|
length = 0
|
||||||
for roll in all_rolls:
|
for roll in all_rolls:
|
||||||
all_states = list(Board.calculate_legal_states(board, player*-1, roll))
|
all_states = Board.calculate_legal_states(board, player*-1, roll)
|
||||||
for state in all_states:
|
for state in all_states:
|
||||||
state = np.array(self.board_trans_func(state, player*-1)[0])
|
state = np.array(self.board_trans_func(state, player*-1)[0])
|
||||||
test_list.append(state)
|
test_list.append(state)
|
||||||
|
@ -318,148 +308,21 @@ class Network:
|
||||||
|
|
||||||
# print(time.time() - start)
|
# print(time.time() - start)
|
||||||
|
|
||||||
start = time.time()
|
# start = time.time()
|
||||||
|
|
||||||
all_scores_legit = self.model.predict_on_batch(np.array(test_list))
|
all_scores = self.model.predict_on_batch(np.array(test_list))
|
||||||
|
|
||||||
split_scores = []
|
split_scores = []
|
||||||
from_idx = 0
|
from_idx = 0
|
||||||
for length in length_list:
|
for length in length_list:
|
||||||
split_scores.append(all_scores_legit[from_idx:from_idx+length])
|
split_scores.append(all_scores[from_idx:from_idx+length])
|
||||||
from_idx += length
|
from_idx += length
|
||||||
|
|
||||||
means_splits = [tf.reduce_mean(scores) for scores in split_scores]
|
means_splits = [tf.reduce_mean(scores) for scores in split_scores]
|
||||||
transformed_means_splits = [x if player == 1 else (1-x) for x in means_splits]
|
transformed_means_splits = [x if player == 1 else (1-x) for x in means_splits]
|
||||||
# print(time.time() - start)
|
# print(time.time() - start)
|
||||||
|
|
||||||
return ([means_splits, transformed_means_splits])
|
return (means_splits, transformed_means_splits)
|
||||||
|
|
||||||
|
|
||||||
def calc_n_ply(self, n_init, sess, board, player, roll):
|
|
||||||
"""
|
|
||||||
:param n_init:
|
|
||||||
:param sess:
|
|
||||||
:param board:
|
|
||||||
:param player:
|
|
||||||
:param roll:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
# find all legal states from the given board and the given roll
|
|
||||||
init_legal_states = Board.calculate_legal_states(board, player, roll)
|
|
||||||
|
|
||||||
# find all values for the above boards
|
|
||||||
zero_ply_moves_and_scores = [(move, self.eval_state(sess, self.board_trans_func(move, player))) for move in init_legal_states]
|
|
||||||
|
|
||||||
# pythons reverse is in place and I can't call [:15] on it, without applying it to an object like so. Fuck.
|
|
||||||
sorted_moves_and_scores = sorted(zero_ply_moves_and_scores, key=itemgetter(1), reverse=player==1)
|
|
||||||
|
|
||||||
|
|
||||||
best_boards = [x[0] for x in sorted_moves_and_scores[:10]]
|
|
||||||
|
|
||||||
best_move_score_pair = self.n_ply(n_init, sess, best_boards, player)
|
|
||||||
|
|
||||||
return best_move_score_pair
|
|
||||||
|
|
||||||
|
|
||||||
def n_ply(self, n_init, sess, boards_init, player_init):
|
|
||||||
"""
|
|
||||||
:param n_init:
|
|
||||||
:param sess:
|
|
||||||
:param boards_init:
|
|
||||||
:param player_init:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
def ply(n, boards, player):
|
|
||||||
def calculate_possible_states(board):
|
|
||||||
possible_rolls = [ (1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
|
|
||||||
(1, 6), (2, 2), (2, 3), (2, 4), (2, 5),
|
|
||||||
(2, 6), (3, 3), (3, 4), (3, 5), (3, 6),
|
|
||||||
(4, 4), (4, 5), (4, 6), (5, 5), (5, 6),
|
|
||||||
(6, 6) ]
|
|
||||||
|
|
||||||
# for roll in possible_rolls:
|
|
||||||
# print(len(Board.calculate_legal_states(board, player, roll)))
|
|
||||||
|
|
||||||
return [ Board.calculate_legal_states(board, player, roll)
|
|
||||||
for roll
|
|
||||||
in possible_rolls ]
|
|
||||||
|
|
||||||
def find_best_state_score(boards):
|
|
||||||
score_pairs = [ (board, self.eval_state(sess, self.board_trans_func(board, player)))
|
|
||||||
for board
|
|
||||||
in boards ]
|
|
||||||
scores = [ pair[1]
|
|
||||||
for pair
|
|
||||||
in score_pairs ]
|
|
||||||
best_score_pair = score_pairs[np.array(scores).argmax()]
|
|
||||||
|
|
||||||
return best_score_pair
|
|
||||||
|
|
||||||
def average_score(boards):
|
|
||||||
return sum(boards)/len(boards)
|
|
||||||
|
|
||||||
def average_ply_score(board):
|
|
||||||
states_for_rolls = calculate_possible_states(board)
|
|
||||||
|
|
||||||
best_state_score_for_each_roll = [
|
|
||||||
find_best_state_score(states)
|
|
||||||
for states
|
|
||||||
in states_for_rolls ]
|
|
||||||
best_score_for_each_roll = [ x[1]
|
|
||||||
for x
|
|
||||||
in best_state_score_for_each_roll ]
|
|
||||||
|
|
||||||
average_score_var = average_score(best_score_for_each_roll)
|
|
||||||
return average_score_var
|
|
||||||
|
|
||||||
|
|
||||||
if n == 1:
|
|
||||||
average_score_pairs = [ (board, average_ply_score(board))
|
|
||||||
for board
|
|
||||||
in boards ]
|
|
||||||
return average_score_pairs
|
|
||||||
elif n > 1: # n != 1
|
|
||||||
def average_for_score_pairs(score_pairs):
|
|
||||||
scores = [ pair[1]
|
|
||||||
for pair
|
|
||||||
in score_pairs ]
|
|
||||||
return sum(scores)/len(scores)
|
|
||||||
|
|
||||||
def average_plain(scores):
|
|
||||||
return sum(scores)/len(scores)
|
|
||||||
|
|
||||||
print("+"*20)
|
|
||||||
print(n)
|
|
||||||
print(type(boards))
|
|
||||||
print(boards)
|
|
||||||
possible_states_for_boards = [
|
|
||||||
(board, calculate_possible_states(board))
|
|
||||||
for board
|
|
||||||
in boards ]
|
|
||||||
|
|
||||||
average_score_pairs = [
|
|
||||||
(inner_boards[0], average_plain([ average_for_score_pairs(ply(n - 1, inner_board, player * -1 if n == 1 else player))
|
|
||||||
for inner_board
|
|
||||||
in inner_boards[1] ]))
|
|
||||||
for inner_boards
|
|
||||||
in possible_states_for_boards ]
|
|
||||||
|
|
||||||
return average_score_pairs
|
|
||||||
|
|
||||||
else:
|
|
||||||
assert False
|
|
||||||
|
|
||||||
if n_init < 1: print("Unexpected argument n = {}".format(n_init)); exit()
|
|
||||||
|
|
||||||
boards_with_scores = ply(n_init, boards_init, -1 * player_init)
|
|
||||||
#print("Boards with scores:",boards_with_scores)
|
|
||||||
scores = [ ( pair[1] if player_init == 1 else (1 - pair[1]) )
|
|
||||||
for pair
|
|
||||||
in boards_with_scores ]
|
|
||||||
#print("All the scores:",scores)
|
|
||||||
best_score_pair = boards_with_scores[np.array(scores).argmax()]
|
|
||||||
return best_score_pair
|
|
||||||
|
|
||||||
|
|
||||||
def eval(self, episode_count, trained_eps = 0):
|
def eval(self, episode_count, trained_eps = 0):
|
||||||
|
@ -477,7 +340,6 @@ class Network:
|
||||||
"""
|
"""
|
||||||
Do the actual evaluation
|
Do the actual evaluation
|
||||||
|
|
||||||
:param sess:
|
|
||||||
:param method: Either pubeval or dumbeval
|
:param method: Either pubeval or dumbeval
|
||||||
:param episodes: Amount of episodes to use in the evaluation
|
:param episodes: Amount of episodes to use in the evaluation
|
||||||
:param trained_eps:
|
:param trained_eps:
|
||||||
|
@ -509,11 +371,9 @@ class Network:
|
||||||
board = Board.initial_state
|
board = Board.initial_state
|
||||||
while Board.outcome(board) is None:
|
while Board.outcome(board) is None:
|
||||||
roll = (random.randrange(1, 7), random.randrange(1, 7))
|
roll = (random.randrange(1, 7), random.randrange(1, 7))
|
||||||
|
|
||||||
board = (self.make_move(board, roll, 1))[0]
|
board = (self.make_move(board, roll, 1))[0]
|
||||||
|
|
||||||
roll = (random.randrange(1, 7), random.randrange(1, 7))
|
roll = (random.randrange(1, 7), random.randrange(1, 7))
|
||||||
|
|
||||||
board = Eval.make_pubeval_move(board, -1, roll)[0][0:26]
|
board = Eval.make_pubeval_move(board, -1, roll)[0][0:26]
|
||||||
|
|
||||||
sys.stderr.write("\t outcome {}".format(Board.outcome(board)[1]))
|
sys.stderr.write("\t outcome {}".format(Board.outcome(board)[1]))
|
||||||
|
@ -532,11 +392,9 @@ class Network:
|
||||||
board = Board.initial_state
|
board = Board.initial_state
|
||||||
while Board.outcome(board) is None:
|
while Board.outcome(board) is None:
|
||||||
roll = (random.randrange(1, 7), random.randrange(1, 7))
|
roll = (random.randrange(1, 7), random.randrange(1, 7))
|
||||||
|
|
||||||
board = (self.make_move(board, roll, 1))[0]
|
board = (self.make_move(board, roll, 1))[0]
|
||||||
|
|
||||||
roll = (random.randrange(1, 7), random.randrange(1, 7))
|
roll = (random.randrange(1, 7), random.randrange(1, 7))
|
||||||
|
|
||||||
board = Eval.make_dumbeval_move(board, -1, roll)[0][0:26]
|
board = Eval.make_dumbeval_move(board, -1, roll)[0][0:26]
|
||||||
|
|
||||||
sys.stderr.write("\t outcome {}".format(Board.outcome(board)[1]))
|
sys.stderr.write("\t outcome {}".format(Board.outcome(board)[1]))
|
||||||
|
|
18
player.py
18
player.py
|
@ -20,21 +20,22 @@ class Player:
|
||||||
sets.append([Board.calculate_legal_states(board, player, [r,0]), r])
|
sets.append([Board.calculate_legal_states(board, player, [r,0]), r])
|
||||||
total += r
|
total += r
|
||||||
sets.append([Board.calculate_legal_states(board, player, [total,0]), total])
|
sets.append([Board.calculate_legal_states(board, player, [total,0]), total])
|
||||||
|
print(sets)
|
||||||
return sets
|
return sets
|
||||||
|
|
||||||
|
|
||||||
def tmp_name(self, from_board, to_board, roll, player, total_moves):
|
def tmp_name(self, from_board, to_board, roll, player, total_moves, is_quad = False):
|
||||||
sets = self.calc_move_sets(from_board, roll, player)
|
sets = self.calc_move_sets(from_board, roll, player)
|
||||||
return_board = from_board
|
return_board = from_board
|
||||||
for idx, board_set in enumerate(sets):
|
for idx, board_set in enumerate(sets):
|
||||||
|
|
||||||
board_set[0] = list(board_set[0])
|
board_set[0] = list(board_set[0])
|
||||||
print(to_board)
|
# print(to_board)
|
||||||
print(board_set)
|
# print(board_set)
|
||||||
if to_board in board_set[0]:
|
if to_board in board_set[0]:
|
||||||
total_moves -= board_set[1]
|
total_moves -= board_set[1]
|
||||||
# if it's not the sum of the moves
|
# if it's not the sum of the moves
|
||||||
if idx < 2:
|
if idx < (4 if is_quad else 2):
|
||||||
roll[idx] = 0
|
roll[idx] = 0
|
||||||
else:
|
else:
|
||||||
roll = [0,0]
|
roll = [0,0]
|
||||||
|
@ -43,8 +44,11 @@ class Player:
|
||||||
return total_moves, roll, return_board
|
return total_moves, roll, return_board
|
||||||
|
|
||||||
def make_human_move(self, board, roll):
|
def make_human_move(self, board, roll):
|
||||||
total_moves = roll[0] + roll[1] if roll[0] != roll[1] else int(roll[0])*4
|
is_quad = roll[0] == roll[1]
|
||||||
move = ""
|
total_moves = roll[0] + roll[1] if not is_quad else int(roll[0])*4
|
||||||
|
if is_quad:
|
||||||
|
roll = [roll[0]]*4
|
||||||
|
|
||||||
while total_moves != 0:
|
while total_moves != 0:
|
||||||
while True:
|
while True:
|
||||||
print("You have {roll} left!".format(roll=total_moves))
|
print("You have {roll} left!".format(roll=total_moves))
|
||||||
|
@ -60,6 +64,6 @@ class Player:
|
||||||
print("The correct syntax is: 2/5 for a move from index 2 to 5.")
|
print("The correct syntax is: 2/5 for a move from index 2 to 5.")
|
||||||
|
|
||||||
to_board = Board.apply_moves_to_board(board, self.get_sym(), move)
|
to_board = Board.apply_moves_to_board(board, self.get_sym(), move)
|
||||||
total_moves, roll, board = self.tmp_name(board, to_board, list(roll), self.get_sym(), total_moves)
|
total_moves, roll, board = self.tmp_name(board, to_board, list(roll), self.get_sym(), total_moves, is_quad)
|
||||||
print(Board.pretty(board))
|
print(Board.pretty(board))
|
||||||
return board
|
return board
|
Loading…
Reference in New Issue
Block a user