Fixed n_ply and actually added a comma in main.py. *clap Christoffer*
This commit is contained in:
parent
c530aa688d
commit
695a3d43db
2
main.py
2
main.py
|
@ -53,7 +53,7 @@ config = {
|
||||||
'train_perpetually': args.train_perpetually,
|
'train_perpetually': args.train_perpetually,
|
||||||
'model_storage_path': 'models',
|
'model_storage_path': 'models',
|
||||||
'bench_storage_path': 'bench',
|
'bench_storage_path': 'bench',
|
||||||
'board_representation': 'quack'
|
'board_representation': 'quack',
|
||||||
'force_creation': args.force_creation
|
'force_creation': args.force_creation
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -225,6 +225,9 @@ class Network:
|
||||||
(4, 4), (4, 5), (4, 6), (5, 5), (5, 6),
|
(4, 4), (4, 5), (4, 6), (5, 5), (5, 6),
|
||||||
(6, 6) ]
|
(6, 6) ]
|
||||||
|
|
||||||
|
# for roll in possible_rolls:
|
||||||
|
# print(len(Board.calculate_legal_states(board, player, roll)))
|
||||||
|
|
||||||
return [ Board.calculate_legal_states(board, player, roll)
|
return [ Board.calculate_legal_states(board, player, roll)
|
||||||
for roll
|
for roll
|
||||||
in possible_rolls ]
|
in possible_rolls ]
|
||||||
|
@ -284,7 +287,7 @@ class Network:
|
||||||
in boards ]
|
in boards ]
|
||||||
|
|
||||||
average_score_pairs = [
|
average_score_pairs = [
|
||||||
(inner_boards[0], average_plain([ average_for_score_pairs(ply(n - 1, inner_board, player * -1))
|
(inner_boards[0], average_plain([ average_for_score_pairs(ply(n - 1, inner_board, player * -1 if n == 1 else player))
|
||||||
for inner_board
|
for inner_board
|
||||||
in inner_boards[1] ]))
|
in inner_boards[1] ]))
|
||||||
for inner_boards
|
for inner_boards
|
||||||
|
|
|
@ -38,15 +38,58 @@ boards = {initial_state,
|
||||||
initial_state_2 }
|
initial_state_2 }
|
||||||
|
|
||||||
|
|
||||||
|
def gen_21_rolls():
|
||||||
|
"""
|
||||||
|
Calculate all possible rolls, [[1,1], [1,2] ..]
|
||||||
|
:return: All possible rolls
|
||||||
|
"""
|
||||||
|
a = []
|
||||||
|
for x in range(1, 7):
|
||||||
|
for y in range(1, 7):
|
||||||
|
if not [x, y] in a and not [y, x] in a:
|
||||||
|
a.append([x, y])
|
||||||
|
|
||||||
|
return a
|
||||||
|
|
||||||
|
def calc_all_scores(board, player):
|
||||||
|
scores = []
|
||||||
|
trans_board = network.board_trans_func(board, player)
|
||||||
|
rolls = gen_21_rolls()
|
||||||
|
for roll in rolls:
|
||||||
|
score = network.eval_state(session, trans_board)
|
||||||
|
scores.append(score)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_possible_states(board):
|
||||||
|
possible_rolls = [(1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
|
||||||
|
(1, 6), (2, 2), (2, 3), (2, 4), (2, 5),
|
||||||
|
(2, 6), (3, 3), (3, 4), (3, 5), (3, 6),
|
||||||
|
(4, 4), (4, 5), (4, 6), (5, 5), (5, 6),
|
||||||
|
(6, 6)]
|
||||||
|
|
||||||
|
for roll in possible_rolls:
|
||||||
|
meh = Board.calculate_legal_states(board, -1, roll)
|
||||||
|
print(len(meh))
|
||||||
|
return [Board.calculate_legal_states(board, -1, roll)
|
||||||
|
for roll
|
||||||
|
in possible_rolls]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#for board in boards:
|
||||||
|
# calculate_possible_states(board)
|
||||||
|
|
||||||
print("-"*30)
|
print("-"*30)
|
||||||
print(network.do_ply(session, boards, 1))
|
print(network.do_ply(session, boards, 1))
|
||||||
|
|
||||||
print(" "*10 + "network_test")
|
#print(" "*10 + "network_test")
|
||||||
print(" "*20 + "Depth 1")
|
#print(" "*20 + "Depth 1")
|
||||||
print(network.n_ply(1, session, boards, 1))
|
scores = network.n_ply(1, session, boards, 1)
|
||||||
|
|
||||||
print(" "*20 + "Depth 2")
|
|
||||||
print(network.n_ply(2, session, boards, 1))
|
#print(" "*20 + "Depth 2")
|
||||||
|
#print(network.n_ply(2, session, boards, 1))
|
||||||
|
|
||||||
# #print(x.shape)
|
# #print(x.shape)
|
||||||
# with graph_lol.as_default():
|
# with graph_lol.as_default():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user