Yet another input argument, "--ply", 0 for no look-ahead, 1 for a single
look-ahead.
This commit is contained in:
parent
3b57c10b5a
commit
504308a9af
5
main.py
5
main.py
|
@ -40,6 +40,8 @@ parser.add_argument('--use-baseline', action='store_true',
|
||||||
help='use the baseline model, note, has size 28')
|
help='use the baseline model, note, has size 28')
|
||||||
parser.add_argument('--verbose', action='store_true',
|
parser.add_argument('--verbose', action='store_true',
|
||||||
help='If set, a lot of stuff will be printed')
|
help='If set, a lot of stuff will be printed')
|
||||||
|
parser.add_argument('--ply', action='store', dest='ply',
|
||||||
|
help='defines the amount of ply used when deciding what move to make')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -64,7 +66,8 @@ config = {
|
||||||
'force_creation': args.force_creation,
|
'force_creation': args.force_creation,
|
||||||
'use_baseline': args.use_baseline,
|
'use_baseline': args.use_baseline,
|
||||||
'global_step': 0,
|
'global_step': 0,
|
||||||
'verbose': args.verbose
|
'verbose': args.verbose,
|
||||||
|
'ply': args.ply
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
31
network.py
31
network.py
|
@ -23,6 +23,7 @@ class Network:
|
||||||
'tesauro-poop': (198, Board.board_features_tesauro_wrong)
|
'tesauro-poop': (198, Board.board_features_tesauro_wrong)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def custom_tanh(self, x, name=None):
|
def custom_tanh(self, x, name=None):
|
||||||
return tf.scalar_mul(tf.constant(2.00), tf.tanh(x, name))
|
return tf.scalar_mul(tf.constant(2.00), tf.tanh(x, name))
|
||||||
|
|
||||||
|
@ -31,6 +32,12 @@ class Network:
|
||||||
:param config:
|
:param config:
|
||||||
:param name:
|
:param name:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
move_options = {
|
||||||
|
'1': self.make_move_1_ply,
|
||||||
|
'0': self.make_move_0_ply
|
||||||
|
}
|
||||||
|
|
||||||
tf.enable_eager_execution()
|
tf.enable_eager_execution()
|
||||||
|
|
||||||
xavier_init = tf.contrib.layers.xavier_initializer()
|
xavier_init = tf.contrib.layers.xavier_initializer()
|
||||||
|
@ -40,6 +47,10 @@ class Network:
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
|
self.make_move = move_options[
|
||||||
|
self.config['ply']
|
||||||
|
]
|
||||||
|
|
||||||
# Set board representation from config
|
# Set board representation from config
|
||||||
self.input_size, self.board_trans_func = Network.board_reps[
|
self.input_size, self.board_trans_func = Network.board_reps[
|
||||||
self.config['board_representation']
|
self.config['board_representation']
|
||||||
|
@ -191,7 +202,7 @@ class Network:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def make_move(self, board, roll, player):
|
def make_move_0_ply(self, board, roll, player):
|
||||||
"""
|
"""
|
||||||
Find the best move given a board, roll and a player, by finding all possible states one can go to
|
Find the best move given a board, roll and a player, by finding all possible states one can go to
|
||||||
and then picking the best, by using the network to evaluate each state. The highest score is picked
|
and then picking the best, by using the network to evaluate each state. The highest score is picked
|
||||||
|
@ -218,17 +229,16 @@ class Network:
|
||||||
|
|
||||||
return [best_move, best_score]
|
return [best_move, best_score]
|
||||||
|
|
||||||
def make_move_n_ply(self, sess, board, roll, player, n = 1):
|
def make_move_1_ply(self, board, roll, player):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
:param sess:
|
|
||||||
:param board:
|
:param board:
|
||||||
:param roll:
|
:param roll:
|
||||||
:param player:
|
:param player:
|
||||||
:param n:
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
best_pair = self.calc_n_ply(n, sess, board, player, roll)
|
start = time.time()
|
||||||
|
best_pair = self.calculate_1_ply(board, roll, player)
|
||||||
|
print(time.time() - start)
|
||||||
return best_pair
|
return best_pair
|
||||||
|
|
||||||
|
|
||||||
|
@ -303,7 +313,7 @@ class Network:
|
||||||
|
|
||||||
all_rolls = gen_21_rolls()
|
all_rolls = gen_21_rolls()
|
||||||
|
|
||||||
start = time.time()
|
# start = time.time()
|
||||||
|
|
||||||
list_of_moves = []
|
list_of_moves = []
|
||||||
|
|
||||||
|
@ -318,9 +328,8 @@ class Network:
|
||||||
list_of_moves.append(np.array(all_board_moves))
|
list_of_moves.append(np.array(all_board_moves))
|
||||||
|
|
||||||
|
|
||||||
print(time.time() - start)
|
# print(time.time() - start)
|
||||||
|
# start = time.time()
|
||||||
start = time.time()
|
|
||||||
|
|
||||||
# Running data through networks
|
# Running data through networks
|
||||||
all_scores = [self.model.predict_on_batch(board) for board in list_of_moves]
|
all_scores = [self.model.predict_on_batch(board) for board in list_of_moves]
|
||||||
|
@ -328,7 +337,7 @@ class Network:
|
||||||
|
|
||||||
transformed_means = [x if player == 1 else (1-x) for x in scores_means]
|
transformed_means = [x if player == 1 else (1-x) for x in scores_means]
|
||||||
|
|
||||||
print(time.time() - start)
|
# print(time.time() - start)
|
||||||
return ([scores_means, transformed_means])
|
return ([scores_means, transformed_means])
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user