tesauro fat and diffs in values
This commit is contained in:
parent
6e061171da
commit
f170bad9b1
12
board.py
12
board.py
|
@ -268,23 +268,13 @@ class Board:
|
|||
# print("Dice permuts:",dice_permutations)
|
||||
for roll in dice_permutations:
|
||||
# Calculate boards resulting from first move
|
||||
#print("initial board: ", board)
|
||||
#print("roll:", roll)
|
||||
#print("Rest of roll:",roll[1:])
|
||||
boards = calc_moves(board, roll[0])
|
||||
#print("Boards:",boards)
|
||||
#print("Roll:",roll[0])
|
||||
#print("boards after first die: ", boards)
|
||||
|
||||
for die in roll[1:]:
|
||||
# Calculate boards resulting from second move
|
||||
nested_boards = [calc_moves(board, die) for board in boards]
|
||||
#print("nested boards: ", nested_boards)
|
||||
boards = [board for boards in nested_boards for board in boards]
|
||||
# What the fuck
|
||||
#for board in boards:
|
||||
# print(board)
|
||||
# print("type__:",type(board))
|
||||
|
||||
# Add resulting unique boards to set of legal boards resulting from roll
|
||||
|
||||
#print("printing boards from calculate_legal_states: ", boards)
|
||||
|
|
12
network.py
12
network.py
|
@ -114,15 +114,14 @@ class Network:
|
|||
|
||||
with tf.GradientTape() as tape:
|
||||
value = self.model(prev_state.reshape(1,-1))
|
||||
|
||||
grads = tape.gradient(value, self.model.variables)
|
||||
|
||||
difference_in_values = tf.reshape(tf.subtract(value_next, value, name='difference_in_values'), [])
|
||||
tf.summary.scalar("difference_in_values", tf.abs(difference_in_values))
|
||||
|
||||
with tf.variable_scope('apply_gradients'):
|
||||
for grad, train_var in zip(grads, self.model.variables):
|
||||
backprop_calc = self.learning_rate * difference_in_values * grad
|
||||
train_var.assign_add(backprop_calc)
|
||||
for grad, train_var in zip(grads, self.model.variables):
|
||||
backprop_calc = self.learning_rate * difference_in_values * grad
|
||||
train_var.assign_add(backprop_calc)
|
||||
|
||||
|
||||
|
||||
|
@ -299,7 +298,7 @@ class Network:
|
|||
length_list = []
|
||||
test_list = []
|
||||
# Prepping of data
|
||||
start = time.time()
|
||||
# start = time.time()
|
||||
for board in boards:
|
||||
length = 0
|
||||
for roll in all_rolls:
|
||||
|
@ -478,7 +477,6 @@ class Network:
|
|||
for episode in range(1, episodes + 1):
|
||||
|
||||
sys.stderr.write("[TRAIN] Episode {}".format(episode + trained_eps))
|
||||
# TODO decide which player should be here
|
||||
|
||||
# player = 1
|
||||
player = random.choice([-1,1])
|
||||
|
|
Loading…
Reference in New Issue
Block a user