This commit is contained in:
2023-06-03 20:26:53 +08:00
commit f9968ae8b3
14 changed files with 797 additions and 0 deletions

View File

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,42 @@
from typing import Tuple
from cv2 import Mat
from .device import Device
def crop_img(img: Mat, *, top: int, left: int, bottom: int, right: int):
return img[top:bottom, left:right]
def crop_from_device_attr(img: Mat, rect: Tuple[int, int, int, int]):
x, y, w, h = rect
return crop_img(img, top=y, left=x, bottom=y + h, right=x + w)
def crop_to_pure(screenshot: Mat, device: Device):
return crop_from_device_attr(screenshot, device.pure)
def crop_to_far(screenshot: Mat, device: Device):
return crop_from_device_attr(screenshot, device.far)
def crop_to_lost(screenshot: Mat, device: Device):
return crop_from_device_attr(screenshot, device.lost)
def crop_to_max_recall(screenshot: Mat, device: Device):
return crop_from_device_attr(screenshot, device.max_recall)
def crop_to_rating_class(screenshot: Mat, device: Device):
return crop_from_device_attr(screenshot, device.rating_class)
def crop_to_score(screenshot: Mat, device: Device):
return crop_from_device_attr(screenshot, device.score)
def crop_to_title(screenshot: Mat, device: Device):
return crop_from_device_attr(screenshot, device.title)

View File

@ -0,0 +1,32 @@
from dataclasses import dataclass
from typing import Any, Dict, Tuple
@dataclass(kw_only=True)
class Device:
version: int
uuid: str
name: str
pure: Tuple[int, int, int, int]
far: Tuple[int, int, int, int]
lost: Tuple[int, int, int, int]
max_recall: Tuple[int, int, int, int]
rating_class: Tuple[int, int, int, int]
score: Tuple[int, int, int, int]
title: Tuple[int, int, int, int]
@classmethod
def from_json_object(cls, json_dict: Dict[str, Any]):
if json_dict["version"] == 1:
return cls(
version=1,
uuid=json_dict["uuid"],
name=json_dict["name"],
pure=json_dict["pure"],
far=json_dict["far"],
lost=json_dict["lost"],
max_recall=json_dict["max_recall"],
rating_class=json_dict["rating_class"],
score=json_dict["score"],
title=json_dict["title"],
)

View File

@ -0,0 +1,64 @@
from cv2 import BORDER_CONSTANT, BORDER_ISOLATED, Mat, bitwise_or, dilate, inRange
from numpy import array, uint8
GRAY_MIN_HSV = array([0, 0, 70], uint8)
GRAY_MAX_HSV = array([0, 70, 200], uint8)
WHITE_MIN_HSV = array([0, 0, 240], uint8)
WHITE_MAX_HSV = array([179, 10, 255], uint8)
PST_MIN_HSV = array([100, 50, 80], uint8)
PST_MAX_HSV = array([100, 255, 255], uint8)
PRS_MIN_HSV = array([43, 40, 75], uint8)
PRS_MAX_HSV = array([50, 155, 190], uint8)
FTR_MIN_HSV = array([149, 30, 0], uint8)
FTR_MAX_HSV = array([155, 181, 150], uint8)
BYD_MIN_HSV = array([170, 50, 50], uint8)
BYD_MAX_HSV = array([179, 210, 198], uint8)
def mask_gray(img_hsv: Mat):
mask = inRange(img_hsv, GRAY_MIN_HSV, GRAY_MAX_HSV)
mask = dilate(mask, (2, 2))
return mask
def mask_white(img_hsv: Mat):
mask = inRange(img_hsv, WHITE_MIN_HSV, WHITE_MAX_HSV)
mask = dilate(mask, (5, 5), borderType=BORDER_CONSTANT | BORDER_ISOLATED)
return mask
def mask_pst(img_hsv: Mat):
mask = inRange(img_hsv, PST_MIN_HSV, PST_MAX_HSV)
mask = dilate(mask, (1, 1))
return mask
def mask_prs(img_hsv: Mat):
mask = inRange(img_hsv, PRS_MIN_HSV, PRS_MAX_HSV)
mask = dilate(mask, (1, 1))
return mask
def mask_ftr(img_hsv: Mat):
mask = inRange(img_hsv, FTR_MIN_HSV, FTR_MAX_HSV)
mask = dilate(mask, (1, 1))
return mask
def mask_byd(img_hsv: Mat):
mask = inRange(img_hsv, BYD_MIN_HSV, BYD_MAX_HSV)
mask = dilate(mask, (2, 2))
return mask
def mask_rating_class(img_hsv: Mat):
pst = mask_pst(img_hsv)
prs = mask_prs(img_hsv)
ftr = mask_ftr(img_hsv)
byd = mask_byd(img_hsv)
return bitwise_or(byd, bitwise_or(ftr, bitwise_or(pst, prs)))

