mirror of
https://github.com/283375/arcaea-offline-ocr.git
synced 2025-04-18 21:10:17 +00:00
feat: support for ocr digits using cv2.ml.KNearest
model
This commit is contained in:
parent
95987f6a2b
commit
fd10c68c49
@ -1,19 +1,8 @@
|
||||
import re
|
||||
from typing import Dict, List
|
||||
|
||||
from cv2 import (
|
||||
CHAIN_APPROX_SIMPLE,
|
||||
RETR_EXTERNAL,
|
||||
TM_CCOEFF_NORMED,
|
||||
boundingRect,
|
||||
findContours,
|
||||
imshow,
|
||||
matchTemplate,
|
||||
minMaxLoc,
|
||||
rectangle,
|
||||
resize,
|
||||
waitKey,
|
||||
)
|
||||
import cv2
|
||||
import numpy as np
|
||||
from imutils import grab_contours
|
||||
from imutils import resize as imutils_resize
|
||||
from pytesseract import image_to_string
|
||||
@ -24,7 +13,7 @@ from .template import (
|
||||
load_builtin_digit_template,
|
||||
matchTemplateMultiple,
|
||||
)
|
||||
from .types import Mat
|
||||
from .types import Mat, cv2_ml_KNearest
|
||||
|
||||
__all__ = [
|
||||
"group_numbers",
|
||||
@ -155,6 +144,18 @@ def ocr_digits(
|
||||
return int(joined_str) if joined_str else None
|
||||
|
||||
|
||||
def ocr_digits_knn_model(img_gray: Mat, knn_model: cv2_ml_KNearest):
|
||||
if img_gray.shape[:2] != (20, 20):
|
||||
img = cv2.resize(img_gray, [20, 20])
|
||||
else:
|
||||
img = img_gray.copy()
|
||||
|
||||
img = img.astype(np.float32)
|
||||
img = img.reshape([1, -1])
|
||||
retval, _, _, _ = knn_model.findNearest(img, 10)
|
||||
return int(retval)
|
||||
|
||||
|
||||
def ocr_pure(img_masked: Mat):
|
||||
template = load_builtin_digit_template("default")
|
||||
return ocr_digits(
|
||||
@ -173,9 +174,11 @@ def ocr_score(img_cropped: Mat):
|
||||
templates = load_builtin_digit_template("default").regular
|
||||
templates_dict = dict(enumerate(templates[:10]))
|
||||
|
||||
cnts = findContours(img_cropped.copy(), RETR_EXTERNAL, CHAIN_APPROX_SIMPLE)
|
||||
cnts = cv2.findContours(
|
||||
img_cropped.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
cnts = grab_contours(cnts)
|
||||
rects = [boundingRect(cnt) for cnt in cnts]
|
||||
rects = [cv2.boundingRect(cnt) for cnt in cnts]
|
||||
rects = sorted(rects, key=lambda r: r[0])
|
||||
|
||||
# debug
|
||||
@ -190,9 +193,9 @@ def ocr_score(img_cropped: Mat):
|
||||
|
||||
digit_results: Dict[int, float] = {}
|
||||
for digit, template in templates_dict.items():
|
||||
template = resize(template, roi.shape[::-1])
|
||||
template_result = matchTemplate(roi, template, TM_CCOEFF_NORMED)
|
||||
min_val, max_val, min_loc, max_loc = minMaxLoc(template_result)
|
||||
template = cv2.resize(template, roi.shape[::-1])
|
||||
template_result = cv2.matchTemplate(roi, template, cv2.TM_CCOEFF_NORMED)
|
||||
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(template_result)
|
||||
digit_results[digit] = max_val
|
||||
|
||||
digit_results = {k: v for k, v in digit_results.items() if v > 0.5}
|
||||
|
@ -1,5 +1,5 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import NamedTuple, Tuple, Union
|
||||
from typing import Any, NamedTuple, Protocol, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -24,3 +24,19 @@ class XYWHRect(NamedTuple):
|
||||
raise ValueError()
|
||||
|
||||
return self.__class__(*[a - b for a, b in zip(self, other)])
|
||||
|
||||
|
||||
class cv2_ml_StatModel(Protocol):
|
||||
def predict(self, samples: np.ndarray, results: np.ndarray, flags: int = 0):
|
||||
...
|
||||
|
||||
def train(self, samples: np.ndarray, layout: int, responses: np.ndarray):
|
||||
...
|
||||
|
||||
|
||||
class cv2_ml_KNearest(cv2_ml_StatModel, Protocol):
|
||||
def findNearest(
|
||||
self, samples: np.ndarray, k: int
|
||||
) -> Tuple[Any, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""cv.ml.KNearest.findNearest(samples, k[, results[, neighborResponses[, dist]]]) -> retval, results, neighborResponses, dist"""
|
||||
...
|
||||
|
Loading…
x
Reference in New Issue
Block a user