renaming parameters
This commit is contained in:
parent
90b97da4ff
commit
55898d0e66
|
@ -56,7 +56,7 @@ The following examples describe commmon operations.
|
||||||
|
|
||||||
*** Evaluate model named =quack= using default evaluation method (currently =random=)
|
*** Evaluate model named =quack= using default evaluation method (currently =random=)
|
||||||
|
|
||||||
=python3 --eval --model-name=quack=
|
=python3 --eval --model=quack=
|
||||||
|
|
||||||
*** Evaluate default model using evaluation methods =random= and =pubeval=
|
*** Evaluate default model using evaluation methods =random= and =pubeval=
|
||||||
|
|
||||||
|
@ -73,14 +73,14 @@ directory. Otherwise, the model is stored in =models/$MODEL=.
|
||||||
Along with the Tensorflow checkpoint files in the directory, the following files
|
Along with the Tensorflow checkpoint files in the directory, the following files
|
||||||
are stored:
|
are stored:
|
||||||
|
|
||||||
- =model.episodes=: The number of episodes of training performed with the
|
- =episodes_trained=: The number of episodes of training performed with the
|
||||||
model
|
model
|
||||||
- =logs/eval.log=: Log of all completed evaluations performed on the model. The
|
- =logs/eval.log=: Log of all completed evaluations performed on the model. The
|
||||||
format of this file is specified in [[Log format]].
|
format of this file is specified in [[Log format]].
|
||||||
- =logs/train.log=: Log of all completed training sessions performed on the
|
- =logs/train.log=: Log of all completed training sessions performed on the
|
||||||
model. If a training session is aborted before the pre-specified episode
|
model. If a training session is aborted before the pre-specified episode
|
||||||
target is reached, nothing will be written to this file, although
|
target is reached, nothing will be written to this file, although
|
||||||
=model.episodes= will be updated every time the model is saved to disk. The
|
=episodes_trained= will be updated every time the model is saved to disk. The
|
||||||
format of this file is specified in [[Log format]].
|
format of this file is specified in [[Log format]].
|
||||||
|
|
||||||
** Log format
|
** Log format
|
||||||
|
|
14
main.py
14
main.py
|
@ -3,6 +3,12 @@ import sys
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
models_storage_path = 'models'
|
||||||
|
|
||||||
|
# Create models folder
|
||||||
|
if not os.path.exists(models_storage_path):
|
||||||
|
os.makedirs(models_storage_path)
|
||||||
|
|
||||||
# Define helper functions
|
# Define helper functions
|
||||||
def log_train_outcome(outcome, trained_eps = 0):
|
def log_train_outcome(outcome, trained_eps = 0):
|
||||||
format_vars = { 'trained_eps': trained_eps,
|
format_vars = { 'trained_eps': trained_eps,
|
||||||
|
@ -34,9 +40,9 @@ parser = argparse.ArgumentParser(description="Backgammon games")
|
||||||
parser.add_argument('--episodes', action='store', dest='episode_count',
|
parser.add_argument('--episodes', action='store', dest='episode_count',
|
||||||
type=int, default=1000,
|
type=int, default=1000,
|
||||||
help='number of episodes to train')
|
help='number of episodes to train')
|
||||||
parser.add_argument('--model-path', action='store', dest='model_path',
|
parser.add_argument('--model', action='store', dest='model',
|
||||||
default='./model',
|
default='default',
|
||||||
help='path to Tensorflow model')
|
help='name of Tensorflow model to use')
|
||||||
parser.add_argument('--eval-methods', action='store',
|
parser.add_argument('--eval-methods', action='store',
|
||||||
default=['random'], nargs='*',
|
default=['random'], nargs='*',
|
||||||
help='specifies evaluation methods')
|
help='specifies evaluation methods')
|
||||||
|
@ -55,7 +61,7 @@ parser.add_argument('--start-episode', action='store', dest='start_episode',
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
'model_path': args.model_path,
|
'model_path': os.path.join(models_storage_path, args.model),
|
||||||
'episode_count': args.episode_count,
|
'episode_count': args.episode_count,
|
||||||
'eval_methods': args.eval_methods,
|
'eval_methods': args.eval_methods,
|
||||||
'train': args.train,
|
'train': args.train,
|
||||||
|
|
|
@ -26,7 +26,7 @@ class Network:
|
||||||
|
|
||||||
|
|
||||||
# Restore trained episode count for model
|
# Restore trained episode count for model
|
||||||
episode_count_path = os.path.join(self.checkpoint_path, "model.episodes")
|
episode_count_path = os.path.join(self.checkpoint_path, "episodes_trained")
|
||||||
if os.path.isfile(episode_count_path):
|
if os.path.isfile(episode_count_path):
|
||||||
with open(episode_count_path, 'r') as f:
|
with open(episode_count_path, 'r') as f:
|
||||||
self.config['start_episode'] = int(f.read())
|
self.config['start_episode'] = int(f.read())
|
||||||
|
@ -91,7 +91,7 @@ class Network:
|
||||||
|
|
||||||
def save_model(self, episode_count):
|
def save_model(self, episode_count):
|
||||||
self.saver.save(self.session, os.path.join(self.checkpoint_path, 'model.ckpt'))
|
self.saver.save(self.session, os.path.join(self.checkpoint_path, 'model.ckpt'))
|
||||||
with open(os.path.join(self.checkpoint_path, "model.episodes"), 'w+') as f:
|
with open(os.path.join(self.checkpoint_path, "episodes_trained"), 'w+') as f:
|
||||||
f.write(str(episode_count) + "\n")
|
f.write(str(episode_count) + "\n")
|
||||||
|
|
||||||
def restore_model(self):
|
def restore_model(self):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user