More optimizations.
This commit is contained in:
parent
a77c13a0a4
commit
90fad334b9
48
network.py
48
network.py
|
@ -292,55 +292,47 @@ class Network:
|
||||||
to this function.
|
to this function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
all_rolls = [ (1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
|
||||||
|
(1, 6), (2, 2), (2, 3), (2, 4), (2, 5),
|
||||||
|
(2, 6), (3, 3), (3, 4), (3, 5), (3, 6),
|
||||||
|
(4, 4), (4, 5), (4, 6), (5, 5), (5, 6),
|
||||||
|
(6, 6) ]
|
||||||
|
|
||||||
def gen_21_rolls():
|
|
||||||
"""
|
|
||||||
Calculate all possible rolls, [[1,1], [1,2] ..]
|
|
||||||
:return: All possible rolls
|
|
||||||
"""
|
|
||||||
a = []
|
|
||||||
for x in range(1, 7):
|
|
||||||
for y in range(1, 7):
|
|
||||||
if not [x, y] in a and not [y, x] in a:
|
|
||||||
a.append([x, y])
|
|
||||||
|
|
||||||
return a
|
|
||||||
|
|
||||||
all_rolls = gen_21_rolls()
|
|
||||||
|
|
||||||
# start = time.time()
|
# start = time.time()
|
||||||
|
|
||||||
list_of_moves = []
|
# print("/"*50)
|
||||||
|
length_list = []
|
||||||
test_list = []
|
test_list = []
|
||||||
# Prepping of data
|
# Prepping of data
|
||||||
for idx, board in enumerate(boards):
|
start= time.time()
|
||||||
all_board_moves = []
|
for board in boards:
|
||||||
|
length = 0
|
||||||
for roll in all_rolls:
|
for roll in all_rolls:
|
||||||
all_states = list(Board.calculate_legal_states(board, player*-1, roll))
|
all_states = list(Board.calculate_legal_states(board, player*-1, roll))
|
||||||
for state in all_states:
|
for state in all_states:
|
||||||
state = np.array(self.board_trans_func(state, player*-1)[0])
|
state = np.array(self.board_trans_func(state, player*-1)[0])
|
||||||
all_board_moves.append(state)
|
|
||||||
test_list.append(state)
|
test_list.append(state)
|
||||||
list_of_moves.append(np.array(all_board_moves))
|
length += 1
|
||||||
|
length_list.append(length)
|
||||||
|
|
||||||
list_of_lengths = [len(board) for board in list_of_moves]
|
|
||||||
|
|
||||||
|
# print(time.time() - start)
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
all_scores_legit = self.model.predict_on_batch(np.array(test_list))
|
all_scores_legit = self.model.predict_on_batch(np.array(test_list))
|
||||||
|
|
||||||
split_scores = []
|
split_scores = []
|
||||||
from_idx = 0
|
from_idx = 0
|
||||||
for length in list_of_lengths:
|
for length in length_list:
|
||||||
split_scores.append(all_scores_legit[from_idx:from_idx+length])
|
split_scores.append(all_scores_legit[from_idx:from_idx+length])
|
||||||
from_idx += length
|
from_idx += length
|
||||||
|
|
||||||
means_splits = [tf.reduce_mean(scores) for scores in split_scores]
|
means_splits = [tf.reduce_mean(scores) for scores in split_scores]
|
||||||
transformed_means_splits = [x if player == 1 else (1-x) for x in means_splits]
|
transformed_means_splits = [x if player == 1 else (1-x) for x in means_splits]
|
||||||
|
# print(time.time() - start)
|
||||||
|
|
||||||
return ([means_splits, transformed_means_splits])
|
return ([means_sp5lits, transformed_means_splits])
|
||||||
|
|
||||||
|
|
||||||
def calc_n_ply(self, n_init, sess, board, player, roll):
|
def calc_n_ply(self, n_init, sess, board, player, roll):
|
||||||
|
@ -570,6 +562,10 @@ class Network:
|
||||||
|
|
||||||
|
|
||||||
def play_against_network(self):
|
def play_against_network(self):
|
||||||
|
"""
|
||||||
|
Allows you to play against a supplied model.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
self.restore_model()
|
self.restore_model()
|
||||||
human_player = Player(-1)
|
human_player = Player(-1)
|
||||||
cur_player = 1
|
cur_player = 1
|
||||||
|
@ -593,7 +589,7 @@ class Network:
|
||||||
|
|
||||||
def train_model(self, episodes=1000, save_step_size=100, trained_eps=0):
|
def train_model(self, episodes=1000, save_step_size=100, trained_eps=0):
|
||||||
"""
|
"""
|
||||||
|
Train a model to by self-learning.
|
||||||
:param episodes:
|
:param episodes:
|
||||||
:param save_step_size:
|
:param save_step_size:
|
||||||
:param trained_eps:
|
:param trained_eps:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user