refactor!: device scenario

- Correct abstract class annotations
This commit is contained in:
2025-06-25 23:35:38 +08:00
parent 06156db9c2
commit 0055d9e8da
21 changed files with 227 additions and 205 deletions

View File

@ -1,3 +1,2 @@
from .crop import *
from .device import *
from .utils import *

View File

@ -1,2 +0,0 @@
from .common import DeviceOcrResult
from .ocr import DeviceOcr

View File

@ -1,17 +0,0 @@
from dataclasses import dataclass
from typing import Optional
@dataclass
class DeviceOcrResult:
rating_class: int
score: int
pure: Optional[int] = None
far: Optional[int] = None
lost: Optional[int] = None
max_recall: Optional[int] = None
song_id: Optional[str] = None
song_id_possibility: Optional[float] = None
clear_status: Optional[int] = None
partner_id: Optional[str] = None
partner_id_possibility: Optional[float] = None

View File

@ -1,3 +0,0 @@
from .definition import *
from .extractor import *
from .masker import *

View File

@ -1,2 +0,0 @@
from .auto import *
from .common import DeviceRois

View File

@ -1,15 +0,0 @@
from typing import Tuple
Rect = Tuple[int, int, int, int]
class DeviceRois:
pure: Rect
far: Rect
lost: Rect
score: Rect
rating_class: Rect
max_recall: Rect
jacket: Rect
clear_status: Rect
partner_icon: Rect

View File

@ -1 +0,0 @@
from .common import DeviceRoisExtractor

View File

@ -1,48 +0,0 @@
from ....crop import crop_xywh
from ....types import Mat
from ..definition.common import DeviceRois
class DeviceRoisExtractor:
def __init__(self, img: Mat, rois: DeviceRois):
self.img = img
self.sizes = rois
def __construct_int_rect(self, rect):
return tuple(round(r) for r in rect)
@property
def pure(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.pure))
@property
def far(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.far))
@property
def lost(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.lost))
@property
def score(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.score))
@property
def jacket(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.jacket))
@property
def rating_class(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.rating_class))
@property
def max_recall(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.max_recall))
@property
def clear_status(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.clear_status))
@property
def partner_icon(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.partner_icon))

View File

@ -1,2 +0,0 @@
from .auto import *
from .common import DeviceRoisMasker

View File

