Fixed potential bug in regards to scores in 2-ply calculation.
This commit is contained in:
parent
ea3f05846d
commit
8899c5c2d9
37
network.py
37
network.py
|
@ -113,6 +113,14 @@ class Network:
|
||||||
f.write(str(episode_count) + "\n")
|
f.write(str(episode_count) + "\n")
|
||||||
|
|
||||||
def restore_model(self, sess):
|
def restore_model(self, sess):
|
||||||
|
"""
|
||||||
|
Restore a model for a session, such that a trained model and either be further trained or
|
||||||
|
used for evaluation
|
||||||
|
|
||||||
|
:param sess: Current session
|
||||||
|
:return: Nothing. It's a side-effect that a model gets restored for the network.
|
||||||
|
"""
|
||||||
|
|
||||||
if glob.glob(os.path.join(self.checkpoint_path, 'model.ckpt*.index')):
|
if glob.glob(os.path.join(self.checkpoint_path, 'model.ckpt*.index')):
|
||||||
|
|
||||||
latest_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path)
|
latest_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path)
|
||||||
|
@ -150,6 +158,17 @@ class Network:
|
||||||
|
|
||||||
|
|
||||||
def make_move(self, sess, board, roll, player):
|
def make_move(self, sess, board, roll, player):
|
||||||
|
"""
|
||||||
|
Find the best move given a board, roll and a player, by finding all possible states one can go to
|
||||||
|
and then picking the best, by using the network to evaluate each state. The highest score is picked
|
||||||
|
for the 1-player and the max(1-score) is picked for the -1-player.
|
||||||
|
|
||||||
|
:param sess:
|
||||||
|
:param board: Current board
|
||||||
|
:param roll: Current roll
|
||||||
|
:param player: Current player
|
||||||
|
:return: A pair of the best state to go to, together with the score of that state
|
||||||
|
"""
|
||||||
legal_moves = Board.calculate_legal_states(board, player, roll)
|
legal_moves = Board.calculate_legal_states(board, player, roll)
|
||||||
moves_and_scores = [(move, self.eval_state(sess, self.board_trans_func(move, player))) for move in legal_moves]
|
moves_and_scores = [(move, self.eval_state(sess, self.board_trans_func(move, player))) for move in legal_moves]
|
||||||
scores = [x[1] if np.sign(player) > 0 else 1-x[1] for x in moves_and_scores]
|
scores = [x[1] if np.sign(player) > 0 else 1-x[1] for x in moves_and_scores]
|
||||||
|
@ -189,7 +208,11 @@ class Network:
|
||||||
|
|
||||||
# pythons reverse is in place and I can't call [:15] on it, without applying it to an object like so. Fuck.
|
# pythons reverse is in place and I can't call [:15] on it, without applying it to an object like so. Fuck.
|
||||||
best_fifteen = sorted(zero_ply_moves_and_scores, key=itemgetter(1))
|
best_fifteen = sorted(zero_ply_moves_and_scores, key=itemgetter(1))
|
||||||
best_fifteen.reverse()
|
|
||||||
|
# They're sorted from smallest to largest, therefore we wan't to reverse if the current player is 1, since
|
||||||
|
# player 1 wishes to maximize. It's not needed for player -1, since that player seeks to minimize.
|
||||||
|
if player == 1:
|
||||||
|
best_fifteen.reverse()
|
||||||
best_fifteen_boards = [x[0] for x in best_fifteen[:15]]
|
best_fifteen_boards = [x[0] for x in best_fifteen[:15]]
|
||||||
|
|
||||||
all_rolls = gen_21_rolls()
|
all_rolls = gen_21_rolls()
|
||||||
|
@ -198,12 +221,13 @@ class Network:
|
||||||
for a_board in best_fifteen_boards:
|
for a_board in best_fifteen_boards:
|
||||||
a_board_scores = []
|
a_board_scores = []
|
||||||
for roll in all_rolls:
|
for roll in all_rolls:
|
||||||
spec_roll_scores = []
|
|
||||||
all_rolls_boards = Board.calculate_legal_states(a_board, player*-1, roll)
|
all_rolls_boards = Board.calculate_legal_states(a_board, player*-1, roll)
|
||||||
|
|
||||||
spec_roll_scores.append(
|
spec_roll_scores = [self.eval_state(sess, self.board_trans_func(new_board, player*-1))
|
||||||
[self.eval_state(sess, self.board_trans_func(new_board, player*-1)) for new_board in all_rolls_boards]
|
for new_board in all_rolls_boards]
|
||||||
)
|
|
||||||
|
# We need 1-score for the -1 player
|
||||||
|
spec_roll_scores = [x if player == 1 else (1-x) for x in spec_roll_scores]
|
||||||
|
|
||||||
best_score = max(spec_roll_scores)
|
best_score = max(spec_roll_scores)
|
||||||
|
|
||||||
|
@ -372,8 +396,7 @@ class Network:
|
||||||
(random.randrange(1, 7), random.randrange(1, 7)),
|
(random.randrange(1, 7), random.randrange(1, 7)),
|
||||||
player)
|
player)
|
||||||
|
|
||||||
# print("The evaluation of the previous state:\n", self.eval_state(sess, self.board_trans_func(prev_board, player)))
|
|
||||||
# print("The evaluation of the current_state:\n", cur_board_value)
|
|
||||||
|
|
||||||
# adjust weights
|
# adjust weights
|
||||||
sess.run(self.training_op,
|
sess.run(self.training_op,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user