From ea31aaf00c40cee4bd7f7f3b32cfa59c05f526f0 Mon Sep 17 00:00:00 2001 From: 283375 Date: Sat, 10 Jun 2023 18:32:23 +0800 Subject: [PATCH] chore!: add recognize parts --- src/arcaea_offline_ocr/ocr.py | 6 ++- src/arcaea_offline_ocr/recognize.py | 60 ++++++++++++++++++++--------- 2 files changed, 46 insertions(+), 20 deletions(-) diff --git a/src/arcaea_offline_ocr/ocr.py b/src/arcaea_offline_ocr/ocr.py index 452c97b..a705790 100644 --- a/src/arcaea_offline_ocr/ocr.py +++ b/src/arcaea_offline_ocr/ocr.py @@ -88,7 +88,11 @@ def filter_digit_results( if possible_results: # Sort the results based on "max_val" in descending order and select the top result - result = sorted(possible_results, key=lambda d: d["max_val"], reverse=True)[0] + result = sorted( + possible_results, + key=lambda d: d["max_val"], + reverse=True, + )[0] result_digit = result["digit"] # Get the digit value from the result result.pop("digit", None) # Remove the digit key from the result diff --git a/src/arcaea_offline_ocr/recognize.py b/src/arcaea_offline_ocr/recognize.py index 3b7a001..ef4d5b1 100644 --- a/src/arcaea_offline_ocr/recognize.py +++ b/src/arcaea_offline_ocr/recognize.py @@ -9,16 +9,42 @@ from .mask import * from .ocr import * -def process_digit_ocr_img(img_hsv, mask=Callable[[Mat], Mat]): - img_hsv = mask(img_hsv) - img_hsv = GaussianBlur(img_hsv, (3, 3), 0) - return img_hsv +def process_digits_ocr_img(img_hsv_cropped: Mat, mask=Callable[[Mat], Mat]): + img_hsv_cropped = mask(img_hsv_cropped) + img_hsv_cropped = GaussianBlur(img_hsv_cropped, (3, 3), 0) + return img_hsv_cropped -def process_tesseract_ocr_img(img_hsv, mask=Callable[[Mat], Mat]): - img_hsv = mask(img_hsv) - img_hsv = GaussianBlur(img_hsv, (1, 1), 0) - return img_hsv +def process_tesseract_ocr_img(img_hsv_cropped: Mat, mask=Callable[[Mat], Mat]): + img_hsv_cropped = mask(img_hsv_cropped) + img_hsv_cropped = GaussianBlur(img_hsv_cropped, (1, 1), 0) + return img_hsv_cropped + + +def recognize_pure(img_hsv_cropped: Mat): + return ocr_pure(process_digits_ocr_img(img_hsv_cropped, mask=mask_gray)) + + +def recognize_far_lost(img_hsv_cropped: Mat): + return ocr_far_lost(process_digits_ocr_img(img_hsv_cropped, mask=mask_gray)) + + +def recognize_score(img_hsv_cropped: Mat): + return ocr_score(process_digits_ocr_img(img_hsv_cropped, mask=mask_white)) + + +def recognize_max_recall(img_hsv_cropped: Mat): + return ocr_max_recall(process_tesseract_ocr_img(img_hsv_cropped, mask=mask_gray)) + + +def recognize_rating_class(img_hsv_cropped: Mat): + return ocr_rating_class( + process_tesseract_ocr_img(img_hsv_cropped, mask=mask_rating_class) + ) + + +def recognize_title(img_hsv_cropped: Mat): + return ocr_title(process_tesseract_ocr_img(img_hsv_cropped, mask=mask_white)) @dataclass(kw_only=True) @@ -37,29 +63,25 @@ def recognize(img_filename: str, device: Device): img_hsv = cvtColor(img, COLOR_BGR2HSV) pure_roi = crop_to_pure(img_hsv, device) - pure = ocr_pure(process_digit_ocr_img(pure_roi, mask=mask_gray)) + pure = recognize_pure(pure_roi) far_roi = crop_to_far(img_hsv, device) - far = ocr_far_lost(process_digit_ocr_img(far_roi, mask=mask_gray)) + far = recognize_far_lost(far_roi) lost_roi = crop_to_lost(img_hsv, device) - lost = ocr_far_lost(process_digit_ocr_img(lost_roi, mask=mask_gray)) + lost = recognize_far_lost(lost_roi) score_roi = crop_to_score(img_hsv, device) - score = ocr_score(process_digit_ocr_img(score_roi, mask=mask_white)) + score = recognize_score(score_roi) max_recall_roi = crop_to_max_recall(img_hsv, device) - max_recall = ocr_max_recall( - process_tesseract_ocr_img(max_recall_roi, mask=mask_gray) - ) + max_recall = recognize_max_recall(max_recall_roi) rating_class_roi = crop_to_rating_class(img_hsv, device) - rating_class = ocr_rating_class( - process_tesseract_ocr_img(rating_class_roi, mask=mask_rating_class) - ) + rating_class = recognize_rating_class(rating_class_roi) title_roi = crop_to_title(img_hsv, device) - title = ocr_title(process_tesseract_ocr_img(title_roi, mask=mask_white)) + title = recognize_title(title_roi) return RecognizeResult( pure=pure,