advancedskrald/util.py

118 lines
2.2 KiB
Python
Raw Normal View History

from __future__ import annotations
from enum import Enum
from functools import lru_cache
from pathlib import Path
from typing import NamedTuple, Dict, Tuple, List
2019-04-16 21:29:41 +00:00
import cv2
import numpy as np
from sklearn.externals import joblib
2019-05-18 15:12:35 +00:00
here: Path = Path(__file__).parent
class COLOR(Enum):
WHITE = "white"
BLACK = "black"
def __str__(self) -> str:
return self.value
class PIECE(Enum):
KNIGHT = 0
ROOK = 1
BISHOP = 2
PAWN = 3
QUEEN = 4
KING = 5
EMPTY = 6
def __str__(self) -> str:
return self.name.lower()
2019-05-18 15:12:35 +00:00
class LESS_PIECE(Enum):
ROOK = 0
KNIGHT = 1
BISHOP = 2
KING = 3
QUEEN = 4
def __str__(self) -> str:
return self.name.lower()
PieceAndColor = Tuple[PIECE, COLOR]
2019-04-10 11:39:50 +00:00
OUR_PIECES = (
2019-05-18 15:12:35 +00:00
LESS_PIECE.ROOK,
LESS_PIECE.KNIGHT,
LESS_PIECE.BISHOP,
LESS_PIECE.KING,
LESS_PIECE.QUEEN
2019-04-10 11:39:50 +00:00
)
class FILE(int, Enum):
A = 1
B = 2
C = 3
D = 4
E = 5
F = 6
G = 7
H = 8
class RANK(int, Enum):
EIGHT = 8
SEVEN = 7
SIX = 6
FIVE = 5
FOUR = 4
THREE = 3
TWO = 2
ONE = 1
class _Position(NamedTuple):
file: FILE
rank: RANK
def __str__(self) -> str:
return f"{self.file.name}{self.rank}"
@property
def color(self):
if (self.file + self.rank) % 2:
return COLOR.WHITE
return COLOR.BLACK
# POSITION.{A8, A7, ..., H1}
2019-04-16 21:29:41 +00:00
POSITION = Enum("POSITION", {str(_Position(f, r)): _Position(f, r) for f in FILE for r in RANK}, type=_Position) # NOQA
# Squares is a dict mapping positions to square images, i.e. a board container during image processing
2019-04-16 21:29:41 +00:00
Squares = Dict[POSITION, np.ndarray]
class Board(Dict[POSITION, PIECE]):
"""Board is a dict mapping positions to a piece, i.e. a board configuration after all image processing"""
@property
def to_array(self) -> List[List[int]]:
return [[self.get(POSITION((file, rank)), PIECE.EMPTY).value for file in FILE]
for rank in RANK]
def imwrite(*args, **kwargs):
Path(args[0]).parent.mkdir(parents=True, exist_ok=True)
return cv2.imwrite(*args, **kwargs)
@lru_cache()
def load_classifier(filename):
2019-04-10 11:39:50 +00:00
# print(f"Loading classifier {filename}")
2019-05-18 15:12:35 +00:00
return joblib.load(str(here.joinpath(filename)))