diff --git a/src/arcaea_offline_ocr/b30/__init__.py b/src/arcaea_offline_ocr/b30/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/arcaea_offline_ocr/b30/chieri/__init__.py b/src/arcaea_offline_ocr/b30/chieri/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/arcaea_offline_ocr/b30/chieri/v4/__init__.py b/src/arcaea_offline_ocr/b30/chieri/v4/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/arcaea_offline_ocr/b30/chieri/v4/colors.py b/src/arcaea_offline_ocr/b30/chieri/v4/colors.py new file mode 100644 index 0000000..191e0aa --- /dev/null +++ b/src/arcaea_offline_ocr/b30/chieri/v4/colors.py @@ -0,0 +1,37 @@ +import numpy as np + +__all__ = [ + "FONT_THRESHOLD", + "PURE_BG_MIN_HSV", + "PURE_BG_MAX_HSV", + "FAR_BG_MIN_HSV", + "FAR_BG_MAX_HSV", + "LOST_BG_MIN_HSV", + "LOST_BG_MAX_HSV", + "BYD_MIN_HSV", + "BYD_MAX_HSV", + "FTR_MIN_HSV", + "FTR_MAX_HSV", + "PRS_MIN_HSV", + "PRS_MAX_HSV", +] + +FONT_THRESHOLD = 160 + +PURE_BG_MIN_HSV = np.array([95, 140, 150], np.uint8) +PURE_BG_MAX_HSV = np.array([110, 255, 255], np.uint8) + +FAR_BG_MIN_HSV = np.array([15, 100, 150], np.uint8) +FAR_BG_MAX_HSV = np.array([20, 255, 255], np.uint8) + +LOST_BG_MIN_HSV = np.array([115, 60, 150], np.uint8) +LOST_BG_MAX_HSV = np.array([140, 255, 255], np.uint8) + +BYD_MIN_HSV = (158, 120, 0) +BYD_MAX_HSV = (172, 255, 255) + +FTR_MIN_HSV = (145, 70, 0) +FTR_MAX_HSV = (160, 255, 255) + +PRS_MIN_HSV = (45, 60, 0) +PRS_MAX_HSV = (70, 255, 255) diff --git a/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py b/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py new file mode 100644 index 0000000..dd03488 --- /dev/null +++ b/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py @@ -0,0 +1,221 @@ +from datetime import datetime +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import attrs +import cv2 +import numpy as np + +from ....crop import crop_xywh +from ....ocr import preprocess_hog +from ....types import Mat, XYWHRect, cv2_ml_KNearest +from ....utils import construct_int_xywh_rect +from .colors import * +from .rois import ChieriBotV4Rois + +if TYPE_CHECKING: + from paddleocr import PaddleOCR + + +@attrs.define +class ChieriBotV4OcrResultItem: + rating_class: int + title: str + score: int + pure: int + far: int + lost: int + date: Union[datetime, str] + + +class ChieriBotV4Ocr: + def __init__( + self, + paddle_ocr: "PaddleOCR", + knn_digits_model: cv2_ml_KNearest, + factor: Optional[float] = 1.0, + ): + self.__paddle_ocr = paddle_ocr + self.__knn_digits_model = knn_digits_model + self.__rois = ChieriBotV4Rois(factor) + + @property + def paddle_ocr(self): + return self.__paddle_ocr + + @paddle_ocr.setter + def paddle_ocr(self, paddle_ocr: "PaddleOCR"): + self.__paddle_ocr = paddle_ocr + + @property + def knn_digits_model(self): + return self.__knn_digits_model + + @knn_digits_model.setter + def knn_digits_model(self, knn_digits_model: Mat): + self.__knn_digits_model = knn_digits_model + + @property + def rois(self): + return self.__rois + + @property + def factor(self): + return self.__rois.factor + + @factor.setter + def factor(self, factor: float): + self.__rois.factor = factor + + def ocr_component_rating_class(self, component_bgr: Mat) -> int: + rating_class_rect = construct_int_xywh_rect( + self.rois.component_rois.rating_class_rect + ) + rating_class_roi = crop_xywh(component_bgr, rating_class_rect) + rating_class_roi = cv2.cvtColor(rating_class_roi, cv2.COLOR_BGR2HSV) + rating_class_masks = [ + cv2.inRange(rating_class_roi, PRS_MIN_HSV, PRS_MAX_HSV), + cv2.inRange(rating_class_roi, FTR_MIN_HSV, FTR_MAX_HSV), + cv2.inRange(rating_class_roi, BYD_MIN_HSV, BYD_MAX_HSV), + ] # prs, ftr, byd only + rating_class_results = [np.count_nonzero(m) for m in rating_class_masks] + if max(rating_class_results) < 70: + return 0 + else: + return max(enumerate(rating_class_results), key=lambda i: i[1])[0] + 1 + + def ocr_component_title(self, component_bgr: Mat) -> str: + # sourcery skip: inline-immediately-returned-variable + title_rect = construct_int_xywh_rect(self.rois.component_rois.title_rect) + title_roi = crop_xywh(component_bgr, title_rect) + ocr_result = self.paddle_ocr.ocr(title_roi, cls=False) + title = ocr_result[0][-1][1][0] if ocr_result and ocr_result[0] else "" + return title + + def ocr_component_score(self, component_bgr: Mat) -> int: + # sourcery skip: inline-immediately-returned-variable + score_rect = construct_int_xywh_rect(self.rois.component_rois.score_rect) + score_roi = cv2.cvtColor( + crop_xywh(component_bgr, score_rect), cv2.COLOR_BGR2GRAY + ) + _, score_roi = cv2.threshold( + score_roi, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU + ) + score_str = self.paddle_ocr.ocr(score_roi, cls=False)[0][-1][1][0] + score = int(score_str.replace("'", "").replace(" ", "")) + return score + + def find_pfl_rects(self, component_pfl_processed: Mat) -> List[List[int]]: + # sourcery skip: inline-immediately-returned-variable + pfl_roi_find = cv2.morphologyEx( + component_pfl_processed, + cv2.MORPH_CLOSE, + cv2.getStructuringElement(cv2.MORPH_RECT, [10, 1]), + ) + pfl_contours, _ = cv2.findContours( + pfl_roi_find, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE + ) + pfl_rects = [cv2.boundingRect(c) for c in pfl_contours] + pfl_rects = [ + r for r in pfl_rects if r[3] > component_pfl_processed.shape[0] * 0.1 + ] + pfl_rects = sorted(pfl_rects, key=lambda r: r[1]) + pfl_rects_adjusted = [ + ( + max(rect[0] - 2, 0), + rect[1], + min(rect[2] + 2, component_pfl_processed.shape[1]), + rect[3], + ) + for rect in pfl_rects + ] + return pfl_rects_adjusted + + def preprocess_component_pfl(self, component_bgr: Mat) -> Mat: + pfl_rect = construct_int_xywh_rect(self.rois.component_rois.pfl_rect) + pfl_roi = crop_xywh(component_bgr, pfl_rect) + pfl_roi_hsv = cv2.cvtColor(pfl_roi, cv2.COLOR_BGR2HSV) + + # fill the pfl bg with background color + bg_point = [round(i) for i in self.rois.component_rois.bg_point] + bg_color = component_bgr[bg_point[1]][bg_point[0]] + pure_bg_mask = cv2.inRange(pfl_roi_hsv, PURE_BG_MIN_HSV, PURE_BG_MAX_HSV) + far_bg_mask = cv2.inRange(pfl_roi_hsv, FAR_BG_MIN_HSV, FAR_BG_MAX_HSV) + lost_bg_mask = cv2.inRange(pfl_roi_hsv, LOST_BG_MIN_HSV, LOST_BG_MAX_HSV) + pfl_roi[np.where(pure_bg_mask != 0)] = bg_color + pfl_roi[np.where(far_bg_mask != 0)] = bg_color + pfl_roi[np.where(lost_bg_mask != 0)] = bg_color + + # threshold + pfl_roi = cv2.cvtColor(pfl_roi, cv2.COLOR_BGR2GRAY) + # get threshold of blurred image, try ignoring the lines of bg bar + pfl_roi_blurred = cv2.GaussianBlur(pfl_roi, (5, 5), 0) + # pfl_roi_blurred = cv2.medianBlur(pfl_roi, 3) + _, pfl_roi_blurred_threshold = cv2.threshold( + pfl_roi_blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU + ) + # and a threshold of the original roi + _, pfl_roi_threshold = cv2.threshold( + pfl_roi, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU + ) + # turn thresholds into black background + if pfl_roi_blurred_threshold[2][2] == 255: + pfl_roi_blurred_threshold = 255 - pfl_roi_blurred_threshold + if pfl_roi_threshold[2][2] == 255: + pfl_roi_threshold = 255 - pfl_roi_threshold + # return a bitwise_and result + result = cv2.bitwise_and(pfl_roi_blurred_threshold, pfl_roi_threshold) + result_eroded = cv2.erode( + result, cv2.getStructuringElement(cv2.MORPH_CROSS, (2, 2)) + ) + return result_eroded if len(self.find_pfl_rects(result_eroded)) == 3 else result + + def ocr_component_pfl(self, component_bgr: Mat) -> Tuple[int, int, int]: + try: + pfl_roi = self.preprocess_component_pfl(component_bgr) + pfl_rects = self.find_pfl_rects(pfl_roi) + pure_far_lost = [] + for pfl_roi_rect in pfl_rects: + roi = crop_xywh(pfl_roi, pfl_roi_rect) + digit_contours, _ = cv2.findContours( + roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) + digit_rects = sorted( + [cv2.boundingRect(c) for c in digit_contours], + key=lambda r: r[0], + ) + digits = [] + for digit_rect in digit_rects: + digit = crop_xywh(roi, digit_rect) + digit = cv2.resize(digit, (20, 20)) + digits.append(digit) + samples = preprocess_hog(digits) + + _, results, _, _ = self.knn_digits_model.findNearest(samples, 4) + results = [str(int(i)) for i in results.ravel()] + pure_far_lost.append(int("".join(results))) + return tuple(pure_far_lost) + except Exception: + return (-1, -1, -1) + + def ocr_component(self, component_bgr: Mat) -> ChieriBotV4OcrResultItem: + component_blur = cv2.GaussianBlur(component_bgr, (5, 5), 0) + rating_class = self.ocr_component_rating_class(component_blur) + title = self.ocr_component_title(component_blur) + score = self.ocr_component_score(component_blur) + pure, far, lost = self.ocr_component_pfl(component_bgr) + return ChieriBotV4OcrResultItem( + rating_class=rating_class, + title=title, + score=score, + pure=pure, + far=far, + lost=lost, + date="", + ) + + def ocr(self, img_bgr: Mat) -> List[ChieriBotV4OcrResultItem]: + self.factor = img_bgr.shape[0] / 4400 + return [ + self.ocr_component(component_bgr) + for component_bgr in self.rois.components(img_bgr) + ] diff --git a/src/arcaea_offline_ocr/b30/chieri/v4/rois.py b/src/arcaea_offline_ocr/b30/chieri/v4/rois.py new file mode 100644 index 0000000..50c96a2 --- /dev/null +++ b/src/arcaea_offline_ocr/b30/chieri/v4/rois.py @@ -0,0 +1,132 @@ +from typing import List, Optional + +from ....crop import crop_xywh +from ....types import Mat, XYWHRect +from ....utils import apply_factor, construct_int_xywh_rect + + +class ChieriBotV4ComponentRois: + def __init__(self, factor: Optional[float] = 1.0): + self.__factor = factor + + @property + def factor(self): + return self.__factor + + @factor.setter + def factor(self, factor: float): + self.__factor = factor + + @property + def top_font_color_detect(self): + return apply_factor((35, 10, 120, 100), self.factor) + + @property + def bottom_font_color_detect(self): + return apply_factor((30, 125, 175, 110), self.factor) + + @property + def bg_point(self): + return apply_factor((75, 10), self.factor) + + @property + def rating_class_rect(self): + return apply_factor((21, 40, 7, 20), self.factor) + + @property + def title_rect(self): + return apply_factor((35, 10, 430, 50), self.factor) + + @property + def score_rect(self): + return apply_factor((30, 60, 270, 55), self.factor) + + @property + def pfl_rect(self): + return apply_factor((50, 125, 80, 100), self.factor) + + @property + def date_rect(self): + return apply_factor((205, 200, 225, 25), self.factor) + + +class ChieriBotV4Rois: + def __init__(self, factor: Optional[float] = 1.0): + self.__factor = factor + self.__component_rois = ChieriBotV4ComponentRois(factor) + + @property + def component_rois(self): + return self.__component_rois + + @property + def factor(self): + return self.__factor + + @factor.setter + def factor(self, factor: float): + self.__factor = factor + self.__component_rois.factor = factor + + @property + def top(self): + return apply_factor(823, self.factor) + + @property + def left(self): + return apply_factor(107, self.factor) + + @property + def width(self): + return apply_factor(502, self.factor) + + @property + def height(self): + return apply_factor(240, self.factor) + + @property + def vertical_gap(self): + return apply_factor(74, self.factor) + + @property + def horizontal_gap(self): + return apply_factor(40, self.factor) + + @property + def horizontal_items(self): + return 3 + + @property + def vertical_items(self): + return 10 + + @property + def b33_vertical_gap(self): + return apply_factor(121, self.factor) + + def components(self, img_bgr: Mat) -> List[Mat]: + first_rect = XYWHRect(x=self.left, y=self.top, w=self.width, h=self.height) + results = [] + + for vi in range(self.vertical_items): + rect = XYWHRect(*first_rect) + rect += (0, (self.vertical_gap + self.height) * vi, 0, 0) + for hi in range(self.horizontal_items): + if hi > 0: + rect += ((self.width + self.horizontal_gap), 0, 0, 0) + int_rect = construct_int_xywh_rect(rect) + results.append(crop_xywh(img_bgr, int_rect)) + + rect += ( + -(self.width + self.horizontal_gap) * 2, + self.height + self.b33_vertical_gap, + 0, + 0, + ) + for hi in range(self.horizontal_items): + if hi > 0: + rect += ((self.width + self.horizontal_gap), 0, 0, 0) + int_rect = construct_int_xywh_rect(rect) + results.append(crop_xywh(img_bgr, int_rect)) + + return results diff --git a/src/arcaea_offline_ocr/ocr.py b/src/arcaea_offline_ocr/ocr.py index ee24c88..173f850 100644 --- a/src/arcaea_offline_ocr/ocr.py +++ b/src/arcaea_offline_ocr/ocr.py @@ -5,6 +5,7 @@ import cv2 import numpy as np from imutils import grab_contours from imutils import resize as imutils_resize +from numpy.linalg import norm from pytesseract import image_to_string from .template import ( @@ -122,6 +123,32 @@ def filter_digit_results( return final_result +def preprocess_hog(digits): + samples = [] + for img in digits: + gx = cv2.Sobel(img, cv2.CV_32F, 1, 0) + gy = cv2.Sobel(img, cv2.CV_32F, 0, 1) + mag, ang = cv2.cartToPolar(gx, gy) + bin_n = 16 + bin = np.int32(bin_n * ang / (2 * np.pi)) + bin_cells = bin[:10, :10], bin[10:, :10], bin[:10, 10:], bin[10:, 10:] + mag_cells = mag[:10, :10], mag[10:, :10], mag[:10, 10:], mag[10:, 10:] + hists = [ + np.bincount(b.ravel(), m.ravel(), bin_n) + for b, m in zip(bin_cells, mag_cells) + ] + hist = np.hstack(hists) + + # transform to Hellinger kernel + eps = 1e-7 + hist /= hist.sum() + eps + hist = np.sqrt(hist) + hist /= norm(hist) + eps + + samples.append(hist) + return np.float32(samples) + + def ocr_digits( img: Mat, templates: TemplateItem, diff --git a/src/arcaea_offline_ocr/utils.py b/src/arcaea_offline_ocr/utils.py index e0d8fbd..aa771bf 100644 --- a/src/arcaea_offline_ocr/utils.py +++ b/src/arcaea_offline_ocr/utils.py @@ -1,8 +1,11 @@ +from collections.abc import Iterable +from typing import Callable, Tuple, TypeVar, Union, overload + from cv2 import IMREAD_UNCHANGED, imdecode from numpy import fromfile as np_fromfile from numpy import uint8 -from .types import Mat +from .types import Mat, XYWHRect __all__ = ["imread_unicode"] @@ -11,3 +14,34 @@ def imread_unicode(filepath: str) -> Mat: # https://stackoverflow.com/a/57872297/16484891 # CC BY-SA 4.0 return imdecode(np_fromfile(filepath, dtype=uint8), IMREAD_UNCHANGED) + + +def construct_int_xywh_rect( + rect: XYWHRect, func: Callable[[Union[int, float]], int] = round +): + return XYWHRect(*[func(num) for num in rect]) + + +@overload +def apply_factor(item: int, factor: float) -> float: + ... + + +@overload +def apply_factor(item: float, factor: float) -> float: + ... + + +T = TypeVar("T", bound=Iterable) + + +@overload +def apply_factor(item: T, factor: float) -> T: + ... + + +def apply_factor(item, factor: float): + if isinstance(item, (int, float)): + return item * factor + elif isinstance(item, Iterable): + return item.__class__([i * factor for i in item])