Added a verbosity flag, --verbose, which allows for printing of
variables and such.
This commit is contained in:
parent
6429e0732c
commit
9cfdd7e2b2
6
main.py
6
main.py
|
@ -38,6 +38,8 @@ parser.add_argument('--board-rep', action='store', dest='board_rep',
|
||||||
help='name of board representation to use as input to neural network')
|
help='name of board representation to use as input to neural network')
|
||||||
parser.add_argument('--use-baseline', action='store_true',
|
parser.add_argument('--use-baseline', action='store_true',
|
||||||
help='use the baseline model, note, has size 28')
|
help='use the baseline model, note, has size 28')
|
||||||
|
parser.add_argument('--verbose', action='store_true',
|
||||||
|
help='If set, a lot of stuff will be printed')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -61,7 +63,9 @@ config = {
|
||||||
'board_representation': args.board_rep,
|
'board_representation': args.board_rep,
|
||||||
'force_creation': args.force_creation,
|
'force_creation': args.force_creation,
|
||||||
'use_baseline': args.use_baseline,
|
'use_baseline': args.use_baseline,
|
||||||
'global_step': 0
|
'global_step': 0,
|
||||||
|
'verbose': args.verbose
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create models folder
|
# Create models folder
|
||||||
|
|
28
network.py
28
network.py
|
@ -98,10 +98,20 @@ class Network:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def print_variables(self):
|
||||||
|
variables = self.model.variables
|
||||||
|
|
||||||
|
for k in variables:
|
||||||
|
print(k)
|
||||||
|
|
||||||
def eval_state(self, state):
|
def eval_state(self, state):
|
||||||
return self.model(state.reshape(1,-1))
|
return self.model(state.reshape(1,-1))
|
||||||
|
|
||||||
def save_model(self, episode_count):
|
def save_model(self, episode_count):
|
||||||
|
"""
|
||||||
|
:param episode_count:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
tfe.Saver(self.model.variables).save(os.path.join(self.checkpoint_path, 'model.ckpt'))
|
tfe.Saver(self.model.variables).save(os.path.join(self.checkpoint_path, 'model.ckpt'))
|
||||||
#self.saver.save(sess, os.path.join(self.checkpoint_path, 'model.ckpt'), global_step=global_step)
|
#self.saver.save(sess, os.path.join(self.checkpoint_path, 'model.ckpt'), global_step=global_step)
|
||||||
with open(os.path.join(self.checkpoint_path, "episodes_trained"), 'w+') as f:
|
with open(os.path.join(self.checkpoint_path, "episodes_trained"), 'w+') as f:
|
||||||
|
@ -113,6 +123,8 @@ class Network:
|
||||||
print("[NETWK] ({name}) Saving global step to:".format(name=self.name),
|
print("[NETWK] ({name}) Saving global step to:".format(name=self.name),
|
||||||
os.path.join(self.checkpoint_path, 'model.ckpt'))
|
os.path.join(self.checkpoint_path, 'model.ckpt'))
|
||||||
f.write(str(self.global_step) + "\n")
|
f.write(str(self.global_step) + "\n")
|
||||||
|
if self.config['verbose']:
|
||||||
|
self.print_variables()
|
||||||
|
|
||||||
|
|
||||||
def calc_vals(self, states):
|
def calc_vals(self, states):
|
||||||
|
@ -150,6 +162,8 @@ class Network:
|
||||||
with open(global_step_path, 'r') as f:
|
with open(global_step_path, 'r') as f:
|
||||||
self.config['global_step'] = int(f.read())
|
self.config['global_step'] = int(f.read())
|
||||||
|
|
||||||
|
if self.config['verbose']:
|
||||||
|
self.print_variables()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -489,19 +503,13 @@ class Network:
|
||||||
in self.config['eval_methods'] ]
|
in self.config['eval_methods'] ]
|
||||||
return outcomes
|
return outcomes
|
||||||
|
|
||||||
|
|
||||||
def train_model(self, episodes=1000, save_step_size=100, trained_eps=0):
|
def train_model(self, episodes=1000, save_step_size=100, trained_eps=0):
|
||||||
with tf.Session() as sess:
|
with tf.Session() as sess:
|
||||||
difference_in_vals = 0
|
difference_in_vals = 0
|
||||||
|
|
||||||
self.restore_model()
|
self.restore_model()
|
||||||
|
|
||||||
#variables_names = [v.name for v in tf.trainable_variables()]
|
|
||||||
#values = sess.run(variables_names)
|
|
||||||
#for k, v in zip(variables_names, values):
|
|
||||||
# print("Variable: ", k)
|
|
||||||
# print("Shape: ", v.shape)
|
|
||||||
# print(v)
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
def print_time_estimate(eps_completed):
|
def print_time_estimate(eps_completed):
|
||||||
|
@ -537,9 +545,13 @@ class Network:
|
||||||
|
|
||||||
difference_in_vals += abs((cur_board_value - self.eval_state(self.board_trans_func(prev_board, player))))
|
difference_in_vals += abs((cur_board_value - self.eval_state(self.board_trans_func(prev_board, player))))
|
||||||
|
|
||||||
|
if self.config['verbose']:
|
||||||
|
print("Difference in values:", difference_in_vals)
|
||||||
|
print("Current board value :", cur_board_value)
|
||||||
|
print("Current board is :\n",cur_board)
|
||||||
|
|
||||||
|
|
||||||
# adjust weights
|
# adjust weights
|
||||||
#print(cur_board)
|
|
||||||
if Board.outcome(cur_board) is None:
|
if Board.outcome(cur_board) is None:
|
||||||
self.do_backprop(self.board_trans_func(prev_board, player), cur_board_value)
|
self.do_backprop(self.board_trans_func(prev_board, player), cur_board_value)
|
||||||
player *= -1
|
player *= -1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user