1
0
mirror of https://github.com/283375/arcaea-offline-ocr.git synced 2025-04-14 02:50:17 +00:00

chore: remove custom cv2 type annotations ()

This commit is contained in:
283375 2023-10-12 01:50:27 +08:00
parent 2895eb7233
commit 82229b8b5c
Signed by: 283375
SSH Key Fingerprint: SHA256:UcX0qg6ZOSDOeieKPGokA5h7soykG61nz2uxuQgVLSk
6 changed files with 32 additions and 84 deletions
src/arcaea_offline_ocr

@ -8,7 +8,6 @@ from PIL import Image
from ....crop import crop_xywh
from ....ocr import FixRects, ocr_digits_by_contour_knn, preprocess_hog
from ....phash_db import ImagePhashDatabase
from ....types import Mat, cv2_ml_KNearest
from ....utils import construct_int_xywh_rect
from ...shared import B30OcrResultItem
from .colors import *
@ -18,8 +17,8 @@ from .rois import ChieriBotV4Rois
class ChieriBotV4Ocr:
def __init__(
self,
score_knn: cv2_ml_KNearest,
pfl_knn: cv2_ml_KNearest,
score_knn: cv2.ml.KNearest,
pfl_knn: cv2.ml.KNearest,
phash_db: ImagePhashDatabase,
factor: Optional[float] = 1.0,
):
@ -33,7 +32,7 @@ class ChieriBotV4Ocr:
return self.__score_knn
@score_knn.setter
def score_knn(self, knn_digits_model: Mat):
def score_knn(self, knn_digits_model: cv2.ml.KNearest):
self.__score_knn = knn_digits_model
@property
@ -41,7 +40,7 @@ class ChieriBotV4Ocr:
return self.__pfl_knn
@pfl_knn.setter
def pfl_knn(self, knn_digits_model: Mat):
def pfl_knn(self, knn_digits_model: cv2.ml.KNearest):
self.__pfl_knn = knn_digits_model
@property
@ -64,10 +63,10 @@ class ChieriBotV4Ocr:
def factor(self, factor: float):
self.__rois.factor = factor
def set_factor(self, img: Mat):
def set_factor(self, img: cv2.Mat):
self.factor = img.shape[0] / 4400
def ocr_component_rating_class(self, component_bgr: Mat) -> int:
def ocr_component_rating_class(self, component_bgr: cv2.Mat) -> int:
rating_class_rect = construct_int_xywh_rect(
self.rois.component_rois.rating_class_rect
)
@ -84,15 +83,7 @@ class ChieriBotV4Ocr:
else:
return max(enumerate(rating_class_results), key=lambda i: i[1])[0] + 1
# def ocr_component_title(self, component_bgr: Mat) -> str:
# # sourcery skip: inline-immediately-returned-variable
# title_rect = construct_int_xywh_rect(self.rois.component_rois.title_rect)
# title_roi = crop_xywh(component_bgr, title_rect)
# ocr_result = self.sift_db.ocr(title_roi, cls=False)
# title = ocr_result[0][-1][1][0] if ocr_result and ocr_result[0] else ""
# return title
def ocr_component_song_id(self, component_bgr: Mat):
def ocr_component_song_id(self, component_bgr: cv2.Mat):
jacket_rect = construct_int_xywh_rect(
self.rois.component_rois.jacket_rect, floor
)
@ -101,20 +92,7 @@ class ChieriBotV4Ocr:
)
return self.phash_db.lookup_image(Image.fromarray(jacket_roi))[0]
# def ocr_component_score_paddle(self, component_bgr: Mat) -> int:
# # sourcery skip: inline-immediately-returned-variable
# score_rect = construct_int_xywh_rect(self.rois.component_rois.score_rect)
# score_roi = cv2.cvtColor(
# crop_xywh(component_bgr, score_rect), cv2.COLOR_BGR2GRAY
# )
# _, score_roi = cv2.threshold(
# score_roi, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
# )
# score_str = self.sift_db.ocr(score_roi, cls=False)[0][-1][1][0]
# score = int(score_str.replace("'", "").replace(" ", ""))
# return score
def ocr_component_score_knn(self, component_bgr: Mat) -> int:
def ocr_component_score_knn(self, component_bgr: cv2.Mat) -> int:
# sourcery skip: inline-immediately-returned-variable
score_rect = construct_int_xywh_rect(self.rois.component_rois.score_rect)
score_roi = cv2.cvtColor(
@ -136,7 +114,7 @@ class ChieriBotV4Ocr:
score_roi = cv2.fillPoly(score_roi, [contour], 0)
return ocr_digits_by_contour_knn(score_roi, self.score_knn)
def find_pfl_rects(self, component_pfl_processed: Mat) -> List[List[int]]:
def find_pfl_rects(self, component_pfl_processed: cv2.Mat) -> List[List[int]]:
# sourcery skip: inline-immediately-returned-variable
pfl_roi_find = cv2.morphologyEx(
component_pfl_processed,
@ -162,7 +140,7 @@ class ChieriBotV4Ocr:
]
return pfl_rects_adjusted
def preprocess_component_pfl(self, component_bgr: Mat) -> Mat:
def preprocess_component_pfl(self, component_bgr: cv2.Mat) -> cv2.Mat:
pfl_rect = construct_int_xywh_rect(self.rois.component_rois.pfl_rect)
pfl_roi = crop_xywh(component_bgr, pfl_rect)
pfl_roi_hsv = cv2.cvtColor(pfl_roi, cv2.COLOR_BGR2HSV)
@ -202,7 +180,7 @@ class ChieriBotV4Ocr:
return result_eroded if len(self.find_pfl_rects(result_eroded)) == 3 else result
def ocr_component_pfl(
self, component_bgr: Mat
self, component_bgr: cv2.Mat
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
try:
pfl_roi = self.preprocess_component_pfl(component_bgr)
@ -233,16 +211,7 @@ class ChieriBotV4Ocr:
except Exception:
return (None, None, None)
# def ocr_component_date(self, component_bgr: Mat):
# date_rect = construct_int_xywh_rect(self.rois.component_rois.date_rect)
# date_roi = cv2.cvtColor(crop_xywh(component_bgr, date_rect), cv2.COLOR_BGR2GRAY)
# _, date_roi = cv2.threshold(
# date_roi, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
# )
# date_str = self.sift_db.ocr(date_roi, cls=False)[0][-1][1][0]
# return date_str
def ocr_component(self, component_bgr: Mat) -> B30OcrResultItem:
def ocr_component(self, component_bgr: cv2.Mat) -> B30OcrResultItem:
component_blur = cv2.GaussianBlur(component_bgr, (5, 5), 0)
rating_class = self.ocr_component_rating_class(component_blur)
song_id = self.ocr_component_song_id(component_bgr)
@ -261,7 +230,7 @@ class ChieriBotV4Ocr:
date=None,
)
def ocr(self, img_bgr: Mat) -> List[B30OcrResultItem]:
def ocr(self, img_bgr: cv2.Mat) -> List[B30OcrResultItem]:
self.set_factor(img_bgr)
return [
self.ocr_component(component_bgr)

@ -1,7 +1,9 @@
from typing import List, Optional
import cv2
from ....crop import crop_xywh
from ....types import Mat, XYWHRect
from ....types import XYWHRect
from ....utils import apply_factor, construct_int_xywh_rect
@ -108,7 +110,7 @@ class ChieriBotV4Rois:
def b33_vertical_gap(self):
return apply_factor(121, self.factor)
def components(self, img_bgr: Mat) -> List[Mat]:
def components(self, img_bgr: cv2.Mat) -> List[cv2.Mat]:
first_rect = XYWHRect(x=self.left, y=self.top, w=self.width, h=self.height)
results = []

@ -1,26 +1,25 @@
from math import floor
from typing import Tuple
import cv2
import numpy as np
from .types import Mat
__all__ = ["crop_xywh", "crop_black_edges", "crop_black_edges_grayscale"]
def crop_xywh(mat: Mat, rect: Tuple[int, int, int, int]):
def crop_xywh(mat: cv2.Mat, rect: Tuple[int, int, int, int]):
x, y, w, h = rect
return mat[y : y + h, x : x + w]
def is_black_edge(list_of_pixels: Mat, black_pixel: Mat, ratio: float = 0.6):
def is_black_edge(list_of_pixels: cv2.Mat, black_pixel: cv2.Mat, ratio: float = 0.6):
pixels = list_of_pixels.reshape([-1, 3])
return np.count_nonzero(np.all(pixels < black_pixel, axis=1)) > floor(
len(pixels) * ratio
)
def crop_black_edges(img_bgr: Mat, black_threshold: int = 50):
def crop_black_edges(img_bgr: cv2.Mat, black_threshold: int = 50):
cropped = img_bgr.copy()
black_pixel = np.array([black_threshold] * 3, img_bgr.dtype)
height, width = img_bgr.shape[:2]
@ -66,7 +65,7 @@ def is_black_edge_grayscale(
def crop_black_edges_grayscale(
img_gray: Mat, black_threshold: int = 50
img_gray: cv2.Mat, black_threshold: int = 50
) -> Tuple[int, int, int, int]:
"""Returns cropped rect"""
height, width = img_gray.shape[:2]

@ -5,7 +5,6 @@ import cv2
import numpy as np
from .crop import crop_xywh
from .types import Mat, cv2_ml_KNearest
__all__ = [
"FixRects",
@ -68,7 +67,7 @@ class FixRects:
@staticmethod
def split_connected(
img_masked: Mat,
img_masked: cv2.Mat,
rects: Sequence[Tuple[int, int, int, int]],
rect_wh_ratio: float = 1.05,
width_range_ratio: float = 0.1,
@ -118,7 +117,7 @@ class FixRects:
return return_rects
def resize_fill_square(img: Mat, target: int = 20):
def resize_fill_square(img: cv2.Mat, target: int = 20):
h, w = img.shape[:2]
if h > w:
new_h = target
@ -150,14 +149,14 @@ def preprocess_hog(digit_rois):
return np.float32(samples)
def ocr_digit_samples_knn(__samples, knn_model: cv2_ml_KNearest, k: int = 4):
def ocr_digit_samples_knn(__samples, knn_model: cv2.ml.KNearest, k: int = 4):
_, results, _, _ = knn_model.findNearest(__samples, k)
result_list = [int(r) for r in results.ravel()]
result_str = "".join(str(r) for r in result_list if r > -1)
return int(result_str) if result_str else 0
def ocr_digits_by_contour_get_samples(__roi_gray: Mat, size: int):
def ocr_digits_by_contour_get_samples(__roi_gray: cv2.Mat, size: int):
roi = __roi_gray.copy()
contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
rects = [cv2.boundingRect(c) for c in contours]
@ -170,8 +169,8 @@ def ocr_digits_by_contour_get_samples(__roi_gray: Mat, size: int):
def ocr_digits_by_contour_knn(
__roi_gray: Mat,
knn_model: cv2_ml_KNearest,
__roi_gray: cv2.Mat,
knn_model: cv2.ml.KNearest,
*,
k=4,
size: int = 20,

@ -1,10 +1,5 @@
from collections.abc import Iterable
from typing import Any, NamedTuple, Protocol, Tuple, Union
import numpy as np
# from pylance
Mat = np.ndarray[int, np.dtype[np.generic]]
from typing import NamedTuple, Tuple, Union
class XYWHRect(NamedTuple):
@ -24,19 +19,3 @@ class XYWHRect(NamedTuple):
raise ValueError()
return self.__class__(*[a - b for a, b in zip(self, other)])
class cv2_ml_StatModel(Protocol):
def predict(self, samples: np.ndarray, results: np.ndarray, flags: int = 0):
...
def train(self, samples: np.ndarray, layout: int, responses: np.ndarray):
...
class cv2_ml_KNearest(cv2_ml_StatModel, Protocol):
def findNearest(
self, samples: np.ndarray, k: int
) -> Tuple[Any, np.ndarray, np.ndarray, np.ndarray]:
"""cv.ml.KNearest.findNearest(samples, k[, results[, neighborResponses[, dist]]]) -> retval, results, neighborResponses, dist"""
...

@ -1,17 +1,17 @@
import io
from collections.abc import Iterable
from typing import Callable, Tuple, TypeVar, Union, overload
from typing import Callable, TypeVar, Union, overload
import cv2
import numpy as np
from PIL import Image, ImageCms
from .types import Mat, XYWHRect
from .types import XYWHRect
__all__ = ["imread_unicode"]
def imread_unicode(filepath: str, flags: int = cv2.IMREAD_UNCHANGED) -> Mat:
def imread_unicode(filepath: str, flags: int = cv2.IMREAD_UNCHANGED):
# https://stackoverflow.com/a/57872297/16484891
# CC BY-SA 4.0
return cv2.imdecode(np.fromfile(filepath, dtype=np.uint8), flags)