diff --git a/assets/fix_rects/broken.jpg b/assets/fix_rects/broken.jpg new file mode 100644 index 0000000..58f498a Binary files /dev/null and b/assets/fix_rects/broken.jpg differ diff --git a/assets/fix_rects/broken_masked.jpg b/assets/fix_rects/broken_masked.jpg new file mode 100644 index 0000000..6f9596c Binary files /dev/null and b/assets/fix_rects/broken_masked.jpg differ diff --git a/assets/fix_rects/broken_rects.jpg b/assets/fix_rects/broken_rects.jpg new file mode 100644 index 0000000..89a4e6a Binary files /dev/null and b/assets/fix_rects/broken_rects.jpg differ diff --git a/src/arcaea_offline_ocr/ocr.py b/src/arcaea_offline_ocr/ocr.py index 801befb..70300ea 100644 --- a/src/arcaea_offline_ocr/ocr.py +++ b/src/arcaea_offline_ocr/ocr.py @@ -1,5 +1,6 @@ import math -from typing import Tuple +from copy import deepcopy +from typing import Optional, Sequence, Tuple import cv2 import numpy as np @@ -10,12 +11,63 @@ from .mask import mask_byd, mask_ftr, mask_prs, mask_pst from .types import Mat, cv2_ml_KNearest __all__ = [ + "FixRects", "preprocess_hog", "ocr_digits_by_contour_get_samples", "ocr_digits_by_contour_knn", ] +class FixRects: + @staticmethod + def connect_broken( + rects: Sequence[Tuple[int, int, int, int]], + img_width: int, + img_height: int, + tolerance: Optional[int] = None, + ): + # for a "broken" digit, please refer to + # /assets/fix_rects/broken_masked.jpg + # the larger "5" in the image is a "broken" digit + + if tolerance is None: + tolerance = math.ceil(img_width * 0.08) + + new_rects = [] + consumed_rects = [] + for rect in rects: + x, y, w, h = rect + # grab those small rects + if not img_height * 0.1 <= h <= img_height * 0.6: + continue + + group = [] + # see if there's other rects that have near left & right borders + for other_rect in rects: + if rect == other_rect: + continue + ox, oy, ow, oh = other_rect + if abs(x - ox) < tolerance and abs((x + w) - (ox + ow)) < tolerance: + group.append(other_rect) + + if group: + group.append(rect) + consumed_rects.extend(group) + # calculate the new rect + new_x = min(r[0] for r in group) + new_y = min(r[1] for r in group) + new_right = max(r[0] + r[2] for r in group) + new_bottom = max(r[1] + r[3] for r in group) + new_w = new_right - new_x + new_h = new_bottom - new_y + new_rects.append((new_x, new_y, new_w, new_h)) + + return_rects = deepcopy(rects) + return_rects = [r for r in return_rects if r not in consumed_rects] + return_rects.extend(new_rects) + return return_rects + + def resize_fill_sqaure(img, target: int = 20): h, w = img.shape[:2] if h > w: