wip!: template rewrite & new ocr method for score

This commit is contained in:
283375 2023-06-21 02:53:53 +08:00
parent 0fe5f09de0
commit c76b656f3d
4 changed files with 219 additions and 69 deletions

View File

@ -1,25 +1,100 @@
import base64
import json
import pickle
import cv2
from src.arcaea_offline_ocr.template import load_digit_template
import imutils
import numpy
from imutils import contours
def load_template_image(filename: str) -> dict[int, cv2.Mat]:
"""
Arguments:
filename -- An image with white background and black "0 1 2 3 4 5 6 7 8 9 '" text.
"""
# https://pyimagesearch.com/2017/07/17/credit-card-ocr-with-opencv-and-python/
ref = cv2.imread(filename)
ref = cv2.cvtColor(ref, cv2.COLOR_BGR2GRAY)
ref = cv2.threshold(ref, 10, 255, cv2.THRESH_BINARY_INV)[1]
refCnts = cv2.findContours(ref.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
refCnts = imutils.grab_contours(refCnts)
refCnts = contours.sort_contours(refCnts, method="left-to-right")[0]
digits = {}
keys = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, "'"]
for key, cnt in zip(keys, refCnts):
(x, y, w, h) = cv2.boundingRect(cnt)
roi = ref[y : y + h, x : x + w]
digits[key] = roi
return list(digits.values())
def process_default(img_path: str):
template_res = load_template_image(img_path)
template_res_pickled = [
base64.b64encode(
pickle.dumps(template_arr, protocol=pickle.HIGHEST_PROTOCOL)
).decode("utf-8")
for template_arr in template_res
]
return json.dumps(template_res_pickled)
def process_eroded(img_path: str):
kernel = numpy.ones((5, 5), numpy.uint8)
template_res = load_template_image(img_path)
template_res_eroded = []
# cv2.imshow("orig", template_res[7])
for template in template_res:
# add borders
template = cv2.copyMakeBorder(
template, 10, 10, 10, 10, cv2.BORDER_CONSTANT, None, (0, 0, 0)
)
# erode
template = cv2.erode(template, kernel)
# remove borders
h, w = template.shape
template = template[10 : h - 10, 10 : w - 10]
template_res_eroded.append(template)
# cv2.imshow("erode", template_res_eroded[7])
# cv2.waitKey(0)
template_res_pickled = [
base64.b64encode(
pickle.dumps(template_arr, protocol=pickle.HIGHEST_PROTOCOL)
).decode("utf-8")
for template_arr in template_res_eroded
]
return json.dumps(template_res_pickled)
TEMPLATES = [
("GeoSansLight_Regular", "./assets/templates/GeoSansLightRegular.png"),
("GeoSansLight_Italic", "./assets/templates/GeoSansLightItalic.png"),
(
"DEFAULT_REGULAR",
"./assets/templates/GeoSansLightRegular.png",
process_default,
),
(
"DEFAULT_ITALIC",
"./assets/templates/GeoSansLightItalic.png",
process_default,
),
(
"DEFAULT_REGULAR_ERODED",
"./assets/templates/GeoSansLightRegular.png",
process_eroded,
),
(
"DEFAULT_ITALIC_ERODED",
"./assets/templates/GeoSansLightItalic.png",
process_eroded,
),
]
OUTPUT_FILE = "_builtin_templates.py"
output = ""
for name, file in TEMPLATES:
template_res = load_digit_template(file)
template_res_b64 = {
key: base64.b64encode(cv2.imencode(".png", template_img)[1]).decode("utf-8")
for key, template_img in template_res.items()
}
# jpg_as_text = base64.b64encode(buffer)
output += f"{name} = {json.dumps(template_res_b64)}"
for name, img_path, process_func in TEMPLATES:
output += f"{name} = {process_func(img_path)}"
output += "\n"
with open(OUTPUT_FILE, "w", encoding="utf-8") as of:

File diff suppressed because one or more lines are too long

View File

