wip!: template rewrite & new ocr method for score

This commit is contained in:
283375 2023-06-21 02:53:53 +08:00
parent 0fe5f09de0
commit c76b656f3d
4 changed files with 219 additions and 69 deletions

View File

@ -1,25 +1,100 @@
import base64 import base64
import json import json
import pickle
import cv2 import cv2
from src.arcaea_offline_ocr.template import load_digit_template import imutils
import numpy
from imutils import contours
def load_template_image(filename: str) -> dict[int, cv2.Mat]:
"""
Arguments:
filename -- An image with white background and black "0 1 2 3 4 5 6 7 8 9 '" text.
"""
# https://pyimagesearch.com/2017/07/17/credit-card-ocr-with-opencv-and-python/
ref = cv2.imread(filename)
ref = cv2.cvtColor(ref, cv2.COLOR_BGR2GRAY)
ref = cv2.threshold(ref, 10, 255, cv2.THRESH_BINARY_INV)[1]
refCnts = cv2.findContours(ref.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
refCnts = imutils.grab_contours(refCnts)
refCnts = contours.sort_contours(refCnts, method="left-to-right")[0]
digits = {}
keys = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, "'"]
for key, cnt in zip(keys, refCnts):
(x, y, w, h) = cv2.boundingRect(cnt)
roi = ref[y : y + h, x : x + w]
digits[key] = roi
return list(digits.values())
def process_default(img_path: str):
template_res = load_template_image(img_path)
template_res_pickled = [
base64.b64encode(
pickle.dumps(template_arr, protocol=pickle.HIGHEST_PROTOCOL)
).decode("utf-8")
for template_arr in template_res
]
return json.dumps(template_res_pickled)
def process_eroded(img_path: str):
kernel = numpy.ones((5, 5), numpy.uint8)
template_res = load_template_image(img_path)
template_res_eroded = []
# cv2.imshow("orig", template_res[7])
for template in template_res:
# add borders
template = cv2.copyMakeBorder(
template, 10, 10, 10, 10, cv2.BORDER_CONSTANT, None, (0, 0, 0)
)
# erode
template = cv2.erode(template, kernel)
# remove borders
h, w = template.shape
template = template[10 : h - 10, 10 : w - 10]
template_res_eroded.append(template)
# cv2.imshow("erode", template_res_eroded[7])
# cv2.waitKey(0)
template_res_pickled = [
base64.b64encode(
pickle.dumps(template_arr, protocol=pickle.HIGHEST_PROTOCOL)
).decode("utf-8")
for template_arr in template_res_eroded
]
return json.dumps(template_res_pickled)
TEMPLATES = [ TEMPLATES = [
("GeoSansLight_Regular", "./assets/templates/GeoSansLightRegular.png"), (
("GeoSansLight_Italic", "./assets/templates/GeoSansLightItalic.png"), "DEFAULT_REGULAR",
"./assets/templates/GeoSansLightRegular.png",
process_default,
),
(
"DEFAULT_ITALIC",
"./assets/templates/GeoSansLightItalic.png",
process_default,
),
(
"DEFAULT_REGULAR_ERODED",
"./assets/templates/GeoSansLightRegular.png",
process_eroded,
),
(
"DEFAULT_ITALIC_ERODED",
"./assets/templates/GeoSansLightItalic.png",
process_eroded,
),
] ]
OUTPUT_FILE = "_builtin_templates.py" OUTPUT_FILE = "_builtin_templates.py"
output = "" output = ""
for name, file in TEMPLATES: for name, img_path, process_func in TEMPLATES:
template_res = load_digit_template(file) output += f"{name} = {process_func(img_path)}"
template_res_b64 = {
key: base64.b64encode(cv2.imencode(".png", template_img)[1]).decode("utf-8")
for key, template_img in template_res.items()
}
# jpg_as_text = base64.b64encode(buffer)
output += f"{name} = {json.dumps(template_res_b64)}"
output += "\n" output += "\n"
with open(OUTPUT_FILE, "w", encoding="utf-8") as of: with open(OUTPUT_FILE, "w", encoding="utf-8") as of:

File diff suppressed because one or more lines are too long

View File

