More optimizations.

This commit is contained in:
Alexander Munch-Hansen 2018-05-15 23:37:35 +02:00
parent a77c13a0a4
commit 90fad334b9

View File

@ -292,55 +292,47 @@ class Network:
to this function. to this function.
""" """
import time all_rolls = [ (1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
(1, 6), (2, 2), (2, 3), (2, 4), (2, 5),
(2, 6), (3, 3), (3, 4), (3, 5), (3, 6),
(4, 4), (4, 5), (4, 6), (5, 5), (5, 6),
(6, 6) ]
def gen_21_rolls():
"""
Calculate all possible rolls, [[1,1], [1,2] ..]
:return: All possible rolls
"""
a = []
for x in range(1, 7):
for y in range(1, 7):
if not [x, y] in a and not [y, x] in a:
a.append([x, y])
return a
all_rolls = gen_21_rolls()
# start = time.time() # start = time.time()
list_of_moves = [] # print("/"*50)
length_list = []
test_list = [] test_list = []
# Prepping of data # Prepping of data
for idx, board in enumerate(boards): start= time.time()
all_board_moves = [] for board in boards:
length = 0
for roll in all_rolls: for roll in all_rolls:
all_states = list(Board.calculate_legal_states(board, player*-1, roll)) all_states = list(Board.calculate_legal_states(board, player*-1, roll))
for state in all_states: for state in all_states:
state = np.array(self.board_trans_func(state, player*-1)[0]) state = np.array(self.board_trans_func(state, player*-1)[0])
all_board_moves.append(state)
test_list.append(state) test_list.append(state)
list_of_moves.append(np.array(all_board_moves)) length += 1
length_list.append(length)
list_of_lengths = [len(board) for board in list_of_moves]
# print(time.time() - start)
start = time.time() start = time.time()
all_scores_legit = self.model.predict_on_batch(np.array(test_list)) all_scores_legit = self.model.predict_on_batch(np.array(test_list))
split_scores = [] split_scores = []
from_idx = 0 from_idx = 0
for length in list_of_lengths: for length in length_list:
split_scores.append(all_scores_legit[from_idx:from_idx+length]) split_scores.append(all_scores_legit[from_idx:from_idx+length])
from_idx += length from_idx += length
means_splits = [tf.reduce_mean(scores) for scores in split_scores] means_splits = [tf.reduce_mean(scores) for scores in split_scores]
transformed_means_splits = [x if player == 1 else (1-x) for x in means_splits] transformed_means_splits = [x if player == 1 else (1-x) for x in means_splits]
# print(time.time() - start)
return ([means_splits, transformed_means_splits]) return ([means_sp5lits, transformed_means_splits])
def calc_n_ply(self, n_init, sess, board, player, roll): def calc_n_ply(self, n_init, sess, board, player, roll):
@ -570,6 +562,10 @@ class Network:
def play_against_network(self): def play_against_network(self):
"""
Allows you to play against a supplied model.
:return:
"""
self.restore_model() self.restore_model()
human_player = Player(-1) human_player = Player(-1)
cur_player = 1 cur_player = 1
@ -593,7 +589,7 @@ class Network:
def train_model(self, episodes=1000, save_step_size=100, trained_eps=0): def train_model(self, episodes=1000, save_step_size=100, trained_eps=0):
""" """
Train a model to by self-learning.
:param episodes: :param episodes:
:param save_step_size: :param save_step_size:
:param trained_eps: :param trained_eps: