Might be able to learn now (?)
This commit is contained in:
parent
b6fdffd958
commit
11d25603cf
3
bot.py
3
bot.py
|
@ -14,12 +14,13 @@ class Bot:
|
||||||
with self.graph.as_default():
|
with self.graph.as_default():
|
||||||
self.session = tf.Session()
|
self.session = tf.Session()
|
||||||
self.network = Network(self.session)
|
self.network = Network(self.session)
|
||||||
|
self.network.restore_model()
|
||||||
|
|
||||||
|
|
||||||
def roll(self):
|
def roll(self):
|
||||||
print("{} rolled: ".format(self.sym))
|
print("{} rolled: ".format(self.sym))
|
||||||
roll = self.cup.roll()
|
roll = self.cup.roll()
|
||||||
print(roll)
|
# print(roll)
|
||||||
return roll
|
return roll
|
||||||
|
|
||||||
|
|
||||||
|
|
18
game.py
18
game.py
|
@ -25,7 +25,7 @@ class Game:
|
||||||
|
|
||||||
def next_round(self):
|
def next_round(self):
|
||||||
roll = self.roll()
|
roll = self.roll()
|
||||||
print(roll)
|
#print(roll)
|
||||||
self.board = Board.flip(self.p2.make_move(Board.flip(self.board), self.p2.get_sym(), roll))
|
self.board = Board.flip(self.p2.make_move(Board.flip(self.board), self.p2.get_sym(), roll))
|
||||||
return self.board
|
return self.board
|
||||||
|
|
||||||
|
@ -33,16 +33,22 @@ class Game:
|
||||||
return self.board
|
return self.board
|
||||||
|
|
||||||
def train_model(self):
|
def train_model(self):
|
||||||
episodes = 100
|
episodes = 8000
|
||||||
outcomes = []
|
outcomes = []
|
||||||
for episode in range(episodes):
|
for episode in range(episodes):
|
||||||
self.board = Board.initial_state
|
self.board = Board.initial_state
|
||||||
prev_board = self.board
|
# prev_board = self.board
|
||||||
|
prev_board, prev_board_value = self.roll_and_find_best_for_bot()
|
||||||
|
# find the best move here, make this move, then change turn as the
|
||||||
|
# first thing inside of the while loop and then call
|
||||||
|
# roll_and_find_best_for_bot to get V_t+1
|
||||||
|
# self.p1.make_move(prev_board, self.p1.get_sym(), self.roll())
|
||||||
while Board.outcome(self.board) is None:
|
while Board.outcome(self.board) is None:
|
||||||
|
self.next_round()
|
||||||
cur_board, cur_board_value = self.roll_and_find_best_for_bot()
|
cur_board, cur_board_value = self.roll_and_find_best_for_bot()
|
||||||
self.p1.get_network().train(prev_board, cur_board_value)
|
self.p1.get_network().train(prev_board, cur_board_value)
|
||||||
prev_board = cur_board
|
prev_board = cur_board
|
||||||
self.next_round()
|
# self.next_round()
|
||||||
# print("-"*30)
|
# print("-"*30)
|
||||||
# print(Board.pretty(self.board))
|
# print(Board.pretty(self.board))
|
||||||
# print("/"*30)
|
# print("/"*30)
|
||||||
|
@ -51,11 +57,13 @@ class Game:
|
||||||
final_score = np.array([ Board.outcome(self.board)[1] ]).reshape((1, 1))
|
final_score = np.array([ Board.outcome(self.board)[1] ]).reshape((1, 1))
|
||||||
self.p1.get_network().train(prev_board, final_score)
|
self.p1.get_network().train(prev_board, final_score)
|
||||||
print("trained episode {}".format(episode))
|
print("trained episode {}".format(episode))
|
||||||
if episode % 10 == 0:
|
if episode % 100 == 0:
|
||||||
print("Saving...")
|
print("Saving...")
|
||||||
self.p1.get_network().save_model()
|
self.p1.get_network().save_model()
|
||||||
|
self.p2.restore_model()
|
||||||
|
|
||||||
print(outcomes)
|
print(outcomes)
|
||||||
|
print(sum(outcomes))
|
||||||
|
|
||||||
def next_round_test(self):
|
def next_round_test(self):
|
||||||
print(self.board)
|
print(self.board)
|
||||||
|
|
|
@ -100,10 +100,10 @@ class Network:
|
||||||
self.saver.restore(self.session, latest_checkpoint)
|
self.saver.restore(self.session, latest_checkpoint)
|
||||||
|
|
||||||
# Have a circular dependency, #fuck, need to rewrite something
|
# Have a circular dependency, #fuck, need to rewrite something
|
||||||
def train(self, x, v_next):
|
def train(self, board, v_next):
|
||||||
# print("lol")
|
# print("lol")
|
||||||
x = np.array(x).reshape((1,26))
|
board = np.array(board).reshape((1,26))
|
||||||
self.session.run(self.training_op, feed_dict = {self.x:x, self.value_next: v_next})
|
self.session.run(self.training_op, feed_dict = {self.x:board, self.value_next: v_next})
|
||||||
|
|
||||||
|
|
||||||
# while game isn't done:
|
# while game isn't done:
|
||||||
|
|
|
@ -20,6 +20,10 @@ class RestoreBot:
|
||||||
def get_sym(self):
|
def get_sym(self):
|
||||||
return self.sym
|
return self.sym
|
||||||
|
|
||||||
|
def restore_model(self):
|
||||||
|
with self.graph.as_default():
|
||||||
|
self.network.restore_model()
|
||||||
|
|
||||||
def make_move(self, board, sym, roll):
|
def make_move(self, board, sym, roll):
|
||||||
# print(Board.pretty(board))
|
# print(Board.pretty(board))
|
||||||
legal_moves = Board.calculate_legal_states(board, sym, roll)
|
legal_moves = Board.calculate_legal_states(board, sym, roll)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user