final
This commit is contained in:
parent
12ab541de6
commit
e97ec3fea9
BIN
2_piece_new_fuck_lel.h5
Normal file
BIN
2_piece_new_fuck_lel.h5
Normal file
Binary file not shown.
BIN
5_piece_new_fuck_lel.h5
Normal file
BIN
5_piece_new_fuck_lel.h5
Normal file
Binary file not shown.
BIN
classifiers/classifier_bishop/black.pkl
Normal file
BIN
classifiers/classifier_bishop/black.pkl
Normal file
Binary file not shown.
BIN
classifiers/classifier_bishop/white.pkl
Normal file
BIN
classifiers/classifier_bishop/white.pkl
Normal file
Binary file not shown.
BIN
classifiers/classifier_king/black.pkl
Normal file
BIN
classifiers/classifier_king/black.pkl
Normal file
Binary file not shown.
BIN
classifiers/classifier_king/white.pkl
Normal file
BIN
classifiers/classifier_king/white.pkl
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
classifiers/classifier_queen/black.pkl
Normal file
BIN
classifiers/classifier_queen/black.pkl
Normal file
Binary file not shown.
BIN
classifiers/classifier_queen/white.pkl
Normal file
BIN
classifiers/classifier_queen/white.pkl
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
205
main.py
205
main.py
|
@ -1,5 +1,10 @@
|
|||
import glob
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
|
@ -7,7 +12,8 @@ import numpy as np
|
|||
from sklearn.exceptions import DataConversionWarning
|
||||
|
||||
import runner
|
||||
from util import load_classifier, PIECE, COLOR, POSITION, Board, Squares, PieceAndColor, OUR_PIECES
|
||||
from tensor_classifier import predict_board, predict_piece, predict_empty_nn
|
||||
from util import load_classifier, PIECE, COLOR, POSITION, Board, Squares, PieceAndColor, OUR_PIECES, FILE, RANK, LESS_PIECE
|
||||
|
||||
warnings.filterwarnings(action='ignore', category=DataConversionWarning)
|
||||
np.set_printoptions(threshold=sys.maxsize)
|
||||
|
@ -32,7 +38,7 @@ def identify_piece(image: np.ndarray, position: POSITION, sift: cv2.xfeatures2d_
|
|||
#prob = classifier.predict_proba(data)
|
||||
|
||||
probs[piece.name][color.name] = prob[0, 1]
|
||||
print(f"{piece}, {color}, {prob[0, 1]}")
|
||||
#print(f"{piece}, {color}, {prob[0, 1]}")
|
||||
#if prob[0, 1] > best and color == position.color: # can only be best if correct color. Iterating through both colors for debugging only
|
||||
if prob[0, 1] > best:
|
||||
best = prob[0, 1]
|
||||
|
@ -92,7 +98,23 @@ def predict_empty(square: np.ndarray, position: POSITION) -> bool:
|
|||
empty_classifier = load_classifier(f"classifiers/classifier_empty/white_piece_on_{position.color}_square.pkl")
|
||||
prob = empty_classifier.predict_proba(np.array(y).reshape(1, -1))
|
||||
#print(f"{position}, {position.color}: {prob[0, 1]}")
|
||||
return prob[0, 1] > 0.95
|
||||
|
||||
y, x = np.histogram(square.ravel(), bins=64, range=[0, 256])
|
||||
|
||||
|
||||
|
||||
lel = np.array(y).reshape(1, -1)
|
||||
#print(lel[lel > 5000])
|
||||
#print(np.array(y).reshape(1, -1))
|
||||
|
||||
|
||||
return prob[0, 1] > 0.75
|
||||
|
||||
if position.color == "white":
|
||||
return prob[0, 1] > 0.75 or len(lel[lel > 5000]) > 5
|
||||
else:
|
||||
return prob[0, 1] > 0.65 or len(lel[lel > 5000]) > 5
|
||||
|
||||
|
||||
|
||||
def remove_most_empties(warped):
|
||||
|
@ -133,6 +155,139 @@ def remove_most_empties(warped):
|
|||
return non_empties
|
||||
|
||||
|
||||
|
||||
def test_empties():
|
||||
pred_empty = 0
|
||||
actual_empty = 0
|
||||
import random
|
||||
for filename in glob.glob(f"training_images/*/*_square/*.png"):
|
||||
square = cv2.imread(filename)
|
||||
if square.shape != (172, 172, 3):
|
||||
continue
|
||||
if 'empty' in filename:
|
||||
actual_empty += 1
|
||||
if "black" in filename:
|
||||
pos = POSITION.A1
|
||||
else:
|
||||
pos = POSITION.A2
|
||||
|
||||
square = cv2.GaussianBlur(square, (7, 7), 0)
|
||||
square = cv2.multiply(square, np.array([(random.random() * 0.50) + 0.90]))
|
||||
|
||||
pred_empty += predict_empty(square, pos)
|
||||
#pred_empty += 1 - predict_empty_nn(square)
|
||||
print(actual_empty)
|
||||
print(pred_empty)
|
||||
print(min(actual_empty,pred_empty)/max(actual_empty, pred_empty))
|
||||
|
||||
|
||||
def test_piece_recognition(svms = False):
|
||||
sift = cv2.xfeatures2d.SIFT_create()
|
||||
total = 0
|
||||
correct_guess = 0
|
||||
for filename in glob.glob(f"training_images/rook/*/*.png"):
|
||||
img = cv2.imread(filename)
|
||||
square = cv2.GaussianBlur(img, (9, 9), 0)
|
||||
square = cv2.multiply(square, np.array([(random.random() * 5) + 0.90]))
|
||||
|
||||
#cv2.imwrite("normal_square.png", img)
|
||||
#cv2.imwrite("modified_square.png", square)
|
||||
#cv2.imshow("normal", img)
|
||||
#cv2.imshow("modified", square)
|
||||
#cv2.waitKey(0)
|
||||
|
||||
if "black" in filename:
|
||||
pos = POSITION.A1
|
||||
else:
|
||||
pos = POSITION.A2
|
||||
|
||||
if (svms):
|
||||
res = identify_piece(square, pos, sift)[0]
|
||||
correct_guess += (res == LESS_PIECE.ROOK)
|
||||
else:
|
||||
res = predict_piece(square)
|
||||
correct_guess += (res == LESS_PIECE.ROOK)
|
||||
total += 1
|
||||
|
||||
for filename in glob.glob(f"training_images/knight/*/*.png"):
|
||||
img = cv2.imread(filename)
|
||||
square = cv2.GaussianBlur(img, (7, 7), 0)
|
||||
square = cv2.multiply(square, np.array([(random.random() * 0.50) + 0.90]))
|
||||
|
||||
if "black" in filename:
|
||||
pos = POSITION.A1
|
||||
else:
|
||||
pos = POSITION.A2
|
||||
|
||||
if (svms):
|
||||
res = identify_piece(square, pos, sift)[0]
|
||||
correct_guess += (res == LESS_PIECE.KNIGHT)
|
||||
else:
|
||||
res = predict_piece(square)
|
||||
correct_guess += (res == LESS_PIECE.KNIGHT)
|
||||
total += 1
|
||||
|
||||
for filename in glob.glob(f"training_images/bishop/*/*.png"):
|
||||
img = cv2.imread(filename)
|
||||
square = cv2.GaussianBlur(img, (7, 7), 0)
|
||||
square = cv2.multiply(square, np.array([(random.random() * 0.50) + 0.90]))
|
||||
|
||||
if "black" in filename:
|
||||
pos = POSITION.A1
|
||||
else:
|
||||
pos = POSITION.A2
|
||||
|
||||
if (svms):
|
||||
res = identify_piece(square, pos, sift)[0]
|
||||
correct_guess += (res == LESS_PIECE.BISHOP)
|
||||
else:
|
||||
res = predict_piece(square)
|
||||
correct_guess += (res == LESS_PIECE.BISHOP)
|
||||
total += 1
|
||||
|
||||
for filename in glob.glob(f"training_images/king/*/*.png"):
|
||||
img = cv2.imread(filename)
|
||||
square = cv2.GaussianBlur(img, (7, 7), 0)
|
||||
square = cv2.multiply(square, np.array([(random.random() * 0.50) + 0.90]))
|
||||
|
||||
if "black" in filename:
|
||||
pos = POSITION.A1
|
||||
else:
|
||||
pos = POSITION.A2
|
||||
|
||||
if (svms):
|
||||
res = identify_piece(square, pos, sift)[0]
|
||||
correct_guess += (res == LESS_PIECE.KING)
|
||||
else:
|
||||
res = predict_piece(square)
|
||||
correct_guess += (res == LESS_PIECE.KING)
|
||||
total += 1
|
||||
|
||||
for filename in glob.glob(f"training_images/queen/*/*.png"):
|
||||
img = cv2.imread(filename)
|
||||
square = cv2.GaussianBlur(img, (7, 7), 0)
|
||||
square = cv2.multiply(square, np.array([(random.random() * 0.50) + 0.90]))
|
||||
|
||||
if "black" in filename:
|
||||
pos = POSITION.A1
|
||||
else:
|
||||
pos = POSITION.A2
|
||||
|
||||
if (svms):
|
||||
res = identify_piece(square, pos, sift)[0]
|
||||
correct_guess += (res == LESS_PIECE.QUEEN)
|
||||
else:
|
||||
res = predict_piece(square)
|
||||
correct_guess += (res == LESS_PIECE.QUEEN)
|
||||
total += 1
|
||||
|
||||
print(total)
|
||||
print(correct_guess)
|
||||
print(min(total, correct_guess)/max(total, correct_guess))
|
||||
|
||||
|
||||
|
||||
|
||||
def find_occupied_squares(warped: np.ndarray) -> Squares:
|
||||
non_empties = remove_most_empties(warped)
|
||||
|
||||
|
@ -144,18 +299,54 @@ def find_occupied_squares(warped: np.ndarray) -> Squares:
|
|||
return completely_non_empties
|
||||
|
||||
|
||||
def find_occupied_using_nn(warped: np.ndarray) -> Squares:
|
||||
non_empties = runner.get_squares(warped)
|
||||
completely_non_empties = {}
|
||||
for (position, square) in non_empties.items():
|
||||
prediction = predict_empty_nn(square)
|
||||
if prediction:
|
||||
completely_non_empties[position] = square
|
||||
|
||||
return completely_non_empties
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
test_piece_recognition(svms=True)
|
||||
exit()
|
||||
#runner.train_pieces_svm()
|
||||
|
||||
board = cv2.imread("whole_boards/boards_for_empty/board_1554288891.129901_rank_8.png")
|
||||
#board = cv2.imread("whole_boards/boards_for_empty/board_1554286515.323962_rank_3.png")
|
||||
#board = cv2.imread("quality_check.png")
|
||||
#board = cv2.imread("whole_boards/boards_for_empty/board_1554286526.199486_rank_3.png")
|
||||
board = cv2.imread("whole_boards/boards_for_empty/lmao_xd_gg_v2.png")
|
||||
start = time.time()
|
||||
warped = runner.warp_board(board)
|
||||
print(time.time() - start)
|
||||
#cv2.imshow("warped", warped)
|
||||
#cv2.waitKey(0)
|
||||
# squares = runner.get_squares(warped)
|
||||
#squares = find_occupied_squares(warped)
|
||||
squares = find_occupied_using_nn(warped)
|
||||
|
||||
tmp = find_occupied_squares(warped)
|
||||
for pos, square in tmp:
|
||||
for pos, square in squares.items():
|
||||
piece = predict_piece(square)
|
||||
cv2.putText(square, f"{pos} {piece}", (0, 50), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,)*3, thickness=3)
|
||||
cv2.imshow(f"{pos}", square)
|
||||
cv2.waitKey(0)
|
||||
|
||||
|
||||
exit()
|
||||
tmp = find_occupied_squares(warped)
|
||||
|
||||
#for pos, square in tmp:
|
||||
# cv2.imshow(f"{pos}", square)
|
||||
#cv2.waitKey(0)
|
||||
|
||||
|
||||
board = predict_board(tmp)
|
||||
|
||||
for pos, piece in board.items():
|
||||
print(f"{pos}, {piece}")
|
||||
|
||||
exit()
|
||||
|
||||
"""
|
||||
|
|
|
@ -6,15 +6,17 @@ import runner
|
|||
from util import FILE, RANK, PIECE, COLOR, imwrite, POSITION
|
||||
|
||||
cap = cv2.VideoCapture(0)
|
||||
#cap.set(cv2.CAP_PROP_FRAME_WIDTH, 4096)
|
||||
#cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 2160)
|
||||
|
||||
color = COLOR.BLACK
|
||||
rank = RANK.EIGHT
|
||||
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1920)
|
||||
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 1080)
|
||||
|
||||
color = COLOR.WHITE
|
||||
rank = RANK.FIVE
|
||||
pieces = {
|
||||
PIECE.rook: [POSITION((FILE.A, rank)), POSITION((FILE.F, rank))],
|
||||
PIECE.knight: [POSITION((FILE.E, rank)), POSITION((FILE.H, rank))],
|
||||
PIECE.bishop: [POSITION((FILE.C, rank)), POSITION((FILE.D, rank))],
|
||||
PIECE.queen: [POSITION((FILE.B, rank))],
|
||||
PIECE.king: [POSITION((FILE.G, rank))],
|
||||
PIECE.PAWN: [POSITION((FILE.A, rank)), POSITION((FILE.B, rank)), POSITION((FILE.C, rank)), POSITION((FILE.D, rank)),
|
||||
POSITION((FILE.E, rank)), POSITION((FILE.F, rank)), POSITION((FILE.G, rank)), POSITION((FILE.H, rank))]
|
||||
}
|
||||
|
||||
while True:
|
||||
|
@ -26,16 +28,18 @@ while True:
|
|||
|
||||
if cv2.waitKey(100) & 0xFF == ord("c"):
|
||||
print(f"capturing frame")
|
||||
imwrite(f"whole_boards/boards_for_empty/board_{datetime.utcnow().timestamp()}_.png", frame)
|
||||
#imwrite(f"whole_boards/boards_for_empty/board_{datetime.utcnow().timestamp()}_.png", frame)
|
||||
imwrite("whole_boards/boards_for_empty/lol_gg_xD.png", frame)
|
||||
|
||||
break
|
||||
warped = runner.warp_board(frame)
|
||||
|
||||
runner.save_empty_fields(warped, skip_rank=rank)
|
||||
#runner.save_empty_fields(warped, skip_rank=rank)
|
||||
|
||||
for piece, positions in pieces.items():
|
||||
for position in positions:
|
||||
square = runner.get_square(warped, *position)
|
||||
x, y = position
|
||||
imwrite(f"training_images/{piece}/{position.color}_square/training_{x}{str(y)}_{datetime.utcnow().timestamp()}.png", square)
|
||||
square = runner.get_square(warped, position)
|
||||
imwrite(f"training_images/{piece}/{position.color}_square/training_{position}_{datetime.utcnow().timestamp()}.png", square)
|
||||
|
||||
|
||||
# When everything done, release the capture
|
||||
|
|
54
runner.py
54
runner.py
|
@ -1,9 +1,13 @@
|
|||
import glob
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import copyreg
|
||||
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from sklearn import cluster, metrics, svm, neural_network
|
||||
|
@ -14,6 +18,11 @@ from sklearn.preprocessing import StandardScaler
|
|||
from util import RANK, POSITION, imwrite, PIECE, COLOR, Squares, OUR_PIECES
|
||||
|
||||
here: Path = Path(__file__).parent
|
||||
BASELINE = cv2.imread(str(here.joinpath("new_baseline_board.png")))
|
||||
BASELINE_GRAY = cv2.cvtColor(BASELINE, cv2.COLOR_BGR2GRAY)
|
||||
SIFT = cv2.xfeatures2d.SIFT_create()
|
||||
BASELINE_KEYPOINTS = SIFT.detect(BASELINE_GRAY)
|
||||
BASELINE_KEYPOINTS, BASELINE_DES = SIFT.compute(BASELINE_GRAY, BASELINE_KEYPOINTS)
|
||||
|
||||
|
||||
def generate_centers(number_of_clusters, sift: cv2.xfeatures2d_SIFT):
|
||||
|
@ -112,7 +121,7 @@ def train_pieces_svm() -> None:
|
|||
current_weight = len(glob.glob(f"training_images/{piece}/{color}_square/*.png"))
|
||||
print(f"Training for piece: {piece}")
|
||||
X, Y = load_training_data(piece, color)
|
||||
classifier = svm.SVC(C=10, gamma=0.01, class_weight={0: 15, 1: 0.8}, probability=True)
|
||||
classifier = svm.SVC(gamma=0.01, class_weight={0: current_weight, 1: total_weights}, probability=True)
|
||||
classifier.fit(X, Y)
|
||||
joblib.dump(classifier, f"classifiers/classifier_{piece}/{color}.pkl")
|
||||
|
||||
|
@ -140,23 +149,30 @@ def find_keypoints(camera_image: np.ndarray, baseline: np.ndarray, debug=False)
|
|||
|
||||
:return: (src points, dest points)
|
||||
"""
|
||||
cv2.imwrite("camera_image.png", camera_image)
|
||||
camera_image_gray = cv2.cvtColor(camera_image, cv2.COLOR_BGR2GRAY)
|
||||
baseline_gray = cv2.cvtColor(baseline, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
sift = cv2.xfeatures2d.SIFT_create()
|
||||
camera_image_keypoints = sift.detect(camera_image_gray, None)
|
||||
baseline_keypoints = sift.detect(baseline_gray, None)
|
||||
#sift = cv2.xfeatures2d.SURF_create()
|
||||
|
||||
camera_image_keypoints, des = sift.compute(camera_image_gray, camera_image_keypoints)
|
||||
baseline_keypoints, des2 = sift.compute(baseline_gray, baseline_keypoints)
|
||||
kp_start = time.time()
|
||||
camera_image_keypoints = SIFT.detect(camera_image_gray, None)
|
||||
|
||||
|
||||
|
||||
camera_image_keypoints, des = SIFT.compute(camera_image_gray, camera_image_keypoints)
|
||||
#print("kp:",time.time() - kp_start)
|
||||
|
||||
def_flan = time.time()
|
||||
# FLANN parameters
|
||||
FLANN_INDEX_KDTREE = 0
|
||||
index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=8)
|
||||
search_params = dict(checks=100) # or pass empty dictionary
|
||||
|
||||
flann_start = time.time()
|
||||
flann = cv2.FlannBasedMatcher(index_params, search_params)
|
||||
matches = flann.knnMatch(des, des2, k=2)
|
||||
#print("end_def:", time.time() - def_flan)
|
||||
matches = flann.knnMatch(des, BASELINE_DES, k=2)
|
||||
#print("flann:",time.time() - flann_start)
|
||||
|
||||
# Need to draw only good matches, so create a mask
|
||||
matchesMask = [[0, 0] for _ in range(len(matches))]
|
||||
|
@ -168,6 +184,8 @@ def find_keypoints(camera_image: np.ndarray, baseline: np.ndarray, debug=False)
|
|||
matchesMask[i] = [1, 0]
|
||||
good_matches.append([m, n])
|
||||
|
||||
#good_matches = list(filter(lambda x: x[0].distance < 0.55 * x[1].distance, matches))
|
||||
|
||||
if debug:
|
||||
# Save keypoints
|
||||
keypoints_image = camera_image.copy()
|
||||
|
@ -178,7 +196,7 @@ def find_keypoints(camera_image: np.ndarray, baseline: np.ndarray, debug=False)
|
|||
camera_image,
|
||||
camera_image_keypoints,
|
||||
baseline,
|
||||
baseline_keypoints,
|
||||
BASELINE_KEYPOINTS,
|
||||
matches,
|
||||
None,
|
||||
matchColor=(0, 255, 0),
|
||||
|
@ -192,15 +210,16 @@ def find_keypoints(camera_image: np.ndarray, baseline: np.ndarray, debug=False)
|
|||
src_points = np.zeros((len(good_matches), 2), dtype=np.float32)
|
||||
dst_points = np.zeros((len(good_matches), 2), dtype=np.float32)
|
||||
|
||||
|
||||
for i, (m, n) in enumerate(good_matches):
|
||||
src_points[i, :] = camera_image_keypoints[m.queryIdx].pt
|
||||
dst_points[i, :] = baseline_keypoints[m.trainIdx].pt
|
||||
dst_points[i, :] = BASELINE_KEYPOINTS[m.trainIdx].pt
|
||||
|
||||
return src_points, dst_points
|
||||
|
||||
|
||||
def find_homography(camera_image: np.ndarray,
|
||||
baseline: np.ndarray = cv2.imread(str(here.joinpath("new_baseline_board.png"))),
|
||||
baseline: np.ndarray = BASELINE,
|
||||
debug=False) -> np.ndarray:
|
||||
src_points, dst_points = find_keypoints(camera_image, baseline, debug=debug)
|
||||
h, mask = cv2.findHomography(src_points, dst_points, cv2.RANSAC)
|
||||
|
@ -209,11 +228,10 @@ def find_homography(camera_image: np.ndarray,
|
|||
|
||||
|
||||
def warp_board(camera_image: np.ndarray, homography: np.ndarray = None, debug=False) -> np.ndarray:
|
||||
baseline = cv2.imread(str(here.joinpath("new_baseline_board.png")))
|
||||
if homography is None:
|
||||
homography = find_homography(camera_image, baseline, debug=debug)
|
||||
homography = find_homography(camera_image, BASELINE, debug=debug)
|
||||
|
||||
height, width, channels = baseline.shape
|
||||
height, width, channels = BASELINE.shape
|
||||
return cv2.warpPerspective(camera_image, homography, (width, height))
|
||||
|
||||
|
||||
|
@ -240,11 +258,14 @@ def get_squares(warped_board: np.ndarray) -> Squares:
|
|||
for position in POSITION}
|
||||
|
||||
|
||||
def save_empty_fields(warped_board: np.ndarray, skip_rank: RANK = None) -> None:
|
||||
def save_empty_fields(warped_board: np.ndarray, skip_rank: RANK = None, fourk=False) -> None:
|
||||
for position in POSITION:
|
||||
if position.rank == skip_rank:
|
||||
continue
|
||||
square = get_square(warped_board, position)
|
||||
if fourk:
|
||||
imwrite(f"training_images/4k/empty/{position.color}_square/training_{position}_{datetime.utcnow().timestamp()}.png", square)
|
||||
else:
|
||||
imwrite(f"training_images/empty/{position.color}_square/training_{position}_{datetime.utcnow().timestamp()}.png", square)
|
||||
|
||||
|
||||
|
@ -292,4 +313,5 @@ def train_nn():
|
|||
|
||||
if __name__ == '__main__':
|
||||
#train_nn()
|
||||
train_empty_or_piece_hist()
|
||||
do_pre_processing()
|
||||
train_pieces_svm()
|
||||
|
|
|
@ -1,27 +1,55 @@
|
|||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from tensorflow.python.keras import models
|
||||
|
||||
from util import PIECE, Squares, Board
|
||||
from util import PIECE, Squares, Board, OUR_PIECES, LESS_PIECE
|
||||
|
||||
here: Path = Path(__file__).parent
|
||||
|
||||
|
||||
#new_model = models.load_model(str(here.joinpath('pls_model_new_fuck_lel.h5')))
|
||||
new_model = models.load_model(str(here.joinpath('5_piece_new_fuck_lel.h5')))
|
||||
empty_class = models.load_model(str(here.joinpath('2_piece_new_fuck_lel.h5')))
|
||||
#all_piece_model = models.load_model(str(here.joinpath('6_piece_new_fuck_lel.h5')))
|
||||
|
||||
new_model = models.load_model('chess_model_3_pieces.h5')
|
||||
#new_model.summary()
|
||||
|
||||
#board = cv2.imread("whole_boards/boards_for_empty/board_1554286488.605142_rank_3.png")
|
||||
#board = cv2.imread("whole_boards/boards_for_empty/board_1554285167.655788_rank_5.png")
|
||||
board = cv2.imread("whole_boards/boards_for_empty/board_1554288891.129901_rank_8.png")
|
||||
#board = cv2.imread("whole_boards/boards_for_empty/board_1554288891.129901_rank_8.png")
|
||||
|
||||
def predict_empty_nn(square):
|
||||
square = square[6:-6, 6:-6]
|
||||
width, height, channels = square.shape
|
||||
square = square / 255.0
|
||||
test = empty_class.predict(np.array(square).reshape((-1, width, height, 3)))
|
||||
print([round(x, 2) for x in test[0]])
|
||||
return int(np.argmax(test))
|
||||
|
||||
def predict_piece(square):
|
||||
square = square[6:-6, 6:-6]
|
||||
width, height, channels = square.shape
|
||||
square = square / 255.0
|
||||
test = new_model.predict(np.array(square).reshape((-1, width, height, 3)))
|
||||
print([round(x, 2) for x in test[0]])
|
||||
return LESS_PIECE(int(np.argmax(test)))
|
||||
|
||||
def predict_board(occupied_squares: Squares) -> Board:
|
||||
board = Board()
|
||||
for pos, square in occupied_squares.items():
|
||||
square = cv2.cvtColor(square, cv2.COLOR_BGR2GRAY)
|
||||
width, height = square.shape
|
||||
#square = cv2.cvtColor(square, cv2.COLOR_BGR2GRAY)
|
||||
square = square[6:-6, 6:-6]
|
||||
|
||||
width, height, channels = square.shape
|
||||
square = square / 255.0
|
||||
test = new_model.predict(np.array(square).reshape((-1, width, height, 1)))
|
||||
#test = new_model.predict(np.array(square).reshape((-1, width, height, 3)))
|
||||
test = new_model.predict(np.array(square).reshape((-1, width, height, 3)))
|
||||
|
||||
#cv2.putText(square, f"{pos} {PIECE(int(np.argmax(test)))}", (0, 50), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,) * 3, thickness=3)
|
||||
#cv2.imwrite("lel", square)
|
||||
board[pos] = PIECE(int(np.argmax(test)))
|
||||
#print(f"{pos}, {test}")
|
||||
board[pos] = LESS_PIECE(int(np.argmax(test)))
|
||||
|
||||
return board
|
||||
|
|
22
util.py
22
util.py
|
@ -9,6 +9,8 @@ import cv2
|
|||
import numpy as np
|
||||
from sklearn.externals import joblib
|
||||
|
||||
here: Path = Path(__file__).parent
|
||||
|
||||
|
||||
class COLOR(Enum):
|
||||
WHITE = "white"
|
||||
|
@ -30,13 +32,25 @@ class PIECE(Enum):
|
|||
def __str__(self) -> str:
|
||||
return self.name.lower()
|
||||
|
||||
class LESS_PIECE(Enum):
|
||||
ROOK = 0
|
||||
KNIGHT = 1
|
||||
BISHOP = 2
|
||||
KING = 3
|
||||
QUEEN = 4
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name.lower()
|
||||
|
||||
|
||||
PieceAndColor = Tuple[PIECE, COLOR]
|
||||
|
||||
OUR_PIECES = (
|
||||
PIECE.KNIGHT,
|
||||
PIECE.ROOK,
|
||||
PIECE.BISHOP,
|
||||
LESS_PIECE.ROOK,
|
||||
LESS_PIECE.KNIGHT,
|
||||
LESS_PIECE.BISHOP,
|
||||
LESS_PIECE.KING,
|
||||
LESS_PIECE.QUEEN
|
||||
)
|
||||
|
||||
|
||||
|
@ -100,4 +114,4 @@ def imwrite(*args, **kwargs):
|
|||
@lru_cache()
|
||||
def load_classifier(filename):
|
||||
# print(f"Loading classifier {filename}")
|
||||
return joblib.load(filename)
|
||||
return joblib.load(str(here.joinpath(filename)))
|
||||
|
|
14
web.py
14
web.py
|
@ -7,12 +7,14 @@ from flask import Flask, jsonify, request
|
|||
from main import find_occupied_squares
|
||||
from runner import find_homography, warp_board
|
||||
from tensor_classifier import predict_board
|
||||
from time import time
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route("/", methods=["POST"])
|
||||
def process():
|
||||
print("Received request")
|
||||
data = request.get_json(force=True)
|
||||
|
||||
decoded = base64.b64decode(data["img"])
|
||||
|
@ -21,11 +23,19 @@ def process():
|
|||
camera_img = cv2.cvtColor(camera_img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# def do_everything:
|
||||
homography = find_homography(camera_img)
|
||||
start = time()
|
||||
print("Finding keypoints")
|
||||
homography = find_homography(camera_img, debug=True)
|
||||
print("Computing homography")
|
||||
warped_board = warp_board(camera_img, homography)
|
||||
print("Warping board")
|
||||
cv2.imwrite("warped.png", warped_board)
|
||||
print("Removing empty squares")
|
||||
occupied_squares = find_occupied_squares(warped_board)
|
||||
print("Predicting board state")
|
||||
board = predict_board(occupied_squares)
|
||||
|
||||
print(f"The request took {round(time() - start, 3)} seconds")
|
||||
print("Returning board state")
|
||||
# Finally, output for unity to read
|
||||
return jsonify({
|
||||
"homography": homography.tolist(),
|
||||
|
|
Loading…
Reference in New Issue
Block a user