From 2d84cd5a0b80c3ff46e94bcd08c7fc9dab9176e1 Mon Sep 17 00:00:00 2001 From: Alexander Munch-Hansen Date: Thu, 10 May 2018 19:06:53 +0200 Subject: [PATCH] 1-ply now works again. --- network.py | 42 +++++++++++++++++++----------------------- network_test.py | 2 +- 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/network.py b/network.py index d84036f..657b924 100644 --- a/network.py +++ b/network.py @@ -262,17 +262,15 @@ class Network: sorted_moves_and_scores = sorted(moves_and_scores, key=itemgetter(1), reverse=player==1) - best_boards = [x[0] for x in sorted_moves_and_scores] + best_boards = [x[0] for x in sorted_moves_and_scores[:10]] - self.do_ply(best_boards, player) + scores, trans_scores = self.do_ply(best_boards, player) + best_score_idx = np.array(trans_scores).argmax() - #best_score_index = np.array(all_rolls_scores).argmax() - #best_board = best_fifteen_boards[best_score_index] - - #return [best_board, max(all_rolls_scores)] + return [best_boards[best_score_idx], scores[best_score_idx]] def do_ply(self, boards, player): """ @@ -305,31 +303,29 @@ class Network: all_rolls = gen_21_rolls() - all_rolls_scores = [] - start = time.time() list_of_moves = [] for idx, board in enumerate(boards): - list_of_moves.append([]) - for roll in all_rolls: - all_states = list(Board.calculate_legal_states(board, player, roll)) - list_of_moves[idx].append(all_states) - - tmp = [] - for board in list_of_moves: all_board_moves = [] - for roll in board: - for spec in roll: - legal_state = np.array(self.board_trans_func(spec, player)[0]) - all_board_moves.append(legal_state) - tmp.append(np.array(all_board_moves)) + for roll in all_rolls: + all_states = list(Board.calculate_legal_states(board, player*-1, roll)) + for state in all_states: + state = np.array(self.board_trans_func(state, player*-1)[0]) + all_board_moves.append(state) + list_of_moves.append(np.array(all_board_moves)) - # print(tmp) - for board in tmp: - print(self.model.predict_on_batch(board)) + all_scores = [self.model.predict_on_batch(board) for board in list_of_moves] + transformed_scores = [x if player == 1 else (1-x) for x in all_scores] + + scores_means = [tf.reduce_mean(score) for score in all_scores] + transformed_means = [tf.reduce_mean(score) for score in transformed_scores] + + + return ([scores_means, transformed_means]) + print(time.time() - start) diff --git a/network_test.py b/network_test.py index 243d2df..a4d8dda 100644 --- a/network_test.py +++ b/network_test.py @@ -55,4 +55,4 @@ network.print_variables() network.save_model(2) -network.calculate_1_ply(Board.initial_state, [3,2], 1) \ No newline at end of file +print(network.calculate_1_ply(Board.initial_state, [3,2], 1)) \ No newline at end of file