diff --git a/pyproject.toml b/pyproject.toml index 45e97f0..1390b88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,12 +4,12 @@ build-backend = "setuptools.build_meta" [project] name = "arcaea-offline-ocr" -version = "0.1.0" +version = "0.0.95" authors = [{ name = "283375", email = "log_283375@163.com" }] description = "Extract your Arcaea play result from screenshot." readme = "README.md" requires-python = ">=3.8" -dependencies = ["attrs==23.1.0", "numpy==1.25.2", "opencv-python==4.8.0.76"] +dependencies = ["attrs==23.1.0", "numpy==1.26.1", "opencv-python==4.8.1.78"] classifiers = [ "Development Status :: 3 - Alpha", "Programming Language :: Python :: 3", diff --git a/requirements.txt b/requirements.txt index b065292..b1b83a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ attrs==23.1.0 -numpy==1.25.2 -opencv-python==4.8.0.76 +numpy==1.26.1 +opencv-python==4.8.1.78 diff --git a/src/arcaea_offline_ocr/__init__.py b/src/arcaea_offline_ocr/__init__.py index 41057d6..c2e0b50 100644 --- a/src/arcaea_offline_ocr/__init__.py +++ b/src/arcaea_offline_ocr/__init__.py @@ -1,5 +1,4 @@ from .crop import * from .device import * -from .mask import * from .ocr import * from .utils import * diff --git a/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py b/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py index c033088..845145a 100644 --- a/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py +++ b/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py @@ -3,13 +3,16 @@ from typing import List, Optional, Tuple import cv2 import numpy as np -from PIL import Image from ....crop import crop_xywh -from ....ocr import FixRects, ocr_digits_by_contour_knn, preprocess_hog -from ....phash_db import ImagePHashDatabase -from ....sift_db import SIFTDatabase -from ....types import Mat, cv2_ml_KNearest +from ....ocr import ( + FixRects, + ocr_digits_by_contour_knn, + preprocess_hog, + resize_fill_square, +) +from ....phash_db import ImagePhashDatabase +from ....types import Mat from ....utils import construct_int_xywh_rect from ...shared import B30OcrResultItem from .colors import * @@ -19,9 +22,9 @@ from .rois import ChieriBotV4Rois class ChieriBotV4Ocr: def __init__( self, - score_knn: cv2_ml_KNearest, - pfl_knn: cv2_ml_KNearest, - phash_db: ImagePHashDatabase, + score_knn: cv2.ml.KNearest, + pfl_knn: cv2.ml.KNearest, + phash_db: ImagePhashDatabase, factor: Optional[float] = 1.0, ): self.__score_knn = score_knn @@ -34,7 +37,7 @@ class ChieriBotV4Ocr: return self.__score_knn @score_knn.setter - def score_knn(self, knn_digits_model: Mat): + def score_knn(self, knn_digits_model: cv2.ml.KNearest): self.__score_knn = knn_digits_model @property @@ -42,7 +45,7 @@ class ChieriBotV4Ocr: return self.__pfl_knn @pfl_knn.setter - def pfl_knn(self, knn_digits_model: Mat): + def pfl_knn(self, knn_digits_model: cv2.ml.KNearest): self.__pfl_knn = knn_digits_model @property @@ -50,7 +53,7 @@ class ChieriBotV4Ocr: return self.__phash_db @phash_db.setter - def phash_db(self, phash_db: ImagePHashDatabase): + def phash_db(self, phash_db: ImagePhashDatabase): self.__phash_db = phash_db @property @@ -85,14 +88,6 @@ class ChieriBotV4Ocr: else: return max(enumerate(rating_class_results), key=lambda i: i[1])[0] + 1 - # def ocr_component_title(self, component_bgr: Mat) -> str: - # # sourcery skip: inline-immediately-returned-variable - # title_rect = construct_int_xywh_rect(self.rois.component_rois.title_rect) - # title_roi = crop_xywh(component_bgr, title_rect) - # ocr_result = self.sift_db.ocr(title_roi, cls=False) - # title = ocr_result[0][-1][1][0] if ocr_result and ocr_result[0] else "" - # return title - def ocr_component_song_id(self, component_bgr: Mat): jacket_rect = construct_int_xywh_rect( self.rois.component_rois.jacket_rect, floor @@ -100,20 +95,7 @@ class ChieriBotV4Ocr: jacket_roi = cv2.cvtColor( crop_xywh(component_bgr, jacket_rect), cv2.COLOR_BGR2GRAY ) - return self.phash_db.lookup_image(Image.fromarray(jacket_roi))[0] - - # def ocr_component_score_paddle(self, component_bgr: Mat) -> int: - # # sourcery skip: inline-immediately-returned-variable - # score_rect = construct_int_xywh_rect(self.rois.component_rois.score_rect) - # score_roi = cv2.cvtColor( - # crop_xywh(component_bgr, score_rect), cv2.COLOR_BGR2GRAY - # ) - # _, score_roi = cv2.threshold( - # score_roi, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU - # ) - # score_str = self.sift_db.ocr(score_roi, cls=False)[0][-1][1][0] - # score = int(score_str.replace("'", "").replace(" ", "")) - # return score + return self.phash_db.lookup_jacket(jacket_roi)[0] def ocr_component_score_knn(self, component_bgr: Mat) -> int: # sourcery skip: inline-immediately-returned-variable @@ -223,7 +205,7 @@ class ChieriBotV4Ocr: digits = [] for digit_rect in digit_rects: digit = crop_xywh(roi, digit_rect) - digit = cv2.resize(digit, (20, 20)) + digit = resize_fill_square(digit, 20) digits.append(digit) samples = preprocess_hog(digits) @@ -234,15 +216,6 @@ class ChieriBotV4Ocr: except Exception: return (None, None, None) - # def ocr_component_date(self, component_bgr: Mat): - # date_rect = construct_int_xywh_rect(self.rois.component_rois.date_rect) - # date_roi = cv2.cvtColor(crop_xywh(component_bgr, date_rect), cv2.COLOR_BGR2GRAY) - # _, date_roi = cv2.threshold( - # date_roi, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU - # ) - # date_str = self.sift_db.ocr(date_roi, cls=False)[0][-1][1][0] - # return date_str - def ocr_component(self, component_bgr: Mat) -> B30OcrResultItem: component_blur = cv2.GaussianBlur(component_bgr, (5, 5), 0) rating_class = self.ocr_component_rating_class(component_blur) diff --git a/src/arcaea_offline_ocr/crop.py b/src/arcaea_offline_ocr/crop.py index a65b6ea..12c531d 100644 --- a/src/arcaea_offline_ocr/crop.py +++ b/src/arcaea_offline_ocr/crop.py @@ -1,11 +1,12 @@ -from math import floor +import math from typing import Tuple +import cv2 import numpy as np from .types import Mat -__all__ = ["crop_xywh", "crop_black_edges", "crop_black_edges_grayscale"] +__all__ = ["crop_xywh", "CropBlackEdges"] def crop_xywh(mat: Mat, rect: Tuple[int, int, int, int]): @@ -13,92 +14,53 @@ def crop_xywh(mat: Mat, rect: Tuple[int, int, int, int]): return mat[y : y + h, x : x + w] -def is_black_edge(list_of_pixels: Mat, black_pixel: Mat, ratio: float = 0.6): - pixels = list_of_pixels.reshape([-1, 3]) - return np.count_nonzero(np.all(pixels < black_pixel, axis=1)) > floor( - len(pixels) * ratio - ) +class CropBlackEdges: + @staticmethod + def is_black_edge(__img_gray_slice: Mat, black_pixel: int, ratio: float = 0.6): + pixels_compared = __img_gray_slice < black_pixel + return np.count_nonzero(pixels_compared) > math.floor( + __img_gray_slice.size * ratio + ) + @classmethod + def get_crop_rect(cls, img_gray: Mat, black_threshold: int = 25): + height, width = img_gray.shape[:2] + left = 0 + right = width + top = 0 + bottom = height -def crop_black_edges(img_bgr: Mat, black_threshold: int = 50): - cropped = img_bgr.copy() - black_pixel = np.array([black_threshold] * 3, img_bgr.dtype) - height, width = img_bgr.shape[:2] - left = 0 - right = width - top = 0 - bottom = height + for i in range(width): + column = img_gray[:, i] + if not cls.is_black_edge(column, black_threshold): + break + left += 1 - for i in range(width): - column = cropped[:, i] - if not is_black_edge(column, black_pixel): - break - left += 1 + for i in sorted(range(width), reverse=True): + column = img_gray[:, i] + if i <= left + 1 or not cls.is_black_edge(column, black_threshold): + break + right -= 1 - for i in sorted(range(width), reverse=True): - column = cropped[:, i] - if i <= left + 1 or not is_black_edge(column, black_pixel): - break - right -= 1 + for i in range(height): + row = img_gray[i] + if not cls.is_black_edge(row, black_threshold): + break + top += 1 - for i in range(height): - row = cropped[i] - if not is_black_edge(row, black_pixel): - break - top += 1 + for i in sorted(range(height), reverse=True): + row = img_gray[i] + if i <= top + 1 or not cls.is_black_edge(row, black_threshold): + break + bottom -= 1 - for i in sorted(range(height), reverse=True): - row = cropped[i] - if i <= top + 1 or not is_black_edge(row, black_pixel): - break - bottom -= 1 + assert right > left, "cropped width < 0" + assert bottom > top, "cropped height < 0" + return (left, top, right - left, bottom - top) - return cropped[top:bottom, left:right] - - -def is_black_edge_grayscale( - gray_value_list: np.ndarray, black_threshold: int = 50, ratio: float = 0.6 -) -> bool: - return ( - np.count_nonzero(gray_value_list < black_threshold) - > len(gray_value_list) * ratio - ) - - -def crop_black_edges_grayscale( - img_gray: Mat, black_threshold: int = 50 -) -> Tuple[int, int, int, int]: - """Returns cropped rect""" - height, width = img_gray.shape[:2] - left = 0 - right = width - top = 0 - bottom = height - - for i in range(width): - column = img_gray[:, i] - if not is_black_edge_grayscale(column, black_threshold): - break - left += 1 - - for i in sorted(range(width), reverse=True): - column = img_gray[:, i] - if i <= left + 1 or not is_black_edge_grayscale(column, black_threshold): - break - right -= 1 - - for i in range(height): - row = img_gray[i] - if not is_black_edge_grayscale(row, black_threshold): - break - top += 1 - - for i in sorted(range(height), reverse=True): - row = img_gray[i] - if i <= top + 1 or not is_black_edge_grayscale(row, black_threshold): - break - bottom -= 1 - - assert right > left, "cropped width > 0" - assert bottom > top, "cropped height > 0" - return (left, top, right - left, bottom - top) + @classmethod + def crop( + cls, img: Mat, convert_flag: cv2.COLOR_BGR2GRAY, black_threshold: int = 25 + ) -> Mat: + rect = cls.get_crop_rect(cv2.cvtColor(img, convert_flag), black_threshold) + return crop_xywh(img, rect) diff --git a/src/arcaea_offline_ocr/device/__init__.py b/src/arcaea_offline_ocr/device/__init__.py index e69de29..bd6cd44 100644 --- a/src/arcaea_offline_ocr/device/__init__.py +++ b/src/arcaea_offline_ocr/device/__init__.py @@ -0,0 +1,2 @@ +from .common import DeviceOcrResult +from .ocr import DeviceOcr diff --git a/src/arcaea_offline_ocr/device/common.py b/src/arcaea_offline_ocr/device/common.py new file mode 100644 index 0000000..9cdc076 --- /dev/null +++ b/src/arcaea_offline_ocr/device/common.py @@ -0,0 +1,18 @@ +from typing import Optional + +import attrs + + +@attrs.define +class DeviceOcrResult: + rating_class: int + pure: int + far: int + lost: int + score: int + max_recall: Optional[int] = None + song_id: Optional[str] = None + song_id_possibility: Optional[float] = None + clear_status: Optional[int] = None + partner_id: Optional[str] = None + partner_id_possibility: Optional[float] = None diff --git a/src/arcaea_offline_ocr/device/ocr.py b/src/arcaea_offline_ocr/device/ocr.py new file mode 100644 index 0000000..d0cbf88 --- /dev/null +++ b/src/arcaea_offline_ocr/device/ocr.py @@ -0,0 +1,160 @@ +import cv2 +import numpy as np + +from ..crop import crop_xywh +from ..ocr import ( + FixRects, + ocr_digit_samples_knn, + ocr_digits_by_contour_knn, + preprocess_hog, + resize_fill_square, +) +from ..phash_db import ImagePhashDatabase +from ..types import Mat +from .common import DeviceOcrResult +from .rois.extractor import DeviceRoisExtractor +from .rois.masker import DeviceRoisMasker + + +class DeviceOcr: + def __init__( + self, + extractor: DeviceRoisExtractor, + masker: DeviceRoisMasker, + knn_model: cv2.ml.KNearest, + phash_db: ImagePhashDatabase, + ): + self.extractor = extractor + self.masker = masker + self.knn_model = knn_model + self.phash_db = phash_db + + def pfl(self, roi_gray: Mat, factor: float = 1.25): + contours, _ = cv2.findContours( + roi_gray, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE + ) + filtered_contours = [c for c in contours if cv2.contourArea(c) >= 5 * factor] + rects = [cv2.boundingRect(c) for c in filtered_contours] + rects = FixRects.connect_broken(rects, roi_gray.shape[1], roi_gray.shape[0]) + + filtered_rects = [r for r in rects if r[2] >= 5 * factor and r[3] >= 6 * factor] + filtered_rects = FixRects.split_connected(roi_gray, filtered_rects) + filtered_rects = sorted(filtered_rects, key=lambda r: r[0]) + + roi_ocr = roi_gray.copy() + filtered_contours_flattened = {tuple(c.flatten()) for c in filtered_contours} + for contour in contours: + if tuple(contour.flatten()) in filtered_contours_flattened: + continue + roi_ocr = cv2.fillPoly(roi_ocr, [contour], [0]) + digit_rois = [ + resize_fill_square(crop_xywh(roi_ocr, r), 20) for r in filtered_rects + ] + + samples = preprocess_hog(digit_rois) + return ocr_digit_samples_knn(samples, self.knn_model) + + def pure(self): + return self.pfl(self.masker.pure(self.extractor.pure)) + + def far(self): + return self.pfl(self.masker.far(self.extractor.far)) + + def lost(self): + return self.pfl(self.masker.lost(self.extractor.lost)) + + def score(self): + roi = self.masker.score(self.extractor.score) + contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + for contour in contours: + x, y, w, h = cv2.boundingRect(contour) + if h < roi.shape[0] * 0.6: + roi = cv2.fillPoly(roi, [contour], [0]) + return ocr_digits_by_contour_knn(roi, self.knn_model) + + def rating_class(self): + roi = self.extractor.rating_class + results = [ + self.masker.rating_class_pst(roi), + self.masker.rating_class_prs(roi), + self.masker.rating_class_ftr(roi), + self.masker.rating_class_byd(roi), + ] + return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0] + + def max_recall(self): + return ocr_digits_by_contour_knn( + self.masker.max_recall(self.extractor.max_recall), self.knn_model + ) + + def clear_status(self): + roi = self.extractor.clear_status + results = [ + self.masker.clear_status_track_lost(roi), + self.masker.clear_status_track_complete(roi), + self.masker.clear_status_full_recall(roi), + self.masker.clear_status_pure_memory(roi), + ] + return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0] + + def lookup_song_id(self): + return self.phash_db.lookup_jacket( + cv2.cvtColor(self.extractor.jacket, cv2.COLOR_BGR2GRAY) + ) + + def song_id(self): + return self.lookup_song_id()[0] + + @staticmethod + def preprocess_char_icon(img_gray: Mat): + h, w = img_gray.shape[:2] + img = cv2.copyMakeBorder(img_gray, w - h, 0, 0, 0, cv2.BORDER_REPLICATE) + h, w = img.shape[:2] + img = cv2.fillPoly( + img, + [ + np.array([[0, 0], [round(w / 2), 0], [0, round(h / 2)]], np.int32), + np.array([[w, 0], [round(w / 2), 0], [w, round(h / 2)]], np.int32), + np.array([[0, h], [round(w / 2), h], [0, round(h / 2)]], np.int32), + np.array([[w, h], [round(w / 2), h], [w, round(h / 2)]], np.int32), + ], + (128), + ) + return img + + def lookup_partner_id(self): + return self.phash_db.lookup_partner_icon( + self.preprocess_char_icon( + cv2.cvtColor(self.extractor.partner_icon, cv2.COLOR_BGR2GRAY) + ) + ) + + def partner_id(self): + return self.lookup_partner_id()[0] + + def ocr(self) -> DeviceOcrResult: + rating_class = self.rating_class() + pure = self.pure() + far = self.far() + lost = self.lost() + score = self.score() + max_recall = self.max_recall() + clear_status = self.clear_status() + + hash_len = self.phash_db.hash_size**2 + song_id, song_id_distance = self.lookup_song_id() + partner_id, partner_id_distance = self.lookup_partner_id() + + return DeviceOcrResult( + rating_class=rating_class, + pure=pure, + far=far, + lost=lost, + score=score, + max_recall=max_recall, + song_id=song_id, + song_id_possibility=1 - song_id_distance / hash_len, + clear_status=clear_status, + partner_id=partner_id, + partner_id_possibility=1 - partner_id_distance / hash_len, + ) diff --git a/src/arcaea_offline_ocr/device/rois/__init__.py b/src/arcaea_offline_ocr/device/rois/__init__.py new file mode 100644 index 0000000..e73f32a --- /dev/null +++ b/src/arcaea_offline_ocr/device/rois/__init__.py @@ -0,0 +1,3 @@ +from .definition import * +from .extractor import * +from .masker import * diff --git a/src/arcaea_offline_ocr/device/rois/definition/__init__.py b/src/arcaea_offline_ocr/device/rois/definition/__init__.py new file mode 100644 index 0000000..37f6ef1 --- /dev/null +++ b/src/arcaea_offline_ocr/device/rois/definition/__init__.py @@ -0,0 +1,2 @@ +from .auto import * +from .common import DeviceRois diff --git a/src/arcaea_offline_ocr/device/rois/definition/auto.py b/src/arcaea_offline_ocr/device/rois/definition/auto.py new file mode 100644 index 0000000..66508c9 --- /dev/null +++ b/src/arcaea_offline_ocr/device/rois/definition/auto.py @@ -0,0 +1,255 @@ +from .common import DeviceRois + +__all__ = ["DeviceRoisAuto", "DeviceRoisAutoT1", "DeviceRoisAutoT2"] + + +class DeviceRoisAuto(DeviceRois): + def __init__(self, w: int, h: int): + self.w = w + self.h = h + + +class DeviceRoisAutoT1(DeviceRoisAuto): + @property + def factor(self): + return ( + ((self.w / 16) * 9) / 720 if (self.w / self.h) < (16 / 9) else self.h / 720 + ) + + @property + def w_mid(self): + return self.w / 2 + + @property + def h_mid(self): + return self.h / 2 + + @property + def top_bar(self): + return (0, 0, self.w, 50 * self.factor) + + @property + def layout_area_h_mid(self): + return self.h / 2 + self.top_bar[3] + + @property + def pfl_left_from_w_mid(self): + return 5 * self.factor + + @property + def pfl_x(self): + return self.w_mid + self.pfl_left_from_w_mid + + @property + def pfl_w(self): + return 76 * self.factor + + @property + def pfl_h(self): + return 26 * self.factor + + @property + def pure(self): + return ( + self.pfl_x, + self.layout_area_h_mid + 110 * self.factor, + self.pfl_w, + self.pfl_h, + ) + + @property + def far(self): + return ( + self.pfl_x, + self.pure[1] + self.pure[3] + 12 * self.factor, + self.pfl_w, + self.pfl_h, + ) + + @property + def lost(self): + return ( + self.pfl_x, + self.far[1] + self.far[3] + 10 * self.factor, + self.pfl_w, + self.pfl_h, + ) + + @property + def score(self): + w = 280 * self.factor + h = 45 * self.factor + return ( + self.w_mid - w / 2, + self.layout_area_h_mid - 75 * self.factor - h, + w, + h, + ) + + @property + def rating_class(self): + return ( + self.w_mid - 610 * self.factor, + self.layout_area_h_mid - 180 * self.factor, + 265 * self.factor, + 35 * self.factor, + ) + + @property + def max_recall(self): + return ( + self.w_mid - 465 * self.factor, + self.layout_area_h_mid - 215 * self.factor, + 150 * self.factor, + 35 * self.factor, + ) + + @property + def jacket(self): + return ( + self.w_mid - 610 * self.factor, + self.layout_area_h_mid - 143 * self.factor, + 375 * self.factor, + 375 * self.factor, + ) + + @property + def clear_status(self): + w = 550 * self.factor + h = 60 * self.factor + return ( + self.w_mid - w / 2, + self.layout_area_h_mid - 155 * self.factor - h, + w, + h, + ) + + @property + def partner_icon(self): + w = 90 * self.factor + h = 75 * self.factor + return (self.w_mid - w / 2, 0, w, h) + + +class DeviceRoisAutoT2(DeviceRoisAuto): + @property + def factor(self): + return ( + ((self.w / 16) * 9) / 1080 + if (self.w / self.h) < (16 / 9) + else self.h / 1080 + ) + + @property + def w_mid(self): + return self.w / 2 + + @property + def h_mid(self): + return self.h / 2 + + @property + def top_bar(self): + return (0, 0, self.w, 75 * self.factor) + + @property + def layout_area_h_mid(self): + return self.h / 2 + self.top_bar[3] + + @property + def pfl_mid_from_w_mid(self): + return 60 * self.factor + + @property + def pfl_x(self): + return self.w_mid + 10 * self.factor + + @property + def pfl_w(self): + return 100 * self.factor + + @property + def pfl_h(self): + return 24 * self.factor + + @property + def pure(self): + return ( + self.pfl_x, + self.layout_area_h_mid + 175 * self.factor, + self.pfl_w, + self.pfl_h, + ) + + @property + def far(self): + return ( + self.pfl_x, + self.pure[1] + self.pure[3] + 30 * self.factor, + self.pfl_w, + self.pfl_h, + ) + + @property + def lost(self): + return ( + self.pfl_x, + self.far[1] + self.far[3] + 35 * self.factor, + self.pfl_w, + self.pfl_h, + ) + + @property + def score(self): + w = 420 * self.factor + h = 70 * self.factor + return ( + self.w_mid - w / 2, + self.layout_area_h_mid - 110 * self.factor - h, + w, + h, + ) + + @property + def rating_class(self): + return ( + max(0, self.w_mid - 965 * self.factor), + self.layout_area_h_mid - 330 * self.factor, + 350 * self.factor, + 110 * self.factor, + ) + + @property + def max_recall(self): + return ( + self.w_mid - 625 * self.factor, + self.layout_area_h_mid - 275 * self.factor, + 150 * self.factor, + 50 * self.factor, + ) + + @property + def jacket(self): + return ( + self.w_mid - 915 * self.factor, + self.layout_area_h_mid - 215 * self.factor, + 565 * self.factor, + 565 * self.factor, + ) + + @property + def clear_status(self): + w = 825 * self.factor + h = 90 * self.factor + return ( + self.w_mid - w / 2, + self.layout_area_h_mid - 235 * self.factor - h, + w, + h, + ) + + @property + def partner_icon(self): + w = 135 * self.factor + h = 110 * self.factor + return (self.w_mid - w / 2, 0, w, h) diff --git a/src/arcaea_offline_ocr/device/rois/definition/common.py b/src/arcaea_offline_ocr/device/rois/definition/common.py new file mode 100644 index 0000000..96512c4 --- /dev/null +++ b/src/arcaea_offline_ocr/device/rois/definition/common.py @@ -0,0 +1,15 @@ +from typing import Tuple + +Rect = Tuple[int, int, int, int] + + +class DeviceRois: + pure: Rect + far: Rect + lost: Rect + score: Rect + rating_class: Rect + max_recall: Rect + jacket: Rect + clear_status: Rect + partner_icon: Rect diff --git a/src/arcaea_offline_ocr/device/v1/__init__.py b/src/arcaea_offline_ocr/device/rois/definition/custom.py similarity index 100% rename from src/arcaea_offline_ocr/device/v1/__init__.py rename to src/arcaea_offline_ocr/device/rois/definition/custom.py diff --git a/src/arcaea_offline_ocr/device/rois/extractor/__init__.py b/src/arcaea_offline_ocr/device/rois/extractor/__init__.py new file mode 100644 index 0000000..1b6ae1d --- /dev/null +++ b/src/arcaea_offline_ocr/device/rois/extractor/__init__.py @@ -0,0 +1 @@ +from .common import DeviceRoisExtractor diff --git a/src/arcaea_offline_ocr/device/rois/extractor/common.py b/src/arcaea_offline_ocr/device/rois/extractor/common.py new file mode 100644 index 0000000..671ae2c --- /dev/null +++ b/src/arcaea_offline_ocr/device/rois/extractor/common.py @@ -0,0 +1,48 @@ +from ....crop import crop_xywh +from ....types import Mat +from ..definition.common import DeviceRois + + +class DeviceRoisExtractor: + def __init__(self, img: Mat, rois: DeviceRois): + self.img = img + self.sizes = rois + + def __construct_int_rect(self, rect): + return tuple(round(r) for r in rect) + + @property + def pure(self): + return crop_xywh(self.img, self.__construct_int_rect(self.sizes.pure)) + + @property + def far(self): + return crop_xywh(self.img, self.__construct_int_rect(self.sizes.far)) + + @property + def lost(self): + return crop_xywh(self.img, self.__construct_int_rect(self.sizes.lost)) + + @property + def score(self): + return crop_xywh(self.img, self.__construct_int_rect(self.sizes.score)) + + @property + def jacket(self): + return crop_xywh(self.img, self.__construct_int_rect(self.sizes.jacket)) + + @property + def rating_class(self): + return crop_xywh(self.img, self.__construct_int_rect(self.sizes.rating_class)) + + @property + def max_recall(self): + return crop_xywh(self.img, self.__construct_int_rect(self.sizes.max_recall)) + + @property + def clear_status(self): + return crop_xywh(self.img, self.__construct_int_rect(self.sizes.clear_status)) + + @property + def partner_icon(self): + return crop_xywh(self.img, self.__construct_int_rect(self.sizes.partner_icon)) diff --git a/src/arcaea_offline_ocr/device/rois/masker/__init__.py b/src/arcaea_offline_ocr/device/rois/masker/__init__.py new file mode 100644 index 0000000..ced796d --- /dev/null +++ b/src/arcaea_offline_ocr/device/rois/masker/__init__.py @@ -0,0 +1,2 @@ +from .auto import * +from .common import DeviceRoisMasker diff --git a/src/arcaea_offline_ocr/device/rois/masker/auto.py b/src/arcaea_offline_ocr/device/rois/masker/auto.py new file mode 100644 index 0000000..ec92548 --- /dev/null +++ b/src/arcaea_offline_ocr/device/rois/masker/auto.py @@ -0,0 +1,255 @@ +import cv2 +import numpy as np + +from ....types import Mat +from .common import DeviceRoisMasker + + +class DeviceRoisMaskerAuto(DeviceRoisMasker): + ... + + +class DeviceRoisMaskerAutoT1(DeviceRoisMaskerAuto): + GRAY_BGR_MIN = np.array([50] * 3, np.uint8) + GRAY_BGR_MAX = np.array([160] * 3, np.uint8) + + WHITE_HSV_MIN = np.array([0, 0, 240], np.uint8) + WHITE_HSV_MAX = np.array([179, 10, 255], np.uint8) + + PST_HSV_MIN = np.array([100, 50, 80], np.uint8) + PST_HSV_MAX = np.array([100, 255, 255], np.uint8) + + PRS_HSV_MIN = np.array([43, 40, 75], np.uint8) + PRS_HSV_MAX = np.array([50, 155, 190], np.uint8) + + FTR_HSV_MIN = np.array([149, 30, 0], np.uint8) + FTR_HSV_MAX = np.array([155, 181, 150], np.uint8) + + BYD_HSV_MIN = np.array([170, 50, 50], np.uint8) + BYD_HSV_MAX = np.array([179, 210, 198], np.uint8) + + TRACK_LOST_HSV_MIN = np.array([170, 75, 90], np.uint8) + TRACK_LOST_HSV_MAX = np.array([175, 170, 160], np.uint8) + + TRACK_COMPLETE_HSV_MIN = np.array([140, 0, 50], np.uint8) + TRACK_COMPLETE_HSV_MAX = np.array([145, 50, 130], np.uint8) + + FULL_RECALL_HSV_MIN = np.array([140, 60, 80], np.uint8) + FULL_RECALL_HSV_MAX = np.array([150, 130, 145], np.uint8) + + PURE_MEMORY_HSV_MIN = np.array([90, 70, 80], np.uint8) + PURE_MEMORY_HSV_MAX = np.array([110, 200, 175], np.uint8) + + @classmethod + def gray(cls, roi_bgr: Mat) -> Mat: + bgr_value_equal_mask = np.max(roi_bgr, axis=2) - np.min(roi_bgr, axis=2) <= 5 + img_bgr = roi_bgr.copy() + img_bgr[~bgr_value_equal_mask] = np.array([0, 0, 0], roi_bgr.dtype) + img_bgr = cv2.erode(img_bgr, cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))) + img_bgr = cv2.dilate(img_bgr, cv2.getStructuringElement(cv2.MORPH_RECT, (1, 1))) + return cv2.inRange(img_bgr, cls.GRAY_BGR_MIN, cls.GRAY_BGR_MAX) + + @classmethod + def pure(cls, roi_bgr: Mat) -> Mat: + return cls.gray(roi_bgr) + + @classmethod + def far(cls, roi_bgr: Mat) -> Mat: + return cls.gray(roi_bgr) + + @classmethod + def lost(cls, roi_bgr: Mat) -> Mat: + return cls.gray(roi_bgr) + + @classmethod + def score(cls, roi_bgr: Mat) -> Mat: + return cv2.inRange( + cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), + cls.WHITE_HSV_MIN, + cls.WHITE_HSV_MAX, + ) + + @classmethod + def rating_class_pst(cls, roi_bgr: Mat) -> Mat: + return cv2.inRange( + cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.PST_HSV_MIN, cls.PST_HSV_MAX + ) + + @classmethod + def rating_class_prs(cls, roi_bgr: Mat) -> Mat: + return cv2.inRange( + cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.PRS_HSV_MIN, cls.PRS_HSV_MAX + ) + + @classmethod + def rating_class_ftr(cls, roi_bgr: Mat) -> Mat: + return cv2.inRange( + cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.FTR_HSV_MIN, cls.FTR_HSV_MAX + ) + + @classmethod + def rating_class_byd(cls, roi_bgr: Mat) -> Mat: + return cv2.inRange( + cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.BYD_HSV_MIN, cls.BYD_HSV_MAX + ) + + @classmethod + def max_recall(cls, roi_bgr: Mat) -> Mat: + return cls.gray(roi_bgr) + + @classmethod + def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat: + return cv2.inRange( + cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), + cls.TRACK_LOST_HSV_MIN, + cls.TRACK_LOST_HSV_MAX, + ) + + @classmethod + def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat: + return cv2.inRange( + cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), + cls.TRACK_COMPLETE_HSV_MIN, + cls.TRACK_COMPLETE_HSV_MAX, + ) + + @classmethod + def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat: + return cv2.inRange( + cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), + cls.FULL_RECALL_HSV_MIN, + cls.FULL_RECALL_HSV_MAX, + ) + + @classmethod + def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat: + return cv2.inRange( + cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), + cls.PURE_MEMORY_HSV_MIN, + cls.PURE_MEMORY_HSV_MAX, + ) + + +class DeviceRoisMaskerAutoT2(DeviceRoisMaskerAuto): + PFL_HSV_MIN = np.array([0, 0, 248], np.uint8) + PFL_HSV_MAX = np.array([179, 10, 255], np.uint8) + + WHITE_HSV_MIN = np.array([0, 0, 240], np.uint8) + WHITE_HSV_MAX = np.array([179, 10, 255], np.uint8) + + PST_HSV_MIN = np.array([100, 50, 80], np.uint8) + PST_HSV_MAX = np.array([100, 255, 255], np.uint8) + + PRS_HSV_MIN = np.array([43, 40, 75], np.uint8) + PRS_HSV_MAX = np.array([50, 155, 190], np.uint8) + + FTR_HSV_MIN = np.array([149, 30, 0], np.uint8) + FTR_HSV_MAX = np.array([155, 181, 150], np.uint8) + + BYD_HSV_MIN = np.array([170, 50, 50], np.uint8) + BYD_HSV_MAX = np.array([179, 210, 198], np.uint8) + + MAX_RECALL_HSV_MIN = np.array([125, 0, 0], np.uint8) + MAX_RECALL_HSV_MAX = np.array([145, 100, 150], np.uint8) + + TRACK_LOST_HSV_MIN = np.array([170, 75, 90], np.uint8) + TRACK_LOST_HSV_MAX = np.array([175, 170, 160], np.uint8) + + TRACK_COMPLETE_HSV_MIN = np.array([140, 0, 50], np.uint8) + TRACK_COMPLETE_HSV_MAX = np.array([145, 50, 130], np.uint8) + + FULL_RECALL_HSV_MIN = np.array([140, 60, 80], np.uint8) + FULL_RECALL_HSV_MAX = np.array([150, 130, 145], np.uint8) + + PURE_MEMORY_HSV_MIN = np.array([90, 70, 80], np.uint8) + PURE_MEMORY_HSV_MAX = np.array([110, 200, 175], np.uint8) + + @classmethod + def pfl(cls, roi_bgr: Mat) -> Mat: + return cv2.inRange( + cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.PFL_HSV_MIN, cls.PFL_HSV_MAX + ) + + @classmethod + def pure(cls, roi_bgr: Mat) -> Mat: + return cls.pfl(roi_bgr) + + @classmethod + def far(cls, roi_bgr: Mat) -> Mat: + return cls.pfl(roi_bgr) + + @classmethod + def lost(cls, roi_bgr: Mat) -> Mat: + return cls.pfl(roi_bgr) + + @classmethod + def score(cls, roi_bgr: Mat) -> Mat: + return cv2.inRange( + cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), + cls.WHITE_HSV_MIN, + cls.WHITE_HSV_MAX, + ) + + @classmethod + def rating_class_pst(cls, roi_bgr: Mat) -> Mat: + return cv2.inRange( + cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.PST_HSV_MIN, cls.PST_HSV_MAX + ) + + @classmethod + def rating_class_prs(cls, roi_bgr: Mat) -> Mat: + return cv2.inRange( + cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.PRS_HSV_MIN, cls.PRS_HSV_MAX + ) + + @classmethod + def rating_class_ftr(cls, roi_bgr: Mat) -> Mat: + return cv2.inRange( + cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.FTR_HSV_MIN, cls.FTR_HSV_MAX + ) + + @classmethod + def rating_class_byd(cls, roi_bgr: Mat) -> Mat: + return cv2.inRange( + cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.BYD_HSV_MIN, cls.BYD_HSV_MAX + ) + + @classmethod + def max_recall(cls, roi_bgr: Mat) -> Mat: + return cv2.inRange( + cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), + cls.MAX_RECALL_HSV_MIN, + cls.MAX_RECALL_HSV_MAX, + ) + + @classmethod + def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat: + return cv2.inRange( + cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), + cls.TRACK_LOST_HSV_MIN, + cls.TRACK_LOST_HSV_MAX, + ) + + @classmethod + def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat: + return cv2.inRange( + cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), + cls.TRACK_COMPLETE_HSV_MIN, + cls.TRACK_COMPLETE_HSV_MAX, + ) + + @classmethod + def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat: + return cv2.inRange( + cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), + cls.FULL_RECALL_HSV_MIN, + cls.FULL_RECALL_HSV_MAX, + ) + + @classmethod + def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat: + return cv2.inRange( + cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), + cls.PURE_MEMORY_HSV_MIN, + cls.PURE_MEMORY_HSV_MAX, + ) diff --git a/src/arcaea_offline_ocr/device/rois/masker/common.py b/src/arcaea_offline_ocr/device/rois/masker/common.py new file mode 100644 index 0000000..cb1223f --- /dev/null +++ b/src/arcaea_offline_ocr/device/rois/masker/common.py @@ -0,0 +1,55 @@ +from ....types import Mat + + +class DeviceRoisMasker: + @classmethod + def pure(cls, roi_bgr: Mat) -> Mat: + raise NotImplementedError() + + @classmethod + def far(cls, roi_bgr: Mat) -> Mat: + raise NotImplementedError() + + @classmethod + def lost(cls, roi_bgr: Mat) -> Mat: + raise NotImplementedError() + + @classmethod + def score(cls, roi_bgr: Mat) -> Mat: + raise NotImplementedError() + + @classmethod + def rating_class_pst(cls, roi_bgr: Mat) -> Mat: + raise NotImplementedError() + + @classmethod + def rating_class_prs(cls, roi_bgr: Mat) -> Mat: + raise NotImplementedError() + + @classmethod + def rating_class_ftr(cls, roi_bgr: Mat) -> Mat: + raise NotImplementedError() + + @classmethod + def rating_class_byd(cls, roi_bgr: Mat) -> Mat: + raise NotImplementedError() + + @classmethod + def max_recall(cls, roi_bgr: Mat) -> Mat: + raise NotImplementedError() + + @classmethod + def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat: + raise NotImplementedError() + + @classmethod + def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat: + raise NotImplementedError() + + @classmethod + def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat: + raise NotImplementedError() + + @classmethod + def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat: + raise NotImplementedError() diff --git a/src/arcaea_offline_ocr/device/shared.py b/src/arcaea_offline_ocr/device/shared.py deleted file mode 100644 index 5a48d37..0000000 --- a/src/arcaea_offline_ocr/device/shared.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import Optional - -import attrs - - -@attrs.define -class DeviceOcrResult: - rating_class: int - pure: int - far: int - lost: int - score: int - max_recall: int - song_id: Optional[str] = None - title: Optional[str] = None - clear_type: Optional[str] = None diff --git a/src/arcaea_offline_ocr/device/v1/crop.py b/src/arcaea_offline_ocr/device/v1/crop.py deleted file mode 100644 index 2f7d6a2..0000000 --- a/src/arcaea_offline_ocr/device/v1/crop.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Tuple - -from ...types import Mat -from .definition import DeviceV1 - -__all__ = [ - "crop_img", - "crop_from_device_attr", - "crop_to_pure", - "crop_to_far", - "crop_to_lost", - "crop_to_max_recall", - "crop_to_rating_class", - "crop_to_score", - "crop_to_title", -] - - -def crop_img(img: Mat, *, top: int, left: int, bottom: int, right: int): - return img[top:bottom, left:right] - - -def crop_from_device_attr(img: Mat, rect: Tuple[int, int, int, int]): - x, y, w, h = rect - return crop_img(img, top=y, left=x, bottom=y + h, right=x + w) - - -def crop_to_pure(screenshot: Mat, device: DeviceV1): - return crop_from_device_attr(screenshot, device.pure) - - -def crop_to_far(screenshot: Mat, device: DeviceV1): - return crop_from_device_attr(screenshot, device.far) - - -def crop_to_lost(screenshot: Mat, device: DeviceV1): - return crop_from_device_attr(screenshot, device.lost) - - -def crop_to_max_recall(screenshot: Mat, device: DeviceV1): - return crop_from_device_attr(screenshot, device.max_recall) - - -def crop_to_rating_class(screenshot: Mat, device: DeviceV1): - return crop_from_device_attr(screenshot, device.rating_class) - - -def crop_to_score(screenshot: Mat, device: DeviceV1): - return crop_from_device_attr(screenshot, device.score) - - -def crop_to_title(screenshot: Mat, device: DeviceV1): - return crop_from_device_attr(screenshot, device.title) diff --git a/src/arcaea_offline_ocr/device/v1/definition.py b/src/arcaea_offline_ocr/device/v1/definition.py deleted file mode 100644 index 51ca29c..0000000 --- a/src/arcaea_offline_ocr/device/v1/definition.py +++ /dev/null @@ -1,37 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Dict, Tuple - -__all__ = ["DeviceV1"] - - -@dataclass(kw_only=True) -class DeviceV1: - version: int - uuid: str - name: str - pure: Tuple[int, int, int, int] - far: Tuple[int, int, int, int] - lost: Tuple[int, int, int, int] - max_recall: Tuple[int, int, int, int] - rating_class: Tuple[int, int, int, int] - score: Tuple[int, int, int, int] - title: Tuple[int, int, int, int] - - @classmethod - def from_json_object(cls, json_dict: Dict[str, Any]): - if json_dict["version"] == 1: - return cls( - version=1, - uuid=json_dict["uuid"], - name=json_dict["name"], - pure=json_dict["pure"], - far=json_dict["far"], - lost=json_dict["lost"], - max_recall=json_dict["max_recall"], - rating_class=json_dict["rating_class"], - score=json_dict["score"], - title=json_dict["title"], - ) - - def repr_info(self): - return f"Device(version={self.version}, uuid={repr(self.uuid)}, name={repr(self.name)})" diff --git a/src/arcaea_offline_ocr/device/v1/ocr.py b/src/arcaea_offline_ocr/device/v1/ocr.py deleted file mode 100644 index 2e3a4b1..0000000 --- a/src/arcaea_offline_ocr/device/v1/ocr.py +++ /dev/null @@ -1,86 +0,0 @@ -from typing import List - -import cv2 - -from ...crop import crop_xywh -from ...mask import mask_gray, mask_white -from ...ocr import ocr_digits_by_contour_knn, ocr_rating_class -from ...types import Mat, cv2_ml_KNearest -from ..shared import DeviceOcrResult -from .crop import * -from .definition import DeviceV1 - - -class DeviceV1Ocr: - def __init__(self, device: DeviceV1, knn_model: cv2_ml_KNearest): - self.__device = device - self.__knn_model = knn_model - - @property - def device(self): - return self.__device - - @device.setter - def device(self, value): - self.__device = value - - @property - def knn_model(self): - return self.__knn_model - - @knn_model.setter - def knn_model(self, value): - self.__knn_model = value - - def preprocess_score_roi(self, __roi_gray: Mat) -> List[Mat]: - roi_gray = __roi_gray.copy() - contours, _ = cv2.findContours( - roi_gray, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE - ) - for contour in contours: - rect = cv2.boundingRect(contour) - if rect[3] > roi_gray.shape[0] * 0.6: - continue - roi_gray = cv2.fillPoly(roi_gray, [contour], 0) - return roi_gray - - def ocr(self, img_bgr: Mat): - rating_class_roi = crop_to_rating_class(img_bgr, self.device) - rating_class = ocr_rating_class(rating_class_roi) - - pfl_mr_roi = [ - crop_to_pure(img_bgr, self.device), - crop_to_far(img_bgr, self.device), - crop_to_lost(img_bgr, self.device), - crop_to_max_recall(img_bgr, self.device), - ] - pfl_mr_roi = [mask_gray(roi) for roi in pfl_mr_roi] - - pure, far, lost = [ - ocr_digits_by_contour_knn(roi, self.knn_model) for roi in pfl_mr_roi[:3] - ] - - max_recall_contours, _ = cv2.findContours( - pfl_mr_roi[3], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE - ) - max_recall_rects = [cv2.boundingRect(c) for c in max_recall_contours] - max_recall_rect = sorted(max_recall_rects, key=lambda r: r[0])[-1] - max_recall_roi = crop_xywh(img_bgr, max_recall_rect) - max_recall = ocr_digits_by_contour_knn(max_recall_roi, self.knn_model) - - score_roi = crop_to_score(img_bgr, self.device) - score_roi = mask_white(score_roi) - score_roi = self.preprocess_score_roi(score_roi) - score = ocr_digits_by_contour_knn(score_roi, self.knn_model) - - return DeviceOcrResult( - song_id=None, - title=None, - rating_class=rating_class, - pure=pure, - far=far, - lost=lost, - score=score, - max_recall=max_recall, - clear_type=None, - ) diff --git a/src/arcaea_offline_ocr/device/v2/__init__.py b/src/arcaea_offline_ocr/device/v2/__init__.py deleted file mode 100644 index 64db9c3..0000000 --- a/src/arcaea_offline_ocr/device/v2/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .definition import DeviceV2 -from .ocr import DeviceV2Ocr -from .rois import DeviceV2AutoRois, DeviceV2Rois -from .shared import MAX_RECALL_CLOSE_KERNEL diff --git a/src/arcaea_offline_ocr/device/v2/definition.py b/src/arcaea_offline_ocr/device/v2/definition.py deleted file mode 100644 index 31dd17a..0000000 --- a/src/arcaea_offline_ocr/device/v2/definition.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Iterable - -from attrs import define, field - -from ...types import XYWHRect - - -def iterable_to_xywh_rect(__iter: Iterable) -> XYWHRect: - return XYWHRect(*__iter) - - -@define(kw_only=True) -class DeviceV2: - version = field(type=int) - uuid = field(type=str) - name = field(type=str) - crop_black_edges = field(type=bool) - factor = field(type=float) - pure = field(converter=iterable_to_xywh_rect, default=[0, 0, 0, 0]) - far = field(converter=iterable_to_xywh_rect, default=[0, 0, 0, 0]) - lost = field(converter=iterable_to_xywh_rect, default=[0, 0, 0, 0]) - score = field(converter=iterable_to_xywh_rect, default=[0, 0, 0, 0]) - max_recall_rating_class = field( - converter=iterable_to_xywh_rect, default=[0, 0, 0, 0] - ) - title = field(converter=iterable_to_xywh_rect, default=[0, 0, 0, 0]) diff --git a/src/arcaea_offline_ocr/device/v2/ocr.py b/src/arcaea_offline_ocr/device/v2/ocr.py deleted file mode 100644 index 1dab23d..0000000 --- a/src/arcaea_offline_ocr/device/v2/ocr.py +++ /dev/null @@ -1,172 +0,0 @@ -import math -from functools import lru_cache -from typing import Sequence - -import cv2 -import numpy as np -from PIL import Image - -from ...crop import crop_xywh -from ...mask import ( - mask_byd, - mask_ftr, - mask_gray, - mask_max_recall_purple, - mask_pfl_white, - mask_prs, - mask_pst, - mask_white, -) -from ...ocr import ( - FixRects, - ocr_digit_samples_knn, - ocr_digits_by_contour_knn, - preprocess_hog, - resize_fill_square, -) -from ...phash_db import ImagePHashDatabase -from ...sift_db import SIFTDatabase -from ...types import Mat, cv2_ml_KNearest -from ..shared import DeviceOcrResult -from .preprocess import find_digits_preprocess -from .rois import DeviceV2Rois -from .shared import MAX_RECALL_CLOSE_KERNEL -from .sizes import SizesV2 - - -class DeviceV2Ocr: - def __init__(self, knn_model: cv2_ml_KNearest, phash_db: ImagePHashDatabase): - self.__knn_model = knn_model - self.__phash_db = phash_db - - @property - def knn_model(self): - if not self.__knn_model: - raise ValueError("`knn_model` unset.") - return self.__knn_model - - @knn_model.setter - def knn_model(self, value: cv2_ml_KNearest): - self.__knn_model = value - - @property - def phash_db(self): - if not self.__phash_db: - raise ValueError("`phash_db` unset.") - return self.__phash_db - - @phash_db.setter - def phash_db(self, value: SIFTDatabase): - self.__phash_db = value - - @lru_cache - def _get_digit_widths(self, num_list: Sequence[int], factor: float): - widths = set() - for n in num_list: - lower = math.floor(n * factor) - upper = math.ceil(n * factor) - widths.update(range(lower, upper + 1)) - return widths - - def _base_ocr_pfl(self, roi_masked: Mat, factor: float = 1.0): - contours, _ = cv2.findContours( - roi_masked, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE - ) - filtered_contours = [c for c in contours if cv2.contourArea(c) >= 5 * factor] - rects = [cv2.boundingRect(c) for c in filtered_contours] - rects = FixRects.connect_broken(rects, roi_masked.shape[1], roi_masked.shape[0]) - rect_contour_map = dict(zip(rects, filtered_contours)) - - filtered_rects = [r for r in rects if r[2] >= 5 * factor and r[3] >= 6 * factor] - filtered_rects = FixRects.split_connected(roi_masked, filtered_rects) - filtered_rects = sorted(filtered_rects, key=lambda r: r[0]) - - roi_ocr = roi_masked.copy() - filtered_contours_flattened = {tuple(c.flatten()) for c in filtered_contours} - for contour in contours: - if tuple(contour.flatten()) in filtered_contours_flattened: - continue - roi_ocr = cv2.fillPoly(roi_ocr, [contour], [0]) - digit_rois = [ - resize_fill_square(crop_xywh(roi_ocr, r), 20) - for r in sorted(filtered_rects, key=lambda r: r[0]) - ] - # [cv2.imshow(f"r{i}", r) for i, r in enumerate(digit_rois)] - # cv2.waitKey(0) - samples = preprocess_hog(digit_rois) - return ocr_digit_samples_knn(samples, self.knn_model) - - def ocr_song_id(self, rois: DeviceV2Rois): - jacket = cv2.cvtColor(rois.jacket, cv2.COLOR_BGR2GRAY) - return self.phash_db.lookup_image(Image.fromarray(jacket))[0] - - def ocr_rating_class(self, rois: DeviceV2Rois): - roi = cv2.cvtColor(rois.max_recall_rating_class, cv2.COLOR_BGR2HSV) - results = [mask_pst(roi), mask_prs(roi), mask_ftr(roi), mask_byd(roi)] - return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0] - - def ocr_score(self, rois: DeviceV2Rois): - roi = cv2.cvtColor(rois.score, cv2.COLOR_BGR2HSV) - roi = mask_white(roi) - contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) - for contour in contours: - x, y, w, h = cv2.boundingRect(contour) - if h < roi.shape[0] * 0.6: - roi = cv2.fillPoly(roi, [contour], [0]) - return ocr_digits_by_contour_knn(roi, self.knn_model) - - def mask_pfl(self, pfl_roi: Mat, rois: DeviceV2Rois): - return ( - mask_pfl_white(cv2.cvtColor(pfl_roi, cv2.COLOR_BGR2HSV)) - if isinstance(rois.sizes, SizesV2) - else mask_gray(pfl_roi) - ) - - def ocr_pure(self, rois: DeviceV2Rois): - roi = self.mask_pfl(rois.pure, rois) - return self._base_ocr_pfl(roi, rois.sizes.factor) - - def ocr_far(self, rois: DeviceV2Rois): - roi = self.mask_pfl(rois.far, rois) - return self._base_ocr_pfl(roi, rois.sizes.factor) - - def ocr_lost(self, rois: DeviceV2Rois): - roi = self.mask_pfl(rois.lost, rois) - return self._base_ocr_pfl(roi, rois.sizes.factor) - - def ocr_max_recall(self, rois: DeviceV2Rois): - roi = ( - mask_max_recall_purple( - cv2.cvtColor(rois.max_recall_rating_class, cv2.COLOR_BGR2HSV) - ) - if isinstance(rois.sizes, SizesV2) - else mask_gray(rois.max_recall_rating_class) - ) - roi_closed = cv2.morphologyEx(roi, cv2.MORPH_CLOSE, MAX_RECALL_CLOSE_KERNEL) - contours, _ = cv2.findContours( - roi_closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE - ) - rects = [cv2.boundingRect(c) for c in contours] - rects = [r for r in rects if r[2] > 5 and r[3] > 5] - rects = sorted(rects, key=lambda r: r[0], reverse=True) - max_recall_roi = crop_xywh(roi, rects[0]) - return ocr_digits_by_contour_knn(max_recall_roi, self.knn_model) - - def ocr(self, rois: DeviceV2Rois): - song_id = self.ocr_song_id(rois) - rating_class = self.ocr_rating_class(rois) - score = self.ocr_score(rois) - pure = self.ocr_pure(rois) - far = self.ocr_far(rois) - lost = self.ocr_lost(rois) - max_recall = self.ocr_max_recall(rois) - - return DeviceOcrResult( - rating_class=rating_class, - pure=pure, - far=far, - lost=lost, - score=score, - max_recall=max_recall, - song_id=song_id, - ) diff --git a/src/arcaea_offline_ocr/device/v2/preprocess.py b/src/arcaea_offline_ocr/device/v2/preprocess.py deleted file mode 100644 index deef8d5..0000000 --- a/src/arcaea_offline_ocr/device/v2/preprocess.py +++ /dev/null @@ -1,54 +0,0 @@ -import cv2 - -from ...types import Mat -from .shared import * - - -def find_digits_preprocess(__img_masked: Mat) -> Mat: - img = __img_masked.copy() - img_denoised = cv2.morphologyEx(img, cv2.MORPH_OPEN, PFL_DENOISE_KERNEL) - # img_denoised = cv2.bitwise_and(img, img_denoised) - - denoise_contours, _ = cv2.findContours( - img_denoised, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE - ) - # cv2.drawContours(img_denoised, contours, -1, [128], 2) - - # fill all contour.area < max(contour.area) * ratio with black pixels - # for denoise purposes - - # define threshold contour area - # we assume the smallest digit "1", is 80% height of the image, - # and at least 1.5 pixel wide, considering cv2.contourArea always - # returns a smaller value than the actual contour area. - max_contour_area = __img_masked.shape[0] * 0.8 * 1.5 - filtered_contours = list( - filter(lambda c: cv2.contourArea(c) >= max_contour_area, denoise_contours) - ) - - filtered_contours_flattened = {tuple(c.flatten()) for c in filtered_contours} - - for contour in denoise_contours: - if tuple(contour.flatten()) not in filtered_contours_flattened: - img_denoised = cv2.fillPoly(img_denoised, [contour], [0]) - - # old algorithm, finding the largest contour area - ## contour_area_tuples = [(contour, cv2.contourArea(contour)) for contour in contours] - ## contour_area_tuples = sorted( - ## contour_area_tuples, key=lambda item: item[1], reverse=True - ## ) - ## max_contour_area = contour_area_tuples[0][1] - ## print(max_contour_area, [item[1] for item in contour_area_tuples]) - ## contours_filter_end_index = len(contours) - ## for i, item in enumerate(contour_area_tuples): - ## contour, area = item - ## if area < max_contour_area * 0.15: - ## contours_filter_end_index = i - ## break - ## contours = [item[0] for item in contour_area_tuples] - ## for contour in contours[-contours_filter_end_index - 1:]: - ## img = cv2.fillPoly(img, [contour], [0]) - ## img_denoised = cv2.fillPoly(img_denoised, [contour], [0]) - ## contours = contours[:contours_filter_end_index] - - return img_denoised diff --git a/src/arcaea_offline_ocr/device/v2/rois.py b/src/arcaea_offline_ocr/device/v2/rois.py deleted file mode 100644 index 100aece..0000000 --- a/src/arcaea_offline_ocr/device/v2/rois.py +++ /dev/null @@ -1,199 +0,0 @@ -from typing import Union - -from ...crop import crop_black_edges, crop_xywh -from ...types import Mat, XYWHRect -from .definition import DeviceV2 -from .sizes import Sizes, SizesV1 - - -def to_int(num: Union[int, float]) -> int: - return round(num) - - -class DeviceV2Rois: - def __init__(self, device: DeviceV2, img: Mat): - self.device = device - self.sizes = SizesV1(self.device.factor) - self.__img = img - - @staticmethod - def construct_int_xywh_rect(x, y, w, h) -> XYWHRect: - return XYWHRect(*[to_int(item) for item in [x, y, w, h]]) - - @property - def img(self): - return self.__img - - @img.setter - def img(self, img: Mat): - self.__img = ( - crop_black_edges(img) if self.device.crop_black_edges else img.copy() - ) - - @property - def h(self): - return self.img.shape[0] - - @property - def vmid(self): - return self.h / 2 - - @property - def w(self): - return self.img.shape[1] - - @property - def hmid(self): - return self.w / 2 - - @property - def h_without_top_bar(self): - """img_height -= top_bar_height""" - return self.h - self.sizes.TOP_BAR_HEIGHT - - @property - def h_without_top_bar_mid(self): - return self.sizes.TOP_BAR_HEIGHT + self.h_without_top_bar / 2 - - @property - def pfl_top(self): - return self.h_without_top_bar_mid + self.sizes.PFL_TOP_FROM_VMID - - @property - def pfl_left(self): - return self.hmid + self.sizes.PFL_LEFT_FROM_HMID - - @property - def pure_rect(self): - return self.construct_int_xywh_rect( - x=self.pfl_left, - y=self.pfl_top, - w=self.sizes.PFL_WIDTH, - h=self.sizes.PFL_FONT_PX, - ) - - @property - def pure(self): - return crop_xywh(self.img, self.pure_rect) - - @property - def far_rect(self): - return self.construct_int_xywh_rect( - x=self.pfl_left, - y=self.pfl_top + self.sizes.PFL_FONT_PX + self.sizes.PURE_FAR_GAP, - w=self.sizes.PFL_WIDTH, - h=self.sizes.PFL_FONT_PX, - ) - - @property - def far(self): - return crop_xywh(self.img, self.far_rect) - - @property - def lost_rect(self): - return self.construct_int_xywh_rect( - x=self.pfl_left, - y=( - self.pfl_top - + self.sizes.PFL_FONT_PX * 2 - + self.sizes.PURE_FAR_GAP - + self.sizes.FAR_LOST_GAP - ), - w=self.sizes.PFL_WIDTH, - h=self.sizes.PFL_FONT_PX, - ) - - @property - def lost(self): - return crop_xywh(self.img, self.lost_rect) - - @property - def score_rect(self): - return self.construct_int_xywh_rect( - x=self.hmid - (self.sizes.SCORE_WIDTH / 2), - y=( - self.h_without_top_bar_mid - + self.sizes.SCORE_BOTTOM_FROM_VMID - - self.sizes.SCORE_FONT_PX - ), - w=self.sizes.SCORE_WIDTH, - h=self.sizes.SCORE_FONT_PX, - ) - - @property - def score(self): - return crop_xywh(self.img, self.score_rect) - - @property - def max_recall_rating_class_rect(self): - x = ( - self.hmid - + self.sizes.JACKET_RIGHT_FROM_HOR_MID - - self.sizes.JACKET_WIDTH - - 25 * self.sizes.factor - ) - return self.construct_int_xywh_rect( - x=x, - y=( - self.h_without_top_bar_mid - - self.sizes.SCORE_PANEL[1] / 2 - - self.sizes.MR_RT_HEIGHT - ), - w=self.sizes.MR_RT_WIDTH, - h=self.sizes.MR_RT_HEIGHT, - ) - - @property - def max_recall_rating_class(self): - return crop_xywh(self.img, self.max_recall_rating_class_rect) - - @property - def title_rect(self): - return self.construct_int_xywh_rect( - x=0, - y=self.h_without_top_bar_mid - + self.sizes.TITLE_BOTTOM_FROM_VMID - - self.sizes.TITLE_FONT_PX, - w=self.hmid + self.sizes.TITLE_WIDTH_RIGHT, - h=self.sizes.TITLE_FONT_PX, - ) - - @property - def title(self): - return crop_xywh(self.img, self.title_rect) - - @property - def jacket_rect(self): - return self.construct_int_xywh_rect( - x=self.hmid - + self.sizes.JACKET_RIGHT_FROM_HOR_MID - - self.sizes.JACKET_WIDTH, - y=self.h_without_top_bar_mid - self.sizes.SCORE_PANEL[1] / 2, - w=self.sizes.JACKET_WIDTH, - h=self.sizes.JACKET_WIDTH, - ) - - @property - def jacket(self): - return crop_xywh(self.img, self.jacket_rect) - - -class DeviceV2AutoRois(DeviceV2Rois): - @staticmethod - def get_factor(width: int, height: int): - ratio = width / height - return ((width / 16) * 9) / 720 if ratio < (16 / 9) else height / 720 - - def __init__(self, img: Mat): - factor = self.get_factor(img.shape[1], img.shape[0]) - self.sizes = SizesV1(factor) - self.__img = None - self.img = img - - @property - def img(self): - return self.__img - - @img.setter - def img(self, img: Mat): - self.__img = crop_black_edges(img) diff --git a/src/arcaea_offline_ocr/device/v2/shared.py b/src/arcaea_offline_ocr/device/v2/shared.py deleted file mode 100644 index ca511b1..0000000 --- a/src/arcaea_offline_ocr/device/v2/shared.py +++ /dev/null @@ -1,9 +0,0 @@ -from cv2 import MORPH_RECT, getStructuringElement - -PFL_DENOISE_KERNEL = getStructuringElement(MORPH_RECT, [2, 2]) -PFL_ERODE_KERNEL = getStructuringElement(MORPH_RECT, [3, 3]) -PFL_CLOSE_HORIZONTAL_KERNEL = getStructuringElement(MORPH_RECT, [10, 1]) - -MAX_RECALL_DENOISE_KERNEL = getStructuringElement(MORPH_RECT, [3, 3]) -MAX_RECALL_ERODE_KERNEL = getStructuringElement(MORPH_RECT, [2, 2]) -MAX_RECALL_CLOSE_KERNEL = getStructuringElement(MORPH_RECT, [20, 1]) diff --git a/src/arcaea_offline_ocr/device/v2/sizes.py b/src/arcaea_offline_ocr/device/v2/sizes.py deleted file mode 100644 index 3347cb2..0000000 --- a/src/arcaea_offline_ocr/device/v2/sizes.py +++ /dev/null @@ -1,254 +0,0 @@ -from typing import Tuple, Union - - -def apply_factor(num: Union[int, float], factor: float): - return num * factor - - -class Sizes: - def __init__(self, factor: float): - raise NotImplementedError() - - @property - def TOP_BAR_HEIGHT(self): - ... - - @property - def SCORE_PANEL(self) -> Tuple[int, int]: - ... - - @property - def PFL_TOP_FROM_VMID(self): - ... - - @property - def PFL_LEFT_FROM_HMID(self): - ... - - @property - def PFL_WIDTH(self): - ... - - @property - def PFL_FONT_PX(self): - ... - - @property - def PURE_FAR_GAP(self): - ... - - @property - def FAR_LOST_GAP(self): - ... - - @property - def SCORE_BOTTOM_FROM_VMID(self): - ... - - @property - def SCORE_FONT_PX(self): - ... - - @property - def SCORE_WIDTH(self): - ... - - @property - def JACKET_RIGHT_FROM_HOR_MID(self): - ... - - @property - def JACKET_WIDTH(self): - ... - - @property - def MR_RT_RIGHT_FROM_HMID(self): - ... - - @property - def MR_RT_WIDTH(self): - ... - - @property - def MR_RT_HEIGHT(self): - ... - - @property - def TITLE_BOTTOM_FROM_VMID(self): - ... - - @property - def TITLE_FONT_PX(self): - ... - - @property - def TITLE_WIDTH_RIGHT(self): - ... - - -class SizesV1(Sizes): - def __init__(self, factor: float): - self.factor = factor - - def apply_factor(self, num): - return apply_factor(num, self.factor) - - @property - def TOP_BAR_HEIGHT(self): - return self.apply_factor(50) - - @property - def SCORE_PANEL(self) -> Tuple[int, int]: - return tuple(self.apply_factor(num) for num in [485, 239]) - - @property - def PFL_TOP_FROM_VMID(self): - return self.apply_factor(135) - - @property - def PFL_LEFT_FROM_HMID(self): - return self.apply_factor(5) - - @property - def PFL_WIDTH(self): - return self.apply_factor(76) - - @property - def PFL_FONT_PX(self): - return self.apply_factor(26) - - @property - def PURE_FAR_GAP(self): - return self.apply_factor(12) - - @property - def FAR_LOST_GAP(self): - return self.apply_factor(10) - - @property - def SCORE_BOTTOM_FROM_VMID(self): - return self.apply_factor(-50) - - @property - def SCORE_FONT_PX(self): - return self.apply_factor(45) - - @property - def SCORE_WIDTH(self): - return self.apply_factor(280) - - @property - def JACKET_RIGHT_FROM_HOR_MID(self): - return self.apply_factor(-235) - - @property - def JACKET_WIDTH(self): - return self.apply_factor(375) - - @property - def MR_RT_RIGHT_FROM_HMID(self): - return self.apply_factor(-300) - - @property - def MR_RT_WIDTH(self): - return self.apply_factor(275) - - @property - def MR_RT_HEIGHT(self): - return self.apply_factor(75) - - @property - def TITLE_BOTTOM_FROM_VMID(self): - return self.apply_factor(-265) - - @property - def TITLE_FONT_PX(self): - return self.apply_factor(40) - - @property - def TITLE_WIDTH_RIGHT(self): - return self.apply_factor(275) - - -class SizesV2(Sizes): - def __init__(self, factor: float): - self.factor = factor - - def apply_factor(self, num): - return apply_factor(num, self.factor) - - @property - def TOP_BAR_HEIGHT(self): - return self.apply_factor(50) - - @property - def SCORE_PANEL(self) -> Tuple[int, int]: - return tuple(self.apply_factor(num) for num in [447, 233]) - - @property - def PFL_TOP_FROM_VMID(self): - return self.apply_factor(142) - - @property - def PFL_LEFT_FROM_HMID(self): - return self.apply_factor(10) - - @property - def PFL_WIDTH(self): - return self.apply_factor(60) - - @property - def PFL_FONT_PX(self): - return self.apply_factor(16) - - @property - def PURE_FAR_GAP(self): - return self.apply_factor(20) - - @property - def FAR_LOST_GAP(self): - return self.apply_factor(23) - - @property - def SCORE_BOTTOM_FROM_VMID(self): - return self.apply_factor(-50) - - @property - def SCORE_FONT_PX(self): - return self.apply_factor(45) - - @property - def SCORE_WIDTH(self): - return self.apply_factor(280) - - @property - def JACKET_RIGHT_FROM_HOR_MID(self): - return self.apply_factor(-235) - - @property - def JACKET_WIDTH(self): - return self.apply_factor(375) - - @property - def MR_RT_RIGHT_FROM_HMID(self): - return self.apply_factor(-330) - - @property - def MR_RT_WIDTH(self): - return self.apply_factor(330) - - @property - def MR_RT_HEIGHT(self): - return self.apply_factor(75) - - @property - def TITLE_BOTTOM_FROM_VMID(self): - return self.apply_factor(-265) - - @property - def TITLE_FONT_PX(self): - return self.apply_factor(40) - - @property - def TITLE_WIDTH_RIGHT(self): - return self.apply_factor(275) diff --git a/src/arcaea_offline_ocr/mask.py b/src/arcaea_offline_ocr/mask.py deleted file mode 100644 index 29e8560..0000000 --- a/src/arcaea_offline_ocr/mask.py +++ /dev/null @@ -1,119 +0,0 @@ -import cv2 -import numpy as np - -from .types import Mat - -__all__ = [ - "GRAY_MIN_HSV", - "GRAY_MAX_HSV", - "WHITE_MIN_HSV", - "WHITE_MAX_HSV", - "PFL_WHITE_MIN_HSV", - "PFL_WHITE_MAX_HSV", - "PST_MIN_HSV", - "PST_MAX_HSV", - "PRS_MIN_HSV", - "PRS_MAX_HSV", - "FTR_MIN_HSV", - "FTR_MAX_HSV", - "BYD_MIN_HSV", - "BYD_MAX_HSV", - "MAX_RECALL_PURPLE_MIN_HSV", - "MAX_RECALL_PURPLE_MAX_HSV", - "mask_gray", - "mask_white", - "mask_pfl_white", - "mask_pst", - "mask_prs", - "mask_ftr", - "mask_byd", - "mask_rating_class", - "mask_max_recall_purple", -] - -GRAY_MIN_HSV = np.array([0, 0, 70], np.uint8) -GRAY_MAX_HSV = np.array([0, 0, 200], np.uint8) - -GRAY_MIN_BGR = np.array([50] * 3, np.uint8) -GRAY_MAX_BGR = np.array([160] * 3, np.uint8) - -WHITE_MIN_HSV = np.array([0, 0, 240], np.uint8) -WHITE_MAX_HSV = np.array([179, 10, 255], np.uint8) - -PFL_WHITE_MIN_HSV = np.array([0, 0, 248], np.uint8) -PFL_WHITE_MAX_HSV = np.array([179, 10, 255], np.uint8) - -PST_MIN_HSV = np.array([100, 50, 80], np.uint8) -PST_MAX_HSV = np.array([100, 255, 255], np.uint8) - -PRS_MIN_HSV = np.array([43, 40, 75], np.uint8) -PRS_MAX_HSV = np.array([50, 155, 190], np.uint8) - -FTR_MIN_HSV = np.array([149, 30, 0], np.uint8) -FTR_MAX_HSV = np.array([155, 181, 150], np.uint8) - -BYD_MIN_HSV = np.array([170, 50, 50], np.uint8) -BYD_MAX_HSV = np.array([179, 210, 198], np.uint8) - -MAX_RECALL_PURPLE_MIN_HSV = np.array([125, 0, 0], np.uint8) -MAX_RECALL_PURPLE_MAX_HSV = np.array([145, 100, 150], np.uint8) - - -def mask_gray(__img_bgr: Mat): - # bgr_value_equal_mask = all(__img_bgr[:, 1:] == __img_bgr[:, :-1], axis=1) - bgr_value_equal_mask = np.max(__img_bgr, axis=2) - np.min(__img_bgr, axis=2) <= 5 - img_bgr = __img_bgr.copy() - img_bgr[~bgr_value_equal_mask] = np.array([0, 0, 0], __img_bgr.dtype) - img_bgr = cv2.erode(img_bgr, cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))) - img_bgr = cv2.dilate(img_bgr, cv2.getStructuringElement(cv2.MORPH_RECT, (1, 1))) - return cv2.inRange(img_bgr, GRAY_MIN_BGR, GRAY_MAX_BGR) - - -def mask_white(img_hsv: Mat): - mask = cv2.inRange(img_hsv, WHITE_MIN_HSV, WHITE_MAX_HSV) - mask = cv2.dilate(mask, cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))) - return mask - - -def mask_pfl_white(img_hsv: Mat): - mask = cv2.inRange(img_hsv, PFL_WHITE_MIN_HSV, PFL_WHITE_MAX_HSV) - mask = cv2.dilate(mask, cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))) - return mask - - -def mask_pst(img_hsv: Mat): - mask = cv2.inRange(img_hsv, PST_MIN_HSV, PST_MAX_HSV) - mask = cv2.dilate(mask, (1, 1)) - return mask - - -def mask_prs(img_hsv: Mat): - mask = cv2.inRange(img_hsv, PRS_MIN_HSV, PRS_MAX_HSV) - mask = cv2.dilate(mask, (1, 1)) - return mask - - -def mask_ftr(img_hsv: Mat): - mask = cv2.inRange(img_hsv, FTR_MIN_HSV, FTR_MAX_HSV) - mask = cv2.dilate(mask, (1, 1)) - return mask - - -def mask_byd(img_hsv: Mat): - mask = cv2.inRange(img_hsv, BYD_MIN_HSV, BYD_MAX_HSV) - mask = cv2.dilate(mask, (2, 2)) - return mask - - -def mask_rating_class(img_hsv: Mat): - pst = mask_pst(img_hsv) - prs = mask_prs(img_hsv) - ftr = mask_ftr(img_hsv) - byd = mask_byd(img_hsv) - return cv2.bitwise_or(byd, cv2.bitwise_or(ftr, cv2.bitwise_or(pst, prs))) - - -def mask_max_recall_purple(img_hsv: Mat): - mask = cv2.inRange(img_hsv, MAX_RECALL_PURPLE_MIN_HSV, MAX_RECALL_PURPLE_MAX_HSV) - mask = cv2.dilate(mask, (2, 2)) - return mask diff --git a/src/arcaea_offline_ocr/ocr.py b/src/arcaea_offline_ocr/ocr.py index 1c9c36e..44ca73a 100644 --- a/src/arcaea_offline_ocr/ocr.py +++ b/src/arcaea_offline_ocr/ocr.py @@ -1,14 +1,11 @@ import math -from copy import deepcopy from typing import Optional, Sequence, Tuple import cv2 import numpy as np -from numpy.linalg import norm from .crop import crop_xywh -from .mask import mask_byd, mask_ftr, mask_prs, mask_pst -from .types import Mat, cv2_ml_KNearest +from .types import Mat __all__ = [ "FixRects", @@ -65,8 +62,7 @@ class FixRects: new_h = new_bottom - new_y new_rects.append((new_x, new_y, new_w, new_h)) - return_rects = deepcopy(rects) - return_rects = [r for r in return_rects if r not in consumed_rects] + return_rects = [r for r in rects if r not in consumed_rects] return_rects.extend(new_rects) return return_rects @@ -81,42 +77,42 @@ class FixRects: new_rects = [] for rect in rects: rx, ry, rw, rh = rect - if rw / rh > rect_wh_ratio: - # consider this is a connected contour - connected_rects.append(rect) + if rw / rh <= rect_wh_ratio: + continue - # find the thinnest part - border_ignore = round(rw * width_range_ratio) - img_cropped = crop_xywh( - img_masked, - (border_ignore, ry, rw - border_ignore, rh), - ) - white_pixels = {} # dict[x, white_pixel_number] - for i in range(img_cropped.shape[1]): - col = img_cropped[:, i] - white_pixels[rx + border_ignore + i] = np.count_nonzero(col > 200) - least_white_pixels = min(v for v in white_pixels.values() if v > 0) - x_values = [ - x - for x, pixel in white_pixels.items() - if pixel == least_white_pixels - ] - # select only middle values - x_mean = np.mean(x_values) - x_std = np.std(x_values) - x_values = [ - x - for x in x_values - if x_mean - x_std * 1.5 <= x <= x_mean + x_std * 1.5 - ] - x_mid = round(np.median(x_values)) + connected_rects.append(rect) - # split the rect - new_rects.extend( - [(rx, ry, x_mid - rx, rh), (x_mid, ry, rx + rw - x_mid, rh)] - ) + # find the thinnest part + border_ignore = round(rw * width_range_ratio) + img_cropped = crop_xywh( + img_masked, + (border_ignore, ry, rw - border_ignore, rh), + ) + white_pixels = {} # dict[x, white_pixel_number] + for i in range(img_cropped.shape[1]): + col = img_cropped[:, i] + white_pixels[rx + border_ignore + i] = np.count_nonzero(col > 200) + + if all(v == 0 for v in white_pixels.values()): + return rects + + least_white_pixels = min(v for v in white_pixels.values() if v > 0) + x_values = [ + x for x, pixel in white_pixels.items() if pixel == least_white_pixels + ] + # select only middle values + x_mean = np.mean(x_values) + x_std = np.std(x_values) + x_values = [ + x for x in x_values if x_mean - x_std * 1.5 <= x <= x_mean + x_std * 1.5 + ] + x_mid = round(np.median(x_values)) + + # split the rect + new_rects.extend( + [(rx, ry, x_mid - rx, rh), (x_mid, ry, rx + rw - x_mid, rh)] + ) - return_rects = deepcopy(rects) return_rects = [r for r in rects if r not in connected_rects] return_rects.extend(new_rects) return return_rects @@ -145,33 +141,16 @@ def resize_fill_square(img: Mat, target: int = 20): def preprocess_hog(digit_rois): - # https://github.com/opencv/opencv/blob/f834736307c8328340aea48908484052170c9224/samples/python/digits.py + # https://learnopencv.com/handwritten-digits-classification-an-opencv-c-python-tutorial/ samples = [] for digit in digit_rois: - gx = cv2.Sobel(digit, cv2.CV_32F, 1, 0) - gy = cv2.Sobel(digit, cv2.CV_32F, 0, 1) - mag, ang = cv2.cartToPolar(gx, gy) - bin_n = 16 - _bin = np.int32(bin_n * ang / (2 * np.pi)) - bin_cells = _bin[:10, :10], _bin[10:, :10], _bin[:10, 10:], _bin[10:, 10:] - mag_cells = mag[:10, :10], mag[10:, :10], mag[:10, 10:], mag[10:, 10:] - hists = [ - np.bincount(b.ravel(), m.ravel(), bin_n) - for b, m in zip(bin_cells, mag_cells) - ] - hist = np.hstack(hists) - - # transform to Hellinger kernel - eps = 1e-7 - hist /= hist.sum() + eps - hist = np.sqrt(hist) - hist /= norm(hist) + eps - + hog = cv2.HOGDescriptor((20, 20), (10, 10), (5, 5), (10, 10), 9) + hist = hog.compute(digit) samples.append(hist) return np.float32(samples) -def ocr_digit_samples_knn(__samples, knn_model: cv2_ml_KNearest, k: int = 4): +def ocr_digit_samples_knn(__samples, knn_model: cv2.ml.KNearest, k: int = 4): _, results, _, _ = knn_model.findNearest(__samples, k) result_list = [int(r) for r in results.ravel()] result_str = "".join(str(r) for r in result_list if r > -1) @@ -192,20 +171,10 @@ def ocr_digits_by_contour_get_samples(__roi_gray: Mat, size: int): def ocr_digits_by_contour_knn( __roi_gray: Mat, - knn_model: cv2_ml_KNearest, + knn_model: cv2.ml.KNearest, *, k=4, size: int = 20, ) -> int: samples = ocr_digits_by_contour_get_samples(__roi_gray, size) return ocr_digit_samples_knn(samples, knn_model, k) - - -def ocr_rating_class(roi_hsv: Mat): - mask_results = [ - mask_pst(roi_hsv), - mask_prs(roi_hsv), - mask_ftr(roi_hsv), - mask_byd(roi_hsv), - ] - return max(enumerate(mask_results), key=lambda e: np.count_nonzero(e[1]))[0] diff --git a/src/arcaea_offline_ocr/phash_db.py b/src/arcaea_offline_ocr/phash_db.py index 6bbcd5b..dba7c04 100644 --- a/src/arcaea_offline_ocr/phash_db.py +++ b/src/arcaea_offline_ocr/phash_db.py @@ -1,8 +1,34 @@ import sqlite3 +from typing import List, Union -import imagehash +import cv2 import numpy as np -from PIL import Image + +from .types import Mat + + +def phash_opencv(img_gray, hash_size=8, highfreq_factor=4): + # type: (Union[Mat, np.ndarray], int, int) -> np.ndarray + """ + Perceptual Hash computation. + + Implementation follows http://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html + + Adapted from `imagehash.phash`, pure opencv implementation + + The result is slightly different from `imagehash.phash`. + """ + if hash_size < 2: + raise ValueError("Hash size must be greater than or equal to 2") + + img_size = hash_size * highfreq_factor + image = cv2.resize(img_gray, (img_size, img_size), interpolation=cv2.INTER_LANCZOS4) + image = np.float32(image) + dct = cv2.dct(image) + dctlowfreq = dct[:hash_size, :hash_size] + med = np.median(dctlowfreq) + diff = dctlowfreq > med + return diff def hamming_distance_sql_function(user_input, db_entry) -> int: @@ -11,7 +37,7 @@ def hamming_distance_sql_function(user_input, db_entry) -> int: ) -class ImagePHashDatabase: +class ImagePhashDatabase: def __init__(self, db_path: str): with sqlite3.connect(db_path) as conn: self.hash_size = int( @@ -30,36 +56,63 @@ class ImagePHashDatabase: ).fetchone()[0] ) - # self.conn.create_function( - # "HAMMING_DISTANCE", - # 2, - # hamming_distance_sql_function, - # deterministic=True, - # ) - - self.ids = [i[0] for i in conn.execute("SELECT id FROM hashes").fetchall()] + self.ids: List[str] = [ + i[0] for i in conn.execute("SELECT id FROM hashes").fetchall() + ] self.hashes_byte = [ i[0] for i in conn.execute("SELECT hash FROM hashes").fetchall() ] self.hashes = [np.frombuffer(hb, bool) for hb in self.hashes_byte] - self.hashes_slice_size = round(len(self.hashes_byte[0]) * 0.25) - self.hashes_head = [h[: self.hashes_slice_size] for h in self.hashes] - self.hashes_tail = [h[-self.hashes_slice_size :] for h in self.hashes] - def lookup_hash(self, image_hash: imagehash.ImageHash, *, limit: int = 5): - image_hash = image_hash.hash.flatten() - # image_hash_head = image_hash[: self.hashes_slice_size] - # image_hash_tail = image_hash[-self.hashes_slice_size :] - # head_xor_results = [image_hash_head ^ h for h in self.hashes] - # tail_xor_results = [image_hash_head ^ h for h in self.hashes] + self.jacket_ids: List[str] = [] + self.jacket_hashes = [] + self.partner_icon_ids: List[str] = [] + self.partner_icon_hashes = [] + + for id, hash in zip(self.ids, self.hashes): + id_splitted = id.split("||") + if len(id_splitted) > 1 and id_splitted[0] == "partner_icon": + self.partner_icon_ids.append(id_splitted[1]) + self.partner_icon_hashes.append(hash) + else: + self.jacket_ids.append(id) + self.jacket_hashes.append(hash) + + def calculate_phash(self, img_gray: Mat): + return phash_opencv( + img_gray, hash_size=self.hash_size, highfreq_factor=self.highfreq_factor + ) + + def lookup_hash(self, image_hash: np.ndarray, *, limit: int = 5): + image_hash = image_hash.flatten() xor_results = [ (id, np.count_nonzero(image_hash ^ h)) for id, h in zip(self.ids, self.hashes) ] return sorted(xor_results, key=lambda r: r[1])[:limit] - def lookup_image(self, pil_image: Image.Image): - image_hash = imagehash.phash( - pil_image, hash_size=self.hash_size, highfreq_factor=self.highfreq_factor - ) + def lookup_image(self, img_gray: Mat): + image_hash = self.calculate_phash(img_gray) return self.lookup_hash(image_hash)[0] + + def lookup_jackets(self, img_gray: Mat, *, limit: int = 5): + image_hash = self.calculate_phash(img_gray).flatten() + xor_results = [ + (id, np.count_nonzero(image_hash ^ h)) + for id, h in zip(self.jacket_ids, self.jacket_hashes) + ] + return sorted(xor_results, key=lambda r: r[1])[:limit] + + def lookup_jacket(self, img_gray: Mat): + return self.lookup_jackets(img_gray)[0] + + def lookup_partner_icons(self, img_gray: Mat, *, limit: int = 5): + image_hash = self.calculate_phash(img_gray).flatten() + xor_results = [ + (id, np.count_nonzero(image_hash ^ h)) + for id, h in zip(self.partner_icon_ids, self.partner_icon_hashes) + ] + return sorted(xor_results, key=lambda r: r[1])[:limit] + + def lookup_partner_icon(self, img_gray: Mat): + return self.lookup_partner_icons(img_gray)[0] diff --git a/src/arcaea_offline_ocr/sift_db.py b/src/arcaea_offline_ocr/sift_db.py deleted file mode 100644 index d249dae..0000000 --- a/src/arcaea_offline_ocr/sift_db.py +++ /dev/null @@ -1,110 +0,0 @@ -import io -import sqlite3 -from gzip import GzipFile -from typing import Tuple - -import cv2 -import numpy as np - -from .types import Mat - - -class SIFTDatabase: - def __init__(self, db_path: str, load: bool = True): - self.__db_path = db_path - self.__tags = [] - self.__descriptors = [] - self.__size = None - - self.__sift = cv2.SIFT_create() - self.__bf_matcher = cv2.BFMatcher() - - if load: - self.load_db() - - @property - def db_path(self): - return self.__db_path - - @db_path.setter - def db_path(self, value): - self.__db_path = value - - @property - def tags(self): - return self.__tags - - @property - def descriptors(self): - return self.__descriptors - - @property - def size(self): - return self.__size - - @size.setter - def size(self, value: Tuple[int, int]): - self.__size = value - - @property - def sift(self): - return self.__sift - - @property - def bf_matcher(self): - return self.__bf_matcher - - def load_db(self): - conn = sqlite3.connect(self.db_path) - with conn: - cursor = conn.cursor() - - size_str = cursor.execute( - "SELECT value FROM properties WHERE id = 'size'" - ).fetchone()[0] - sizr_str_arr = size_str.split(", ") - self.size = tuple(int(s) for s in sizr_str_arr) - tag__descriptors_bytes = cursor.execute( - "SELECT tag, descriptors FROM sift" - ).fetchall() - - gzipped = int( - cursor.execute( - "SELECT value FROM properties WHERE id = 'gzip'" - ).fetchone()[0] - ) - for tag, descriptor_bytes in tag__descriptors_bytes: - buffer = io.BytesIO(descriptor_bytes) - self.tags.append(tag) - if gzipped == 0: - self.descriptors.append(np.load(buffer)) - else: - gzipped_buffer = GzipFile(None, "rb", fileobj=buffer) - self.descriptors.append(np.load(gzipped_buffer)) - - def lookup_img( - self, - __img: Mat, - *, - sift=None, - bf=None, - ) -> Tuple[str, float]: - sift = sift or self.sift - bf = bf or self.bf_matcher - - img = __img.copy() - if self.size is not None: - img = cv2.resize(img, self.size) - _, descriptors = sift.detectAndCompute(img, None) - - good_results = [] - for des in self.descriptors: - matches = bf.knnMatch(descriptors, des, k=2) - good = sum(m.distance < 0.75 * n.distance for m, n in matches) - good_results.append(good) - best_match_index = max(enumerate(good_results), key=lambda i: i[1])[0] - - return ( - self.tags[best_match_index], - good_results[best_match_index] / len(descriptors), - ) diff --git a/src/arcaea_offline_ocr/types.py b/src/arcaea_offline_ocr/types.py index 3a3dc92..7f1bc5b 100644 --- a/src/arcaea_offline_ocr/types.py +++ b/src/arcaea_offline_ocr/types.py @@ -1,10 +1,9 @@ from collections.abc import Iterable -from typing import Any, NamedTuple, Protocol, Tuple, Union +from typing import NamedTuple, Tuple, Union import numpy as np -# from pylance -Mat = np.ndarray[int, np.dtype[np.generic]] +Mat = np.ndarray class XYWHRect(NamedTuple): @@ -24,19 +23,3 @@ class XYWHRect(NamedTuple): raise ValueError() return self.__class__(*[a - b for a, b in zip(self, other)]) - - -class cv2_ml_StatModel(Protocol): - def predict(self, samples: np.ndarray, results: np.ndarray, flags: int = 0): - ... - - def train(self, samples: np.ndarray, layout: int, responses: np.ndarray): - ... - - -class cv2_ml_KNearest(cv2_ml_StatModel, Protocol): - def findNearest( - self, samples: np.ndarray, k: int - ) -> Tuple[Any, np.ndarray, np.ndarray, np.ndarray]: - """cv.ml.KNearest.findNearest(samples, k[, results[, neighborResponses[, dist]]]) -> retval, results, neighborResponses, dist""" - ... diff --git a/src/arcaea_offline_ocr/utils.py b/src/arcaea_offline_ocr/utils.py index e55ea0f..9fa9390 100644 --- a/src/arcaea_offline_ocr/utils.py +++ b/src/arcaea_offline_ocr/utils.py @@ -1,17 +1,15 @@ -import io from collections.abc import Iterable -from typing import Callable, Tuple, TypeVar, Union, overload +from typing import Callable, TypeVar, Union, overload import cv2 import numpy as np -from PIL import Image, ImageCms -from .types import Mat, XYWHRect +from .types import XYWHRect __all__ = ["imread_unicode"] -def imread_unicode(filepath: str, flags: int = cv2.IMREAD_UNCHANGED) -> Mat: +def imread_unicode(filepath: str, flags: int = cv2.IMREAD_UNCHANGED): # https://stackoverflow.com/a/57872297/16484891 # CC BY-SA 4.0 return cv2.imdecode(np.fromfile(filepath, dtype=np.uint8), flags) @@ -46,25 +44,3 @@ def apply_factor(item, factor: float): return item * factor elif isinstance(item, Iterable): return item.__class__([i * factor for i in item]) - - -def convert_to_srgb(pil_img: Image.Image): - """ - Convert PIL image to sRGB color space (if possible) - and save the converted file. - - https://stackoverflow.com/a/65667797/16484891 - - CC BY-SA 4.0 - """ - icc = pil_img.info.get("icc_profile", "") - icc_conv = "" - - if icc: - io_handle = io.BytesIO(icc) # virtual file - src_profile = ImageCms.ImageCmsProfile(io_handle) - dst_profile = ImageCms.createProfile("sRGB") - img_conv = ImageCms.profileToProfile(pil_img, src_profile, dst_profile) - icc_conv = img_conv.info.get("icc_profile", "") - - return img_conv if icc != icc_conv else pil_img