148 lines
5.3 KiB
Python

import math
from functools import lru_cache
from typing import Sequence
import cv2
import numpy as np
from ...crop import crop_xywh
from ...mask import mask_byd, mask_ftr, mask_gray, mask_prs, mask_pst, mask_white
from ...ocr import (
FixRects,
ocr_digit_samples_knn,
ocr_digits_by_contour_knn,
preprocess_hog,
resize_fill_square,
)
from ...sift_db import SIFTDatabase
from ...types import Mat, cv2_ml_KNearest
from ..shared import DeviceOcrResult
from .preprocess import find_digits_preprocess
from .rois import DeviceV2Rois
from .shared import MAX_RECALL_CLOSE_KERNEL
class DeviceV2Ocr:
def __init__(self, knn_model: cv2_ml_KNearest, sift_db: SIFTDatabase):
self.__knn_model = knn_model
self.__sift_db = sift_db
@property
def knn_model(self):
if not self.__knn_model:
raise ValueError("`knn_model` unset.")
return self.__knn_model
@knn_model.setter
def knn_model(self, value: cv2_ml_KNearest):
self.__knn_model = value
@property
def sift_db(self):
if not self.__sift_db:
raise ValueError("`sift_db` unset.")
return self.__sift_db
@sift_db.setter
def sift_db(self, value: SIFTDatabase):
self.__sift_db = value
@lru_cache
def _get_digit_widths(self, num_list: Sequence[int], factor: float):
widths = set()
for n in num_list:
lower = math.floor(n * factor)
upper = math.ceil(n * factor)
widths.update(range(lower, upper + 1))
return widths
def _base_ocr_pfl(self, roi_masked: Mat, factor: float = 1.0):
contours, _ = cv2.findContours(
roi_masked, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
)
filtered_contours = [c for c in contours if cv2.contourArea(c) >= 5 * factor]
rects = [cv2.boundingRect(c) for c in filtered_contours]
rects = FixRects.connect_broken(rects, roi_masked.shape[1], roi_masked.shape[0])
rect_contour_map = dict(zip(rects, filtered_contours))
filtered_rects = [r for r in rects if r[2] >= 5 * factor and r[3] >= 6 * factor]
filtered_rects = FixRects.split_connected(roi_masked, filtered_rects)
filtered_rects = sorted(filtered_rects, key=lambda r: r[0])
roi_ocr = roi_masked.copy()
filtered_contours_flattened = {tuple(c.flatten()) for c in filtered_contours}
for contour in contours:
if tuple(contour.flatten()) in filtered_contours_flattened:
continue
roi_ocr = cv2.fillPoly(roi_ocr, [contour], [0])
digit_rois = [
resize_fill_square(crop_xywh(roi_ocr, r), 20)
for r in sorted(filtered_rects, key=lambda r: r[0])
]
# [cv2.imshow(f"r{i}", r) for i, r in enumerate(digit_rois)]
# cv2.waitKey(0)
samples = preprocess_hog(digit_rois)
return ocr_digit_samples_knn(samples, self.knn_model)
def ocr_song_id(self, rois: DeviceV2Rois):
jacket = cv2.cvtColor(rois.jacket, cv2.COLOR_BGR2GRAY)
return self.sift_db.lookup_img(jacket)[0]
def ocr_rating_class(self, rois: DeviceV2Rois):
roi = cv2.cvtColor(rois.max_recall_rating_class, cv2.COLOR_BGR2HSV)
results = [mask_pst(roi), mask_prs(roi), mask_ftr(roi), mask_byd(roi)]
return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0]
def ocr_score(self, rois: DeviceV2Rois):
roi = cv2.cvtColor(rois.score, cv2.COLOR_BGR2HSV)
roi = mask_white(roi)
contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
if h < roi.shape[0] * 0.6:
roi = cv2.fillPoly(roi, [contour], [0])
return ocr_digits_by_contour_knn(roi, self.knn_model)
def ocr_pure(self, rois: DeviceV2Rois):
roi = mask_gray(rois.pure)
return self._base_ocr_pfl(roi, rois.sizes.factor)
def ocr_far(self, rois: DeviceV2Rois):
roi = mask_gray(rois.far)
return self._base_ocr_pfl(roi, rois.sizes.factor)
def ocr_lost(self, rois: DeviceV2Rois):
roi = mask_gray(rois.lost)
return self._base_ocr_pfl(roi, rois.sizes.factor)
def ocr_max_recall(self, rois: DeviceV2Rois):
roi = mask_gray(rois.max_recall_rating_class)
roi_closed = cv2.morphologyEx(roi, cv2.MORPH_CLOSE, MAX_RECALL_CLOSE_KERNEL)
contours, _ = cv2.findContours(
roi_closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
)
rects = sorted(
[cv2.boundingRect(c) for c in contours], key=lambda r: r[0], reverse=True
)
max_recall_roi = crop_xywh(roi, rects[0])
return ocr_digits_by_contour_knn(max_recall_roi, self.knn_model)
def ocr(self, rois: DeviceV2Rois):
song_id = self.ocr_song_id(rois)
rating_class = self.ocr_rating_class(rois)
score = self.ocr_score(rois)
pure = self.ocr_pure(rois)
far = self.ocr_far(rois)
lost = self.ocr_lost(rois)
max_recall = self.ocr_max_recall(rois)
return DeviceOcrResult(
rating_class=rating_class,
pure=pure,
far=far,
lost=lost,
score=score,
max_recall=max_recall,
song_id=song_id,
)