refactor!: device scenario

- Correct abstract class annotations
This commit is contained in:
2025-06-25 23:35:38 +08:00
parent 06156db9c2
commit 0055d9e8da
21 changed files with 227 additions and 205 deletions

View File

@ -1,3 +1,2 @@
from .crop import * from .crop import *
from .device import *
from .utils import * from .utils import *

View File

@ -1,2 +0,0 @@
from .common import DeviceOcrResult
from .ocr import DeviceOcr

View File

@ -1,17 +0,0 @@
from dataclasses import dataclass
from typing import Optional
@dataclass
class DeviceOcrResult:
rating_class: int
score: int
pure: Optional[int] = None
far: Optional[int] = None
lost: Optional[int] = None
max_recall: Optional[int] = None
song_id: Optional[str] = None
song_id_possibility: Optional[float] = None
clear_status: Optional[int] = None
partner_id: Optional[str] = None
partner_id_possibility: Optional[float] = None

View File

@ -1,3 +0,0 @@
from .definition import *
from .extractor import *
from .masker import *

View File

@ -1,2 +0,0 @@
from .auto import *
from .common import DeviceRois

View File

@ -1,15 +0,0 @@
from typing import Tuple
Rect = Tuple[int, int, int, int]
class DeviceRois:
pure: Rect
far: Rect
lost: Rect
score: Rect
rating_class: Rect
max_recall: Rect
jacket: Rect
clear_status: Rect
partner_icon: Rect

View File

@ -1 +0,0 @@
from .common import DeviceRoisExtractor

View File

@ -1,48 +0,0 @@
from ....crop import crop_xywh
from ....types import Mat
from ..definition.common import DeviceRois
class DeviceRoisExtractor:
def __init__(self, img: Mat, rois: DeviceRois):
self.img = img
self.sizes = rois
def __construct_int_rect(self, rect):
return tuple(round(r) for r in rect)
@property
def pure(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.pure))
@property
def far(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.far))
@property
def lost(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.lost))
@property
def score(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.score))
@property
def jacket(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.jacket))
@property
def rating_class(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.rating_class))
@property
def max_recall(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.max_recall))
@property
def clear_status(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.clear_status))
@property
def partner_icon(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.partner_icon))

View File

@ -1,2 +0,0 @@
from .auto import *
from .common import DeviceRoisMasker

View File