View File

@ -0,0 +1,158 @@
import re
from typing import Dict, List
from cv2 import Mat
from imutils import resize
from pytesseract import image_to_string
from .template import (
MatchTemplateMultipleResult,
load_builtin_digit_template,
matchTemplateMultiple,
)
def group_numbers(numbers: List[int], threshold: int) -> List[List[int]]:
"""
```
numbers = [26, 189, 303, 348, 32, 195, 391, 145, 77]
group_numbers(numbers, 10) -> [[26, 32], [77], [145], [189, 195], [303], [348], [391]]
group_numbers(numbers, 5) -> [[26], [32], [77], [145], [189], [195], [303], [348], [391]]
group_numbers(numbers, 50) -> [[26, 32, 77], [145, 189, 195], [303, 348, 391]]
# from Bing AI
```
"""
numbers.sort()
# Initialize an empty list of groups
groups = []
# Initialize an empty list for the current group
group = []
# Loop through the numbers
for number in numbers:
# If the current group is empty or the number is within the threshold of the last number in the group
if not group or number - group[-1] <= threshold:
# Append the number to the current group
group.append(number)
# Otherwise
else:
# Append the current group to the list of groups
groups.append(group)
# Start a new group with the number
group = [number]
# Append the last group to the list of groups
groups.append(group)
# Return the list of groups
return groups
class FilterDigitResultDict(MatchTemplateMultipleResult):
digit: int
def filter_digit_results(
results: Dict[int, List[MatchTemplateMultipleResult]], threshold: int
):
result_sorted_by_x_pos: Dict[
int, List[FilterDigitResultDict]
] = {} # dict[x_pos, dict[int, list[result]]]
for digit, match_results in results.items():
if match_results:
for result in match_results:
x_pos = result["xywh"][0]
_dict = {**result, "digit": digit}
if result_sorted_by_x_pos.get(x_pos) is None:
result_sorted_by_x_pos[x_pos] = [_dict]
else:
result_sorted_by_x_pos[x_pos].append(_dict)
x_poses_grouped: List[List[int]] = group_numbers(
list(result_sorted_by_x_pos), threshold
)
final_result: Dict[
int, List[MatchTemplateMultipleResult]
] = {} # dict[digit, list[Results]]
for x_poses in x_poses_grouped:
possible_results = []
for x_pos in x_poses:
possible_results.extend(result_sorted_by_x_pos.get(x_pos, []))
result = sorted(possible_results, key=lambda d: d["max_val"], reverse=True)[0]
result_digit = result["digit"]
result.pop("digit", None)
if final_result.get(result_digit) is None:
final_result[result_digit] = [result]
else:
final_result[result_digit].append(result)
return final_result
def ocr_digits(
img: Mat,
templates: Dict[int, Mat],
template_threshold: float,
filter_threshold: int,
):
results: Dict[int, List[MatchTemplateMultipleResult]] = {}
for digit, template in templates.items():
template = resize(template, height=img.shape[0])
results[digit] = matchTemplateMultiple(img, template, template_threshold)
results = filter_digit_results(results, filter_threshold)
result_x_digit_map = {}
for digit, match_results in results.items():
if match_results:
for result in match_results:
result_x_digit_map[result["xywh"][0]] = digit
digits_sorted_by_x = dict(sorted(result_x_digit_map.items()))
joined_str = "".join([str(digit) for digit in digits_sorted_by_x.values()])
return int(joined_str) if joined_str else None
def ocr_pure(img_masked: Mat):
templates = load_builtin_digit_template("GeoSansLight-Regular")
return ocr_digits(img_masked, templates, template_threshold=0.6, filter_threshold=3)
def ocr_far_lost(img_masked: Mat):
templates = load_builtin_digit_template("GeoSansLight-Italic")
return ocr_digits(img_masked, templates, template_threshold=0.6, filter_threshold=3)
def ocr_score(img_cropped: Mat):
templates = load_builtin_digit_template("GeoSansLight-Regular")
return ocr_digits(
img_cropped, templates, template_threshold=0.5, filter_threshold=10
)
def ocr_max_recall(img_cropped: Mat):
try:
texts = image_to_string(img_cropped).split(" ") # type: List[str]
texts.reverse()
for text in texts:
if re.match(r"^[0-9]+$", text):
return int(text)
except Exception as e:
return None
def ocr_rating_class(img_cropped: Mat):
try:
text = image_to_string(img_cropped) # type: str
text = text.lower()
if "past" in text:
return 0
elif "present" in text:
return 1
elif "future" in text:
return 2
elif "beyond" in text:
return 3
except Exception as e:
return None
def ocr_title(img_cropped: Mat):
try:
return image_to_string(img_cropped).replace("\n", "")
except Exception as e:
return ""

