Board rep can now be inferred from file after being given once.
We can also evaluate multiple times by using the flag "--repeat-eval". The flag defaults to 1, if not provided.
This commit is contained in:
parent
c3f5e909d6
commit
ba4ef86bb5
57
main.py
57
main.py
|
@ -34,14 +34,15 @@ parser.add_argument('--list-models', action='store_true',
|
|||
parser.add_argument('--force-creation', action='store_true',
|
||||
help='force model creation if model does not exist')
|
||||
parser.add_argument('--board-rep', action='store', dest='board_rep',
|
||||
default='tesauro',
|
||||
help='name of board representation to use as input to neural network')
|
||||
parser.add_argument('--use-baseline', action='store_true',
|
||||
help='use the baseline model, note, has size 28')
|
||||
parser.add_argument('--verbose', action='store_true',
|
||||
help='If set, a lot of stuff will be printed')
|
||||
parser.add_argument('--ply', action='store', dest='ply',
|
||||
parser.add_argument('--ply', action='store', dest='ply', default='0',
|
||||
help='defines the amount of ply used when deciding what move to make')
|
||||
parser.add_argument('--repeat-eval', action='store', dest='repeat_eval', default='1',
|
||||
help='the amount of times the evaluation method should be repeated')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -67,10 +68,11 @@ config = {
|
|||
'use_baseline': args.use_baseline,
|
||||
'global_step': 0,
|
||||
'verbose': args.verbose,
|
||||
'ply': args.ply
|
||||
|
||||
'ply': args.ply,
|
||||
'repeat_eval': args.repeat_eval
|
||||
}
|
||||
|
||||
|
||||
# Create models folder
|
||||
if not os.path.exists(config['model_storage_path']):
|
||||
os.makedirs(config['model_storage_path'])
|
||||
|
@ -133,6 +135,24 @@ def log_bench_eval_outcomes(outcomes, log_path, index, time, trained_eps = 0):
|
|||
with open(log_path, 'a+') as f:
|
||||
f.write("{method};{count};{index};{time};{sum};{mean}".format(**format_vars) + "\n")
|
||||
|
||||
def find_board_rep():
|
||||
checkpoint_path = os.path.join(config['model_storage_path'], config['model'])
|
||||
board_rep_path = os.path.join(checkpoint_path, "board_representation")
|
||||
with open(board_rep_path, 'r') as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def board_rep_file_exists():
|
||||
checkpoint_path = os.path.join(config['model_storage_path'], config['model'])
|
||||
board_rep_path = os.path.join(checkpoint_path, "board_representation")
|
||||
return os.path.isfile(board_rep_path)
|
||||
|
||||
def create_board_rep():
|
||||
checkpoint_path = os.path.join(config['model_storage_path'], config['model'])
|
||||
board_rep_path = os.path.join(checkpoint_path, "board_representation")
|
||||
with open(board_rep_path, 'a+') as f:
|
||||
f.write(config['board_representation'])
|
||||
|
||||
# Do actions specified by command-line
|
||||
if args.list_models:
|
||||
def get_eps_trained(folder):
|
||||
|
@ -155,6 +175,22 @@ if __name__ == "__main__":
|
|||
|
||||
# Set up variables
|
||||
episode_count = config['episode_count']
|
||||
|
||||
if config['board_representation'] is None:
|
||||
if board_rep_file_exists():
|
||||
config['board_representation'] = find_board_rep()
|
||||
else:
|
||||
sys.stderr.write("Was not given a board_rep and was unable to find a board_rep file\n")
|
||||
exit()
|
||||
else:
|
||||
if not board_rep_file_exists():
|
||||
create_board_rep()
|
||||
else:
|
||||
if config['board_representation'] != find_board_rep():
|
||||
sys.stderr.write("Board representation \"{given}\", does not match one in board_rep file, \"{board_rep}\"\n".
|
||||
format(given = config['board_representation'], board_rep = find_board_rep()))
|
||||
exit()
|
||||
|
||||
|
||||
if args.train:
|
||||
network = Network(config, config['model'])
|
||||
|
@ -172,12 +208,13 @@ if __name__ == "__main__":
|
|||
|
||||
elif args.eval:
|
||||
network = Network(config, config['model'])
|
||||
start_episode = network.episodes_trained
|
||||
# Evaluation measures are described in `config`
|
||||
outcomes = network.eval(config['episode_count'])
|
||||
log_eval_outcomes(outcomes, trained_eps = start_episode)
|
||||
# elif args.play:
|
||||
# g.play(episodes = episode_count)
|
||||
for i in range(int(config['repeat_eval'])):
|
||||
start_episode = network.episodes_trained
|
||||
# Evaluation measures are described in `config`
|
||||
outcomes = network.eval(config['episode_count'])
|
||||
log_eval_outcomes(outcomes, trained_eps = start_episode)
|
||||
# elif args.play:
|
||||
# g.play(episodes = episode_count)
|
||||
|
||||
|
||||
elif args.bench_eval_scores:
|
||||
|
|
|
@ -177,6 +177,7 @@ class Network:
|
|||
:return: Nothing. It's a side-effect that a model gets restored for the network.
|
||||
"""
|
||||
|
||||
|
||||
if glob.glob(os.path.join(self.checkpoint_path, 'model.ckpt*.index')):
|
||||
|
||||
latest_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path)
|
||||
|
@ -235,9 +236,9 @@ class Network:
|
|||
:param player:
|
||||
:return:
|
||||
"""
|
||||
start = time.time()
|
||||
# start = time.time()
|
||||
best_pair = self.calculate_1_ply(board, roll, player)
|
||||
print(time.time() - start)
|
||||
# print(time.time() - start)
|
||||
return best_pair
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user