mirror of
https://github.com/283375/arcaea-offline-ocr.git
synced 2025-04-18 21:10:17 +00:00
feat: DeviceV2 ocr API
This commit is contained in:
parent
7d885cfe46
commit
42979c67cb
@ -1,14 +1,16 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import attrs
|
import attrs
|
||||||
|
|
||||||
|
|
||||||
@attrs.define
|
@attrs.define
|
||||||
class DeviceOcrResult:
|
class DeviceOcrResult:
|
||||||
song_id: None
|
|
||||||
title: None
|
|
||||||
rating_class: int
|
rating_class: int
|
||||||
pure: int
|
pure: int
|
||||||
far: int
|
far: int
|
||||||
lost: int
|
lost: int
|
||||||
score: int
|
score: int
|
||||||
max_recall: int
|
max_recall: int
|
||||||
clear_type: None
|
song_id: Optional[str] = None
|
||||||
|
title: Optional[str] = None
|
||||||
|
clear_type: Optional[str] = None
|
||||||
|
@ -1,28 +1,21 @@
|
|||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from ...crop import crop_xywh
|
||||||
from ...mask import mask_byd, mask_ftr, mask_gray, mask_prs, mask_pst, mask_white
|
from ...mask import mask_byd, mask_ftr, mask_gray, mask_prs, mask_pst, mask_white
|
||||||
from ...ocr import ocr_digits_knn_model
|
from ...ocr import ocr_digits_by_contour_knn
|
||||||
|
from ...sift_db import SIFTDatabase
|
||||||
from ...types import Mat, cv2_ml_KNearest
|
from ...types import Mat, cv2_ml_KNearest
|
||||||
from ..shared import DeviceOcrResult
|
from ..shared import DeviceOcrResult
|
||||||
from .find import find_digits
|
from .find import find_digits_preprocess
|
||||||
from .rois import DeviceV2Rois
|
from .rois import DeviceV2Rois
|
||||||
|
from .shared import MAX_RECALL_CLOSE_KERNEL
|
||||||
|
|
||||||
|
|
||||||
class DeviceV2Ocr:
|
class DeviceV2Ocr:
|
||||||
def __init__(self, rois: DeviceV2Rois, knn_model: cv2_ml_KNearest):
|
def __init__(self, knn_model: cv2_ml_KNearest, sift_db: SIFTDatabase):
|
||||||
self.__rois = rois
|
|
||||||
self.__knn_model = knn_model
|
self.__knn_model = knn_model
|
||||||
|
self.__sift_db = sift_db
|
||||||
@property
|
|
||||||
def rois(self):
|
|
||||||
if not self.__rois:
|
|
||||||
raise ValueError("`rois` unset.")
|
|
||||||
return self.__rois
|
|
||||||
|
|
||||||
@rois.setter
|
|
||||||
def rois(self, rois: DeviceV2Rois):
|
|
||||||
self.__rois = rois
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def knn_model(self):
|
def knn_model(self):
|
||||||
@ -31,46 +24,77 @@ class DeviceV2Ocr:
|
|||||||
return self.__knn_model
|
return self.__knn_model
|
||||||
|
|
||||||
@knn_model.setter
|
@knn_model.setter
|
||||||
def knn_model(self, model: cv2_ml_KNearest):
|
def knn_model(self, value: cv2_ml_KNearest):
|
||||||
self.__knn_model = model
|
self.__knn_model = value
|
||||||
|
|
||||||
def _base_ocr_digits(self, roi_processed: Mat):
|
|
||||||
digits = find_digits(roi_processed)
|
|
||||||
result = ""
|
|
||||||
for digit in digits:
|
|
||||||
roi_result = ocr_digits_knn_model(digit, self.knn_model)
|
|
||||||
if roi_result is not None:
|
|
||||||
result += str(roi_result)
|
|
||||||
return int(result, base=10)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pure(self):
|
def sift_db(self):
|
||||||
roi = mask_gray(self.rois.pure)
|
if not self.__sift_db:
|
||||||
return self._base_ocr_digits(roi)
|
raise ValueError("`sift_db` unset.")
|
||||||
|
return self.__sift_db
|
||||||
|
|
||||||
@property
|
@sift_db.setter
|
||||||
def far(self):
|
def sift_db(self, value: SIFTDatabase):
|
||||||
roi = mask_gray(self.rois.far)
|
self.__sift_db = value
|
||||||
return self._base_ocr_digits(roi)
|
|
||||||
|
|
||||||
@property
|
def _base_ocr_digits(self, roi_masked: Mat):
|
||||||
def lost(self):
|
return ocr_digits_by_contour_knn(
|
||||||
roi = mask_gray(self.rois.lost)
|
find_digits_preprocess(roi_masked), self.knn_model
|
||||||
return self._base_ocr_digits(roi)
|
)
|
||||||
|
|
||||||
@property
|
def ocr_song_id(self, rois: DeviceV2Rois):
|
||||||
def score(self):
|
cover = cv2.cvtColor(rois.cover, cv2.COLOR_BGR2GRAY)
|
||||||
roi = cv2.cvtColor(self.rois.score, cv2.COLOR_BGR2HSV)
|
return self.sift_db.lookup_img(cover)[0]
|
||||||
|
|
||||||
|
def ocr_rating_class(self, rois: DeviceV2Rois):
|
||||||
|
roi = cv2.cvtColor(rois.max_recall_rating_class, cv2.COLOR_BGR2HSV)
|
||||||
|
results = [mask_pst(roi), mask_prs(roi), mask_ftr(roi), mask_byd(roi)]
|
||||||
|
return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0]
|
||||||
|
|
||||||
|
def ocr_score(self, rois: DeviceV2Rois):
|
||||||
|
roi = cv2.cvtColor(rois.score, cv2.COLOR_BGR2HSV)
|
||||||
roi = mask_white(roi)
|
roi = mask_white(roi)
|
||||||
return self._base_ocr_digits(roi)
|
return self._base_ocr_digits(roi)
|
||||||
|
|
||||||
@property
|
def ocr_pure(self, rois: DeviceV2Rois):
|
||||||
def rating_class(self):
|
roi = mask_gray(rois.pure)
|
||||||
roi = cv2.cvtColor(self.rois.max_recall_rating_class, cv2.COLOR_BGR2HSV)
|
return self._base_ocr_digits(roi)
|
||||||
results = [
|
|
||||||
mask_pst(roi),
|
def ocr_far(self, rois: DeviceV2Rois):
|
||||||
mask_prs(roi),
|
roi = mask_gray(rois.far)
|
||||||
mask_ftr(roi),
|
return self._base_ocr_digits(roi)
|
||||||
mask_byd(roi),
|
|
||||||
]
|
def ocr_lost(self, rois: DeviceV2Rois):
|
||||||
return max(enumerate(results), key=lambda e: np.count_nonzero(e[1]))[0]
|
roi = mask_gray(rois.lost)
|
||||||
|
return self._base_ocr_digits(roi)
|
||||||
|
|
||||||
|
def ocr_max_recall(self, rois: DeviceV2Rois):
|
||||||
|
roi = mask_gray(rois.max_recall_rating_class)
|
||||||
|
roi_closed = cv2.morphologyEx(roi, cv2.MORPH_CLOSE, MAX_RECALL_CLOSE_KERNEL)
|
||||||
|
contours, _ = cv2.findContours(
|
||||||
|
roi_closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
|
||||||
|
)
|
||||||
|
rects = sorted(
|
||||||
|
[cv2.boundingRect(c) for c in contours], key=lambda r: r[0], reverse=True
|
||||||
|
)
|
||||||
|
max_recall_roi = crop_xywh(roi, rects[0])
|
||||||
|
return self._base_ocr_digits(max_recall_roi)
|
||||||
|
|
||||||
|
def ocr(self, rois: DeviceV2Rois):
|
||||||
|
song_id = self.ocr_song_id(rois)
|
||||||
|
rating_class = self.ocr_rating_class(rois)
|
||||||
|
score = self.ocr_score(rois)
|
||||||
|
pure = self.ocr_pure(rois)
|
||||||
|
far = self.ocr_far(rois)
|
||||||
|
lost = self.ocr_lost(rois)
|
||||||
|
max_recall = self.ocr_max_recall(rois)
|
||||||
|
|
||||||
|
return DeviceOcrResult(
|
||||||
|
rating_class=rating_class,
|
||||||
|
pure=pure,
|
||||||
|
far=far,
|
||||||
|
lost=lost,
|
||||||
|
score=score,
|
||||||
|
max_recall=max_recall,
|
||||||
|
song_id=song_id,
|
||||||
|
)
|
||||||
|
@ -98,10 +98,10 @@ class Sizes:
|
|||||||
|
|
||||||
|
|
||||||
class DeviceV2Rois:
|
class DeviceV2Rois:
|
||||||
def __init__(self, device: DeviceV2):
|
def __init__(self, device: DeviceV2, img: Mat):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.sizes = Sizes(self.device.factor)
|
self.sizes = Sizes(self.device.factor)
|
||||||
self.__img = None
|
self.__img = img
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def construct_int_xywh_rect(x, y, w, h) -> XYWHRect:
|
def construct_int_xywh_rect(x, y, w, h) -> XYWHRect:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user