diff --git a/bot.py b/bot.py index a39fab1..05de7ca 100644 --- a/bot.py +++ b/bot.py @@ -14,12 +14,13 @@ class Bot: with self.graph.as_default(): self.session = tf.Session() self.network = Network(self.session) + self.network.restore_model() def roll(self): print("{} rolled: ".format(self.sym)) roll = self.cup.roll() - print(roll) +# print(roll) return roll diff --git a/game.py b/game.py index 4f3d6d5..45203e5 100644 --- a/game.py +++ b/game.py @@ -25,7 +25,7 @@ class Game: def next_round(self): roll = self.roll() - print(roll) + #print(roll) self.board = Board.flip(self.p2.make_move(Board.flip(self.board), self.p2.get_sym(), roll)) return self.board @@ -33,16 +33,22 @@ class Game: return self.board def train_model(self): - episodes = 100 + episodes = 8000 outcomes = [] for episode in range(episodes): self.board = Board.initial_state - prev_board = self.board +# prev_board = self.board + prev_board, prev_board_value = self.roll_and_find_best_for_bot() + # find the best move here, make this move, then change turn as the + # first thing inside of the while loop and then call + # roll_and_find_best_for_bot to get V_t+1 +# self.p1.make_move(prev_board, self.p1.get_sym(), self.roll()) while Board.outcome(self.board) is None: + self.next_round() cur_board, cur_board_value = self.roll_and_find_best_for_bot() self.p1.get_network().train(prev_board, cur_board_value) prev_board = cur_board - self.next_round() +# self.next_round() # print("-"*30) # print(Board.pretty(self.board)) # print("/"*30) @@ -51,11 +57,13 @@ class Game: final_score = np.array([ Board.outcome(self.board)[1] ]).reshape((1, 1)) self.p1.get_network().train(prev_board, final_score) print("trained episode {}".format(episode)) - if episode % 10 == 0: + if episode % 100 == 0: print("Saving...") self.p1.get_network().save_model() + self.p2.restore_model() print(outcomes) + print(sum(outcomes)) def next_round_test(self): print(self.board) diff --git a/network.py b/network.py index 5503b85..ca2a07b 100644 --- a/network.py +++ b/network.py @@ -100,10 +100,10 @@ class Network: self.saver.restore(self.session, latest_checkpoint) # Have a circular dependency, #fuck, need to rewrite something - def train(self, x, v_next): + def train(self, board, v_next): # print("lol") - x = np.array(x).reshape((1,26)) - self.session.run(self.training_op, feed_dict = {self.x:x, self.value_next: v_next}) + board = np.array(board).reshape((1,26)) + self.session.run(self.training_op, feed_dict = {self.x:board, self.value_next: v_next}) # while game isn't done: diff --git a/restore_bot.py b/restore_bot.py index 131f8c9..4f78132 100644 --- a/restore_bot.py +++ b/restore_bot.py @@ -20,6 +20,10 @@ class RestoreBot: def get_sym(self): return self.sym + def restore_model(self): + with self.graph.as_default(): + self.network.restore_model() + def make_move(self, board, sym, roll): # print(Board.pretty(board)) legal_moves = Board.calculate_legal_states(board, sym, roll)