mirror of
https://github.com/283375/arcaea-offline-ocr.git
synced 2025-04-19 05:20:17 +00:00
impr: digit preprocess
This commit is contained in:
parent
7f62fd20b1
commit
80ec1b203a
@ -1,9 +1,19 @@
|
|||||||
|
import math
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ...crop import crop_xywh
|
from ...crop import crop_xywh
|
||||||
from ...mask import mask_byd, mask_ftr, mask_gray, mask_prs, mask_pst, mask_white
|
from ...mask import mask_byd, mask_ftr, mask_gray, mask_prs, mask_pst, mask_white
|
||||||
from ...ocr import ocr_digits_by_contour_knn
|
from ...ocr import (
|
||||||
|
FixRects,
|
||||||
|
ocr_digit_samples_knn,
|
||||||
|
ocr_digits_by_contour_knn,
|
||||||
|
preprocess_hog,
|
||||||
|
resize_fill_square,
|
||||||
|
)
|
||||||
from ...sift_db import SIFTDatabase
|
from ...sift_db import SIFTDatabase
|
||||||
from ...types import Mat, cv2_ml_KNearest
|
from ...types import Mat, cv2_ml_KNearest
|
||||||
from ..shared import DeviceOcrResult
|
from ..shared import DeviceOcrResult
|
||||||
@ -37,10 +47,42 @@ class DeviceV2Ocr:
|
|||||||
def sift_db(self, value: SIFTDatabase):
|
def sift_db(self, value: SIFTDatabase):
|
||||||
self.__sift_db = value
|
self.__sift_db = value
|
||||||
|
|
||||||
def _base_ocr_digits(self, roi_masked: Mat):
|
@lru_cache
|
||||||
return ocr_digits_by_contour_knn(
|
def _get_digit_widths(self, num_list: Sequence[int], factor: float):
|
||||||
find_digits_preprocess(roi_masked), self.knn_model
|
widths = set()
|
||||||
|
for n in num_list:
|
||||||
|
lower = math.floor(n * factor)
|
||||||
|
upper = math.ceil(n * factor)
|
||||||
|
widths.update(range(lower, upper + 1))
|
||||||
|
return widths
|
||||||
|
|
||||||
|
def _base_ocr_pfl(self, roi_masked: Mat, factor: float = 1.0):
|
||||||
|
contours, _ = cv2.findContours(
|
||||||
|
roi_masked, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
|
||||||
)
|
)
|
||||||
|
filtered_contours = [c for c in contours if cv2.contourArea(c) >= 5 * factor]
|
||||||
|
rects = [cv2.boundingRect(c) for c in filtered_contours]
|
||||||
|
rects = FixRects.connect_broken(rects, roi_masked.shape[1], roi_masked.shape[0])
|
||||||
|
rect_contour_map = dict(zip(rects, filtered_contours))
|
||||||
|
|
||||||
|
filtered_rects = [r for r in rects if r[2] >= 5 * factor and r[3] >= 6 * factor]
|
||||||
|
filtered_rects = FixRects.split_connected(roi_masked, filtered_rects)
|
||||||
|
filtered_rects = sorted(filtered_rects, key=lambda r: r[0])
|
||||||
|
|
||||||
|
roi_ocr = roi_masked.copy()
|
||||||
|
filtered_contours_flattened = {tuple(c.flatten()) for c in filtered_contours}
|
||||||
|
for contour in contours:
|
||||||
|
if tuple(contour.flatten()) in filtered_contours_flattened:
|
||||||
|
continue
|
||||||
|
roi_ocr = cv2.fillPoly(roi_ocr, [contour], [0])
|
||||||
|
digit_rois = [
|
||||||
|
resize_fill_square(crop_xywh(roi_ocr, r), 20)
|
||||||
|
for r in sorted(filtered_rects, key=lambda r: r[0])
|
||||||
|
]
|
||||||
|
# [cv2.imshow(f"r{i}", r) for i, r in enumerate(digit_rois)]
|
||||||
|
# cv2.waitKey(0)
|
||||||
|
samples = preprocess_hog(digit_rois)
|
||||||
|
return ocr_digit_samples_knn(samples, self.knn_model)
|
||||||
|
|
||||||
def ocr_song_id(self, rois: DeviceV2Rois):
|
def ocr_song_id(self, rois: DeviceV2Rois):
|
||||||
cover = cv2.cvtColor(rois.cover, cv2.COLOR_BGR2GRAY)
|
cover = cv2.cvtColor(rois.cover, cv2.COLOR_BGR2GRAY)
|
||||||
@ -54,19 +96,24 @@ class DeviceV2Ocr:
|
|||||||
def ocr_score(self, rois: DeviceV2Rois):
|
def ocr_score(self, rois: DeviceV2Rois):
|
||||||
roi = cv2.cvtColor(rois.score, cv2.COLOR_BGR2HSV)
|
roi = cv2.cvtColor(rois.score, cv2.COLOR_BGR2HSV)
|
||||||
roi = mask_white(roi)
|
roi = mask_white(roi)
|
||||||
return self._base_ocr_digits(roi)
|
contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
||||||
|
for contour in contours:
|
||||||
|
x, y, w, h = cv2.boundingRect(contour)
|
||||||
|
if h < roi.shape[0] * 0.6:
|
||||||
|
roi = cv2.fillPoly(roi, [contour], [0])
|
||||||
|
return ocr_digits_by_contour_knn(roi, self.knn_model)
|
||||||
|
|
||||||
def ocr_pure(self, rois: DeviceV2Rois):
|
def ocr_pure(self, rois: DeviceV2Rois):
|
||||||
roi = mask_gray(rois.pure)
|
roi = mask_gray(rois.pure)
|
||||||
return self._base_ocr_digits(roi)
|
return self._base_ocr_pfl(roi, rois.sizes.factor)
|
||||||
|
|
||||||
def ocr_far(self, rois: DeviceV2Rois):
|
def ocr_far(self, rois: DeviceV2Rois):
|
||||||
roi = mask_gray(rois.far)
|
roi = mask_gray(rois.far)
|
||||||
return self._base_ocr_digits(roi)
|
return self._base_ocr_pfl(roi, rois.sizes.factor)
|
||||||
|
|
||||||
def ocr_lost(self, rois: DeviceV2Rois):
|
def ocr_lost(self, rois: DeviceV2Rois):
|
||||||
roi = mask_gray(rois.lost)
|
roi = mask_gray(rois.lost)
|
||||||
return self._base_ocr_digits(roi)
|
return self._base_ocr_pfl(roi, rois.sizes.factor)
|
||||||
|
|
||||||
def ocr_max_recall(self, rois: DeviceV2Rois):
|
def ocr_max_recall(self, rois: DeviceV2Rois):
|
||||||
roi = mask_gray(rois.max_recall_rating_class)
|
roi = mask_gray(rois.max_recall_rating_class)
|
||||||
@ -78,7 +125,7 @@ class DeviceV2Ocr:
|
|||||||
[cv2.boundingRect(c) for c in contours], key=lambda r: r[0], reverse=True
|
[cv2.boundingRect(c) for c in contours], key=lambda r: r[0], reverse=True
|
||||||
)
|
)
|
||||||
max_recall_roi = crop_xywh(roi, rects[0])
|
max_recall_roi = crop_xywh(roi, rects[0])
|
||||||
return self._base_ocr_digits(max_recall_roi)
|
return ocr_digits_by_contour_knn(max_recall_roi, self.knn_model)
|
||||||
|
|
||||||
def ocr(self, rois: DeviceV2Rois):
|
def ocr(self, rois: DeviceV2Rois):
|
||||||
song_id = self.ocr_song_id(rois)
|
song_id = self.ocr_song_id(rois)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user