mirror of
https://github.com/283375/arcaea-offline-ocr.git
synced 2025-04-19 05:20:17 +00:00
wip!: template rewrite & new ocr method for score
This commit is contained in:
parent
0fe5f09de0
commit
c76b656f3d
@ -1,25 +1,100 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
import pickle
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
from src.arcaea_offline_ocr.template import load_digit_template
|
import imutils
|
||||||
|
import numpy
|
||||||
|
from imutils import contours
|
||||||
|
|
||||||
|
|
||||||
|
def load_template_image(filename: str) -> dict[int, cv2.Mat]:
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
filename -- An image with white background and black "0 1 2 3 4 5 6 7 8 9 '" text.
|
||||||
|
"""
|
||||||
|
# https://pyimagesearch.com/2017/07/17/credit-card-ocr-with-opencv-and-python/
|
||||||
|
ref = cv2.imread(filename)
|
||||||
|
ref = cv2.cvtColor(ref, cv2.COLOR_BGR2GRAY)
|
||||||
|
ref = cv2.threshold(ref, 10, 255, cv2.THRESH_BINARY_INV)[1]
|
||||||
|
refCnts = cv2.findContours(ref.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
refCnts = imutils.grab_contours(refCnts)
|
||||||
|
refCnts = contours.sort_contours(refCnts, method="left-to-right")[0]
|
||||||
|
digits = {}
|
||||||
|
keys = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, "'"]
|
||||||
|
for key, cnt in zip(keys, refCnts):
|
||||||
|
(x, y, w, h) = cv2.boundingRect(cnt)
|
||||||
|
roi = ref[y : y + h, x : x + w]
|
||||||
|
digits[key] = roi
|
||||||
|
return list(digits.values())
|
||||||
|
|
||||||
|
|
||||||
|
def process_default(img_path: str):
|
||||||
|
template_res = load_template_image(img_path)
|
||||||
|
template_res_pickled = [
|
||||||
|
base64.b64encode(
|
||||||
|
pickle.dumps(template_arr, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
).decode("utf-8")
|
||||||
|
for template_arr in template_res
|
||||||
|
]
|
||||||
|
return json.dumps(template_res_pickled)
|
||||||
|
|
||||||
|
|
||||||
|
def process_eroded(img_path: str):
|
||||||
|
kernel = numpy.ones((5, 5), numpy.uint8)
|
||||||
|
template_res = load_template_image(img_path)
|
||||||
|
template_res_eroded = []
|
||||||
|
# cv2.imshow("orig", template_res[7])
|
||||||
|
for template in template_res:
|
||||||
|
# add borders
|
||||||
|
template = cv2.copyMakeBorder(
|
||||||
|
template, 10, 10, 10, 10, cv2.BORDER_CONSTANT, None, (0, 0, 0)
|
||||||
|
)
|
||||||
|
# erode
|
||||||
|
template = cv2.erode(template, kernel)
|
||||||
|
# remove borders
|
||||||
|
h, w = template.shape
|
||||||
|
template = template[10 : h - 10, 10 : w - 10]
|
||||||
|
template_res_eroded.append(template)
|
||||||
|
# cv2.imshow("erode", template_res_eroded[7])
|
||||||
|
# cv2.waitKey(0)
|
||||||
|
template_res_pickled = [
|
||||||
|
base64.b64encode(
|
||||||
|
pickle.dumps(template_arr, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
).decode("utf-8")
|
||||||
|
for template_arr in template_res_eroded
|
||||||
|
]
|
||||||
|
return json.dumps(template_res_pickled)
|
||||||
|
|
||||||
|
|
||||||
TEMPLATES = [
|
TEMPLATES = [
|
||||||
("GeoSansLight_Regular", "./assets/templates/GeoSansLightRegular.png"),
|
(
|
||||||
("GeoSansLight_Italic", "./assets/templates/GeoSansLightItalic.png"),
|
"DEFAULT_REGULAR",
|
||||||
|
"./assets/templates/GeoSansLightRegular.png",
|
||||||
|
process_default,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"DEFAULT_ITALIC",
|
||||||
|
"./assets/templates/GeoSansLightItalic.png",
|
||||||
|
process_default,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"DEFAULT_REGULAR_ERODED",
|
||||||
|
"./assets/templates/GeoSansLightRegular.png",
|
||||||
|
process_eroded,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"DEFAULT_ITALIC_ERODED",
|
||||||
|
"./assets/templates/GeoSansLightItalic.png",
|
||||||
|
process_eroded,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
OUTPUT_FILE = "_builtin_templates.py"
|
OUTPUT_FILE = "_builtin_templates.py"
|
||||||
output = ""
|
output = ""
|
||||||
|
|
||||||
for name, file in TEMPLATES:
|
for name, img_path, process_func in TEMPLATES:
|
||||||
template_res = load_digit_template(file)
|
output += f"{name} = {process_func(img_path)}"
|
||||||
template_res_b64 = {
|
|
||||||
key: base64.b64encode(cv2.imencode(".png", template_img)[1]).decode("utf-8")
|
|
||||||
for key, template_img in template_res.items()
|
|
||||||
}
|
|
||||||
# jpg_as_text = base64.b64encode(buffer)
|
|
||||||
output += f"{name} = {json.dumps(template_res_b64)}"
|
|
||||||
output += "\n"
|
output += "\n"
|
||||||
|
|
||||||
with open(OUTPUT_FILE, "w", encoding="utf-8") as of:
|
with open(OUTPUT_FILE, "w", encoding="utf-8") as of:
|
||||||
|
File diff suppressed because one or more lines are too long
@ -1,12 +1,27 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from cv2 import Mat
|
from cv2 import (
|
||||||
from imutils import resize
|
CHAIN_APPROX_SIMPLE,
|
||||||
|
RETR_EXTERNAL,
|
||||||
|
TM_CCOEFF_NORMED,
|
||||||
|
Mat,
|
||||||
|
boundingRect,
|
||||||
|
findContours,
|
||||||
|
imshow,
|
||||||
|
matchTemplate,
|
||||||
|
minMaxLoc,
|
||||||
|
rectangle,
|
||||||
|
resize,
|
||||||
|
waitKey,
|
||||||
|
)
|
||||||
|
from imutils import grab_contours
|
||||||
|
from imutils import resize as imutils_resize
|
||||||
from pytesseract import image_to_string
|
from pytesseract import image_to_string
|
||||||
|
|
||||||
from .template import (
|
from .template import (
|
||||||
MatchTemplateMultipleResult,
|
MatchTemplateMultipleResult,
|
||||||
|
TemplateItem,
|
||||||
load_builtin_digit_template,
|
load_builtin_digit_template,
|
||||||
matchTemplateMultiple,
|
matchTemplateMultiple,
|
||||||
)
|
)
|
||||||
@ -120,13 +135,14 @@ def filter_digit_results(
|
|||||||
|
|
||||||
def ocr_digits(
|
def ocr_digits(
|
||||||
img: Mat,
|
img: Mat,
|
||||||
templates: Dict[int, Mat],
|
templates: TemplateItem,
|
||||||
template_threshold: float,
|
template_threshold: float,
|
||||||
filter_threshold: int,
|
filter_threshold: int,
|
||||||
):
|
):
|
||||||
|
templates = dict(enumerate(templates[:10]))
|
||||||
results: Dict[int, List[MatchTemplateMultipleResult]] = {}
|
results: Dict[int, List[MatchTemplateMultipleResult]] = {}
|
||||||
for digit, template in templates.items():
|
for digit, template in templates.items():
|
||||||
template = resize(template, height=img.shape[0])
|
template = imutils_resize(template, height=img.shape[0])
|
||||||
results[digit] = matchTemplateMultiple(img, template, template_threshold)
|
results[digit] = matchTemplateMultiple(img, template, template_threshold)
|
||||||
results = filter_digit_results(results, filter_threshold)
|
results = filter_digit_results(results, filter_threshold)
|
||||||
result_x_digit_map = {}
|
result_x_digit_map = {}
|
||||||
@ -140,20 +156,57 @@ def ocr_digits(
|
|||||||
|
|
||||||
|
|
||||||
def ocr_pure(img_masked: Mat):
|
def ocr_pure(img_masked: Mat):
|
||||||
templates = load_builtin_digit_template("GeoSansLight-Regular")
|
template = load_builtin_digit_template("default")
|
||||||
return ocr_digits(img_masked, templates, template_threshold=0.6, filter_threshold=3)
|
return ocr_digits(
|
||||||
|
img_masked, template.regular, template_threshold=0.6, filter_threshold=3
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def ocr_far_lost(img_masked: Mat):
|
def ocr_far_lost(img_masked: Mat):
|
||||||
templates = load_builtin_digit_template("GeoSansLight-Italic")
|
template = load_builtin_digit_template("default")
|
||||||
return ocr_digits(img_masked, templates, template_threshold=0.6, filter_threshold=3)
|
return ocr_digits(
|
||||||
|
img_masked, template.italic, template_threshold=0.6, filter_threshold=3
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def ocr_score(img_cropped: Mat):
|
def ocr_score(img_cropped: Mat):
|
||||||
templates = load_builtin_digit_template("GeoSansLight-Regular")
|
templates = load_builtin_digit_template("default").regular
|
||||||
return ocr_digits(
|
templates_dict = dict(enumerate(templates[:10]))
|
||||||
img_cropped, templates, template_threshold=0.5, filter_threshold=10
|
|
||||||
)
|
cnts = findContours(img_cropped.copy(), RETR_EXTERNAL, CHAIN_APPROX_SIMPLE)
|
||||||
|
cnts = grab_contours(cnts)
|
||||||
|
rects = [boundingRect(cnt) for cnt in cnts]
|
||||||
|
rects = sorted(rects, key=lambda r: r[0])
|
||||||
|
|
||||||
|
# debug
|
||||||
|
# [rectangle(img, rect, (128, 128, 128), 2) for rect in rects]
|
||||||
|
# imshow("img", img)
|
||||||
|
# waitKey(0)
|
||||||
|
|
||||||
|
result = ""
|
||||||
|
for rect in rects:
|
||||||
|
x, y, w, h = rect
|
||||||
|
roi = img_cropped[y : y + h, x : x + w]
|
||||||
|
|
||||||
|
digit_results: Dict[int, float] = {}
|
||||||
|
for digit, template in templates_dict.items():
|
||||||
|
template = resize(template, roi.shape[::-1])
|
||||||
|
template_result = matchTemplate(roi, template, TM_CCOEFF_NORMED)
|
||||||
|
min_val, max_val, min_loc, max_loc = minMaxLoc(template_result)
|
||||||
|
digit_results[digit] = max_val
|
||||||
|
|
||||||
|
digit_results = {k: v for k, v in digit_results.items() if v > 0.5}
|
||||||
|
if digit_results:
|
||||||
|
best_match_digit = max(digit_results, key=digit_results.get)
|
||||||
|
result += str(best_match_digit)
|
||||||
|
return int(result) if result else None
|
||||||
|
|
||||||
|
# return ocr_digits(
|
||||||
|
# img_cropped,
|
||||||
|
# template.regular,
|
||||||
|
# template_threshold=0.5,
|
||||||
|
# filter_threshold=10,
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
def ocr_max_recall(img_cropped: Mat):
|
def ocr_max_recall(img_cropped: Mat):
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
|
import pickle
|
||||||
from base64 import b64decode
|
from base64 import b64decode
|
||||||
from time import sleep
|
from typing import Any, Dict, List, Literal, Tuple, TypedDict, Union
|
||||||
from typing import Dict, List, Literal, Tuple, TypedDict
|
|
||||||
|
|
||||||
from cv2 import (
|
from cv2 import (
|
||||||
CHAIN_APPROX_SIMPLE,
|
CHAIN_APPROX_SIMPLE,
|
||||||
@ -26,55 +26,79 @@ from cv2 import (
|
|||||||
threshold,
|
threshold,
|
||||||
waitKey,
|
waitKey,
|
||||||
)
|
)
|
||||||
from imutils import contours, grab_contours
|
from numpy import ndarray
|
||||||
from numpy import frombuffer as np_frombuffer
|
|
||||||
from numpy import uint8
|
|
||||||
|
|
||||||
from ._builtin_templates import GeoSansLight_Italic, GeoSansLight_Regular
|
from ._builtin_templates import (
|
||||||
|
DEFAULT_ITALIC,
|
||||||
|
DEFAULT_ITALIC_ERODED,
|
||||||
|
DEFAULT_REGULAR,
|
||||||
|
DEFAULT_REGULAR_ERODED,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"load_digit_template",
|
"TemplateItem",
|
||||||
|
"DigitTemplate",
|
||||||
"load_builtin_digit_template",
|
"load_builtin_digit_template",
|
||||||
"MatchTemplateMultipleResult",
|
"MatchTemplateMultipleResult",
|
||||||
"matchTemplateMultiple",
|
"matchTemplateMultiple",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# a list of Mat showing following characters:
|
||||||
def load_digit_template(filename: str) -> Dict[int, Mat]:
|
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ']
|
||||||
"""
|
TemplateItem = Union[List[Mat], Tuple[Mat]]
|
||||||
Arguments:
|
|
||||||
filename -- An image with white background and black "0 1 2 3 4 5 6 7 8 9" text.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[int, cv2.Mat]
|
|
||||||
"""
|
|
||||||
# https://pyimagesearch.com/2017/07/17/credit-card-ocr-with-opencv-and-python/
|
|
||||||
ref = imread(filename)
|
|
||||||
ref = cvtColor(ref, COLOR_BGR2GRAY)
|
|
||||||
ref = threshold(ref, 10, 255, THRESH_BINARY_INV)[1]
|
|
||||||
refCnts = findContours(ref.copy(), RETR_EXTERNAL, CHAIN_APPROX_SIMPLE)
|
|
||||||
refCnts = grab_contours(refCnts)
|
|
||||||
refCnts = contours.sort_contours(refCnts, method="left-to-right")[0]
|
|
||||||
digits = {}
|
|
||||||
for i, cnt in enumerate(refCnts):
|
|
||||||
(x, y, w, h) = boundingRect(cnt)
|
|
||||||
roi = ref[y : y + h, x : x + w]
|
|
||||||
digits[i] = roi
|
|
||||||
return digits
|
|
||||||
|
|
||||||
|
|
||||||
def load_builtin_digit_template(
|
class DigitTemplate:
|
||||||
name: Literal["GeoSansLight-Regular", "GeoSansLight-Italic"]
|
__slots__ = ["regular", "italic", "regular_eroded", "italic_eroded"]
|
||||||
):
|
|
||||||
name_builtin_template_b64_map = {
|
regular: TemplateItem
|
||||||
"GeoSansLight-Regular": GeoSansLight_Regular,
|
italic: TemplateItem
|
||||||
"GeoSansLight-Italic": GeoSansLight_Italic,
|
regular_eroded: TemplateItem
|
||||||
}
|
italic_eroded: TemplateItem
|
||||||
template_b64 = name_builtin_template_b64_map[name]
|
|
||||||
return {
|
def __ensure_template_item(self, item):
|
||||||
int(key): imdecode(np_frombuffer(b64decode(b64str), uint8), IMREAD_GRAYSCALE)
|
return (
|
||||||
for key, b64str in template_b64.items()
|
isinstance(item, (list, tuple))
|
||||||
|
and len(item) == 11
|
||||||
|
and all(isinstance(i, ndarray) for i in item)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, regular, italic, regular_eroded, italic_eroded):
|
||||||
|
self.regular = regular
|
||||||
|
self.italic = italic
|
||||||
|
self.regular_eroded = regular_eroded
|
||||||
|
self.italic_eroded = italic_eroded
|
||||||
|
|
||||||
|
def __setattr__(self, __name: str, __value: Any):
|
||||||
|
if __name in {
|
||||||
|
"regular",
|
||||||
|
"italic",
|
||||||
|
"regular_eroded",
|
||||||
|
"italic_eroded",
|
||||||
|
} and self.__ensure_template_item(__value):
|
||||||
|
super().__setattr__(__name, __value)
|
||||||
|
return
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid attribute set, expected type TemplateItem or invalid attribute name."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_builtin_digit_template(name: Literal["default"]) -> DigitTemplate:
|
||||||
|
CONSTANTS = {
|
||||||
|
"default": [
|
||||||
|
DEFAULT_REGULAR,
|
||||||
|
DEFAULT_ITALIC,
|
||||||
|
DEFAULT_REGULAR_ERODED,
|
||||||
|
DEFAULT_ITALIC_ERODED,
|
||||||
|
]
|
||||||
}
|
}
|
||||||
|
args = CONSTANTS[name]
|
||||||
|
args = [
|
||||||
|
[pickle.loads(b64decode(encoded_str)) for encoded_str in arg] for arg in args
|
||||||
|
]
|
||||||
|
|
||||||
|
return DigitTemplate(*args)
|
||||||
|
|
||||||
|
|
||||||
class MatchTemplateMultipleResult(TypedDict):
|
class MatchTemplateMultipleResult(TypedDict):
|
||||||
@ -85,10 +109,6 @@ class MatchTemplateMultipleResult(TypedDict):
|
|||||||
def matchTemplateMultiple(
|
def matchTemplateMultiple(
|
||||||
src: Mat, template: Mat, threshold: float = 0.1
|
src: Mat, template: Mat, threshold: float = 0.1
|
||||||
) -> List[MatchTemplateMultipleResult]:
|
) -> List[MatchTemplateMultipleResult]:
|
||||||
"""
|
|
||||||
Returns:
|
|
||||||
A list of tuple[x, y, w, h] representing the matched rectangle
|
|
||||||
"""
|
|
||||||
template_result = matchTemplate(src, template, TM_CCOEFF_NORMED)
|
template_result = matchTemplate(src, template, TM_CCOEFF_NORMED)
|
||||||
min_val, max_val, min_loc, max_loc = minMaxLoc(template_result)
|
min_val, max_val, min_loc, max_loc = minMaxLoc(template_result)
|
||||||
template_h, template_w = template.shape[:2]
|
template_h, template_w = template.shape[:2]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user