make_move now calls n_ply to search deeper and potentially give

better moves. It's hella fucking slow.
This commit is contained in:
Alexander Munch-Hansen 2018-05-02 01:06:23 +02:00
parent 695a3d43db
commit 1db469709a
2 changed files with 46 additions and 18 deletions

View File

@ -157,7 +157,7 @@ class Network:
exit() exit()
def make_move(self, sess, board, roll, player): #def make_move(self, sess, board, roll, player):
""" """
Find the best move given a board, roll and a player, by finding all possible states one can go to Find the best move given a board, roll and a player, by finding all possible states one can go to
and then picking the best, by using the network to evaluate each state. The highest score is picked and then picking the best, by using the network to evaluate each state. The highest score is picked
@ -169,24 +169,28 @@ class Network:
:param player: Current player :param player: Current player
:return: A pair of the best state to go to, together with the score of that state :return: A pair of the best state to go to, together with the score of that state
""" """
legal_moves = Board.calculate_legal_states(board, player, roll) # legal_moves = Board.calculate_legal_states(board, player, roll)
moves_and_scores = [(move, self.eval_state(sess, self.board_trans_func(move, player))) for move in legal_moves] # moves_and_scores = [(move, self.eval_state(sess, self.board_trans_func(move, player))) for move in legal_moves]
scores = [x[1] if np.sign(player) > 0 else 1-x[1] for x in moves_and_scores] # scores = [x[1] if np.sign(player) > 0 else 1-x[1] for x in moves_and_scores]
best_score_index = np.array(scores).argmax() # best_score_index = np.array(scores).argmax()
best_move_pair = moves_and_scores[best_score_index] # best_move_pair = moves_and_scores[best_score_index]
return best_move_pair # return best_move_pair
def make_move(self, sess, board, roll, player, n = 1):
best_pair = self.calc_n_ply(n, sess, board, player, roll)
return best_pair
def calculate_2_ply(self, sess, board, roll, player): def calculate_1_ply(self, sess, board, roll, player):
""" """
Find the best move based on a 2-ply look-ahead. First the best move is found for a single ply and then an Find the best move based on a 1-ply look-ahead. First the best move is found for a single ply and then an
exhaustive search is performed on the best 15 moves from the single ply. exhaustive search is performed on the best 15 moves from the single ply.
:param sess: :param sess:
:param board: :param board:
:param roll: The original roll :param roll: The original roll
:param player: The current player :param player: The current player
:return: Best possible move based on 2-ply look-ahead :return: Best possible move based on 1-ply look-ahead
""" """
@ -205,7 +209,7 @@ class Network:
if player == 1: if player == 1:
best_fifteen.reverse() best_fifteen.reverse()
best_fifteen_boards = [x[0] for x in best_fifteen[:15]] best_fifteen_boards = [x[0] for x in best_fifteen[:10]]
all_rolls_scores = self.do_ply(sess, best_fifteen_boards, player) all_rolls_scores = self.do_ply(sess, best_fifteen_boards, player)
@ -215,6 +219,29 @@ class Network:
return [best_board, max(all_rolls_scores)] return [best_board, max(all_rolls_scores)]
def calc_n_ply(self, n_init, sess, board, player, roll):
# find all legal states from the given board and the given roll
init_legal_states = Board.calculate_legal_states(board, player, roll)
# find all values for the above boards
zero_ply_moves_and_scores = [(move, self.eval_state(sess, self.board_trans_func(move, player))) for move in init_legal_states]
# pythons reverse is in place and I can't call [:15] on it, without applying it to an object like so. Fuck.
sorted_moves_and_scores = sorted(zero_ply_moves_and_scores, key=itemgetter(1))
# They're sorted from smallest to largest, therefore we wan't to reverse if the current player is 1, since
# player 1 wishes to maximize. It's not needed for player -1, since that player seeks to minimize.
if player == 1:
sorted_moves_and_scores.reverse()
best_boards = [x[0] for x in sorted_moves_and_scores[:10]]
best_move_score_pair = self.n_ply(n_init, sess, best_boards, player)
return best_move_score_pair
def n_ply(self, n_init, sess, boards_init, player_init): def n_ply(self, n_init, sess, boards_init, player_init):
def ply(n, boards, player): def ply(n, boards, player):
@ -262,7 +289,6 @@ class Network:
if n == 1: if n == 1:
print("blalhlalha")
average_score_pairs = [ (board, average_ply_score(board)) average_score_pairs = [ (board, average_ply_score(board))
for board for board
in boards ] in boards ]
@ -301,12 +327,13 @@ class Network:
if n_init < 1: print("Unexpected argument n = {}".format(n_init)); exit() if n_init < 1: print("Unexpected argument n = {}".format(n_init)); exit()
boards_with_scores = ply(n_init, boards_init, -1 * player_init) boards_with_scores = ply(n_init, boards_init, -1 * player_init)
print(boards_with_scores) #print("Boards with scores:",boards_with_scores)
scores = [ ( pair[1] if player_init == 1 else (1 - pair[1]) ) scores = [ ( pair[1] if player_init == 1 else (1 - pair[1]) )
for pair for pair
in boards_with_scores ] in boards_with_scores ]
#print("All the scores:",scores)
best_score_pair = boards_with_scores[np.array(scores).argmax()] best_score_pair = boards_with_scores[np.array(scores).argmax()]
return best_score_pair[0] return best_score_pair
def do_ply(self, sess, boards, player): def do_ply(self, sess, boards, player):
""" """

View File

@ -80,13 +80,14 @@ def calculate_possible_states(board):
#for board in boards: #for board in boards:
# calculate_possible_states(board) # calculate_possible_states(board)
print("-"*30) #print("-"*30)
print(network.do_ply(session, boards, 1)) #print(network.calculate_1_ply(session, Board.initial_state, [2,4], 1))
#print(" "*10 + "network_test") #print(" "*10 + "network_test")
#print(" "*20 + "Depth 1") print(" "*20 + "Depth 1")
scores = network.n_ply(1, session, boards, 1) print(network.calc_n_ply(2, session, Board.initial_state, 1, [2, 4]))
#print(scores)
#print(" "*20 + "Depth 2") #print(" "*20 + "Depth 2")
#print(network.n_ply(2, session, boards, 1)) #print(network.n_ply(2, session, boards, 1))