From 9428a00c11bc3c65b403fed5f4f98431db95fea0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoffer=20M=C3=BCller=20Madsen?= Date: Thu, 26 Apr 2018 11:43:19 +0200 Subject: [PATCH] add "--force-creation" flag to force model creation --- main.py | 8 +++++--- network.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 513f076..4df4f1c 100644 --- a/main.py +++ b/main.py @@ -31,11 +31,13 @@ parser.add_argument('--train-perpetually', action='store_true', help='start new training session as soon as the previous is finished') parser.add_argument('--list-models', action='store_true', help='list all known models') +parser.add_argument('--force-creation', action='store_true', + help='force model creation if model does not exist') args = parser.parse_args() -if args.model == "baseline": - print("Model name 'baseline' not allowed") +if args.model == "baseline_model": + print("Model name 'baseline_model' not allowed") exit() config = { @@ -52,6 +54,7 @@ config = { 'model_storage_path': 'models', 'bench_storage_path': 'bench', 'board_representation': 'quack' + 'force_creation': args.force_creation } # Create models folder @@ -67,7 +70,6 @@ if not os.path.isdir(model_path()): if not os.path.isdir(log_path): os.mkdir(log_path) - # Define helper functions def log_train_outcome(outcome, trained_eps = 0, log_path = os.path.join(model_path(), 'logs', "train.log")): diff --git a/network.py b/network.py index faed87a..82cd095 100644 --- a/network.py +++ b/network.py @@ -152,7 +152,7 @@ class Network: print("Variable: ", k) print("Shape: ", v.shape) print(v) - else: + elif not self.config['force_creation']: print("You need to have baseline_model inside models") exit() @@ -454,4 +454,4 @@ class Network: writer.close() - return outcomes \ No newline at end of file + return outcomes