View File

@ -0,0 +1,67 @@
from dataclasses import dataclass
from typing import Optional
from cv2 import COLOR_BGR2HSV, GaussianBlur, cvtColor, imread
from .crop import *
from .device import Device
from .mask import *
from .ocr import *
@dataclass(kw_only=True)
class RecognizeResult:
pure: Optional[int]
far: Optional[int]
lost: Optional[int]
score: Optional[int]
max_recall: Optional[int]
rating_class: Optional[int]
title: str
def recognize(img_filename: str, device: Device):
img = imread(img_filename)
img_hsv = cvtColor(img, COLOR_BGR2HSV)
pure_roi = crop_to_pure(img_hsv, device)
pure_roi = mask_gray(pure_roi)
pure_roi = GaussianBlur(pure_roi, (3, 3), 0)
pure = ocr_pure(pure_roi)
far_roi = crop_to_far(img_hsv, device)
far_roi = mask_gray(far_roi)
far_roi = GaussianBlur(far_roi, (3, 3), 0)
far = ocr_far_lost(far_roi)
lost_roi = crop_to_lost(img_hsv, device)
lost_roi = mask_gray(lost_roi)
lost_roi = GaussianBlur(lost_roi, (3, 3), 0)
lost = ocr_far_lost(lost_roi)
score_roi = crop_to_score(img_hsv, device)
score_roi = mask_white(score_roi)
score_roi = GaussianBlur(score_roi, (3, 3), 0)
score = ocr_score(score_roi)
max_recall_roi = crop_to_max_recall(img_hsv, device)
max_recall_roi = mask_gray(max_recall_roi)
max_recall = ocr_max_recall(max_recall_roi)
rating_class_roi = crop_to_rating_class(img_hsv, device)
rating_class_roi = mask_rating_class(rating_class_roi)
rating_class = ocr_rating_class(rating_class_roi)
title_roi = crop_to_title(img_hsv, device)
title_roi = mask_white(title_roi)
title = ocr_title(title_roi)
return RecognizeResult(
pure=pure,
far=far,
lost=lost,
score=score,
max_recall=max_recall,
rating_class=rating_class,
title=title,
)

View File

