renaming parameters
This commit is contained in:
parent
90b97da4ff
commit
55898d0e66
|
@ -56,7 +56,7 @@ The following examples describe commmon operations.
|
|||
|
||||
*** Evaluate model named =quack= using default evaluation method (currently =random=)
|
||||
|
||||
=python3 --eval --model-name=quack=
|
||||
=python3 --eval --model=quack=
|
||||
|
||||
*** Evaluate default model using evaluation methods =random= and =pubeval=
|
||||
|
||||
|
@ -73,14 +73,14 @@ directory. Otherwise, the model is stored in =models/$MODEL=.
|
|||
Along with the Tensorflow checkpoint files in the directory, the following files
|
||||
are stored:
|
||||
|
||||
- =model.episodes=: The number of episodes of training performed with the
|
||||
- =episodes_trained=: The number of episodes of training performed with the
|
||||
model
|
||||
- =logs/eval.log=: Log of all completed evaluations performed on the model. The
|
||||
format of this file is specified in [[Log format]].
|
||||
- =logs/train.log=: Log of all completed training sessions performed on the
|
||||
model. If a training session is aborted before the pre-specified episode
|
||||
target is reached, nothing will be written to this file, although
|
||||
=model.episodes= will be updated every time the model is saved to disk. The
|
||||
=episodes_trained= will be updated every time the model is saved to disk. The
|
||||
format of this file is specified in [[Log format]].
|
||||
|
||||
** Log format
|
||||
|
|
14
main.py
14
main.py
|
@ -3,6 +3,12 @@ import sys
|
|||
import os
|
||||
import time
|
||||
|
||||
models_storage_path = 'models'
|
||||
|
||||
# Create models folder
|
||||
if not os.path.exists(models_storage_path):
|
||||
os.makedirs(models_storage_path)
|
||||
|
||||
# Define helper functions
|
||||
def log_train_outcome(outcome, trained_eps = 0):
|
||||
format_vars = { 'trained_eps': trained_eps,
|
||||
|
@ -34,9 +40,9 @@ parser = argparse.ArgumentParser(description="Backgammon games")
|
|||
parser.add_argument('--episodes', action='store', dest='episode_count',
|
||||
type=int, default=1000,
|
||||
help='number of episodes to train')
|
||||
parser.add_argument('--model-path', action='store', dest='model_path',
|
||||
default='./model',
|
||||
help='path to Tensorflow model')
|
||||
parser.add_argument('--model', action='store', dest='model',
|
||||
default='default',
|
||||
help='name of Tensorflow model to use')
|
||||
parser.add_argument('--eval-methods', action='store',
|
||||
default=['random'], nargs='*',
|
||||
help='specifies evaluation methods')
|
||||
|
@ -55,7 +61,7 @@ parser.add_argument('--start-episode', action='store', dest='start_episode',
|
|||
args = parser.parse_args()
|
||||
|
||||
config = {
|
||||
'model_path': args.model_path,
|
||||
'model_path': os.path.join(models_storage_path, args.model),
|
||||
'episode_count': args.episode_count,
|
||||
'eval_methods': args.eval_methods,
|
||||
'train': args.train,
|
||||
|
|
|
@ -26,7 +26,7 @@ class Network:
|
|||
|
||||
|
||||
# Restore trained episode count for model
|
||||
episode_count_path = os.path.join(self.checkpoint_path, "model.episodes")
|
||||
episode_count_path = os.path.join(self.checkpoint_path, "episodes_trained")
|
||||
if os.path.isfile(episode_count_path):
|
||||
with open(episode_count_path, 'r') as f:
|
||||
self.config['start_episode'] = int(f.read())
|
||||
|
@ -91,7 +91,7 @@ class Network:
|
|||
|
||||
def save_model(self, episode_count):
|
||||
self.saver.save(self.session, os.path.join(self.checkpoint_path, 'model.ckpt'))
|
||||
with open(os.path.join(self.checkpoint_path, "model.episodes"), 'w+') as f:
|
||||
with open(os.path.join(self.checkpoint_path, "episodes_trained"), 'w+') as f:
|
||||
f.write(str(episode_count) + "\n")
|
||||
|
||||
def restore_model(self):
|
||||
|
|
Loading…
Reference in New Issue
Block a user