4 Commits

Author SHA1 Message Date
8cc407b2bc chore: update README 2025-06-27 02:06:24 +08:00
673e45834d chore: apply ruff rules 2025-06-27 01:38:54 +08:00
57f430770e chore: update dependencies 2025-06-27 01:06:24 +08:00
d7ad85bdb0 ci: new build & publish workflow 2025-06-26 01:11:25 +08:00
24 changed files with 504 additions and 231 deletions

View File

@ -1,39 +0,0 @@
name: "Build and draft a release"
on:
workflow_dispatch:
push:
tags:
- '*.*.*'
permissions:
contents: write
discussions: write
jobs:
build-and-draft-release:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python environment
uses: actions/setup-python@v5
with:
python-version: "3.x"
- name: Build package
run: |
pip install build
python -m build
- name: Draft a release
uses: softprops/action-gh-release@v2
with:
discussion_category_name: New releases
draft: true
generate_release_notes: true
files: |
dist/*

103
.github/workflows/build-and-publish.yml vendored Normal file
View File

@ -0,0 +1,103 @@
name: Build, Release, Publish
on:
workflow_dispatch:
push:
tags:
- "*.*.*"
jobs:
build:
name: Build package
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python environment
uses: actions/setup-python@v5
with:
python-version: "3.x"
- name: Build package
run: |
pip install build
python -m build
- name: Store the distribution files
uses: actions/upload-artifact@v4
with:
name: python-package-distributions
path: dist/
draft-release:
name: Draft a release
runs-on: ubuntu-latest
needs:
- build
permissions:
contents: write
discussions: write
steps:
- name: Download the distribution files
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Draft a release
uses: softprops/action-gh-release@v2
with:
discussion_category_name: New releases
draft: true
generate_release_notes: true
files: |
dist/*
publish-to-pypi:
name: Publish distribution to PyPI
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes
needs:
- build
environment:
name: pypi
url: https://pypi.org/p/arcaea-offline-ocr
permissions:
id-token: write
steps:
- name: Download the distribution files
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Publish distribution to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
publish-to-testpypi:
name: Publish distribution to TestPyPI
runs-on: ubuntu-latest
needs:
- build
environment:
name: testpypi
url: https://test.pypi.org/p/arcaea-offline-ocr
permissions:
id-token: write
steps:
- name: Download the distribution files
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Publish distribution to TestPyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://test.pypi.org/legacy/

156
README.md
View File

@ -1,32 +1,152 @@
# Arcaea Offline OCR # Arcaea Offline OCR
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
## Example ## Example
> Results from `arcaea_offline_ocr 0.1.0a2`
### Build an image hash database (ihdb)
```py ```py
import sqlite3
from pathlib import Path
import cv2 import cv2
from arcaea_offline_ocr.device.ocr import DeviceOcr
from arcaea_offline_ocr.device.rois.definition import DeviceRoisAutoT2
from arcaea_offline_ocr.device.rois.extractor import DeviceRoisExtractor
from arcaea_offline_ocr.device.rois.masker import DeviceRoisMaskerAutoT2
from arcaea_offline_ocr.phash_db import ImagePhashDatabase
img_path = "/path/to/opencv/supported/image/formats.jpg" from arcaea_offline_ocr.builders.ihdb import (
img = cv2.imread(img_path, cv2.IMREAD_COLOR) ImageHashDatabaseBuildTask,
ImageHashesDatabaseBuilder,
)
from arcaea_offline_ocr.providers import ImageCategory, ImageHashDatabaseIdProvider
from arcaea_offline_ocr.scenarios.device import DeviceScenario
rois = DeviceRoisAutoT2(img.shape[1], img.shape[0]) def build():
extractor = DeviceRoisExtractor(img, rois) def _read_partner_icon(image_path: str):
masker = DeviceRoisMaskerAutoT2() return DeviceScenario.preprocess_char_icon(
cv2.cvtColor(cv2.imread(image_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2GRAY),
)
knn_model = cv2.ml.KNearest.load("/path/to/trained/knn/model.dat") builder = ImageHashesDatabaseBuilder()
phash_db = ImagePhashDatabase("/path/to/image/phash/database.db") tasks = [
ImageHashDatabaseBuildTask(
image_path=str(file),
image_id=file.stem,
category=ImageCategory.JACKET,
)
for file in Path("/path/to/some/jackets").glob("*.jpg")
]
ocr = DeviceOcr(extractor, masker, knn_model, phash_db) tasks.extend(
print(ocr.ocr()) [
ImageHashDatabaseBuildTask(
image_path=str(file),
image_id=file.stem,
category=ImageCategory.PARTNER_ICON,
imread_function=_read_partner_icon,
)
for file in Path("/path/to/some/partner_icons").glob("*.png")
],
)
with sqlite3.connect("/path/to/ihdb-X.Y.Z.db") as conn:
builder.build(conn, tasks)
``` ```
```sh ### Device OCR
$ python example.py
DeviceOcrResult(rating_class=2, pure=1135, far=11, lost=0, score=9953016, max_recall=1146, song_id='ringedgenesis', song_id_possibility=0.953125, clear_status=2, partner_id='45', partner_id_possibility=0.8046875) ```py
import json
import sqlite3
from dataclasses import asdict
import cv2
from arcaea_offline_ocr.providers import (
ImageHashDatabaseIdProvider,
OcrKNearestTextProvider,
)
from arcaea_offline_ocr.scenarios.device import (
DeviceRoisAutoT2,
DeviceRoisExtractor,
DeviceRoisMaskerAutoT2,
DeviceScenario,
)
with sqlite3.connect("/path/to/ihdb-X.Y.Z.db") as conn:
img = cv2.imread("/path/to/your/screenshot.jpg")
h, w = img.shape[:2]
r = DeviceRoisAutoT2(w, h)
m = DeviceRoisMaskerAutoT2()
e = DeviceRoisExtractor(img, r)
scenario = DeviceScenario(
extractor=e,
masker=m,
knn_provider=OcrKNearestTextProvider(
cv2.ml.KNearest.load("/path/to/knn_model.dat"),
),
image_id_provider=ImageHashDatabaseIdProvider(conn),
)
result = scenario.result()
with open("result.jsonc", "w", encoding="utf-8") as jf:
json.dump(asdict(result), jf, indent=2, ensure_ascii=False)
```
```jsonc
// result.json
{
"song_id": "vector",
"rating_class": 1,
"score": 9990996,
"song_id_results": [
{
"image_id": "vector",
"category": 0,
"confidence": 1.0,
"image_hash_type": 0
},
{
"image_id": "clotho",
"category": 0,
"confidence": 0.71875,
"image_hash_type": 0
}
// 28 more results omitted…
],
"partner_id_results": [
{
"image_id": "23",
"category": 1,
"confidence": 0.90625,
"image_hash_type": 0
},
{
"image_id": "45",
"category": 1,
"confidence": 0.8828125,
"image_hash_type": 0
}
// 28 more results omitted…
],
"pure": 1000,
"pure_inaccurate": null,
"pure_early": null,
"pure_late": null,
"far": 2,
"far_inaccurate": null,
"far_early": null,
"far_late": null,
"lost": 0,
"played_at": null,
"max_recall": 1002,
"clear_status": 2,
"clear_type": null,
"modifier": null
}
``` ```
## License ## License
@ -48,4 +168,4 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
## Credits ## Credits
[283375/image-phash-database](https://github.com/283375/image-phash-database) - [JohannesBuchner/imagehash](https://github.com/JohannesBuchner/imagehash): `arcaea_offline_ocr.core.hashers` implementations reference

View File

@ -10,32 +10,24 @@ authors = [{ name = "283375", email = "log_283375@163.com" }]
description = "Extract your Arcaea play result from screenshot." description = "Extract your Arcaea play result from screenshot."
readme = "README.md" readme = "README.md"
requires-python = ">=3.8" requires-python = ">=3.8"
dependencies = ["attrs==23.1.0", "numpy==1.26.1", "opencv-python==4.8.1.78"] dependencies = ["numpy~=2.3", "opencv-python~=4.11"]
classifiers = [ classifiers = [
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
] ]
[project.optional-dependencies]
dev = ["ruff", "pre-commit"]
[project.urls] [project.urls]
"Homepage" = "https://github.com/ArcaeaOffline/core-ocr" "Homepage" = "https://github.com/ArcaeaOffline/core-ocr"
"Bug Tracker" = "https://github.com/ArcaeaOffline/core-ocr/issues" "Bug Tracker" = "https://github.com/ArcaeaOffline/core-ocr/issues"
[tool.setuptools_scm] [tool.setuptools_scm]
[tool.isort]
profile = "black"
src_paths = ["src/arcaea_offline_ocr"]
[tool.pyright] [tool.pyright]
ignore = ["**/__debug*.*"] ignore = ["**/__debug*.*"]
[tool.pylint.main] [tool.ruff.lint]
# extension-pkg-allow-list = ["cv2"] select = ["ALL"]
generated-members = ["cv2.*"] ignore = ["ANN", "D", "ERA", "PLR"]
[tool.pylint.logging]
disable = [
"missing-module-docstring",
"missing-class-docstring",
"missing-function-docstring",
]

