From 504308a9af13efb0e15dd0cf2621e84c9dc15f63 Mon Sep 17 00:00:00 2001 From: Alexander Munch-Hansen Date: Thu, 10 May 2018 23:22:41 +0200 Subject: [PATCH] Yet another input argument, "--ply", 0 for no look-ahead, 1 for a single look-ahead. --- main.py | 5 ++++- network.py | 31 ++++++++++++++++++++----------- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index a276220..e3ded40 100644 --- a/main.py +++ b/main.py @@ -40,6 +40,8 @@ parser.add_argument('--use-baseline', action='store_true', help='use the baseline model, note, has size 28') parser.add_argument('--verbose', action='store_true', help='If set, a lot of stuff will be printed') +parser.add_argument('--ply', action='store', dest='ply', + help='defines the amount of ply used when deciding what move to make') args = parser.parse_args() @@ -64,7 +66,8 @@ config = { 'force_creation': args.force_creation, 'use_baseline': args.use_baseline, 'global_step': 0, - 'verbose': args.verbose + 'verbose': args.verbose, + 'ply': args.ply } diff --git a/network.py b/network.py index 070f0b3..6a91198 100644 --- a/network.py +++ b/network.py @@ -23,6 +23,7 @@ class Network: 'tesauro-poop': (198, Board.board_features_tesauro_wrong) } + def custom_tanh(self, x, name=None): return tf.scalar_mul(tf.constant(2.00), tf.tanh(x, name)) @@ -31,6 +32,12 @@ class Network: :param config: :param name: """ + + move_options = { + '1': self.make_move_1_ply, + '0': self.make_move_0_ply + } + tf.enable_eager_execution() xavier_init = tf.contrib.layers.xavier_initializer() @@ -40,6 +47,10 @@ class Network: self.name = name + self.make_move = move_options[ + self.config['ply'] + ] + # Set board representation from config self.input_size, self.board_trans_func = Network.board_reps[ self.config['board_representation'] @@ -191,7 +202,7 @@ class Network: - def make_move(self, board, roll, player): + def make_move_0_ply(self, board, roll, player): """ Find the best move given a board, roll and a player, by finding all possible states one can go to and then picking the best, by using the network to evaluate each state. The highest score is picked @@ -218,17 +229,16 @@ class Network: return [best_move, best_score] - def make_move_n_ply(self, sess, board, roll, player, n = 1): + def make_move_1_ply(self, board, roll, player): """ - - :param sess: :param board: :param roll: :param player: - :param n: :return: """ - best_pair = self.calc_n_ply(n, sess, board, player, roll) + start = time.time() + best_pair = self.calculate_1_ply(board, roll, player) + print(time.time() - start) return best_pair @@ -303,7 +313,7 @@ class Network: all_rolls = gen_21_rolls() - start = time.time() + # start = time.time() list_of_moves = [] @@ -318,9 +328,8 @@ class Network: list_of_moves.append(np.array(all_board_moves)) - print(time.time() - start) - - start = time.time() + # print(time.time() - start) + # start = time.time() # Running data through networks all_scores = [self.model.predict_on_batch(board) for board in list_of_moves] @@ -328,7 +337,7 @@ class Network: transformed_means = [x if player == 1 else (1-x) for x in scores_means] - print(time.time() - start) + # print(time.time() - start) return ([scores_means, transformed_means])