1-ply now works again.
This commit is contained in:
parent
396d5b036d
commit
2d84cd5a0b
42
network.py
42
network.py
|
@ -262,17 +262,15 @@ class Network:
|
|||
|
||||
sorted_moves_and_scores = sorted(moves_and_scores, key=itemgetter(1), reverse=player==1)
|
||||
|
||||
best_boards = [x[0] for x in sorted_moves_and_scores]
|
||||
best_boards = [x[0] for x in sorted_moves_and_scores[:10]]
|
||||
|
||||
|
||||
|
||||
self.do_ply(best_boards, player)
|
||||
scores, trans_scores = self.do_ply(best_boards, player)
|
||||
|
||||
best_score_idx = np.array(trans_scores).argmax()
|
||||
|
||||
#best_score_index = np.array(all_rolls_scores).argmax()
|
||||
#best_board = best_fifteen_boards[best_score_index]
|
||||
|
||||
#return [best_board, max(all_rolls_scores)]
|
||||
return [best_boards[best_score_idx], scores[best_score_idx]]
|
||||
|
||||
def do_ply(self, boards, player):
|
||||
"""
|
||||
|
@ -305,31 +303,29 @@ class Network:
|
|||
|
||||
all_rolls = gen_21_rolls()
|
||||
|
||||
all_rolls_scores = []
|
||||
|
||||
start = time.time()
|
||||
|
||||
list_of_moves = []
|
||||
|
||||
for idx, board in enumerate(boards):
|
||||
list_of_moves.append([])
|
||||
for roll in all_rolls:
|
||||
all_states = list(Board.calculate_legal_states(board, player, roll))
|
||||
list_of_moves[idx].append(all_states)
|
||||
|
||||
tmp = []
|
||||
for board in list_of_moves:
|
||||
all_board_moves = []
|
||||
for roll in board:
|
||||
for spec in roll:
|
||||
legal_state = np.array(self.board_trans_func(spec, player)[0])
|
||||
all_board_moves.append(legal_state)
|
||||
tmp.append(np.array(all_board_moves))
|
||||
for roll in all_rolls:
|
||||
all_states = list(Board.calculate_legal_states(board, player*-1, roll))
|
||||
for state in all_states:
|
||||
state = np.array(self.board_trans_func(state, player*-1)[0])
|
||||
all_board_moves.append(state)
|
||||
list_of_moves.append(np.array(all_board_moves))
|
||||
|
||||
# print(tmp)
|
||||
|
||||
for board in tmp:
|
||||
print(self.model.predict_on_batch(board))
|
||||
all_scores = [self.model.predict_on_batch(board) for board in list_of_moves]
|
||||
transformed_scores = [x if player == 1 else (1-x) for x in all_scores]
|
||||
|
||||
scores_means = [tf.reduce_mean(score) for score in all_scores]
|
||||
transformed_means = [tf.reduce_mean(score) for score in transformed_scores]
|
||||
|
||||
|
||||
return ([scores_means, transformed_means])
|
||||
|
||||
|
||||
print(time.time() - start)
|
||||
|
||||
|
|
|
@ -55,4 +55,4 @@ network.print_variables()
|
|||
|
||||
network.save_model(2)
|
||||
|
||||
network.calculate_1_ply(Board.initial_state, [3,2], 1)
|
||||
print(network.calculate_1_ply(Board.initial_state, [3,2], 1))
|
Loading…
Reference in New Issue
Block a user