diff --git a/network.py b/network.py index 76a21f8..006b3e9 100644 --- a/network.py +++ b/network.py @@ -113,6 +113,14 @@ class Network: f.write(str(episode_count) + "\n") def restore_model(self, sess): + """ + Restore a model for a session, such that a trained model and either be further trained or + used for evaluation + + :param sess: Current session + :return: Nothing. It's a side-effect that a model gets restored for the network. + """ + if glob.glob(os.path.join(self.checkpoint_path, 'model.ckpt*.index')): latest_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path) @@ -150,6 +158,17 @@ class Network: def make_move(self, sess, board, roll, player): + """ + Find the best move given a board, roll and a player, by finding all possible states one can go to + and then picking the best, by using the network to evaluate each state. The highest score is picked + for the 1-player and the max(1-score) is picked for the -1-player. + + :param sess: + :param board: Current board + :param roll: Current roll + :param player: Current player + :return: A pair of the best state to go to, together with the score of that state + """ legal_moves = Board.calculate_legal_states(board, player, roll) moves_and_scores = [(move, self.eval_state(sess, self.board_trans_func(move, player))) for move in legal_moves] scores = [x[1] if np.sign(player) > 0 else 1-x[1] for x in moves_and_scores] @@ -189,7 +208,11 @@ class Network: # pythons reverse is in place and I can't call [:15] on it, without applying it to an object like so. Fuck. best_fifteen = sorted(zero_ply_moves_and_scores, key=itemgetter(1)) - best_fifteen.reverse() + + # They're sorted from smallest to largest, therefore we wan't to reverse if the current player is 1, since + # player 1 wishes to maximize. It's not needed for player -1, since that player seeks to minimize. + if player == 1: + best_fifteen.reverse() best_fifteen_boards = [x[0] for x in best_fifteen[:15]] all_rolls = gen_21_rolls() @@ -198,12 +221,13 @@ class Network: for a_board in best_fifteen_boards: a_board_scores = [] for roll in all_rolls: - spec_roll_scores = [] all_rolls_boards = Board.calculate_legal_states(a_board, player*-1, roll) - spec_roll_scores.append( - [self.eval_state(sess, self.board_trans_func(new_board, player*-1)) for new_board in all_rolls_boards] - ) + spec_roll_scores = [self.eval_state(sess, self.board_trans_func(new_board, player*-1)) + for new_board in all_rolls_boards] + + # We need 1-score for the -1 player + spec_roll_scores = [x if player == 1 else (1-x) for x in spec_roll_scores] best_score = max(spec_roll_scores) @@ -372,8 +396,7 @@ class Network: (random.randrange(1, 7), random.randrange(1, 7)), player) - # print("The evaluation of the previous state:\n", self.eval_state(sess, self.board_trans_func(prev_board, player))) - # print("The evaluation of the current_state:\n", cur_board_value) + # adjust weights sess.run(self.training_op,