backgammon/tensorflow_impl_tests/eager_main.py
Alexander Munch-Hansen 9a2d87516e Ongoing rewrite of network to use an eager model. We're now capable of
evaluating a list of states with network.py. We can also save and
restore models.
2018-05-09 00:33:05 +02:00

74 lines
1.8 KiB
Python

import time
import numpy as np
import tensorflow as tf
import tensorflow.contrib.eager as tfe
tf.enable_eager_execution()
xavier_init = tf.contrib.layers.xavier_initializer()
opt = tf.train.MomentumOptimizer(learning_rate=0.1, momentum=1)
output_size = 1
hidden_size = 40
input_size = 30
model = tf.keras.Sequential([
tf.keras.layers.Dense(40, activation="sigmoid", kernel_initializer=xavier_init, input_shape=(1,input_size)),
tf.keras.layers.Dense(1, activation="sigmoid", kernel_initializer=xavier_init)
])
#tfe.Saver(model.variables).restore(tf.train.latest_checkpoint("./"))
input = [0, 2, 0, 0, 0, 0, -5, 0, -3, 0, 0, 0, 5, -5, 0, 0, 0, 3, 0, 5, 0, 0, 0, 0, -2, 0, 0, 0, 1, 0]
all_input = np.array([input for _ in range(20)])
single_in = np.array(input).reshape(1,-1)
start = time.time()
all_predictions = model.predict_on_batch(all_input)
print(all_predictions)
print(time.time() - start)
start = time.time()
all_predictions = [model(single_in) for _ in range(20)]
#print(all_predictions[:10])
print(time.time() - start)
print("-"*30)
with tf.GradientTape() as tape:
val = model(np.array(input).reshape(1,-1))
grads = tape.gradient(val, model.variables)
grads = [0.1*val-np.random.uniform(-1,1)+grad for grad, trainable_var in zip(grads, model.variables)]
# print(model.variables[0][0])
weights_before = model.weights[0]
start = time.time()
#[trainable_var.assign_add(0.1*val-0.3+grad) for grad, trainable_var in zip(grads, model.variables)]
start = time.time()
#for gradient, trainable_var in zip(grads, model.variables):
# backprop_calc = 0.1 * (val - np.random.uniform(-1, 1)) * gradient
# trainable_var.assign_add(backprop_calc)
opt.apply_gradients(zip(grads, model.variables))
print(time.time() - start)
print(model(np.array(input).reshape(1,-1)))
tfe.Saver(model.variables).save("./tmp_ckpt")