View File

@ -1,3 +1,2 @@
black==23.7.0 ruff
isort==5.12.0 pre-commit
pre-commit==3.3.3

View File

@ -1,11 +1,12 @@
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import TYPE_CHECKING, Callable, List from typing import TYPE_CHECKING, Callable
import cv2 import cv2
from arcaea_offline_ocr.core import hashers from arcaea_offline_ocr.core import hashers
from arcaea_offline_ocr.providers import ImageCategory
from arcaea_offline_ocr.providers.ihdb import ( from arcaea_offline_ocr.providers.ihdb import (
PROP_KEY_BUILT_AT, PROP_KEY_BUILT_AT,
PROP_KEY_HASH_SIZE, PROP_KEY_HASH_SIZE,
@ -17,6 +18,7 @@ from arcaea_offline_ocr.providers.ihdb import (
if TYPE_CHECKING: if TYPE_CHECKING:
from sqlite3 import Connection from sqlite3 import Connection
from arcaea_offline_ocr.providers import ImageCategory
from arcaea_offline_ocr.types import Mat from arcaea_offline_ocr.types import Mat
@ -29,7 +31,7 @@ class ImageHashDatabaseBuildTask:
image_path: str image_path: str
image_id: str image_id: str
category: ImageCategory category: ImageCategory
imread_function: Callable[[str], "Mat"] = _default_imread_gray imread_function: Callable[[str], Mat] = _default_imread_gray
@dataclass @dataclass
@ -42,7 +44,7 @@ class _ImageHash:
class ImageHashesDatabaseBuilder: class ImageHashesDatabaseBuilder:
@staticmethod @staticmethod
def __insert_property(conn: "Connection", key: str, value: str): def __insert_property(conn: Connection, key: str, value: str):
return conn.execute( return conn.execute(
"INSERT INTO properties (key, value) VALUES (?, ?)", "INSERT INTO properties (key, value) VALUES (?, ?)",
(key, value), (key, value),
@ -51,13 +53,13 @@ class ImageHashesDatabaseBuilder:
@classmethod @classmethod
def build( def build(
cls, cls,
conn: "Connection", conn: Connection,
tasks: List[ImageHashDatabaseBuildTask], tasks: list[ImageHashDatabaseBuildTask],
*, *,
hash_size: int = 16, hash_size: int = 16,
high_freq_factor: int = 4, high_freq_factor: int = 4,
): ):
hashes: List[_ImageHash] = [] hashes: list[_ImageHash] = []
for task in tasks: for task in tasks:
img_gray = task.imread_function(task.image_path) img_gray = task.imread_function(task.image_path)
@ -82,7 +84,7 @@ class ImageHashesDatabaseBuilder:
image_hash_type=hash_type, image_hash_type=hash_type,
category=task.category, category=task.category,
hash=ImageHashDatabaseIdProvider.hash_mat_to_bytes(hash_mat), hash=ImageHashDatabaseIdProvider.hash_mat_to_bytes(hash_mat),
) ),
) )
conn.execute("CREATE TABLE properties (`key` VARCHAR, `value` VARCHAR)") conn.execute("CREATE TABLE properties (`key` VARCHAR, `value` VARCHAR)")
@ -92,7 +94,7 @@ class ImageHashesDatabaseBuilder:
`category` INTEGER, `category` INTEGER,
`hash_type` INTEGER, `hash_type` INTEGER,
`hash` BLOB `hash` BLOB
)""" )""",
) )
now = datetime.now(tz=timezone.utc) now = datetime.now(tz=timezone.utc)
@ -103,7 +105,8 @@ class ImageHashesDatabaseBuilder:
cls.__insert_property(conn, PROP_KEY_BUILT_AT, str(timestamp)) cls.__insert_property(conn, PROP_KEY_BUILT_AT, str(timestamp))
conn.executemany( conn.executemany(
"INSERT INTO hashes (`id`, `category`, `hash_type`, `hash`) VALUES (?, ?, ?, ?)", """INSERT INTO hashes (`id`, `category`, `hash_type`, `hash`)
VALUES (?, ?, ?, ?)""",
[ [
(it.image_id, it.category.value, it.image_hash_type.value, it.hash) (it.image_id, it.category.value, it.image_hash_type.value, it.hash)
for it in hashes for it in hashes

View File

@ -23,7 +23,7 @@ def difference(img_gray: Mat, hash_size: int) -> Mat:
def dct(img_gray: Mat, hash_size: int = 16, high_freq_factor: int = 4) -> Mat: def dct(img_gray: Mat, hash_size: int = 16, high_freq_factor: int = 4) -> Mat:
# TODO: consistency? # TODO: consistency? # noqa: FIX002, TD002, TD003
img_size_base = hash_size * high_freq_factor img_size_base = hash_size * high_freq_factor
img_size = (img_size_base, img_size_base) img_size = (img_size_base, img_size_base)

View File

@ -1,29 +1,32 @@
from __future__ import annotations
import math import math
from typing import Tuple from typing import TYPE_CHECKING
import cv2 import cv2
import numpy as np import numpy as np
from .types import Mat if TYPE_CHECKING:
from .types import Mat
__all__ = ["crop_xywh", "CropBlackEdges"] __all__ = ["CropBlackEdges", "crop_xywh"]
def crop_xywh(mat: Mat, rect: Tuple[int, int, int, int]): def crop_xywh(mat: Mat, rect: tuple[int, int, int, int]):
x, y, w, h = rect x, y, w, h = rect
return mat[y : y + h, x : x + w] return mat[y : y + h, x : x + w]
class CropBlackEdges: class CropBlackEdges:
@staticmethod @staticmethod
def is_black_edge(__img_gray_slice: Mat, black_pixel: int, ratio: float = 0.6): def is_black_edge(img_gray_slice: Mat, black_pixel: int, ratio: float = 0.6):
pixels_compared = __img_gray_slice < black_pixel pixels_compared = img_gray_slice < black_pixel
return np.count_nonzero(pixels_compared) > math.floor( return np.count_nonzero(pixels_compared) > math.floor(
__img_gray_slice.size * ratio img_gray_slice.size * ratio,
) )
@classmethod @classmethod
def get_crop_rect(cls, img_gray: Mat, black_threshold: int = 25): def get_crop_rect(cls, img_gray: Mat, black_threshold: int = 25): # noqa: C901
height, width = img_gray.shape[:2] height, width = img_gray.shape[:2]
left = 0 left = 0
right = width right = width
@ -54,13 +57,22 @@ class CropBlackEdges:
break break
bottom -= 1 bottom -= 1
assert right > left, "cropped width < 0" if right <= left:
assert bottom > top, "cropped height < 0" msg = "cropped width < 0"
raise ValueError(msg)
if bottom <= top:
msg = "cropped height < 0"
raise ValueError(msg)
return (left, top, right - left, bottom - top) return (left, top, right - left, bottom - top)
@classmethod @classmethod
def crop( def crop(
cls, img: Mat, convert_flag: cv2.COLOR_BGR2GRAY, black_threshold: int = 25 cls,
img: Mat,
convert_flag: cv2.COLOR_BGR2GRAY,
black_threshold: int = 25,
) -> Mat: ) -> Mat:
rect = cls.get_crop_rect(cv2.cvtColor(img, convert_flag), black_threshold) rect = cls.get_crop_rect(cv2.cvtColor(img, convert_flag), black_threshold)
return crop_xywh(img, rect) return crop_xywh(img, rect)

View File

@ -5,8 +5,8 @@ from .knn import OcrKNearestTextProvider
__all__ = [ __all__ = [
"ImageCategory", "ImageCategory",
"ImageHashDatabaseIdProvider", "ImageHashDatabaseIdProvider",
"OcrKNearestTextProvider",
"ImageIdProvider", "ImageIdProvider",
"OcrTextProvider",
"ImageIdProviderResult", "ImageIdProviderResult",
"OcrKNearestTextProvider",
"OcrTextProvider",
] ]

View File

@ -1,17 +1,19 @@
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum from enum import IntEnum
from typing import TYPE_CHECKING, Any, Sequence, Optional from typing import TYPE_CHECKING, Any, Sequence
if TYPE_CHECKING: if TYPE_CHECKING:
from ..types import Mat from arcaea_offline_ocr.types import Mat
class OcrTextProvider(ABC): class OcrTextProvider(ABC):
@abstractmethod @abstractmethod
def result_raw(self, img: "Mat", /, *args, **kwargs) -> Any: ... def result_raw(self, img: Mat, /, *args, **kwargs) -> Any: ...
@abstractmethod @abstractmethod
def result(self, img: "Mat", /, *args, **kwargs) -> Optional[str]: ... def result(self, img: Mat, /, *args, **kwargs) -> str | None: ...
class ImageCategory(IntEnum): class ImageCategory(IntEnum):
@ -29,10 +31,20 @@ class ImageIdProviderResult:
class ImageIdProvider(ABC): class ImageIdProvider(ABC):
@abstractmethod @abstractmethod
def result( def result(
self, img: "Mat", category: ImageCategory, /, *args, **kwargs self,
img: Mat,
category: ImageCategory,
/,
*args,
**kwargs,
) -> ImageIdProviderResult: ... ) -> ImageIdProviderResult: ...
@abstractmethod @abstractmethod
def results( def results(
self, img: "Mat", category: ImageCategory, /, *args, **kwargs self,
img: Mat,
category: ImageCategory,
/,
*args,
**kwargs,
) -> Sequence[ImageIdProviderResult]: ... ) -> Sequence[ImageIdProviderResult]: ...

View File

@ -1,14 +1,17 @@
import sqlite3 from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timezone
from enum import IntEnum from enum import IntEnum
from typing import TYPE_CHECKING, Any, Callable, List, Optional, TypeVar from typing import TYPE_CHECKING, Any, Callable, TypeVar
from arcaea_offline_ocr.core import hashers from arcaea_offline_ocr.core import hashers
from .base import ImageCategory, ImageIdProvider, ImageIdProviderResult from .base import ImageCategory, ImageIdProvider, ImageIdProviderResult
if TYPE_CHECKING: if TYPE_CHECKING:
import sqlite3
from arcaea_offline_ocr.types import Mat from arcaea_offline_ocr.types import Mat
@ -19,9 +22,11 @@ PROP_KEY_BUILT_AT = "built_at"
def _sql_hamming_distance(hash1: bytes, hash2: bytes): def _sql_hamming_distance(hash1: bytes, hash2: bytes):
assert len(hash1) == len(hash2), "hash size does not match!" if len(hash1) != len(hash2):
count = sum(1 for byte1, byte2 in zip(hash1, hash2) if byte1 != byte2) msg = "hash size does not match!"
return count raise ValueError(msg)
return sum(1 for byte1, byte2 in zip(hash1, hash2) if byte1 != byte2)
class ImageHashType(IntEnum): class ImageHashType(IntEnum):
@ -36,7 +41,7 @@ class ImageHashDatabaseIdProviderResult(ImageIdProviderResult):
class MissingPropertiesError(Exception): class MissingPropertiesError(Exception):
keys: List[str] keys: list[str]
def __init__(self, keys, *args): def __init__(self, keys, *args):
super().__init__(*args) super().__init__(*args)
@ -72,7 +77,7 @@ class ImageHashDatabaseIdProvider(ImageIdProvider):
return self.properties[PROP_KEY_HIGH_FREQ_FACTOR] return self.properties[PROP_KEY_HIGH_FREQ_FACTOR]
@property @property
def built_at(self) -> Optional[datetime]: def built_at(self) -> datetime | None:
return self.properties.get(PROP_KEY_BUILT_AT) return self.properties.get(PROP_KEY_BUILT_AT)
@property @property
@ -80,7 +85,7 @@ class ImageHashDatabaseIdProvider(ImageIdProvider):
return self._hash_length return self._hash_length
def _initialize(self): def _initialize(self):
def get_property(key, converter: Callable[[Any], T]) -> Optional[T]: def get_property(key, converter: Callable[[Any], T]) -> T | None:
result = self.conn.execute( result = self.conn.execute(
"SELECT value FROM properties WHERE key = ?", "SELECT value FROM properties WHERE key = ?",
(key,), (key,),
@ -97,7 +102,8 @@ class ImageHashDatabaseIdProvider(ImageIdProvider):
PROP_KEY_HASH_SIZE: lambda x: int(x), PROP_KEY_HASH_SIZE: lambda x: int(x),
PROP_KEY_HIGH_FREQ_FACTOR: lambda x: int(x), PROP_KEY_HIGH_FREQ_FACTOR: lambda x: int(x),
PROP_KEY_BUILT_AT: lambda ts: datetime.fromtimestamp( PROP_KEY_BUILT_AT: lambda ts: datetime.fromtimestamp(
int(ts) / 1000, tz=timezone.utc int(ts) / 1000,
tz=timezone.utc,
), ),
} }
required_properties = [PROP_KEY_HASH_SIZE, PROP_KEY_HIGH_FREQ_FACTOR] required_properties = [PROP_KEY_HASH_SIZE, PROP_KEY_HIGH_FREQ_FACTOR]
@ -122,8 +128,11 @@ class ImageHashDatabaseIdProvider(ImageIdProvider):
self._hash_length = self.hash_size**2 self._hash_length = self.hash_size**2
def lookup_hash( def lookup_hash(
self, category: ImageCategory, hash_type: ImageHashType, hash: bytes self,
) -> List[ImageHashDatabaseIdProviderResult]: category: ImageCategory,
hash_type: ImageHashType,
hash_data: bytes,
) -> list[ImageHashDatabaseIdProviderResult]:
cursor = self.conn.execute( cursor = self.conn.execute(
""" """
SELECT SELECT
@ -132,7 +141,7 @@ SELECT
FROM hashes FROM hashes
WHERE category = ? AND hash_type = ? WHERE category = ? AND hash_type = ?
ORDER BY distance ASC LIMIT 10""", ORDER BY distance ASC LIMIT 10""",
(hash, category.value, hash_type.value), (hash_data, category.value, hash_type.value),
) )
results = [] results = []
@ -143,52 +152,52 @@ ORDER BY distance ASC LIMIT 10""",
category=category, category=category,
confidence=(self.hash_length - distance) / self.hash_length, confidence=(self.hash_length - distance) / self.hash_length,
image_hash_type=hash_type, image_hash_type=hash_type,
) ),
) )
return results return results
@staticmethod @staticmethod
def hash_mat_to_bytes(hash: "Mat") -> bytes: def hash_mat_to_bytes(hash_mat: Mat) -> bytes:
return bytes([255 if b else 0 for b in hash.flatten()]) return bytes([255 if b else 0 for b in hash_mat.flatten()])
def results(self, img: "Mat", category: ImageCategory, /): def results(self, img: Mat, category: ImageCategory, /):
results: List[ImageHashDatabaseIdProviderResult] = [] results: list[ImageHashDatabaseIdProviderResult] = []
results.extend( results.extend(
self.lookup_hash( self.lookup_hash(
category, category,
ImageHashType.AVERAGE, ImageHashType.AVERAGE,
self.hash_mat_to_bytes(hashers.average(img, self.hash_size)), self.hash_mat_to_bytes(hashers.average(img, self.hash_size)),
) ),
) )
results.extend( results.extend(
self.lookup_hash( self.lookup_hash(
category, category,
ImageHashType.DIFFERENCE, ImageHashType.DIFFERENCE,
self.hash_mat_to_bytes(hashers.difference(img, self.hash_size)), self.hash_mat_to_bytes(hashers.difference(img, self.hash_size)),
) ),
) )
results.extend( results.extend(
self.lookup_hash( self.lookup_hash(
category, category,
ImageHashType.DCT, ImageHashType.DCT,
self.hash_mat_to_bytes( self.hash_mat_to_bytes(
hashers.dct(img, self.hash_size, self.high_freq_factor) hashers.dct(img, self.hash_size, self.high_freq_factor),
),
), ),
)
) )
return results return results
def result( def result(
self, self,
img: "Mat", img: Mat,
category: ImageCategory, category: ImageCategory,
/, /,
*, *,
hash_type: ImageHashType = ImageHashType.DCT, hash_type: ImageHashType = ImageHashType.DCT,
): ):
return [ return next(
it for it in self.results(img, category) if it.image_hash_type == hash_type it for it in self.results(img, category) if it.image_hash_type == hash_type
][0] )

View File

@ -1,17 +1,20 @@
from __future__ import annotations
import logging import logging
import math import math
from typing import TYPE_CHECKING, Callable, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Callable, Sequence
import cv2 import cv2
import numpy as np import numpy as np
from ..crop import crop_xywh from arcaea_offline_ocr.crop import crop_xywh
from .base import OcrTextProvider from .base import OcrTextProvider
if TYPE_CHECKING: if TYPE_CHECKING:
from cv2.ml import KNearest from cv2.ml import KNearest
from ..types import Mat from arcaea_offline_ocr.types import Mat
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,10 +22,10 @@ logger = logging.getLogger(__name__)
class FixRects: class FixRects:
@staticmethod @staticmethod
def connect_broken( def connect_broken(
rects: Sequence[Tuple[int, int, int, int]], rects: Sequence[tuple[int, int, int, int]],
img_width: int, img_width: int,
img_height: int, img_height: int,
tolerance: Optional[int] = None, tolerance: int | None = None,
): ):
# for a "broken" digit, please refer to # for a "broken" digit, please refer to
# /assets/fix_rects/broken_masked.jpg # /assets/fix_rects/broken_masked.jpg
@ -69,8 +72,8 @@ class FixRects:
@staticmethod @staticmethod
def split_connected( def split_connected(
img_masked: "Mat", img_masked: Mat,
rects: Sequence[Tuple[int, int, int, int]], rects: Sequence[tuple[int, int, int, int]],
rect_wh_ratio: float = 1.05, rect_wh_ratio: float = 1.05,
width_range_ratio: float = 0.1, width_range_ratio: float = 0.1,
): ):
@ -111,7 +114,7 @@ class FixRects:
# split the rect # split the rect
new_rects.extend( new_rects.extend(
[(rx, ry, x_mid - rx, rh), (x_mid, ry, rx + rw - x_mid, rh)] [(rx, ry, x_mid - rx, rh), (x_mid, ry, rx + rw - x_mid, rh)],
) )
return_rects = [r for r in rects if r not in connected_rects] return_rects = [r for r in rects if r not in connected_rects]
@ -119,7 +122,7 @@ class FixRects:
return return_rects return return_rects
def resize_fill_square(img: "Mat", target: int = 20): def resize_fill_square(img: Mat, target: int = 20):
h, w = img.shape[:2] h, w = img.shape[:2]
if h > w: if h > w:
new_h = target new_h = target
@ -132,11 +135,21 @@ def resize_fill_square(img: "Mat", target: int = 20):
border_size = math.ceil((max(new_w, new_h) - min(new_w, new_h)) / 2) border_size = math.ceil((max(new_w, new_h) - min(new_w, new_h)) / 2)
if new_w < new_h: if new_w < new_h:
resized = cv2.copyMakeBorder( resized = cv2.copyMakeBorder(
resized, 0, 0, border_size, border_size, cv2.BORDER_CONSTANT resized,
0,
0,
border_size,
border_size,
cv2.BORDER_CONSTANT,
) )
else: else:
resized = cv2.copyMakeBorder( resized = cv2.copyMakeBorder(
resized, border_size, border_size, 0, 0, cv2.BORDER_CONSTANT resized,
border_size,
border_size,
0,
0,
cv2.BORDER_CONSTANT,
) )
return cv2.resize(resized, (target, target)) return cv2.resize(resized, (target, target))
@ -151,8 +164,8 @@ def preprocess_hog(digit_rois):
return np.float32(samples) return np.float32(samples)
def ocr_digit_samples_knn(__samples, knn_model: cv2.ml.KNearest, k: int = 4): def ocr_digit_samples_knn(samples, knn_model: cv2.ml.KNearest, k: int = 4):
_, results, _, _ = knn_model.findNearest(__samples, k) _, results, _, _ = knn_model.findNearest(samples, k)
return [int(r) for r in results.ravel()] return [int(r) for r in results.ravel()]
@ -160,11 +173,15 @@ class OcrKNearestTextProvider(OcrTextProvider):
_ContourFilter = Callable[["Mat"], bool] _ContourFilter = Callable[["Mat"], bool]
_RectsFilter = Callable[[Sequence[int]], bool] _RectsFilter = Callable[[Sequence[int]], bool]
def __init__(self, model: "KNearest"): def __init__(self, model: KNearest):
self.model = model self.model = model
def contours( def contours(
self, img: "Mat", /, *, contours_filter: Optional[_ContourFilter] = None self,
img: Mat,
/,
*,
contours_filter: _ContourFilter | None = None,
): ):
cnts, _ = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) cnts, _ = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
if contours_filter: if contours_filter:
@ -174,12 +191,12 @@ class OcrKNearestTextProvider(OcrTextProvider):
def result_raw( def result_raw(
self, self,
img: "Mat", img: Mat,
/, /,
*, *,
fix_rects: bool = True, fix_rects: bool = True,
contours_filter: Optional[_ContourFilter] = None, contours_filter: _ContourFilter | None = None,
rects_filter: Optional[_RectsFilter] = None, rects_filter: _RectsFilter | None = None,
): ):
""" """
:param img: grayscaled roi :param img: grayscaled roi
@ -192,11 +209,11 @@ class OcrKNearestTextProvider(OcrTextProvider):
rects = [cv2.boundingRect(cnt) for cnt in cnts] rects = [cv2.boundingRect(cnt) for cnt in cnts]
if fix_rects and rects_filter: if fix_rects and rects_filter:
rects = FixRects.connect_broken(rects, img.shape[1], img.shape[0]) # type: ignore rects = FixRects.connect_broken(rects, img.shape[1], img.shape[0])
rects = list(filter(rects_filter, rects)) rects = list(filter(rects_filter, rects))
rects = FixRects.split_connected(img, rects) rects = FixRects.split_connected(img, rects)
elif fix_rects: elif fix_rects:
rects = FixRects.connect_broken(rects, img.shape[1], img.shape[0]) # type: ignore rects = FixRects.connect_broken(rects, img.shape[1], img.shape[0])
rects = FixRects.split_connected(img, rects) rects = FixRects.split_connected(img, rects)
elif rects_filter: elif rects_filter:
rects = list(filter(rects_filter, rects)) rects = list(filter(rects_filter, rects))
@ -216,12 +233,12 @@ class OcrKNearestTextProvider(OcrTextProvider):
def result( def result(
self, self,
img: "Mat", img: Mat,
/, /,
*, *,
fix_rects: bool = True, fix_rects: bool = True,
contours_filter: Optional[_ContourFilter] = None, contours_filter: _ContourFilter | None = None,
rects_filter: Optional[_RectsFilter] = None, rects_filter: _RectsFilter | None = None,
): ):
""" """
:param img: grayscaled roi :param img: grayscaled roi

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING
from arcaea_offline_ocr.scenarios.base import OcrScenario, OcrScenarioResult from arcaea_offline_ocr.scenarios.base import OcrScenario, OcrScenarioResult
@ -9,13 +11,13 @@ if TYPE_CHECKING:
class Best30Scenario(OcrScenario): class Best30Scenario(OcrScenario):
@abstractmethod @abstractmethod
def components(self, img: "Mat", /) -> List["Mat"]: ... def components(self, img: Mat, /) -> list[Mat]: ...
@abstractmethod @abstractmethod
def result(self, component_img: "Mat", /, *args, **kwargs) -> OcrScenarioResult: ... def result(self, component_img: Mat, /, *args, **kwargs) -> OcrScenarioResult: ...
@abstractmethod @abstractmethod
def results(self, img: "Mat", /, *args, **kwargs) -> List[OcrScenarioResult]: def results(self, img: Mat, /, *args, **kwargs) -> list[OcrScenarioResult]:
""" """
Commonly a shorthand for `[self.result(comp) for comp in self.components(img)]` Commonly a shorthand for `[self.result(comp) for comp in self.components(img)]`
""" """

View File

@ -1,19 +1,19 @@
import numpy as np import numpy as np
__all__ = [ __all__ = [
"FONT_THRESHOLD",
"PURE_BG_MIN_HSV",
"PURE_BG_MAX_HSV",
"FAR_BG_MIN_HSV",
"FAR_BG_MAX_HSV",
"LOST_BG_MIN_HSV",
"LOST_BG_MAX_HSV",
"BYD_MIN_HSV",
"BYD_MAX_HSV", "BYD_MAX_HSV",
"FTR_MIN_HSV", "BYD_MIN_HSV",
"FAR_BG_MAX_HSV",
"FAR_BG_MIN_HSV",
"FONT_THRESHOLD",
"FTR_MAX_HSV", "FTR_MAX_HSV",
"PRS_MIN_HSV", "FTR_MIN_HSV",
"LOST_BG_MAX_HSV",
"LOST_BG_MIN_HSV",
"PRS_MAX_HSV", "PRS_MAX_HSV",
"PRS_MIN_HSV",
"PURE_BG_MAX_HSV",
"PURE_BG_MIN_HSV",
] ]
FONT_THRESHOLD = 160 FONT_THRESHOLD = 160

View File

@ -1,4 +1,6 @@
from typing import List, Optional, Tuple from __future__ import annotations
from typing import TYPE_CHECKING
import cv2 import cv2
import numpy as np import numpy as np
@ -11,7 +13,9 @@ from arcaea_offline_ocr.providers import (
) )
from arcaea_offline_ocr.scenarios.b30.base import Best30Scenario from arcaea_offline_ocr.scenarios.b30.base import Best30Scenario
from arcaea_offline_ocr.scenarios.base import OcrScenarioResult from arcaea_offline_ocr.scenarios.base import OcrScenarioResult
from arcaea_offline_ocr.types import Mat
if TYPE_CHECKING:
from arcaea_offline_ocr.types import Mat
from .colors import ( from .colors import (
BYD_MAX_HSV, BYD_MAX_HSV,
@ -71,13 +75,13 @@ class ChieriBotV4Best30Scenario(Best30Scenario):
rating_class_results = [np.count_nonzero(m) for m in rating_class_masks] rating_class_results = [np.count_nonzero(m) for m in rating_class_masks]
if max(rating_class_results) < 70: if max(rating_class_results) < 70:
return 0 return 0
else:
return max(enumerate(rating_class_results), key=lambda i: i[1])[0] + 1 return max(enumerate(rating_class_results), key=lambda i: i[1])[0] + 1
def ocr_component_song_id_results(self, component_bgr: Mat): def ocr_component_song_id_results(self, component_bgr: Mat):
jacket_rect = self.rois.component_rois.jacket_rect.floored() jacket_rect = self.rois.component_rois.jacket_rect.floored()
jacket_roi = cv2.cvtColor( jacket_roi = cv2.cvtColor(
crop_xywh(component_bgr, jacket_rect), cv2.COLOR_BGR2GRAY crop_xywh(component_bgr, jacket_rect),
cv2.COLOR_BGR2GRAY,
) )
return self.image_id_provider.results(jacket_roi, ImageCategory.JACKET) return self.image_id_provider.results(jacket_roi, ImageCategory.JACKET)
@ -85,16 +89,22 @@ class ChieriBotV4Best30Scenario(Best30Scenario):
# sourcery skip: inline-immediately-returned-variable # sourcery skip: inline-immediately-returned-variable
score_rect = self.rois.component_rois.score_rect.rounded() score_rect = self.rois.component_rois.score_rect.rounded()
score_roi = cv2.cvtColor( score_roi = cv2.cvtColor(
crop_xywh(component_bgr, score_rect), cv2.COLOR_BGR2GRAY crop_xywh(component_bgr, score_rect),
cv2.COLOR_BGR2GRAY,
) )
_, score_roi = cv2.threshold( _, score_roi = cv2.threshold(
score_roi, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU score_roi,
0,
255,
cv2.THRESH_BINARY + cv2.THRESH_OTSU,
) )
if score_roi[1][1] == 255: if score_roi[1][1] == 255:
score_roi = 255 - score_roi score_roi = 255 - score_roi
contours, _ = cv2.findContours( contours, _ = cv2.findContours(
score_roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE score_roi,
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE,
) )
for contour in contours: for contour in contours:
rect = cv2.boundingRect(contour) rect = cv2.boundingRect(contour)
@ -106,8 +116,9 @@ class ChieriBotV4Best30Scenario(Best30Scenario):
return int(ocr_result) if ocr_result else 0 return int(ocr_result) if ocr_result else 0
def find_pfl_rects( def find_pfl_rects(
self, component_pfl_processed: Mat self,
) -> List[Tuple[int, int, int, int]]: component_pfl_processed: Mat,
) -> list[tuple[int, int, int, int]]:
# sourcery skip: inline-immediately-returned-variable # sourcery skip: inline-immediately-returned-variable
pfl_roi_find = cv2.morphologyEx( pfl_roi_find = cv2.morphologyEx(
component_pfl_processed, component_pfl_processed,
@ -115,14 +126,16 @@ class ChieriBotV4Best30Scenario(Best30Scenario):
cv2.getStructuringElement(cv2.MORPH_RECT, [10, 1]), cv2.getStructuringElement(cv2.MORPH_RECT, [10, 1]),
) )
pfl_contours, _ = cv2.findContours( pfl_contours, _ = cv2.findContours(
pfl_roi_find, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE pfl_roi_find,
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_NONE,
) )
pfl_rects = [cv2.boundingRect(c) for c in pfl_contours] pfl_rects = [cv2.boundingRect(c) for c in pfl_contours]
pfl_rects = [ pfl_rects = [
r for r in pfl_rects if r[3] > component_pfl_processed.shape[0] * 0.1 r for r in pfl_rects if r[3] > component_pfl_processed.shape[0] * 0.1
] ]
pfl_rects = sorted(pfl_rects, key=lambda r: r[1]) pfl_rects = sorted(pfl_rects, key=lambda r: r[1])
pfl_rects_adjusted = [ return [
( (
max(rect[0] - 2, 0), max(rect[0] - 2, 0),
rect[1], rect[1],
@ -131,7 +144,6 @@ class ChieriBotV4Best30Scenario(Best30Scenario):
) )
for rect in pfl_rects for rect in pfl_rects
] ]
return pfl_rects_adjusted
def preprocess_component_pfl(self, component_bgr: Mat) -> Mat: def preprocess_component_pfl(self, component_bgr: Mat) -> Mat:
pfl_rect = self.rois.component_rois.pfl_rect.rounded() pfl_rect = self.rois.component_rois.pfl_rect.rounded()
@ -154,11 +166,17 @@ class ChieriBotV4Best30Scenario(Best30Scenario):
pfl_roi_blurred = cv2.GaussianBlur(pfl_roi, (5, 5), 0) pfl_roi_blurred = cv2.GaussianBlur(pfl_roi, (5, 5), 0)
# pfl_roi_blurred = cv2.medianBlur(pfl_roi, 3) # pfl_roi_blurred = cv2.medianBlur(pfl_roi, 3)
_, pfl_roi_blurred_threshold = cv2.threshold( _, pfl_roi_blurred_threshold = cv2.threshold(
pfl_roi_blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU pfl_roi_blurred,
0,
255,
cv2.THRESH_BINARY + cv2.THRESH_OTSU,
) )
# and a threshold of the original roi # and a threshold of the original roi
_, pfl_roi_threshold = cv2.threshold( _, pfl_roi_threshold = cv2.threshold(
pfl_roi, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU pfl_roi,
0,
255,
cv2.THRESH_BINARY + cv2.THRESH_OTSU,
) )
# turn thresholds into black background # turn thresholds into black background
if pfl_roi_blurred_threshold[2][2] == 255: if pfl_roi_blurred_threshold[2][2] == 255:
@ -168,13 +186,15 @@ class ChieriBotV4Best30Scenario(Best30Scenario):
# return a bitwise_and result # return a bitwise_and result
result = cv2.bitwise_and(pfl_roi_blurred_threshold, pfl_roi_threshold) result = cv2.bitwise_and(pfl_roi_blurred_threshold, pfl_roi_threshold)
result_eroded = cv2.erode( result_eroded = cv2.erode(
result, cv2.getStructuringElement(cv2.MORPH_CROSS, (2, 2)) result,
cv2.getStructuringElement(cv2.MORPH_CROSS, (2, 2)),
) )
return result_eroded if len(self.find_pfl_rects(result_eroded)) == 3 else result return result_eroded if len(self.find_pfl_rects(result_eroded)) == 3 else result
def ocr_component_pfl( def ocr_component_pfl(
self, component_bgr: Mat self,
) -> Tuple[Optional[int], Optional[int], Optional[int]]: component_bgr: Mat,
) -> tuple[int | None, int | None, int | None]:
try: try:
pfl_roi = self.preprocess_component_pfl(component_bgr) pfl_roi = self.preprocess_component_pfl(component_bgr)
pfl_rects = self.find_pfl_rects(pfl_roi) pfl_rects = self.find_pfl_rects(pfl_roi)
@ -185,7 +205,7 @@ class ChieriBotV4Best30Scenario(Best30Scenario):
pure_far_lost.append(int(result) if result else None) pure_far_lost.append(int(result) if result else None)
return tuple(pure_far_lost) return tuple(pure_far_lost)
except Exception: except Exception: # noqa: BLE001
return (None, None, None) return (None, None, None)
def ocr_component(self, component_bgr: Mat) -> OcrScenarioResult: def ocr_component(self, component_bgr: Mat) -> OcrScenarioResult:
@ -216,7 +236,7 @@ class ChieriBotV4Best30Scenario(Best30Scenario):
def result(self, component_img: Mat, /): def result(self, component_img: Mat, /):
return self.ocr_component(component_img) return self.ocr_component(component_img)
def results(self, img: Mat, /) -> List[OcrScenarioResult]: def results(self, img: Mat, /) -> list[OcrScenarioResult]:
""" """
:param img: BGR format image :param img: BGR format image
""" """

View File

@ -1,4 +1,4 @@
from typing import List from __future__ import annotations
from arcaea_offline_ocr.crop import crop_xywh from arcaea_offline_ocr.crop import crop_xywh
from arcaea_offline_ocr.types import Mat, XYWHRect from arcaea_offline_ocr.types import Mat, XYWHRect
@ -105,7 +105,7 @@ class ChieriBotV4Rois:
def b33_vertical_gap(self): def b33_vertical_gap(self):
return 121 * self.factor return 121 * self.factor
def components(self, img_bgr: Mat) -> List[Mat]: def components(self, img_bgr: Mat) -> list[Mat]:
first_rect = XYWHRect(x=self.left, y=self.top, w=self.width, h=self.height) first_rect = XYWHRect(x=self.left, y=self.top, w=self.width, h=self.height)
results = [] results = []

View File

@ -1,9 +1,13 @@
from __future__ import annotations
from abc import ABC from abc import ABC
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from typing import TYPE_CHECKING, Sequence
from typing import Sequence, Optional
from arcaea_offline_ocr.providers import ImageIdProviderResult if TYPE_CHECKING:
from datetime import datetime
from arcaea_offline_ocr.providers import ImageIdProviderResult
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -12,27 +16,27 @@ class OcrScenarioResult:
rating_class: int rating_class: int
score: int score: int
song_id_results: Sequence[ImageIdProviderResult] = field(default_factory=lambda: []) song_id_results: Sequence[ImageIdProviderResult] = field(default_factory=list)
partner_id_results: Sequence[ImageIdProviderResult] = field( partner_id_results: Sequence[ImageIdProviderResult] = field(
default_factory=lambda: [] default_factory=list,
) )
pure: Optional[int] = None pure: int | None = None
pure_inaccurate: Optional[int] = None pure_inaccurate: int | None = None
pure_early: Optional[int] = None pure_early: int | None = None
pure_late: Optional[int] = None pure_late: int | None = None
far: Optional[int] = None far: int | None = None
far_inaccurate: Optional[int] = None far_inaccurate: int | None = None
far_early: Optional[int] = None far_early: int | None = None
far_late: Optional[int] = None far_late: int | None = None
lost: Optional[int] = None lost: int | None = None
played_at: Optional[datetime] = None played_at: datetime | None = None
max_recall: Optional[int] = None max_recall: int | None = None
clear_status: Optional[int] = None clear_status: int | None = None
clear_type: Optional[int] = None clear_type: int | None = None
modifier: Optional[int] = None modifier: int | None = None
class OcrScenario(ABC): class OcrScenario(ABC): # noqa: B024
pass pass

View File

@ -4,10 +4,10 @@ from .masker import DeviceRoisMaskerAutoT1, DeviceRoisMaskerAutoT2
from .rois import DeviceRoisAutoT1, DeviceRoisAutoT2 from .rois import DeviceRoisAutoT1, DeviceRoisAutoT2
__all__ = [ __all__ = [
"DeviceRoisMaskerAutoT1",
"DeviceRoisMaskerAutoT2",
"DeviceRoisAutoT1", "DeviceRoisAutoT1",
"DeviceRoisAutoT2", "DeviceRoisAutoT2",
"DeviceRoisExtractor", "DeviceRoisExtractor",
"DeviceRoisMaskerAutoT1",
"DeviceRoisMaskerAutoT2",
"DeviceScenario", "DeviceScenario",
] ]

View File

@ -1,8 +1,7 @@
from arcaea_offline_ocr.crop import crop_xywh from arcaea_offline_ocr.crop import crop_xywh
from arcaea_offline_ocr.scenarios.device.rois import DeviceRois
from arcaea_offline_ocr.types import Mat from arcaea_offline_ocr.types import Mat
from ..rois.base import DeviceRois
class DeviceRoisExtractor: class DeviceRoisExtractor:
def __init__(self, img: Mat, rois: DeviceRois): def __init__(self, img: Mat, rois: DeviceRois):

View File

@ -33,7 +33,8 @@ class DeviceScenario(DeviceScenarioBase):
contours = self.knn_provider.contours(roi_gray) contours = self.knn_provider.contours(roi_gray)
contours_filtered = self.knn_provider.contours( contours_filtered = self.knn_provider.contours(
roi_gray, contours_filter=contour_filter roi_gray,
contours_filter=contour_filter,
) )
roi_ocr = roi_gray.copy() roi_ocr = roi_gray.copy()
@ -84,7 +85,7 @@ class DeviceScenario(DeviceScenarioBase):
def max_recall(self): def max_recall(self):
ocr_result = self.knn_provider.result( ocr_result = self.knn_provider.result(
self.masker.max_recall(self.extractor.max_recall) self.masker.max_recall(self.extractor.max_recall),
) )
return int(ocr_result) if ocr_result else None return int(ocr_result) if ocr_result else None
@ -109,7 +110,7 @@ class DeviceScenario(DeviceScenarioBase):
h, w = img_gray.shape[:2] h, w = img_gray.shape[:2]
img = cv2.copyMakeBorder(img_gray, max(w - h, 0), 0, 0, 0, cv2.BORDER_REPLICATE) img = cv2.copyMakeBorder(img_gray, max(w - h, 0), 0, 0, 0, cv2.BORDER_REPLICATE)
h, w = img.shape[:2] h, w = img.shape[:2]
img = cv2.fillPoly( return cv2.fillPoly(
img, img,
[ [
np.array([[0, 0], [round(w / 2), 0], [0, round(h / 2)]], np.int32), np.array([[0, 0], [round(w / 2), 0], [0, round(h / 2)]], np.int32),
@ -119,12 +120,11 @@ class DeviceScenario(DeviceScenarioBase):
], ],
(128,), (128,),
) )
return img
def partner_id_results(self): def partner_id_results(self):
return self.image_id_provider.results( return self.image_id_provider.results(
self.preprocess_char_icon( self.preprocess_char_icon(
cv2.cvtColor(self.extractor.partner_icon, cv2.COLOR_BGR2GRAY) cv2.cvtColor(self.extractor.partner_icon, cv2.COLOR_BGR2GRAY),
), ),
ImageCategory.PARTNER_ICON, ImageCategory.PARTNER_ICON,
) )

View File

@ -2,8 +2,8 @@ from .auto import DeviceRoisMaskerAuto, DeviceRoisMaskerAutoT1, DeviceRoisMasker
from .base import DeviceRoisMasker from .base import DeviceRoisMasker
__all__ = [ __all__ = [
"DeviceRoisMasker",
"DeviceRoisMaskerAuto", "DeviceRoisMaskerAuto",
"DeviceRoisMaskerAutoT1", "DeviceRoisMaskerAutoT1",
"DeviceRoisMaskerAutoT2", "DeviceRoisMaskerAutoT2",
"DeviceRoisMasker",
] ]

View File

@ -10,7 +10,9 @@ class DeviceRoisMaskerAuto(DeviceRoisMasker):
@staticmethod @staticmethod
def mask_bgr_in_hsv(roi_bgr: Mat, hsv_lower: Mat, hsv_upper: Mat): def mask_bgr_in_hsv(roi_bgr: Mat, hsv_lower: Mat, hsv_upper: Mat):
return cv2.inRange( return cv2.inRange(
cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), hsv_lower, hsv_upper cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV),
hsv_lower,
hsv_upper,
) )
@ -100,25 +102,33 @@ class DeviceRoisMaskerAutoT1(DeviceRoisMaskerAuto):
@classmethod @classmethod
def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat: def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, cls.TRACK_LOST_HSV_MIN, cls.TRACK_LOST_HSV_MAX roi_bgr,
cls.TRACK_LOST_HSV_MIN,
cls.TRACK_LOST_HSV_MAX,
) )
@classmethod @classmethod
def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat: def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, cls.TRACK_COMPLETE_HSV_MIN, cls.TRACK_COMPLETE_HSV_MAX roi_bgr,
cls.TRACK_COMPLETE_HSV_MIN,
cls.TRACK_COMPLETE_HSV_MAX,
) )
@classmethod @classmethod
def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat: def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, cls.FULL_RECALL_HSV_MIN, cls.FULL_RECALL_HSV_MAX roi_bgr,
cls.FULL_RECALL_HSV_MIN,
cls.FULL_RECALL_HSV_MAX,
) )
@classmethod @classmethod
def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat: def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, cls.PURE_MEMORY_HSV_MIN, cls.PURE_MEMORY_HSV_MAX roi_bgr,
cls.PURE_MEMORY_HSV_MIN,
cls.PURE_MEMORY_HSV_MAX,
) )
@ -202,29 +212,39 @@ class DeviceRoisMaskerAutoT2(DeviceRoisMaskerAuto):
@classmethod @classmethod
def max_recall(cls, roi_bgr: Mat) -> Mat: def max_recall(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, cls.MAX_RECALL_HSV_MIN, cls.MAX_RECALL_HSV_MAX roi_bgr,
cls.MAX_RECALL_HSV_MIN,
cls.MAX_RECALL_HSV_MAX,
) )
@classmethod @classmethod
def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat: def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, cls.TRACK_LOST_HSV_MIN, cls.TRACK_LOST_HSV_MAX roi_bgr,
cls.TRACK_LOST_HSV_MIN,
cls.TRACK_LOST_HSV_MAX,
) )
@classmethod @classmethod
def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat: def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, cls.TRACK_COMPLETE_HSV_MIN, cls.TRACK_COMPLETE_HSV_MAX roi_bgr,
cls.TRACK_COMPLETE_HSV_MIN,
cls.TRACK_COMPLETE_HSV_MAX,
) )
@classmethod @classmethod
def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat: def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, cls.FULL_RECALL_HSV_MIN, cls.FULL_RECALL_HSV_MAX roi_bgr,
cls.FULL_RECALL_HSV_MIN,
cls.FULL_RECALL_HSV_MAX,
) )
@classmethod @classmethod
def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat: def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, cls.PURE_MEMORY_HSV_MIN, cls.PURE_MEMORY_HSV_MAX roi_bgr,
cls.PURE_MEMORY_HSV_MIN,
cls.PURE_MEMORY_HSV_MAX,
) )

View File

@ -25,18 +25,18 @@ class XYWHRect(NamedTuple):
def __add__(self, other): def __add__(self, other):
if not isinstance(other, (list, tuple)) or len(other) != 4: if not isinstance(other, (list, tuple)) or len(other) != 4:
raise TypeError() raise TypeError
return self.__class__(*[a + b for a, b in zip(self, other)]) return self.__class__(*[a + b for a, b in zip(self, other)])
def __sub__(self, other): def __sub__(self, other):
if not isinstance(other, (list, tuple)) or len(other) != 4: if not isinstance(other, (list, tuple)) or len(other) != 4:
raise TypeError() raise TypeError
return self.__class__(*[a - b for a, b in zip(self, other)]) return self.__class__(*[a - b for a, b in zip(self, other)])
def __mul__(self, other): def __mul__(self, other):
if not isinstance(other, (int, float)): if not isinstance(other, (int, float)):
raise TypeError() raise TypeError
return self.__class__(*[v * other for v in self]) return self.__class__(*[v * other for v in self])