From bae1e73692dae23e0cb6631284d81576cdb58026 Mon Sep 17 00:00:00 2001 From: Alexander Munch-Hansen Date: Wed, 7 Mar 2018 14:44:17 +0100 Subject: [PATCH] Now only using one bot again. Also changed learning rate to 0.1 --- bot.py | 4 ++++ game.py | 11 ++++++----- network.py | 2 +- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/bot.py b/bot.py index 05de7ca..eda20b1 100644 --- a/bot.py +++ b/bot.py @@ -27,6 +27,10 @@ class Bot: def switch(self,cur): return -1 if cur == 1 else 1 + def restore_model(self): + with self.graph.as_default(): + self.network.restore_model() + def get_session(self): return self.session diff --git a/game.py b/game.py index 45203e5..25e3fde 100644 --- a/game.py +++ b/game.py @@ -11,7 +11,7 @@ class Game: self.board = Board.initial_state self.p1 = Bot(1) - self.p2 = RestoreBot(1) + self.p2 = Bot(1) self.cup = Cup() def roll(self): @@ -26,14 +26,14 @@ class Game: def next_round(self): roll = self.roll() #print(roll) - self.board = Board.flip(self.p2.make_move(Board.flip(self.board), self.p2.get_sym(), roll)) + self.board = Board.flip(self.p2.make_move(Board.flip(self.board), self.p2.get_sym(), roll)[0]) return self.board def board_state(self): return self.board def train_model(self): - episodes = 8000 + episodes = 100 outcomes = [] for episode in range(episodes): self.board = Board.initial_state @@ -57,10 +57,11 @@ class Game: final_score = np.array([ Board.outcome(self.board)[1] ]).reshape((1, 1)) self.p1.get_network().train(prev_board, final_score) print("trained episode {}".format(episode)) - if episode % 100 == 0: + if episode % 10 == 0: print("Saving...") self.p1.get_network().save_model() self.p2.restore_model() + print(sum(outcomes)) print(outcomes) print(sum(outcomes)) @@ -95,7 +96,7 @@ class Game: roll = self.roll() print("{} rolled: {}".format(self.p2.get_sym(), roll)) - self.board = self.p2.make_move(self.board, self.p2.get_sym(), roll) + self.board = self.p2.make_move(self.board, self.p2.get_sym(), roll)[0] if Board.outcome(self.board)[1] > 0: diff --git a/network.py b/network.py index ca2a07b..a83fcaa 100644 --- a/network.py +++ b/network.py @@ -10,7 +10,7 @@ class Config(): input_size = 26 output_size = 1 # Can't remember the best learning_rate, look this up - learning_rate = 0.3 + learning_rate = 0.1 checkpoint_path = "/tmp/"