final
This commit is contained in:
parent
12ab541de6
commit
e97ec3fea9
BIN
2_piece_new_fuck_lel.h5
Normal file
BIN
2_piece_new_fuck_lel.h5
Normal file
Binary file not shown.
BIN
5_piece_new_fuck_lel.h5
Normal file
BIN
5_piece_new_fuck_lel.h5
Normal file
Binary file not shown.
BIN
classifiers/classifier_bishop/black.pkl
Normal file
BIN
classifiers/classifier_bishop/black.pkl
Normal file
Binary file not shown.
BIN
classifiers/classifier_bishop/white.pkl
Normal file
BIN
classifiers/classifier_bishop/white.pkl
Normal file
Binary file not shown.
BIN
classifiers/classifier_king/black.pkl
Normal file
BIN
classifiers/classifier_king/black.pkl
Normal file
Binary file not shown.
BIN
classifiers/classifier_king/white.pkl
Normal file
BIN
classifiers/classifier_king/white.pkl
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
classifiers/classifier_queen/black.pkl
Normal file
BIN
classifiers/classifier_queen/black.pkl
Normal file
Binary file not shown.
BIN
classifiers/classifier_queen/white.pkl
Normal file
BIN
classifiers/classifier_queen/white.pkl
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
205
main.py
205
main.py
|
@ -1,5 +1,10 @@
|
||||||
|
import glob
|
||||||
|
import random
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
import time
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
@ -7,7 +12,8 @@ import numpy as np
|
||||||
from sklearn.exceptions import DataConversionWarning
|
from sklearn.exceptions import DataConversionWarning
|
||||||
|
|
||||||
import runner
|
import runner
|
||||||
from util import load_classifier, PIECE, COLOR, POSITION, Board, Squares, PieceAndColor, OUR_PIECES
|
from tensor_classifier import predict_board, predict_piece, predict_empty_nn
|
||||||
|
from util import load_classifier, PIECE, COLOR, POSITION, Board, Squares, PieceAndColor, OUR_PIECES, FILE, RANK, LESS_PIECE
|
||||||
|
|
||||||
warnings.filterwarnings(action='ignore', category=DataConversionWarning)
|
warnings.filterwarnings(action='ignore', category=DataConversionWarning)
|
||||||
np.set_printoptions(threshold=sys.maxsize)
|
np.set_printoptions(threshold=sys.maxsize)
|
||||||
|
@ -32,7 +38,7 @@ def identify_piece(image: np.ndarray, position: POSITION, sift: cv2.xfeatures2d_
|
||||||
#prob = classifier.predict_proba(data)
|
#prob = classifier.predict_proba(data)
|
||||||
|
|
||||||
probs[piece.name][color.name] = prob[0, 1]
|
probs[piece.name][color.name] = prob[0, 1]
|
||||||
print(f"{piece}, {color}, {prob[0, 1]}")
|
#print(f"{piece}, {color}, {prob[0, 1]}")
|
||||||
#if prob[0, 1] > best and color == position.color: # can only be best if correct color. Iterating through both colors for debugging only
|
#if prob[0, 1] > best and color == position.color: # can only be best if correct color. Iterating through both colors for debugging only
|
||||||
if prob[0, 1] > best:
|
if prob[0, 1] > best:
|
||||||
best = prob[0, 1]
|
best = prob[0, 1]
|
||||||
|
@ -92,7 +98,23 @@ def predict_empty(square: np.ndarray, position: POSITION) -> bool:
|
||||||
empty_classifier = load_classifier(f"classifiers/classifier_empty/white_piece_on_{position.color}_square.pkl")
|
empty_classifier = load_classifier(f"classifiers/classifier_empty/white_piece_on_{position.color}_square.pkl")
|
||||||
prob = empty_classifier.predict_proba(np.array(y).reshape(1, -1))
|
prob = empty_classifier.predict_proba(np.array(y).reshape(1, -1))
|
||||||
#print(f"{position}, {position.color}: {prob[0, 1]}")
|
#print(f"{position}, {position.color}: {prob[0, 1]}")
|
||||||
return prob[0, 1] > 0.95
|
|
||||||
|
y, x = np.histogram(square.ravel(), bins=64, range=[0, 256])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
lel = np.array(y).reshape(1, -1)
|
||||||
|
#print(lel[lel > 5000])
|
||||||
|
#print(np.array(y).reshape(1, -1))
|
||||||
|
|
||||||
|
|
||||||
|
return prob[0, 1] > 0.75
|
||||||
|
|
||||||
|
if position.color == "white":
|
||||||
|
return prob[0, 1] > 0.75 or len(lel[lel > 5000]) > 5
|
||||||
|
else:
|
||||||
|
return prob[0, 1] > 0.65 or len(lel[lel > 5000]) > 5
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def remove_most_empties(warped):
|
def remove_most_empties(warped):
|
||||||
|
@ -133,6 +155,139 @@ def remove_most_empties(warped):
|
||||||
return non_empties
|
return non_empties
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_empties():
|
||||||
|
pred_empty = 0
|
||||||
|
actual_empty = 0
|
||||||
|
import random
|
||||||
|
for filename in glob.glob(f"training_images/*/*_square/*.png"):
|
||||||
|
square = cv2.imread(filename)
|
||||||
|
if square.shape != (172, 172, 3):
|
||||||
|
continue
|
||||||
|
if 'empty' in filename:
|
||||||
|
actual_empty += 1
|
||||||
|
if "black" in filename:
|
||||||
|
pos = POSITION.A1
|
||||||
|
else:
|
||||||
|
pos = POSITION.A2
|
||||||
|
|
||||||
|
square = cv2.GaussianBlur(square, (7, 7), 0)
|
||||||
|
square = cv2.multiply(square, np.array([(random.random() * 0.50) + 0.90]))
|
||||||
|
|
||||||
|
pred_empty += predict_empty(square, pos)
|
||||||
|
#pred_empty += 1 - predict_empty_nn(square)
|
||||||
|
print(actual_empty)
|
||||||
|
print(pred_empty)
|
||||||
|
print(min(actual_empty,pred_empty)/max(actual_empty, pred_empty))
|
||||||
|
|
||||||
|
|
||||||
|
def test_piece_recognition(svms = False):
|
||||||
|
sift = cv2.xfeatures2d.SIFT_create()
|
||||||
|
total = 0
|
||||||
|
correct_guess = 0
|
||||||
|
for filename in glob.glob(f"training_images/rook/*/*.png"):
|
||||||
|
img = cv2.imread(filename)
|
||||||
|
square = cv2.GaussianBlur(img, (9, 9), 0)
|
||||||
|
square = cv2.multiply(square, np.array([(random.random() * 5) + 0.90]))
|
||||||
|
|
||||||
|
#cv2.imwrite("normal_square.png", img)
|
||||||
|
#cv2.imwrite("modified_square.png", square)
|
||||||
|
#cv2.imshow("normal", img)
|
||||||
|
#cv2.imshow("modified", square)
|
||||||
|
#cv2.waitKey(0)
|
||||||
|
|
||||||
|
if "black" in filename:
|
||||||
|
pos = POSITION.A1
|
||||||
|
else:
|
||||||
|
pos = POSITION.A2
|
||||||
|
|
||||||
|
if (svms):
|
||||||
|
res = identify_piece(square, pos, sift)[0]
|
||||||
|
correct_guess += (res == LESS_PIECE.ROOK)
|
||||||
|
else:
|
||||||
|
res = predict_piece(square)
|
||||||
|
correct_guess += (res == LESS_PIECE.ROOK)
|
||||||
|
total += 1
|
||||||
|
|
||||||
|
for filename in glob.glob(f"training_images/knight/*/*.png"):
|
||||||
|
img = cv2.imread(filename)
|
||||||
|
square = cv2.GaussianBlur(img, (7, 7), 0)
|
||||||
|
square = cv2.multiply(square, np.array([(random.random() * 0.50) + 0.90]))
|
||||||
|
|
||||||
|
if "black" in filename:
|
||||||
|
pos = POSITION.A1
|
||||||
|
else:
|
||||||
|
pos = POSITION.A2
|
||||||
|
|
||||||
|
if (svms):
|
||||||
|
res = identify_piece(square, pos, sift)[0]
|
||||||
|
correct_guess += (res == LESS_PIECE.KNIGHT)
|
||||||
|
else:
|
||||||
|
res = predict_piece(square)
|
||||||
|
correct_guess += (res == LESS_PIECE.KNIGHT)
|
||||||
|
total += 1
|
||||||
|
|
||||||
|
for filename in glob.glob(f"training_images/bishop/*/*.png"):
|
||||||
|
img = cv2.imread(filename)
|
||||||
|
square = cv2.GaussianBlur(img, (7, 7), 0)
|
||||||
|
square = cv2.multiply(square, np.array([(random.random() * 0.50) + 0.90]))
|
||||||
|
|
||||||
|
if "black" in filename:
|
||||||
|
pos = POSITION.A1
|
||||||
|
else:
|
||||||
|
pos = POSITION.A2
|
||||||
|
|
||||||
|
if (svms):
|
||||||
|
res = identify_piece(square, pos, sift)[0]
|
||||||
|
correct_guess += (res == LESS_PIECE.BISHOP)
|
||||||
|
else:
|
||||||
|
res = predict_piece(square)
|
||||||
|
correct_guess += (res == LESS_PIECE.BISHOP)
|
||||||
|
total += 1
|
||||||
|
|
||||||
|
for filename in glob.glob(f"training_images/king/*/*.png"):
|
||||||
|
img = cv2.imread(filename)
|
||||||
|
square = cv2.GaussianBlur(img, (7, 7), 0)
|
||||||
|
square = cv2.multiply(square, np.array([(random.random() * 0.50) + 0.90]))
|
||||||
|
|
||||||
|
if "black" in filename:
|
||||||
|
pos = POSITION.A1
|
||||||
|
else:
|
||||||
|
pos = POSITION.A2
|
||||||
|
|
||||||
|
if (svms):
|
||||||
|
res = identify_piece(square, pos, sift)[0]
|
||||||
|
correct_guess += (res == LESS_PIECE.KING)
|
||||||
|
else:
|
||||||
|
res = predict_piece(square)
|
||||||
|
correct_guess += (res == LESS_PIECE.KING)
|
||||||
|
total += 1
|
||||||
|
|
||||||
|
for filename in glob.glob(f"training_images/queen/*/*.png"):
|
||||||
|
img = cv2.imread(filename)
|
||||||
|
square = cv2.GaussianBlur(img, (7, 7), 0)
|
||||||
|
square = cv2.multiply(square, np.array([(random.random() * 0.50) + 0.90]))
|
||||||
|
|
||||||
|
if "black" in filename:
|
||||||
|
pos = POSITION.A1
|
||||||
|
else:
|
||||||
|
pos = POSITION.A2
|
||||||
|
|
||||||
|
if (svms):
|
||||||
|
res = identify_piece(square, pos, sift)[0]
|
||||||
|
correct_guess += (res == LESS_PIECE.QUEEN)
|
||||||
|
else:
|
||||||
|
res = predict_piece(square)
|
||||||
|
correct_guess += (res == LESS_PIECE.QUEEN)
|
||||||
|
total += 1
|
||||||
|
|
||||||
|
print(total)
|
||||||
|
print(correct_guess)
|
||||||
|
print(min(total, correct_guess)/max(total, correct_guess))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def find_occupied_squares(warped: np.ndarray) -> Squares:
|
def find_occupied_squares(warped: np.ndarray) -> Squares:
|
||||||
non_empties = remove_most_empties(warped)
|
non_empties = remove_most_empties(warped)
|
||||||
|
|
||||||
|
@ -144,18 +299,54 @@ def find_occupied_squares(warped: np.ndarray) -> Squares:
|
||||||
return completely_non_empties
|
return completely_non_empties
|
||||||
|
|
||||||
|
|
||||||
|
def find_occupied_using_nn(warped: np.ndarray) -> Squares:
|
||||||
|
non_empties = runner.get_squares(warped)
|
||||||
|
completely_non_empties = {}
|
||||||
|
for (position, square) in non_empties.items():
|
||||||
|
prediction = predict_empty_nn(square)
|
||||||
|
if prediction:
|
||||||
|
completely_non_empties[position] = square
|
||||||
|
|
||||||
|
return completely_non_empties
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
test_piece_recognition(svms=True)
|
||||||
|
exit()
|
||||||
#runner.train_pieces_svm()
|
#runner.train_pieces_svm()
|
||||||
|
|
||||||
board = cv2.imread("whole_boards/boards_for_empty/board_1554288891.129901_rank_8.png")
|
#board = cv2.imread("quality_check.png")
|
||||||
#board = cv2.imread("whole_boards/boards_for_empty/board_1554286515.323962_rank_3.png")
|
#board = cv2.imread("whole_boards/boards_for_empty/board_1554286526.199486_rank_3.png")
|
||||||
|
board = cv2.imread("whole_boards/boards_for_empty/lmao_xd_gg_v2.png")
|
||||||
|
start = time.time()
|
||||||
warped = runner.warp_board(board)
|
warped = runner.warp_board(board)
|
||||||
|
print(time.time() - start)
|
||||||
|
#cv2.imshow("warped", warped)
|
||||||
|
#cv2.waitKey(0)
|
||||||
|
# squares = runner.get_squares(warped)
|
||||||
|
#squares = find_occupied_squares(warped)
|
||||||
|
squares = find_occupied_using_nn(warped)
|
||||||
|
|
||||||
tmp = find_occupied_squares(warped)
|
for pos, square in squares.items():
|
||||||
for pos, square in tmp:
|
piece = predict_piece(square)
|
||||||
|
cv2.putText(square, f"{pos} {piece}", (0, 50), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,)*3, thickness=3)
|
||||||
cv2.imshow(f"{pos}", square)
|
cv2.imshow(f"{pos}", square)
|
||||||
cv2.waitKey(0)
|
cv2.waitKey(0)
|
||||||
|
|
||||||
|
|
||||||
|
exit()
|
||||||
|
tmp = find_occupied_squares(warped)
|
||||||
|
|
||||||
|
#for pos, square in tmp:
|
||||||
|
# cv2.imshow(f"{pos}", square)
|
||||||
|
#cv2.waitKey(0)
|
||||||
|
|
||||||
|
|
||||||
|
board = predict_board(tmp)
|
||||||
|
|
||||||
|
for pos, piece in board.items():
|
||||||
|
print(f"{pos}, {piece}")
|
||||||
|
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -6,15 +6,17 @@ import runner
|
||||||
from util import FILE, RANK, PIECE, COLOR, imwrite, POSITION
|
from util import FILE, RANK, PIECE, COLOR, imwrite, POSITION
|
||||||
|
|
||||||
cap = cv2.VideoCapture(0)
|
cap = cv2.VideoCapture(0)
|
||||||
|
#cap.set(cv2.CAP_PROP_FRAME_WIDTH, 4096)
|
||||||
|
#cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 2160)
|
||||||
|
|
||||||
color = COLOR.BLACK
|
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1920)
|
||||||
rank = RANK.EIGHT
|
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 1080)
|
||||||
|
|
||||||
|
color = COLOR.WHITE
|
||||||
|
rank = RANK.FIVE
|
||||||
pieces = {
|
pieces = {
|
||||||
PIECE.rook: [POSITION((FILE.A, rank)), POSITION((FILE.F, rank))],
|
PIECE.PAWN: [POSITION((FILE.A, rank)), POSITION((FILE.B, rank)), POSITION((FILE.C, rank)), POSITION((FILE.D, rank)),
|
||||||
PIECE.knight: [POSITION((FILE.E, rank)), POSITION((FILE.H, rank))],
|
POSITION((FILE.E, rank)), POSITION((FILE.F, rank)), POSITION((FILE.G, rank)), POSITION((FILE.H, rank))]
|
||||||
PIECE.bishop: [POSITION((FILE.C, rank)), POSITION((FILE.D, rank))],
|
|
||||||
PIECE.queen: [POSITION((FILE.B, rank))],
|
|
||||||
PIECE.king: [POSITION((FILE.G, rank))],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
@ -26,16 +28,18 @@ while True:
|
||||||
|
|
||||||
if cv2.waitKey(100) & 0xFF == ord("c"):
|
if cv2.waitKey(100) & 0xFF == ord("c"):
|
||||||
print(f"capturing frame")
|
print(f"capturing frame")
|
||||||
imwrite(f"whole_boards/boards_for_empty/board_{datetime.utcnow().timestamp()}_.png", frame)
|
#imwrite(f"whole_boards/boards_for_empty/board_{datetime.utcnow().timestamp()}_.png", frame)
|
||||||
|
imwrite("whole_boards/boards_for_empty/lol_gg_xD.png", frame)
|
||||||
|
|
||||||
|
break
|
||||||
warped = runner.warp_board(frame)
|
warped = runner.warp_board(frame)
|
||||||
|
|
||||||
runner.save_empty_fields(warped, skip_rank=rank)
|
#runner.save_empty_fields(warped, skip_rank=rank)
|
||||||
|
|
||||||
for piece, positions in pieces.items():
|
for piece, positions in pieces.items():
|
||||||
for position in positions:
|
for position in positions:
|
||||||
square = runner.get_square(warped, *position)
|
square = runner.get_square(warped, position)
|
||||||
x, y = position
|
imwrite(f"training_images/{piece}/{position.color}_square/training_{position}_{datetime.utcnow().timestamp()}.png", square)
|
||||||
imwrite(f"training_images/{piece}/{position.color}_square/training_{x}{str(y)}_{datetime.utcnow().timestamp()}.png", square)
|
|
||||||
|
|
||||||
|
|
||||||
# When everything done, release the capture
|
# When everything done, release the capture
|
||||||
|
|
56
runner.py
56
runner.py
|
@ -1,9 +1,13 @@
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
import copyreg
|
||||||
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn import cluster, metrics, svm, neural_network
|
from sklearn import cluster, metrics, svm, neural_network
|
||||||
|
@ -14,6 +18,11 @@ from sklearn.preprocessing import StandardScaler
|
||||||
from util import RANK, POSITION, imwrite, PIECE, COLOR, Squares, OUR_PIECES
|
from util import RANK, POSITION, imwrite, PIECE, COLOR, Squares, OUR_PIECES
|
||||||
|
|
||||||
here: Path = Path(__file__).parent
|
here: Path = Path(__file__).parent
|
||||||
|
BASELINE = cv2.imread(str(here.joinpath("new_baseline_board.png")))
|
||||||
|
BASELINE_GRAY = cv2.cvtColor(BASELINE, cv2.COLOR_BGR2GRAY)
|
||||||
|
SIFT = cv2.xfeatures2d.SIFT_create()
|
||||||
|
BASELINE_KEYPOINTS = SIFT.detect(BASELINE_GRAY)
|
||||||
|
BASELINE_KEYPOINTS, BASELINE_DES = SIFT.compute(BASELINE_GRAY, BASELINE_KEYPOINTS)
|
||||||
|
|
||||||
|
|
||||||
def generate_centers(number_of_clusters, sift: cv2.xfeatures2d_SIFT):
|
def generate_centers(number_of_clusters, sift: cv2.xfeatures2d_SIFT):
|
||||||
|
@ -112,7 +121,7 @@ def train_pieces_svm() -> None:
|
||||||
current_weight = len(glob.glob(f"training_images/{piece}/{color}_square/*.png"))
|
current_weight = len(glob.glob(f"training_images/{piece}/{color}_square/*.png"))
|
||||||
print(f"Training for piece: {piece}")
|
print(f"Training for piece: {piece}")
|
||||||
X, Y = load_training_data(piece, color)
|
X, Y = load_training_data(piece, color)
|
||||||
classifier = svm.SVC(C=10, gamma=0.01, class_weight={0: 15, 1: 0.8}, probability=True)
|
classifier = svm.SVC(gamma=0.01, class_weight={0: current_weight, 1: total_weights}, probability=True)
|
||||||
classifier.fit(X, Y)
|
classifier.fit(X, Y)
|
||||||
joblib.dump(classifier, f"classifiers/classifier_{piece}/{color}.pkl")
|
joblib.dump(classifier, f"classifiers/classifier_{piece}/{color}.pkl")
|
||||||
|
|
||||||
|
@ -140,23 +149,30 @@ def find_keypoints(camera_image: np.ndarray, baseline: np.ndarray, debug=False)
|
||||||
|
|
||||||
:return: (src points, dest points)
|
:return: (src points, dest points)
|
||||||
"""
|
"""
|
||||||
|
cv2.imwrite("camera_image.png", camera_image)
|
||||||
camera_image_gray = cv2.cvtColor(camera_image, cv2.COLOR_BGR2GRAY)
|
camera_image_gray = cv2.cvtColor(camera_image, cv2.COLOR_BGR2GRAY)
|
||||||
baseline_gray = cv2.cvtColor(baseline, cv2.COLOR_BGR2GRAY)
|
|
||||||
|
|
||||||
sift = cv2.xfeatures2d.SIFT_create()
|
#sift = cv2.xfeatures2d.SURF_create()
|
||||||
camera_image_keypoints = sift.detect(camera_image_gray, None)
|
|
||||||
baseline_keypoints = sift.detect(baseline_gray, None)
|
|
||||||
|
|
||||||
camera_image_keypoints, des = sift.compute(camera_image_gray, camera_image_keypoints)
|
kp_start = time.time()
|
||||||
baseline_keypoints, des2 = sift.compute(baseline_gray, baseline_keypoints)
|
camera_image_keypoints = SIFT.detect(camera_image_gray, None)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
camera_image_keypoints, des = SIFT.compute(camera_image_gray, camera_image_keypoints)
|
||||||
|
#print("kp:",time.time() - kp_start)
|
||||||
|
|
||||||
|
def_flan = time.time()
|
||||||
# FLANN parameters
|
# FLANN parameters
|
||||||
FLANN_INDEX_KDTREE = 0
|
FLANN_INDEX_KDTREE = 0
|
||||||
index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=8)
|
index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=8)
|
||||||
search_params = dict(checks=100) # or pass empty dictionary
|
search_params = dict(checks=100) # or pass empty dictionary
|
||||||
|
|
||||||
|
flann_start = time.time()
|
||||||
flann = cv2.FlannBasedMatcher(index_params, search_params)
|
flann = cv2.FlannBasedMatcher(index_params, search_params)
|
||||||
matches = flann.knnMatch(des, des2, k=2)
|
#print("end_def:", time.time() - def_flan)
|
||||||
|
matches = flann.knnMatch(des, BASELINE_DES, k=2)
|
||||||
|
#print("flann:",time.time() - flann_start)
|
||||||
|
|
||||||
# Need to draw only good matches, so create a mask
|
# Need to draw only good matches, so create a mask
|
||||||
matchesMask = [[0, 0] for _ in range(len(matches))]
|
matchesMask = [[0, 0] for _ in range(len(matches))]
|
||||||
|
@ -168,6 +184,8 @@ def find_keypoints(camera_image: np.ndarray, baseline: np.ndarray, debug=False)
|
||||||
matchesMask[i] = [1, 0]
|
matchesMask[i] = [1, 0]
|
||||||
good_matches.append([m, n])
|
good_matches.append([m, n])
|
||||||
|
|
||||||
|
#good_matches = list(filter(lambda x: x[0].distance < 0.55 * x[1].distance, matches))
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
# Save keypoints
|
# Save keypoints
|
||||||
keypoints_image = camera_image.copy()
|
keypoints_image = camera_image.copy()
|
||||||
|
@ -178,7 +196,7 @@ def find_keypoints(camera_image: np.ndarray, baseline: np.ndarray, debug=False)
|
||||||
camera_image,
|
camera_image,
|
||||||
camera_image_keypoints,
|
camera_image_keypoints,
|
||||||
baseline,
|
baseline,
|
||||||
baseline_keypoints,
|
BASELINE_KEYPOINTS,
|
||||||
matches,
|
matches,
|
||||||
None,
|
None,
|
||||||
matchColor=(0, 255, 0),
|
matchColor=(0, 255, 0),
|
||||||
|
@ -192,15 +210,16 @@ def find_keypoints(camera_image: np.ndarray, baseline: np.ndarray, debug=False)
|
||||||
src_points = np.zeros((len(good_matches), 2), dtype=np.float32)
|
src_points = np.zeros((len(good_matches), 2), dtype=np.float32)
|
||||||
dst_points = np.zeros((len(good_matches), 2), dtype=np.float32)
|
dst_points = np.zeros((len(good_matches), 2), dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
for i, (m, n) in enumerate(good_matches):
|
for i, (m, n) in enumerate(good_matches):
|
||||||
src_points[i, :] = camera_image_keypoints[m.queryIdx].pt
|
src_points[i, :] = camera_image_keypoints[m.queryIdx].pt
|
||||||
dst_points[i, :] = baseline_keypoints[m.trainIdx].pt
|
dst_points[i, :] = BASELINE_KEYPOINTS[m.trainIdx].pt
|
||||||
|
|
||||||
return src_points, dst_points
|
return src_points, dst_points
|
||||||
|
|
||||||
|
|
||||||
def find_homography(camera_image: np.ndarray,
|
def find_homography(camera_image: np.ndarray,
|
||||||
baseline: np.ndarray = cv2.imread(str(here.joinpath("new_baseline_board.png"))),
|
baseline: np.ndarray = BASELINE,
|
||||||
debug=False) -> np.ndarray:
|
debug=False) -> np.ndarray:
|
||||||
src_points, dst_points = find_keypoints(camera_image, baseline, debug=debug)
|
src_points, dst_points = find_keypoints(camera_image, baseline, debug=debug)
|
||||||
h, mask = cv2.findHomography(src_points, dst_points, cv2.RANSAC)
|
h, mask = cv2.findHomography(src_points, dst_points, cv2.RANSAC)
|
||||||
|
@ -209,11 +228,10 @@ def find_homography(camera_image: np.ndarray,
|
||||||
|
|
||||||
|
|
||||||
def warp_board(camera_image: np.ndarray, homography: np.ndarray = None, debug=False) -> np.ndarray:
|
def warp_board(camera_image: np.ndarray, homography: np.ndarray = None, debug=False) -> np.ndarray:
|
||||||
baseline = cv2.imread(str(here.joinpath("new_baseline_board.png")))
|
|
||||||
if homography is None:
|
if homography is None:
|
||||||
homography = find_homography(camera_image, baseline, debug=debug)
|
homography = find_homography(camera_image, BASELINE, debug=debug)
|
||||||
|
|
||||||
height, width, channels = baseline.shape
|
height, width, channels = BASELINE.shape
|
||||||
return cv2.warpPerspective(camera_image, homography, (width, height))
|
return cv2.warpPerspective(camera_image, homography, (width, height))
|
||||||
|
|
||||||
|
|
||||||
|
@ -240,12 +258,15 @@ def get_squares(warped_board: np.ndarray) -> Squares:
|
||||||
for position in POSITION}
|
for position in POSITION}
|
||||||
|
|
||||||
|
|
||||||
def save_empty_fields(warped_board: np.ndarray, skip_rank: RANK = None) -> None:
|
def save_empty_fields(warped_board: np.ndarray, skip_rank: RANK = None, fourk=False) -> None:
|
||||||
for position in POSITION:
|
for position in POSITION:
|
||||||
if position.rank == skip_rank:
|
if position.rank == skip_rank:
|
||||||
continue
|
continue
|
||||||
square = get_square(warped_board, position)
|
square = get_square(warped_board, position)
|
||||||
imwrite(f"training_images/empty/{position.color}_square/training_{position}_{datetime.utcnow().timestamp()}.png", square)
|
if fourk:
|
||||||
|
imwrite(f"training_images/4k/empty/{position.color}_square/training_{position}_{datetime.utcnow().timestamp()}.png", square)
|
||||||
|
else:
|
||||||
|
imwrite(f"training_images/empty/{position.color}_square/training_{position}_{datetime.utcnow().timestamp()}.png", square)
|
||||||
|
|
||||||
|
|
||||||
def load_data_nn(spec_piece, color):
|
def load_data_nn(spec_piece, color):
|
||||||
|
@ -292,4 +313,5 @@ def train_nn():
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
#train_nn()
|
#train_nn()
|
||||||
train_empty_or_piece_hist()
|
do_pre_processing()
|
||||||
|
train_pieces_svm()
|
||||||
|
|
|
@ -1,27 +1,55 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorflow.python.keras import models
|
from tensorflow.python.keras import models
|
||||||
|
|
||||||
from util import PIECE, Squares, Board
|
from util import PIECE, Squares, Board, OUR_PIECES, LESS_PIECE
|
||||||
|
|
||||||
|
here: Path = Path(__file__).parent
|
||||||
|
|
||||||
|
|
||||||
|
#new_model = models.load_model(str(here.joinpath('pls_model_new_fuck_lel.h5')))
|
||||||
|
new_model = models.load_model(str(here.joinpath('5_piece_new_fuck_lel.h5')))
|
||||||
|
empty_class = models.load_model(str(here.joinpath('2_piece_new_fuck_lel.h5')))
|
||||||
|
#all_piece_model = models.load_model(str(here.joinpath('6_piece_new_fuck_lel.h5')))
|
||||||
|
|
||||||
new_model = models.load_model('chess_model_3_pieces.h5')
|
|
||||||
#new_model.summary()
|
#new_model.summary()
|
||||||
|
|
||||||
#board = cv2.imread("whole_boards/boards_for_empty/board_1554286488.605142_rank_3.png")
|
#board = cv2.imread("whole_boards/boards_for_empty/board_1554286488.605142_rank_3.png")
|
||||||
#board = cv2.imread("whole_boards/boards_for_empty/board_1554285167.655788_rank_5.png")
|
#board = cv2.imread("whole_boards/boards_for_empty/board_1554285167.655788_rank_5.png")
|
||||||
board = cv2.imread("whole_boards/boards_for_empty/board_1554288891.129901_rank_8.png")
|
#board = cv2.imread("whole_boards/boards_for_empty/board_1554288891.129901_rank_8.png")
|
||||||
|
|
||||||
|
def predict_empty_nn(square):
|
||||||
|
square = square[6:-6, 6:-6]
|
||||||
|
width, height, channels = square.shape
|
||||||
|
square = square / 255.0
|
||||||
|
test = empty_class.predict(np.array(square).reshape((-1, width, height, 3)))
|
||||||
|
print([round(x, 2) for x in test[0]])
|
||||||
|
return int(np.argmax(test))
|
||||||
|
|
||||||
|
def predict_piece(square):
|
||||||
|
square = square[6:-6, 6:-6]
|
||||||
|
width, height, channels = square.shape
|
||||||
|
square = square / 255.0
|
||||||
|
test = new_model.predict(np.array(square).reshape((-1, width, height, 3)))
|
||||||
|
print([round(x, 2) for x in test[0]])
|
||||||
|
return LESS_PIECE(int(np.argmax(test)))
|
||||||
|
|
||||||
def predict_board(occupied_squares: Squares) -> Board:
|
def predict_board(occupied_squares: Squares) -> Board:
|
||||||
board = Board()
|
board = Board()
|
||||||
for pos, square in occupied_squares.items():
|
for pos, square in occupied_squares.items():
|
||||||
square = cv2.cvtColor(square, cv2.COLOR_BGR2GRAY)
|
#square = cv2.cvtColor(square, cv2.COLOR_BGR2GRAY)
|
||||||
width, height = square.shape
|
square = square[6:-6, 6:-6]
|
||||||
|
|
||||||
|
width, height, channels = square.shape
|
||||||
square = square / 255.0
|
square = square / 255.0
|
||||||
test = new_model.predict(np.array(square).reshape((-1, width, height, 1)))
|
#test = new_model.predict(np.array(square).reshape((-1, width, height, 3)))
|
||||||
|
test = new_model.predict(np.array(square).reshape((-1, width, height, 3)))
|
||||||
|
|
||||||
#cv2.putText(square, f"{pos} {PIECE(int(np.argmax(test)))}", (0, 50), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,) * 3, thickness=3)
|
#cv2.putText(square, f"{pos} {PIECE(int(np.argmax(test)))}", (0, 50), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,) * 3, thickness=3)
|
||||||
#cv2.imwrite("lel", square)
|
#cv2.imwrite("lel", square)
|
||||||
board[pos] = PIECE(int(np.argmax(test)))
|
#print(f"{pos}, {test}")
|
||||||
|
board[pos] = LESS_PIECE(int(np.argmax(test)))
|
||||||
|
|
||||||
return board
|
return board
|
||||||
|
|
22
util.py
22
util.py
|
@ -9,6 +9,8 @@ import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.externals import joblib
|
from sklearn.externals import joblib
|
||||||
|
|
||||||
|
here: Path = Path(__file__).parent
|
||||||
|
|
||||||
|
|
||||||
class COLOR(Enum):
|
class COLOR(Enum):
|
||||||
WHITE = "white"
|
WHITE = "white"
|
||||||
|
@ -30,13 +32,25 @@ class PIECE(Enum):
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.name.lower()
|
return self.name.lower()
|
||||||
|
|
||||||
|
class LESS_PIECE(Enum):
|
||||||
|
ROOK = 0
|
||||||
|
KNIGHT = 1
|
||||||
|
BISHOP = 2
|
||||||
|
KING = 3
|
||||||
|
QUEEN = 4
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.name.lower()
|
||||||
|
|
||||||
|
|
||||||
PieceAndColor = Tuple[PIECE, COLOR]
|
PieceAndColor = Tuple[PIECE, COLOR]
|
||||||
|
|
||||||
OUR_PIECES = (
|
OUR_PIECES = (
|
||||||
PIECE.KNIGHT,
|
LESS_PIECE.ROOK,
|
||||||
PIECE.ROOK,
|
LESS_PIECE.KNIGHT,
|
||||||
PIECE.BISHOP,
|
LESS_PIECE.BISHOP,
|
||||||
|
LESS_PIECE.KING,
|
||||||
|
LESS_PIECE.QUEEN
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -100,4 +114,4 @@ def imwrite(*args, **kwargs):
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def load_classifier(filename):
|
def load_classifier(filename):
|
||||||
# print(f"Loading classifier {filename}")
|
# print(f"Loading classifier {filename}")
|
||||||
return joblib.load(filename)
|
return joblib.load(str(here.joinpath(filename)))
|
||||||
|
|
16
web.py
16
web.py
|
@ -7,12 +7,14 @@ from flask import Flask, jsonify, request
|
||||||
from main import find_occupied_squares
|
from main import find_occupied_squares
|
||||||
from runner import find_homography, warp_board
|
from runner import find_homography, warp_board
|
||||||
from tensor_classifier import predict_board
|
from tensor_classifier import predict_board
|
||||||
|
from time import time
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/", methods=["POST"])
|
@app.route("/", methods=["POST"])
|
||||||
def process():
|
def process():
|
||||||
|
print("Received request")
|
||||||
data = request.get_json(force=True)
|
data = request.get_json(force=True)
|
||||||
|
|
||||||
decoded = base64.b64decode(data["img"])
|
decoded = base64.b64decode(data["img"])
|
||||||
|
@ -21,16 +23,24 @@ def process():
|
||||||
camera_img = cv2.cvtColor(camera_img, cv2.COLOR_BGR2RGB)
|
camera_img = cv2.cvtColor(camera_img, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
# def do_everything:
|
# def do_everything:
|
||||||
homography = find_homography(camera_img)
|
start = time()
|
||||||
|
print("Finding keypoints")
|
||||||
|
homography = find_homography(camera_img, debug=True)
|
||||||
|
print("Computing homography")
|
||||||
warped_board = warp_board(camera_img, homography)
|
warped_board = warp_board(camera_img, homography)
|
||||||
|
print("Warping board")
|
||||||
|
cv2.imwrite("warped.png", warped_board)
|
||||||
|
print("Removing empty squares")
|
||||||
occupied_squares = find_occupied_squares(warped_board)
|
occupied_squares = find_occupied_squares(warped_board)
|
||||||
|
print("Predicting board state")
|
||||||
board = predict_board(occupied_squares)
|
board = predict_board(occupied_squares)
|
||||||
|
print(f"The request took {round(time() - start, 3)} seconds")
|
||||||
|
print("Returning board state")
|
||||||
# Finally, output for unity to read
|
# Finally, output for unity to read
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"homography": homography.tolist(),
|
"homography": homography.tolist(),
|
||||||
"board": board.to_array,
|
"board": board.to_array,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user