diff --git a/test.py b/test.py index 90cea23..0c6d8a6 100644 --- a/test.py +++ b/test.py @@ -737,6 +737,23 @@ class TestBoardFlip(unittest.TestCase): self.assertTrue((Board.board_features_tesauro(board, 1) == np.array(expected).reshape(1, 198)).all()) + def test_pubeval_features(self): + board = Board.initial_state + + expected = (0, + 2, 0, 0, 0, 0, -5, + 0, -3, 0, 0, 0, 5, + -5, 0, 0, 0, 3, 0, + 5, 0, 0, 0, 0, -2, + 0, + 0, 0) + + import numpy as np + self.assertTrue((Board.board_features_to_pubeval(board, 1) == + np.array(expected).reshape(1, 28)).all()) + self.assertTrue((Board.board_features_to_pubeval(board, -1) == + np.array(expected).reshape(1, 28)).all()) + def test_tesauro_bars(self): board = list(Board.initial_state) board[1] = 0