add "--force-creation" flag to force model creation

This commit is contained in:
Christoffer Müller Madsen 2018-04-26 11:43:19 +02:00
parent 48a5f6cbb6
commit 9428a00c11
2 changed files with 7 additions and 5 deletions

View File

@ -31,11 +31,13 @@ parser.add_argument('--train-perpetually', action='store_true',
help='start new training session as soon as the previous is finished') help='start new training session as soon as the previous is finished')
parser.add_argument('--list-models', action='store_true', parser.add_argument('--list-models', action='store_true',
help='list all known models') help='list all known models')
parser.add_argument('--force-creation', action='store_true',
help='force model creation if model does not exist')
args = parser.parse_args() args = parser.parse_args()
if args.model == "baseline": if args.model == "baseline_model":
print("Model name 'baseline' not allowed") print("Model name 'baseline_model' not allowed")
exit() exit()
config = { config = {
@ -52,6 +54,7 @@ config = {
'model_storage_path': 'models', 'model_storage_path': 'models',
'bench_storage_path': 'bench', 'bench_storage_path': 'bench',
'board_representation': 'quack' 'board_representation': 'quack'
'force_creation': args.force_creation
} }
# Create models folder # Create models folder
@ -67,7 +70,6 @@ if not os.path.isdir(model_path()):
if not os.path.isdir(log_path): if not os.path.isdir(log_path):
os.mkdir(log_path) os.mkdir(log_path)
# Define helper functions # Define helper functions
def log_train_outcome(outcome, trained_eps = 0, log_path = os.path.join(model_path(), 'logs', "train.log")): def log_train_outcome(outcome, trained_eps = 0, log_path = os.path.join(model_path(), 'logs', "train.log")):

View File

@ -152,7 +152,7 @@ class Network:
print("Variable: ", k) print("Variable: ", k)
print("Shape: ", v.shape) print("Shape: ", v.shape)
print(v) print(v)
else: elif not self.config['force_creation']:
print("You need to have baseline_model inside models") print("You need to have baseline_model inside models")
exit() exit()
@ -454,4 +454,4 @@ class Network:
writer.close() writer.close()
return outcomes return outcomes