From 80ec1b203a1129a01d039c48fd6d9b9cca9a4a05 Mon Sep 17 00:00:00 2001 From: 283375 Date: Tue, 5 Sep 2023 01:30:44 +0800 Subject: [PATCH] impr: digit preprocess --- src/arcaea_offline_ocr/device/v2/ocr.py | 65 +++++++++++++++++++++---- 1 file changed, 56 insertions(+), 9 deletions(-) diff --git a/src/arcaea_offline_ocr/device/v2/ocr.py b/src/arcaea_offline_ocr/device/v2/ocr.py index 102cfd4..3a9b93a 100644 --- a/src/arcaea_offline_ocr/device/v2/ocr.py +++ b/src/arcaea_offline_ocr/device/v2/ocr.py @@ -1,9 +1,19 @@ +import math +from functools import lru_cache +from typing import Sequence + import cv2 import numpy as np from ...crop import crop_xywh from ...mask import mask_byd, mask_ftr, mask_gray, mask_prs, mask_pst, mask_white -from ...ocr import ocr_digits_by_contour_knn +from ...ocr import ( + FixRects, + ocr_digit_samples_knn, + ocr_digits_by_contour_knn, + preprocess_hog, + resize_fill_square, +) from ...sift_db import SIFTDatabase from ...types import Mat, cv2_ml_KNearest from ..shared import DeviceOcrResult @@ -37,10 +47,42 @@ class DeviceV2Ocr: def sift_db(self, value: SIFTDatabase): self.__sift_db = value - def _base_ocr_digits(self, roi_masked: Mat): - return ocr_digits_by_contour_knn( - find_digits_preprocess(roi_masked), self.knn_model + @lru_cache + def _get_digit_widths(self, num_list: Sequence[int], factor: float): + widths = set() + for n in num_list: + lower = math.floor(n * factor) + upper = math.ceil(n * factor) + widths.update(range(lower, upper + 1)) + return widths + + def _base_ocr_pfl(self, roi_masked: Mat, factor: float = 1.0): + contours, _ = cv2.findContours( + roi_masked, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE ) + filtered_contours = [c for c in contours if cv2.contourArea(c) >= 5 * factor] + rects = [cv2.boundingRect(c) for c in filtered_contours] + rects = FixRects.connect_broken(rects, roi_masked.shape[1], roi_masked.shape[0]) + rect_contour_map = dict(zip(rects, filtered_contours)) + + filtered_rects = [r for r in rects if r[2] >= 5 * factor and r[3] >= 6 * factor] + filtered_rects = FixRects.split_connected(roi_masked, filtered_rects) + filtered_rects = sorted(filtered_rects, key=lambda r: r[0]) + + roi_ocr = roi_masked.copy() + filtered_contours_flattened = {tuple(c.flatten()) for c in filtered_contours} + for contour in contours: + if tuple(contour.flatten()) in filtered_contours_flattened: + continue + roi_ocr = cv2.fillPoly(roi_ocr, [contour], [0]) + digit_rois = [ + resize_fill_square(crop_xywh(roi_ocr, r), 20) + for r in sorted(filtered_rects, key=lambda r: r[0]) + ] + # [cv2.imshow(f"r{i}", r) for i, r in enumerate(digit_rois)] + # cv2.waitKey(0) + samples = preprocess_hog(digit_rois) + return ocr_digit_samples_knn(samples, self.knn_model) def ocr_song_id(self, rois: DeviceV2Rois): cover = cv2.cvtColor(rois.cover, cv2.COLOR_BGR2GRAY) @@ -54,19 +96,24 @@ class DeviceV2Ocr: def ocr_score(self, rois: DeviceV2Rois): roi = cv2.cvtColor(rois.score, cv2.COLOR_BGR2HSV) roi = mask_white(roi) - return self._base_ocr_digits(roi) + contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + for contour in contours: + x, y, w, h = cv2.boundingRect(contour) + if h < roi.shape[0] * 0.6: + roi = cv2.fillPoly(roi, [contour], [0]) + return ocr_digits_by_contour_knn(roi, self.knn_model) def ocr_pure(self, rois: DeviceV2Rois): roi = mask_gray(rois.pure) - return self._base_ocr_digits(roi) + return self._base_ocr_pfl(roi, rois.sizes.factor) def ocr_far(self, rois: DeviceV2Rois): roi = mask_gray(rois.far) - return self._base_ocr_digits(roi) + return self._base_ocr_pfl(roi, rois.sizes.factor) def ocr_lost(self, rois: DeviceV2Rois): roi = mask_gray(rois.lost) - return self._base_ocr_digits(roi) + return self._base_ocr_pfl(roi, rois.sizes.factor) def ocr_max_recall(self, rois: DeviceV2Rois): roi = mask_gray(rois.max_recall_rating_class) @@ -78,7 +125,7 @@ class DeviceV2Ocr: [cv2.boundingRect(c) for c in contours], key=lambda r: r[0], reverse=True ) max_recall_roi = crop_xywh(roi, rects[0]) - return self._base_ocr_digits(max_recall_roi) + return ocr_digits_by_contour_knn(max_recall_roi, self.knn_model) def ocr(self, rois: DeviceV2Rois): song_id = self.ocr_song_id(rois)