Yet another input argument, "--ply", 0 for no look-ahead, 1 for a single

look-ahead.
This commit is contained in:
Alexander Munch-Hansen 2018-05-10 23:22:41 +02:00
parent 3b57c10b5a
commit 504308a9af
2 changed files with 24 additions and 12 deletions

View File

@ -40,6 +40,8 @@ parser.add_argument('--use-baseline', action='store_true',
help='use the baseline model, note, has size 28') help='use the baseline model, note, has size 28')
parser.add_argument('--verbose', action='store_true', parser.add_argument('--verbose', action='store_true',
help='If set, a lot of stuff will be printed') help='If set, a lot of stuff will be printed')
parser.add_argument('--ply', action='store', dest='ply',
help='defines the amount of ply used when deciding what move to make')
args = parser.parse_args() args = parser.parse_args()
@ -64,7 +66,8 @@ config = {
'force_creation': args.force_creation, 'force_creation': args.force_creation,
'use_baseline': args.use_baseline, 'use_baseline': args.use_baseline,
'global_step': 0, 'global_step': 0,
'verbose': args.verbose 'verbose': args.verbose,
'ply': args.ply
} }

View File

@ -23,6 +23,7 @@ class Network:
'tesauro-poop': (198, Board.board_features_tesauro_wrong) 'tesauro-poop': (198, Board.board_features_tesauro_wrong)
} }
def custom_tanh(self, x, name=None): def custom_tanh(self, x, name=None):
return tf.scalar_mul(tf.constant(2.00), tf.tanh(x, name)) return tf.scalar_mul(tf.constant(2.00), tf.tanh(x, name))
@ -31,6 +32,12 @@ class Network:
:param config: :param config:
:param name: :param name:
""" """
move_options = {
'1': self.make_move_1_ply,
'0': self.make_move_0_ply
}
tf.enable_eager_execution() tf.enable_eager_execution()
xavier_init = tf.contrib.layers.xavier_initializer() xavier_init = tf.contrib.layers.xavier_initializer()
@ -40,6 +47,10 @@ class Network:
self.name = name self.name = name
self.make_move = move_options[
self.config['ply']
]
# Set board representation from config # Set board representation from config
self.input_size, self.board_trans_func = Network.board_reps[ self.input_size, self.board_trans_func = Network.board_reps[
self.config['board_representation'] self.config['board_representation']
@ -191,7 +202,7 @@ class Network:
def make_move(self, board, roll, player): def make_move_0_ply(self, board, roll, player):
""" """
Find the best move given a board, roll and a player, by finding all possible states one can go to Find the best move given a board, roll and a player, by finding all possible states one can go to
and then picking the best, by using the network to evaluate each state. The highest score is picked and then picking the best, by using the network to evaluate each state. The highest score is picked
@ -218,17 +229,16 @@ class Network:
return [best_move, best_score] return [best_move, best_score]
def make_move_n_ply(self, sess, board, roll, player, n = 1): def make_move_1_ply(self, board, roll, player):
""" """
:param sess:
:param board: :param board:
:param roll: :param roll:
:param player: :param player:
:param n:
:return: :return:
""" """
best_pair = self.calc_n_ply(n, sess, board, player, roll) start = time.time()
best_pair = self.calculate_1_ply(board, roll, player)
print(time.time() - start)
return best_pair return best_pair
@ -303,7 +313,7 @@ class Network:
all_rolls = gen_21_rolls() all_rolls = gen_21_rolls()
start = time.time() # start = time.time()
list_of_moves = [] list_of_moves = []
@ -318,9 +328,8 @@ class Network:
list_of_moves.append(np.array(all_board_moves)) list_of_moves.append(np.array(all_board_moves))
print(time.time() - start) # print(time.time() - start)
# start = time.time()
start = time.time()
# Running data through networks # Running data through networks
all_scores = [self.model.predict_on_batch(board) for board in list_of_moves] all_scores = [self.model.predict_on_batch(board) for board in list_of_moves]
@ -328,7 +337,7 @@ class Network:
transformed_means = [x if player == 1 else (1-x) for x in scores_means] transformed_means = [x if player == 1 else (1-x) for x in scores_means]
print(time.time() - start) # print(time.time() - start)
return ([scores_means, transformed_means]) return ([scores_means, transformed_means])