More optimizations.
This commit is contained in:
parent
a77c13a0a4
commit
90fad334b9
48
network.py
48
network.py
|
@ -292,55 +292,47 @@ class Network:
|
|||
to this function.
|
||||
"""
|
||||
|
||||
import time
|
||||
all_rolls = [ (1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
|
||||
(1, 6), (2, 2), (2, 3), (2, 4), (2, 5),
|
||||
(2, 6), (3, 3), (3, 4), (3, 5), (3, 6),
|
||||
(4, 4), (4, 5), (4, 6), (5, 5), (5, 6),
|
||||
(6, 6) ]
|
||||
|
||||
def gen_21_rolls():
|
||||
"""
|
||||
Calculate all possible rolls, [[1,1], [1,2] ..]
|
||||
:return: All possible rolls
|
||||
"""
|
||||
a = []
|
||||
for x in range(1, 7):
|
||||
for y in range(1, 7):
|
||||
if not [x, y] in a and not [y, x] in a:
|
||||
a.append([x, y])
|
||||
|
||||
return a
|
||||
|
||||
all_rolls = gen_21_rolls()
|
||||
|
||||
# start = time.time()
|
||||
|
||||
list_of_moves = []
|
||||
# print("/"*50)
|
||||
length_list = []
|
||||
test_list = []
|
||||
# Prepping of data
|
||||
for idx, board in enumerate(boards):
|
||||
all_board_moves = []
|
||||
start= time.time()
|
||||
for board in boards:
|
||||
length = 0
|
||||
for roll in all_rolls:
|
||||
all_states = list(Board.calculate_legal_states(board, player*-1, roll))
|
||||
for state in all_states:
|
||||
state = np.array(self.board_trans_func(state, player*-1)[0])
|
||||
all_board_moves.append(state)
|
||||
test_list.append(state)
|
||||
list_of_moves.append(np.array(all_board_moves))
|
||||
|
||||
|
||||
list_of_lengths = [len(board) for board in list_of_moves]
|
||||
length += 1
|
||||
length_list.append(length)
|
||||
|
||||
# print(time.time() - start)
|
||||
|
||||
start = time.time()
|
||||
|
||||
all_scores_legit = self.model.predict_on_batch(np.array(test_list))
|
||||
|
||||
split_scores = []
|
||||
from_idx = 0
|
||||
for length in list_of_lengths:
|
||||
for length in length_list:
|
||||
split_scores.append(all_scores_legit[from_idx:from_idx+length])
|
||||
from_idx += length
|
||||
|
||||
means_splits = [tf.reduce_mean(scores) for scores in split_scores]
|
||||
transformed_means_splits = [x if player == 1 else (1-x) for x in means_splits]
|
||||
# print(time.time() - start)
|
||||
|
||||
return ([means_splits, transformed_means_splits])
|
||||
return ([means_sp5lits, transformed_means_splits])
|
||||
|
||||
|
||||
def calc_n_ply(self, n_init, sess, board, player, roll):
|
||||
|
@ -570,6 +562,10 @@ class Network:
|
|||
|
||||
|
||||
def play_against_network(self):
|
||||
"""
|
||||
Allows you to play against a supplied model.
|
||||
:return:
|
||||
"""
|
||||
self.restore_model()
|
||||
human_player = Player(-1)
|
||||
cur_player = 1
|
||||
|
@ -593,7 +589,7 @@ class Network:
|
|||
|
||||
def train_model(self, episodes=1000, save_step_size=100, trained_eps=0):
|
||||
"""
|
||||
|
||||
Train a model to by self-learning.
|
||||
:param episodes:
|
||||
:param save_step_size:
|
||||
:param trained_eps:
|
||||
|
|
Loading…
Reference in New Issue
Block a user