From 695a3d43dbc875c1d46db4b86ab166ae3c0ff5ed Mon Sep 17 00:00:00 2001 From: Alexander Munch-Hansen Date: Tue, 1 May 2018 20:39:29 +0200 Subject: [PATCH] Fixed n_ply and actually added a comma in main.py. *clap Christoffer* --- main.py | 2 +- network.py | 5 ++++- network_test.py | 53 ++++++++++++++++++++++++++++++++++++++++++++----- 3 files changed, 53 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index 4df4f1c..bcf891d 100644 --- a/main.py +++ b/main.py @@ -53,7 +53,7 @@ config = { 'train_perpetually': args.train_perpetually, 'model_storage_path': 'models', 'bench_storage_path': 'bench', - 'board_representation': 'quack' + 'board_representation': 'quack', 'force_creation': args.force_creation } diff --git a/network.py b/network.py index bc8d601..9c0e1db 100644 --- a/network.py +++ b/network.py @@ -225,6 +225,9 @@ class Network: (4, 4), (4, 5), (4, 6), (5, 5), (5, 6), (6, 6) ] + # for roll in possible_rolls: + # print(len(Board.calculate_legal_states(board, player, roll))) + return [ Board.calculate_legal_states(board, player, roll) for roll in possible_rolls ] @@ -284,7 +287,7 @@ class Network: in boards ] average_score_pairs = [ - (inner_boards[0], average_plain([ average_for_score_pairs(ply(n - 1, inner_board, player * -1)) + (inner_boards[0], average_plain([ average_for_score_pairs(ply(n - 1, inner_board, player * -1 if n == 1 else player)) for inner_board in inner_boards[1] ])) for inner_boards diff --git a/network_test.py b/network_test.py index fb343aa..bc948c3 100644 --- a/network_test.py +++ b/network_test.py @@ -38,15 +38,58 @@ boards = {initial_state, initial_state_2 } +def gen_21_rolls(): + """ + Calculate all possible rolls, [[1,1], [1,2] ..] + :return: All possible rolls + """ + a = [] + for x in range(1, 7): + for y in range(1, 7): + if not [x, y] in a and not [y, x] in a: + a.append([x, y]) + + return a + +def calc_all_scores(board, player): + scores = [] + trans_board = network.board_trans_func(board, player) + rolls = gen_21_rolls() + for roll in rolls: + score = network.eval_state(session, trans_board) + scores.append(score) + return scores + + +def calculate_possible_states(board): + possible_rolls = [(1, 1), (1, 2), (1, 3), (1, 4), (1, 5), + (1, 6), (2, 2), (2, 3), (2, 4), (2, 5), + (2, 6), (3, 3), (3, 4), (3, 5), (3, 6), + (4, 4), (4, 5), (4, 6), (5, 5), (5, 6), + (6, 6)] + + for roll in possible_rolls: + meh = Board.calculate_legal_states(board, -1, roll) + print(len(meh)) + return [Board.calculate_legal_states(board, -1, roll) + for roll + in possible_rolls] + + + +#for board in boards: +# calculate_possible_states(board) + print("-"*30) print(network.do_ply(session, boards, 1)) -print(" "*10 + "network_test") -print(" "*20 + "Depth 1") -print(network.n_ply(1, session, boards, 1)) +#print(" "*10 + "network_test") +#print(" "*20 + "Depth 1") +scores = network.n_ply(1, session, boards, 1) -print(" "*20 + "Depth 2") -print(network.n_ply(2, session, boards, 1)) + +#print(" "*20 + "Depth 2") +#print(network.n_ply(2, session, boards, 1)) # #print(x.shape) # with graph_lol.as_default():