@ -0,0 +1,163 @@
from base64 import b64decode
from time import sleep
from typing import Dict, List, Literal, Tuple, TypedDict
from cv2 import (
CHAIN_APPROX_SIMPLE,
COLOR_BGR2GRAY,
COLOR_GRAY2BGR,
FONT_HERSHEY_SIMPLEX,
IMREAD_GRAYSCALE,
RETR_EXTERNAL,
THRESH_BINARY_INV,
TM_CCOEFF_NORMED,
Mat,
boundingRect,
cvtColor,
destroyAllWindows,
findContours,
imdecode,
imread,
imshow,
matchTemplate,
minMaxLoc,
putText,
rectangle,
threshold,
waitKey,
)
from imutils import contours, grab_contours
from numpy import frombuffer as np_frombuffer
from numpy import uint8
from ._builtin_templates import GeoSansLight_Italic, GeoSansLight_Regular
def load_digit_template(filename: str) -> Dict[int, Mat]:
"""
Arguments:
filename -- An image with white background and black "0 1 2 3 4 5 6 7 8 9" text.
Returns:
dict[int, cv2.Mat]
"""
# https://pyimagesearch.com/2017/07/17/credit-card-ocr-with-opencv-and-python/
ref = imread(filename)
ref = cvtColor(ref, COLOR_BGR2GRAY)
ref = threshold(ref, 10, 255, THRESH_BINARY_INV)[1]
refCnts = findContours(ref.copy(), RETR_EXTERNAL, CHAIN_APPROX_SIMPLE)
refCnts = grab_contours(refCnts)
refCnts = contours.sort_contours(refCnts, method="left-to-right")[0]
digits = {}
for i, cnt in enumerate(refCnts):
(x, y, w, h) = boundingRect(cnt)
roi = ref[y : y + h, x : x + w]
digits[i] = roi
return digits
def load_builtin_digit_template(
name: Literal["GeoSansLight-Regular", "GeoSansLight-Italic"]
):
name_builtin_template_b64_map = {
"GeoSansLight-Regular": GeoSansLight_Regular,
"GeoSansLight-Italic": GeoSansLight_Italic,
}
template_b64 = name_builtin_template_b64_map[name]
return {
int(key): imdecode(np_frombuffer(b64decode(b64str), uint8), IMREAD_GRAYSCALE)
for key, b64str in template_b64.items()
}
class MatchTemplateMultipleResult(TypedDict):
max_val: float
xywh: Tuple[int, int, int, int]
def matchTemplateMultiple(
src: Mat, template: Mat, threshold: float = 0.1
) -> List[MatchTemplateMultipleResult]:
"""
Returns:
A list of tuple[x, y, w, h] representing the matched rectangle
"""
template_result = matchTemplate(src, template, TM_CCOEFF_NORMED)
min_val, max_val, min_loc, max_loc = minMaxLoc(template_result)
template_h, template_w = template.shape[:2]
results = []
# debug
# imshow("templ", template)
# waitKey(750)
# destroyAllWindows()
# https://stackoverflow.com/a/66848923/16484891
# CC BY-SA 4.0
prev_min_val, prev_max_val, prev_min_loc, prev_max_loc = None, None, None, None
while max_val > threshold:
min_val, max_val, min_loc, max_loc = minMaxLoc(template_result)
# Prevent infinite loop. If those 4 values are the same as previous ones, break the loop.
if (
prev_min_val == min_val
and prev_max_val == max_val
and prev_min_loc == min_loc
and prev_max_loc == max_loc
):
break
else:
prev_min_val, prev_max_val, prev_min_loc, prev_max_loc = (
min_val,
max_val,
min_loc,
max_loc,
)
if max_val > threshold:
# Prevent start_row, end_row, start_col, end_col be out of range of image
start_row = max(0, max_loc[1] - template_h // 2)
start_col = max(0, max_loc[0] - template_w // 2)
end_row = min(template_result.shape[0], max_loc[1] + template_h // 2 + 1)
end_col = min(template_result.shape[1], max_loc[0] + template_w // 2 + 1)
template_result[start_row:end_row, start_col:end_col] = 0
results.append(
{
"max_val": max_val,
"xywh": (
max_loc[0],
max_loc[1],
max_loc[0] + template_w + 1,
max_loc[1] + template_h + 1,
),
}
)
# debug
# src_dbg = cvtColor(src, COLOR_GRAY2BGR)
# src_dbg = rectangle(
# src_dbg,
# (max_loc[0], max_loc[1]),
# (
# max_loc[0] + template_w + 1,
# max_loc[1] + template_h + 1,
# ),
# (0, 255, 0),
# thickness=3,
# )
# src_dbg = putText(
# src_dbg,
# f"{max_val:.5f}",
# (5, src_dbg.shape[0] - 5),
# FONT_HERSHEY_SIMPLEX,
# 1,
# (0, 255, 0),
# thickness=2,
# )
# imshow("src_rect", src_dbg)
# imshow("templ", template)
# waitKey(750)
# destroyAllWindows()
return results