Board rep can now be inferred from file after being given once.
We can also evaluate multiple times by using the flag "--repeat-eval". The flag defaults to 1, if not provided.
This commit is contained in:
parent
c3f5e909d6
commit
ba4ef86bb5
45
main.py
45
main.py
|
@ -34,14 +34,15 @@ parser.add_argument('--list-models', action='store_true',
|
||||||
parser.add_argument('--force-creation', action='store_true',
|
parser.add_argument('--force-creation', action='store_true',
|
||||||
help='force model creation if model does not exist')
|
help='force model creation if model does not exist')
|
||||||
parser.add_argument('--board-rep', action='store', dest='board_rep',
|
parser.add_argument('--board-rep', action='store', dest='board_rep',
|
||||||
default='tesauro',
|
|
||||||
help='name of board representation to use as input to neural network')
|
help='name of board representation to use as input to neural network')
|
||||||
parser.add_argument('--use-baseline', action='store_true',
|
parser.add_argument('--use-baseline', action='store_true',
|
||||||
help='use the baseline model, note, has size 28')
|
help='use the baseline model, note, has size 28')
|
||||||
parser.add_argument('--verbose', action='store_true',
|
parser.add_argument('--verbose', action='store_true',
|
||||||
help='If set, a lot of stuff will be printed')
|
help='If set, a lot of stuff will be printed')
|
||||||
parser.add_argument('--ply', action='store', dest='ply',
|
parser.add_argument('--ply', action='store', dest='ply', default='0',
|
||||||
help='defines the amount of ply used when deciding what move to make')
|
help='defines the amount of ply used when deciding what move to make')
|
||||||
|
parser.add_argument('--repeat-eval', action='store', dest='repeat_eval', default='1',
|
||||||
|
help='the amount of times the evaluation method should be repeated')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -67,10 +68,11 @@ config = {
|
||||||
'use_baseline': args.use_baseline,
|
'use_baseline': args.use_baseline,
|
||||||
'global_step': 0,
|
'global_step': 0,
|
||||||
'verbose': args.verbose,
|
'verbose': args.verbose,
|
||||||
'ply': args.ply
|
'ply': args.ply,
|
||||||
|
'repeat_eval': args.repeat_eval
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# Create models folder
|
# Create models folder
|
||||||
if not os.path.exists(config['model_storage_path']):
|
if not os.path.exists(config['model_storage_path']):
|
||||||
os.makedirs(config['model_storage_path'])
|
os.makedirs(config['model_storage_path'])
|
||||||
|
@ -133,6 +135,24 @@ def log_bench_eval_outcomes(outcomes, log_path, index, time, trained_eps = 0):
|
||||||
with open(log_path, 'a+') as f:
|
with open(log_path, 'a+') as f:
|
||||||
f.write("{method};{count};{index};{time};{sum};{mean}".format(**format_vars) + "\n")
|
f.write("{method};{count};{index};{time};{sum};{mean}".format(**format_vars) + "\n")
|
||||||
|
|
||||||
|
def find_board_rep():
|
||||||
|
checkpoint_path = os.path.join(config['model_storage_path'], config['model'])
|
||||||
|
board_rep_path = os.path.join(checkpoint_path, "board_representation")
|
||||||
|
with open(board_rep_path, 'r') as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
def board_rep_file_exists():
|
||||||
|
checkpoint_path = os.path.join(config['model_storage_path'], config['model'])
|
||||||
|
board_rep_path = os.path.join(checkpoint_path, "board_representation")
|
||||||
|
return os.path.isfile(board_rep_path)
|
||||||
|
|
||||||
|
def create_board_rep():
|
||||||
|
checkpoint_path = os.path.join(config['model_storage_path'], config['model'])
|
||||||
|
board_rep_path = os.path.join(checkpoint_path, "board_representation")
|
||||||
|
with open(board_rep_path, 'a+') as f:
|
||||||
|
f.write(config['board_representation'])
|
||||||
|
|
||||||
# Do actions specified by command-line
|
# Do actions specified by command-line
|
||||||
if args.list_models:
|
if args.list_models:
|
||||||
def get_eps_trained(folder):
|
def get_eps_trained(folder):
|
||||||
|
@ -156,6 +176,22 @@ if __name__ == "__main__":
|
||||||
# Set up variables
|
# Set up variables
|
||||||
episode_count = config['episode_count']
|
episode_count = config['episode_count']
|
||||||
|
|
||||||
|
if config['board_representation'] is None:
|
||||||
|
if board_rep_file_exists():
|
||||||
|
config['board_representation'] = find_board_rep()
|
||||||
|
else:
|
||||||
|
sys.stderr.write("Was not given a board_rep and was unable to find a board_rep file\n")
|
||||||
|
exit()
|
||||||
|
else:
|
||||||
|
if not board_rep_file_exists():
|
||||||
|
create_board_rep()
|
||||||
|
else:
|
||||||
|
if config['board_representation'] != find_board_rep():
|
||||||
|
sys.stderr.write("Board representation \"{given}\", does not match one in board_rep file, \"{board_rep}\"\n".
|
||||||
|
format(given = config['board_representation'], board_rep = find_board_rep()))
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
if args.train:
|
if args.train:
|
||||||
network = Network(config, config['model'])
|
network = Network(config, config['model'])
|
||||||
start_episode = network.episodes_trained
|
start_episode = network.episodes_trained
|
||||||
|
@ -172,6 +208,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
elif args.eval:
|
elif args.eval:
|
||||||
network = Network(config, config['model'])
|
network = Network(config, config['model'])
|
||||||
|
for i in range(int(config['repeat_eval'])):
|
||||||
start_episode = network.episodes_trained
|
start_episode = network.episodes_trained
|
||||||
# Evaluation measures are described in `config`
|
# Evaluation measures are described in `config`
|
||||||
outcomes = network.eval(config['episode_count'])
|
outcomes = network.eval(config['episode_count'])
|
||||||
|
|
|
@ -177,6 +177,7 @@ class Network:
|
||||||
:return: Nothing. It's a side-effect that a model gets restored for the network.
|
:return: Nothing. It's a side-effect that a model gets restored for the network.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
if glob.glob(os.path.join(self.checkpoint_path, 'model.ckpt*.index')):
|
if glob.glob(os.path.join(self.checkpoint_path, 'model.ckpt*.index')):
|
||||||
|
|
||||||
latest_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path)
|
latest_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path)
|
||||||
|
@ -235,9 +236,9 @@ class Network:
|
||||||
:param player:
|
:param player:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
start = time.time()
|
# start = time.time()
|
||||||
best_pair = self.calculate_1_ply(board, roll, player)
|
best_pair = self.calculate_1_ply(board, roll, player)
|
||||||
print(time.time() - start)
|
# print(time.time() - start)
|
||||||
return best_pair
|
return best_pair
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user