@ -1,59 +0,0 @@
from ....types import Mat
class DeviceRoisMasker:
@classmethod
def pure(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def far(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def lost(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def score(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def rating_class_pst(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def rating_class_prs(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def rating_class_ftr(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def rating_class_byd(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def rating_class_etr(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def max_recall(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()

View File

@ -0,0 +1,13 @@
from .extractor import DeviceRoisExtractor
from .impl import DeviceScenario
from .masker import DeviceRoisMaskerAutoT1, DeviceRoisMaskerAutoT2
from .rois import DeviceRoisAutoT1, DeviceRoisAutoT2
__all__ = [
"DeviceRoisMaskerAutoT1",
"DeviceRoisMaskerAutoT2",
"DeviceRoisAutoT1",
"DeviceRoisAutoT2",
"DeviceRoisExtractor",
"DeviceScenario",
]

View File

@ -0,0 +1,3 @@
from .base import DeviceRoisExtractor
__all__ = ["DeviceRoisExtractor"]

View File

@ -0,0 +1,46 @@
from arcaea_offline_ocr.crop import crop_xywh
from arcaea_offline_ocr.types import Mat
from ..rois.base import DeviceRois
class DeviceRoisExtractor:
def __init__(self, img: Mat, rois: DeviceRois):
self.img = img
self.sizes = rois
@property
def pure(self):
return crop_xywh(self.img, self.sizes.pure.rounded())
@property
def far(self):
return crop_xywh(self.img, self.sizes.far.rounded())
@property
def lost(self):
return crop_xywh(self.img, self.sizes.lost.rounded())
@property
def score(self):
return crop_xywh(self.img, self.sizes.score.rounded())
@property
def jacket(self):
return crop_xywh(self.img, self.sizes.jacket.rounded())
@property
def rating_class(self):
return crop_xywh(self.img, self.sizes.rating_class.rounded())
@property
def max_recall(self):
return crop_xywh(self.img, self.sizes.max_recall.rounded())
@property
def clear_status(self):
return crop_xywh(self.img, self.sizes.clear_status.rounded())
@property
def partner_icon(self):
return crop_xywh(self.img, self.sizes.partner_icon.rounded())

View File

@ -1,26 +1,31 @@
import cv2
import numpy as np
from ..phash_db import ImagePhashDatabase
from ..providers.knn import OcrKNearestTextProvider
from ..types import Mat
from .common import DeviceOcrResult
from .rois.extractor import DeviceRoisExtractor
from .rois.masker import DeviceRoisMasker
from arcaea_offline_ocr.providers import (
ImageCategory,
ImageIdProvider,
OcrKNearestTextProvider,
)
from arcaea_offline_ocr.scenarios.base import OcrScenarioResult
from arcaea_offline_ocr.types import Mat
from .base import DeviceScenarioBase
from .extractor import DeviceRoisExtractor
from .masker import DeviceRoisMasker
class DeviceOcr:
class DeviceScenario(DeviceScenarioBase):
def __init__(
self,
extractor: DeviceRoisExtractor,
masker: DeviceRoisMasker,
knn_provider: OcrKNearestTextProvider,
phash_db: ImagePhashDatabase,
image_id_provider: ImageIdProvider,
):
self.extractor = extractor
self.masker = masker
self.knn_provider = knn_provider
self.phash_db = phash_db
self.image_id_provider = image_id_provider
def pfl(self, roi_gray: Mat, factor: float = 1.25):
def contour_filter(cnt):
@ -93,14 +98,12 @@ class DeviceOcr:
]
return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0]
def lookup_song_id(self):
return self.phash_db.lookup_jacket(
cv2.cvtColor(self.extractor.jacket, cv2.COLOR_BGR2GRAY)
def song_id_results(self):
return self.image_id_provider.results(
cv2.cvtColor(self.extractor.jacket, cv2.COLOR_BGR2GRAY),
ImageCategory.JACKET,
)
def song_id(self):
return self.lookup_song_id()[0]
@staticmethod
def preprocess_char_icon(img_gray: Mat):
h, w = img_gray.shape[:2]
@ -114,21 +117,19 @@ class DeviceOcr:
np.array([[0, h], [round(w / 2), h], [0, round(h / 2)]], np.int32),
np.array([[w, h], [round(w / 2), h], [w, round(h / 2)]], np.int32),
],
(128),
(128,),
)
return img
def lookup_partner_id(self):
return self.phash_db.lookup_partner_icon(
def partner_id_results(self):
return self.image_id_provider.results(
self.preprocess_char_icon(
cv2.cvtColor(self.extractor.partner_icon, cv2.COLOR_BGR2GRAY)
)
),
ImageCategory.PARTNER_ICON,
)
def partner_id(self):
return self.lookup_partner_id()[0]
def ocr(self) -> DeviceOcrResult:
def result(self):
rating_class = self.rating_class()
pure = self.pure()
far = self.far()
@ -137,20 +138,18 @@ class DeviceOcr:
max_recall = self.max_recall()
clear_status = self.clear_status()
hash_len = self.phash_db.hash_size**2
song_id, song_id_distance = self.lookup_song_id()
partner_id, partner_id_distance = self.lookup_partner_id()
song_id_results = self.song_id_results()
partner_id_results = self.partner_id_results()
return DeviceOcrResult(
return OcrScenarioResult(
song_id=song_id_results[0].image_id,
song_id_results=song_id_results,
rating_class=rating_class,
pure=pure,
far=far,
lost=lost,
score=score,
max_recall=max_recall,
song_id=song_id,
song_id_possibility=1 - song_id_distance / hash_len,
partner_id_results=partner_id_results,
clear_status=clear_status,
partner_id=partner_id,
partner_id_possibility=1 - partner_id_distance / hash_len,
)

View File

@ -0,0 +1,9 @@
from .auto import DeviceRoisMaskerAuto, DeviceRoisMaskerAutoT1, DeviceRoisMaskerAutoT2
from .base import DeviceRoisMasker
__all__ = [
"DeviceRoisMaskerAuto",
"DeviceRoisMaskerAutoT1",
"DeviceRoisMaskerAutoT2",
"DeviceRoisMasker",
]

View File

@ -1,13 +1,12 @@
import cv2
import numpy as np
from ....types import Mat
from .common import DeviceRoisMasker
from arcaea_offline_ocr.types import Mat
from .base import DeviceRoisMasker
class DeviceRoisMaskerAuto(DeviceRoisMasker):
# pylint: disable=abstract-method
@staticmethod
def mask_bgr_in_hsv(roi_bgr: Mat, hsv_lower: Mat, hsv_upper: Mat):
return cv2.inRange(

View File

@ -0,0 +1,61 @@
from abc import ABC, abstractmethod
from arcaea_offline_ocr.types import Mat
class DeviceRoisMasker(ABC):
@classmethod
@abstractmethod
def pure(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def far(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def lost(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def score(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def rating_class_pst(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def rating_class_prs(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def rating_class_ftr(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def rating_class_byd(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def rating_class_etr(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def max_recall(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat: ...

View File

@ -0,0 +1,9 @@
from .auto import DeviceRoisAuto, DeviceRoisAutoT1, DeviceRoisAutoT2
from .base import DeviceRois
__all__ = [
"DeviceRois",
"DeviceRoisAuto",
"DeviceRoisAutoT1",
"DeviceRoisAutoT2",
]

View File

@ -1,6 +1,6 @@
from .common import DeviceRois
from arcaea_offline_ocr.types import XYWHRect
__all__ = ["DeviceRoisAuto", "DeviceRoisAutoT1", "DeviceRoisAutoT2"]
from .base import DeviceRois
class DeviceRoisAuto(DeviceRois):
@ -50,7 +50,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property
def pure(self):
return (
return XYWHRect(
self.pfl_x,
self.layout_area_h_mid + 110 * self.factor,
self.pfl_w,
@ -59,7 +59,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property
def far(self):
return (
return XYWHRect(
self.pfl_x,
self.pure[1] + self.pure[3] + 12 * self.factor,
self.pfl_w,
@ -68,7 +68,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property
def lost(self):
return (
return XYWHRect(
self.pfl_x,
self.far[1] + self.far[3] + 10 * self.factor,
self.pfl_w,
@ -79,7 +79,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
def score(self):
w = 280 * self.factor
h = 45 * self.factor
return (
return XYWHRect(
self.w_mid - w / 2,
self.layout_area_h_mid - 75 * self.factor - h,
w,
@ -88,7 +88,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property
def rating_class(self):
return (
return XYWHRect(
self.w_mid - 610 * self.factor,
self.layout_area_h_mid - 180 * self.factor,
265 * self.factor,
@ -97,7 +97,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property
def max_recall(self):
return (
return XYWHRect(
self.w_mid - 465 * self.factor,
self.layout_area_h_mid - 215 * self.factor,
150 * self.factor,
@ -106,7 +106,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property
def jacket(self):
return (
return XYWHRect(
self.w_mid - 610 * self.factor,
self.layout_area_h_mid - 143 * self.factor,
375 * self.factor,
@ -117,7 +117,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
def clear_status(self):
w = 550 * self.factor
h = 60 * self.factor
return (
return XYWHRect(
self.w_mid - w / 2,
self.layout_area_h_mid - 155 * self.factor - h,
w * 0.4,
@ -128,7 +128,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
def partner_icon(self):
w = 90 * self.factor
h = 75 * self.factor
return (self.w_mid - w / 2, 0, w, h)
return XYWHRect(self.w_mid - w / 2, 0, w, h)
class DeviceRoisAutoT2(DeviceRoisAuto):
@ -174,7 +174,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property
def pure(self):
return (
return XYWHRect(
self.pfl_x,
self.layout_area_h_mid + 175 * self.factor,
self.pfl_w,
@ -183,7 +183,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property
def far(self):
return (
return XYWHRect(
self.pfl_x,
self.pure[1] + self.pure[3] + 30 * self.factor,
self.pfl_w,
@ -192,7 +192,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property
def lost(self):
return (
return XYWHRect(
self.pfl_x,
self.far[1] + self.far[3] + 35 * self.factor,
self.pfl_w,
@ -203,7 +203,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
def score(self):
w = 420 * self.factor
h = 70 * self.factor
return (
return XYWHRect(
self.w_mid - w / 2,
self.layout_area_h_mid - 110 * self.factor - h,
w,
@ -212,7 +212,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property
def rating_class(self):
return (
return XYWHRect(
max(0, self.w_mid - 965 * self.factor),
self.layout_area_h_mid - 330 * self.factor,
350 * self.factor,
@ -221,7 +221,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property
def max_recall(self):
return (
return XYWHRect(
self.w_mid - 625 * self.factor,
self.layout_area_h_mid - 275 * self.factor,
150 * self.factor,
@ -230,7 +230,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property
def jacket(self):
return (
return XYWHRect(
self.w_mid - 915 * self.factor,
self.layout_area_h_mid - 215 * self.factor,
565 * self.factor,
@ -241,7 +241,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
def clear_status(self):
w = 825 * self.factor
h = 90 * self.factor
return (
return XYWHRect(
self.w_mid - w / 2,
self.layout_area_h_mid - 235 * self.factor - h,
w * 0.4,
@ -252,4 +252,4 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
def partner_icon(self):
w = 135 * self.factor
h = 110 * self.factor
return (self.w_mid - w / 2, 0, w, h)
return XYWHRect(self.w_mid - w / 2, 0, w, h)

View File

@ -0,0 +1,33 @@
from abc import ABC, abstractmethod
from arcaea_offline_ocr.types import XYWHRect
class DeviceRois(ABC):
@property
@abstractmethod
def pure(self) -> XYWHRect: ...
@property
@abstractmethod
def far(self) -> XYWHRect: ...
@property
@abstractmethod
def lost(self) -> XYWHRect: ...
@property
@abstractmethod
def score(self) -> XYWHRect: ...
@property
@abstractmethod
def rating_class(self) -> XYWHRect: ...
@property
@abstractmethod
def max_recall(self) -> XYWHRect: ...
@property
@abstractmethod
def jacket(self) -> XYWHRect: ...
@property
@abstractmethod
def clear_status(self) -> XYWHRect: ...
@property
@abstractmethod
def partner_icon(self) -> XYWHRect: ...