From fd10c68c49861a4034618d36a21d155cac65b71d Mon Sep 17 00:00:00 2001 From: 283375 Date: Sat, 5 Aug 2023 02:27:31 +0800 Subject: [PATCH] feat: support for ocr digits using `cv2.ml.KNearest` model --- src/arcaea_offline_ocr/ocr.py | 41 ++++++++++++++++++--------------- src/arcaea_offline_ocr/types.py | 18 ++++++++++++++- 2 files changed, 39 insertions(+), 20 deletions(-) diff --git a/src/arcaea_offline_ocr/ocr.py b/src/arcaea_offline_ocr/ocr.py index edffa20..ee24c88 100644 --- a/src/arcaea_offline_ocr/ocr.py +++ b/src/arcaea_offline_ocr/ocr.py @@ -1,19 +1,8 @@ import re from typing import Dict, List -from cv2 import ( - CHAIN_APPROX_SIMPLE, - RETR_EXTERNAL, - TM_CCOEFF_NORMED, - boundingRect, - findContours, - imshow, - matchTemplate, - minMaxLoc, - rectangle, - resize, - waitKey, -) +import cv2 +import numpy as np from imutils import grab_contours from imutils import resize as imutils_resize from pytesseract import image_to_string @@ -24,7 +13,7 @@ from .template import ( load_builtin_digit_template, matchTemplateMultiple, ) -from .types import Mat +from .types import Mat, cv2_ml_KNearest __all__ = [ "group_numbers", @@ -155,6 +144,18 @@ def ocr_digits( return int(joined_str) if joined_str else None +def ocr_digits_knn_model(img_gray: Mat, knn_model: cv2_ml_KNearest): + if img_gray.shape[:2] != (20, 20): + img = cv2.resize(img_gray, [20, 20]) + else: + img = img_gray.copy() + + img = img.astype(np.float32) + img = img.reshape([1, -1]) + retval, _, _, _ = knn_model.findNearest(img, 10) + return int(retval) + + def ocr_pure(img_masked: Mat): template = load_builtin_digit_template("default") return ocr_digits( @@ -173,9 +174,11 @@ def ocr_score(img_cropped: Mat): templates = load_builtin_digit_template("default").regular templates_dict = dict(enumerate(templates[:10])) - cnts = findContours(img_cropped.copy(), RETR_EXTERNAL, CHAIN_APPROX_SIMPLE) + cnts = cv2.findContours( + img_cropped.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) cnts = grab_contours(cnts) - rects = [boundingRect(cnt) for cnt in cnts] + rects = [cv2.boundingRect(cnt) for cnt in cnts] rects = sorted(rects, key=lambda r: r[0]) # debug @@ -190,9 +193,9 @@ def ocr_score(img_cropped: Mat): digit_results: Dict[int, float] = {} for digit, template in templates_dict.items(): - template = resize(template, roi.shape[::-1]) - template_result = matchTemplate(roi, template, TM_CCOEFF_NORMED) - min_val, max_val, min_loc, max_loc = minMaxLoc(template_result) + template = cv2.resize(template, roi.shape[::-1]) + template_result = cv2.matchTemplate(roi, template, cv2.TM_CCOEFF_NORMED) + min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(template_result) digit_results[digit] = max_val digit_results = {k: v for k, v in digit_results.items() if v > 0.5} diff --git a/src/arcaea_offline_ocr/types.py b/src/arcaea_offline_ocr/types.py index 8c0730d..3a3dc92 100644 --- a/src/arcaea_offline_ocr/types.py +++ b/src/arcaea_offline_ocr/types.py @@ -1,5 +1,5 @@ from collections.abc import Iterable -from typing import NamedTuple, Tuple, Union +from typing import Any, NamedTuple, Protocol, Tuple, Union import numpy as np @@ -24,3 +24,19 @@ class XYWHRect(NamedTuple): raise ValueError() return self.__class__(*[a - b for a, b in zip(self, other)]) + + +class cv2_ml_StatModel(Protocol): + def predict(self, samples: np.ndarray, results: np.ndarray, flags: int = 0): + ... + + def train(self, samples: np.ndarray, layout: int, responses: np.ndarray): + ... + + +class cv2_ml_KNearest(cv2_ml_StatModel, Protocol): + def findNearest( + self, samples: np.ndarray, k: int + ) -> Tuple[Any, np.ndarray, np.ndarray, np.ndarray]: + """cv.ml.KNearest.findNearest(samples, k[, results[, neighborResponses[, dist]]]) -> retval, results, neighborResponses, dist""" + ...