diff --git a/network.py b/network.py index 65d7e9e..bc8d601 100644 --- a/network.py +++ b/network.py @@ -297,9 +297,9 @@ class Network: if n_init < 1: print("Unexpected argument n = {}".format(n_init)); exit() - boards_with_scores = ply(n_init, boards_init, player_init) + boards_with_scores = ply(n_init, boards_init, -1 * player_init) print(boards_with_scores) - scores = [ pair[1] + scores = [ ( pair[1] if player_init == 1 else (1 - pair[1]) ) for pair in boards_with_scores ] best_score_pair = boards_with_scores[np.array(scores).argmax()]