@ -1,12 +1,27 @@
import re import re
from typing import Dict, List from typing import Dict, List
from cv2 import Mat from cv2 import (
from imutils import resize CHAIN_APPROX_SIMPLE,
RETR_EXTERNAL,
TM_CCOEFF_NORMED,
Mat,
boundingRect,
findContours,
imshow,
matchTemplate,
minMaxLoc,
rectangle,
resize,
waitKey,
)
from imutils import grab_contours
from imutils import resize as imutils_resize
from pytesseract import image_to_string from pytesseract import image_to_string
from .template import ( from .template import (
MatchTemplateMultipleResult, MatchTemplateMultipleResult,
TemplateItem,
load_builtin_digit_template, load_builtin_digit_template,
matchTemplateMultiple, matchTemplateMultiple,
) )
@ -120,13 +135,14 @@ def filter_digit_results(
def ocr_digits( def ocr_digits(
img: Mat, img: Mat,
templates: Dict[int, Mat], templates: TemplateItem,
template_threshold: float, template_threshold: float,
filter_threshold: int, filter_threshold: int,
): ):
templates = dict(enumerate(templates[:10]))
results: Dict[int, List[MatchTemplateMultipleResult]] = {} results: Dict[int, List[MatchTemplateMultipleResult]] = {}
for digit, template in templates.items(): for digit, template in templates.items():
template = resize(template, height=img.shape[0]) template = imutils_resize(template, height=img.shape[0])
results[digit] = matchTemplateMultiple(img, template, template_threshold) results[digit] = matchTemplateMultiple(img, template, template_threshold)
results = filter_digit_results(results, filter_threshold) results = filter_digit_results(results, filter_threshold)
result_x_digit_map = {} result_x_digit_map = {}
@ -140,20 +156,57 @@ def ocr_digits(
def ocr_pure(img_masked: Mat): def ocr_pure(img_masked: Mat):
templates = load_builtin_digit_template("GeoSansLight-Regular") template = load_builtin_digit_template("default")
return ocr_digits(img_masked, templates, template_threshold=0.6, filter_threshold=3) return ocr_digits(
img_masked, template.regular, template_threshold=0.6, filter_threshold=3
)
def ocr_far_lost(img_masked: Mat): def ocr_far_lost(img_masked: Mat):
templates = load_builtin_digit_template("GeoSansLight-Italic") template = load_builtin_digit_template("default")
return ocr_digits(img_masked, templates, template_threshold=0.6, filter_threshold=3) return ocr_digits(
img_masked, template.italic, template_threshold=0.6, filter_threshold=3
)
def ocr_score(img_cropped: Mat): def ocr_score(img_cropped: Mat):
templates = load_builtin_digit_template("GeoSansLight-Regular") templates = load_builtin_digit_template("default").regular
return ocr_digits( templates_dict = dict(enumerate(templates[:10]))
img_cropped, templates, template_threshold=0.5, filter_threshold=10
) cnts = findContours(img_cropped.copy(), RETR_EXTERNAL, CHAIN_APPROX_SIMPLE)
cnts = grab_contours(cnts)
rects = [boundingRect(cnt) for cnt in cnts]
rects = sorted(rects, key=lambda r: r[0])
# debug
# [rectangle(img, rect, (128, 128, 128), 2) for rect in rects]
# imshow("img", img)
# waitKey(0)
result = ""
for rect in rects:
x, y, w, h = rect
roi = img_cropped[y : y + h, x : x + w]
digit_results: Dict[int, float] = {}
for digit, template in templates_dict.items():
template = resize(template, roi.shape[::-1])
template_result = matchTemplate(roi, template, TM_CCOEFF_NORMED)
min_val, max_val, min_loc, max_loc = minMaxLoc(template_result)
digit_results[digit] = max_val
digit_results = {k: v for k, v in digit_results.items() if v > 0.5}
if digit_results:
best_match_digit = max(digit_results, key=digit_results.get)
result += str(best_match_digit)
return int(result) if result else None
# return ocr_digits(
# img_cropped,
# template.regular,
# template_threshold=0.5,
# filter_threshold=10,
# )
def ocr_max_recall(img_cropped: Mat): def ocr_max_recall(img_cropped: Mat):

View File

@ -1,6 +1,6 @@
import pickle
from base64 import b64decode from base64 import b64decode
from time import sleep from typing import Any, Dict, List, Literal, Tuple, TypedDict, Union
from typing import Dict, List, Literal, Tuple, TypedDict
from cv2 import ( from cv2 import (
CHAIN_APPROX_SIMPLE, CHAIN_APPROX_SIMPLE,
@ -26,55 +26,79 @@ from cv2 import (
threshold, threshold,
waitKey, waitKey,
) )
from imutils import contours, grab_contours from numpy import ndarray
from numpy import frombuffer as np_frombuffer
from numpy import uint8
from ._builtin_templates import GeoSansLight_Italic, GeoSansLight_Regular from ._builtin_templates import (
DEFAULT_ITALIC,
DEFAULT_ITALIC_ERODED,
DEFAULT_REGULAR,
DEFAULT_REGULAR_ERODED,
)
__all__ = [ __all__ = [
"load_digit_template", "TemplateItem",
"DigitTemplate",
"load_builtin_digit_template", "load_builtin_digit_template",
"MatchTemplateMultipleResult", "MatchTemplateMultipleResult",
"matchTemplateMultiple", "matchTemplateMultiple",
] ]
# a list of Mat showing following characters:
def load_digit_template(filename: str) -> Dict[int, Mat]: # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ']
""" TemplateItem = Union[List[Mat], Tuple[Mat]]
Arguments:
filename -- An image with white background and black "0 1 2 3 4 5 6 7 8 9" text.
Returns:
dict[int, cv2.Mat]
"""
# https://pyimagesearch.com/2017/07/17/credit-card-ocr-with-opencv-and-python/
ref = imread(filename)
ref = cvtColor(ref, COLOR_BGR2GRAY)
ref = threshold(ref, 10, 255, THRESH_BINARY_INV)[1]
refCnts = findContours(ref.copy(), RETR_EXTERNAL, CHAIN_APPROX_SIMPLE)
refCnts = grab_contours(refCnts)
refCnts = contours.sort_contours(refCnts, method="left-to-right")[0]
digits = {}
for i, cnt in enumerate(refCnts):
(x, y, w, h) = boundingRect(cnt)
roi = ref[y : y + h, x : x + w]
digits[i] = roi
return digits
def load_builtin_digit_template( class DigitTemplate:
name: Literal["GeoSansLight-Regular", "GeoSansLight-Italic"] __slots__ = ["regular", "italic", "regular_eroded", "italic_eroded"]
):
name_builtin_template_b64_map = { regular: TemplateItem
"GeoSansLight-Regular": GeoSansLight_Regular, italic: TemplateItem
"GeoSansLight-Italic": GeoSansLight_Italic, regular_eroded: TemplateItem
} italic_eroded: TemplateItem
template_b64 = name_builtin_template_b64_map[name]
return { def __ensure_template_item(self, item):
int(key): imdecode(np_frombuffer(b64decode(b64str), uint8), IMREAD_GRAYSCALE) return (
for key, b64str in template_b64.items() isinstance(item, (list, tuple))
and len(item) == 11
and all(isinstance(i, ndarray) for i in item)
)
def __init__(self, regular, italic, regular_eroded, italic_eroded):
self.regular = regular
self.italic = italic
self.regular_eroded = regular_eroded
self.italic_eroded = italic_eroded
def __setattr__(self, __name: str, __value: Any):
if __name in {
"regular",
"italic",
"regular_eroded",
"italic_eroded",
} and self.__ensure_template_item(__value):
super().__setattr__(__name, __value)
return
raise ValueError(
"Invalid attribute set, expected type TemplateItem or invalid attribute name."
)
def load_builtin_digit_template(name: Literal["default"]) -> DigitTemplate:
CONSTANTS = {
"default": [
DEFAULT_REGULAR,
DEFAULT_ITALIC,
DEFAULT_REGULAR_ERODED,
DEFAULT_ITALIC_ERODED,
]
} }
args = CONSTANTS[name]
args = [
[pickle.loads(b64decode(encoded_str)) for encoded_str in arg] for arg in args
]
return DigitTemplate(*args)
class MatchTemplateMultipleResult(TypedDict): class MatchTemplateMultipleResult(TypedDict):
@ -85,10 +109,6 @@ class MatchTemplateMultipleResult(TypedDict):
def matchTemplateMultiple( def matchTemplateMultiple(
src: Mat, template: Mat, threshold: float = 0.1 src: Mat, template: Mat, threshold: float = 0.1
) -> List[MatchTemplateMultipleResult]: ) -> List[MatchTemplateMultipleResult]:
"""
Returns:
A list of tuple[x, y, w, h] representing the matched rectangle
"""
template_result = matchTemplate(src, template, TM_CCOEFF_NORMED) template_result = matchTemplate(src, template, TM_CCOEFF_NORMED)
min_val, max_val, min_loc, max_loc = minMaxLoc(template_result) min_val, max_val, min_loc, max_loc = minMaxLoc(template_result)
template_h, template_w = template.shape[:2] template_h, template_w = template.shape[:2]