diff --git a/README.md b/README.md index 59b9201..521985f 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # Arcaea Offline OCR +[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) + ## Example ```py diff --git a/pyproject.toml b/pyproject.toml index 96ac8e3..1d43f40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,26 +16,18 @@ classifiers = [ "Programming Language :: Python :: 3", ] +[project.optional-dependencies] +dev = ["ruff", "pre-commit"] + [project.urls] "Homepage" = "https://github.com/ArcaeaOffline/core-ocr" "Bug Tracker" = "https://github.com/ArcaeaOffline/core-ocr/issues" [tool.setuptools_scm] -[tool.isort] -profile = "black" -src_paths = ["src/arcaea_offline_ocr"] - [tool.pyright] ignore = ["**/__debug*.*"] -[tool.pylint.main] -# extension-pkg-allow-list = ["cv2"] -generated-members = ["cv2.*"] - -[tool.pylint.logging] -disable = [ - "missing-module-docstring", - "missing-class-docstring", - "missing-function-docstring", -] +[tool.ruff.lint] +select = ["ALL"] +ignore = ["ANN", "D", "ERA", "PLR"] diff --git a/requirements.dev.txt b/requirements.dev.txt index 4fb53e7..c6ef9af 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -1,3 +1,2 @@ -black==23.7.0 -isort==5.12.0 -pre-commit==3.3.3 +ruff +pre-commit diff --git a/src/arcaea_offline_ocr/builders/ihdb.py b/src/arcaea_offline_ocr/builders/ihdb.py index 6088684..e4ed2be 100644 --- a/src/arcaea_offline_ocr/builders/ihdb.py +++ b/src/arcaea_offline_ocr/builders/ihdb.py @@ -1,11 +1,12 @@ +from __future__ import annotations + from dataclasses import dataclass from datetime import datetime, timezone -from typing import TYPE_CHECKING, Callable, List +from typing import TYPE_CHECKING, Callable import cv2 from arcaea_offline_ocr.core import hashers -from arcaea_offline_ocr.providers import ImageCategory from arcaea_offline_ocr.providers.ihdb import ( PROP_KEY_BUILT_AT, PROP_KEY_HASH_SIZE, @@ -17,6 +18,7 @@ from arcaea_offline_ocr.providers.ihdb import ( if TYPE_CHECKING: from sqlite3 import Connection + from arcaea_offline_ocr.providers import ImageCategory from arcaea_offline_ocr.types import Mat @@ -29,7 +31,7 @@ class ImageHashDatabaseBuildTask: image_path: str image_id: str category: ImageCategory - imread_function: Callable[[str], "Mat"] = _default_imread_gray + imread_function: Callable[[str], Mat] = _default_imread_gray @dataclass @@ -42,7 +44,7 @@ class _ImageHash: class ImageHashesDatabaseBuilder: @staticmethod - def __insert_property(conn: "Connection", key: str, value: str): + def __insert_property(conn: Connection, key: str, value: str): return conn.execute( "INSERT INTO properties (key, value) VALUES (?, ?)", (key, value), @@ -51,13 +53,13 @@ class ImageHashesDatabaseBuilder: @classmethod def build( cls, - conn: "Connection", - tasks: List[ImageHashDatabaseBuildTask], + conn: Connection, + tasks: list[ImageHashDatabaseBuildTask], *, hash_size: int = 16, high_freq_factor: int = 4, ): - hashes: List[_ImageHash] = [] + hashes: list[_ImageHash] = [] for task in tasks: img_gray = task.imread_function(task.image_path) @@ -82,7 +84,7 @@ class ImageHashesDatabaseBuilder: image_hash_type=hash_type, category=task.category, hash=ImageHashDatabaseIdProvider.hash_mat_to_bytes(hash_mat), - ) + ), ) conn.execute("CREATE TABLE properties (`key` VARCHAR, `value` VARCHAR)") @@ -92,7 +94,7 @@ class ImageHashesDatabaseBuilder: `category` INTEGER, `hash_type` INTEGER, `hash` BLOB -)""" +)""", ) now = datetime.now(tz=timezone.utc) @@ -103,7 +105,8 @@ class ImageHashesDatabaseBuilder: cls.__insert_property(conn, PROP_KEY_BUILT_AT, str(timestamp)) conn.executemany( - "INSERT INTO hashes (`id`, `category`, `hash_type`, `hash`) VALUES (?, ?, ?, ?)", + """INSERT INTO hashes (`id`, `category`, `hash_type`, `hash`) + VALUES (?, ?, ?, ?)""", [ (it.image_id, it.category.value, it.image_hash_type.value, it.hash) for it in hashes diff --git a/src/arcaea_offline_ocr/core/hashers/index.py b/src/arcaea_offline_ocr/core/hashers/index.py index 1d8c3fd..b100408 100644 --- a/src/arcaea_offline_ocr/core/hashers/index.py +++ b/src/arcaea_offline_ocr/core/hashers/index.py @@ -23,7 +23,7 @@ def difference(img_gray: Mat, hash_size: int) -> Mat: def dct(img_gray: Mat, hash_size: int = 16, high_freq_factor: int = 4) -> Mat: - # TODO: consistency? + # TODO: consistency? # noqa: FIX002, TD002, TD003 img_size_base = hash_size * high_freq_factor img_size = (img_size_base, img_size_base) diff --git a/src/arcaea_offline_ocr/crop.py b/src/arcaea_offline_ocr/crop.py index 12c531d..22a6e84 100644 --- a/src/arcaea_offline_ocr/crop.py +++ b/src/arcaea_offline_ocr/crop.py @@ -1,29 +1,32 @@ +from __future__ import annotations + import math -from typing import Tuple +from typing import TYPE_CHECKING import cv2 import numpy as np -from .types import Mat +if TYPE_CHECKING: + from .types import Mat -__all__ = ["crop_xywh", "CropBlackEdges"] +__all__ = ["CropBlackEdges", "crop_xywh"] -def crop_xywh(mat: Mat, rect: Tuple[int, int, int, int]): +def crop_xywh(mat: Mat, rect: tuple[int, int, int, int]): x, y, w, h = rect return mat[y : y + h, x : x + w] class CropBlackEdges: @staticmethod - def is_black_edge(__img_gray_slice: Mat, black_pixel: int, ratio: float = 0.6): - pixels_compared = __img_gray_slice < black_pixel + def is_black_edge(img_gray_slice: Mat, black_pixel: int, ratio: float = 0.6): + pixels_compared = img_gray_slice < black_pixel return np.count_nonzero(pixels_compared) > math.floor( - __img_gray_slice.size * ratio + img_gray_slice.size * ratio, ) @classmethod - def get_crop_rect(cls, img_gray: Mat, black_threshold: int = 25): + def get_crop_rect(cls, img_gray: Mat, black_threshold: int = 25): # noqa: C901 height, width = img_gray.shape[:2] left = 0 right = width @@ -54,13 +57,22 @@ class CropBlackEdges: break bottom -= 1 - assert right > left, "cropped width < 0" - assert bottom > top, "cropped height < 0" + if right <= left: + msg = "cropped width < 0" + raise ValueError(msg) + + if bottom <= top: + msg = "cropped height < 0" + raise ValueError(msg) + return (left, top, right - left, bottom - top) @classmethod def crop( - cls, img: Mat, convert_flag: cv2.COLOR_BGR2GRAY, black_threshold: int = 25 + cls, + img: Mat, + convert_flag: cv2.COLOR_BGR2GRAY, + black_threshold: int = 25, ) -> Mat: rect = cls.get_crop_rect(cv2.cvtColor(img, convert_flag), black_threshold) return crop_xywh(img, rect) diff --git a/src/arcaea_offline_ocr/providers/__init__.py b/src/arcaea_offline_ocr/providers/__init__.py index baa233b..3330663 100644 --- a/src/arcaea_offline_ocr/providers/__init__.py +++ b/src/arcaea_offline_ocr/providers/__init__.py @@ -5,8 +5,8 @@ from .knn import OcrKNearestTextProvider __all__ = [ "ImageCategory", "ImageHashDatabaseIdProvider", - "OcrKNearestTextProvider", "ImageIdProvider", - "OcrTextProvider", "ImageIdProviderResult", + "OcrKNearestTextProvider", + "OcrTextProvider", ] diff --git a/src/arcaea_offline_ocr/providers/base.py b/src/arcaea_offline_ocr/providers/base.py index b98c058..8a8e44b 100644 --- a/src/arcaea_offline_ocr/providers/base.py +++ b/src/arcaea_offline_ocr/providers/base.py @@ -1,17 +1,19 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from dataclasses import dataclass from enum import IntEnum -from typing import TYPE_CHECKING, Any, Sequence, Optional +from typing import TYPE_CHECKING, Any, Sequence if TYPE_CHECKING: - from ..types import Mat + from arcaea_offline_ocr.types import Mat class OcrTextProvider(ABC): @abstractmethod - def result_raw(self, img: "Mat", /, *args, **kwargs) -> Any: ... + def result_raw(self, img: Mat, /, *args, **kwargs) -> Any: ... @abstractmethod - def result(self, img: "Mat", /, *args, **kwargs) -> Optional[str]: ... + def result(self, img: Mat, /, *args, **kwargs) -> str | None: ... class ImageCategory(IntEnum): @@ -29,10 +31,20 @@ class ImageIdProviderResult: class ImageIdProvider(ABC): @abstractmethod def result( - self, img: "Mat", category: ImageCategory, /, *args, **kwargs + self, + img: Mat, + category: ImageCategory, + /, + *args, + **kwargs, ) -> ImageIdProviderResult: ... @abstractmethod def results( - self, img: "Mat", category: ImageCategory, /, *args, **kwargs + self, + img: Mat, + category: ImageCategory, + /, + *args, + **kwargs, ) -> Sequence[ImageIdProviderResult]: ... diff --git a/src/arcaea_offline_ocr/providers/ihdb.py b/src/arcaea_offline_ocr/providers/ihdb.py index 0539264..5ff534a 100644 --- a/src/arcaea_offline_ocr/providers/ihdb.py +++ b/src/arcaea_offline_ocr/providers/ihdb.py @@ -1,14 +1,17 @@ -import sqlite3 +from __future__ import annotations + from dataclasses import dataclass from datetime import datetime, timezone from enum import IntEnum -from typing import TYPE_CHECKING, Any, Callable, List, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Callable, TypeVar from arcaea_offline_ocr.core import hashers from .base import ImageCategory, ImageIdProvider, ImageIdProviderResult if TYPE_CHECKING: + import sqlite3 + from arcaea_offline_ocr.types import Mat @@ -19,9 +22,11 @@ PROP_KEY_BUILT_AT = "built_at" def _sql_hamming_distance(hash1: bytes, hash2: bytes): - assert len(hash1) == len(hash2), "hash size does not match!" - count = sum(1 for byte1, byte2 in zip(hash1, hash2) if byte1 != byte2) - return count + if len(hash1) != len(hash2): + msg = "hash size does not match!" + raise ValueError(msg) + + return sum(1 for byte1, byte2 in zip(hash1, hash2) if byte1 != byte2) class ImageHashType(IntEnum): @@ -36,7 +41,7 @@ class ImageHashDatabaseIdProviderResult(ImageIdProviderResult): class MissingPropertiesError(Exception): - keys: List[str] + keys: list[str] def __init__(self, keys, *args): super().__init__(*args) @@ -72,7 +77,7 @@ class ImageHashDatabaseIdProvider(ImageIdProvider): return self.properties[PROP_KEY_HIGH_FREQ_FACTOR] @property - def built_at(self) -> Optional[datetime]: + def built_at(self) -> datetime | None: return self.properties.get(PROP_KEY_BUILT_AT) @property @@ -80,7 +85,7 @@ class ImageHashDatabaseIdProvider(ImageIdProvider): return self._hash_length def _initialize(self): - def get_property(key, converter: Callable[[Any], T]) -> Optional[T]: + def get_property(key, converter: Callable[[Any], T]) -> T | None: result = self.conn.execute( "SELECT value FROM properties WHERE key = ?", (key,), @@ -97,7 +102,8 @@ class ImageHashDatabaseIdProvider(ImageIdProvider): PROP_KEY_HASH_SIZE: lambda x: int(x), PROP_KEY_HIGH_FREQ_FACTOR: lambda x: int(x), PROP_KEY_BUILT_AT: lambda ts: datetime.fromtimestamp( - int(ts) / 1000, tz=timezone.utc + int(ts) / 1000, + tz=timezone.utc, ), } required_properties = [PROP_KEY_HASH_SIZE, PROP_KEY_HIGH_FREQ_FACTOR] @@ -122,8 +128,11 @@ class ImageHashDatabaseIdProvider(ImageIdProvider): self._hash_length = self.hash_size**2 def lookup_hash( - self, category: ImageCategory, hash_type: ImageHashType, hash: bytes - ) -> List[ImageHashDatabaseIdProviderResult]: + self, + category: ImageCategory, + hash_type: ImageHashType, + hash_data: bytes, + ) -> list[ImageHashDatabaseIdProviderResult]: cursor = self.conn.execute( """ SELECT @@ -132,7 +141,7 @@ SELECT FROM hashes WHERE category = ? AND hash_type = ? ORDER BY distance ASC LIMIT 10""", - (hash, category.value, hash_type.value), + (hash_data, category.value, hash_type.value), ) results = [] @@ -143,52 +152,52 @@ ORDER BY distance ASC LIMIT 10""", category=category, confidence=(self.hash_length - distance) / self.hash_length, image_hash_type=hash_type, - ) + ), ) return results @staticmethod - def hash_mat_to_bytes(hash: "Mat") -> bytes: - return bytes([255 if b else 0 for b in hash.flatten()]) + def hash_mat_to_bytes(hash_mat: Mat) -> bytes: + return bytes([255 if b else 0 for b in hash_mat.flatten()]) - def results(self, img: "Mat", category: ImageCategory, /): - results: List[ImageHashDatabaseIdProviderResult] = [] + def results(self, img: Mat, category: ImageCategory, /): + results: list[ImageHashDatabaseIdProviderResult] = [] results.extend( self.lookup_hash( category, ImageHashType.AVERAGE, self.hash_mat_to_bytes(hashers.average(img, self.hash_size)), - ) + ), ) results.extend( self.lookup_hash( category, ImageHashType.DIFFERENCE, self.hash_mat_to_bytes(hashers.difference(img, self.hash_size)), - ) + ), ) results.extend( self.lookup_hash( category, ImageHashType.DCT, self.hash_mat_to_bytes( - hashers.dct(img, self.hash_size, self.high_freq_factor) + hashers.dct(img, self.hash_size, self.high_freq_factor), ), - ) + ), ) return results def result( self, - img: "Mat", + img: Mat, category: ImageCategory, /, *, hash_type: ImageHashType = ImageHashType.DCT, ): - return [ + return next( it for it in self.results(img, category) if it.image_hash_type == hash_type - ][0] + ) diff --git a/src/arcaea_offline_ocr/providers/knn.py b/src/arcaea_offline_ocr/providers/knn.py index 3e8473f..1af0435 100644 --- a/src/arcaea_offline_ocr/providers/knn.py +++ b/src/arcaea_offline_ocr/providers/knn.py @@ -1,17 +1,20 @@ +from __future__ import annotations + import logging import math -from typing import TYPE_CHECKING, Callable, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Callable, Sequence import cv2 import numpy as np -from ..crop import crop_xywh +from arcaea_offline_ocr.crop import crop_xywh + from .base import OcrTextProvider if TYPE_CHECKING: from cv2.ml import KNearest - from ..types import Mat + from arcaea_offline_ocr.types import Mat logger = logging.getLogger(__name__) @@ -19,10 +22,10 @@ logger = logging.getLogger(__name__) class FixRects: @staticmethod def connect_broken( - rects: Sequence[Tuple[int, int, int, int]], + rects: Sequence[tuple[int, int, int, int]], img_width: int, img_height: int, - tolerance: Optional[int] = None, + tolerance: int | None = None, ): # for a "broken" digit, please refer to # /assets/fix_rects/broken_masked.jpg @@ -69,8 +72,8 @@ class FixRects: @staticmethod def split_connected( - img_masked: "Mat", - rects: Sequence[Tuple[int, int, int, int]], + img_masked: Mat, + rects: Sequence[tuple[int, int, int, int]], rect_wh_ratio: float = 1.05, width_range_ratio: float = 0.1, ): @@ -111,7 +114,7 @@ class FixRects: # split the rect new_rects.extend( - [(rx, ry, x_mid - rx, rh), (x_mid, ry, rx + rw - x_mid, rh)] + [(rx, ry, x_mid - rx, rh), (x_mid, ry, rx + rw - x_mid, rh)], ) return_rects = [r for r in rects if r not in connected_rects] @@ -119,7 +122,7 @@ class FixRects: return return_rects -def resize_fill_square(img: "Mat", target: int = 20): +def resize_fill_square(img: Mat, target: int = 20): h, w = img.shape[:2] if h > w: new_h = target @@ -132,11 +135,21 @@ def resize_fill_square(img: "Mat", target: int = 20): border_size = math.ceil((max(new_w, new_h) - min(new_w, new_h)) / 2) if new_w < new_h: resized = cv2.copyMakeBorder( - resized, 0, 0, border_size, border_size, cv2.BORDER_CONSTANT + resized, + 0, + 0, + border_size, + border_size, + cv2.BORDER_CONSTANT, ) else: resized = cv2.copyMakeBorder( - resized, border_size, border_size, 0, 0, cv2.BORDER_CONSTANT + resized, + border_size, + border_size, + 0, + 0, + cv2.BORDER_CONSTANT, ) return cv2.resize(resized, (target, target)) @@ -151,8 +164,8 @@ def preprocess_hog(digit_rois): return np.float32(samples) -def ocr_digit_samples_knn(__samples, knn_model: cv2.ml.KNearest, k: int = 4): - _, results, _, _ = knn_model.findNearest(__samples, k) +def ocr_digit_samples_knn(samples, knn_model: cv2.ml.KNearest, k: int = 4): + _, results, _, _ = knn_model.findNearest(samples, k) return [int(r) for r in results.ravel()] @@ -160,11 +173,15 @@ class OcrKNearestTextProvider(OcrTextProvider): _ContourFilter = Callable[["Mat"], bool] _RectsFilter = Callable[[Sequence[int]], bool] - def __init__(self, model: "KNearest"): + def __init__(self, model: KNearest): self.model = model def contours( - self, img: "Mat", /, *, contours_filter: Optional[_ContourFilter] = None + self, + img: Mat, + /, + *, + contours_filter: _ContourFilter | None = None, ): cnts, _ = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) if contours_filter: @@ -174,12 +191,12 @@ class OcrKNearestTextProvider(OcrTextProvider): def result_raw( self, - img: "Mat", + img: Mat, /, *, fix_rects: bool = True, - contours_filter: Optional[_ContourFilter] = None, - rects_filter: Optional[_RectsFilter] = None, + contours_filter: _ContourFilter | None = None, + rects_filter: _RectsFilter | None = None, ): """ :param img: grayscaled roi @@ -192,11 +209,11 @@ class OcrKNearestTextProvider(OcrTextProvider): rects = [cv2.boundingRect(cnt) for cnt in cnts] if fix_rects and rects_filter: - rects = FixRects.connect_broken(rects, img.shape[1], img.shape[0]) # type: ignore + rects = FixRects.connect_broken(rects, img.shape[1], img.shape[0]) rects = list(filter(rects_filter, rects)) rects = FixRects.split_connected(img, rects) elif fix_rects: - rects = FixRects.connect_broken(rects, img.shape[1], img.shape[0]) # type: ignore + rects = FixRects.connect_broken(rects, img.shape[1], img.shape[0]) rects = FixRects.split_connected(img, rects) elif rects_filter: rects = list(filter(rects_filter, rects)) @@ -216,12 +233,12 @@ class OcrKNearestTextProvider(OcrTextProvider): def result( self, - img: "Mat", + img: Mat, /, *, fix_rects: bool = True, - contours_filter: Optional[_ContourFilter] = None, - rects_filter: Optional[_RectsFilter] = None, + contours_filter: _ContourFilter | None = None, + rects_filter: _RectsFilter | None = None, ): """ :param img: grayscaled roi diff --git a/src/arcaea_offline_ocr/scenarios/__init__.py b/src/arcaea_offline_ocr/scenarios/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/arcaea_offline_ocr/scenarios/b30/base.py b/src/arcaea_offline_ocr/scenarios/b30/base.py index 7d3492a..fe8bb23 100644 --- a/src/arcaea_offline_ocr/scenarios/b30/base.py +++ b/src/arcaea_offline_ocr/scenarios/b30/base.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from abc import abstractmethod -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING from arcaea_offline_ocr.scenarios.base import OcrScenario, OcrScenarioResult @@ -9,13 +11,13 @@ if TYPE_CHECKING: class Best30Scenario(OcrScenario): @abstractmethod - def components(self, img: "Mat", /) -> List["Mat"]: ... + def components(self, img: Mat, /) -> list[Mat]: ... @abstractmethod - def result(self, component_img: "Mat", /, *args, **kwargs) -> OcrScenarioResult: ... + def result(self, component_img: Mat, /, *args, **kwargs) -> OcrScenarioResult: ... @abstractmethod - def results(self, img: "Mat", /, *args, **kwargs) -> List[OcrScenarioResult]: + def results(self, img: Mat, /, *args, **kwargs) -> list[OcrScenarioResult]: """ Commonly a shorthand for `[self.result(comp) for comp in self.components(img)]` """ diff --git a/src/arcaea_offline_ocr/scenarios/b30/chieri/v4/colors.py b/src/arcaea_offline_ocr/scenarios/b30/chieri/v4/colors.py index 98d32f1..17d2a15 100644 --- a/src/arcaea_offline_ocr/scenarios/b30/chieri/v4/colors.py +++ b/src/arcaea_offline_ocr/scenarios/b30/chieri/v4/colors.py @@ -1,19 +1,19 @@ import numpy as np __all__ = [ - "FONT_THRESHOLD", - "PURE_BG_MIN_HSV", - "PURE_BG_MAX_HSV", - "FAR_BG_MIN_HSV", - "FAR_BG_MAX_HSV", - "LOST_BG_MIN_HSV", - "LOST_BG_MAX_HSV", - "BYD_MIN_HSV", "BYD_MAX_HSV", - "FTR_MIN_HSV", + "BYD_MIN_HSV", + "FAR_BG_MAX_HSV", + "FAR_BG_MIN_HSV", + "FONT_THRESHOLD", "FTR_MAX_HSV", - "PRS_MIN_HSV", + "FTR_MIN_HSV", + "LOST_BG_MAX_HSV", + "LOST_BG_MIN_HSV", "PRS_MAX_HSV", + "PRS_MIN_HSV", + "PURE_BG_MAX_HSV", + "PURE_BG_MIN_HSV", ] FONT_THRESHOLD = 160 diff --git a/src/arcaea_offline_ocr/scenarios/b30/chieri/v4/impl.py b/src/arcaea_offline_ocr/scenarios/b30/chieri/v4/impl.py index 53a23f5..e254411 100644 --- a/src/arcaea_offline_ocr/scenarios/b30/chieri/v4/impl.py +++ b/src/arcaea_offline_ocr/scenarios/b30/chieri/v4/impl.py @@ -1,4 +1,6 @@ -from typing import List, Optional, Tuple +from __future__ import annotations + +from typing import TYPE_CHECKING import cv2 import numpy as np @@ -11,7 +13,9 @@ from arcaea_offline_ocr.providers import ( ) from arcaea_offline_ocr.scenarios.b30.base import Best30Scenario from arcaea_offline_ocr.scenarios.base import OcrScenarioResult -from arcaea_offline_ocr.types import Mat + +if TYPE_CHECKING: + from arcaea_offline_ocr.types import Mat from .colors import ( BYD_MAX_HSV, @@ -71,13 +75,13 @@ class ChieriBotV4Best30Scenario(Best30Scenario): rating_class_results = [np.count_nonzero(m) for m in rating_class_masks] if max(rating_class_results) < 70: return 0 - else: - return max(enumerate(rating_class_results), key=lambda i: i[1])[0] + 1 + return max(enumerate(rating_class_results), key=lambda i: i[1])[0] + 1 def ocr_component_song_id_results(self, component_bgr: Mat): jacket_rect = self.rois.component_rois.jacket_rect.floored() jacket_roi = cv2.cvtColor( - crop_xywh(component_bgr, jacket_rect), cv2.COLOR_BGR2GRAY + crop_xywh(component_bgr, jacket_rect), + cv2.COLOR_BGR2GRAY, ) return self.image_id_provider.results(jacket_roi, ImageCategory.JACKET) @@ -85,16 +89,22 @@ class ChieriBotV4Best30Scenario(Best30Scenario): # sourcery skip: inline-immediately-returned-variable score_rect = self.rois.component_rois.score_rect.rounded() score_roi = cv2.cvtColor( - crop_xywh(component_bgr, score_rect), cv2.COLOR_BGR2GRAY + crop_xywh(component_bgr, score_rect), + cv2.COLOR_BGR2GRAY, ) _, score_roi = cv2.threshold( - score_roi, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU + score_roi, + 0, + 255, + cv2.THRESH_BINARY + cv2.THRESH_OTSU, ) if score_roi[1][1] == 255: score_roi = 255 - score_roi contours, _ = cv2.findContours( - score_roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + score_roi, + cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_SIMPLE, ) for contour in contours: rect = cv2.boundingRect(contour) @@ -106,8 +116,9 @@ class ChieriBotV4Best30Scenario(Best30Scenario): return int(ocr_result) if ocr_result else 0 def find_pfl_rects( - self, component_pfl_processed: Mat - ) -> List[Tuple[int, int, int, int]]: + self, + component_pfl_processed: Mat, + ) -> list[tuple[int, int, int, int]]: # sourcery skip: inline-immediately-returned-variable pfl_roi_find = cv2.morphologyEx( component_pfl_processed, @@ -115,14 +126,16 @@ class ChieriBotV4Best30Scenario(Best30Scenario): cv2.getStructuringElement(cv2.MORPH_RECT, [10, 1]), ) pfl_contours, _ = cv2.findContours( - pfl_roi_find, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE + pfl_roi_find, + cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_NONE, ) pfl_rects = [cv2.boundingRect(c) for c in pfl_contours] pfl_rects = [ r for r in pfl_rects if r[3] > component_pfl_processed.shape[0] * 0.1 ] pfl_rects = sorted(pfl_rects, key=lambda r: r[1]) - pfl_rects_adjusted = [ + return [ ( max(rect[0] - 2, 0), rect[1], @@ -131,7 +144,6 @@ class ChieriBotV4Best30Scenario(Best30Scenario): ) for rect in pfl_rects ] - return pfl_rects_adjusted def preprocess_component_pfl(self, component_bgr: Mat) -> Mat: pfl_rect = self.rois.component_rois.pfl_rect.rounded() @@ -154,11 +166,17 @@ class ChieriBotV4Best30Scenario(Best30Scenario): pfl_roi_blurred = cv2.GaussianBlur(pfl_roi, (5, 5), 0) # pfl_roi_blurred = cv2.medianBlur(pfl_roi, 3) _, pfl_roi_blurred_threshold = cv2.threshold( - pfl_roi_blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU + pfl_roi_blurred, + 0, + 255, + cv2.THRESH_BINARY + cv2.THRESH_OTSU, ) # and a threshold of the original roi _, pfl_roi_threshold = cv2.threshold( - pfl_roi, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU + pfl_roi, + 0, + 255, + cv2.THRESH_BINARY + cv2.THRESH_OTSU, ) # turn thresholds into black background if pfl_roi_blurred_threshold[2][2] == 255: @@ -168,13 +186,15 @@ class ChieriBotV4Best30Scenario(Best30Scenario): # return a bitwise_and result result = cv2.bitwise_and(pfl_roi_blurred_threshold, pfl_roi_threshold) result_eroded = cv2.erode( - result, cv2.getStructuringElement(cv2.MORPH_CROSS, (2, 2)) + result, + cv2.getStructuringElement(cv2.MORPH_CROSS, (2, 2)), ) return result_eroded if len(self.find_pfl_rects(result_eroded)) == 3 else result def ocr_component_pfl( - self, component_bgr: Mat - ) -> Tuple[Optional[int], Optional[int], Optional[int]]: + self, + component_bgr: Mat, + ) -> tuple[int | None, int | None, int | None]: try: pfl_roi = self.preprocess_component_pfl(component_bgr) pfl_rects = self.find_pfl_rects(pfl_roi) @@ -185,7 +205,7 @@ class ChieriBotV4Best30Scenario(Best30Scenario): pure_far_lost.append(int(result) if result else None) return tuple(pure_far_lost) - except Exception: + except Exception: # noqa: BLE001 return (None, None, None) def ocr_component(self, component_bgr: Mat) -> OcrScenarioResult: @@ -216,7 +236,7 @@ class ChieriBotV4Best30Scenario(Best30Scenario): def result(self, component_img: Mat, /): return self.ocr_component(component_img) - def results(self, img: Mat, /) -> List[OcrScenarioResult]: + def results(self, img: Mat, /) -> list[OcrScenarioResult]: """ :param img: BGR format image """ diff --git a/src/arcaea_offline_ocr/scenarios/b30/chieri/v4/rois.py b/src/arcaea_offline_ocr/scenarios/b30/chieri/v4/rois.py index 239ec8a..4ea6584 100644 --- a/src/arcaea_offline_ocr/scenarios/b30/chieri/v4/rois.py +++ b/src/arcaea_offline_ocr/scenarios/b30/chieri/v4/rois.py @@ -1,4 +1,4 @@ -from typing import List +from __future__ import annotations from arcaea_offline_ocr.crop import crop_xywh from arcaea_offline_ocr.types import Mat, XYWHRect @@ -105,7 +105,7 @@ class ChieriBotV4Rois: def b33_vertical_gap(self): return 121 * self.factor - def components(self, img_bgr: Mat) -> List[Mat]: + def components(self, img_bgr: Mat) -> list[Mat]: first_rect = XYWHRect(x=self.left, y=self.top, w=self.width, h=self.height) results = [] diff --git a/src/arcaea_offline_ocr/scenarios/base.py b/src/arcaea_offline_ocr/scenarios/base.py index 35e13c9..12244f6 100644 --- a/src/arcaea_offline_ocr/scenarios/base.py +++ b/src/arcaea_offline_ocr/scenarios/base.py @@ -1,9 +1,13 @@ +from __future__ import annotations + from abc import ABC from dataclasses import dataclass, field -from datetime import datetime -from typing import Sequence, Optional +from typing import TYPE_CHECKING, Sequence -from arcaea_offline_ocr.providers import ImageIdProviderResult +if TYPE_CHECKING: + from datetime import datetime + + from arcaea_offline_ocr.providers import ImageIdProviderResult @dataclass(kw_only=True) @@ -12,27 +16,27 @@ class OcrScenarioResult: rating_class: int score: int - song_id_results: Sequence[ImageIdProviderResult] = field(default_factory=lambda: []) + song_id_results: Sequence[ImageIdProviderResult] = field(default_factory=list) partner_id_results: Sequence[ImageIdProviderResult] = field( - default_factory=lambda: [] + default_factory=list, ) - pure: Optional[int] = None - pure_inaccurate: Optional[int] = None - pure_early: Optional[int] = None - pure_late: Optional[int] = None - far: Optional[int] = None - far_inaccurate: Optional[int] = None - far_early: Optional[int] = None - far_late: Optional[int] = None - lost: Optional[int] = None + pure: int | None = None + pure_inaccurate: int | None = None + pure_early: int | None = None + pure_late: int | None = None + far: int | None = None + far_inaccurate: int | None = None + far_early: int | None = None + far_late: int | None = None + lost: int | None = None - played_at: Optional[datetime] = None - max_recall: Optional[int] = None - clear_status: Optional[int] = None - clear_type: Optional[int] = None - modifier: Optional[int] = None + played_at: datetime | None = None + max_recall: int | None = None + clear_status: int | None = None + clear_type: int | None = None + modifier: int | None = None -class OcrScenario(ABC): +class OcrScenario(ABC): # noqa: B024 pass diff --git a/src/arcaea_offline_ocr/scenarios/device/__init__.py b/src/arcaea_offline_ocr/scenarios/device/__init__.py index 6ae9efb..1d1113a 100644 --- a/src/arcaea_offline_ocr/scenarios/device/__init__.py +++ b/src/arcaea_offline_ocr/scenarios/device/__init__.py @@ -4,10 +4,10 @@ from .masker import DeviceRoisMaskerAutoT1, DeviceRoisMaskerAutoT2 from .rois import DeviceRoisAutoT1, DeviceRoisAutoT2 __all__ = [ - "DeviceRoisMaskerAutoT1", - "DeviceRoisMaskerAutoT2", "DeviceRoisAutoT1", "DeviceRoisAutoT2", "DeviceRoisExtractor", + "DeviceRoisMaskerAutoT1", + "DeviceRoisMaskerAutoT2", "DeviceScenario", ] diff --git a/src/arcaea_offline_ocr/scenarios/device/extractor/base.py b/src/arcaea_offline_ocr/scenarios/device/extractor/base.py index 6705d8c..43a0974 100644 --- a/src/arcaea_offline_ocr/scenarios/device/extractor/base.py +++ b/src/arcaea_offline_ocr/scenarios/device/extractor/base.py @@ -1,8 +1,7 @@ from arcaea_offline_ocr.crop import crop_xywh +from arcaea_offline_ocr.scenarios.device.rois import DeviceRois from arcaea_offline_ocr.types import Mat -from ..rois.base import DeviceRois - class DeviceRoisExtractor: def __init__(self, img: Mat, rois: DeviceRois): diff --git a/src/arcaea_offline_ocr/scenarios/device/impl.py b/src/arcaea_offline_ocr/scenarios/device/impl.py index 0ace970..7c2df7a 100644 --- a/src/arcaea_offline_ocr/scenarios/device/impl.py +++ b/src/arcaea_offline_ocr/scenarios/device/impl.py @@ -33,7 +33,8 @@ class DeviceScenario(DeviceScenarioBase): contours = self.knn_provider.contours(roi_gray) contours_filtered = self.knn_provider.contours( - roi_gray, contours_filter=contour_filter + roi_gray, + contours_filter=contour_filter, ) roi_ocr = roi_gray.copy() @@ -84,7 +85,7 @@ class DeviceScenario(DeviceScenarioBase): def max_recall(self): ocr_result = self.knn_provider.result( - self.masker.max_recall(self.extractor.max_recall) + self.masker.max_recall(self.extractor.max_recall), ) return int(ocr_result) if ocr_result else None @@ -109,7 +110,7 @@ class DeviceScenario(DeviceScenarioBase): h, w = img_gray.shape[:2] img = cv2.copyMakeBorder(img_gray, max(w - h, 0), 0, 0, 0, cv2.BORDER_REPLICATE) h, w = img.shape[:2] - img = cv2.fillPoly( + return cv2.fillPoly( img, [ np.array([[0, 0], [round(w / 2), 0], [0, round(h / 2)]], np.int32), @@ -119,12 +120,11 @@ class DeviceScenario(DeviceScenarioBase): ], (128,), ) - return img def partner_id_results(self): return self.image_id_provider.results( self.preprocess_char_icon( - cv2.cvtColor(self.extractor.partner_icon, cv2.COLOR_BGR2GRAY) + cv2.cvtColor(self.extractor.partner_icon, cv2.COLOR_BGR2GRAY), ), ImageCategory.PARTNER_ICON, ) diff --git a/src/arcaea_offline_ocr/scenarios/device/masker/__init__.py b/src/arcaea_offline_ocr/scenarios/device/masker/__init__.py index e19d62e..1843207 100644 --- a/src/arcaea_offline_ocr/scenarios/device/masker/__init__.py +++ b/src/arcaea_offline_ocr/scenarios/device/masker/__init__.py @@ -2,8 +2,8 @@ from .auto import DeviceRoisMaskerAuto, DeviceRoisMaskerAutoT1, DeviceRoisMasker from .base import DeviceRoisMasker __all__ = [ + "DeviceRoisMasker", "DeviceRoisMaskerAuto", "DeviceRoisMaskerAutoT1", "DeviceRoisMaskerAutoT2", - "DeviceRoisMasker", ] diff --git a/src/arcaea_offline_ocr/scenarios/device/masker/auto.py b/src/arcaea_offline_ocr/scenarios/device/masker/auto.py index 04ace51..bf51e90 100644 --- a/src/arcaea_offline_ocr/scenarios/device/masker/auto.py +++ b/src/arcaea_offline_ocr/scenarios/device/masker/auto.py @@ -10,7 +10,9 @@ class DeviceRoisMaskerAuto(DeviceRoisMasker): @staticmethod def mask_bgr_in_hsv(roi_bgr: Mat, hsv_lower: Mat, hsv_upper: Mat): return cv2.inRange( - cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), hsv_lower, hsv_upper + cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), + hsv_lower, + hsv_upper, ) @@ -100,25 +102,33 @@ class DeviceRoisMaskerAutoT1(DeviceRoisMaskerAuto): @classmethod def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat: return cls.mask_bgr_in_hsv( - roi_bgr, cls.TRACK_LOST_HSV_MIN, cls.TRACK_LOST_HSV_MAX + roi_bgr, + cls.TRACK_LOST_HSV_MIN, + cls.TRACK_LOST_HSV_MAX, ) @classmethod def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat: return cls.mask_bgr_in_hsv( - roi_bgr, cls.TRACK_COMPLETE_HSV_MIN, cls.TRACK_COMPLETE_HSV_MAX + roi_bgr, + cls.TRACK_COMPLETE_HSV_MIN, + cls.TRACK_COMPLETE_HSV_MAX, ) @classmethod def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat: return cls.mask_bgr_in_hsv( - roi_bgr, cls.FULL_RECALL_HSV_MIN, cls.FULL_RECALL_HSV_MAX + roi_bgr, + cls.FULL_RECALL_HSV_MIN, + cls.FULL_RECALL_HSV_MAX, ) @classmethod def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat: return cls.mask_bgr_in_hsv( - roi_bgr, cls.PURE_MEMORY_HSV_MIN, cls.PURE_MEMORY_HSV_MAX + roi_bgr, + cls.PURE_MEMORY_HSV_MIN, + cls.PURE_MEMORY_HSV_MAX, ) @@ -202,29 +212,39 @@ class DeviceRoisMaskerAutoT2(DeviceRoisMaskerAuto): @classmethod def max_recall(cls, roi_bgr: Mat) -> Mat: return cls.mask_bgr_in_hsv( - roi_bgr, cls.MAX_RECALL_HSV_MIN, cls.MAX_RECALL_HSV_MAX + roi_bgr, + cls.MAX_RECALL_HSV_MIN, + cls.MAX_RECALL_HSV_MAX, ) @classmethod def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat: return cls.mask_bgr_in_hsv( - roi_bgr, cls.TRACK_LOST_HSV_MIN, cls.TRACK_LOST_HSV_MAX + roi_bgr, + cls.TRACK_LOST_HSV_MIN, + cls.TRACK_LOST_HSV_MAX, ) @classmethod def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat: return cls.mask_bgr_in_hsv( - roi_bgr, cls.TRACK_COMPLETE_HSV_MIN, cls.TRACK_COMPLETE_HSV_MAX + roi_bgr, + cls.TRACK_COMPLETE_HSV_MIN, + cls.TRACK_COMPLETE_HSV_MAX, ) @classmethod def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat: return cls.mask_bgr_in_hsv( - roi_bgr, cls.FULL_RECALL_HSV_MIN, cls.FULL_RECALL_HSV_MAX + roi_bgr, + cls.FULL_RECALL_HSV_MIN, + cls.FULL_RECALL_HSV_MAX, ) @classmethod def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat: return cls.mask_bgr_in_hsv( - roi_bgr, cls.PURE_MEMORY_HSV_MIN, cls.PURE_MEMORY_HSV_MAX + roi_bgr, + cls.PURE_MEMORY_HSV_MIN, + cls.PURE_MEMORY_HSV_MAX, ) diff --git a/src/arcaea_offline_ocr/types.py b/src/arcaea_offline_ocr/types.py index ccab234..a87e931 100644 --- a/src/arcaea_offline_ocr/types.py +++ b/src/arcaea_offline_ocr/types.py @@ -25,18 +25,18 @@ class XYWHRect(NamedTuple): def __add__(self, other): if not isinstance(other, (list, tuple)) or len(other) != 4: - raise TypeError() + raise TypeError return self.__class__(*[a + b for a, b in zip(self, other)]) def __sub__(self, other): if not isinstance(other, (list, tuple)) or len(other) != 4: - raise TypeError() + raise TypeError return self.__class__(*[a - b for a, b in zip(self, other)]) def __mul__(self, other): if not isinstance(other, (int, float)): - raise TypeError() + raise TypeError return self.__class__(*[v * other for v in self])