1-ply now works again.

This commit is contained in:
Alexander Munch-Hansen 2018-05-10 19:06:53 +02:00
parent 396d5b036d
commit 2d84cd5a0b
2 changed files with 20 additions and 24 deletions

View File

@ -262,17 +262,15 @@ class Network:
sorted_moves_and_scores = sorted(moves_and_scores, key=itemgetter(1), reverse=player==1)
best_boards = [x[0] for x in sorted_moves_and_scores]
best_boards = [x[0] for x in sorted_moves_and_scores[:10]]
self.do_ply(best_boards, player)
scores, trans_scores = self.do_ply(best_boards, player)
best_score_idx = np.array(trans_scores).argmax()
#best_score_index = np.array(all_rolls_scores).argmax()
#best_board = best_fifteen_boards[best_score_index]
#return [best_board, max(all_rolls_scores)]
return [best_boards[best_score_idx], scores[best_score_idx]]
def do_ply(self, boards, player):
"""
@ -305,31 +303,29 @@ class Network:
all_rolls = gen_21_rolls()
all_rolls_scores = []
start = time.time()
list_of_moves = []
for idx, board in enumerate(boards):
list_of_moves.append([])
for roll in all_rolls:
all_states = list(Board.calculate_legal_states(board, player, roll))
list_of_moves[idx].append(all_states)
tmp = []
for board in list_of_moves:
all_board_moves = []
for roll in board:
for spec in roll:
legal_state = np.array(self.board_trans_func(spec, player)[0])
all_board_moves.append(legal_state)
tmp.append(np.array(all_board_moves))
for roll in all_rolls:
all_states = list(Board.calculate_legal_states(board, player*-1, roll))
for state in all_states:
state = np.array(self.board_trans_func(state, player*-1)[0])
all_board_moves.append(state)
list_of_moves.append(np.array(all_board_moves))
# print(tmp)
for board in tmp:
print(self.model.predict_on_batch(board))
all_scores = [self.model.predict_on_batch(board) for board in list_of_moves]
transformed_scores = [x if player == 1 else (1-x) for x in all_scores]
scores_means = [tf.reduce_mean(score) for score in all_scores]
transformed_means = [tf.reduce_mean(score) for score in transformed_scores]
return ([scores_means, transformed_means])
print(time.time() - start)

View File

@ -55,4 +55,4 @@ network.print_variables()
network.save_model(2)
network.calculate_1_ply(Board.initial_state, [3,2], 1)
print(network.calculate_1_ply(Board.initial_state, [3,2], 1))