diff --git a/network.py b/network.py index f299e50..9405924 100644 --- a/network.py +++ b/network.py @@ -327,12 +327,6 @@ class Network: list_of_lengths = [len(board) for board in list_of_moves] - start = time.time() - for i in range(len(test_list)): - self.model.predict_on_batch(np.array([state])) - print("Indiviual rolls:", time.time() - start) - all_scores = [self.model.predict_on_batch(board) for board in list_of_moves] - start = time.time() all_scores_legit = self.model.predict_on_batch(np.array(test_list)) @@ -343,23 +337,10 @@ class Network: split_scores.append(all_scores_legit[from_idx:from_idx+length]) from_idx += length - transformed_splits = [tf.reduce_mean(scores) for scores in split_scores] + means_splits = [tf.reduce_mean(scores) for scores in split_scores] + transformed_means_splits = [x if player == 1 else (1-x) for x in means_splits] - print(transformed_splits) - - - - - print("All in one:", time.time() - start) - - scores_means = [tf.reduce_mean(score) for score in all_scores] - - print(scores_means) - - transformed_means = [x if player == 1 else (1-x) for x in scores_means] - - # print(time.time() - start) - return ([scores_means, transformed_means]) + return ([means_splits, transformed_means_splits]) def calc_n_ply(self, n_init, sess, board, player, roll):