diff --git a/main.py b/main.py index 513f076..4df4f1c 100644 --- a/main.py +++ b/main.py @@ -31,11 +31,13 @@ parser.add_argument('--train-perpetually', action='store_true', help='start new training session as soon as the previous is finished') parser.add_argument('--list-models', action='store_true', help='list all known models') +parser.add_argument('--force-creation', action='store_true', + help='force model creation if model does not exist') args = parser.parse_args() -if args.model == "baseline": - print("Model name 'baseline' not allowed") +if args.model == "baseline_model": + print("Model name 'baseline_model' not allowed") exit() config = { @@ -52,6 +54,7 @@ config = { 'model_storage_path': 'models', 'bench_storage_path': 'bench', 'board_representation': 'quack' + 'force_creation': args.force_creation } # Create models folder @@ -67,7 +70,6 @@ if not os.path.isdir(model_path()): if not os.path.isdir(log_path): os.mkdir(log_path) - # Define helper functions def log_train_outcome(outcome, trained_eps = 0, log_path = os.path.join(model_path(), 'logs', "train.log")): diff --git a/network.py b/network.py index faed87a..82cd095 100644 --- a/network.py +++ b/network.py @@ -152,7 +152,7 @@ class Network: print("Variable: ", k) print("Shape: ", v.shape) print(v) - else: + elif not self.config['force_creation']: print("You need to have baseline_model inside models") exit() @@ -454,4 +454,4 @@ class Network: writer.close() - return outcomes \ No newline at end of file + return outcomes