mirror of
https://github.com/283375/arcaea-offline-ocr.git
synced 2025-07-01 20:36:27 +00:00
init
This commit is contained in:
0
src/arcaea_offline_ocr/__init__.py
Normal file
0
src/arcaea_offline_ocr/__init__.py
Normal file
2
src/arcaea_offline_ocr/_builtin_templates.py
Normal file
2
src/arcaea_offline_ocr/_builtin_templates.py
Normal file
File diff suppressed because one or more lines are too long
42
src/arcaea_offline_ocr/crop.py
Normal file
42
src/arcaea_offline_ocr/crop.py
Normal file
@ -0,0 +1,42 @@
|
||||
from typing import Tuple
|
||||
|
||||
from cv2 import Mat
|
||||
|
||||
from .device import Device
|
||||
|
||||
|
||||
def crop_img(img: Mat, *, top: int, left: int, bottom: int, right: int):
|
||||
return img[top:bottom, left:right]
|
||||
|
||||
|
||||
def crop_from_device_attr(img: Mat, rect: Tuple[int, int, int, int]):
|
||||
x, y, w, h = rect
|
||||
return crop_img(img, top=y, left=x, bottom=y + h, right=x + w)
|
||||
|
||||
|
||||
def crop_to_pure(screenshot: Mat, device: Device):
|
||||
return crop_from_device_attr(screenshot, device.pure)
|
||||
|
||||
|
||||
def crop_to_far(screenshot: Mat, device: Device):
|
||||
return crop_from_device_attr(screenshot, device.far)
|
||||
|
||||
|
||||
def crop_to_lost(screenshot: Mat, device: Device):
|
||||
return crop_from_device_attr(screenshot, device.lost)
|
||||
|
||||
|
||||
def crop_to_max_recall(screenshot: Mat, device: Device):
|
||||
return crop_from_device_attr(screenshot, device.max_recall)
|
||||
|
||||
|
||||
def crop_to_rating_class(screenshot: Mat, device: Device):
|
||||
return crop_from_device_attr(screenshot, device.rating_class)
|
||||
|
||||
|
||||
def crop_to_score(screenshot: Mat, device: Device):
|
||||
return crop_from_device_attr(screenshot, device.score)
|
||||
|
||||
|
||||
def crop_to_title(screenshot: Mat, device: Device):
|
||||
return crop_from_device_attr(screenshot, device.title)
|
32
src/arcaea_offline_ocr/device.py
Normal file
32
src/arcaea_offline_ocr/device.py
Normal file
@ -0,0 +1,32 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Device:
|
||||
version: int
|
||||
uuid: str
|
||||
name: str
|
||||
pure: Tuple[int, int, int, int]
|
||||
far: Tuple[int, int, int, int]
|
||||
lost: Tuple[int, int, int, int]
|
||||
max_recall: Tuple[int, int, int, int]
|
||||
rating_class: Tuple[int, int, int, int]
|
||||
score: Tuple[int, int, int, int]
|
||||
title: Tuple[int, int, int, int]
|
||||
|
||||
@classmethod
|
||||
def from_json_object(cls, json_dict: Dict[str, Any]):
|
||||
if json_dict["version"] == 1:
|
||||
return cls(
|
||||
version=1,
|
||||
uuid=json_dict["uuid"],
|
||||
name=json_dict["name"],
|
||||
pure=json_dict["pure"],
|
||||
far=json_dict["far"],
|
||||
lost=json_dict["lost"],
|
||||
max_recall=json_dict["max_recall"],
|
||||
rating_class=json_dict["rating_class"],
|
||||
score=json_dict["score"],
|
||||
title=json_dict["title"],
|
||||
)
|
64
src/arcaea_offline_ocr/mask.py
Normal file
64
src/arcaea_offline_ocr/mask.py
Normal file
@ -0,0 +1,64 @@
|
||||
from cv2 import BORDER_CONSTANT, BORDER_ISOLATED, Mat, bitwise_or, dilate, inRange
|
||||
from numpy import array, uint8
|
||||
|
||||
GRAY_MIN_HSV = array([0, 0, 70], uint8)
|
||||
GRAY_MAX_HSV = array([0, 70, 200], uint8)
|
||||
|
||||
WHITE_MIN_HSV = array([0, 0, 240], uint8)
|
||||
WHITE_MAX_HSV = array([179, 10, 255], uint8)
|
||||
|
||||
PST_MIN_HSV = array([100, 50, 80], uint8)
|
||||
PST_MAX_HSV = array([100, 255, 255], uint8)
|
||||
|
||||
PRS_MIN_HSV = array([43, 40, 75], uint8)
|
||||
PRS_MAX_HSV = array([50, 155, 190], uint8)
|
||||
|
||||
FTR_MIN_HSV = array([149, 30, 0], uint8)
|
||||
FTR_MAX_HSV = array([155, 181, 150], uint8)
|
||||
|
||||
BYD_MIN_HSV = array([170, 50, 50], uint8)
|
||||
BYD_MAX_HSV = array([179, 210, 198], uint8)
|
||||
|
||||
|
||||
def mask_gray(img_hsv: Mat):
|
||||
mask = inRange(img_hsv, GRAY_MIN_HSV, GRAY_MAX_HSV)
|
||||
mask = dilate(mask, (2, 2))
|
||||
return mask
|
||||
|
||||
|
||||
def mask_white(img_hsv: Mat):
|
||||
mask = inRange(img_hsv, WHITE_MIN_HSV, WHITE_MAX_HSV)
|
||||
mask = dilate(mask, (5, 5), borderType=BORDER_CONSTANT | BORDER_ISOLATED)
|
||||
return mask
|
||||
|
||||
|
||||
def mask_pst(img_hsv: Mat):
|
||||
mask = inRange(img_hsv, PST_MIN_HSV, PST_MAX_HSV)
|
||||
mask = dilate(mask, (1, 1))
|
||||
return mask
|
||||
|
||||
|
||||
def mask_prs(img_hsv: Mat):
|
||||
mask = inRange(img_hsv, PRS_MIN_HSV, PRS_MAX_HSV)
|
||||
mask = dilate(mask, (1, 1))
|
||||
return mask
|
||||
|
||||
|
||||
def mask_ftr(img_hsv: Mat):
|
||||
mask = inRange(img_hsv, FTR_MIN_HSV, FTR_MAX_HSV)
|
||||
mask = dilate(mask, (1, 1))
|
||||
return mask
|
||||
|
||||
|
||||
def mask_byd(img_hsv: Mat):
|
||||
mask = inRange(img_hsv, BYD_MIN_HSV, BYD_MAX_HSV)
|
||||
mask = dilate(mask, (2, 2))
|
||||
return mask
|
||||
|
||||
|
||||
def mask_rating_class(img_hsv: Mat):
|
||||
pst = mask_pst(img_hsv)
|
||||
prs = mask_prs(img_hsv)
|
||||
ftr = mask_ftr(img_hsv)
|
||||
byd = mask_byd(img_hsv)
|
||||
return bitwise_or(byd, bitwise_or(ftr, bitwise_or(pst, prs)))
|
158
src/arcaea_offline_ocr/ocr.py
Normal file
158
src/arcaea_offline_ocr/ocr.py
Normal file
@ -0,0 +1,158 @@
|
||||
import re
|
||||
from typing import Dict, List
|
||||
|
||||
from cv2 import Mat
|
||||
from imutils import resize
|
||||
from pytesseract import image_to_string
|
||||
|
||||
from .template import (
|
||||
MatchTemplateMultipleResult,
|
||||
load_builtin_digit_template,
|
||||
matchTemplateMultiple,
|
||||
)
|
||||
|
||||
|
||||
def group_numbers(numbers: List[int], threshold: int) -> List[List[int]]:
|
||||
"""
|
||||
```
|
||||
numbers = [26, 189, 303, 348, 32, 195, 391, 145, 77]
|
||||
group_numbers(numbers, 10) -> [[26, 32], [77], [145], [189, 195], [303], [348], [391]]
|
||||
group_numbers(numbers, 5) -> [[26], [32], [77], [145], [189], [195], [303], [348], [391]]
|
||||
group_numbers(numbers, 50) -> [[26, 32, 77], [145, 189, 195], [303, 348, 391]]
|
||||
# from Bing AI
|
||||
```
|
||||
"""
|
||||
numbers.sort()
|
||||
# Initialize an empty list of groups
|
||||
groups = []
|
||||
# Initialize an empty list for the current group
|
||||
group = []
|
||||
# Loop through the numbers
|
||||
for number in numbers:
|
||||
# If the current group is empty or the number is within the threshold of the last number in the group
|
||||
if not group or number - group[-1] <= threshold:
|
||||
# Append the number to the current group
|
||||
group.append(number)
|
||||
# Otherwise
|
||||
else:
|
||||
# Append the current group to the list of groups
|
||||
groups.append(group)
|
||||
# Start a new group with the number
|
||||
group = [number]
|
||||
# Append the last group to the list of groups
|
||||
groups.append(group)
|
||||
# Return the list of groups
|
||||
return groups
|
||||
|
||||
|
||||
class FilterDigitResultDict(MatchTemplateMultipleResult):
|
||||
digit: int
|
||||
|
||||
|
||||
def filter_digit_results(
|
||||
results: Dict[int, List[MatchTemplateMultipleResult]], threshold: int
|
||||
):
|
||||
result_sorted_by_x_pos: Dict[
|
||||
int, List[FilterDigitResultDict]
|
||||
] = {} # dict[x_pos, dict[int, list[result]]]
|
||||
for digit, match_results in results.items():
|
||||
if match_results:
|
||||
for result in match_results:
|
||||
x_pos = result["xywh"][0]
|
||||
_dict = {**result, "digit": digit}
|
||||
if result_sorted_by_x_pos.get(x_pos) is None:
|
||||
result_sorted_by_x_pos[x_pos] = [_dict]
|
||||
else:
|
||||
result_sorted_by_x_pos[x_pos].append(_dict)
|
||||
|
||||
x_poses_grouped: List[List[int]] = group_numbers(
|
||||
list(result_sorted_by_x_pos), threshold
|
||||
)
|
||||
|
||||
final_result: Dict[
|
||||
int, List[MatchTemplateMultipleResult]
|
||||
] = {} # dict[digit, list[Results]]
|
||||
for x_poses in x_poses_grouped:
|
||||
possible_results = []
|
||||
for x_pos in x_poses:
|
||||
possible_results.extend(result_sorted_by_x_pos.get(x_pos, []))
|
||||
result = sorted(possible_results, key=lambda d: d["max_val"], reverse=True)[0]
|
||||
result_digit = result["digit"]
|
||||
result.pop("digit", None)
|
||||
if final_result.get(result_digit) is None:
|
||||
final_result[result_digit] = [result]
|
||||
else:
|
||||
final_result[result_digit].append(result)
|
||||
return final_result
|
||||
|
||||
|
||||
def ocr_digits(
|
||||
img: Mat,
|
||||
templates: Dict[int, Mat],
|
||||
template_threshold: float,
|
||||
filter_threshold: int,
|
||||
):
|
||||
results: Dict[int, List[MatchTemplateMultipleResult]] = {}
|
||||
for digit, template in templates.items():
|
||||
template = resize(template, height=img.shape[0])
|
||||
results[digit] = matchTemplateMultiple(img, template, template_threshold)
|
||||
results = filter_digit_results(results, filter_threshold)
|
||||
result_x_digit_map = {}
|
||||
for digit, match_results in results.items():
|
||||
if match_results:
|
||||
for result in match_results:
|
||||
result_x_digit_map[result["xywh"][0]] = digit
|
||||
digits_sorted_by_x = dict(sorted(result_x_digit_map.items()))
|
||||
joined_str = "".join([str(digit) for digit in digits_sorted_by_x.values()])
|
||||
return int(joined_str) if joined_str else None
|
||||
|
||||
|
||||
def ocr_pure(img_masked: Mat):
|
||||
templates = load_builtin_digit_template("GeoSansLight-Regular")
|
||||
return ocr_digits(img_masked, templates, template_threshold=0.6, filter_threshold=3)
|
||||
|
||||
|
||||
def ocr_far_lost(img_masked: Mat):
|
||||
templates = load_builtin_digit_template("GeoSansLight-Italic")
|
||||
return ocr_digits(img_masked, templates, template_threshold=0.6, filter_threshold=3)
|
||||
|
||||
|
||||
def ocr_score(img_cropped: Mat):
|
||||
templates = load_builtin_digit_template("GeoSansLight-Regular")
|
||||
return ocr_digits(
|
||||
img_cropped, templates, template_threshold=0.5, filter_threshold=10
|
||||
)
|
||||
|
||||
|
||||
def ocr_max_recall(img_cropped: Mat):
|
||||
try:
|
||||
texts = image_to_string(img_cropped).split(" ") # type: List[str]
|
||||
texts.reverse()
|
||||
for text in texts:
|
||||
if re.match(r"^[0-9]+$", text):
|
||||
return int(text)
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def ocr_rating_class(img_cropped: Mat):
|
||||
try:
|
||||
text = image_to_string(img_cropped) # type: str
|
||||
text = text.lower()
|
||||
if "past" in text:
|
||||
return 0
|
||||
elif "present" in text:
|
||||
return 1
|
||||
elif "future" in text:
|
||||
return 2
|
||||
elif "beyond" in text:
|
||||
return 3
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def ocr_title(img_cropped: Mat):
|
||||
try:
|
||||
return image_to_string(img_cropped).replace("\n", "")
|
||||
except Exception as e:
|
||||
return ""
|
67
src/arcaea_offline_ocr/recognize.py
Normal file
67
src/arcaea_offline_ocr/recognize.py
Normal file
@ -0,0 +1,67 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from cv2 import COLOR_BGR2HSV, GaussianBlur, cvtColor, imread
|
||||
|
||||
from .crop import *
|
||||
from .device import Device
|
||||
from .mask import *
|
||||
from .ocr import *
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class RecognizeResult:
|
||||
pure: Optional[int]
|
||||
far: Optional[int]
|
||||
lost: Optional[int]
|
||||
score: Optional[int]
|
||||
max_recall: Optional[int]
|
||||
rating_class: Optional[int]
|
||||
title: str
|
||||
|
||||
|
||||
def recognize(img_filename: str, device: Device):
|
||||
img = imread(img_filename)
|
||||
img_hsv = cvtColor(img, COLOR_BGR2HSV)
|
||||
|
||||
pure_roi = crop_to_pure(img_hsv, device)
|
||||
pure_roi = mask_gray(pure_roi)
|
||||
pure_roi = GaussianBlur(pure_roi, (3, 3), 0)
|
||||
pure = ocr_pure(pure_roi)
|
||||
|
||||
far_roi = crop_to_far(img_hsv, device)
|
||||
far_roi = mask_gray(far_roi)
|
||||
far_roi = GaussianBlur(far_roi, (3, 3), 0)
|
||||
far = ocr_far_lost(far_roi)
|
||||
|
||||
lost_roi = crop_to_lost(img_hsv, device)
|
||||
lost_roi = mask_gray(lost_roi)
|
||||
lost_roi = GaussianBlur(lost_roi, (3, 3), 0)
|
||||
lost = ocr_far_lost(lost_roi)
|
||||
|
||||
score_roi = crop_to_score(img_hsv, device)
|
||||
score_roi = mask_white(score_roi)
|
||||
score_roi = GaussianBlur(score_roi, (3, 3), 0)
|
||||
score = ocr_score(score_roi)
|
||||
|
||||
max_recall_roi = crop_to_max_recall(img_hsv, device)
|
||||
max_recall_roi = mask_gray(max_recall_roi)
|
||||
max_recall = ocr_max_recall(max_recall_roi)
|
||||
|
||||
rating_class_roi = crop_to_rating_class(img_hsv, device)
|
||||
rating_class_roi = mask_rating_class(rating_class_roi)
|
||||
rating_class = ocr_rating_class(rating_class_roi)
|
||||
|
||||
title_roi = crop_to_title(img_hsv, device)
|
||||
title_roi = mask_white(title_roi)
|
||||
title = ocr_title(title_roi)
|
||||
|
||||
return RecognizeResult(
|
||||
pure=pure,
|
||||
far=far,
|
||||
lost=lost,
|
||||
score=score,
|
||||
max_recall=max_recall,
|
||||
rating_class=rating_class,
|
||||
title=title,
|
||||
)
|
163
src/arcaea_offline_ocr/template.py
Normal file
163
src/arcaea_offline_ocr/template.py
Normal file
@ -0,0 +1,163 @@
|
||||
from base64 import b64decode
|
||||
from time import sleep
|
||||
from typing import Dict, List, Literal, Tuple, TypedDict
|
||||
|
||||
from cv2 import (
|
||||
CHAIN_APPROX_SIMPLE,
|
||||
COLOR_BGR2GRAY,
|
||||
COLOR_GRAY2BGR,
|
||||
FONT_HERSHEY_SIMPLEX,
|
||||
IMREAD_GRAYSCALE,
|
||||
RETR_EXTERNAL,
|
||||
THRESH_BINARY_INV,
|
||||
TM_CCOEFF_NORMED,
|
||||
Mat,
|
||||
boundingRect,
|
||||
cvtColor,
|
||||
destroyAllWindows,
|
||||
findContours,
|
||||
imdecode,
|
||||
imread,
|
||||
imshow,
|
||||
matchTemplate,
|
||||
minMaxLoc,
|
||||
putText,
|
||||
rectangle,
|
||||
threshold,
|
||||
waitKey,
|
||||
)
|
||||
from imutils import contours, grab_contours
|
||||
from numpy import frombuffer as np_frombuffer
|
||||
from numpy import uint8
|
||||
|
||||
from ._builtin_templates import GeoSansLight_Italic, GeoSansLight_Regular
|
||||
|
||||
|
||||
def load_digit_template(filename: str) -> Dict[int, Mat]:
|
||||
"""
|
||||
Arguments:
|
||||
filename -- An image with white background and black "0 1 2 3 4 5 6 7 8 9" text.
|
||||
|
||||
Returns:
|
||||
dict[int, cv2.Mat]
|
||||
"""
|
||||
# https://pyimagesearch.com/2017/07/17/credit-card-ocr-with-opencv-and-python/
|
||||
ref = imread(filename)
|
||||
ref = cvtColor(ref, COLOR_BGR2GRAY)
|
||||
ref = threshold(ref, 10, 255, THRESH_BINARY_INV)[1]
|
||||
refCnts = findContours(ref.copy(), RETR_EXTERNAL, CHAIN_APPROX_SIMPLE)
|
||||
refCnts = grab_contours(refCnts)
|
||||
refCnts = contours.sort_contours(refCnts, method="left-to-right")[0]
|
||||
digits = {}
|
||||
for i, cnt in enumerate(refCnts):
|
||||
(x, y, w, h) = boundingRect(cnt)
|
||||
roi = ref[y : y + h, x : x + w]
|
||||
digits[i] = roi
|
||||
return digits
|
||||
|
||||
|
||||
def load_builtin_digit_template(
|
||||
name: Literal["GeoSansLight-Regular", "GeoSansLight-Italic"]
|
||||
):
|
||||
name_builtin_template_b64_map = {
|
||||
"GeoSansLight-Regular": GeoSansLight_Regular,
|
||||
"GeoSansLight-Italic": GeoSansLight_Italic,
|
||||
}
|
||||
template_b64 = name_builtin_template_b64_map[name]
|
||||
return {
|
||||
int(key): imdecode(np_frombuffer(b64decode(b64str), uint8), IMREAD_GRAYSCALE)
|
||||
for key, b64str in template_b64.items()
|
||||
}
|
||||
|
||||
|
||||
class MatchTemplateMultipleResult(TypedDict):
|
||||
max_val: float
|
||||
xywh: Tuple[int, int, int, int]
|
||||
|
||||
|
||||
def matchTemplateMultiple(
|
||||
src: Mat, template: Mat, threshold: float = 0.1
|
||||
) -> List[MatchTemplateMultipleResult]:
|
||||
"""
|
||||
Returns:
|
||||
A list of tuple[x, y, w, h] representing the matched rectangle
|
||||
"""
|
||||
template_result = matchTemplate(src, template, TM_CCOEFF_NORMED)
|
||||
min_val, max_val, min_loc, max_loc = minMaxLoc(template_result)
|
||||
template_h, template_w = template.shape[:2]
|
||||
results = []
|
||||
|
||||
# debug
|
||||
# imshow("templ", template)
|
||||
# waitKey(750)
|
||||
# destroyAllWindows()
|
||||
|
||||
# https://stackoverflow.com/a/66848923/16484891
|
||||
# CC BY-SA 4.0
|
||||
prev_min_val, prev_max_val, prev_min_loc, prev_max_loc = None, None, None, None
|
||||
while max_val > threshold:
|
||||
min_val, max_val, min_loc, max_loc = minMaxLoc(template_result)
|
||||
|
||||
# Prevent infinite loop. If those 4 values are the same as previous ones, break the loop.
|
||||
if (
|
||||
prev_min_val == min_val
|
||||
and prev_max_val == max_val
|
||||
and prev_min_loc == min_loc
|
||||
and prev_max_loc == max_loc
|
||||
):
|
||||
break
|
||||
else:
|
||||
prev_min_val, prev_max_val, prev_min_loc, prev_max_loc = (
|
||||
min_val,
|
||||
max_val,
|
||||
min_loc,
|
||||
max_loc,
|
||||
)
|
||||
|
||||
if max_val > threshold:
|
||||
# Prevent start_row, end_row, start_col, end_col be out of range of image
|
||||
start_row = max(0, max_loc[1] - template_h // 2)
|
||||
start_col = max(0, max_loc[0] - template_w // 2)
|
||||
end_row = min(template_result.shape[0], max_loc[1] + template_h // 2 + 1)
|
||||
end_col = min(template_result.shape[1], max_loc[0] + template_w // 2 + 1)
|
||||
|
||||
template_result[start_row:end_row, start_col:end_col] = 0
|
||||
results.append(
|
||||
{
|
||||
"max_val": max_val,
|
||||
"xywh": (
|
||||
max_loc[0],
|
||||
max_loc[1],
|
||||
max_loc[0] + template_w + 1,
|
||||
max_loc[1] + template_h + 1,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# debug
|
||||
# src_dbg = cvtColor(src, COLOR_GRAY2BGR)
|
||||
# src_dbg = rectangle(
|
||||
# src_dbg,
|
||||
# (max_loc[0], max_loc[1]),
|
||||
# (
|
||||
# max_loc[0] + template_w + 1,
|
||||
# max_loc[1] + template_h + 1,
|
||||
# ),
|
||||
# (0, 255, 0),
|
||||
# thickness=3,
|
||||
# )
|
||||
# src_dbg = putText(
|
||||
# src_dbg,
|
||||
# f"{max_val:.5f}",
|
||||
# (5, src_dbg.shape[0] - 5),
|
||||
# FONT_HERSHEY_SIMPLEX,
|
||||
# 1,
|
||||
# (0, 255, 0),
|
||||
# thickness=2,
|
||||
# )
|
||||
# imshow("src_rect", src_dbg)
|
||||
# imshow("templ", template)
|
||||
# waitKey(750)
|
||||
# destroyAllWindows()
|
||||
|
||||
return results
|
Reference in New Issue
Block a user