train and eval now outputs proper number of training episodes to log
This commit is contained in:
parent
bd459ba0ad
commit
fc88c64452
10
game.py
10
game.py
|
@ -61,11 +61,11 @@ class Game:
|
|||
self.board = p2.make_move(self.board, p2.get_sym(), roll)
|
||||
|
||||
|
||||
def train_model(self, episodes=1000, save_step_size = 100, init_ep = 0):
|
||||
def train_model(self, episodes=1000, save_step_size = 100, trained_eps = 0):
|
||||
sys.stderr.write("[TRAIN] Training {} episodes and save_step_size {}\n".format(episodes, save_step_size))
|
||||
outcomes = []
|
||||
for episode in range(episodes):
|
||||
sys.stderr.write("[TRAIN] Episode {}".format(episode + init_ep))
|
||||
sys.stderr.write("[TRAIN] Episode {}".format(episode + trained_eps))
|
||||
self.board = Board.initial_state
|
||||
|
||||
prev_board, prev_board_value = self.roll_and_find_best_for_bot()
|
||||
|
@ -109,8 +109,8 @@ class Game:
|
|||
print(self.board)
|
||||
print("--------------------------------")
|
||||
|
||||
def eval(self, init_ep = 0):
|
||||
def do_eval(method, episodes = 1000, init_ep = 0):
|
||||
def eval(self, trained_eps = 0):
|
||||
def do_eval(method, episodes = 1000, trained_eps = 0):
|
||||
sys.stderr.write("[EVAL ] Evaluating {eps} episode(s) with method '{method}'\n".format(eps=episodes, method=method))
|
||||
if method == 'random':
|
||||
outcomes = []
|
||||
|
@ -132,7 +132,7 @@ class Game:
|
|||
|
||||
return [ (method, do_eval(method,
|
||||
self.config['episode_count'],
|
||||
init_ep = init_ep))
|
||||
trained_eps = trained_eps))
|
||||
for method
|
||||
in self.config['eval_methods'] ]
|
||||
|
||||
|
|
40
main.py
40
main.py
|
@ -1,23 +1,27 @@
|
|||
import argparse
|
||||
import sys
|
||||
import time
|
||||
|
||||
def print_train_outcome(outcome, init_ep = 0):
|
||||
format_vars = { 'init_ep': init_ep,
|
||||
def print_train_outcome(outcome, trained_eps = 0):
|
||||
format_vars = { 'trained_eps': trained_eps,
|
||||
'count': len(train_outcome),
|
||||
'sum': sum(train_outcome),
|
||||
'mean': sum(train_outcome) / len(train_outcome)}
|
||||
print("train;{init_ep};{count};{sum};{mean}".format(**format_vars))
|
||||
'mean': sum(train_outcome) / len(train_outcome),
|
||||
'time': int(time.time())
|
||||
}
|
||||
print("train;{time};{trained_eps};{count};{sum};{mean}".format(**format_vars))
|
||||
|
||||
def print_eval_outcomes(outcomes, init_ep = 0):
|
||||
def print_eval_outcomes(outcomes, trained_eps = 0):
|
||||
for outcome in outcomes:
|
||||
scores = outcome[1]
|
||||
format_vars = { 'init_ep': init_ep,
|
||||
format_vars = { 'trained_eps': trained_eps,
|
||||
'method': outcome[0],
|
||||
'count': len(scores),
|
||||
'sum': sum(scores),
|
||||
'mean': sum(scores) / len(scores)
|
||||
'mean': sum(scores) / len(scores),
|
||||
'time': int(time.time())
|
||||
}
|
||||
print("eval;{method};{init_ep};{count};{sum};{mean}".format(**format_vars))
|
||||
print("eval;{time};{method};{trained_eps};{count};{sum};{mean}".format(**format_vars))
|
||||
|
||||
parser = argparse.ArgumentParser(description="Backgammon games")
|
||||
parser.add_argument('--episodes', action='store', dest='episode_count',
|
||||
|
@ -61,24 +65,16 @@ episode_count = args.episode_count
|
|||
if args.train:
|
||||
eps = 0
|
||||
while True:
|
||||
train_outcome = g.train_model(episodes = episode_count, init_ep = eps)
|
||||
print_train_outcome(train_outcome, init_ep = eps)
|
||||
if args.eval:
|
||||
eval_outcomes = g.eval(init_ep = eps)
|
||||
print_eval_outcomes(eval_outcomes, init_ep = eps)
|
||||
train_outcome = g.train_model(episodes = episode_count, trained_eps = eps)
|
||||
eps += episode_count
|
||||
print_train_outcome(train_outcome, trained_eps = eps)
|
||||
if args.eval:
|
||||
eval_outcomes = g.eval(trained_eps = eps)
|
||||
print_eval_outcomes(eval_outcomes, trained_eps = eps)
|
||||
sys.stdout.flush()
|
||||
elif args.eval:
|
||||
outcomes = g.eval()
|
||||
print_eval_outcomes(outcomes, init_ep = 0)
|
||||
print_eval_outcomes(outcomes, trained_eps = 0)
|
||||
#elif args.play:
|
||||
# g.play(episodes = episode_count)
|
||||
|
||||
#outcomes = g.play(2000)
|
||||
#print(outcomes)
|
||||
#print(sum(outcomes))
|
||||
#count = g.play()
|
||||
# highest = max(highest,count)
|
||||
# except KeyboardInterrupt:
|
||||
# break
|
||||
#print("\nHighest amount of turns is:",highest)
|
||||
|
|
Loading…
Reference in New Issue
Block a user