1-ply runs even faster.
This commit is contained in:
parent
260c32d909
commit
a77c13a0a4
25
network.py
25
network.py
|
@ -327,12 +327,6 @@ class Network:
|
||||||
|
|
||||||
list_of_lengths = [len(board) for board in list_of_moves]
|
list_of_lengths = [len(board) for board in list_of_moves]
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
for i in range(len(test_list)):
|
|
||||||
self.model.predict_on_batch(np.array([state]))
|
|
||||||
print("Indiviual rolls:", time.time() - start)
|
|
||||||
all_scores = [self.model.predict_on_batch(board) for board in list_of_moves]
|
|
||||||
|
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
all_scores_legit = self.model.predict_on_batch(np.array(test_list))
|
all_scores_legit = self.model.predict_on_batch(np.array(test_list))
|
||||||
|
@ -343,23 +337,10 @@ class Network:
|
||||||
split_scores.append(all_scores_legit[from_idx:from_idx+length])
|
split_scores.append(all_scores_legit[from_idx:from_idx+length])
|
||||||
from_idx += length
|
from_idx += length
|
||||||
|
|
||||||
transformed_splits = [tf.reduce_mean(scores) for scores in split_scores]
|
means_splits = [tf.reduce_mean(scores) for scores in split_scores]
|
||||||
|
transformed_means_splits = [x if player == 1 else (1-x) for x in means_splits]
|
||||||
|
|
||||||
print(transformed_splits)
|
return ([means_splits, transformed_means_splits])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
print("All in one:", time.time() - start)
|
|
||||||
|
|
||||||
scores_means = [tf.reduce_mean(score) for score in all_scores]
|
|
||||||
|
|
||||||
print(scores_means)
|
|
||||||
|
|
||||||
transformed_means = [x if player == 1 else (1-x) for x in scores_means]
|
|
||||||
|
|
||||||
# print(time.time() - start)
|
|
||||||
return ([scores_means, transformed_means])
|
|
||||||
|
|
||||||
|
|
||||||
def calc_n_ply(self, n_init, sess, board, player, roll):
|
def calc_n_ply(self, n_init, sess, board, player, roll):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user