@ -1,59 +0,0 @@
from ....types import Mat
class DeviceRoisMasker:
@classmethod
def pure(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def far(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def lost(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def score(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def rating_class_pst(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def rating_class_prs(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def rating_class_ftr(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def rating_class_byd(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def rating_class_etr(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def max_recall(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()

View File

@ -0,0 +1,13 @@
from .extractor import DeviceRoisExtractor
from .impl import DeviceScenario
from .masker import DeviceRoisMaskerAutoT1, DeviceRoisMaskerAutoT2
from .rois import DeviceRoisAutoT1, DeviceRoisAutoT2
__all__ = [
"DeviceRoisMaskerAutoT1",
"DeviceRoisMaskerAutoT2",
"DeviceRoisAutoT1",
"DeviceRoisAutoT2",
"DeviceRoisExtractor",
"DeviceScenario",
]

View File

@ -0,0 +1,3 @@
from .base import DeviceRoisExtractor
__all__ = ["DeviceRoisExtractor"]

View File

@ -0,0 +1,46 @@
from arcaea_offline_ocr.crop import crop_xywh
from arcaea_offline_ocr.types import Mat
from ..rois.base import DeviceRois
class DeviceRoisExtractor:
def __init__(self, img: Mat, rois: DeviceRois):
self.img = img
self.sizes = rois
@property
def pure(self):
return crop_xywh(self.img, self.sizes.pure.rounded())
@property
def far(self):
return crop_xywh(self.img, self.sizes.far.rounded())
@property
def lost(self):
return crop_xywh(self.img, self.sizes.lost.rounded())
@property
def score(self):
return crop_xywh(self.img, self.sizes.score.rounded())
@property
def jacket(self):
return crop_xywh(self.img, self.sizes.jacket.rounded())
@property
def rating_class(self):
return crop_xywh(self.img, self.sizes.rating_class.rounded())
@property
def max_recall(self):
return crop_xywh(self.img, self.sizes.max_recall.rounded())
@property
def clear_status(self):
return crop_xywh(self.img, self.sizes.clear_status.rounded())
@property
def partner_icon(self):
return crop_xywh(self.img, self.sizes.partner_icon.rounded())

View File

@ -1,26 +1,31 @@
import cv2 import cv2
import numpy as np import numpy as np
from ..phash_db import ImagePhashDatabase from arcaea_offline_ocr.providers import (
from ..providers.knn import OcrKNearestTextProvider ImageCategory,
from ..types import Mat ImageIdProvider,
from .common import DeviceOcrResult OcrKNearestTextProvider,
from .rois.extractor import DeviceRoisExtractor )
from .rois.masker import DeviceRoisMasker from arcaea_offline_ocr.scenarios.base import OcrScenarioResult
from arcaea_offline_ocr.types import Mat
from .base import DeviceScenarioBase
from .extractor import DeviceRoisExtractor
from .masker import DeviceRoisMasker
class DeviceOcr: class DeviceScenario(DeviceScenarioBase):
def __init__( def __init__(
self, self,
extractor: DeviceRoisExtractor, extractor: DeviceRoisExtractor,
masker: DeviceRoisMasker, masker: DeviceRoisMasker,
knn_provider: OcrKNearestTextProvider, knn_provider: OcrKNearestTextProvider,
phash_db: ImagePhashDatabase, image_id_provider: ImageIdProvider,
): ):
self.extractor = extractor self.extractor = extractor
self.masker = masker self.masker = masker
self.knn_provider = knn_provider self.knn_provider = knn_provider
self.phash_db = phash_db self.image_id_provider = image_id_provider
def pfl(self, roi_gray: Mat, factor: float = 1.25): def pfl(self, roi_gray: Mat, factor: float = 1.25):
def contour_filter(cnt): def contour_filter(cnt):
@ -93,14 +98,12 @@ class DeviceOcr:
] ]
return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0] return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0]
def lookup_song_id(self): def song_id_results(self):
return self.phash_db.lookup_jacket( return self.image_id_provider.results(
cv2.cvtColor(self.extractor.jacket, cv2.COLOR_BGR2GRAY) cv2.cvtColor(self.extractor.jacket, cv2.COLOR_BGR2GRAY),
ImageCategory.JACKET,
) )
def song_id(self):
return self.lookup_song_id()[0]
@staticmethod @staticmethod
def preprocess_char_icon(img_gray: Mat): def preprocess_char_icon(img_gray: Mat):
h, w = img_gray.shape[:2] h, w = img_gray.shape[:2]
@ -114,21 +117,19 @@ class DeviceOcr:
np.array([[0, h], [round(w / 2), h], [0, round(h / 2)]], np.int32), np.array([[0, h], [round(w / 2), h], [0, round(h / 2)]], np.int32),
np.array([[w, h], [round(w / 2), h], [w, round(h / 2)]], np.int32), np.array([[w, h], [round(w / 2), h], [w, round(h / 2)]], np.int32),
], ],
(128), (128,),
) )
return img return img
def lookup_partner_id(self): def partner_id_results(self):
return self.phash_db.lookup_partner_icon( return self.image_id_provider.results(
self.preprocess_char_icon( self.preprocess_char_icon(
cv2.cvtColor(self.extractor.partner_icon, cv2.COLOR_BGR2GRAY) cv2.cvtColor(self.extractor.partner_icon, cv2.COLOR_BGR2GRAY)
) ),
ImageCategory.PARTNER_ICON,
) )
def partner_id(self): def result(self):
return self.lookup_partner_id()[0]
def ocr(self) -> DeviceOcrResult:
rating_class = self.rating_class() rating_class = self.rating_class()
pure = self.pure() pure = self.pure()
far = self.far() far = self.far()
@ -137,20 +138,18 @@ class DeviceOcr:
max_recall = self.max_recall() max_recall = self.max_recall()
clear_status = self.clear_status() clear_status = self.clear_status()
hash_len = self.phash_db.hash_size**2 song_id_results = self.song_id_results()
song_id, song_id_distance = self.lookup_song_id() partner_id_results = self.partner_id_results()
partner_id, partner_id_distance = self.lookup_partner_id()
return DeviceOcrResult( return OcrScenarioResult(
song_id=song_id_results[0].image_id,
song_id_results=song_id_results,
rating_class=rating_class, rating_class=rating_class,
pure=pure, pure=pure,
far=far, far=far,
lost=lost, lost=lost,
score=score, score=score,
max_recall=max_recall, max_recall=max_recall,
song_id=song_id, partner_id_results=partner_id_results,
song_id_possibility=1 - song_id_distance / hash_len,
clear_status=clear_status, clear_status=clear_status,
partner_id=partner_id,
partner_id_possibility=1 - partner_id_distance / hash_len,
) )

View File

@ -0,0 +1,9 @@
from .auto import DeviceRoisMaskerAuto, DeviceRoisMaskerAutoT1, DeviceRoisMaskerAutoT2
from .base import DeviceRoisMasker
__all__ = [
"DeviceRoisMaskerAuto",
"DeviceRoisMaskerAutoT1",
"DeviceRoisMaskerAutoT2",
"DeviceRoisMasker",
]

View File

@ -1,13 +1,12 @@
import cv2 import cv2
import numpy as np import numpy as np
from ....types import Mat from arcaea_offline_ocr.types import Mat
from .common import DeviceRoisMasker
from .base import DeviceRoisMasker
class DeviceRoisMaskerAuto(DeviceRoisMasker): class DeviceRoisMaskerAuto(DeviceRoisMasker):
# pylint: disable=abstract-method
@staticmethod @staticmethod
def mask_bgr_in_hsv(roi_bgr: Mat, hsv_lower: Mat, hsv_upper: Mat): def mask_bgr_in_hsv(roi_bgr: Mat, hsv_lower: Mat, hsv_upper: Mat):
return cv2.inRange( return cv2.inRange(

View File

@ -0,0 +1,61 @@
from abc import ABC, abstractmethod
from arcaea_offline_ocr.types import Mat
class DeviceRoisMasker(ABC):
@classmethod
@abstractmethod
def pure(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def far(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def lost(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def score(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def rating_class_pst(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def rating_class_prs(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def rating_class_ftr(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def rating_class_byd(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def rating_class_etr(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def max_recall(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat: ...

View File

@ -0,0 +1,9 @@
from .auto import DeviceRoisAuto, DeviceRoisAutoT1, DeviceRoisAutoT2
from .base import DeviceRois
__all__ = [
"DeviceRois",
"DeviceRoisAuto",
"DeviceRoisAutoT1",
"DeviceRoisAutoT2",
]

View File

@ -1,6 +1,6 @@
from .common import DeviceRois from arcaea_offline_ocr.types import XYWHRect
__all__ = ["DeviceRoisAuto", "DeviceRoisAutoT1", "DeviceRoisAutoT2"] from .base import DeviceRois
class DeviceRoisAuto(DeviceRois): class DeviceRoisAuto(DeviceRois):
@ -50,7 +50,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property @property
def pure(self): def pure(self):
return ( return XYWHRect(
self.pfl_x, self.pfl_x,
self.layout_area_h_mid + 110 * self.factor, self.layout_area_h_mid + 110 * self.factor,
self.pfl_w, self.pfl_w,
@ -59,7 +59,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property @property
def far(self): def far(self):
return ( return XYWHRect(
self.pfl_x, self.pfl_x,
self.pure[1] + self.pure[3] + 12 * self.factor, self.pure[1] + self.pure[3] + 12 * self.factor,
self.pfl_w, self.pfl_w,
@ -68,7 +68,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property @property
def lost(self): def lost(self):
return ( return XYWHRect(
self.pfl_x, self.pfl_x,
self.far[1] + self.far[3] + 10 * self.factor, self.far[1] + self.far[3] + 10 * self.factor,
self.pfl_w, self.pfl_w,
@ -79,7 +79,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
def score(self): def score(self):
w = 280 * self.factor w = 280 * self.factor
h = 45 * self.factor h = 45 * self.factor
return ( return XYWHRect(
self.w_mid - w / 2, self.w_mid - w / 2,
self.layout_area_h_mid - 75 * self.factor - h, self.layout_area_h_mid - 75 * self.factor - h,
w, w,
@ -88,7 +88,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property @property
def rating_class(self): def rating_class(self):
return ( return XYWHRect(
self.w_mid - 610 * self.factor, self.w_mid - 610 * self.factor,
self.layout_area_h_mid - 180 * self.factor, self.layout_area_h_mid - 180 * self.factor,
265 * self.factor, 265 * self.factor,
@ -97,7 +97,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property @property
def max_recall(self): def max_recall(self):
return ( return XYWHRect(
self.w_mid - 465 * self.factor, self.w_mid - 465 * self.factor,
self.layout_area_h_mid - 215 * self.factor, self.layout_area_h_mid - 215 * self.factor,
150 * self.factor, 150 * self.factor,
@ -106,7 +106,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property @property
def jacket(self): def jacket(self):
return ( return XYWHRect(
self.w_mid - 610 * self.factor, self.w_mid - 610 * self.factor,
self.layout_area_h_mid - 143 * self.factor, self.layout_area_h_mid - 143 * self.factor,
375 * self.factor, 375 * self.factor,
@ -117,7 +117,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
def clear_status(self): def clear_status(self):
w = 550 * self.factor w = 550 * self.factor
h = 60 * self.factor h = 60 * self.factor
return ( return XYWHRect(
self.w_mid - w / 2, self.w_mid - w / 2,
self.layout_area_h_mid - 155 * self.factor - h, self.layout_area_h_mid - 155 * self.factor - h,
w * 0.4, w * 0.4,
@ -128,7 +128,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
def partner_icon(self): def partner_icon(self):
w = 90 * self.factor w = 90 * self.factor
h = 75 * self.factor h = 75 * self.factor
return (self.w_mid - w / 2, 0, w, h) return XYWHRect(self.w_mid - w / 2, 0, w, h)
class DeviceRoisAutoT2(DeviceRoisAuto): class DeviceRoisAutoT2(DeviceRoisAuto):
@ -174,7 +174,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property @property
def pure(self): def pure(self):
return ( return XYWHRect(
self.pfl_x, self.pfl_x,
self.layout_area_h_mid + 175 * self.factor, self.layout_area_h_mid + 175 * self.factor,
self.pfl_w, self.pfl_w,
@ -183,7 +183,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property @property
def far(self): def far(self):
return ( return XYWHRect(
self.pfl_x, self.pfl_x,
self.pure[1] + self.pure[3] + 30 * self.factor, self.pure[1] + self.pure[3] + 30 * self.factor,
self.pfl_w, self.pfl_w,
@ -192,7 +192,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property @property
def lost(self): def lost(self):
return ( return XYWHRect(
self.pfl_x, self.pfl_x,
self.far[1] + self.far[3] + 35 * self.factor, self.far[1] + self.far[3] + 35 * self.factor,
self.pfl_w, self.pfl_w,
@ -203,7 +203,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
def score(self): def score(self):
w = 420 * self.factor w = 420 * self.factor
h = 70 * self.factor h = 70 * self.factor
return ( return XYWHRect(
self.w_mid - w / 2, self.w_mid - w / 2,
self.layout_area_h_mid - 110 * self.factor - h, self.layout_area_h_mid - 110 * self.factor - h,
w, w,
@ -212,7 +212,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property @property
def rating_class(self): def rating_class(self):
return ( return XYWHRect(
max(0, self.w_mid - 965 * self.factor), max(0, self.w_mid - 965 * self.factor),
self.layout_area_h_mid - 330 * self.factor, self.layout_area_h_mid - 330 * self.factor,
350 * self.factor, 350 * self.factor,
@ -221,7 +221,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property @property
def max_recall(self): def max_recall(self):
return ( return XYWHRect(
self.w_mid - 625 * self.factor, self.w_mid - 625 * self.factor,
self.layout_area_h_mid - 275 * self.factor, self.layout_area_h_mid - 275 * self.factor,
150 * self.factor, 150 * self.factor,
@ -230,7 +230,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property @property
def jacket(self): def jacket(self):
return ( return XYWHRect(
self.w_mid - 915 * self.factor, self.w_mid - 915 * self.factor,
self.layout_area_h_mid - 215 * self.factor, self.layout_area_h_mid - 215 * self.factor,
565 * self.factor, 565 * self.factor,
@ -241,7 +241,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
def clear_status(self): def clear_status(self):
w = 825 * self.factor w = 825 * self.factor
h = 90 * self.factor h = 90 * self.factor
return ( return XYWHRect(
self.w_mid - w / 2, self.w_mid - w / 2,
self.layout_area_h_mid - 235 * self.factor - h, self.layout_area_h_mid - 235 * self.factor - h,
w * 0.4, w * 0.4,
@ -252,4 +252,4 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
def partner_icon(self): def partner_icon(self):
w = 135 * self.factor w = 135 * self.factor
h = 110 * self.factor h = 110 * self.factor
return (self.w_mid - w / 2, 0, w, h) return XYWHRect(self.w_mid - w / 2, 0, w, h)

View File

@ -0,0 +1,33 @@
from abc import ABC, abstractmethod
from arcaea_offline_ocr.types import XYWHRect
class DeviceRois(ABC):
@property
@abstractmethod
def pure(self) -> XYWHRect: ...
@property
@abstractmethod
def far(self) -> XYWHRect: ...
@property
@abstractmethod
def lost(self) -> XYWHRect: ...
@property
@abstractmethod
def score(self) -> XYWHRect: ...
@property
@abstractmethod
def rating_class(self) -> XYWHRect: ...
@property
@abstractmethod
def max_recall(self) -> XYWHRect: ...
@property
@abstractmethod
def jacket(self) -> XYWHRect: ...
@property
@abstractmethod
def clear_status(self) -> XYWHRect: ...
@property
@abstractmethod
def partner_icon(self) -> XYWHRect: ...