diff --git a/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py b/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py index 98e1292..845145a 100644 --- a/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py +++ b/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py @@ -12,6 +12,7 @@ from ....ocr import ( resize_fill_square, ) from ....phash_db import ImagePhashDatabase +from ....types import Mat from ....utils import construct_int_xywh_rect from ...shared import B30OcrResultItem from .colors import * @@ -67,10 +68,10 @@ class ChieriBotV4Ocr: def factor(self, factor: float): self.__rois.factor = factor - def set_factor(self, img: cv2.Mat): + def set_factor(self, img: Mat): self.factor = img.shape[0] / 4400 - def ocr_component_rating_class(self, component_bgr: cv2.Mat) -> int: + def ocr_component_rating_class(self, component_bgr: Mat) -> int: rating_class_rect = construct_int_xywh_rect( self.rois.component_rois.rating_class_rect ) @@ -87,7 +88,7 @@ class ChieriBotV4Ocr: else: return max(enumerate(rating_class_results), key=lambda i: i[1])[0] + 1 - def ocr_component_song_id(self, component_bgr: cv2.Mat): + def ocr_component_song_id(self, component_bgr: Mat): jacket_rect = construct_int_xywh_rect( self.rois.component_rois.jacket_rect, floor ) @@ -96,7 +97,7 @@ class ChieriBotV4Ocr: ) return self.phash_db.lookup_jacket(jacket_roi)[0] - def ocr_component_score_knn(self, component_bgr: cv2.Mat) -> int: + def ocr_component_score_knn(self, component_bgr: Mat) -> int: # sourcery skip: inline-immediately-returned-variable score_rect = construct_int_xywh_rect(self.rois.component_rois.score_rect) score_roi = cv2.cvtColor( @@ -118,7 +119,7 @@ class ChieriBotV4Ocr: score_roi = cv2.fillPoly(score_roi, [contour], 0) return ocr_digits_by_contour_knn(score_roi, self.score_knn) - def find_pfl_rects(self, component_pfl_processed: cv2.Mat) -> List[List[int]]: + def find_pfl_rects(self, component_pfl_processed: Mat) -> List[List[int]]: # sourcery skip: inline-immediately-returned-variable pfl_roi_find = cv2.morphologyEx( component_pfl_processed, @@ -144,7 +145,7 @@ class ChieriBotV4Ocr: ] return pfl_rects_adjusted - def preprocess_component_pfl(self, component_bgr: cv2.Mat) -> cv2.Mat: + def preprocess_component_pfl(self, component_bgr: Mat) -> Mat: pfl_rect = construct_int_xywh_rect(self.rois.component_rois.pfl_rect) pfl_roi = crop_xywh(component_bgr, pfl_rect) pfl_roi_hsv = cv2.cvtColor(pfl_roi, cv2.COLOR_BGR2HSV) @@ -184,7 +185,7 @@ class ChieriBotV4Ocr: return result_eroded if len(self.find_pfl_rects(result_eroded)) == 3 else result def ocr_component_pfl( - self, component_bgr: cv2.Mat + self, component_bgr: Mat ) -> Tuple[Optional[int], Optional[int], Optional[int]]: try: pfl_roi = self.preprocess_component_pfl(component_bgr) @@ -215,7 +216,7 @@ class ChieriBotV4Ocr: except Exception: return (None, None, None) - def ocr_component(self, component_bgr: cv2.Mat) -> B30OcrResultItem: + def ocr_component(self, component_bgr: Mat) -> B30OcrResultItem: component_blur = cv2.GaussianBlur(component_bgr, (5, 5), 0) rating_class = self.ocr_component_rating_class(component_blur) song_id = self.ocr_component_song_id(component_bgr) @@ -234,7 +235,7 @@ class ChieriBotV4Ocr: date=None, ) - def ocr(self, img_bgr: cv2.Mat) -> List[B30OcrResultItem]: + def ocr(self, img_bgr: Mat) -> List[B30OcrResultItem]: self.set_factor(img_bgr) return [ self.ocr_component(component_bgr) diff --git a/src/arcaea_offline_ocr/b30/chieri/v4/rois.py b/src/arcaea_offline_ocr/b30/chieri/v4/rois.py index 6bc5c22..9926b8a 100644 --- a/src/arcaea_offline_ocr/b30/chieri/v4/rois.py +++ b/src/arcaea_offline_ocr/b30/chieri/v4/rois.py @@ -1,9 +1,7 @@ from typing import List, Optional -import cv2 - from ....crop import crop_xywh -from ....types import XYWHRect +from ....types import Mat, XYWHRect from ....utils import apply_factor, construct_int_xywh_rect @@ -110,7 +108,7 @@ class ChieriBotV4Rois: def b33_vertical_gap(self): return apply_factor(121, self.factor) - def components(self, img_bgr: cv2.Mat) -> List[cv2.Mat]: + def components(self, img_bgr: Mat) -> List[Mat]: first_rect = XYWHRect(x=self.left, y=self.top, w=self.width, h=self.height) results = [] diff --git a/src/arcaea_offline_ocr/crop.py b/src/arcaea_offline_ocr/crop.py index aa9d1f2..12c531d 100644 --- a/src/arcaea_offline_ocr/crop.py +++ b/src/arcaea_offline_ocr/crop.py @@ -4,24 +4,26 @@ from typing import Tuple import cv2 import numpy as np +from .types import Mat + __all__ = ["crop_xywh", "CropBlackEdges"] -def crop_xywh(mat: cv2.Mat, rect: Tuple[int, int, int, int]): +def crop_xywh(mat: Mat, rect: Tuple[int, int, int, int]): x, y, w, h = rect return mat[y : y + h, x : x + w] class CropBlackEdges: @staticmethod - def is_black_edge(__img_gray_slice: cv2.Mat, black_pixel: int, ratio: float = 0.6): + def is_black_edge(__img_gray_slice: Mat, black_pixel: int, ratio: float = 0.6): pixels_compared = __img_gray_slice < black_pixel return np.count_nonzero(pixels_compared) > math.floor( __img_gray_slice.size * ratio ) @classmethod - def get_crop_rect(cls, img_gray: cv2.Mat, black_threshold: int = 25): + def get_crop_rect(cls, img_gray: Mat, black_threshold: int = 25): height, width = img_gray.shape[:2] left = 0 right = width @@ -58,7 +60,7 @@ class CropBlackEdges: @classmethod def crop( - cls, img: cv2.Mat, convert_flag: cv2.COLOR_BGR2GRAY, black_threshold: int = 25 - ) -> cv2.Mat: + cls, img: Mat, convert_flag: cv2.COLOR_BGR2GRAY, black_threshold: int = 25 + ) -> Mat: rect = cls.get_crop_rect(cv2.cvtColor(img, convert_flag), black_threshold) return crop_xywh(img, rect) diff --git a/src/arcaea_offline_ocr/device/ocr.py b/src/arcaea_offline_ocr/device/ocr.py index 1b87a45..d0cbf88 100644 --- a/src/arcaea_offline_ocr/device/ocr.py +++ b/src/arcaea_offline_ocr/device/ocr.py @@ -10,6 +10,7 @@ from ..ocr import ( resize_fill_square, ) from ..phash_db import ImagePhashDatabase +from ..types import Mat from .common import DeviceOcrResult from .rois.extractor import DeviceRoisExtractor from .rois.masker import DeviceRoisMasker @@ -28,7 +29,7 @@ class DeviceOcr: self.knn_model = knn_model self.phash_db = phash_db - def pfl(self, roi_gray: cv2.Mat, factor: float = 1.25): + def pfl(self, roi_gray: Mat, factor: float = 1.25): contours, _ = cv2.findContours( roi_gray, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE ) @@ -105,7 +106,7 @@ class DeviceOcr: return self.lookup_song_id()[0] @staticmethod - def preprocess_char_icon(img_gray: cv2.Mat): + def preprocess_char_icon(img_gray: Mat): h, w = img_gray.shape[:2] img = cv2.copyMakeBorder(img_gray, w - h, 0, 0, 0, cv2.BORDER_REPLICATE) h, w = img.shape[:2] diff --git a/src/arcaea_offline_ocr/device/rois/extractor/common.py b/src/arcaea_offline_ocr/device/rois/extractor/common.py index a90a0a4..671ae2c 100644 --- a/src/arcaea_offline_ocr/device/rois/extractor/common.py +++ b/src/arcaea_offline_ocr/device/rois/extractor/common.py @@ -1,11 +1,10 @@ -import cv2 - from ....crop import crop_xywh +from ....types import Mat from ..definition.common import DeviceRois class DeviceRoisExtractor: - def __init__(self, img: cv2.Mat, rois: DeviceRois): + def __init__(self, img: Mat, rois: DeviceRois): self.img = img self.sizes = rois diff --git a/src/arcaea_offline_ocr/device/rois/masker/auto.py b/src/arcaea_offline_ocr/device/rois/masker/auto.py index 164fe2b..ec92548 100644 --- a/src/arcaea_offline_ocr/device/rois/masker/auto.py +++ b/src/arcaea_offline_ocr/device/rois/masker/auto.py @@ -1,6 +1,7 @@ import cv2 import numpy as np +from ....types import Mat from .common import DeviceRoisMasker @@ -40,7 +41,7 @@ class DeviceRoisMaskerAutoT1(DeviceRoisMaskerAuto): PURE_MEMORY_HSV_MAX = np.array([110, 200, 175], np.uint8) @classmethod - def gray(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def gray(cls, roi_bgr: Mat) -> Mat: bgr_value_equal_mask = np.max(roi_bgr, axis=2) - np.min(roi_bgr, axis=2) <= 5 img_bgr = roi_bgr.copy() img_bgr[~bgr_value_equal_mask] = np.array([0, 0, 0], roi_bgr.dtype) @@ -49,19 +50,19 @@ class DeviceRoisMaskerAutoT1(DeviceRoisMaskerAuto): return cv2.inRange(img_bgr, cls.GRAY_BGR_MIN, cls.GRAY_BGR_MAX) @classmethod - def pure(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def pure(cls, roi_bgr: Mat) -> Mat: return cls.gray(roi_bgr) @classmethod - def far(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def far(cls, roi_bgr: Mat) -> Mat: return cls.gray(roi_bgr) @classmethod - def lost(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def lost(cls, roi_bgr: Mat) -> Mat: return cls.gray(roi_bgr) @classmethod - def score(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def score(cls, roi_bgr: Mat) -> Mat: return cv2.inRange( cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.WHITE_HSV_MIN, @@ -69,35 +70,35 @@ class DeviceRoisMaskerAutoT1(DeviceRoisMaskerAuto): ) @classmethod - def rating_class_pst(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def rating_class_pst(cls, roi_bgr: Mat) -> Mat: return cv2.inRange( cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.PST_HSV_MIN, cls.PST_HSV_MAX ) @classmethod - def rating_class_prs(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def rating_class_prs(cls, roi_bgr: Mat) -> Mat: return cv2.inRange( cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.PRS_HSV_MIN, cls.PRS_HSV_MAX ) @classmethod - def rating_class_ftr(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def rating_class_ftr(cls, roi_bgr: Mat) -> Mat: return cv2.inRange( cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.FTR_HSV_MIN, cls.FTR_HSV_MAX ) @classmethod - def rating_class_byd(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def rating_class_byd(cls, roi_bgr: Mat) -> Mat: return cv2.inRange( cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.BYD_HSV_MIN, cls.BYD_HSV_MAX ) @classmethod - def max_recall(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def max_recall(cls, roi_bgr: Mat) -> Mat: return cls.gray(roi_bgr) @classmethod - def clear_status_track_lost(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat: return cv2.inRange( cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.TRACK_LOST_HSV_MIN, @@ -105,7 +106,7 @@ class DeviceRoisMaskerAutoT1(DeviceRoisMaskerAuto): ) @classmethod - def clear_status_track_complete(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat: return cv2.inRange( cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.TRACK_COMPLETE_HSV_MIN, @@ -113,7 +114,7 @@ class DeviceRoisMaskerAutoT1(DeviceRoisMaskerAuto): ) @classmethod - def clear_status_full_recall(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat: return cv2.inRange( cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.FULL_RECALL_HSV_MIN, @@ -121,7 +122,7 @@ class DeviceRoisMaskerAutoT1(DeviceRoisMaskerAuto): ) @classmethod - def clear_status_pure_memory(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat: return cv2.inRange( cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.PURE_MEMORY_HSV_MIN, @@ -164,25 +165,25 @@ class DeviceRoisMaskerAutoT2(DeviceRoisMaskerAuto): PURE_MEMORY_HSV_MAX = np.array([110, 200, 175], np.uint8) @classmethod - def pfl(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def pfl(cls, roi_bgr: Mat) -> Mat: return cv2.inRange( cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.PFL_HSV_MIN, cls.PFL_HSV_MAX ) @classmethod - def pure(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def pure(cls, roi_bgr: Mat) -> Mat: return cls.pfl(roi_bgr) @classmethod - def far(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def far(cls, roi_bgr: Mat) -> Mat: return cls.pfl(roi_bgr) @classmethod - def lost(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def lost(cls, roi_bgr: Mat) -> Mat: return cls.pfl(roi_bgr) @classmethod - def score(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def score(cls, roi_bgr: Mat) -> Mat: return cv2.inRange( cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.WHITE_HSV_MIN, @@ -190,31 +191,31 @@ class DeviceRoisMaskerAutoT2(DeviceRoisMaskerAuto): ) @classmethod - def rating_class_pst(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def rating_class_pst(cls, roi_bgr: Mat) -> Mat: return cv2.inRange( cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.PST_HSV_MIN, cls.PST_HSV_MAX ) @classmethod - def rating_class_prs(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def rating_class_prs(cls, roi_bgr: Mat) -> Mat: return cv2.inRange( cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.PRS_HSV_MIN, cls.PRS_HSV_MAX ) @classmethod - def rating_class_ftr(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def rating_class_ftr(cls, roi_bgr: Mat) -> Mat: return cv2.inRange( cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.FTR_HSV_MIN, cls.FTR_HSV_MAX ) @classmethod - def rating_class_byd(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def rating_class_byd(cls, roi_bgr: Mat) -> Mat: return cv2.inRange( cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.BYD_HSV_MIN, cls.BYD_HSV_MAX ) @classmethod - def max_recall(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def max_recall(cls, roi_bgr: Mat) -> Mat: return cv2.inRange( cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.MAX_RECALL_HSV_MIN, @@ -222,7 +223,7 @@ class DeviceRoisMaskerAutoT2(DeviceRoisMaskerAuto): ) @classmethod - def clear_status_track_lost(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat: return cv2.inRange( cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.TRACK_LOST_HSV_MIN, @@ -230,7 +231,7 @@ class DeviceRoisMaskerAutoT2(DeviceRoisMaskerAuto): ) @classmethod - def clear_status_track_complete(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat: return cv2.inRange( cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.TRACK_COMPLETE_HSV_MIN, @@ -238,7 +239,7 @@ class DeviceRoisMaskerAutoT2(DeviceRoisMaskerAuto): ) @classmethod - def clear_status_full_recall(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat: return cv2.inRange( cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.FULL_RECALL_HSV_MIN, @@ -246,7 +247,7 @@ class DeviceRoisMaskerAutoT2(DeviceRoisMaskerAuto): ) @classmethod - def clear_status_pure_memory(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat: return cv2.inRange( cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cls.PURE_MEMORY_HSV_MIN, diff --git a/src/arcaea_offline_ocr/device/rois/masker/common.py b/src/arcaea_offline_ocr/device/rois/masker/common.py index f877e2c..cb1223f 100644 --- a/src/arcaea_offline_ocr/device/rois/masker/common.py +++ b/src/arcaea_offline_ocr/device/rois/masker/common.py @@ -1,55 +1,55 @@ -import cv2 +from ....types import Mat class DeviceRoisMasker: @classmethod - def pure(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def pure(cls, roi_bgr: Mat) -> Mat: raise NotImplementedError() @classmethod - def far(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def far(cls, roi_bgr: Mat) -> Mat: raise NotImplementedError() @classmethod - def lost(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def lost(cls, roi_bgr: Mat) -> Mat: raise NotImplementedError() @classmethod - def score(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def score(cls, roi_bgr: Mat) -> Mat: raise NotImplementedError() @classmethod - def rating_class_pst(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def rating_class_pst(cls, roi_bgr: Mat) -> Mat: raise NotImplementedError() @classmethod - def rating_class_prs(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def rating_class_prs(cls, roi_bgr: Mat) -> Mat: raise NotImplementedError() @classmethod - def rating_class_ftr(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def rating_class_ftr(cls, roi_bgr: Mat) -> Mat: raise NotImplementedError() @classmethod - def rating_class_byd(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def rating_class_byd(cls, roi_bgr: Mat) -> Mat: raise NotImplementedError() @classmethod - def max_recall(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def max_recall(cls, roi_bgr: Mat) -> Mat: raise NotImplementedError() @classmethod - def clear_status_track_lost(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat: raise NotImplementedError() @classmethod - def clear_status_track_complete(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat: raise NotImplementedError() @classmethod - def clear_status_full_recall(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat: raise NotImplementedError() @classmethod - def clear_status_pure_memory(cls, roi_bgr: cv2.Mat) -> cv2.Mat: + def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat: raise NotImplementedError() diff --git a/src/arcaea_offline_ocr/ocr.py b/src/arcaea_offline_ocr/ocr.py index cfa96e3..44ca73a 100644 --- a/src/arcaea_offline_ocr/ocr.py +++ b/src/arcaea_offline_ocr/ocr.py @@ -5,6 +5,7 @@ import cv2 import numpy as np from .crop import crop_xywh +from .types import Mat __all__ = [ "FixRects", @@ -67,7 +68,7 @@ class FixRects: @staticmethod def split_connected( - img_masked: cv2.Mat, + img_masked: Mat, rects: Sequence[Tuple[int, int, int, int]], rect_wh_ratio: float = 1.05, width_range_ratio: float = 0.1, @@ -117,7 +118,7 @@ class FixRects: return return_rects -def resize_fill_square(img: cv2.Mat, target: int = 20): +def resize_fill_square(img: Mat, target: int = 20): h, w = img.shape[:2] if h > w: new_h = target @@ -156,7 +157,7 @@ def ocr_digit_samples_knn(__samples, knn_model: cv2.ml.KNearest, k: int = 4): return int(result_str) if result_str else 0 -def ocr_digits_by_contour_get_samples(__roi_gray: cv2.Mat, size: int): +def ocr_digits_by_contour_get_samples(__roi_gray: Mat, size: int): roi = __roi_gray.copy() contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) rects = [cv2.boundingRect(c) for c in contours] @@ -169,7 +170,7 @@ def ocr_digits_by_contour_get_samples(__roi_gray: cv2.Mat, size: int): def ocr_digits_by_contour_knn( - __roi_gray: cv2.Mat, + __roi_gray: Mat, knn_model: cv2.ml.KNearest, *, k=4, diff --git a/src/arcaea_offline_ocr/phash_db.py b/src/arcaea_offline_ocr/phash_db.py index 8d95b6b..dba7c04 100644 --- a/src/arcaea_offline_ocr/phash_db.py +++ b/src/arcaea_offline_ocr/phash_db.py @@ -4,9 +4,11 @@ from typing import List, Union import cv2 import numpy as np +from .types import Mat + def phash_opencv(img_gray, hash_size=8, highfreq_factor=4): - # type: (Union[cv2.Mat, np.ndarray], int, int) -> np.ndarray + # type: (Union[Mat, np.ndarray], int, int) -> np.ndarray """ Perceptual Hash computation. @@ -76,7 +78,7 @@ class ImagePhashDatabase: self.jacket_ids.append(id) self.jacket_hashes.append(hash) - def calculate_phash(self, img_gray: cv2.Mat): + def calculate_phash(self, img_gray: Mat): return phash_opencv( img_gray, hash_size=self.hash_size, highfreq_factor=self.highfreq_factor ) @@ -89,11 +91,11 @@ class ImagePhashDatabase: ] return sorted(xor_results, key=lambda r: r[1])[:limit] - def lookup_image(self, img_gray: cv2.Mat): + def lookup_image(self, img_gray: Mat): image_hash = self.calculate_phash(img_gray) return self.lookup_hash(image_hash)[0] - def lookup_jackets(self, img_gray: cv2.Mat, *, limit: int = 5): + def lookup_jackets(self, img_gray: Mat, *, limit: int = 5): image_hash = self.calculate_phash(img_gray).flatten() xor_results = [ (id, np.count_nonzero(image_hash ^ h)) @@ -101,10 +103,10 @@ class ImagePhashDatabase: ] return sorted(xor_results, key=lambda r: r[1])[:limit] - def lookup_jacket(self, img_gray: cv2.Mat): + def lookup_jacket(self, img_gray: Mat): return self.lookup_jackets(img_gray)[0] - def lookup_partner_icons(self, img_gray: cv2.Mat, *, limit: int = 5): + def lookup_partner_icons(self, img_gray: Mat, *, limit: int = 5): image_hash = self.calculate_phash(img_gray).flatten() xor_results = [ (id, np.count_nonzero(image_hash ^ h)) @@ -112,5 +114,5 @@ class ImagePhashDatabase: ] return sorted(xor_results, key=lambda r: r[1])[:limit] - def lookup_partner_icon(self, img_gray: cv2.Mat): + def lookup_partner_icon(self, img_gray: Mat): return self.lookup_partner_icons(img_gray)[0] diff --git a/src/arcaea_offline_ocr/types.py b/src/arcaea_offline_ocr/types.py index dc0bd76..7f1bc5b 100644 --- a/src/arcaea_offline_ocr/types.py +++ b/src/arcaea_offline_ocr/types.py @@ -1,6 +1,10 @@ from collections.abc import Iterable from typing import NamedTuple, Tuple, Union +import numpy as np + +Mat = np.ndarray + class XYWHRect(NamedTuple): x: int