diff --git a/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py b/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py index dd03488..1d9a528 100644 --- a/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py +++ b/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py @@ -1,58 +1,58 @@ -from datetime import datetime -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from math import floor +from typing import TYPE_CHECKING, List, Optional, Tuple -import attrs import cv2 import numpy as np from ....crop import crop_xywh -from ....ocr import preprocess_hog -from ....types import Mat, XYWHRect, cv2_ml_KNearest +from ....ocr import ocr_digits_by_contour_knn, preprocess_hog +from ....sift_db import SIFTDatabase +from ....types import Mat, cv2_ml_KNearest from ....utils import construct_int_xywh_rect +from ...shared import B30OcrResultItem from .colors import * from .rois import ChieriBotV4Rois -if TYPE_CHECKING: - from paddleocr import PaddleOCR - - -@attrs.define -class ChieriBotV4OcrResultItem: - rating_class: int - title: str - score: int - pure: int - far: int - lost: int - date: Union[datetime, str] +# if TYPE_CHECKING: +# from paddleocr import PaddleOCR class ChieriBotV4Ocr: def __init__( self, - paddle_ocr: "PaddleOCR", - knn_digits_model: cv2_ml_KNearest, + score_knn: cv2_ml_KNearest, + pfl_knn: cv2_ml_KNearest, + sift_db: SIFTDatabase, factor: Optional[float] = 1.0, ): - self.__paddle_ocr = paddle_ocr - self.__knn_digits_model = knn_digits_model + self.__score_knn = score_knn + self.__pfl_knn = pfl_knn + self.__sift_db = sift_db self.__rois = ChieriBotV4Rois(factor) @property - def paddle_ocr(self): - return self.__paddle_ocr + def score_knn(self): + return self.__score_knn - @paddle_ocr.setter - def paddle_ocr(self, paddle_ocr: "PaddleOCR"): - self.__paddle_ocr = paddle_ocr + @score_knn.setter + def score_knn(self, knn_digits_model: Mat): + self.__score_knn = knn_digits_model @property - def knn_digits_model(self): - return self.__knn_digits_model + def pfl_knn(self): + return self.__pfl_knn - @knn_digits_model.setter - def knn_digits_model(self, knn_digits_model: Mat): - self.__knn_digits_model = knn_digits_model + @pfl_knn.setter + def pfl_knn(self, knn_digits_model: Mat): + self.__pfl_knn = knn_digits_model + + @property + def sift_db(self): + return self.__sift_db + + @sift_db.setter + def sift_db(self, sift_db: SIFTDatabase): + self.__sift_db = sift_db @property def rois(self): @@ -66,6 +66,9 @@ class ChieriBotV4Ocr: def factor(self, factor: float): self.__rois.factor = factor + def set_factor(self, img: Mat): + self.factor = img.shape[0] / 4400 + def ocr_component_rating_class(self, component_bgr: Mat) -> int: rating_class_rect = construct_int_xywh_rect( self.rois.component_rois.rating_class_rect @@ -83,15 +86,37 @@ class ChieriBotV4Ocr: else: return max(enumerate(rating_class_results), key=lambda i: i[1])[0] + 1 - def ocr_component_title(self, component_bgr: Mat) -> str: - # sourcery skip: inline-immediately-returned-variable - title_rect = construct_int_xywh_rect(self.rois.component_rois.title_rect) - title_roi = crop_xywh(component_bgr, title_rect) - ocr_result = self.paddle_ocr.ocr(title_roi, cls=False) - title = ocr_result[0][-1][1][0] if ocr_result and ocr_result[0] else "" - return title + # def ocr_component_title(self, component_bgr: Mat) -> str: + # # sourcery skip: inline-immediately-returned-variable + # title_rect = construct_int_xywh_rect(self.rois.component_rois.title_rect) + # title_roi = crop_xywh(component_bgr, title_rect) + # ocr_result = self.sift_db.ocr(title_roi, cls=False) + # title = ocr_result[0][-1][1][0] if ocr_result and ocr_result[0] else "" + # return title - def ocr_component_score(self, component_bgr: Mat) -> int: + def ocr_component_song_id(self, component_bgr: Mat): + jacket_rect = construct_int_xywh_rect( + self.rois.component_rois.jacket_rect, floor + ) + jacket_roi = cv2.cvtColor( + crop_xywh(component_bgr, jacket_rect), cv2.COLOR_BGR2GRAY + ) + return self.sift_db.lookup_img(jacket_roi)[0] + + # def ocr_component_score_paddle(self, component_bgr: Mat) -> int: + # # sourcery skip: inline-immediately-returned-variable + # score_rect = construct_int_xywh_rect(self.rois.component_rois.score_rect) + # score_roi = cv2.cvtColor( + # crop_xywh(component_bgr, score_rect), cv2.COLOR_BGR2GRAY + # ) + # _, score_roi = cv2.threshold( + # score_roi, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU + # ) + # score_str = self.sift_db.ocr(score_roi, cls=False)[0][-1][1][0] + # score = int(score_str.replace("'", "").replace(" ", "")) + # return score + + def ocr_component_score_knn(self, component_bgr: Mat) -> int: # sourcery skip: inline-immediately-returned-variable score_rect = construct_int_xywh_rect(self.rois.component_rois.score_rect) score_roi = cv2.cvtColor( @@ -100,9 +125,18 @@ class ChieriBotV4Ocr: _, score_roi = cv2.threshold( score_roi, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU ) - score_str = self.paddle_ocr.ocr(score_roi, cls=False)[0][-1][1][0] - score = int(score_str.replace("'", "").replace(" ", "")) - return score + if score_roi[1][1] == 255: + score_roi = 255 - score_roi + + contours, _ = cv2.findContours( + score_roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) + for contour in contours: + rect = cv2.boundingRect(contour) + if rect[3] > score_roi.shape[0] * 0.5: + continue + score_roi = cv2.fillPoly(score_roi, [contour], 0) + return ocr_digits_by_contour_knn(score_roi, self.score_knn) def find_pfl_rects(self, component_pfl_processed: Mat) -> List[List[int]]: # sourcery skip: inline-immediately-returned-variable @@ -169,7 +203,9 @@ class ChieriBotV4Ocr: ) return result_eroded if len(self.find_pfl_rects(result_eroded)) == 3 else result - def ocr_component_pfl(self, component_bgr: Mat) -> Tuple[int, int, int]: + def ocr_component_pfl( + self, component_bgr: Mat + ) -> Tuple[Optional[int], Optional[int], Optional[int]]: try: pfl_roi = self.preprocess_component_pfl(component_bgr) pfl_rects = self.find_pfl_rects(pfl_roi) @@ -190,31 +226,43 @@ class ChieriBotV4Ocr: digits.append(digit) samples = preprocess_hog(digits) - _, results, _, _ = self.knn_digits_model.findNearest(samples, 4) + _, results, _, _ = self.pfl_knn.findNearest(samples, 4) results = [str(int(i)) for i in results.ravel()] pure_far_lost.append(int("".join(results))) return tuple(pure_far_lost) except Exception: - return (-1, -1, -1) + return (None, None, None) - def ocr_component(self, component_bgr: Mat) -> ChieriBotV4OcrResultItem: + # def ocr_component_date(self, component_bgr: Mat): + # date_rect = construct_int_xywh_rect(self.rois.component_rois.date_rect) + # date_roi = cv2.cvtColor(crop_xywh(component_bgr, date_rect), cv2.COLOR_BGR2GRAY) + # _, date_roi = cv2.threshold( + # date_roi, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU + # ) + # date_str = self.sift_db.ocr(date_roi, cls=False)[0][-1][1][0] + # return date_str + + def ocr_component(self, component_bgr: Mat) -> B30OcrResultItem: component_blur = cv2.GaussianBlur(component_bgr, (5, 5), 0) rating_class = self.ocr_component_rating_class(component_blur) - title = self.ocr_component_title(component_blur) - score = self.ocr_component_score(component_blur) + song_id = self.ocr_component_song_id(component_bgr) + # title = self.ocr_component_title(component_blur) + # score = self.ocr_component_score(component_blur) + score = self.ocr_component_score_knn(component_bgr) pure, far, lost = self.ocr_component_pfl(component_bgr) - return ChieriBotV4OcrResultItem( + return B30OcrResultItem( + song_id=song_id, rating_class=rating_class, - title=title, + # title=title, score=score, pure=pure, far=far, lost=lost, - date="", + date=None, ) - def ocr(self, img_bgr: Mat) -> List[ChieriBotV4OcrResultItem]: - self.factor = img_bgr.shape[0] / 4400 + def ocr(self, img_bgr: Mat) -> List[B30OcrResultItem]: + self.set_factor(img_bgr) return [ self.ocr_component(component_bgr) for component_bgr in self.rois.components(img_bgr) diff --git a/src/arcaea_offline_ocr/b30/chieri/v4/rois.py b/src/arcaea_offline_ocr/b30/chieri/v4/rois.py index 50c96a2..9926b8a 100644 --- a/src/arcaea_offline_ocr/b30/chieri/v4/rois.py +++ b/src/arcaea_offline_ocr/b30/chieri/v4/rois.py @@ -37,6 +37,10 @@ class ChieriBotV4ComponentRois: def title_rect(self): return apply_factor((35, 10, 430, 50), self.factor) + @property + def jacket_rect(self): + return apply_factor((263, 0, 239, 239), self.factor) + @property def score_rect(self): return apply_factor((30, 60, 270, 55), self.factor) diff --git a/src/arcaea_offline_ocr/b30/shared.py b/src/arcaea_offline_ocr/b30/shared.py new file mode 100644 index 0000000..d76fea2 --- /dev/null +++ b/src/arcaea_offline_ocr/b30/shared.py @@ -0,0 +1,16 @@ +from datetime import datetime +from typing import Optional, Union + +import attrs + + +@attrs.define +class B30OcrResultItem: + rating_class: int + score: int + pure: Optional[int] = None + far: Optional[int] = None + lost: Optional[int] = None + date: Optional[datetime] = None + title: Optional[str] = None + song_id: Optional[str] = None diff --git a/src/arcaea_offline_ocr/ocr.py b/src/arcaea_offline_ocr/ocr.py index ce645c7..801befb 100644 --- a/src/arcaea_offline_ocr/ocr.py +++ b/src/arcaea_offline_ocr/ocr.py @@ -11,7 +11,7 @@ from .types import Mat, cv2_ml_KNearest __all__ = [ "preprocess_hog", - "ocr_digits_by_contour_samples", + "ocr_digits_by_contour_get_samples", "ocr_digits_by_contour_knn", ] @@ -65,7 +65,14 @@ def preprocess_hog(digit_rois): return np.float32(samples) -def ocr_digits_by_contour_samples(__roi_gray: Mat, size: int): +def ocr_digit_samples_knn(__samples, knn_model: cv2_ml_KNearest, k: int = 4): + _, results, _, _ = knn_model.findNearest(__samples, k) + result_list = [int(r) for r in results.ravel()] + result_str = "".join(str(r) for r in result_list if r > -1) + return int(result_str) + + +def ocr_digits_by_contour_get_samples(__roi_gray: Mat, size: int): roi = __roi_gray.copy() contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) rects = sorted([cv2.boundingRect(c) for c in contours], key=lambda r: r[0]) @@ -81,10 +88,8 @@ def ocr_digits_by_contour_knn( k=4, size: int = 20, ) -> int: - samples = ocr_digits_by_contour_samples(__roi_gray, size) - _, results, _, _ = knn_model.findNearest(samples, k) - results = [str(int(i)) for i in results.ravel()] - return int("".join(results)) + samples = ocr_digits_by_contour_get_samples(__roi_gray, size) + return ocr_digit_samples_knn(samples, knn_model, k) def ocr_rating_class(roi_hsv: Mat):