advancedskrald/tensor_classifier.py

73 lines
2.2 KiB
Python
Raw Normal View History

2019-04-10 20:32:30 +00:00
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.python.keras import datasets, layers, models
import glob
import numpy as np
import cv2
import runner
from main import find_occupied_squares
from util import POSITION
new_model = models.load_model('test_chess_model.h5')
new_model.summary()
#board = cv2.imread("whole_boards/boards_for_empty/board_1554286488.605142_rank_3.png")
#board = cv2.imread("whole_boards/boards_for_empty/board_1554285167.655788_rank_5.png")
board = cv2.imread("whole_boards/boards_for_empty/board_1554288891.129901_rank_8.png")
warped = runner.warp_board(board)
occupied = find_occupied_squares(warped)
pos_1 = POSITION.C5
pos_2 = POSITION.D5
pos_3 = POSITION.G5
pos_4 = POSITION.H5
square_1 = runner.get_square(warped, pos_1)
square_2 = runner.get_square(warped, pos_2)
square_3 = runner.get_square(warped, pos_3)
square_4 = runner.get_square(warped, pos_4)
square_1 = cv2.cvtColor(square_1, cv2.COLOR_BGR2GRAY)
square_2 = cv2.cvtColor(square_2, cv2.COLOR_BGR2GRAY)
square_3 = cv2.cvtColor(square_3, cv2.COLOR_BGR2GRAY)
square_4 = cv2.cvtColor(square_4, cv2.COLOR_BGR2GRAY)
width, height = square_1.shape
square_1 = square_1 / 255.0
square_2 = square_2 / 255.0
square_3 = square_3 / 255.0
square_4 = square_4 / 255.0
pieces = ['knight', 'rook']
for pos, square in occupied:
square = cv2.cvtColor(square, cv2.COLOR_BGR2GRAY)
width, height = square.shape
square = square / 255.0
test = new_model.predict(np.array(square).reshape((-1, width, height, 1)))
text_color = 255
cv2.putText(square, f"{pos} {pieces[int(np.argmax(test))]}", (0, 50), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1,
color=(text_color,) * 3, thickness=3)
cv2.imshow(f"{pos}", square)
cv2.waitKey(0)
"""
for pos, square in [(pos_1, square_1), (pos_2, square_2), (pos_3, square_3), (pos_4, square_4)]:
test = new_model.predict(np.array(square).reshape((-1, width, height, 1)))
print(f"{pos}: {np.argmax(test)}")
text_color = 255
cv2.putText(square, f"{pos} {pieces[int(np.argmax(test))]}", (0, 50), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1,
color=(text_color,) * 3, thickness=3)
cv2.imshow(f"{pos}", square)
"""