mirror of
https://github.com/283375/arcaea-offline-ocr.git
synced 2025-04-16 03:50:18 +00:00
refactor!: remove matchTemplate
digit recognize functions
This commit is contained in:
parent
64598d0a84
commit
1326ab66a2
@ -1,101 +0,0 @@
|
||||
import base64
|
||||
import json
|
||||
import pickle
|
||||
|
||||
import cv2
|
||||
import imutils
|
||||
import numpy
|
||||
from imutils import contours
|
||||
|
||||
|
||||
def load_template_image(filename: str) -> dict[int, cv2.Mat]:
|
||||
"""
|
||||
Arguments:
|
||||
filename -- An image with white background and black "0 1 2 3 4 5 6 7 8 9 '" text.
|
||||
"""
|
||||
# https://pyimagesearch.com/2017/07/17/credit-card-ocr-with-opencv-and-python/
|
||||
ref = cv2.imread(filename)
|
||||
ref = cv2.cvtColor(ref, cv2.COLOR_BGR2GRAY)
|
||||
ref = cv2.threshold(ref, 10, 255, cv2.THRESH_BINARY_INV)[1]
|
||||
refCnts = cv2.findContours(ref.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
refCnts = imutils.grab_contours(refCnts)
|
||||
refCnts = contours.sort_contours(refCnts, method="left-to-right")[0]
|
||||
digits = {}
|
||||
keys = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, "'"]
|
||||
for key, cnt in zip(keys, refCnts):
|
||||
(x, y, w, h) = cv2.boundingRect(cnt)
|
||||
roi = ref[y : y + h, x : x + w]
|
||||
digits[key] = roi
|
||||
return list(digits.values())
|
||||
|
||||
|
||||
def process_default(img_path: str):
|
||||
template_res = load_template_image(img_path)
|
||||
template_res_pickled = [
|
||||
base64.b64encode(
|
||||
pickle.dumps(template_arr, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
).decode("utf-8")
|
||||
for template_arr in template_res
|
||||
]
|
||||
return json.dumps(template_res_pickled)
|
||||
|
||||
|
||||
def process_eroded(img_path: str):
|
||||
kernel = numpy.ones((5, 5), numpy.uint8)
|
||||
template_res = load_template_image(img_path)
|
||||
template_res_eroded = []
|
||||
# cv2.imshow("orig", template_res[7])
|
||||
for template in template_res:
|
||||
# add borders
|
||||
template = cv2.copyMakeBorder(
|
||||
template, 10, 10, 10, 10, cv2.BORDER_CONSTANT, None, (0, 0, 0)
|
||||
)
|
||||
# erode
|
||||
template = cv2.erode(template, kernel)
|
||||
# remove borders
|
||||
h, w = template.shape
|
||||
template = template[10 : h - 10, 10 : w - 10]
|
||||
template_res_eroded.append(template)
|
||||
# cv2.imshow("erode", template_res_eroded[7])
|
||||
# cv2.waitKey(0)
|
||||
template_res_pickled = [
|
||||
base64.b64encode(
|
||||
pickle.dumps(template_arr, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
).decode("utf-8")
|
||||
for template_arr in template_res_eroded
|
||||
]
|
||||
return json.dumps(template_res_pickled)
|
||||
|
||||
|
||||
TEMPLATES = [
|
||||
(
|
||||
"DEFAULT_REGULAR",
|
||||
"./assets/templates/GeoSansLightRegular.png",
|
||||
process_default,
|
||||
),
|
||||
(
|
||||
"DEFAULT_ITALIC",
|
||||
"./assets/templates/GeoSansLightItalic.png",
|
||||
process_default,
|
||||
),
|
||||
(
|
||||
"DEFAULT_REGULAR_ERODED",
|
||||
"./assets/templates/GeoSansLightRegular.png",
|
||||
process_eroded,
|
||||
),
|
||||
(
|
||||
"DEFAULT_ITALIC_ERODED",
|
||||
"./assets/templates/GeoSansLightItalic.png",
|
||||
process_eroded,
|
||||
),
|
||||
]
|
||||
|
||||
OUTPUT_FILE = "_builtin_templates.py"
|
||||
output = ""
|
||||
|
||||
for name, img_path, process_func in TEMPLATES:
|
||||
output += f"{name} = {process_func(img_path)}"
|
||||
output += "\n"
|
||||
|
||||
with open(OUTPUT_FILE, "w", encoding="utf-8") as of:
|
||||
of.write(output)
|
File diff suppressed because one or more lines are too long
@ -1,137 +1,30 @@
|
||||
import re
|
||||
from typing import Dict, List
|
||||
from typing import Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from imutils import grab_contours
|
||||
from imutils import resize as imutils_resize
|
||||
from numpy.linalg import norm
|
||||
from pytesseract import image_to_string
|
||||
|
||||
from .template import (
|
||||
MatchTemplateMultipleResult,
|
||||
TemplateItem,
|
||||
load_builtin_digit_template,
|
||||
matchTemplateMultiple,
|
||||
)
|
||||
from .crop import crop_xywh
|
||||
from .mask import mask_byd, mask_ftr, mask_prs, mask_pst
|
||||
from .types import Mat, cv2_ml_KNearest
|
||||
|
||||
__all__ = [
|
||||
"group_numbers",
|
||||
"FilterDigitResultDict",
|
||||
"filter_digit_results",
|
||||
"ocr_digits",
|
||||
"ocr_pure",
|
||||
"ocr_far_lost",
|
||||
"ocr_score",
|
||||
"ocr_max_recall",
|
||||
"ocr_rating_class",
|
||||
"ocr_title",
|
||||
"preprocess_hog",
|
||||
"ocr_digits_by_contour_samples",
|
||||
"ocr_digits_by_contour_knn",
|
||||
]
|
||||
|
||||
|
||||
def group_numbers(numbers: List[int], threshold: int) -> List[List[int]]:
|
||||
"""
|
||||
```
|
||||
numbers = [26, 189, 303, 348, 32, 195, 391, 145, 77]
|
||||
group_numbers(numbers, 10) -> [[26, 32], [77], [145], [189, 195], [303], [348], [391]]
|
||||
group_numbers(numbers, 5) -> [[26], [32], [77], [145], [189], [195], [303], [348], [391]]
|
||||
group_numbers(numbers, 50) -> [[26, 32, 77], [145, 189, 195], [303, 348, 391]]
|
||||
# from Bing AI
|
||||
```
|
||||
"""
|
||||
numbers.sort()
|
||||
# Initialize an empty list of groups
|
||||
groups = []
|
||||
# Initialize an empty list for the current group
|
||||
group = []
|
||||
# Loop through the numbers
|
||||
for number in numbers:
|
||||
# If the current group is empty or the number is within the threshold of the last number in the group
|
||||
if not group or number - group[-1] <= threshold:
|
||||
# Append the number to the current group
|
||||
group.append(number)
|
||||
# Otherwise
|
||||
else:
|
||||
# Append the current group to the list of groups
|
||||
groups.append(group)
|
||||
# Start a new group with the number
|
||||
group = [number]
|
||||
# Append the last group to the list of groups
|
||||
groups.append(group)
|
||||
# Return the list of groups
|
||||
return groups
|
||||
|
||||
|
||||
class FilterDigitResultDict(MatchTemplateMultipleResult):
|
||||
digit: int
|
||||
|
||||
|
||||
def filter_digit_results(
|
||||
results: Dict[int, List[MatchTemplateMultipleResult]], threshold: int
|
||||
):
|
||||
result_sorted_by_x_pos: Dict[
|
||||
int, List[FilterDigitResultDict]
|
||||
] = {} # Dictionary to store results sorted by x-position
|
||||
|
||||
# Iterate over each digit and its match results
|
||||
for digit, match_results in results.items():
|
||||
if match_results:
|
||||
# Iterate over each match result
|
||||
for result in match_results:
|
||||
x_pos = result["xywh"][0] # Extract x-position from result
|
||||
_dict = {**result, "digit": digit} # Add digit information to result
|
||||
|
||||
# Store result in result_sorted_by_x_pos dictionary
|
||||
if result_sorted_by_x_pos.get(x_pos) is None:
|
||||
result_sorted_by_x_pos[x_pos] = [_dict]
|
||||
else:
|
||||
result_sorted_by_x_pos[x_pos].append(_dict)
|
||||
|
||||
x_poses_grouped: List[List[int]] = group_numbers(
|
||||
list(result_sorted_by_x_pos), threshold
|
||||
) # Group x-positions based on threshold
|
||||
|
||||
final_result: Dict[
|
||||
int, List[MatchTemplateMultipleResult]
|
||||
] = {} # Dictionary to store final filtered results
|
||||
|
||||
# Iterate over each group of x-positions
|
||||
for x_poses in x_poses_grouped:
|
||||
possible_results = []
|
||||
# Iterate over each x-position in the group
|
||||
for x_pos in x_poses:
|
||||
# Retrieve all results associated with the x-position
|
||||
possible_results.extend(result_sorted_by_x_pos.get(x_pos, []))
|
||||
|
||||
if possible_results:
|
||||
# Sort the results based on "max_val" in descending order and select the top result
|
||||
result = sorted(
|
||||
possible_results,
|
||||
key=lambda d: d["max_val"],
|
||||
reverse=True,
|
||||
)[0]
|
||||
result_digit = result["digit"] # Get the digit value from the result
|
||||
result.pop("digit", None) # Remove the digit key from the result
|
||||
|
||||
# Store the result in the final_result dictionary
|
||||
if final_result.get(result_digit) is None:
|
||||
final_result[result_digit] = [result]
|
||||
else:
|
||||
final_result[result_digit].append(result)
|
||||
|
||||
return final_result
|
||||
|
||||
|
||||
def preprocess_hog(digits):
|
||||
def preprocess_hog(digit_rois):
|
||||
# https://github.com/opencv/opencv/blob/f834736307c8328340aea48908484052170c9224/samples/python/digits.py
|
||||
samples = []
|
||||
for img in digits:
|
||||
gx = cv2.Sobel(img, cv2.CV_32F, 1, 0)
|
||||
gy = cv2.Sobel(img, cv2.CV_32F, 0, 1)
|
||||
for digit in digit_rois:
|
||||
gx = cv2.Sobel(digit, cv2.CV_32F, 1, 0)
|
||||
gy = cv2.Sobel(digit, cv2.CV_32F, 0, 1)
|
||||
mag, ang = cv2.cartToPolar(gx, gy)
|
||||
bin_n = 16
|
||||
bin = np.int32(bin_n * ang / (2 * np.pi))
|
||||
bin_cells = bin[:10, :10], bin[10:, :10], bin[:10, 10:], bin[10:, 10:]
|
||||
_bin = np.int32(bin_n * ang / (2 * np.pi))
|
||||
bin_cells = _bin[:10, :10], _bin[10:, :10], _bin[:10, 10:], _bin[10:, 10:]
|
||||
mag_cells = mag[:10, :10], mag[10:, :10], mag[:10, 10:], mag[10:, 10:]
|
||||
hists = [
|
||||
np.bincount(b.ravel(), m.ravel(), bin_n)
|
||||
@ -149,125 +42,32 @@ def preprocess_hog(digits):
|
||||
return np.float32(samples)
|
||||
|
||||
|
||||
def ocr_digits(
|
||||
img: Mat,
|
||||
templates: TemplateItem,
|
||||
template_threshold: float,
|
||||
filter_threshold: int,
|
||||
):
|
||||
templates = dict(enumerate(templates[:10]))
|
||||
results: Dict[int, List[MatchTemplateMultipleResult]] = {}
|
||||
for digit, template in templates.items():
|
||||
template = imutils_resize(template, height=img.shape[0])
|
||||
results[digit] = matchTemplateMultiple(img, template, template_threshold)
|
||||
results = filter_digit_results(results, filter_threshold)
|
||||
result_x_digit_map = {}
|
||||
for digit, match_results in results.items():
|
||||
if match_results:
|
||||
for result in match_results:
|
||||
result_x_digit_map[result["xywh"][0]] = digit
|
||||
digits_sorted_by_x = dict(sorted(result_x_digit_map.items()))
|
||||
joined_str = "".join([str(digit) for digit in digits_sorted_by_x.values()])
|
||||
return int(joined_str) if joined_str else None
|
||||
def ocr_digits_by_contour_samples(__roi_gray: Mat, size: Tuple[int, int]):
|
||||
roi = __roi_gray.copy()
|
||||
contours = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
||||
rects = sorted([cv2.boundingRect(c) for c in contours], key=lambda r: r[0])
|
||||
digit_rois = [cv2.resize(crop_xywh(roi, rect), size) for rect in rects]
|
||||
return preprocess_hog(digit_rois)
|
||||
|
||||
|
||||
def ocr_digits_knn_model(img_gray: Mat, knn_model: cv2_ml_KNearest):
|
||||
if img_gray.shape[:2] != (20, 20):
|
||||
img = cv2.resize(img_gray, [20, 20])
|
||||
else:
|
||||
img = img_gray.copy()
|
||||
|
||||
img = img.astype(np.float32)
|
||||
img = img.reshape([1, -1])
|
||||
retval, _, _, _ = knn_model.findNearest(img, 10)
|
||||
return int(retval)
|
||||
def ocr_digits_by_contour_knn(
|
||||
__roi_gray: Mat,
|
||||
knn_model: cv2_ml_KNearest,
|
||||
*,
|
||||
k=4,
|
||||
size: Tuple[int, int] = (20, 20),
|
||||
) -> int:
|
||||
samples = ocr_digits_by_contour_samples(__roi_gray, size)
|
||||
_, results, _, _ = knn_model.findNearest(samples, k)
|
||||
results = [str(int(i)) for i in results.ravel()]
|
||||
return int("".join(results))
|
||||
|
||||
|
||||
def ocr_pure(img_masked: Mat):
|
||||
template = load_builtin_digit_template("default")
|
||||
return ocr_digits(
|
||||
img_masked, template.regular, template_threshold=0.6, filter_threshold=3
|
||||
)
|
||||
|
||||
|
||||
def ocr_far_lost(img_masked: Mat):
|
||||
template = load_builtin_digit_template("default")
|
||||
return ocr_digits(
|
||||
img_masked, template.italic, template_threshold=0.6, filter_threshold=3
|
||||
)
|
||||
|
||||
|
||||
def ocr_score(img_cropped: Mat):
|
||||
templates = load_builtin_digit_template("default").regular
|
||||
templates_dict = dict(enumerate(templates[:10]))
|
||||
|
||||
cnts = cv2.findContours(
|
||||
img_cropped.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
cnts = grab_contours(cnts)
|
||||
rects = [cv2.boundingRect(cnt) for cnt in cnts]
|
||||
rects = sorted(rects, key=lambda r: r[0])
|
||||
|
||||
# debug
|
||||
# [rectangle(img, rect, (128, 128, 128), 2) for rect in rects]
|
||||
# imshow("img", img)
|
||||
# waitKey(0)
|
||||
|
||||
result = ""
|
||||
for rect in rects:
|
||||
x, y, w, h = rect
|
||||
roi = img_cropped[y : y + h, x : x + w]
|
||||
|
||||
digit_results: Dict[int, float] = {}
|
||||
for digit, template in templates_dict.items():
|
||||
template = cv2.resize(template, roi.shape[::-1])
|
||||
template_result = cv2.matchTemplate(roi, template, cv2.TM_CCOEFF_NORMED)
|
||||
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(template_result)
|
||||
digit_results[digit] = max_val
|
||||
|
||||
digit_results = {k: v for k, v in digit_results.items() if v > 0.5}
|
||||
if digit_results:
|
||||
best_match_digit = max(digit_results, key=digit_results.get)
|
||||
result += str(best_match_digit)
|
||||
return int(result) if result else None
|
||||
|
||||
# return ocr_digits(
|
||||
# img_cropped,
|
||||
# template.regular,
|
||||
# template_threshold=0.5,
|
||||
# filter_threshold=10,
|
||||
# )
|
||||
|
||||
|
||||
def ocr_max_recall(img_cropped: Mat):
|
||||
try:
|
||||
texts = image_to_string(img_cropped).split(" ") # type: List[str]
|
||||
texts.reverse()
|
||||
for text in texts:
|
||||
if re.match(r"^[0-9]+$", text):
|
||||
return int(text)
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def ocr_rating_class(img_cropped: Mat):
|
||||
try:
|
||||
text = image_to_string(img_cropped) # type: str
|
||||
text = text.lower()
|
||||
if "past" in text:
|
||||
return 0
|
||||
elif "present" in text:
|
||||
return 1
|
||||
elif "future" in text:
|
||||
return 2
|
||||
elif "beyond" in text:
|
||||
return 3
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def ocr_title(img_cropped: Mat):
|
||||
try:
|
||||
return image_to_string(img_cropped).replace("\n", "")
|
||||
except Exception as e:
|
||||
return ""
|
||||
def ocr_rating_class(roi_hsv: Mat):
|
||||
mask_results = [
|
||||
mask_pst(roi_hsv),
|
||||
mask_prs(roi_hsv),
|
||||
mask_ftr(roi_hsv),
|
||||
mask_byd(roi_hsv),
|
||||
]
|
||||
return max(enumerate(mask_results), key=lambda e: np.count_nonzero(e[1]))[0]
|
||||
|
@ -1,168 +0,0 @@
|
||||
import pickle
|
||||
from base64 import b64decode
|
||||
from typing import Any, Dict, List, Literal, Tuple, TypedDict, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from ._builtin_templates import (
|
||||
DEFAULT_ITALIC,
|
||||
DEFAULT_ITALIC_ERODED,
|
||||
DEFAULT_REGULAR,
|
||||
DEFAULT_REGULAR_ERODED,
|
||||
)
|
||||
from .types import Mat
|
||||
|
||||
__all__ = [
|
||||
"TemplateItem",
|
||||
"DigitTemplate",
|
||||
"load_builtin_digit_template",
|
||||
"MatchTemplateMultipleResult",
|
||||
"matchTemplateMultiple",
|
||||
]
|
||||
|
||||
# a list of Mat showing following characters:
|
||||
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ']
|
||||
TemplateItem = Union[List[Mat], Tuple[Mat]]
|
||||
|
||||
|
||||
class DigitTemplate:
|
||||
__slots__ = ["regular", "italic", "regular_eroded", "italic_eroded"]
|
||||
|
||||
regular: TemplateItem
|
||||
italic: TemplateItem
|
||||
regular_eroded: TemplateItem
|
||||
italic_eroded: TemplateItem
|
||||
|
||||
def __ensure_template_item(self, item):
|
||||
return (
|
||||
isinstance(item, (list, tuple))
|
||||
and len(item) == 11
|
||||
and all(isinstance(i, np.ndarray) for i in item)
|
||||
)
|
||||
|
||||
def __init__(self, regular, italic, regular_eroded, italic_eroded):
|
||||
self.regular = regular
|
||||
self.italic = italic
|
||||
self.regular_eroded = regular_eroded
|
||||
self.italic_eroded = italic_eroded
|
||||
|
||||
def __setattr__(self, __name: str, __value: Any):
|
||||
if __name in {
|
||||
"regular",
|
||||
"italic",
|
||||
"regular_eroded",
|
||||
"italic_eroded",
|
||||
} and self.__ensure_template_item(__value):
|
||||
super().__setattr__(__name, __value)
|
||||
return
|
||||
|
||||
raise ValueError(
|
||||
"Invalid attribute set, expected type TemplateItem or invalid attribute name."
|
||||
)
|
||||
|
||||
|
||||
def load_builtin_digit_template(name: Literal["default"]) -> DigitTemplate:
|
||||
CONSTANTS = {
|
||||
"default": [
|
||||
DEFAULT_REGULAR,
|
||||
DEFAULT_ITALIC,
|
||||
DEFAULT_REGULAR_ERODED,
|
||||
DEFAULT_ITALIC_ERODED,
|
||||
]
|
||||
}
|
||||
args = CONSTANTS[name]
|
||||
args = [
|
||||
[pickle.loads(b64decode(encoded_str)) for encoded_str in arg] for arg in args
|
||||
]
|
||||
|
||||
return DigitTemplate(*args)
|
||||
|
||||
|
||||
class MatchTemplateMultipleResult(TypedDict):
|
||||
max_val: float
|
||||
xywh: Tuple[int, int, int, int]
|
||||
|
||||
|
||||
def matchTemplateMultiple(
|
||||
src: Mat, template: Mat, threshold: float = 0.1
|
||||
) -> List[MatchTemplateMultipleResult]:
|
||||
template_result = cv2.matchTemplate(src, template, cv2.TM_CCOEFF_NORMED)
|
||||
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(template_result)
|
||||
template_h, template_w = template.shape[:2]
|
||||
results = []
|
||||
|
||||
# debug
|
||||
# imshow("templ", template)
|
||||
# waitKey(750)
|
||||
# destroyAllWindows()
|
||||
|
||||
# https://stackoverflow.com/a/66848923/16484891
|
||||
# CC BY-SA 4.0
|
||||
prev_min_val, prev_max_val, prev_min_loc, prev_max_loc = None, None, None, None
|
||||
while max_val > threshold:
|
||||
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(template_result)
|
||||
|
||||
# Prevent infinite loop. If those 4 values are the same as previous ones, break the loop.
|
||||
if (
|
||||
prev_min_val == min_val
|
||||
and prev_max_val == max_val
|
||||
and prev_min_loc == min_loc
|
||||
and prev_max_loc == max_loc
|
||||
):
|
||||
break
|
||||
else:
|
||||
prev_min_val, prev_max_val, prev_min_loc, prev_max_loc = (
|
||||
min_val,
|
||||
max_val,
|
||||
min_loc,
|
||||
max_loc,
|
||||
)
|
||||
|
||||
if max_val > threshold:
|
||||
# Prevent start_row, end_row, start_col, end_col be out of range of image
|
||||
start_row = max(0, max_loc[1] - template_h // 2)
|
||||
start_col = max(0, max_loc[0] - template_w // 2)
|
||||
end_row = min(template_result.shape[0], max_loc[1] + template_h // 2 + 1)
|
||||
end_col = min(template_result.shape[1], max_loc[0] + template_w // 2 + 1)
|
||||
|
||||
template_result[start_row:end_row, start_col:end_col] = 0
|
||||
results.append(
|
||||
{
|
||||
"max_val": max_val,
|
||||
"xywh": (
|
||||
max_loc[0],
|
||||
max_loc[1],
|
||||
max_loc[0] + template_w + 1,
|
||||
max_loc[1] + template_h + 1,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# debug
|
||||
# src_dbg = cvtColor(src, COLOR_GRAY2BGR)
|
||||
# src_dbg = rectangle(
|
||||
# src_dbg,
|
||||
# (max_loc[0], max_loc[1]),
|
||||
# (
|
||||
# max_loc[0] + template_w + 1,
|
||||
# max_loc[1] + template_h + 1,
|
||||
# ),
|
||||
# (0, 255, 0),
|
||||
# thickness=3,
|
||||
# )
|
||||
# src_dbg = putText(
|
||||
# src_dbg,
|
||||
# f"{max_val:.5f}",
|
||||
# (5, src_dbg.shape[0] - 5),
|
||||
# FONT_HERSHEY_SIMPLEX,
|
||||
# 1,
|
||||
# (0, 255, 0),
|
||||
# thickness=2,
|
||||
# )
|
||||
# imshow("src_rect", src_dbg)
|
||||
# imshow("templ", template)
|
||||
# waitKey(750)
|
||||
# destroyAllWindows()
|
||||
|
||||
return results
|
Loading…
x
Reference in New Issue
Block a user