backgammon/network_test.py

102 lines
2.5 KiB
Python
Raw Normal View History

2018-03-04 16:35:36 +00:00
from network import Network
import tensorflow as tf
import random
import numpy as np
2018-04-29 10:14:14 +00:00
from board import Board
2018-03-04 16:35:36 +00:00
2018-04-29 10:14:14 +00:00
import main
2018-03-04 16:35:36 +00:00
2018-04-29 10:14:14 +00:00
config = main.config.copy()
config['model'] = "tesauro_blah"
config['force_creation'] = True
network = Network(config, config['model'])
2018-03-04 16:35:36 +00:00
2018-04-29 10:14:14 +00:00
session = tf.Session()
2018-03-04 16:35:36 +00:00
2018-04-29 10:14:14 +00:00
session.run(tf.global_variables_initializer())
network.restore_model(session)
initial_state = Board.initial_state
initial_state_1 = ( 0,
0, 0, 0, 2, 0, -5,
0, -3, 0, 0, 0, 0,
-5, 0, 0, 0, 3, 5,
0, 0, 0, 0, 5, -2,
0 )
initial_state_2 = ( 0,
-5, -5, -3, -2, 0, 0,
0, 0, 0, 0, 0, 0,
0, 0, 0, 15, 0, 0,
0, 0, 0, 0, 0, 0,
0 )
boards = {initial_state,
initial_state_1,
initial_state_2 }
def gen_21_rolls():
"""
Calculate all possible rolls, [[1,1], [1,2] ..]
:return: All possible rolls
"""
a = []
for x in range(1, 7):
for y in range(1, 7):
if not [x, y] in a and not [y, x] in a:
a.append([x, y])
return a
def calc_all_scores(board, player):
scores = []
trans_board = network.board_trans_func(board, player)
rolls = gen_21_rolls()
for roll in rolls:
score = network.eval_state(session, trans_board)
scores.append(score)
return scores
def calculate_possible_states(board):
possible_rolls = [(1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
(1, 6), (2, 2), (2, 3), (2, 4), (2, 5),
(2, 6), (3, 3), (3, 4), (3, 5), (3, 6),
(4, 4), (4, 5), (4, 6), (5, 5), (5, 6),
(6, 6)]
for roll in possible_rolls:
meh = Board.calculate_legal_states(board, -1, roll)
print(len(meh))
return [Board.calculate_legal_states(board, -1, roll)
for roll
in possible_rolls]
#for board in boards:
# calculate_possible_states(board)
2018-04-29 10:14:14 +00:00
print("-"*30)
print(network.do_ply(session, boards, 1))
#print(" "*10 + "network_test")
#print(" "*20 + "Depth 1")
scores = network.n_ply(1, session, boards, 1)
2018-04-29 10:14:14 +00:00
#print(" "*20 + "Depth 2")
#print(network.n_ply(2, session, boards, 1))
2018-04-29 10:14:14 +00:00
# #print(x.shape)
# with graph_lol.as_default():
# session_2 = tf.Session(graph = graph_lol)
# network_2 = Network(session_2)
# network_2.restore_model()
# print(network_2.eval_state(initial_state))
2018-03-04 16:35:36 +00:00
2018-04-29 10:14:14 +00:00
# print(network.eval_state(initial_state))