chore: remove custom cv2 type annotations (#8)

This commit is contained in:
283375 2023-10-12 01:50:27 +08:00
parent 2895eb7233
commit 82229b8b5c
Signed by: 283375
SSH Key Fingerprint: SHA256:UcX0qg6ZOSDOeieKPGokA5h7soykG61nz2uxuQgVLSk
6 changed files with 32 additions and 84 deletions

View File

@ -8,7 +8,6 @@ from PIL import Image
from ....crop import crop_xywh from ....crop import crop_xywh
from ....ocr import FixRects, ocr_digits_by_contour_knn, preprocess_hog from ....ocr import FixRects, ocr_digits_by_contour_knn, preprocess_hog
from ....phash_db import ImagePhashDatabase from ....phash_db import ImagePhashDatabase
from ....types import Mat, cv2_ml_KNearest
from ....utils import construct_int_xywh_rect from ....utils import construct_int_xywh_rect
from ...shared import B30OcrResultItem from ...shared import B30OcrResultItem
from .colors import * from .colors import *
@ -18,8 +17,8 @@ from .rois import ChieriBotV4Rois
class ChieriBotV4Ocr: class ChieriBotV4Ocr:
def __init__( def __init__(
self, self,
score_knn: cv2_ml_KNearest, score_knn: cv2.ml.KNearest,
pfl_knn: cv2_ml_KNearest, pfl_knn: cv2.ml.KNearest,
phash_db: ImagePhashDatabase, phash_db: ImagePhashDatabase,
factor: Optional[float] = 1.0, factor: Optional[float] = 1.0,
): ):
@ -33,7 +32,7 @@ class ChieriBotV4Ocr:
return self.__score_knn return self.__score_knn
@score_knn.setter @score_knn.setter
def score_knn(self, knn_digits_model: Mat): def score_knn(self, knn_digits_model: cv2.ml.KNearest):
self.__score_knn = knn_digits_model self.__score_knn = knn_digits_model
@property @property
@ -41,7 +40,7 @@ class ChieriBotV4Ocr:
return self.__pfl_knn return self.__pfl_knn
@pfl_knn.setter @pfl_knn.setter
def pfl_knn(self, knn_digits_model: Mat): def pfl_knn(self, knn_digits_model: cv2.ml.KNearest):
self.__pfl_knn = knn_digits_model self.__pfl_knn = knn_digits_model
@property @property
@ -64,10 +63,10 @@ class ChieriBotV4Ocr:
def factor(self, factor: float): def factor(self, factor: float):
self.__rois.factor = factor self.__rois.factor = factor
def set_factor(self, img: Mat): def set_factor(self, img: cv2.Mat):
self.factor = img.shape[0] / 4400 self.factor = img.shape[0] / 4400
def ocr_component_rating_class(self, component_bgr: Mat) -> int: def ocr_component_rating_class(self, component_bgr: cv2.Mat) -> int:
rating_class_rect = construct_int_xywh_rect( rating_class_rect = construct_int_xywh_rect(
self.rois.component_rois.rating_class_rect self.rois.component_rois.rating_class_rect
) )
@ -84,15 +83,7 @@ class ChieriBotV4Ocr:
else: else:
return max(enumerate(rating_class_results), key=lambda i: i[1])[0] + 1 return max(enumerate(rating_class_results), key=lambda i: i[1])[0] + 1
# def ocr_component_title(self, component_bgr: Mat) -> str: def ocr_component_song_id(self, component_bgr: cv2.Mat):
# # sourcery skip: inline-immediately-returned-variable
# title_rect = construct_int_xywh_rect(self.rois.component_rois.title_rect)
# title_roi = crop_xywh(component_bgr, title_rect)
# ocr_result = self.sift_db.ocr(title_roi, cls=False)
# title = ocr_result[0][-1][1][0] if ocr_result and ocr_result[0] else ""
# return title
def ocr_component_song_id(self, component_bgr: Mat):
jacket_rect = construct_int_xywh_rect( jacket_rect = construct_int_xywh_rect(
self.rois.component_rois.jacket_rect, floor self.rois.component_rois.jacket_rect, floor
) )
@ -101,20 +92,7 @@ class ChieriBotV4Ocr:
) )
return self.phash_db.lookup_image(Image.fromarray(jacket_roi))[0] return self.phash_db.lookup_image(Image.fromarray(jacket_roi))[0]
# def ocr_component_score_paddle(self, component_bgr: Mat) -> int: def ocr_component_score_knn(self, component_bgr: cv2.Mat) -> int:
# # sourcery skip: inline-immediately-returned-variable
# score_rect = construct_int_xywh_rect(self.rois.component_rois.score_rect)
# score_roi = cv2.cvtColor(
# crop_xywh(component_bgr, score_rect), cv2.COLOR_BGR2GRAY
# )
# _, score_roi = cv2.threshold(
# score_roi, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
# )
# score_str = self.sift_db.ocr(score_roi, cls=False)[0][-1][1][0]
# score = int(score_str.replace("'", "").replace(" ", ""))
# return score
def ocr_component_score_knn(self, component_bgr: Mat) -> int:
# sourcery skip: inline-immediately-returned-variable # sourcery skip: inline-immediately-returned-variable
score_rect = construct_int_xywh_rect(self.rois.component_rois.score_rect) score_rect = construct_int_xywh_rect(self.rois.component_rois.score_rect)
score_roi = cv2.cvtColor( score_roi = cv2.cvtColor(
@ -136,7 +114,7 @@ class ChieriBotV4Ocr:
score_roi = cv2.fillPoly(score_roi, [contour], 0) score_roi = cv2.fillPoly(score_roi, [contour], 0)
return ocr_digits_by_contour_knn(score_roi, self.score_knn) return ocr_digits_by_contour_knn(score_roi, self.score_knn)
def find_pfl_rects(self, component_pfl_processed: Mat) -> List[List[int]]: def find_pfl_rects(self, component_pfl_processed: cv2.Mat) -> List[List[int]]:
# sourcery skip: inline-immediately-returned-variable # sourcery skip: inline-immediately-returned-variable
pfl_roi_find = cv2.morphologyEx( pfl_roi_find = cv2.morphologyEx(
component_pfl_processed, component_pfl_processed,
@ -162,7 +140,7 @@ class ChieriBotV4Ocr:
] ]
return pfl_rects_adjusted return pfl_rects_adjusted
def preprocess_component_pfl(self, component_bgr: Mat) -> Mat: def preprocess_component_pfl(self, component_bgr: cv2.Mat) -> cv2.Mat:
pfl_rect = construct_int_xywh_rect(self.rois.component_rois.pfl_rect) pfl_rect = construct_int_xywh_rect(self.rois.component_rois.pfl_rect)
pfl_roi = crop_xywh(component_bgr, pfl_rect) pfl_roi = crop_xywh(component_bgr, pfl_rect)
pfl_roi_hsv = cv2.cvtColor(pfl_roi, cv2.COLOR_BGR2HSV) pfl_roi_hsv = cv2.cvtColor(pfl_roi, cv2.COLOR_BGR2HSV)
@ -202,7 +180,7 @@ class ChieriBotV4Ocr:
return result_eroded if len(self.find_pfl_rects(result_eroded)) == 3 else result return result_eroded if len(self.find_pfl_rects(result_eroded)) == 3 else result
def ocr_component_pfl( def ocr_component_pfl(
self, component_bgr: Mat self, component_bgr: cv2.Mat
) -> Tuple[Optional[int], Optional[int], Optional[int]]: ) -> Tuple[Optional[int], Optional[int], Optional[int]]:
try: try:
pfl_roi = self.preprocess_component_pfl(component_bgr) pfl_roi = self.preprocess_component_pfl(component_bgr)
@ -233,16 +211,7 @@ class ChieriBotV4Ocr:
except Exception: except Exception:
return (None, None, None) return (None, None, None)
# def ocr_component_date(self, component_bgr: Mat): def ocr_component(self, component_bgr: cv2.Mat) -> B30OcrResultItem:
# date_rect = construct_int_xywh_rect(self.rois.component_rois.date_rect)
# date_roi = cv2.cvtColor(crop_xywh(component_bgr, date_rect), cv2.COLOR_BGR2GRAY)
# _, date_roi = cv2.threshold(
# date_roi, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
# )
# date_str = self.sift_db.ocr(date_roi, cls=False)[0][-1][1][0]
# return date_str
def ocr_component(self, component_bgr: Mat) -> B30OcrResultItem:
component_blur = cv2.GaussianBlur(component_bgr, (5, 5), 0) component_blur = cv2.GaussianBlur(component_bgr, (5, 5), 0)
rating_class = self.ocr_component_rating_class(component_blur) rating_class = self.ocr_component_rating_class(component_blur)
song_id = self.ocr_component_song_id(component_bgr) song_id = self.ocr_component_song_id(component_bgr)
@ -261,7 +230,7 @@ class ChieriBotV4Ocr:
date=None, date=None,
) )
def ocr(self, img_bgr: Mat) -> List[B30OcrResultItem]: def ocr(self, img_bgr: cv2.Mat) -> List[B30OcrResultItem]:
self.set_factor(img_bgr) self.set_factor(img_bgr)
return [ return [
self.ocr_component(component_bgr) self.ocr_component(component_bgr)

View File

@ -1,7 +1,9 @@
from typing import List, Optional from typing import List, Optional
import cv2
from ....crop import crop_xywh from ....crop import crop_xywh
from ....types import Mat, XYWHRect from ....types import XYWHRect
from ....utils import apply_factor, construct_int_xywh_rect from ....utils import apply_factor, construct_int_xywh_rect
@ -108,7 +110,7 @@ class ChieriBotV4Rois:
def b33_vertical_gap(self): def b33_vertical_gap(self):
return apply_factor(121, self.factor) return apply_factor(121, self.factor)
def components(self, img_bgr: Mat) -> List[Mat]: def components(self, img_bgr: cv2.Mat) -> List[cv2.Mat]:
first_rect = XYWHRect(x=self.left, y=self.top, w=self.width, h=self.height) first_rect = XYWHRect(x=self.left, y=self.top, w=self.width, h=self.height)
results = [] results = []

View File

@ -1,26 +1,25 @@
from math import floor from math import floor
from typing import Tuple from typing import Tuple
import cv2
import numpy as np import numpy as np
from .types import Mat
__all__ = ["crop_xywh", "crop_black_edges", "crop_black_edges_grayscale"] __all__ = ["crop_xywh", "crop_black_edges", "crop_black_edges_grayscale"]
def crop_xywh(mat: Mat, rect: Tuple[int, int, int, int]): def crop_xywh(mat: cv2.Mat, rect: Tuple[int, int, int, int]):
x, y, w, h = rect x, y, w, h = rect
return mat[y : y + h, x : x + w] return mat[y : y + h, x : x + w]
def is_black_edge(list_of_pixels: Mat, black_pixel: Mat, ratio: float = 0.6): def is_black_edge(list_of_pixels: cv2.Mat, black_pixel: cv2.Mat, ratio: float = 0.6):
pixels = list_of_pixels.reshape([-1, 3]) pixels = list_of_pixels.reshape([-1, 3])
return np.count_nonzero(np.all(pixels < black_pixel, axis=1)) > floor( return np.count_nonzero(np.all(pixels < black_pixel, axis=1)) > floor(
len(pixels) * ratio len(pixels) * ratio
) )
def crop_black_edges(img_bgr: Mat, black_threshold: int = 50): def crop_black_edges(img_bgr: cv2.Mat, black_threshold: int = 50):
cropped = img_bgr.copy() cropped = img_bgr.copy()
black_pixel = np.array([black_threshold] * 3, img_bgr.dtype) black_pixel = np.array([black_threshold] * 3, img_bgr.dtype)
height, width = img_bgr.shape[:2] height, width = img_bgr.shape[:2]
@ -66,7 +65,7 @@ def is_black_edge_grayscale(
def crop_black_edges_grayscale( def crop_black_edges_grayscale(
img_gray: Mat, black_threshold: int = 50 img_gray: cv2.Mat, black_threshold: int = 50
) -> Tuple[int, int, int, int]: ) -> Tuple[int, int, int, int]:
"""Returns cropped rect""" """Returns cropped rect"""
height, width = img_gray.shape[:2] height, width = img_gray.shape[:2]

View File

@ -5,7 +5,6 @@ import cv2
import numpy as np import numpy as np
from .crop import crop_xywh from .crop import crop_xywh
from .types import Mat, cv2_ml_KNearest
__all__ = [ __all__ = [
"FixRects", "FixRects",
@ -68,7 +67,7 @@ class FixRects:
@staticmethod @staticmethod
def split_connected( def split_connected(
img_masked: Mat, img_masked: cv2.Mat,
rects: Sequence[Tuple[int, int, int, int]], rects: Sequence[Tuple[int, int, int, int]],
rect_wh_ratio: float = 1.05, rect_wh_ratio: float = 1.05,
width_range_ratio: float = 0.1, width_range_ratio: float = 0.1,
@ -118,7 +117,7 @@ class FixRects:
return return_rects return return_rects
def resize_fill_square(img: Mat, target: int = 20): def resize_fill_square(img: cv2.Mat, target: int = 20):
h, w = img.shape[:2] h, w = img.shape[:2]
if h > w: if h > w:
new_h = target new_h = target
@ -150,14 +149,14 @@ def preprocess_hog(digit_rois):
return np.float32(samples) return np.float32(samples)
def ocr_digit_samples_knn(__samples, knn_model: cv2_ml_KNearest, k: int = 4): def ocr_digit_samples_knn(__samples, knn_model: cv2.ml.KNearest, k: int = 4):
_, results, _, _ = knn_model.findNearest(__samples, k) _, results, _, _ = knn_model.findNearest(__samples, k)
result_list = [int(r) for r in results.ravel()] result_list = [int(r) for r in results.ravel()]
result_str = "".join(str(r) for r in result_list if r > -1) result_str = "".join(str(r) for r in result_list if r > -1)
return int(result_str) if result_str else 0 return int(result_str) if result_str else 0
def ocr_digits_by_contour_get_samples(__roi_gray: Mat, size: int): def ocr_digits_by_contour_get_samples(__roi_gray: cv2.Mat, size: int):
roi = __roi_gray.copy() roi = __roi_gray.copy()
contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
rects = [cv2.boundingRect(c) for c in contours] rects = [cv2.boundingRect(c) for c in contours]
@ -170,8 +169,8 @@ def ocr_digits_by_contour_get_samples(__roi_gray: Mat, size: int):
def ocr_digits_by_contour_knn( def ocr_digits_by_contour_knn(
__roi_gray: Mat, __roi_gray: cv2.Mat,
knn_model: cv2_ml_KNearest, knn_model: cv2.ml.KNearest,
*, *,
k=4, k=4,
size: int = 20, size: int = 20,

View File

@ -1,10 +1,5 @@
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, NamedTuple, Protocol, Tuple, Union from typing import NamedTuple, Tuple, Union
import numpy as np
# from pylance
Mat = np.ndarray[int, np.dtype[np.generic]]
class XYWHRect(NamedTuple): class XYWHRect(NamedTuple):
@ -24,19 +19,3 @@ class XYWHRect(NamedTuple):
raise ValueError() raise ValueError()
return self.__class__(*[a - b for a, b in zip(self, other)]) return self.__class__(*[a - b for a, b in zip(self, other)])
class cv2_ml_StatModel(Protocol):
def predict(self, samples: np.ndarray, results: np.ndarray, flags: int = 0):
...
def train(self, samples: np.ndarray, layout: int, responses: np.ndarray):
...
class cv2_ml_KNearest(cv2_ml_StatModel, Protocol):
def findNearest(
self, samples: np.ndarray, k: int
) -> Tuple[Any, np.ndarray, np.ndarray, np.ndarray]:
"""cv.ml.KNearest.findNearest(samples, k[, results[, neighborResponses[, dist]]]) -> retval, results, neighborResponses, dist"""
...

View File

@ -1,17 +1,17 @@
import io import io
from collections.abc import Iterable from collections.abc import Iterable
from typing import Callable, Tuple, TypeVar, Union, overload from typing import Callable, TypeVar, Union, overload
import cv2 import cv2
import numpy as np import numpy as np
from PIL import Image, ImageCms from PIL import Image, ImageCms
from .types import Mat, XYWHRect from .types import XYWHRect
__all__ = ["imread_unicode"] __all__ = ["imread_unicode"]
def imread_unicode(filepath: str, flags: int = cv2.IMREAD_UNCHANGED) -> Mat: def imread_unicode(filepath: str, flags: int = cv2.IMREAD_UNCHANGED):
# https://stackoverflow.com/a/57872297/16484891 # https://stackoverflow.com/a/57872297/16484891
# CC BY-SA 4.0 # CC BY-SA 4.0
return cv2.imdecode(np.fromfile(filepath, dtype=np.uint8), flags) return cv2.imdecode(np.fromfile(filepath, dtype=np.uint8), flags)