1-ply now works again.
This commit is contained in:
parent
396d5b036d
commit
2d84cd5a0b
42
network.py
42
network.py
|
@ -262,17 +262,15 @@ class Network:
|
||||||
|
|
||||||
sorted_moves_and_scores = sorted(moves_and_scores, key=itemgetter(1), reverse=player==1)
|
sorted_moves_and_scores = sorted(moves_and_scores, key=itemgetter(1), reverse=player==1)
|
||||||
|
|
||||||
best_boards = [x[0] for x in sorted_moves_and_scores]
|
best_boards = [x[0] for x in sorted_moves_and_scores[:10]]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
self.do_ply(best_boards, player)
|
scores, trans_scores = self.do_ply(best_boards, player)
|
||||||
|
|
||||||
|
best_score_idx = np.array(trans_scores).argmax()
|
||||||
|
|
||||||
#best_score_index = np.array(all_rolls_scores).argmax()
|
return [best_boards[best_score_idx], scores[best_score_idx]]
|
||||||
#best_board = best_fifteen_boards[best_score_index]
|
|
||||||
|
|
||||||
#return [best_board, max(all_rolls_scores)]
|
|
||||||
|
|
||||||
def do_ply(self, boards, player):
|
def do_ply(self, boards, player):
|
||||||
"""
|
"""
|
||||||
|
@ -305,31 +303,29 @@ class Network:
|
||||||
|
|
||||||
all_rolls = gen_21_rolls()
|
all_rolls = gen_21_rolls()
|
||||||
|
|
||||||
all_rolls_scores = []
|
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
list_of_moves = []
|
list_of_moves = []
|
||||||
|
|
||||||
for idx, board in enumerate(boards):
|
for idx, board in enumerate(boards):
|
||||||
list_of_moves.append([])
|
|
||||||
for roll in all_rolls:
|
|
||||||
all_states = list(Board.calculate_legal_states(board, player, roll))
|
|
||||||
list_of_moves[idx].append(all_states)
|
|
||||||
|
|
||||||
tmp = []
|
|
||||||
for board in list_of_moves:
|
|
||||||
all_board_moves = []
|
all_board_moves = []
|
||||||
for roll in board:
|
for roll in all_rolls:
|
||||||
for spec in roll:
|
all_states = list(Board.calculate_legal_states(board, player*-1, roll))
|
||||||
legal_state = np.array(self.board_trans_func(spec, player)[0])
|
for state in all_states:
|
||||||
all_board_moves.append(legal_state)
|
state = np.array(self.board_trans_func(state, player*-1)[0])
|
||||||
tmp.append(np.array(all_board_moves))
|
all_board_moves.append(state)
|
||||||
|
list_of_moves.append(np.array(all_board_moves))
|
||||||
|
|
||||||
# print(tmp)
|
|
||||||
|
|
||||||
for board in tmp:
|
all_scores = [self.model.predict_on_batch(board) for board in list_of_moves]
|
||||||
print(self.model.predict_on_batch(board))
|
transformed_scores = [x if player == 1 else (1-x) for x in all_scores]
|
||||||
|
|
||||||
|
scores_means = [tf.reduce_mean(score) for score in all_scores]
|
||||||
|
transformed_means = [tf.reduce_mean(score) for score in transformed_scores]
|
||||||
|
|
||||||
|
|
||||||
|
return ([scores_means, transformed_means])
|
||||||
|
|
||||||
|
|
||||||
print(time.time() - start)
|
print(time.time() - start)
|
||||||
|
|
||||||
|
|
|
@ -55,4 +55,4 @@ network.print_variables()
|
||||||
|
|
||||||
network.save_model(2)
|
network.save_model(2)
|
||||||
|
|
||||||
network.calculate_1_ply(Board.initial_state, [3,2], 1)
|
print(network.calculate_1_ply(Board.initial_state, [3,2], 1))
|
Loading…
Reference in New Issue
Block a user