clean up
This commit is contained in:
parent
554e587ffd
commit
81f8db35f4
|
@ -42,6 +42,10 @@ The following examples describe commmon operations.
|
|||
|
||||
=python3 --train=
|
||||
|
||||
*** Train perpetually
|
||||
|
||||
=python3 --train --train-perpetually=
|
||||
|
||||
*** Train model named =quack=
|
||||
|
||||
=python3 --train --model=quack=
|
||||
|
|
15
game.py
15
game.py
|
@ -81,9 +81,20 @@ class Game:
|
|||
|
||||
|
||||
def train_model(self, episodes=1000, save_step_size = 100, trained_eps = 0):
|
||||
start_time = time.time()
|
||||
def print_time_estimate(eps_completed):
|
||||
cur_time = time.time()
|
||||
time_diff = cur_time - start_time
|
||||
eps_per_sec = eps_completed / time_diff
|
||||
secs_per_ep = time_diff / eps_completed
|
||||
eps_remaining = (episodes - eps_completed)
|
||||
sys.stderr.write("[TRAIN] Averaging {per_sec} episodes per second\n".format(per_sec = round(eps_per_sec, 2)))
|
||||
sys.stderr.write("[TRAIN] {eps_remaining} episodes remaining; approx. {time_remaining} seconds remaining\n".format(eps_remaining = eps_remaining, time_remaining = int(eps_remaining * secs_per_ep)))
|
||||
|
||||
|
||||
sys.stderr.write("[TRAIN] Training {} episodes and save_step_size {}\n".format(episodes, save_step_size))
|
||||
outcomes = []
|
||||
for episode in range(episodes):
|
||||
for episode in range(1, episodes + 1):
|
||||
sys.stderr.write("[TRAIN] Episode {}".format(episode + trained_eps))
|
||||
self.board = Board.initial_state
|
||||
|
||||
|
@ -114,6 +125,8 @@ class Game:
|
|||
sys.stderr.write("[TRAIN] Loading model for training opponent...\n")
|
||||
self.p2.restore_model()
|
||||
|
||||
if episode % 50 == 0:
|
||||
print_time_estimate(episode)
|
||||
|
||||
sys.stderr.write("[TRAIN] Saving model for final episode...\n")
|
||||
self.p1.get_network().save_model(episode+trained_eps)
|
||||
|
|
32
main.py
32
main.py
|
@ -3,11 +3,11 @@ import sys
|
|||
import os
|
||||
import time
|
||||
|
||||
models_storage_path = 'models'
|
||||
model_storage_path = 'models'
|
||||
|
||||
# Create models folder
|
||||
if not os.path.exists(models_storage_path):
|
||||
os.makedirs(models_storage_path)
|
||||
if not os.path.exists(model_storage_path):
|
||||
os.makedirs(model_storage_path)
|
||||
|
||||
# Define helper functions
|
||||
def log_train_outcome(outcome, trained_eps = 0):
|
||||
|
@ -57,18 +57,23 @@ parser.add_argument('--play', action='store_true',
|
|||
parser.add_argument('--start-episode', action='store', dest='start_episode',
|
||||
type=int, default=0,
|
||||
help='episode count to start at; purely for display purposes')
|
||||
parser.add_argument('--train-perpetually', action='store_true',
|
||||
help='start new training session as soon as the previous is finished')
|
||||
parser.add_argument('--list-models', action='store_true',
|
||||
help='list all known models')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = {
|
||||
'model_path': os.path.join(models_storage_path, args.model),
|
||||
'model_path': os.path.join(model_storage_path, args.model),
|
||||
'episode_count': args.episode_count,
|
||||
'eval_methods': args.eval_methods,
|
||||
'train': args.train,
|
||||
'play': args.play,
|
||||
'eval': args.eval,
|
||||
'eval_after_train': args.eval_after_train,
|
||||
'start_episode': args.start_episode
|
||||
'start_episode': args.start_episode,
|
||||
'train_perpetually': args.train_perpetually
|
||||
}
|
||||
|
||||
# Make sure directories exist
|
||||
|
@ -91,7 +96,20 @@ episode_count = config['episode_count']
|
|||
|
||||
|
||||
# Do actions specified by command-line
|
||||
if args.train:
|
||||
if args.list_models:
|
||||
def get_eps_trained(folder):
|
||||
with open(os.path.join(folder, 'episodes_trained'), 'r') as f:
|
||||
return int(f.read())
|
||||
model_folders = [ f.path
|
||||
for f
|
||||
in os.scandir(model_storage_path)
|
||||
if f.is_dir() ]
|
||||
models = [ (folder, get_eps_trained(folder)) for folder in model_folders ]
|
||||
sys.stderr.write("Found {} model(s)\n".format(len(models)))
|
||||
for model in models:
|
||||
sys.stderr.write(" {name}: {eps_trained}\n".format(name = model[0], eps_trained = model[1]))
|
||||
|
||||
elif args.train:
|
||||
eps = config['start_episode']
|
||||
while True:
|
||||
train_outcome = g.train_model(episodes = episode_count, trained_eps = eps)
|
||||
|
@ -100,6 +118,8 @@ if args.train:
|
|||
if config['eval_after_train']:
|
||||
eval_outcomes = g.eval(trained_eps = eps)
|
||||
log_eval_outcomes(eval_outcomes, trained_eps = eps)
|
||||
if not config['train_perpetually']:
|
||||
break
|
||||
elif args.eval:
|
||||
eps = config['start_episode']
|
||||
outcomes = g.eval()
|
||||
|
|
3
plot.py
3
plot.py
|
@ -44,8 +44,7 @@ if __name__ == '__main__':
|
|||
plt.show()
|
||||
|
||||
while True:
|
||||
df = pd.read_csv('models/c/logs/eval.log', sep=";", names=eval_headers)
|
||||
df['timestamp'] = df['timestamp'].map(lambda t: datetime.datetime.fromtimestamp(t))
|
||||
df = dataframes('default')['eval']
|
||||
|
||||
print(df)
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ grpcio==1.10.0
|
|||
html5lib==0.9999999
|
||||
Markdown==2.6.11
|
||||
numpy==1.14.1
|
||||
pkg-resources==0.0.0
|
||||
protobuf==3.5.1
|
||||
six==1.11.0
|
||||
tensorboard==1.6.0
|
||||
|
|
Loading…
Reference in New Issue
Block a user