@ -1,12 +1,27 @@
import re
from typing import Dict, List
from cv2 import Mat
from imutils import resize
from cv2 import (
CHAIN_APPROX_SIMPLE,
RETR_EXTERNAL,
TM_CCOEFF_NORMED,
Mat,
boundingRect,
findContours,
imshow,
matchTemplate,
minMaxLoc,
rectangle,
resize,
waitKey,
)
from imutils import grab_contours
from imutils import resize as imutils_resize
from pytesseract import image_to_string
from .template import (
MatchTemplateMultipleResult,
TemplateItem,
load_builtin_digit_template,
matchTemplateMultiple,
)
@ -120,13 +135,14 @@ def filter_digit_results(
def ocr_digits(
img: Mat,
templates: Dict[int, Mat],
templates: TemplateItem,
template_threshold: float,
filter_threshold: int,
):
templates = dict(enumerate(templates[:10]))
results: Dict[int, List[MatchTemplateMultipleResult]] = {}
for digit, template in templates.items():
template = resize(template, height=img.shape[0])
template = imutils_resize(template, height=img.shape[0])
results[digit] = matchTemplateMultiple(img, template, template_threshold)
results = filter_digit_results(results, filter_threshold)
result_x_digit_map = {}
@ -140,20 +156,57 @@ def ocr_digits(
def ocr_pure(img_masked: Mat):
templates = load_builtin_digit_template("GeoSansLight-Regular")
return ocr_digits(img_masked, templates, template_threshold=0.6, filter_threshold=3)
template = load_builtin_digit_template("default")
return ocr_digits(
img_masked, template.regular, template_threshold=0.6, filter_threshold=3
)
def ocr_far_lost(img_masked: Mat):
templates = load_builtin_digit_template("GeoSansLight-Italic")
return ocr_digits(img_masked, templates, template_threshold=0.6, filter_threshold=3)
template = load_builtin_digit_template("default")
return ocr_digits(
img_masked, template.italic, template_threshold=0.6, filter_threshold=3
)
def ocr_score(img_cropped: Mat):
templates = load_builtin_digit_template("GeoSansLight-Regular")
return ocr_digits(
img_cropped, templates, template_threshold=0.5, filter_threshold=10
)
templates = load_builtin_digit_template("default").regular
templates_dict = dict(enumerate(templates[:10]))
cnts = findContours(img_cropped.copy(), RETR_EXTERNAL, CHAIN_APPROX_SIMPLE)
cnts = grab_contours(cnts)
rects = [boundingRect(cnt) for cnt in cnts]
rects = sorted(rects, key=lambda r: r[0])
# debug
# [rectangle(img, rect, (128, 128, 128), 2) for rect in rects]
# imshow("img", img)
# waitKey(0)
result = ""
for rect in rects:
x, y, w, h = rect
roi = img_cropped[y : y + h, x : x + w]
digit_results: Dict[int, float] = {}
for digit, template in templates_dict.items():
template = resize(template, roi.shape[::-1])
template_result = matchTemplate(roi, template, TM_CCOEFF_NORMED)
min_val, max_val, min_loc, max_loc = minMaxLoc(template_result)
digit_results[digit] = max_val
digit_results = {k: v for k, v in digit_results.items() if v > 0.5}
if digit_results:
best_match_digit = max(digit_results, key=digit_results.get)
result += str(best_match_digit)
return int(result) if result else None
# return ocr_digits(
# img_cropped,
# template.regular,
# template_threshold=0.5,
# filter_threshold=10,
# )
def ocr_max_recall(img_cropped: Mat):

View File

@ -1,6 +1,6 @@
import pickle
from base64 import b64decode
from time import sleep
from typing import Dict, List, Literal, Tuple, TypedDict
from typing import Any, Dict, List, Literal, Tuple, TypedDict, Union
from cv2 import (
CHAIN_APPROX_SIMPLE,
@ -26,55 +26,79 @@ from cv2 import (
threshold,
waitKey,
)
from imutils import contours, grab_contours
from numpy import frombuffer as np_frombuffer
from numpy import uint8
from numpy import ndarray
from ._builtin_templates import GeoSansLight_Italic, GeoSansLight_Regular
from ._builtin_templates import (
DEFAULT_ITALIC,
DEFAULT_ITALIC_ERODED,
DEFAULT_REGULAR,
DEFAULT_REGULAR_ERODED,
)
__all__ = [
"load_digit_template",
"TemplateItem",
"DigitTemplate",
"load_builtin_digit_template",
"MatchTemplateMultipleResult",
"matchTemplateMultiple",
]
def load_digit_template(filename: str) -> Dict[int, Mat]:
"""
Arguments:
filename -- An image with white background and black "0 1 2 3 4 5 6 7 8 9" text.
Returns:
dict[int, cv2.Mat]
"""
# https://pyimagesearch.com/2017/07/17/credit-card-ocr-with-opencv-and-python/
ref = imread(filename)
ref = cvtColor(ref, COLOR_BGR2GRAY)
ref = threshold(ref, 10, 255, THRESH_BINARY_INV)[1]
refCnts = findContours(ref.copy(), RETR_EXTERNAL, CHAIN_APPROX_SIMPLE)
refCnts = grab_contours(refCnts)
refCnts = contours.sort_contours(refCnts, method="left-to-right")[0]
digits = {}
for i, cnt in enumerate(refCnts):
(x, y, w, h) = boundingRect(cnt)
roi = ref[y : y + h, x : x + w]
digits[i] = roi
return digits
# a list of Mat showing following characters:
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ']
TemplateItem = Union[List[Mat], Tuple[Mat]]
def load_builtin_digit_template(
name: Literal["GeoSansLight-Regular", "GeoSansLight-Italic"]
):
name_builtin_template_b64_map = {
"GeoSansLight-Regular": GeoSansLight_Regular,
"GeoSansLight-Italic": GeoSansLight_Italic,
}
template_b64 = name_builtin_template_b64_map[name]
return {
int(key): imdecode(np_frombuffer(b64decode(b64str), uint8), IMREAD_GRAYSCALE)
for key, b64str in template_b64.items()
class DigitTemplate:
__slots__ = ["regular", "italic", "regular_eroded", "italic_eroded"]
regular: TemplateItem
italic: TemplateItem
regular_eroded: TemplateItem
italic_eroded: TemplateItem
def __ensure_template_item(self, item):
return (
isinstance(item, (list, tuple))
and len(item) == 11
and all(isinstance(i, ndarray) for i in item)
)
def __init__(self, regular, italic, regular_eroded, italic_eroded):
self.regular = regular
self.italic = italic
self.regular_eroded = regular_eroded
self.italic_eroded = italic_eroded
def __setattr__(self, __name: str, __value: Any):
if __name in {
"regular",
"italic",
"regular_eroded",
"italic_eroded",
} and self.__ensure_template_item(__value):
super().__setattr__(__name, __value)
return
raise ValueError(
"Invalid attribute set, expected type TemplateItem or invalid attribute name."
)
def load_builtin_digit_template(name: Literal["default"]) -> DigitTemplate:
CONSTANTS = {
"default": [
DEFAULT_REGULAR,
DEFAULT_ITALIC,
DEFAULT_REGULAR_ERODED,
DEFAULT_ITALIC_ERODED,
]
}
args = CONSTANTS[name]
args = [
[pickle.loads(b64decode(encoded_str)) for encoded_str in arg] for arg in args
]
return DigitTemplate(*args)
class MatchTemplateMultipleResult(TypedDict):
@ -85,10 +109,6 @@ class MatchTemplateMultipleResult(TypedDict):
def matchTemplateMultiple(
src: Mat, template: Mat, threshold: float = 0.1
) -> List[MatchTemplateMultipleResult]:
"""
Returns:
A list of tuple[x, y, w, h] representing the matched rectangle
"""
template_result = matchTemplate(src, template, TM_CCOEFF_NORMED)
min_val, max_val, min_loc, max_loc = minMaxLoc(template_result)
template_h, template_w = template.shape[:2]