From 90fad334b917055aab543f03133d3c10ee1751f1 Mon Sep 17 00:00:00 2001 From: Alexander Munch-Hansen Date: Tue, 15 May 2018 23:37:35 +0200 Subject: [PATCH] More optimizations. --- network.py | 48 ++++++++++++++++++++++-------------------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/network.py b/network.py index 9405924..d61f458 100644 --- a/network.py +++ b/network.py @@ -292,55 +292,47 @@ class Network: to this function. """ - import time + all_rolls = [ (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), + (1, 6), (2, 2), (2, 3), (2, 4), (2, 5), + (2, 6), (3, 3), (3, 4), (3, 5), (3, 6), + (4, 4), (4, 5), (4, 6), (5, 5), (5, 6), + (6, 6) ] - def gen_21_rolls(): - """ - Calculate all possible rolls, [[1,1], [1,2] ..] - :return: All possible rolls - """ - a = [] - for x in range(1, 7): - for y in range(1, 7): - if not [x, y] in a and not [y, x] in a: - a.append([x, y]) - - return a - - all_rolls = gen_21_rolls() # start = time.time() - list_of_moves = [] + # print("/"*50) + length_list = [] test_list = [] # Prepping of data - for idx, board in enumerate(boards): - all_board_moves = [] + start= time.time() + for board in boards: + length = 0 for roll in all_rolls: all_states = list(Board.calculate_legal_states(board, player*-1, roll)) for state in all_states: state = np.array(self.board_trans_func(state, player*-1)[0]) - all_board_moves.append(state) test_list.append(state) - list_of_moves.append(np.array(all_board_moves)) - - - list_of_lengths = [len(board) for board in list_of_moves] + length += 1 + length_list.append(length) + # print(time.time() - start) start = time.time() + all_scores_legit = self.model.predict_on_batch(np.array(test_list)) split_scores = [] from_idx = 0 - for length in list_of_lengths: + for length in length_list: split_scores.append(all_scores_legit[from_idx:from_idx+length]) from_idx += length means_splits = [tf.reduce_mean(scores) for scores in split_scores] transformed_means_splits = [x if player == 1 else (1-x) for x in means_splits] + # print(time.time() - start) - return ([means_splits, transformed_means_splits]) + return ([means_sp5lits, transformed_means_splits]) def calc_n_ply(self, n_init, sess, board, player, roll): @@ -570,6 +562,10 @@ class Network: def play_against_network(self): + """ + Allows you to play against a supplied model. + :return: + """ self.restore_model() human_player = Player(-1) cur_player = 1 @@ -593,7 +589,7 @@ class Network: def train_model(self, episodes=1000, save_step_size=100, trained_eps=0): """ - + Train a model to by self-learning. :param episodes: :param save_step_size: :param trained_eps: