ply again again

This commit is contained in:
Christoffer Müller Madsen 2018-04-26 16:49:49 +02:00 committed by Alexander Munch-Hansen
parent 9428a00c11
commit afa6504b05

View File

@ -211,6 +211,96 @@ class Network:
return [best_board, max(all_rolls_scores)]
def n_ply(self, n_init, sess, boards_init, player_init):
def ply(n, boards, player):
def calculate_possible_states(board):
possible_rolls = [ (1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
(1, 6), (2, 2), (2, 3), (2, 4), (2, 5),
(2, 6), (3, 3), (3, 4), (3, 5), (3, 6),
(4, 4), (4, 5), (4, 6), (5, 5), (5, 6),
(6, 6) ]
return [ Board.calculate_legal_states(board, player, roll)
for roll
in possible_rolls ]
def find_best_state_score(boards):
score_pairs = [ (board, self.eval_state(sess, self.board_trans_func(board, player)))
for board
in boards ]
scores = [ pair[1]
for pair
in score_pairs ]
best_score_pair = score_pairs[np.array(scores).argmax()]
return best_score_pair
def average_score(boards):
return sum(boards)/len(boards)
def average_ply_score(board):
states_for_rolls = calculate_possible_states(board)
best_state_score_for_each_roll = [
find_best_state_score(states)
for states
in states_for_rolls ]
best_score_for_each_roll = [ x[1]
for x
in best_state_score_for_each_roll ]
average_score_var = average_score(best_score_for_each_roll)
return average_score_var
if n == 1:
print("blalhlalha")
average_score_pairs = [ (board, average_ply_score(board))
for board
in boards ]
return average_score_pairs
elif n > 1: # n != 1
def average_for_score_pairs(score_pairs):
scores = [ pair[1]
for pair
in score_pairs ]
return sum(scores)/len(scores)
def average_plain(scores):
return sum(scores)/len(scores)
print("+"*20)
print(n)
print(type(boards))
print(boards)
possible_states_for_boards = [
(board, calculate_possible_states(board))
for board
in boards ]
average_score_pairs = [
(inner_boards[0], average_plain([ average_for_score_pairs(ply(n - 1, inner_board, player * -1))
for inner_board
in inner_boards[1] ]))
for inner_boards
in possible_states_for_boards ]
return average_score_pairs
else:
assert False
if n_init < 1: print("Unexpected argument n = {}".format(n_init)); exit()
boards_with_scores = ply(n_init, boards_init, player_init)
print(boards_with_scores)
scores = [ pair[1]
for pair
in boards_with_scores ]
best_score_pair = boards_with_scores[np.array(scores).argmax()]
return best_score_pair[0]
def do_ply(self, sess, boards, player):
"""
Calculates a single extra ply, resulting in a larger search space for our best move.