Training using slightly revamped version of our own board rep. Not sure if works yet.
This commit is contained in:
parent
ab5d2aabb2
commit
f43108c239
11
board.py
11
board.py
|
@ -35,6 +35,17 @@ class Board:
|
|||
board.append(-15 - sum(negatives))
|
||||
return tuple(board)
|
||||
|
||||
@staticmethod
|
||||
def board_features_to_own(board, player):
|
||||
board = list(board)
|
||||
positives = [x if x > 0 else 0 for x in board]
|
||||
negatives = [x if x < 0 else 0 for x in board]
|
||||
board.append(15 - sum(positives))
|
||||
board.append(-15 - sum(negatives))
|
||||
board += ([1, 0] if np.sign(player) > 0 else [1, 0])
|
||||
return np.array(board).reshape(1,-1)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def board_features_to_tesauro(board, cur_player):
|
||||
features = []
|
||||
|
|
27
network.py
27
network.py
|
@ -11,13 +11,11 @@ from eval import Eval
|
|||
|
||||
class Network:
|
||||
hidden_size = 40
|
||||
input_size = 198
|
||||
input_size = 30
|
||||
output_size = 1
|
||||
# Can't remember the best learning_rate, look this up
|
||||
learning_rate = 0.01
|
||||
|
||||
# TODO: Actually compile tensorflow properly
|
||||
# os.environ["TF_CPP_MIN_LOG_LEVEL"]="2"
|
||||
board_rep = Board.board_features_to_own
|
||||
|
||||
def custom_tanh(self, x, name=None):
|
||||
return tf.scalar_mul(tf.constant(2.00), tf.tanh(x, name))
|
||||
|
@ -147,7 +145,7 @@ class Network:
|
|||
def make_move(self, sess, board, roll, player):
|
||||
# print(Board.pretty(board))
|
||||
legal_moves = Board.calculate_legal_states(board, player, roll)
|
||||
moves_and_scores = [(move, self.eval_state(sess, Board.board_features_to_tesauro(move, player))) for move in legal_moves]
|
||||
moves_and_scores = [(move, self.eval_state(sess, Network.board_rep(move, player))) for move in legal_moves]
|
||||
scores = [x[1] if np.sign(player) > 0 else 1-x[1] for x in moves_and_scores]
|
||||
best_score_index = np.array(scores).argmax()
|
||||
best_move_pair = moves_and_scores[best_score_index]
|
||||
|
@ -338,15 +336,6 @@ class Network:
|
|||
sys.stderr.write("[TRAIN] Episode {}".format(episode + trained_eps))
|
||||
# TODO decide which player should be here
|
||||
|
||||
|
||||
# TEST
|
||||
#if episode % 1000 == 0:
|
||||
# self.config['eval_methods'] = 'dumbeval'
|
||||
# self.config['episodes'] = 300
|
||||
# outcomes = self.eval(trained_eps)
|
||||
# self.log_eval_outcomes(outcomes, trained_eps=self.episodes_trained)
|
||||
|
||||
#player = random.choice([-1, 1])
|
||||
player = 1
|
||||
|
||||
prev_board = Board.initial_state
|
||||
|
@ -355,7 +344,6 @@ class Network:
|
|||
# first thing inside of the while loop and then call
|
||||
# best_move_and_score to get V_t+1
|
||||
|
||||
# i = 0
|
||||
while Board.outcome(prev_board) is None:
|
||||
|
||||
#print("PREEEV_BOOOOAAARD:",prev_board)
|
||||
|
@ -367,7 +355,7 @@ class Network:
|
|||
|
||||
# adjust weights
|
||||
sess.run(self.training_op,
|
||||
feed_dict={self.x: Board.board_features_to_tesauro(prev_board, player),
|
||||
feed_dict={self.x: Network.board_rep(prev_board, player),
|
||||
self.value_next: cur_board_value})
|
||||
|
||||
player *= -1
|
||||
|
@ -386,7 +374,7 @@ class Network:
|
|||
with tf.name_scope("final"):
|
||||
merged = tf.summary.merge_all()
|
||||
summary, _ = sess.run([merged, self.training_op],
|
||||
feed_dict={self.x: Board.board_features_to_tesauro(prev_board, player),
|
||||
feed_dict={self.x: Network.board_rep(prev_board, player),
|
||||
self.value_next: scaled_final_score.reshape((1, 1))})
|
||||
writer.add_summary(summary, episode + trained_eps)
|
||||
|
||||
|
@ -415,8 +403,3 @@ class Network:
|
|||
|
||||
# save the current state again, so we can continue running backprop based on the "previous" turn.
|
||||
|
||||
# NOTE: We need to make a method so that we can take a single turn or at least
|
||||
# just pick the next best move, so we know how to evaluate according to TD-learning.
|
||||
# Right now, our game just continues in a while loop without nothing to stop it!
|
||||
|
||||
|
||||
|
|
|
@ -5,6 +5,10 @@ static PyObject* DumbevalError;
|
|||
static float x[122];
|
||||
|
||||
static const float wc[122] = {
|
||||
5.6477, 6.316649999999999, 7.05515, 6.65315, 9.3171, 17.9777, 2.0235499999999993, 5.1129500000000005, 7.599200000000001, 9.68525, 3.1762, 8.05335, 16.153499999999998, 8.02445, 10.55345, 15.489600000000001, 10.525199999999998, 16.438850000000002, 12.27405, 9.6362, 12.7152, 13.2859, 1.6932499999999995, 26.79045, 10.521899999999999, 6.79635, 5.28135, 6.2059, 10.2306, 10.5485, 3.6000500000000004, 4.07825, 6.951700000000001, 4.413749999999999, 11.271450000000002, 12.9361, 11.087299999999999, 13.10085, 10.411999999999999, 8.084050000000001, 12.4893, 5.96055, 4.69195, 18.9482, 9.0946, 9.1954, 6.2592, 16.180300000000003, 8.3376, 23.24915, 14.32525, -2.6699000000000006, 19.156, 5.81445, 4.7214, 7.63055, 7.039, 5.88075, 2.00765, 14.596800000000002, 11.5208, -3.79, -3.8541000000000003, 5.358499999999999, 14.4516, 2.49015, 11.284799999999999, 14.1066, 16.2306, 5.82875, 9.34505, 16.13685, 8.1893, 2.93145, 7.83185, 12.86765, 6.90115, 20.07255, 8.93355, -0.12434999999999974, 12.0587, 11.83985, 6.34155, 7.1963, 10.571200000000001, 22.38365, 6.50745, 8.94595, 12.0434, 10.79885, 14.055800000000001, 0.022100000000000453, 10.39255, 4.088850000000001, 3.6421499999999996, 38.1298, 6.8957, 0.9804999999999997, 5.9599, 13.16055, 11.55305, 10.65015, 4.6673, 15.770999999999999, 27.700050000000005, 4.4329, 12.6349, 7.037800000000001, 3.4897, 18.91945, 10.239899999999999, 5.4625, 10.29705, 10.492799999999999, 8.850900000000001, -10.575999999999999, 10.6893, 15.30845, 17.8083, 31.88275, 11.225000000000001, 4.4806};
|
||||
|
||||
|
||||
/*
|
||||
1.5790816238841092, 1.6374860177130541, -1.7131823639980923, -0.9286186784962336, -1.0732080528763888,
|
||||
-0.33851674519289876, 1.5798155080270462, 2.3161915581553414, 1.5625330782392322, 0.9397141260075461,
|
||||
0.8386342522957442, 1.2380864901133144, -2.803703105809909, -1.6033863837759044, -1.9297462408169208,
|
||||
|
@ -30,7 +34,7 @@ static const float wc[122] = {
|
|||
-0.3393405083020449, 2.787144781914554, -2.401723402781605, -1.1675562811241997, -1.1542961327714207,
|
||||
0.18253192955355502, -2.418436664206371, 0.7423935287565309, 2.9903418274144666, -1.3503112004693552,
|
||||
-2.649146174480099, -0.5447080156947952
|
||||
};
|
||||
};*/
|
||||
|
||||
static const float wr[122] = {
|
||||
-0.7856, -0.50352, 0.12392, -1.00316, -2.46556, -0.1627, 0.18966, 0.0043, 0.0,
|
||||
|
|
Loading…
Reference in New Issue
Block a user