clean up
This commit is contained in:
parent
554e587ffd
commit
81f8db35f4
|
@ -42,6 +42,10 @@ The following examples describe commmon operations.
|
||||||
|
|
||||||
=python3 --train=
|
=python3 --train=
|
||||||
|
|
||||||
|
*** Train perpetually
|
||||||
|
|
||||||
|
=python3 --train --train-perpetually=
|
||||||
|
|
||||||
*** Train model named =quack=
|
*** Train model named =quack=
|
||||||
|
|
||||||
=python3 --train --model=quack=
|
=python3 --train --model=quack=
|
||||||
|
|
15
game.py
15
game.py
|
@ -81,9 +81,20 @@ class Game:
|
||||||
|
|
||||||
|
|
||||||
def train_model(self, episodes=1000, save_step_size = 100, trained_eps = 0):
|
def train_model(self, episodes=1000, save_step_size = 100, trained_eps = 0):
|
||||||
|
start_time = time.time()
|
||||||
|
def print_time_estimate(eps_completed):
|
||||||
|
cur_time = time.time()
|
||||||
|
time_diff = cur_time - start_time
|
||||||
|
eps_per_sec = eps_completed / time_diff
|
||||||
|
secs_per_ep = time_diff / eps_completed
|
||||||
|
eps_remaining = (episodes - eps_completed)
|
||||||
|
sys.stderr.write("[TRAIN] Averaging {per_sec} episodes per second\n".format(per_sec = round(eps_per_sec, 2)))
|
||||||
|
sys.stderr.write("[TRAIN] {eps_remaining} episodes remaining; approx. {time_remaining} seconds remaining\n".format(eps_remaining = eps_remaining, time_remaining = int(eps_remaining * secs_per_ep)))
|
||||||
|
|
||||||
|
|
||||||
sys.stderr.write("[TRAIN] Training {} episodes and save_step_size {}\n".format(episodes, save_step_size))
|
sys.stderr.write("[TRAIN] Training {} episodes and save_step_size {}\n".format(episodes, save_step_size))
|
||||||
outcomes = []
|
outcomes = []
|
||||||
for episode in range(episodes):
|
for episode in range(1, episodes + 1):
|
||||||
sys.stderr.write("[TRAIN] Episode {}".format(episode + trained_eps))
|
sys.stderr.write("[TRAIN] Episode {}".format(episode + trained_eps))
|
||||||
self.board = Board.initial_state
|
self.board = Board.initial_state
|
||||||
|
|
||||||
|
@ -114,6 +125,8 @@ class Game:
|
||||||
sys.stderr.write("[TRAIN] Loading model for training opponent...\n")
|
sys.stderr.write("[TRAIN] Loading model for training opponent...\n")
|
||||||
self.p2.restore_model()
|
self.p2.restore_model()
|
||||||
|
|
||||||
|
if episode % 50 == 0:
|
||||||
|
print_time_estimate(episode)
|
||||||
|
|
||||||
sys.stderr.write("[TRAIN] Saving model for final episode...\n")
|
sys.stderr.write("[TRAIN] Saving model for final episode...\n")
|
||||||
self.p1.get_network().save_model(episode+trained_eps)
|
self.p1.get_network().save_model(episode+trained_eps)
|
||||||
|
|
32
main.py
32
main.py
|
@ -3,11 +3,11 @@ import sys
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
models_storage_path = 'models'
|
model_storage_path = 'models'
|
||||||
|
|
||||||
# Create models folder
|
# Create models folder
|
||||||
if not os.path.exists(models_storage_path):
|
if not os.path.exists(model_storage_path):
|
||||||
os.makedirs(models_storage_path)
|
os.makedirs(model_storage_path)
|
||||||
|
|
||||||
# Define helper functions
|
# Define helper functions
|
||||||
def log_train_outcome(outcome, trained_eps = 0):
|
def log_train_outcome(outcome, trained_eps = 0):
|
||||||
|
@ -57,18 +57,23 @@ parser.add_argument('--play', action='store_true',
|
||||||
parser.add_argument('--start-episode', action='store', dest='start_episode',
|
parser.add_argument('--start-episode', action='store', dest='start_episode',
|
||||||
type=int, default=0,
|
type=int, default=0,
|
||||||
help='episode count to start at; purely for display purposes')
|
help='episode count to start at; purely for display purposes')
|
||||||
|
parser.add_argument('--train-perpetually', action='store_true',
|
||||||
|
help='start new training session as soon as the previous is finished')
|
||||||
|
parser.add_argument('--list-models', action='store_true',
|
||||||
|
help='list all known models')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
'model_path': os.path.join(models_storage_path, args.model),
|
'model_path': os.path.join(model_storage_path, args.model),
|
||||||
'episode_count': args.episode_count,
|
'episode_count': args.episode_count,
|
||||||
'eval_methods': args.eval_methods,
|
'eval_methods': args.eval_methods,
|
||||||
'train': args.train,
|
'train': args.train,
|
||||||
'play': args.play,
|
'play': args.play,
|
||||||
'eval': args.eval,
|
'eval': args.eval,
|
||||||
'eval_after_train': args.eval_after_train,
|
'eval_after_train': args.eval_after_train,
|
||||||
'start_episode': args.start_episode
|
'start_episode': args.start_episode,
|
||||||
|
'train_perpetually': args.train_perpetually
|
||||||
}
|
}
|
||||||
|
|
||||||
# Make sure directories exist
|
# Make sure directories exist
|
||||||
|
@ -91,7 +96,20 @@ episode_count = config['episode_count']
|
||||||
|
|
||||||
|
|
||||||
# Do actions specified by command-line
|
# Do actions specified by command-line
|
||||||
if args.train:
|
if args.list_models:
|
||||||
|
def get_eps_trained(folder):
|
||||||
|
with open(os.path.join(folder, 'episodes_trained'), 'r') as f:
|
||||||
|
return int(f.read())
|
||||||
|
model_folders = [ f.path
|
||||||
|
for f
|
||||||
|
in os.scandir(model_storage_path)
|
||||||
|
if f.is_dir() ]
|
||||||
|
models = [ (folder, get_eps_trained(folder)) for folder in model_folders ]
|
||||||
|
sys.stderr.write("Found {} model(s)\n".format(len(models)))
|
||||||
|
for model in models:
|
||||||
|
sys.stderr.write(" {name}: {eps_trained}\n".format(name = model[0], eps_trained = model[1]))
|
||||||
|
|
||||||
|
elif args.train:
|
||||||
eps = config['start_episode']
|
eps = config['start_episode']
|
||||||
while True:
|
while True:
|
||||||
train_outcome = g.train_model(episodes = episode_count, trained_eps = eps)
|
train_outcome = g.train_model(episodes = episode_count, trained_eps = eps)
|
||||||
|
@ -100,6 +118,8 @@ if args.train:
|
||||||
if config['eval_after_train']:
|
if config['eval_after_train']:
|
||||||
eval_outcomes = g.eval(trained_eps = eps)
|
eval_outcomes = g.eval(trained_eps = eps)
|
||||||
log_eval_outcomes(eval_outcomes, trained_eps = eps)
|
log_eval_outcomes(eval_outcomes, trained_eps = eps)
|
||||||
|
if not config['train_perpetually']:
|
||||||
|
break
|
||||||
elif args.eval:
|
elif args.eval:
|
||||||
eps = config['start_episode']
|
eps = config['start_episode']
|
||||||
outcomes = g.eval()
|
outcomes = g.eval()
|
||||||
|
|
3
plot.py
3
plot.py
|
@ -44,8 +44,7 @@ if __name__ == '__main__':
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
df = pd.read_csv('models/c/logs/eval.log', sep=";", names=eval_headers)
|
df = dataframes('default')['eval']
|
||||||
df['timestamp'] = df['timestamp'].map(lambda t: datetime.datetime.fromtimestamp(t))
|
|
||||||
|
|
||||||
print(df)
|
print(df)
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,6 @@ grpcio==1.10.0
|
||||||
html5lib==0.9999999
|
html5lib==0.9999999
|
||||||
Markdown==2.6.11
|
Markdown==2.6.11
|
||||||
numpy==1.14.1
|
numpy==1.14.1
|
||||||
pkg-resources==0.0.0
|
|
||||||
protobuf==3.5.1
|
protobuf==3.5.1
|
||||||
six==1.11.0
|
six==1.11.0
|
||||||
tensorboard==1.6.0
|
tensorboard==1.6.0
|
||||||
|
|
Loading…
Reference in New Issue
Block a user