feat: support for ocr digits using cv2.ml.KNearest model

This commit is contained in:
283375 2023-08-05 02:27:31 +08:00
parent 95987f6a2b
commit fd10c68c49
2 changed files with 39 additions and 20 deletions

View File

@ -1,19 +1,8 @@
import re
from typing import Dict, List
from cv2 import (
CHAIN_APPROX_SIMPLE,
RETR_EXTERNAL,
TM_CCOEFF_NORMED,
boundingRect,
findContours,
imshow,
matchTemplate,
minMaxLoc,
rectangle,
resize,
waitKey,
)
import cv2
import numpy as np
from imutils import grab_contours
from imutils import resize as imutils_resize
from pytesseract import image_to_string
@ -24,7 +13,7 @@ from .template import (
load_builtin_digit_template,
matchTemplateMultiple,
)
from .types import Mat
from .types import Mat, cv2_ml_KNearest
__all__ = [
"group_numbers",
@ -155,6 +144,18 @@ def ocr_digits(
return int(joined_str) if joined_str else None
def ocr_digits_knn_model(img_gray: Mat, knn_model: cv2_ml_KNearest):
if img_gray.shape[:2] != (20, 20):
img = cv2.resize(img_gray, [20, 20])
else:
img = img_gray.copy()
img = img.astype(np.float32)
img = img.reshape([1, -1])
retval, _, _, _ = knn_model.findNearest(img, 10)
return int(retval)
def ocr_pure(img_masked: Mat):
template = load_builtin_digit_template("default")
return ocr_digits(
@ -173,9 +174,11 @@ def ocr_score(img_cropped: Mat):
templates = load_builtin_digit_template("default").regular
templates_dict = dict(enumerate(templates[:10]))
cnts = findContours(img_cropped.copy(), RETR_EXTERNAL, CHAIN_APPROX_SIMPLE)
cnts = cv2.findContours(
img_cropped.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
cnts = grab_contours(cnts)
rects = [boundingRect(cnt) for cnt in cnts]
rects = [cv2.boundingRect(cnt) for cnt in cnts]
rects = sorted(rects, key=lambda r: r[0])
# debug
@ -190,9 +193,9 @@ def ocr_score(img_cropped: Mat):
digit_results: Dict[int, float] = {}
for digit, template in templates_dict.items():
template = resize(template, roi.shape[::-1])
template_result = matchTemplate(roi, template, TM_CCOEFF_NORMED)
min_val, max_val, min_loc, max_loc = minMaxLoc(template_result)
template = cv2.resize(template, roi.shape[::-1])
template_result = cv2.matchTemplate(roi, template, cv2.TM_CCOEFF_NORMED)
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(template_result)
digit_results[digit] = max_val
digit_results = {k: v for k, v in digit_results.items() if v > 0.5}

View File

@ -1,5 +1,5 @@
from collections.abc import Iterable
from typing import NamedTuple, Tuple, Union
from typing import Any, NamedTuple, Protocol, Tuple, Union
import numpy as np
@ -24,3 +24,19 @@ class XYWHRect(NamedTuple):
raise ValueError()
return self.__class__(*[a - b for a, b in zip(self, other)])
class cv2_ml_StatModel(Protocol):
def predict(self, samples: np.ndarray, results: np.ndarray, flags: int = 0):
...
def train(self, samples: np.ndarray, layout: int, responses: np.ndarray):
...
class cv2_ml_KNearest(cv2_ml_StatModel, Protocol):
def findNearest(
self, samples: np.ndarray, k: int
) -> Tuple[Any, np.ndarray, np.ndarray, np.ndarray]:
"""cv.ml.KNearest.findNearest(samples, k[, results[, neighborResponses[, dist]]]) -> retval, results, neighborResponses, dist"""
...