diff --git a/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py b/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py index 302b9a5..a68a9c1 100644 --- a/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py +++ b/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py @@ -8,7 +8,6 @@ from PIL import Image from ....crop import crop_xywh from ....ocr import FixRects, ocr_digits_by_contour_knn, preprocess_hog from ....phash_db import ImagePhashDatabase -from ....types import Mat, cv2_ml_KNearest from ....utils import construct_int_xywh_rect from ...shared import B30OcrResultItem from .colors import * @@ -18,8 +17,8 @@ from .rois import ChieriBotV4Rois class ChieriBotV4Ocr: def __init__( self, - score_knn: cv2_ml_KNearest, - pfl_knn: cv2_ml_KNearest, + score_knn: cv2.ml.KNearest, + pfl_knn: cv2.ml.KNearest, phash_db: ImagePhashDatabase, factor: Optional[float] = 1.0, ): @@ -33,7 +32,7 @@ class ChieriBotV4Ocr: return self.__score_knn @score_knn.setter - def score_knn(self, knn_digits_model: Mat): + def score_knn(self, knn_digits_model: cv2.ml.KNearest): self.__score_knn = knn_digits_model @property @@ -41,7 +40,7 @@ class ChieriBotV4Ocr: return self.__pfl_knn @pfl_knn.setter - def pfl_knn(self, knn_digits_model: Mat): + def pfl_knn(self, knn_digits_model: cv2.ml.KNearest): self.__pfl_knn = knn_digits_model @property @@ -64,10 +63,10 @@ class ChieriBotV4Ocr: def factor(self, factor: float): self.__rois.factor = factor - def set_factor(self, img: Mat): + def set_factor(self, img: cv2.Mat): self.factor = img.shape[0] / 4400 - def ocr_component_rating_class(self, component_bgr: Mat) -> int: + def ocr_component_rating_class(self, component_bgr: cv2.Mat) -> int: rating_class_rect = construct_int_xywh_rect( self.rois.component_rois.rating_class_rect ) @@ -84,15 +83,7 @@ class ChieriBotV4Ocr: else: return max(enumerate(rating_class_results), key=lambda i: i[1])[0] + 1 - # def ocr_component_title(self, component_bgr: Mat) -> str: - # # sourcery skip: inline-immediately-returned-variable - # title_rect = construct_int_xywh_rect(self.rois.component_rois.title_rect) - # title_roi = crop_xywh(component_bgr, title_rect) - # ocr_result = self.sift_db.ocr(title_roi, cls=False) - # title = ocr_result[0][-1][1][0] if ocr_result and ocr_result[0] else "" - # return title - - def ocr_component_song_id(self, component_bgr: Mat): + def ocr_component_song_id(self, component_bgr: cv2.Mat): jacket_rect = construct_int_xywh_rect( self.rois.component_rois.jacket_rect, floor ) @@ -101,20 +92,7 @@ class ChieriBotV4Ocr: ) return self.phash_db.lookup_image(Image.fromarray(jacket_roi))[0] - # def ocr_component_score_paddle(self, component_bgr: Mat) -> int: - # # sourcery skip: inline-immediately-returned-variable - # score_rect = construct_int_xywh_rect(self.rois.component_rois.score_rect) - # score_roi = cv2.cvtColor( - # crop_xywh(component_bgr, score_rect), cv2.COLOR_BGR2GRAY - # ) - # _, score_roi = cv2.threshold( - # score_roi, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU - # ) - # score_str = self.sift_db.ocr(score_roi, cls=False)[0][-1][1][0] - # score = int(score_str.replace("'", "").replace(" ", "")) - # return score - - def ocr_component_score_knn(self, component_bgr: Mat) -> int: + def ocr_component_score_knn(self, component_bgr: cv2.Mat) -> int: # sourcery skip: inline-immediately-returned-variable score_rect = construct_int_xywh_rect(self.rois.component_rois.score_rect) score_roi = cv2.cvtColor( @@ -136,7 +114,7 @@ class ChieriBotV4Ocr: score_roi = cv2.fillPoly(score_roi, [contour], 0) return ocr_digits_by_contour_knn(score_roi, self.score_knn) - def find_pfl_rects(self, component_pfl_processed: Mat) -> List[List[int]]: + def find_pfl_rects(self, component_pfl_processed: cv2.Mat) -> List[List[int]]: # sourcery skip: inline-immediately-returned-variable pfl_roi_find = cv2.morphologyEx( component_pfl_processed, @@ -162,7 +140,7 @@ class ChieriBotV4Ocr: ] return pfl_rects_adjusted - def preprocess_component_pfl(self, component_bgr: Mat) -> Mat: + def preprocess_component_pfl(self, component_bgr: cv2.Mat) -> cv2.Mat: pfl_rect = construct_int_xywh_rect(self.rois.component_rois.pfl_rect) pfl_roi = crop_xywh(component_bgr, pfl_rect) pfl_roi_hsv = cv2.cvtColor(pfl_roi, cv2.COLOR_BGR2HSV) @@ -202,7 +180,7 @@ class ChieriBotV4Ocr: return result_eroded if len(self.find_pfl_rects(result_eroded)) == 3 else result def ocr_component_pfl( - self, component_bgr: Mat + self, component_bgr: cv2.Mat ) -> Tuple[Optional[int], Optional[int], Optional[int]]: try: pfl_roi = self.preprocess_component_pfl(component_bgr) @@ -233,16 +211,7 @@ class ChieriBotV4Ocr: except Exception: return (None, None, None) - # def ocr_component_date(self, component_bgr: Mat): - # date_rect = construct_int_xywh_rect(self.rois.component_rois.date_rect) - # date_roi = cv2.cvtColor(crop_xywh(component_bgr, date_rect), cv2.COLOR_BGR2GRAY) - # _, date_roi = cv2.threshold( - # date_roi, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU - # ) - # date_str = self.sift_db.ocr(date_roi, cls=False)[0][-1][1][0] - # return date_str - - def ocr_component(self, component_bgr: Mat) -> B30OcrResultItem: + def ocr_component(self, component_bgr: cv2.Mat) -> B30OcrResultItem: component_blur = cv2.GaussianBlur(component_bgr, (5, 5), 0) rating_class = self.ocr_component_rating_class(component_blur) song_id = self.ocr_component_song_id(component_bgr) @@ -261,7 +230,7 @@ class ChieriBotV4Ocr: date=None, ) - def ocr(self, img_bgr: Mat) -> List[B30OcrResultItem]: + def ocr(self, img_bgr: cv2.Mat) -> List[B30OcrResultItem]: self.set_factor(img_bgr) return [ self.ocr_component(component_bgr) diff --git a/src/arcaea_offline_ocr/b30/chieri/v4/rois.py b/src/arcaea_offline_ocr/b30/chieri/v4/rois.py index 9926b8a..6bc5c22 100644 --- a/src/arcaea_offline_ocr/b30/chieri/v4/rois.py +++ b/src/arcaea_offline_ocr/b30/chieri/v4/rois.py @@ -1,7 +1,9 @@ from typing import List, Optional +import cv2 + from ....crop import crop_xywh -from ....types import Mat, XYWHRect +from ....types import XYWHRect from ....utils import apply_factor, construct_int_xywh_rect @@ -108,7 +110,7 @@ class ChieriBotV4Rois: def b33_vertical_gap(self): return apply_factor(121, self.factor) - def components(self, img_bgr: Mat) -> List[Mat]: + def components(self, img_bgr: cv2.Mat) -> List[cv2.Mat]: first_rect = XYWHRect(x=self.left, y=self.top, w=self.width, h=self.height) results = [] diff --git a/src/arcaea_offline_ocr/crop.py b/src/arcaea_offline_ocr/crop.py index a65b6ea..5499acf 100644 --- a/src/arcaea_offline_ocr/crop.py +++ b/src/arcaea_offline_ocr/crop.py @@ -1,26 +1,25 @@ from math import floor from typing import Tuple +import cv2 import numpy as np -from .types import Mat - __all__ = ["crop_xywh", "crop_black_edges", "crop_black_edges_grayscale"] -def crop_xywh(mat: Mat, rect: Tuple[int, int, int, int]): +def crop_xywh(mat: cv2.Mat, rect: Tuple[int, int, int, int]): x, y, w, h = rect return mat[y : y + h, x : x + w] -def is_black_edge(list_of_pixels: Mat, black_pixel: Mat, ratio: float = 0.6): +def is_black_edge(list_of_pixels: cv2.Mat, black_pixel: cv2.Mat, ratio: float = 0.6): pixels = list_of_pixels.reshape([-1, 3]) return np.count_nonzero(np.all(pixels < black_pixel, axis=1)) > floor( len(pixels) * ratio ) -def crop_black_edges(img_bgr: Mat, black_threshold: int = 50): +def crop_black_edges(img_bgr: cv2.Mat, black_threshold: int = 50): cropped = img_bgr.copy() black_pixel = np.array([black_threshold] * 3, img_bgr.dtype) height, width = img_bgr.shape[:2] @@ -66,7 +65,7 @@ def is_black_edge_grayscale( def crop_black_edges_grayscale( - img_gray: Mat, black_threshold: int = 50 + img_gray: cv2.Mat, black_threshold: int = 50 ) -> Tuple[int, int, int, int]: """Returns cropped rect""" height, width = img_gray.shape[:2] diff --git a/src/arcaea_offline_ocr/ocr.py b/src/arcaea_offline_ocr/ocr.py index 83a30b4..cfa96e3 100644 --- a/src/arcaea_offline_ocr/ocr.py +++ b/src/arcaea_offline_ocr/ocr.py @@ -5,7 +5,6 @@ import cv2 import numpy as np from .crop import crop_xywh -from .types import Mat, cv2_ml_KNearest __all__ = [ "FixRects", @@ -68,7 +67,7 @@ class FixRects: @staticmethod def split_connected( - img_masked: Mat, + img_masked: cv2.Mat, rects: Sequence[Tuple[int, int, int, int]], rect_wh_ratio: float = 1.05, width_range_ratio: float = 0.1, @@ -118,7 +117,7 @@ class FixRects: return return_rects -def resize_fill_square(img: Mat, target: int = 20): +def resize_fill_square(img: cv2.Mat, target: int = 20): h, w = img.shape[:2] if h > w: new_h = target @@ -150,14 +149,14 @@ def preprocess_hog(digit_rois): return np.float32(samples) -def ocr_digit_samples_knn(__samples, knn_model: cv2_ml_KNearest, k: int = 4): +def ocr_digit_samples_knn(__samples, knn_model: cv2.ml.KNearest, k: int = 4): _, results, _, _ = knn_model.findNearest(__samples, k) result_list = [int(r) for r in results.ravel()] result_str = "".join(str(r) for r in result_list if r > -1) return int(result_str) if result_str else 0 -def ocr_digits_by_contour_get_samples(__roi_gray: Mat, size: int): +def ocr_digits_by_contour_get_samples(__roi_gray: cv2.Mat, size: int): roi = __roi_gray.copy() contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) rects = [cv2.boundingRect(c) for c in contours] @@ -170,8 +169,8 @@ def ocr_digits_by_contour_get_samples(__roi_gray: Mat, size: int): def ocr_digits_by_contour_knn( - __roi_gray: Mat, - knn_model: cv2_ml_KNearest, + __roi_gray: cv2.Mat, + knn_model: cv2.ml.KNearest, *, k=4, size: int = 20, diff --git a/src/arcaea_offline_ocr/types.py b/src/arcaea_offline_ocr/types.py index 3a3dc92..dc0bd76 100644 --- a/src/arcaea_offline_ocr/types.py +++ b/src/arcaea_offline_ocr/types.py @@ -1,10 +1,5 @@ from collections.abc import Iterable -from typing import Any, NamedTuple, Protocol, Tuple, Union - -import numpy as np - -# from pylance -Mat = np.ndarray[int, np.dtype[np.generic]] +from typing import NamedTuple, Tuple, Union class XYWHRect(NamedTuple): @@ -24,19 +19,3 @@ class XYWHRect(NamedTuple): raise ValueError() return self.__class__(*[a - b for a, b in zip(self, other)]) - - -class cv2_ml_StatModel(Protocol): - def predict(self, samples: np.ndarray, results: np.ndarray, flags: int = 0): - ... - - def train(self, samples: np.ndarray, layout: int, responses: np.ndarray): - ... - - -class cv2_ml_KNearest(cv2_ml_StatModel, Protocol): - def findNearest( - self, samples: np.ndarray, k: int - ) -> Tuple[Any, np.ndarray, np.ndarray, np.ndarray]: - """cv.ml.KNearest.findNearest(samples, k[, results[, neighborResponses[, dist]]]) -> retval, results, neighborResponses, dist""" - ... diff --git a/src/arcaea_offline_ocr/utils.py b/src/arcaea_offline_ocr/utils.py index e55ea0f..d7466d2 100644 --- a/src/arcaea_offline_ocr/utils.py +++ b/src/arcaea_offline_ocr/utils.py @@ -1,17 +1,17 @@ import io from collections.abc import Iterable -from typing import Callable, Tuple, TypeVar, Union, overload +from typing import Callable, TypeVar, Union, overload import cv2 import numpy as np from PIL import Image, ImageCms -from .types import Mat, XYWHRect +from .types import XYWHRect __all__ = ["imread_unicode"] -def imread_unicode(filepath: str, flags: int = cv2.IMREAD_UNCHANGED) -> Mat: +def imread_unicode(filepath: str, flags: int = cv2.IMREAD_UNCHANGED): # https://stackoverflow.com/a/57872297/16484891 # CC BY-SA 4.0 return cv2.imdecode(np.fromfile(filepath, dtype=np.uint8), flags)