advancedskrald/tensor_classifier.py
2019-05-18 17:12:35 +02:00

56 lines
2.1 KiB
Python

from pathlib import Path
import cv2
import numpy as np
from tensorflow.python.keras import models
from util import PIECE, Squares, Board, OUR_PIECES, LESS_PIECE
here: Path = Path(__file__).parent
#new_model = models.load_model(str(here.joinpath('pls_model_new_fuck_lel.h5')))
new_model = models.load_model(str(here.joinpath('5_piece_new_fuck_lel.h5')))
empty_class = models.load_model(str(here.joinpath('2_piece_new_fuck_lel.h5')))
#all_piece_model = models.load_model(str(here.joinpath('6_piece_new_fuck_lel.h5')))
#new_model.summary()
#board = cv2.imread("whole_boards/boards_for_empty/board_1554286488.605142_rank_3.png")
#board = cv2.imread("whole_boards/boards_for_empty/board_1554285167.655788_rank_5.png")
#board = cv2.imread("whole_boards/boards_for_empty/board_1554288891.129901_rank_8.png")
def predict_empty_nn(square):
square = square[6:-6, 6:-6]
width, height, channels = square.shape
square = square / 255.0
test = empty_class.predict(np.array(square).reshape((-1, width, height, 3)))
print([round(x, 2) for x in test[0]])
return int(np.argmax(test))
def predict_piece(square):
square = square[6:-6, 6:-6]
width, height, channels = square.shape
square = square / 255.0
test = new_model.predict(np.array(square).reshape((-1, width, height, 3)))
print([round(x, 2) for x in test[0]])
return LESS_PIECE(int(np.argmax(test)))
def predict_board(occupied_squares: Squares) -> Board:
board = Board()
for pos, square in occupied_squares.items():
#square = cv2.cvtColor(square, cv2.COLOR_BGR2GRAY)
square = square[6:-6, 6:-6]
width, height, channels = square.shape
square = square / 255.0
#test = new_model.predict(np.array(square).reshape((-1, width, height, 3)))
test = new_model.predict(np.array(square).reshape((-1, width, height, 3)))
#cv2.putText(square, f"{pos} {PIECE(int(np.argmax(test)))}", (0, 50), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255,) * 3, thickness=3)
#cv2.imwrite("lel", square)
#print(f"{pos}, {test}")
board[pos] = LESS_PIECE(int(np.argmax(test)))
return board