advancedskrald/main.py
2019-05-18 17:12:35 +02:00

376 lines
12 KiB
Python

import glob
import random
import re
import sys
import warnings
import time
from typing import List, Tuple
import cv2
import matplotlib.pyplot as plt
import numpy as np
from sklearn.exceptions import DataConversionWarning
import runner
from tensor_classifier import predict_board, predict_piece, predict_empty_nn
from util import load_classifier, PIECE, COLOR, POSITION, Board, Squares, PieceAndColor, OUR_PIECES, FILE, RANK, LESS_PIECE
warnings.filterwarnings(action='ignore', category=DataConversionWarning)
np.set_printoptions(threshold=sys.maxsize)
def identify_piece(image: np.ndarray, position: POSITION, sift: cv2.xfeatures2d_SIFT) -> PieceAndColor:
centers = np.load("training_data/centers.npy")
best = 0
probs = {p.name: {} for p in OUR_PIECES}
best_piece = best_color = None
for piece in OUR_PIECES:
for color in COLOR:
#color = runner.compute_color(file, rank)
#classifier = load_classifier(f"classifiers/neural_net_{piece}/{color}.pkl")
classifier = load_classifier(f"classifiers/classifier_{piece}/{color}.pkl")
features = runner.generate_bag_of_words(image, centers, sift)
prob = classifier.predict_proba(features)
#image = cv2.resize(image, (172, 172))
#data = np.reshape(image, (1, np.product(image.shape)))
#prob = classifier.predict_proba(data)
probs[piece.name][color.name] = prob[0, 1]
#print(f"{piece}, {color}, {prob[0, 1]}")
#if prob[0, 1] > best and color == position.color: # can only be best if correct color. Iterating through both colors for debugging only
if prob[0, 1] > best:
best = prob[0, 1]
best_piece, best_color = piece, color
#print(probs)
return best_piece, best_color
def pred_test(position: POSITION, mystery_image=None, empty_bias=False):
sift = cv2.xfeatures2d.SIFT_create()
if mystery_image is None:
mystery_image = cv2.imread("training_images/rook/white/rook_training_D4_2.png")
probs = identify_piece(mystery_image, position, sift)
return probs
def pre_process_and_train() -> None:
runner.do_pre_processing()
runner.train_pieces_svm()
def build_board_from_squares(squares: Squares) -> Board:
sift = cv2.xfeatures2d.SIFT_create()
board = Board()
counter = 0
for position, square in squares.values():
likely_piece = identify_piece(square, position, sift)
board[position] = likely_piece
if likely_piece != PIECE.EMPTY:
counter += 1
print(counter)
print(64/(counter-1))
return board
def test_entire_board() -> None:
board_img = cv2.imread("homo_pls_fuck.jpg")
warped = runner.warp_board(board_img)
squares = runner.get_squares(warped)
board = build_board_from_squares(squares)
print(board)
def predict_empty(square: np.ndarray, position: POSITION) -> bool:
y, x = np.histogram(square.ravel(), bins=32, range=[0, 256])
left, right = x[:-1], x[1:]
X = np.array([left, right]).T.flatten()
Y = np.array([y, y]).T.flatten()
area = sum(np.diff(x) * y)
plt.plot(X, Y)
plt.xlabel(f"{position}")
#plt.show()
empty_classifier = load_classifier(f"classifiers/classifier_empty/white_piece_on_{position.color}_square.pkl")
prob = empty_classifier.predict_proba(np.array(y).reshape(1, -1))
#print(f"{position}, {position.color}: {prob[0, 1]}")
y, x = np.histogram(square.ravel(), bins=64, range=[0, 256])
lel = np.array(y).reshape(1, -1)
#print(lel[lel > 5000])
#print(np.array(y).reshape(1, -1))
return prob[0, 1] > 0.75
if position.color == "white":
return prob[0, 1] > 0.75 or len(lel[lel > 5000]) > 5
else:
return prob[0, 1] > 0.65 or len(lel[lel > 5000]) > 5
def remove_most_empties(warped):
empty = 0
non_empties = []
for position in POSITION:
counter = 0
img_src = runner.get_square(warped, position)
width, height, _ = img_src.shape
src = img_src[width // 25:, height // 25:]
# src = src[:-width//200, :-height//200]
segmentator = cv2.ximgproc.segmentation.createGraphSegmentation(sigma=0.8, k=150, min_size=700)
segment = segmentator.processImage(src)
mask = segment.reshape(list(segment.shape) + [1]).repeat(3, axis=2)
masked = np.ma.masked_array(src, fill_value=0)
pls = []
for i in range(np.max(segment)):
masked.mask = mask != i
y, x = np.where(segment == i)
pls.append(len(y))
top, bottom, left, right = min(y), max(y), min(x), max(x)
dst = masked.filled()[top: bottom + 1, left: right + 1]
#cv2.imwrite(f"tmp_seg/segment_{datetime.utcnow().timestamp()}_{position}.png", dst)
if np.max(segment) > 0 and not np.all([x < (164 ** 2) * 0.2 for x in pls]) and (
np.max(segment) >= 3 or np.all([x < (164 ** 2) * 0.9469 for x in pls])):
#print(f"{position} is nonempty")
non_empties.append([position, img_src])
empty += 1
#print(64 - empty)
return non_empties
def test_empties():
pred_empty = 0
actual_empty = 0
import random
for filename in glob.glob(f"training_images/*/*_square/*.png"):
square = cv2.imread(filename)
if square.shape != (172, 172, 3):
continue
if 'empty' in filename:
actual_empty += 1
if "black" in filename:
pos = POSITION.A1
else:
pos = POSITION.A2
square = cv2.GaussianBlur(square, (7, 7), 0)
square = cv2.multiply(square, np.array([(random.random() * 0.50) + 0.90]))
pred_empty += predict_empty(square, pos)
#pred_empty += 1 - predict_empty_nn(square)
print(actual_empty)
print(pred_empty)
print(min(actual_empty,pred_empty)/max(actual_empty, pred_empty))
def test_piece_recognition(svms = False):
sift = cv2.xfeatures2d.SIFT_create()
total = 0
correct_guess = 0
for filename in glob.glob(f"training_images/rook/*/*.png"):
img = cv2.imread(filename)
square = cv2.GaussianBlur(img, (9, 9), 0)
square = cv2.multiply(square, np.array([(random.random() * 5) + 0.90]))
#cv2.imwrite("normal_square.png", img)
#cv2.imwrite("modified_square.png", square)
#cv2.imshow("normal", img)
#cv2.imshow("modified", square)
#cv2.waitKey(0)
if "black" in filename:
pos = POSITION.A1
else:
pos = POSITION.A2
if (svms):
res = identify_piece(square, pos, sift)[0]
correct_guess += (res == LESS_PIECE.ROOK)
else:
res = predict_piece(square)
correct_guess += (res == LESS_PIECE.ROOK)
total += 1
for filename in glob.glob(f"training_images/knight/*/*.png"):
img = cv2.imread(filename)
square = cv2.GaussianBlur(img, (7, 7), 0)
square = cv2.multiply(square, np.array([(random.random() * 0.50) + 0.90]))
if "black" in filename:
pos = POSITION.A1
else:
pos = POSITION.A2
if (svms):
res = identify_piece(square, pos, sift)[0]
correct_guess += (res == LESS_PIECE.KNIGHT)
else:
res = predict_piece(square)
correct_guess += (res == LESS_PIECE.KNIGHT)
total += 1
for filename in glob.glob(f"training_images/bishop/*/*.png"):
img = cv2.imread(filename)
square = cv2.GaussianBlur(img, (7, 7), 0)
square = cv2.multiply(square, np.array([(random.random() * 0.50) + 0.90]))
if "black" in filename:
pos = POSITION.A1
else:
pos = POSITION.A2
if (svms):
res = identify_piece(square, pos, sift)[0]
correct_guess += (res == LESS_PIECE.BISHOP)
else:
res = predict_piece(square)
correct_guess += (res == LESS_PIECE.BISHOP)
total += 1
for filename in glob.glob(f"training_images/king/*/*.png"):
img = cv2.imread(filename)
square = cv2.GaussianBlur(img, (7, 7), 0)
square = cv2.multiply(square, np.array([(random.random() * 0.50) + 0.90]))
if "black" in filename:
pos = POSITION.A1
else:
pos = POSITION.A2
if (svms):
res = identify_piece(square, pos, sift)[0]
correct_guess += (res == LESS_PIECE.KING)
else:
res = predict_piece(square)
correct_guess += (res == LESS_PIECE.KING)
total += 1
for filename in glob.glob(f"training_images/queen/*/*.png"):
img = cv2.imread(filename)
square = cv2.GaussianBlur(img, (7, 7), 0)
square = cv2.multiply(square, np.array([(random.random() * 0.50) + 0.90]))
if "black" in filename:
pos = POSITION.A1
else:
pos = POSITION.A2
if (svms):
res = identify_piece(square, pos, sift)[0]
correct_guess += (res == LESS_PIECE.QUEEN)
else:
res = predict_piece(square)
correct_guess += (res == LESS_PIECE.QUEEN)
total += 1
print(total)
print(correct_guess)
print(min(total, correct_guess)/max(total, correct_guess))
def find_occupied_squares(warped: np.ndarray) -> Squares:
non_empties = remove_most_empties(warped)
completely_non_empties = {}
for position, square in non_empties:
if not predict_empty(square, position):
completely_non_empties[position] = square
return completely_non_empties
def find_occupied_using_nn(warped: np.ndarray) -> Squares:
non_empties = runner.get_squares(warped)
completely_non_empties = {}
for (position, square) in non_empties.items():
prediction = predict_empty_nn(square)
if prediction:
completely_non_empties[position] = square
return completely_non_empties
if __name__ == '__main__':
test_piece_recognition(svms=True)
exit()
#runner.train_pieces_svm()
#board = cv2.imread("quality_check.png")
#board = cv2.imread("whole_boards/boards_for_empty/board_1554286526.199486_rank_3.png")
board = cv2.imread("whole_boards/boards_for_empty/lmao_xd_gg_v2.png")
start = time.time()
warped = runner.warp_board(board)
print(time.time() - start)
#cv2.imshow("warped", warped)
#cv2.waitKey(0)
# squares = runner.get_squares(warped)
#squares = find_occupied_squares(warped)
squares = find_occupied_using_nn(warped)
for pos, square in squares.items():
piece = predict_piece(square)
cv2.putText(square, f"{pos} {piece}", (0, 50), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,)*3, thickness=3)
cv2.imshow(f"{pos}", square)
cv2.waitKey(0)
exit()
tmp = find_occupied_squares(warped)
#for pos, square in tmp:
# cv2.imshow(f"{pos}", square)
#cv2.waitKey(0)
board = predict_board(tmp)
for pos, piece in board.items():
print(f"{pos}, {piece}")
exit()
"""
rook_square = runner.get_square(warped, POSITION.H3)
knight_square = runner.get_square(warped, POSITION.D3)
cv2.imshow("lel", rook_square)
cv2.imshow("lil", knight_square)
#rook_out = cv2.Canny(rook_square, 50, 55, L2gradient=True)
knight_out = cv2.Canny(knight_square, 50, 55, L2gradient=True)
knight_out_l = cv2.Canny(knight_square, 50, 55, L2gradient=False)
cv2.imshow("lal", knight_out)
cv2.imshow("lul", knight_out_l)
cv2.waitKey(0)
exit()
"""
occupied = find_occupied_squares(warped)
sift = cv2.xfeatures2d.SIFT_create()
for position, square in occupied:
print("---"*15)
piece, color = identify_piece(square, position, sift)
print(f"{piece} on {position}")
text_color = 255 if color == COLOR.WHITE else 0
cv2.putText(square, f"{position} {piece.name}", (0, 50), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(text_color,)*3, thickness=3)
cv2.imshow(f"{position}", square)
cv2.waitKey(0)