diff --git a/src/arcaea_offline_ocr/__init__.py b/src/arcaea_offline_ocr/__init__.py index c2e0b50..c227b3c 100644 --- a/src/arcaea_offline_ocr/__init__.py +++ b/src/arcaea_offline_ocr/__init__.py @@ -1,4 +1,3 @@ from .crop import * from .device import * -from .ocr import * from .utils import * diff --git a/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py b/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py index 793adff..16425d3 100644 --- a/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py +++ b/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py @@ -4,12 +4,6 @@ import cv2 import numpy as np from ....crop import crop_xywh -from ....ocr import ( - FixRects, - ocr_digits_by_contour_knn, - preprocess_hog, - resize_fill_square, -) from ....phash_db import ImagePhashDatabase from ....types import Mat from ...shared import B30OcrResultItem @@ -28,36 +22,21 @@ from .colors import ( PURE_BG_MIN_HSV, ) from .rois import ChieriBotV4Rois +from ....providers.knn import OcrKNearestTextProvider class ChieriBotV4Ocr: def __init__( self, - score_knn: cv2.ml.KNearest, - pfl_knn: cv2.ml.KNearest, + score_knn_provider: OcrKNearestTextProvider, + pfl_knn_provider: OcrKNearestTextProvider, phash_db: ImagePhashDatabase, factor: float = 1.0, ): - self.__score_knn = score_knn - self.__pfl_knn = pfl_knn self.__phash_db = phash_db self.__rois = ChieriBotV4Rois(factor) - - @property - def score_knn(self): - return self.__score_knn - - @score_knn.setter - def score_knn(self, knn_digits_model: cv2.ml.KNearest): - self.__score_knn = knn_digits_model - - @property - def pfl_knn(self): - return self.__pfl_knn - - @pfl_knn.setter - def pfl_knn(self, knn_digits_model: cv2.ml.KNearest): - self.__pfl_knn = knn_digits_model + self.pfl_knn_provider = pfl_knn_provider + self.score_knn_provider = score_knn_provider @property def phash_db(self): @@ -125,7 +104,9 @@ class ChieriBotV4Ocr: if rect[3] > score_roi.shape[0] * 0.5: continue score_roi = cv2.fillPoly(score_roi, [contour], 0) - return ocr_digits_by_contour_knn(score_roi, self.score_knn) + + ocr_result = self.score_knn_provider.result(score_roi) + return int(ocr_result) if ocr_result else 0 def find_pfl_rects( self, component_pfl_processed: Mat @@ -203,25 +184,9 @@ class ChieriBotV4Ocr: pure_far_lost = [] for pfl_roi_rect in pfl_rects: roi = crop_xywh(pfl_roi, pfl_roi_rect) - digit_contours, _ = cv2.findContours( - roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE - ) - digit_rects = [cv2.boundingRect(c) for c in digit_contours] - digit_rects = FixRects.connect_broken( - digit_rects, roi.shape[1], roi.shape[0] - ) - digit_rects = FixRects.split_connected(roi, digit_rects) - digit_rects = sorted(digit_rects, key=lambda r: r[0]) - digits = [] - for digit_rect in digit_rects: - digit = crop_xywh(roi, digit_rect) - digit = resize_fill_square(digit, 20) - digits.append(digit) - samples = preprocess_hog(digits) + result = self.pfl_knn_provider.result(roi) + pure_far_lost.append(int(result) if result else None) - _, results, _, _ = self.pfl_knn.findNearest(samples, 4) - results = [str(int(i)) for i in results.ravel()] - pure_far_lost.append(int("".join(results))) return tuple(pure_far_lost) except Exception: return (None, None, None) diff --git a/src/arcaea_offline_ocr/device/common.py b/src/arcaea_offline_ocr/device/common.py index a67c924..e9bf09a 100644 --- a/src/arcaea_offline_ocr/device/common.py +++ b/src/arcaea_offline_ocr/device/common.py @@ -5,10 +5,10 @@ from typing import Optional @dataclass class DeviceOcrResult: rating_class: int - pure: int - far: int - lost: int score: int + pure: Optional[int] = None + far: Optional[int] = None + lost: Optional[int] = None max_recall: Optional[int] = None song_id: Optional[str] = None song_id_possibility: Optional[float] = None diff --git a/src/arcaea_offline_ocr/device/ocr.py b/src/arcaea_offline_ocr/device/ocr.py index 91d827e..1c58838 100644 --- a/src/arcaea_offline_ocr/device/ocr.py +++ b/src/arcaea_offline_ocr/device/ocr.py @@ -1,15 +1,8 @@ import cv2 import numpy as np -from ..crop import crop_xywh -from ..ocr import ( - FixRects, - ocr_digit_samples_knn, - ocr_digits_by_contour_knn, - preprocess_hog, - resize_fill_square, -) from ..phash_db import ImagePhashDatabase +from ..providers.knn import OcrKNearestTextProvider from ..types import Mat from .common import DeviceOcrResult from .rois.extractor import DeviceRoisExtractor @@ -21,38 +14,37 @@ class DeviceOcr: self, extractor: DeviceRoisExtractor, masker: DeviceRoisMasker, - knn_model: cv2.ml.KNearest, + knn_provider: OcrKNearestTextProvider, phash_db: ImagePhashDatabase, ): self.extractor = extractor self.masker = masker - self.knn_model = knn_model + self.knn_provider = knn_provider self.phash_db = phash_db def pfl(self, roi_gray: Mat, factor: float = 1.25): - contours, _ = cv2.findContours( - roi_gray, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE - ) - filtered_contours = [c for c in contours if cv2.contourArea(c) >= 5 * factor] - rects = [cv2.boundingRect(c) for c in filtered_contours] - rects = FixRects.connect_broken(rects, roi_gray.shape[1], roi_gray.shape[0]) + def contour_filter(cnt): + return cv2.contourArea(cnt) >= 5 * factor - filtered_rects = [r for r in rects if r[2] >= 5 * factor and r[3] >= 6 * factor] - filtered_rects = FixRects.split_connected(roi_gray, filtered_rects) - filtered_rects = sorted(filtered_rects, key=lambda r: r[0]) + contours = self.knn_provider.contours(roi_gray) + contours_filtered = self.knn_provider.contours( + roi_gray, contours_filter=contour_filter + ) roi_ocr = roi_gray.copy() - filtered_contours_flattened = {tuple(c.flatten()) for c in filtered_contours} + contours_filtered_flattened = {tuple(c.flatten()) for c in contours_filtered} for contour in contours: - if tuple(contour.flatten()) in filtered_contours_flattened: + if tuple(contour.flatten()) in contours_filtered_flattened: continue roi_ocr = cv2.fillPoly(roi_ocr, [contour], [0]) - digit_rois = [ - resize_fill_square(crop_xywh(roi_ocr, r), 20) for r in filtered_rects - ] - samples = preprocess_hog(digit_rois) - return ocr_digit_samples_knn(samples, self.knn_model) + ocr_result = self.knn_provider.result( + roi_ocr, + contours_filter=lambda cnt: cv2.contourArea(cnt) >= 5 * factor, + rects_filter=lambda rect: rect[2] >= 5 * factor and rect[3] >= 6 * factor, + ) + + return int(ocr_result) if ocr_result else 0 def pure(self): return self.pfl(self.masker.pure(self.extractor.pure)) @@ -65,13 +57,14 @@ class DeviceOcr: def score(self): roi = self.masker.score(self.extractor.score) - contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + contours = self.knn_provider.contours(roi) for contour in contours: if ( cv2.boundingRect(contour)[3] < roi.shape[0] * 0.6 ): # h < score_component_h * 0.6 roi = cv2.fillPoly(roi, [contour], [0]) - return ocr_digits_by_contour_knn(roi, self.knn_model) + ocr_result = self.knn_provider.result(roi) + return int(ocr_result) if ocr_result else 0 def rating_class(self): roi = self.extractor.rating_class @@ -85,9 +78,10 @@ class DeviceOcr: return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0] def max_recall(self): - return ocr_digits_by_contour_knn( - self.masker.max_recall(self.extractor.max_recall), self.knn_model + ocr_result = self.knn_provider.result( + self.masker.max_recall(self.extractor.max_recall) ) + return int(ocr_result) if ocr_result else None def clear_status(self): roi = self.extractor.clear_status diff --git a/src/arcaea_offline_ocr/providers/__init__.py b/src/arcaea_offline_ocr/providers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/arcaea_offline_ocr/providers/base.py b/src/arcaea_offline_ocr/providers/base.py new file mode 100644 index 0000000..6e84be1 --- /dev/null +++ b/src/arcaea_offline_ocr/providers/base.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + from ..types import Mat + + +class OcrTextProvider(ABC): + @abstractmethod + def result_raw(self, img: "Mat", /, *args, **kwargs) -> Any: ... + @abstractmethod + def result(self, img: "Mat", /, *args, **kwargs) -> Optional[str]: ... diff --git a/src/arcaea_offline_ocr/ocr.py b/src/arcaea_offline_ocr/providers/knn.py similarity index 60% rename from src/arcaea_offline_ocr/ocr.py rename to src/arcaea_offline_ocr/providers/knn.py index fa30635..3e8473f 100644 --- a/src/arcaea_offline_ocr/ocr.py +++ b/src/arcaea_offline_ocr/providers/knn.py @@ -1,18 +1,19 @@ +import logging import math -from typing import Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Callable, Optional, Sequence, Tuple import cv2 import numpy as np -from .crop import crop_xywh -from .types import Mat +from ..crop import crop_xywh +from .base import OcrTextProvider -__all__ = [ - "FixRects", - "preprocess_hog", - "ocr_digits_by_contour_get_samples", - "ocr_digits_by_contour_knn", -] +if TYPE_CHECKING: + from cv2.ml import KNearest + + from ..types import Mat + +logger = logging.getLogger(__name__) class FixRects: @@ -68,7 +69,7 @@ class FixRects: @staticmethod def split_connected( - img_masked: Mat, + img_masked: "Mat", rects: Sequence[Tuple[int, int, int, int]], rect_wh_ratio: float = 1.05, width_range_ratio: float = 0.1, @@ -118,7 +119,7 @@ class FixRects: return return_rects -def resize_fill_square(img: Mat, target: int = 20): +def resize_fill_square(img: "Mat", target: int = 20): h, w = img.shape[:2] if h > w: new_h = target @@ -152,29 +153,88 @@ def preprocess_hog(digit_rois): def ocr_digit_samples_knn(__samples, knn_model: cv2.ml.KNearest, k: int = 4): _, results, _, _ = knn_model.findNearest(__samples, k) - result_list = [int(r) for r in results.ravel()] - result_str = "".join(str(r) for r in result_list if r > -1) - return int(result_str) if result_str else 0 + return [int(r) for r in results.ravel()] -def ocr_digits_by_contour_get_samples(__roi_gray: Mat, size: int): - roi = __roi_gray.copy() - contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) - rects = [cv2.boundingRect(c) for c in contours] - rects = FixRects.connect_broken(rects, roi.shape[1], roi.shape[0]) - rects = FixRects.split_connected(roi, rects) - rects = sorted(rects, key=lambda r: r[0]) - # digit_rois = [cv2.resize(crop_xywh(roi, rect), size) for rect in rects] - digit_rois = [resize_fill_square(crop_xywh(roi, rect), size) for rect in rects] - return preprocess_hog(digit_rois) +class OcrKNearestTextProvider(OcrTextProvider): + _ContourFilter = Callable[["Mat"], bool] + _RectsFilter = Callable[[Sequence[int]], bool] + def __init__(self, model: "KNearest"): + self.model = model -def ocr_digits_by_contour_knn( - __roi_gray: Mat, - knn_model: cv2.ml.KNearest, - *, - k=4, - size: int = 20, -) -> int: - samples = ocr_digits_by_contour_get_samples(__roi_gray, size) - return ocr_digit_samples_knn(samples, knn_model, k) + def contours( + self, img: "Mat", /, *, contours_filter: Optional[_ContourFilter] = None + ): + cnts, _ = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + if contours_filter: + cnts = list(filter(contours_filter, cnts)) + + return cnts + + def result_raw( + self, + img: "Mat", + /, + *, + fix_rects: bool = True, + contours_filter: Optional[_ContourFilter] = None, + rects_filter: Optional[_RectsFilter] = None, + ): + """ + :param img: grayscaled roi + """ + + try: + cnts, _ = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if contours_filter: + cnts = list(filter(contours_filter, cnts)) + + rects = [cv2.boundingRect(cnt) for cnt in cnts] + if fix_rects and rects_filter: + rects = FixRects.connect_broken(rects, img.shape[1], img.shape[0]) # type: ignore + rects = list(filter(rects_filter, rects)) + rects = FixRects.split_connected(img, rects) + elif fix_rects: + rects = FixRects.connect_broken(rects, img.shape[1], img.shape[0]) # type: ignore + rects = FixRects.split_connected(img, rects) + elif rects_filter: + rects = list(filter(rects_filter, rects)) + + rects = sorted(rects, key=lambda r: r[0]) + + digits = [] + for rect in rects: + digit = crop_xywh(img, rect) + digit = resize_fill_square(digit, 20) + digits.append(digit) + samples = preprocess_hog(digits) + return ocr_digit_samples_knn(samples, self.model) + except Exception: + logger.exception("Error occurred during KNearest OCR") + return None + + def result( + self, + img: "Mat", + /, + *, + fix_rects: bool = True, + contours_filter: Optional[_ContourFilter] = None, + rects_filter: Optional[_RectsFilter] = None, + ): + """ + :param img: grayscaled roi + """ + + raw = self.result_raw( + img, + fix_rects=fix_rects, + contours_filter=contours_filter, + rects_filter=rects_filter, + ) + return ( + "".join(["".join(str(r) for r in raw if r > -1)]) + if raw is not None + else None + )