More optimizations.

This commit is contained in:
Alexander Munch-Hansen 2018-05-15 23:37:35 +02:00
parent a77c13a0a4
commit 90fad334b9

View File

@ -292,55 +292,47 @@ class Network:
to this function.
"""
import time
all_rolls = [ (1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
(1, 6), (2, 2), (2, 3), (2, 4), (2, 5),
(2, 6), (3, 3), (3, 4), (3, 5), (3, 6),
(4, 4), (4, 5), (4, 6), (5, 5), (5, 6),
(6, 6) ]
def gen_21_rolls():
"""
Calculate all possible rolls, [[1,1], [1,2] ..]
:return: All possible rolls
"""
a = []
for x in range(1, 7):
for y in range(1, 7):
if not [x, y] in a and not [y, x] in a:
a.append([x, y])
return a
all_rolls = gen_21_rolls()
# start = time.time()
list_of_moves = []
# print("/"*50)
length_list = []
test_list = []
# Prepping of data
for idx, board in enumerate(boards):
all_board_moves = []
start= time.time()
for board in boards:
length = 0
for roll in all_rolls:
all_states = list(Board.calculate_legal_states(board, player*-1, roll))
for state in all_states:
state = np.array(self.board_trans_func(state, player*-1)[0])
all_board_moves.append(state)
test_list.append(state)
list_of_moves.append(np.array(all_board_moves))
list_of_lengths = [len(board) for board in list_of_moves]
length += 1
length_list.append(length)
# print(time.time() - start)
start = time.time()
all_scores_legit = self.model.predict_on_batch(np.array(test_list))
split_scores = []
from_idx = 0
for length in list_of_lengths:
for length in length_list:
split_scores.append(all_scores_legit[from_idx:from_idx+length])
from_idx += length
means_splits = [tf.reduce_mean(scores) for scores in split_scores]
transformed_means_splits = [x if player == 1 else (1-x) for x in means_splits]
# print(time.time() - start)
return ([means_splits, transformed_means_splits])
return ([means_sp5lits, transformed_means_splits])
def calc_n_ply(self, n_init, sess, board, player, roll):
@ -570,6 +562,10 @@ class Network:
def play_against_network(self):
"""
Allows you to play against a supplied model.
:return:
"""
self.restore_model()
human_player = Player(-1)
cur_player = 1
@ -593,7 +589,7 @@ class Network:
def train_model(self, episodes=1000, save_step_size=100, trained_eps=0):
"""
Train a model to by self-learning.
:param episodes:
:param save_step_size:
:param trained_eps: