renaming parameters

This commit is contained in:
Christoffer Müller Madsen 2018-03-12 00:11:55 +01:00
parent 90b97da4ff
commit 55898d0e66
3 changed files with 15 additions and 9 deletions

View File

@ -56,7 +56,7 @@ The following examples describe commmon operations.
*** Evaluate model named =quack= using default evaluation method (currently =random=) *** Evaluate model named =quack= using default evaluation method (currently =random=)
=python3 --eval --model-name=quack= =python3 --eval --model=quack=
*** Evaluate default model using evaluation methods =random= and =pubeval= *** Evaluate default model using evaluation methods =random= and =pubeval=
@ -73,14 +73,14 @@ directory. Otherwise, the model is stored in =models/$MODEL=.
Along with the Tensorflow checkpoint files in the directory, the following files Along with the Tensorflow checkpoint files in the directory, the following files
are stored: are stored:
- =model.episodes=: The number of episodes of training performed with the - =episodes_trained=: The number of episodes of training performed with the
model model
- =logs/eval.log=: Log of all completed evaluations performed on the model. The - =logs/eval.log=: Log of all completed evaluations performed on the model. The
format of this file is specified in [[Log format]]. format of this file is specified in [[Log format]].
- =logs/train.log=: Log of all completed training sessions performed on the - =logs/train.log=: Log of all completed training sessions performed on the
model. If a training session is aborted before the pre-specified episode model. If a training session is aborted before the pre-specified episode
target is reached, nothing will be written to this file, although target is reached, nothing will be written to this file, although
=model.episodes= will be updated every time the model is saved to disk. The =episodes_trained= will be updated every time the model is saved to disk. The
format of this file is specified in [[Log format]]. format of this file is specified in [[Log format]].
** Log format ** Log format

14
main.py
View File

@ -3,6 +3,12 @@ import sys
import os import os
import time import time
models_storage_path = 'models'
# Create models folder
if not os.path.exists(models_storage_path):
os.makedirs(models_storage_path)
# Define helper functions # Define helper functions
def log_train_outcome(outcome, trained_eps = 0): def log_train_outcome(outcome, trained_eps = 0):
format_vars = { 'trained_eps': trained_eps, format_vars = { 'trained_eps': trained_eps,
@ -34,9 +40,9 @@ parser = argparse.ArgumentParser(description="Backgammon games")
parser.add_argument('--episodes', action='store', dest='episode_count', parser.add_argument('--episodes', action='store', dest='episode_count',
type=int, default=1000, type=int, default=1000,
help='number of episodes to train') help='number of episodes to train')
parser.add_argument('--model-path', action='store', dest='model_path', parser.add_argument('--model', action='store', dest='model',
default='./model', default='default',
help='path to Tensorflow model') help='name of Tensorflow model to use')
parser.add_argument('--eval-methods', action='store', parser.add_argument('--eval-methods', action='store',
default=['random'], nargs='*', default=['random'], nargs='*',
help='specifies evaluation methods') help='specifies evaluation methods')
@ -55,7 +61,7 @@ parser.add_argument('--start-episode', action='store', dest='start_episode',
args = parser.parse_args() args = parser.parse_args()
config = { config = {
'model_path': args.model_path, 'model_path': os.path.join(models_storage_path, args.model),
'episode_count': args.episode_count, 'episode_count': args.episode_count,
'eval_methods': args.eval_methods, 'eval_methods': args.eval_methods,
'train': args.train, 'train': args.train,

View File

@ -26,7 +26,7 @@ class Network:
# Restore trained episode count for model # Restore trained episode count for model
episode_count_path = os.path.join(self.checkpoint_path, "model.episodes") episode_count_path = os.path.join(self.checkpoint_path, "episodes_trained")
if os.path.isfile(episode_count_path): if os.path.isfile(episode_count_path):
with open(episode_count_path, 'r') as f: with open(episode_count_path, 'r') as f:
self.config['start_episode'] = int(f.read()) self.config['start_episode'] = int(f.read())
@ -91,7 +91,7 @@ class Network:
def save_model(self, episode_count): def save_model(self, episode_count):
self.saver.save(self.session, os.path.join(self.checkpoint_path, 'model.ckpt')) self.saver.save(self.session, os.path.join(self.checkpoint_path, 'model.ckpt'))
with open(os.path.join(self.checkpoint_path, "model.episodes"), 'w+') as f: with open(os.path.join(self.checkpoint_path, "episodes_trained"), 'w+') as f:
f.write(str(episode_count) + "\n") f.write(str(episode_count) + "\n")
def restore_model(self): def restore_model(self):