diff --git a/bin/train-evaluate-save b/bin/train-evaluate-save new file mode 100755 index 0000000..00b6411 --- /dev/null +++ b/bin/train-evaluate-save @@ -0,0 +1,47 @@ +#!/usr/bin/env ruby +def save(model_name) + require 'date' + + models_dir = 'models' + model_path = File.join(models_dir, model_name) + if not File.exists? model_path then + return false + end + + episode_count = (File.read File.join(model_path, 'episodes_trained')).to_i + + puts "Found model #{model_name} with episodes #{episode_count} trained!" + + file_name = "model-#{model_name}-#{episode_count}-#{Time.now.strftime('%Y%m%d-%H%M%S')}.tar.gz" + save_path = File.join(models_dir, 'saves', file_name) + puts "Saving to #{save_path}" + + system("tar", "-cvzf", save_path, "-C", models_dir, model_name) + + return true +end + +def train(model, episodes) + system("python3", "main.py", "--train", "--model", model, "--episodes", episodes.to_s) +end + +def evaluate(model, episodes, method) + system("python3", "main.py", "--eval" , "--model", model, "--episodes", episodes.to_s, "--eval-methods", method) +end + +model = ARGV[0] + +if model.nil? then raise "no model specified" end + +while true do + save model + train model, 1000 + save model + train model, 1000 + 3.times do + evaluate model, 250, "pubeval" + end + 3.times do + evaluate model, 250, "dumbeval" + end +end