diff --git a/src/arcaea_offline_ocr/device.py b/src/arcaea_offline_ocr/device.py deleted file mode 100644 index 7f12de1..0000000 --- a/src/arcaea_offline_ocr/device.py +++ /dev/null @@ -1,37 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Dict, Tuple - -__all__ = ["Device"] - - -@dataclass(kw_only=True) -class Device: - version: int - uuid: str - name: str - pure: Tuple[int, int, int, int] - far: Tuple[int, int, int, int] - lost: Tuple[int, int, int, int] - max_recall: Tuple[int, int, int, int] - rating_class: Tuple[int, int, int, int] - score: Tuple[int, int, int, int] - title: Tuple[int, int, int, int] - - @classmethod - def from_json_object(cls, json_dict: Dict[str, Any]): - if json_dict["version"] == 1: - return cls( - version=1, - uuid=json_dict["uuid"], - name=json_dict["name"], - pure=json_dict["pure"], - far=json_dict["far"], - lost=json_dict["lost"], - max_recall=json_dict["max_recall"], - rating_class=json_dict["rating_class"], - score=json_dict["score"], - title=json_dict["title"], - ) - - def repr_info(self): - return f"Device(version={self.version}, uuid={repr(self.uuid)}, name={repr(self.name)})" diff --git a/src/arcaea_offline_ocr/device/shared.py b/src/arcaea_offline_ocr/device/shared.py new file mode 100644 index 0000000..567061b --- /dev/null +++ b/src/arcaea_offline_ocr/device/shared.py @@ -0,0 +1,14 @@ +import attrs + + +@attrs.define +class DeviceOcrResult: + song_id: None + title: None + rating_class: int + pure: int + far: int + lost: int + score: int + max_recall: int + clear_type: None diff --git a/src/arcaea_offline_ocr/device/v1/crop.py b/src/arcaea_offline_ocr/device/v1/crop.py index 15922f2..2f7d6a2 100644 --- a/src/arcaea_offline_ocr/device/v1/crop.py +++ b/src/arcaea_offline_ocr/device/v1/crop.py @@ -1,10 +1,7 @@ -from math import floor -from typing import Any, Tuple - -import numpy as np +from typing import Tuple from ...types import Mat -from .definition import Device +from .definition import DeviceV1 __all__ = [ "crop_img", @@ -16,7 +13,6 @@ __all__ = [ "crop_to_rating_class", "crop_to_score", "crop_to_title", - "crop_black_edges", ] @@ -29,38 +25,29 @@ def crop_from_device_attr(img: Mat, rect: Tuple[int, int, int, int]): return crop_img(img, top=y, left=x, bottom=y + h, right=x + w) -def crop_to_pure(screenshot: Mat, device: Device): +def crop_to_pure(screenshot: Mat, device: DeviceV1): return crop_from_device_attr(screenshot, device.pure) -def crop_to_far(screenshot: Mat, device: Device): +def crop_to_far(screenshot: Mat, device: DeviceV1): return crop_from_device_attr(screenshot, device.far) -def crop_to_lost(screenshot: Mat, device: Device): +def crop_to_lost(screenshot: Mat, device: DeviceV1): return crop_from_device_attr(screenshot, device.lost) -def crop_to_max_recall(screenshot: Mat, device: Device): +def crop_to_max_recall(screenshot: Mat, device: DeviceV1): return crop_from_device_attr(screenshot, device.max_recall) -def crop_to_rating_class(screenshot: Mat, device: Device): +def crop_to_rating_class(screenshot: Mat, device: DeviceV1): return crop_from_device_attr(screenshot, device.rating_class) -def crop_to_score(screenshot: Mat, device: Device): +def crop_to_score(screenshot: Mat, device: DeviceV1): return crop_from_device_attr(screenshot, device.score) -def crop_to_title(screenshot: Mat, device: Device): +def crop_to_title(screenshot: Mat, device: DeviceV1): return crop_from_device_attr(screenshot, device.title) - - -def is_black_edge(list_of_pixels: Mat, black_pixel=None): - if black_pixel is None: - black_pixel = np.array([0, 0, 0], list_of_pixels.dtype) - pixels = list_of_pixels.reshape([-1, 3]) - return np.count_nonzero(all(pixels < black_pixel, axis=1)) > floor( - len(pixels) * 0.6 - ) diff --git a/src/arcaea_offline_ocr/device/v1/definition.py b/src/arcaea_offline_ocr/device/v1/definition.py index 7f12de1..51ca29c 100644 --- a/src/arcaea_offline_ocr/device/v1/definition.py +++ b/src/arcaea_offline_ocr/device/v1/definition.py @@ -1,11 +1,11 @@ from dataclasses import dataclass from typing import Any, Dict, Tuple -__all__ = ["Device"] +__all__ = ["DeviceV1"] @dataclass(kw_only=True) -class Device: +class DeviceV1: version: int uuid: str name: str diff --git a/src/arcaea_offline_ocr/device/v1/ocr.py b/src/arcaea_offline_ocr/device/v1/ocr.py new file mode 100644 index 0000000..2e3a4b1 --- /dev/null +++ b/src/arcaea_offline_ocr/device/v1/ocr.py @@ -0,0 +1,86 @@ +from typing import List + +import cv2 + +from ...crop import crop_xywh +from ...mask import mask_gray, mask_white +from ...ocr import ocr_digits_by_contour_knn, ocr_rating_class +from ...types import Mat, cv2_ml_KNearest +from ..shared import DeviceOcrResult +from .crop import * +from .definition import DeviceV1 + + +class DeviceV1Ocr: + def __init__(self, device: DeviceV1, knn_model: cv2_ml_KNearest): + self.__device = device + self.__knn_model = knn_model + + @property + def device(self): + return self.__device + + @device.setter + def device(self, value): + self.__device = value + + @property + def knn_model(self): + return self.__knn_model + + @knn_model.setter + def knn_model(self, value): + self.__knn_model = value + + def preprocess_score_roi(self, __roi_gray: Mat) -> List[Mat]: + roi_gray = __roi_gray.copy() + contours, _ = cv2.findContours( + roi_gray, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE + ) + for contour in contours: + rect = cv2.boundingRect(contour) + if rect[3] > roi_gray.shape[0] * 0.6: + continue + roi_gray = cv2.fillPoly(roi_gray, [contour], 0) + return roi_gray + + def ocr(self, img_bgr: Mat): + rating_class_roi = crop_to_rating_class(img_bgr, self.device) + rating_class = ocr_rating_class(rating_class_roi) + + pfl_mr_roi = [ + crop_to_pure(img_bgr, self.device), + crop_to_far(img_bgr, self.device), + crop_to_lost(img_bgr, self.device), + crop_to_max_recall(img_bgr, self.device), + ] + pfl_mr_roi = [mask_gray(roi) for roi in pfl_mr_roi] + + pure, far, lost = [ + ocr_digits_by_contour_knn(roi, self.knn_model) for roi in pfl_mr_roi[:3] + ] + + max_recall_contours, _ = cv2.findContours( + pfl_mr_roi[3], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE + ) + max_recall_rects = [cv2.boundingRect(c) for c in max_recall_contours] + max_recall_rect = sorted(max_recall_rects, key=lambda r: r[0])[-1] + max_recall_roi = crop_xywh(img_bgr, max_recall_rect) + max_recall = ocr_digits_by_contour_knn(max_recall_roi, self.knn_model) + + score_roi = crop_to_score(img_bgr, self.device) + score_roi = mask_white(score_roi) + score_roi = self.preprocess_score_roi(score_roi) + score = ocr_digits_by_contour_knn(score_roi, self.knn_model) + + return DeviceOcrResult( + song_id=None, + title=None, + rating_class=rating_class, + pure=pure, + far=far, + lost=lost, + score=score, + max_recall=max_recall, + clear_type=None, + ) diff --git a/src/arcaea_offline_ocr/device/v2/ocr.py b/src/arcaea_offline_ocr/device/v2/ocr.py index 5bc4634..28e7929 100644 --- a/src/arcaea_offline_ocr/device/v2/ocr.py +++ b/src/arcaea_offline_ocr/device/v2/ocr.py @@ -1,31 +1,18 @@ -from typing import Optional - -import attrs import cv2 import numpy as np from ...mask import mask_byd, mask_ftr, mask_gray, mask_prs, mask_pst, mask_white from ...ocr import ocr_digits_knn_model from ...types import Mat, cv2_ml_KNearest +from ..shared import DeviceOcrResult from .find import find_digits from .rois import DeviceV2Rois -@attrs.define -class DeviceV2OcrResult: - pure: int - far: int - lost: int - score: int - rating_class: int - max_recall: int - title: Optional[str] - - class DeviceV2Ocr: - def __init__(self): - self.__rois = None - self.__knn_model = None + def __init__(self, rois: DeviceV2Rois, knn_model: cv2_ml_KNearest): + self.__rois = rois + self.__knn_model = knn_model @property def rois(self): diff --git a/src/arcaea_offline_ocr/recognize.py b/src/arcaea_offline_ocr/recognize.py deleted file mode 100644 index 80f338c..0000000 --- a/src/arcaea_offline_ocr/recognize.py +++ /dev/null @@ -1,112 +0,0 @@ -from dataclasses import dataclass -from typing import Callable, Optional - -import cv2 - -from .crop import * - -# from .device import Device -from .mask import * -from .ocr import * -from .types import Mat -from .utils import imread_unicode - -Device = None - -__all__ = [ - "process_digits_ocr_img", - "process_tesseract_ocr_img", - "recognize_pure", - "recognize_far_lost", - "recognize_score", - "recognize_max_recall", - "recognize_rating_class", - "recognize_title", - "RecognizeResult", - "recognize", -] - - -def process_digits_ocr_img(img_hsv_cropped: Mat, mask=Callable[[Mat], Mat]): - img_hsv_cropped = mask(img_hsv_cropped) - img_hsv_cropped = cv2.GaussianBlur(img_hsv_cropped, (3, 3), 0) - return img_hsv_cropped - - -def process_tesseract_ocr_img(img_hsv_cropped: Mat, mask=Callable[[Mat], Mat]): - img_hsv_cropped = mask(img_hsv_cropped) - img_hsv_cropped = cv2.GaussianBlur(img_hsv_cropped, (1, 1), 0) - return img_hsv_cropped - - -def recognize_pure(img_hsv_cropped: Mat): - return ocr_pure(process_digits_ocr_img(img_hsv_cropped, mask=mask_gray)) - - -def recognize_far_lost(img_hsv_cropped: Mat): - return ocr_far_lost(process_digits_ocr_img(img_hsv_cropped, mask=mask_gray)) - - -def recognize_score(img_hsv_cropped: Mat): - return ocr_score(process_digits_ocr_img(img_hsv_cropped, mask=mask_white)) - - -def recognize_max_recall(img_hsv_cropped: Mat): - return ocr_max_recall(process_tesseract_ocr_img(img_hsv_cropped, mask=mask_gray)) - - -def recognize_rating_class(img_hsv_cropped: Mat): - return ocr_rating_class( - process_tesseract_ocr_img(img_hsv_cropped, mask=mask_rating_class) - ) - - -def recognize_title(img_hsv_cropped: Mat): - return ocr_title(process_tesseract_ocr_img(img_hsv_cropped, mask=mask_white)) - - -@dataclass(kw_only=True) -class RecognizeResult: - pure: Optional[int] - far: Optional[int] - lost: Optional[int] - score: Optional[int] - max_recall: Optional[int] - rating_class: Optional[int] - title: str - - -def recognize(img_filename: str, device: Device): - img = imread_unicode(img_filename) - img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) - - pure_roi = crop_to_pure(img_hsv, device) - pure = recognize_pure(pure_roi) - - far_roi = crop_to_far(img_hsv, device) - far = recognize_far_lost(far_roi) - - lost_roi = crop_to_lost(img_hsv, device) - lost = recognize_far_lost(lost_roi) - - score_roi = crop_to_score(img_hsv, device) - score = recognize_score(score_roi) - - max_recall_roi = crop_to_max_recall(img_hsv, device) - max_recall = recognize_max_recall(max_recall_roi) - - rating_class_roi = crop_to_rating_class(img_hsv, device) - rating_class = recognize_rating_class(rating_class_roi) - - title_roi = crop_to_title(img_hsv, device) - title = recognize_title(title_roi) - - return RecognizeResult( - pure=pure, - far=far, - lost=lost, - score=score, - max_recall=max_recall, - rating_class=rating_class, - title=title, - )