diff --git a/src/arcaea_offline_ocr/device/shared.py b/src/arcaea_offline_ocr/device/shared.py index 567061b..5a48d37 100644 --- a/src/arcaea_offline_ocr/device/shared.py +++ b/src/arcaea_offline_ocr/device/shared.py @@ -1,14 +1,16 @@ +from typing import Optional + import attrs @attrs.define class DeviceOcrResult: - song_id: None - title: None rating_class: int pure: int far: int lost: int score: int max_recall: int - clear_type: None + song_id: Optional[str] = None + title: Optional[str] = None + clear_type: Optional[str] = None diff --git a/src/arcaea_offline_ocr/device/v2/ocr.py b/src/arcaea_offline_ocr/device/v2/ocr.py index 28e7929..c43400d 100644 --- a/src/arcaea_offline_ocr/device/v2/ocr.py +++ b/src/arcaea_offline_ocr/device/v2/ocr.py @@ -1,28 +1,21 @@ import cv2 import numpy as np +from ...crop import crop_xywh from ...mask import mask_byd, mask_ftr, mask_gray, mask_prs, mask_pst, mask_white -from ...ocr import ocr_digits_knn_model +from ...ocr import ocr_digits_by_contour_knn +from ...sift_db import SIFTDatabase from ...types import Mat, cv2_ml_KNearest from ..shared import DeviceOcrResult -from .find import find_digits +from .find import find_digits_preprocess from .rois import DeviceV2Rois +from .shared import MAX_RECALL_CLOSE_KERNEL class DeviceV2Ocr: - def __init__(self, rois: DeviceV2Rois, knn_model: cv2_ml_KNearest): - self.__rois = rois + def __init__(self, knn_model: cv2_ml_KNearest, sift_db: SIFTDatabase): self.__knn_model = knn_model - - @property - def rois(self): - if not self.__rois: - raise ValueError("`rois` unset.") - return self.__rois - - @rois.setter - def rois(self, rois: DeviceV2Rois): - self.__rois = rois + self.__sift_db = sift_db @property def knn_model(self): @@ -31,46 +24,77 @@ class DeviceV2Ocr: return self.__knn_model @knn_model.setter - def knn_model(self, model: cv2_ml_KNearest): - self.__knn_model = model - - def _base_ocr_digits(self, roi_processed: Mat): - digits = find_digits(roi_processed) - result = "" - for digit in digits: - roi_result = ocr_digits_knn_model(digit, self.knn_model) - if roi_result is not None: - result += str(roi_result) - return int(result, base=10) + def knn_model(self, value: cv2_ml_KNearest): + self.__knn_model = value @property - def pure(self): - roi = mask_gray(self.rois.pure) - return self._base_ocr_digits(roi) + def sift_db(self): + if not self.__sift_db: + raise ValueError("`sift_db` unset.") + return self.__sift_db - @property - def far(self): - roi = mask_gray(self.rois.far) - return self._base_ocr_digits(roi) + @sift_db.setter + def sift_db(self, value: SIFTDatabase): + self.__sift_db = value - @property - def lost(self): - roi = mask_gray(self.rois.lost) - return self._base_ocr_digits(roi) + def _base_ocr_digits(self, roi_masked: Mat): + return ocr_digits_by_contour_knn( + find_digits_preprocess(roi_masked), self.knn_model + ) - @property - def score(self): - roi = cv2.cvtColor(self.rois.score, cv2.COLOR_BGR2HSV) + def ocr_song_id(self, rois: DeviceV2Rois): + cover = cv2.cvtColor(rois.cover, cv2.COLOR_BGR2GRAY) + return self.sift_db.lookup_img(cover)[0] + + def ocr_rating_class(self, rois: DeviceV2Rois): + roi = cv2.cvtColor(rois.max_recall_rating_class, cv2.COLOR_BGR2HSV) + results = [mask_pst(roi), mask_prs(roi), mask_ftr(roi), mask_byd(roi)] + return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0] + + def ocr_score(self, rois: DeviceV2Rois): + roi = cv2.cvtColor(rois.score, cv2.COLOR_BGR2HSV) roi = mask_white(roi) return self._base_ocr_digits(roi) - @property - def rating_class(self): - roi = cv2.cvtColor(self.rois.max_recall_rating_class, cv2.COLOR_BGR2HSV) - results = [ - mask_pst(roi), - mask_prs(roi), - mask_ftr(roi), - mask_byd(roi), - ] - return max(enumerate(results), key=lambda e: np.count_nonzero(e[1]))[0] + def ocr_pure(self, rois: DeviceV2Rois): + roi = mask_gray(rois.pure) + return self._base_ocr_digits(roi) + + def ocr_far(self, rois: DeviceV2Rois): + roi = mask_gray(rois.far) + return self._base_ocr_digits(roi) + + def ocr_lost(self, rois: DeviceV2Rois): + roi = mask_gray(rois.lost) + return self._base_ocr_digits(roi) + + def ocr_max_recall(self, rois: DeviceV2Rois): + roi = mask_gray(rois.max_recall_rating_class) + roi_closed = cv2.morphologyEx(roi, cv2.MORPH_CLOSE, MAX_RECALL_CLOSE_KERNEL) + contours, _ = cv2.findContours( + roi_closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE + ) + rects = sorted( + [cv2.boundingRect(c) for c in contours], key=lambda r: r[0], reverse=True + ) + max_recall_roi = crop_xywh(roi, rects[0]) + return self._base_ocr_digits(max_recall_roi) + + def ocr(self, rois: DeviceV2Rois): + song_id = self.ocr_song_id(rois) + rating_class = self.ocr_rating_class(rois) + score = self.ocr_score(rois) + pure = self.ocr_pure(rois) + far = self.ocr_far(rois) + lost = self.ocr_lost(rois) + max_recall = self.ocr_max_recall(rois) + + return DeviceOcrResult( + rating_class=rating_class, + pure=pure, + far=far, + lost=lost, + score=score, + max_recall=max_recall, + song_id=song_id, + ) diff --git a/src/arcaea_offline_ocr/device/v2/rois.py b/src/arcaea_offline_ocr/device/v2/rois.py index f0804a5..1f1ba35 100644 --- a/src/arcaea_offline_ocr/device/v2/rois.py +++ b/src/arcaea_offline_ocr/device/v2/rois.py @@ -98,10 +98,10 @@ class Sizes: class DeviceV2Rois: - def __init__(self, device: DeviceV2): + def __init__(self, device: DeviceV2, img: Mat): self.device = device self.sizes = Sizes(self.device.factor) - self.__img = None + self.__img = img @staticmethod def construct_int_xywh_rect(x, y, w, h) -> XYWHRect: