refactor!: OCR text result provider

This commit is contained in:
2025-06-21 21:49:34 +08:00
parent 3ebb058cdf
commit abfd37dbef
7 changed files with 142 additions and 112 deletions

View File

@ -1,4 +1,3 @@
from .crop import * from .crop import *
from .device import * from .device import *
from .ocr import *
from .utils import * from .utils import *

View File

@ -4,12 +4,6 @@ import cv2
import numpy as np import numpy as np
from ....crop import crop_xywh from ....crop import crop_xywh
from ....ocr import (
FixRects,
ocr_digits_by_contour_knn,
preprocess_hog,
resize_fill_square,
)
from ....phash_db import ImagePhashDatabase from ....phash_db import ImagePhashDatabase
from ....types import Mat from ....types import Mat
from ...shared import B30OcrResultItem from ...shared import B30OcrResultItem
@ -28,36 +22,21 @@ from .colors import (
PURE_BG_MIN_HSV, PURE_BG_MIN_HSV,
) )
from .rois import ChieriBotV4Rois from .rois import ChieriBotV4Rois
from ....providers.knn import OcrKNearestTextProvider
class ChieriBotV4Ocr: class ChieriBotV4Ocr:
def __init__( def __init__(
self, self,
score_knn: cv2.ml.KNearest, score_knn_provider: OcrKNearestTextProvider,
pfl_knn: cv2.ml.KNearest, pfl_knn_provider: OcrKNearestTextProvider,
phash_db: ImagePhashDatabase, phash_db: ImagePhashDatabase,
factor: float = 1.0, factor: float = 1.0,
): ):
self.__score_knn = score_knn
self.__pfl_knn = pfl_knn
self.__phash_db = phash_db self.__phash_db = phash_db
self.__rois = ChieriBotV4Rois(factor) self.__rois = ChieriBotV4Rois(factor)
self.pfl_knn_provider = pfl_knn_provider
@property self.score_knn_provider = score_knn_provider
def score_knn(self):
return self.__score_knn
@score_knn.setter
def score_knn(self, knn_digits_model: cv2.ml.KNearest):
self.__score_knn = knn_digits_model
@property
def pfl_knn(self):
return self.__pfl_knn
@pfl_knn.setter
def pfl_knn(self, knn_digits_model: cv2.ml.KNearest):
self.__pfl_knn = knn_digits_model
@property @property
def phash_db(self): def phash_db(self):
@ -125,7 +104,9 @@ class ChieriBotV4Ocr:
if rect[3] > score_roi.shape[0] * 0.5: if rect[3] > score_roi.shape[0] * 0.5:
continue continue
score_roi = cv2.fillPoly(score_roi, [contour], 0) score_roi = cv2.fillPoly(score_roi, [contour], 0)
return ocr_digits_by_contour_knn(score_roi, self.score_knn)
ocr_result = self.score_knn_provider.result(score_roi)
return int(ocr_result) if ocr_result else 0
def find_pfl_rects( def find_pfl_rects(
self, component_pfl_processed: Mat self, component_pfl_processed: Mat
@ -203,25 +184,9 @@ class ChieriBotV4Ocr:
pure_far_lost = [] pure_far_lost = []
for pfl_roi_rect in pfl_rects: for pfl_roi_rect in pfl_rects:
roi = crop_xywh(pfl_roi, pfl_roi_rect) roi = crop_xywh(pfl_roi, pfl_roi_rect)
digit_contours, _ = cv2.findContours( result = self.pfl_knn_provider.result(roi)
roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE pure_far_lost.append(int(result) if result else None)
)
digit_rects = [cv2.boundingRect(c) for c in digit_contours]
digit_rects = FixRects.connect_broken(
digit_rects, roi.shape[1], roi.shape[0]
)
digit_rects = FixRects.split_connected(roi, digit_rects)
digit_rects = sorted(digit_rects, key=lambda r: r[0])
digits = []
for digit_rect in digit_rects:
digit = crop_xywh(roi, digit_rect)
digit = resize_fill_square(digit, 20)
digits.append(digit)
samples = preprocess_hog(digits)
_, results, _, _ = self.pfl_knn.findNearest(samples, 4)
results = [str(int(i)) for i in results.ravel()]
pure_far_lost.append(int("".join(results)))
return tuple(pure_far_lost) return tuple(pure_far_lost)
except Exception: except Exception:
return (None, None, None) return (None, None, None)

View File

@ -5,10 +5,10 @@ from typing import Optional
@dataclass @dataclass
class DeviceOcrResult: class DeviceOcrResult:
rating_class: int rating_class: int
pure: int
far: int
lost: int
score: int score: int
pure: Optional[int] = None
far: Optional[int] = None
lost: Optional[int] = None
max_recall: Optional[int] = None max_recall: Optional[int] = None
song_id: Optional[str] = None song_id: Optional[str] = None
song_id_possibility: Optional[float] = None song_id_possibility: Optional[float] = None

View File

@ -1,15 +1,8 @@
import cv2 import cv2
import numpy as np import numpy as np
from ..crop import crop_xywh
from ..ocr import (
FixRects,
ocr_digit_samples_knn,
ocr_digits_by_contour_knn,
preprocess_hog,
resize_fill_square,
)
from ..phash_db import ImagePhashDatabase from ..phash_db import ImagePhashDatabase
from ..providers.knn import OcrKNearestTextProvider
from ..types import Mat from ..types import Mat
from .common import DeviceOcrResult from .common import DeviceOcrResult
from .rois.extractor import DeviceRoisExtractor from .rois.extractor import DeviceRoisExtractor
@ -21,38 +14,37 @@ class DeviceOcr:
self, self,
extractor: DeviceRoisExtractor, extractor: DeviceRoisExtractor,
masker: DeviceRoisMasker, masker: DeviceRoisMasker,
knn_model: cv2.ml.KNearest, knn_provider: OcrKNearestTextProvider,
phash_db: ImagePhashDatabase, phash_db: ImagePhashDatabase,
): ):
self.extractor = extractor self.extractor = extractor
self.masker = masker self.masker = masker
self.knn_model = knn_model self.knn_provider = knn_provider
self.phash_db = phash_db self.phash_db = phash_db
def pfl(self, roi_gray: Mat, factor: float = 1.25): def pfl(self, roi_gray: Mat, factor: float = 1.25):
contours, _ = cv2.findContours( def contour_filter(cnt):
roi_gray, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE return cv2.contourArea(cnt) >= 5 * factor
)
filtered_contours = [c for c in contours if cv2.contourArea(c) >= 5 * factor]
rects = [cv2.boundingRect(c) for c in filtered_contours]
rects = FixRects.connect_broken(rects, roi_gray.shape[1], roi_gray.shape[0])
filtered_rects = [r for r in rects if r[2] >= 5 * factor and r[3] >= 6 * factor] contours = self.knn_provider.contours(roi_gray)
filtered_rects = FixRects.split_connected(roi_gray, filtered_rects) contours_filtered = self.knn_provider.contours(
filtered_rects = sorted(filtered_rects, key=lambda r: r[0]) roi_gray, contours_filter=contour_filter
)
roi_ocr = roi_gray.copy() roi_ocr = roi_gray.copy()
filtered_contours_flattened = {tuple(c.flatten()) for c in filtered_contours} contours_filtered_flattened = {tuple(c.flatten()) for c in contours_filtered}
for contour in contours: for contour in contours:
if tuple(contour.flatten()) in filtered_contours_flattened: if tuple(contour.flatten()) in contours_filtered_flattened:
continue continue
roi_ocr = cv2.fillPoly(roi_ocr, [contour], [0]) roi_ocr = cv2.fillPoly(roi_ocr, [contour], [0])
digit_rois = [
resize_fill_square(crop_xywh(roi_ocr, r), 20) for r in filtered_rects
]
samples = preprocess_hog(digit_rois) ocr_result = self.knn_provider.result(
return ocr_digit_samples_knn(samples, self.knn_model) roi_ocr,
contours_filter=lambda cnt: cv2.contourArea(cnt) >= 5 * factor,
rects_filter=lambda rect: rect[2] >= 5 * factor and rect[3] >= 6 * factor,
)
return int(ocr_result) if ocr_result else 0
def pure(self): def pure(self):
return self.pfl(self.masker.pure(self.extractor.pure)) return self.pfl(self.masker.pure(self.extractor.pure))
@ -65,13 +57,14 @@ class DeviceOcr:
def score(self): def score(self):
roi = self.masker.score(self.extractor.score) roi = self.masker.score(self.extractor.score)
contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) contours = self.knn_provider.contours(roi)
for contour in contours: for contour in contours:
if ( if (
cv2.boundingRect(contour)[3] < roi.shape[0] * 0.6 cv2.boundingRect(contour)[3] < roi.shape[0] * 0.6
): # h < score_component_h * 0.6 ): # h < score_component_h * 0.6
roi = cv2.fillPoly(roi, [contour], [0]) roi = cv2.fillPoly(roi, [contour], [0])
return ocr_digits_by_contour_knn(roi, self.knn_model) ocr_result = self.knn_provider.result(roi)
return int(ocr_result) if ocr_result else 0
def rating_class(self): def rating_class(self):
roi = self.extractor.rating_class roi = self.extractor.rating_class
@ -85,9 +78,10 @@ class DeviceOcr:
return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0] return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0]
def max_recall(self): def max_recall(self):
return ocr_digits_by_contour_knn( ocr_result = self.knn_provider.result(
self.masker.max_recall(self.extractor.max_recall), self.knn_model self.masker.max_recall(self.extractor.max_recall)
) )
return int(ocr_result) if ocr_result else None
def clear_status(self): def clear_status(self):
roi = self.extractor.clear_status roi = self.extractor.clear_status

