diff --git a/network.py b/network.py index 82cd095..ebe7d8f 100644 --- a/network.py +++ b/network.py @@ -211,6 +211,96 @@ class Network: return [best_board, max(all_rolls_scores)] + + def n_ply(self, n_init, sess, boards_init, player_init): + def ply(n, boards, player): + def calculate_possible_states(board): + possible_rolls = [ (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), + (1, 6), (2, 2), (2, 3), (2, 4), (2, 5), + (2, 6), (3, 3), (3, 4), (3, 5), (3, 6), + (4, 4), (4, 5), (4, 6), (5, 5), (5, 6), + (6, 6) ] + + return [ Board.calculate_legal_states(board, player, roll) + for roll + in possible_rolls ] + + def find_best_state_score(boards): + score_pairs = [ (board, self.eval_state(sess, self.board_trans_func(board, player))) + for board + in boards ] + scores = [ pair[1] + for pair + in score_pairs ] + best_score_pair = score_pairs[np.array(scores).argmax()] + + return best_score_pair + + def average_score(boards): + return sum(boards)/len(boards) + + def average_ply_score(board): + states_for_rolls = calculate_possible_states(board) + + best_state_score_for_each_roll = [ + find_best_state_score(states) + for states + in states_for_rolls ] + best_score_for_each_roll = [ x[1] + for x + in best_state_score_for_each_roll ] + + average_score_var = average_score(best_score_for_each_roll) + return average_score_var + + + if n == 1: + print("blalhlalha") + average_score_pairs = [ (board, average_ply_score(board)) + for board + in boards ] + return average_score_pairs + elif n > 1: # n != 1 + def average_for_score_pairs(score_pairs): + scores = [ pair[1] + for pair + in score_pairs ] + return sum(scores)/len(scores) + + def average_plain(scores): + return sum(scores)/len(scores) + + print("+"*20) + print(n) + print(type(boards)) + print(boards) + possible_states_for_boards = [ + (board, calculate_possible_states(board)) + for board + in boards ] + + average_score_pairs = [ + (inner_boards[0], average_plain([ average_for_score_pairs(ply(n - 1, inner_board, player * -1)) + for inner_board + in inner_boards[1] ])) + for inner_boards + in possible_states_for_boards ] + + return average_score_pairs + + else: + assert False + + if n_init < 1: print("Unexpected argument n = {}".format(n_init)); exit() + + boards_with_scores = ply(n_init, boards_init, player_init) + print(boards_with_scores) + scores = [ pair[1] + for pair + in boards_with_scores ] + best_score_pair = boards_with_scores[np.array(scores).argmax()] + return best_score_pair[0] + def do_ply(self, sess, boards, player): """ Calculates a single extra ply, resulting in a larger search space for our best move.