fix wrongful mergings
This commit is contained in:
parent
28b82e8228
commit
2654006222
63
network.py
63
network.py
|
@ -152,8 +152,8 @@ class Network:
|
||||||
# print("Found the best state, being:", np.array(move_scores).argmax())
|
# print("Found the best state, being:", np.array(move_scores).argmax())
|
||||||
return best_move_pair
|
return best_move_pair
|
||||||
|
|
||||||
def eval(self, trained_eps=0):
|
def eval(self, episode_count, trained_eps = 0, tf_session = None):
|
||||||
def do_eval(sess, method, episodes=1000, trained_eps=trained_eps):
|
def do_eval(sess, method, episodes = 1000, trained_eps = 0):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
def print_time_estimate(eps_completed):
|
def print_time_estimate(eps_completed):
|
||||||
|
@ -291,15 +291,24 @@ class Network:
|
||||||
sys.stderr.write("[EVAL ] Evaluation method '{}' is not defined\n".format(method))
|
sys.stderr.write("[EVAL ] Evaluation method '{}' is not defined\n".format(method))
|
||||||
return [0]
|
return [0]
|
||||||
|
|
||||||
with tf.Session() as session:
|
if tf_session == None:
|
||||||
|
with tf.Session():
|
||||||
session.run(tf.global_variables_initializer())
|
session.run(tf.global_variables_initializer())
|
||||||
self.restore_model(session)
|
self.restore_model(session)
|
||||||
outcomes = [(method, do_eval(session,
|
outcomes = [ (method, do_eval(session,
|
||||||
method,
|
method,
|
||||||
self.config['episode_count'],
|
episode_count,
|
||||||
trained_eps=trained_eps))
|
trained_eps = trained_eps))
|
||||||
for method
|
for method
|
||||||
in self.config['eval_methods']]
|
in self.config['eval_methods'] ]
|
||||||
|
return outcomes
|
||||||
|
else:
|
||||||
|
outcomes = [ (method, do_eval(tf_session,
|
||||||
|
method,
|
||||||
|
episode_count,
|
||||||
|
trained_eps = trained_eps))
|
||||||
|
for method
|
||||||
|
in self.config['eval_methods'] ]
|
||||||
return outcomes
|
return outcomes
|
||||||
|
|
||||||
def train_model(self, episodes=1000, save_step_size=100, trained_eps=0):
|
def train_model(self, episodes=1000, save_step_size=100, trained_eps=0):
|
||||||
|
@ -401,43 +410,3 @@ class Network:
|
||||||
# save the current state again, so we can continue running backprop based on the "previous" turn.
|
# save the current state again, so we can continue running backprop based on the "previous" turn.
|
||||||
|
|
||||||
# NOTE: We need to make a method so that we can take a single turn or at least just pick the next best move, so we know how to evaluate according to TD-learning. Right now, our game just continues in a while loop without nothing to stop it!
|
# NOTE: We need to make a method so that we can take a single turn or at least just pick the next best move, so we know how to evaluate according to TD-learning. Right now, our game just continues in a while loop without nothing to stop it!
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def eval(self, episode_count, trained_eps = 0, tf_session = None):
|
|
||||||
def do_eval(sess, method, episodes = 1000, trained_eps = 0):
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
writer.close()
|
|
||||||
|
|
||||||
return outcomes
|
|
||||||
|
|
||||||
# take turn, which finds the best state and picks it, based on the current network
|
|
||||||
# save current state
|
|
||||||
# run training operation (session.run(self.training_op, {x:x, value_next, value_next})),
|
|
||||||
# (something which does the backprop, based on the state after having taken a turn,
|
|
||||||
# found before, and the state we saved in the beginning and from now we'll
|
|
||||||
# save it at the end of the turn
|
|
||||||
|
|
||||||
# save the current state again, so we can continue running backprop based on the "previous" turn.
|
|
||||||
|
|
||||||
|
|
||||||
if tf_session == None:
|
|
||||||
with tf.Session():
|
|
||||||
session.run(tf.global_variables_initializer())
|
|
||||||
self.restore_model(session)
|
|
||||||
outcomes = [ (method, do_eval(session,
|
|
||||||
method,
|
|
||||||
episode_count,
|
|
||||||
trained_eps = trained_eps))
|
|
||||||
for method
|
|
||||||
in self.config['eval_methods'] ]
|
|
||||||
return outcomes
|
|
||||||
else:
|
|
||||||
outcomes = [ (method, do_eval(tf_session,
|
|
||||||
method,
|
|
||||||
episode_count,
|
|
||||||
trained_eps = trained_eps))
|
|
||||||
for method
|
|
||||||
in self.config['eval_methods'] ]
|
|
||||||
return outcomes
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user