train-evaluate-save
This commit is contained in:
parent
bad870c27a
commit
b7708b3675
|
@ -1,30 +1,30 @@
|
||||||
#!/usr/bin/env ruby
|
#!/usr/bin/env ruby
|
||||||
|
MODELS_DIR = 'models'
|
||||||
|
|
||||||
def save(model_name)
|
def save(model_name)
|
||||||
require 'date'
|
require 'date'
|
||||||
|
|
||||||
models_dir = 'models'
|
model_path = File.join(MODELS_DIR, model_name)
|
||||||
model_path = File.join(models_dir, model_name)
|
|
||||||
if not File.exists? model_path then
|
|
||||||
return false
|
|
||||||
end
|
|
||||||
|
|
||||||
episode_count = (File.read File.join(model_path, 'episodes_trained')).to_i
|
episode_count = (File.read File.join(model_path, 'episodes_trained')).to_i
|
||||||
|
|
||||||
puts "Found model #{model_name} with episodes #{episode_count} trained!"
|
puts "Found model #{model_name} with episodes #{episode_count} trained!"
|
||||||
|
|
||||||
file_name = "model-#{model_name}-#{episode_count}-#{Time.now.strftime('%Y%m%d-%H%M%S')}.tar.gz"
|
file_name = "model-#{model_name}-#{episode_count}-#{Time.now.strftime('%Y%m%d-%H%M%S')}.tar.gz"
|
||||||
save_path = File.join(models_dir, 'saves', file_name)
|
save_path = File.join(MODELS_DIR, 'saves', file_name)
|
||||||
puts "Saving to #{save_path}"
|
puts "Saving to #{save_path}"
|
||||||
|
|
||||||
system("tar", "-cvzf", save_path, "-C", models_dir, model_name)
|
system("tar", "-cvzf", save_path, "-C", MODELS_DIR, model_name)
|
||||||
|
|
||||||
return true
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def train(model, episodes)
|
def train(model, episodes)
|
||||||
system("python3", "main.py", "--train", "--model", model, "--episodes", episodes.to_s)
|
system("python3", "main.py", "--train", "--model", model, "--episodes", episodes.to_s)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def force_train(model, episodes)
|
||||||
|
system("python3", "main.py", "--train", "--force-creation", "--model", model, "--episodes", episodes.to_s)
|
||||||
|
end
|
||||||
|
|
||||||
def evaluate(model, episodes, method)
|
def evaluate(model, episodes, method)
|
||||||
system("python3", "main.py", "--eval" , "--model", model, "--episodes", episodes.to_s, "--eval-methods", method)
|
system("python3", "main.py", "--eval" , "--model", model, "--episodes", episodes.to_s, "--eval-methods", method)
|
||||||
end
|
end
|
||||||
|
@ -33,11 +33,9 @@ model = ARGV[0]
|
||||||
|
|
||||||
if model.nil? then raise "no model specified" end
|
if model.nil? then raise "no model specified" end
|
||||||
|
|
||||||
while true do
|
if not File.exists? File.join(MODELS_DIR, model) then
|
||||||
|
force_train model, 10
|
||||||
save model
|
save model
|
||||||
train model, 1000
|
|
||||||
save model
|
|
||||||
train model, 1000
|
|
||||||
3.times do
|
3.times do
|
||||||
evaluate model, 250, "pubeval"
|
evaluate model, 250, "pubeval"
|
||||||
end
|
end
|
||||||
|
@ -45,3 +43,27 @@ while true do
|
||||||
evaluate model, 250, "dumbeval"
|
evaluate model, 250, "dumbeval"
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# while true do
|
||||||
|
# save model
|
||||||
|
# train model, 1000
|
||||||
|
# save model
|
||||||
|
# train model, 1000
|
||||||
|
# 3.times do
|
||||||
|
# evaluate model, 250, "pubeval"
|
||||||
|
# end
|
||||||
|
# 3.times do
|
||||||
|
# evaluate model, 250, "dumbeval"
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
|
||||||
|
while true do
|
||||||
|
save model
|
||||||
|
train model, 500
|
||||||
|
5.times do
|
||||||
|
evaluate model, 250, "pubeval"
|
||||||
|
end
|
||||||
|
5.times do
|
||||||
|
evaluate model, 250, "dumbeval"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
Loading…
Reference in New Issue
Block a user