diff --git a/board.py b/board.py index d32197c..56c2737 100644 --- a/board.py +++ b/board.py @@ -40,7 +40,7 @@ class Board: def board_features_quack(board, player): board = list(board) board += ([1, 0] if np.sign(player) > 0 else [0, 1]) - return np.array(board).reshape(1, -1) + return np.array(board).reshape(1,28) # quack-fat @staticmethod @@ -51,7 +51,7 @@ class Board: board.append( 15 - sum(positives)) board.append(-15 - sum(negatives)) board += ([1, 0] if np.sign(player) > 0 else [0, 1]) - return np.array(board).reshape(1,-1) + return np.array(board).reshape(1,30) # quack-fatter @@ -68,7 +68,7 @@ class Board: board.append(15 - sum(positives)) board.append(-15 - sum(negatives)) board += ([1, 0] if np.sign(player) > 0 else [0, 1]) - return np.array(board).reshape(1, -1) + return np.array(board).reshape(1,30) # tesauro @staticmethod @@ -124,9 +124,9 @@ class Board: # Calculate how many pieces there must be in the home state and divide it by 15 features.append((15 - sum) / 15) features += ([1,0] if np.sign(cur_player) > 0 else [0,1]) - test = np.array(features).reshape(1,-1) + test = np.array(features) #print("TEST:",test) - return test + return test.reshape(1,198) diff --git a/network.py b/network.py index 56e183b..d14e1ea 100644 --- a/network.py +++ b/network.py @@ -183,8 +183,7 @@ class Network: legal_states = [list(tmp) for tmp in legal_moves] - legal_states = np.array([Board.board_features_quack_fat(tmp, player)[0] for tmp in legal_states]) - + legal_states = np.array([self.board_trans_func(tmp, player)[0] for tmp in legal_states]) scores = self.model.predict_on_batch(legal_states) transformed_scores = [x if np.sign(player) > 0 else 1 - x for x in scores] diff --git a/network_test.py b/network_test.py index 5fb6d6e..4f64612 100644 --- a/network_test.py +++ b/network_test.py @@ -36,46 +36,12 @@ boards = {initial_state, initial_state_2 } -def gen_21_rolls(): - """ - Calculate all possible rolls, [[1,1], [1,2] ..] - :return: All possible rolls - """ - a = [] - for x in range(1, 7): - for y in range(1, 7): - if not [x, y] in a and not [y, x] in a: - a.append([x, y]) - - return a -def calculate_possible_states(board): - possible_rolls = [(1, 1), (1, 2), (1, 3), (1, 4), (1, 5), - (1, 6), (2, 2), (2, 3), (2, 4), (2, 5), - (2, 6), (3, 3), (3, 4), (3, 5), (3, 6), - (4, 4), (4, 5), (4, 6), (5, 5), (5, 6), - (6, 6)] - - for roll in possible_rolls: - meh = Board.calculate_legal_states(board, -1, roll) - print(len(meh)) - return [Board.calculate_legal_states(board, -1, roll) - for roll - in possible_rolls] - - - -#for board in boards: -# calculate_possible_states(board) - -#print("-"*30) -#print(network.calculate_1_ply(session, Board.initial_state, [2,4], 1)) - board = network.board_trans_func(Board.initial_state, 1) -#print(board) + pair = network.make_move(Board.initial_state, [3,2], 1) @@ -83,26 +49,9 @@ print(pair[1]) network.do_backprop(board, 0.9) -network.save_model(2, 342) -# all_input = np.array([input for _ in range(20)]) -# print(network.calc_vals(all_input)) +network.print_variables() -#print(" "*10 + "network_test") -#print(" "*20 + "Depth 1") -#print(network.calc_n_ply(1, session, Board.initial_state, 1, [2, 4])) +network.save_model(2) -#print(scores) - -#print(" "*20 + "Depth 2") -#print(network.n_ply(2, session, boards, 1)) - -# #print(x.shape) -# with graph_lol.as_default(): -# session_2 = tf.Session(graph = graph_lol) -# network_2 = Network(session_2) -# network_2.restore_model() -# print(network_2.eval_state(initial_state)) - -# print(network.eval_state(initial_state))