View File

@ -0,0 +1,12 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from ..types import Mat
class OcrTextProvider(ABC):
@abstractmethod
def result_raw(self, img: "Mat", /, *args, **kwargs) -> Any: ...
@abstractmethod
def result(self, img: "Mat", /, *args, **kwargs) -> Optional[str]: ...

View File

@ -1,18 +1,19 @@
import logging
import math import math
from typing import Optional, Sequence, Tuple from typing import TYPE_CHECKING, Callable, Optional, Sequence, Tuple
import cv2 import cv2
import numpy as np import numpy as np
from .crop import crop_xywh from ..crop import crop_xywh
from .types import Mat from .base import OcrTextProvider
__all__ = [ if TYPE_CHECKING:
"FixRects", from cv2.ml import KNearest
"preprocess_hog",
"ocr_digits_by_contour_get_samples", from ..types import Mat
"ocr_digits_by_contour_knn",
] logger = logging.getLogger(__name__)
class FixRects: class FixRects:
@ -68,7 +69,7 @@ class FixRects:
@staticmethod @staticmethod
def split_connected( def split_connected(
img_masked: Mat, img_masked: "Mat",
rects: Sequence[Tuple[int, int, int, int]], rects: Sequence[Tuple[int, int, int, int]],
rect_wh_ratio: float = 1.05, rect_wh_ratio: float = 1.05,
width_range_ratio: float = 0.1, width_range_ratio: float = 0.1,
@ -118,7 +119,7 @@ class FixRects:
return return_rects return return_rects
def resize_fill_square(img: Mat, target: int = 20): def resize_fill_square(img: "Mat", target: int = 20):
h, w = img.shape[:2] h, w = img.shape[:2]
if h > w: if h > w:
new_h = target new_h = target
@ -152,29 +153,88 @@ def preprocess_hog(digit_rois):
def ocr_digit_samples_knn(__samples, knn_model: cv2.ml.KNearest, k: int = 4): def ocr_digit_samples_knn(__samples, knn_model: cv2.ml.KNearest, k: int = 4):
_, results, _, _ = knn_model.findNearest(__samples, k) _, results, _, _ = knn_model.findNearest(__samples, k)
result_list = [int(r) for r in results.ravel()] return [int(r) for r in results.ravel()]
result_str = "".join(str(r) for r in result_list if r > -1)
return int(result_str) if result_str else 0
def ocr_digits_by_contour_get_samples(__roi_gray: Mat, size: int): class OcrKNearestTextProvider(OcrTextProvider):
roi = __roi_gray.copy() _ContourFilter = Callable[["Mat"], bool]
contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) _RectsFilter = Callable[[Sequence[int]], bool]
rects = [cv2.boundingRect(c) for c in contours]
rects = FixRects.connect_broken(rects, roi.shape[1], roi.shape[0])
rects = FixRects.split_connected(roi, rects)
rects = sorted(rects, key=lambda r: r[0])
# digit_rois = [cv2.resize(crop_xywh(roi, rect), size) for rect in rects]
digit_rois = [resize_fill_square(crop_xywh(roi, rect), size) for rect in rects]
return preprocess_hog(digit_rois)
def __init__(self, model: "KNearest"):
self.model = model
def ocr_digits_by_contour_knn( def contours(
__roi_gray: Mat, self, img: "Mat", /, *, contours_filter: Optional[_ContourFilter] = None
knn_model: cv2.ml.KNearest, ):
cnts, _ = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
if contours_filter:
cnts = list(filter(contours_filter, cnts))
return cnts
def result_raw(
self,
img: "Mat",
/,
*, *,
k=4, fix_rects: bool = True,
size: int = 20, contours_filter: Optional[_ContourFilter] = None,
) -> int: rects_filter: Optional[_RectsFilter] = None,
samples = ocr_digits_by_contour_get_samples(__roi_gray, size) ):
return ocr_digit_samples_knn(samples, knn_model, k) """
:param img: grayscaled roi
"""
try:
cnts, _ = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours_filter:
cnts = list(filter(contours_filter, cnts))
rects = [cv2.boundingRect(cnt) for cnt in cnts]
if fix_rects and rects_filter:
rects = FixRects.connect_broken(rects, img.shape[1], img.shape[0]) # type: ignore
rects = list(filter(rects_filter, rects))
rects = FixRects.split_connected(img, rects)
elif fix_rects:
rects = FixRects.connect_broken(rects, img.shape[1], img.shape[0]) # type: ignore
rects = FixRects.split_connected(img, rects)
elif rects_filter:
rects = list(filter(rects_filter, rects))
rects = sorted(rects, key=lambda r: r[0])
digits = []
for rect in rects:
digit = crop_xywh(img, rect)
digit = resize_fill_square(digit, 20)
digits.append(digit)
samples = preprocess_hog(digits)
return ocr_digit_samples_knn(samples, self.model)
except Exception:
logger.exception("Error occurred during KNearest OCR")
return None
def result(
self,
img: "Mat",
/,
*,
fix_rects: bool = True,
contours_filter: Optional[_ContourFilter] = None,
rects_filter: Optional[_RectsFilter] = None,
):
"""
:param img: grayscaled roi
"""
raw = self.result_raw(
img,
fix_rects=fix_rects,
contours_filter=contours_filter,
rects_filter=rects_filter,
)
return (
"".join(["".join(str(r) for r in raw if r > -1)])
if raw is not None
else None
)