12 Commits

Author SHA1 Message Date
8cc407b2bc chore: update README 2025-06-27 02:06:24 +08:00
673e45834d chore: apply ruff rules 2025-06-27 01:38:54 +08:00
57f430770e chore: update dependencies 2025-06-27 01:06:24 +08:00
d7ad85bdb0 ci: new build & publish workflow 2025-06-26 01:11:25 +08:00
0b53682398 ci: fix tag condition
- ok i didnt expect github actions does not support regex wtf
2025-06-26 00:15:02 +08:00
b117346b46 chore: setuptools-scm integration 2025-06-26 00:07:08 +08:00
ad0a33daad ci: tag regex 2025-06-25 23:43:46 +08:00
c08a1332a7 chore!: removing unused code 2025-06-25 23:37:21 +08:00
0055d9e8da refactor!: device scenario
- Correct abstract class annotations
2025-06-25 23:35:38 +08:00
06156db9c2 refactor!: chieri v4 b30 scenario
- Remove useless `.utils` code
2025-06-25 23:27:15 +08:00
c65798a02d feat: XYWHRect __mul__ 2025-06-25 23:19:52 +08:00
f11dc6e38f refactor: scenario base 2025-06-25 23:11:45 +08:00
49 changed files with 846 additions and 635 deletions

View File

@ -1,48 +0,0 @@
name: "Build and draft a release"
on:
workflow_dispatch:
push:
tags:
- "v[0-9]+.[0-9]+.[0-9]+"
permissions:
contents: write
discussions: write
jobs:
build-and-draft-release:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python environment
uses: actions/setup-python@v5
with:
python-version: "3.x"
- name: Build package
run: |
pip install build
python -m build
- name: Remove `v` in tag name
uses: mad9000/actions-find-and-replace-string@5
id: tagNameReplaced
with:
source: ${{ github.ref_name }}
find: "v"
replace: ""
- name: Draft a release
uses: softprops/action-gh-release@v2
with:
discussion_category_name: New releases
draft: true
generate_release_notes: true
files: |
dist/arcaea_offline_ocr-${{ steps.tagNameReplaced.outputs.value }}*.whl
dist/arcaea-offline-ocr-${{ steps.tagNameReplaced.outputs.value }}.tar.gz

103
.github/workflows/build-and-publish.yml vendored Normal file
View File

@ -0,0 +1,103 @@
name: Build, Release, Publish
on:
workflow_dispatch:
push:
tags:
- "*.*.*"
jobs:
build:
name: Build package
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python environment
uses: actions/setup-python@v5
with:
python-version: "3.x"
- name: Build package
run: |
pip install build
python -m build
- name: Store the distribution files
uses: actions/upload-artifact@v4
with:
name: python-package-distributions
path: dist/
draft-release:
name: Draft a release
runs-on: ubuntu-latest
needs:
- build
permissions:
contents: write
discussions: write
steps:
- name: Download the distribution files
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Draft a release
uses: softprops/action-gh-release@v2
with:
discussion_category_name: New releases
draft: true
generate_release_notes: true
files: |
dist/*
publish-to-pypi:
name: Publish distribution to PyPI
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes
needs:
- build
environment:
name: pypi
url: https://pypi.org/p/arcaea-offline-ocr
permissions:
id-token: write
steps:
- name: Download the distribution files
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Publish distribution to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
publish-to-testpypi:
name: Publish distribution to TestPyPI
runs-on: ubuntu-latest
needs:
- build
environment:
name: testpypi
url: https://test.pypi.org/p/arcaea-offline-ocr
permissions:
id-token: write
steps:
- name: Download the distribution files
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Publish distribution to TestPyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://test.pypi.org/legacy/

156
README.md
View File

@ -1,32 +1,152 @@
# Arcaea Offline OCR # Arcaea Offline OCR
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
## Example ## Example
> Results from `arcaea_offline_ocr 0.1.0a2`
### Build an image hash database (ihdb)
```py ```py
import sqlite3
from pathlib import Path
import cv2 import cv2
from arcaea_offline_ocr.device.ocr import DeviceOcr
from arcaea_offline_ocr.device.rois.definition import DeviceRoisAutoT2
from arcaea_offline_ocr.device.rois.extractor import DeviceRoisExtractor
from arcaea_offline_ocr.device.rois.masker import DeviceRoisMaskerAutoT2
from arcaea_offline_ocr.phash_db import ImagePhashDatabase
img_path = "/path/to/opencv/supported/image/formats.jpg" from arcaea_offline_ocr.builders.ihdb import (
img = cv2.imread(img_path, cv2.IMREAD_COLOR) ImageHashDatabaseBuildTask,
ImageHashesDatabaseBuilder,
)
from arcaea_offline_ocr.providers import ImageCategory, ImageHashDatabaseIdProvider
from arcaea_offline_ocr.scenarios.device import DeviceScenario
rois = DeviceRoisAutoT2(img.shape[1], img.shape[0]) def build():
extractor = DeviceRoisExtractor(img, rois) def _read_partner_icon(image_path: str):
masker = DeviceRoisMaskerAutoT2() return DeviceScenario.preprocess_char_icon(
cv2.cvtColor(cv2.imread(image_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2GRAY),
)
knn_model = cv2.ml.KNearest.load("/path/to/trained/knn/model.dat") builder = ImageHashesDatabaseBuilder()
phash_db = ImagePhashDatabase("/path/to/image/phash/database.db") tasks = [
ImageHashDatabaseBuildTask(
image_path=str(file),
image_id=file.stem,
category=ImageCategory.JACKET,
)
for file in Path("/path/to/some/jackets").glob("*.jpg")
]
ocr = DeviceOcr(extractor, masker, knn_model, phash_db) tasks.extend(
print(ocr.ocr()) [
ImageHashDatabaseBuildTask(
image_path=str(file),
image_id=file.stem,
category=ImageCategory.PARTNER_ICON,
imread_function=_read_partner_icon,
)
for file in Path("/path/to/some/partner_icons").glob("*.png")
],
)
with sqlite3.connect("/path/to/ihdb-X.Y.Z.db") as conn:
builder.build(conn, tasks)
``` ```
```sh ### Device OCR
$ python example.py
DeviceOcrResult(rating_class=2, pure=1135, far=11, lost=0, score=9953016, max_recall=1146, song_id='ringedgenesis', song_id_possibility=0.953125, clear_status=2, partner_id='45', partner_id_possibility=0.8046875) ```py
import json
import sqlite3
from dataclasses import asdict
import cv2
from arcaea_offline_ocr.providers import (
ImageHashDatabaseIdProvider,
OcrKNearestTextProvider,
)
from arcaea_offline_ocr.scenarios.device import (
DeviceRoisAutoT2,
DeviceRoisExtractor,
DeviceRoisMaskerAutoT2,
DeviceScenario,
)
with sqlite3.connect("/path/to/ihdb-X.Y.Z.db") as conn:
img = cv2.imread("/path/to/your/screenshot.jpg")
h, w = img.shape[:2]
r = DeviceRoisAutoT2(w, h)
m = DeviceRoisMaskerAutoT2()
e = DeviceRoisExtractor(img, r)
scenario = DeviceScenario(
extractor=e,
masker=m,
knn_provider=OcrKNearestTextProvider(
cv2.ml.KNearest.load("/path/to/knn_model.dat"),
),
image_id_provider=ImageHashDatabaseIdProvider(conn),
)
result = scenario.result()
with open("result.jsonc", "w", encoding="utf-8") as jf:
json.dump(asdict(result), jf, indent=2, ensure_ascii=False)
```
```jsonc
// result.json
{
"song_id": "vector",
"rating_class": 1,
"score": 9990996,
"song_id_results": [
{
"image_id": "vector",
"category": 0,
"confidence": 1.0,
"image_hash_type": 0
},
{
"image_id": "clotho",
"category": 0,
"confidence": 0.71875,
"image_hash_type": 0
}
// 28 more results omitted…
],
"partner_id_results": [
{
"image_id": "23",
"category": 1,
"confidence": 0.90625,
"image_hash_type": 0
},
{
"image_id": "45",
"category": 1,
"confidence": 0.8828125,
"image_hash_type": 0
}
// 28 more results omitted…
],
"pure": 1000,
"pure_inaccurate": null,
"pure_early": null,
"pure_late": null,
"far": 2,
"far_inaccurate": null,
"far_early": null,
"far_late": null,
"lost": 0,
"played_at": null,
"max_recall": 1002,
"clear_status": 2,
"clear_type": null,
"modifier": null
}
``` ```
## License ## License
@ -48,4 +168,4 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
## Credits ## Credits
[283375/image-phash-database](https://github.com/283375/image-phash-database) - [JohannesBuchner/imagehash](https://github.com/JohannesBuchner/imagehash): `arcaea_offline_ocr.core.hashers` implementations reference

View File

@ -1,38 +1,33 @@
[build-system] [build-system]
requires = ["setuptools>=61.0"] requires = ["setuptools>=64", "setuptools-scm>=8"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[project] [project]
dynamic = ["version"]
name = "arcaea-offline-ocr" name = "arcaea-offline-ocr"
version = "0.0.99"
authors = [{ name = "283375", email = "log_283375@163.com" }] authors = [{ name = "283375", email = "log_283375@163.com" }]
description = "Extract your Arcaea play result from screenshot." description = "Extract your Arcaea play result from screenshot."
readme = "README.md" readme = "README.md"
requires-python = ">=3.8" requires-python = ">=3.8"
dependencies = ["attrs==23.1.0", "numpy==1.26.1", "opencv-python==4.8.1.78"] dependencies = ["numpy~=2.3", "opencv-python~=4.11"]
classifiers = [ classifiers = [
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
] ]
[project.optional-dependencies]
dev = ["ruff", "pre-commit"]
[project.urls] [project.urls]
"Homepage" = "https://github.com/ArcaeaOffline/core-ocr" "Homepage" = "https://github.com/ArcaeaOffline/core-ocr"
"Bug Tracker" = "https://github.com/ArcaeaOffline/core-ocr/issues" "Bug Tracker" = "https://github.com/ArcaeaOffline/core-ocr/issues"
[tool.isort] [tool.setuptools_scm]
profile = "black"
src_paths = ["src/arcaea_offline_ocr"]
[tool.pyright] [tool.pyright]
ignore = ["**/__debug*.*"] ignore = ["**/__debug*.*"]
[tool.pylint.main] [tool.ruff.lint]
# extension-pkg-allow-list = ["cv2"] select = ["ALL"]
generated-members = ["cv2.*"] ignore = ["ANN", "D", "ERA", "PLR"]
[tool.pylint.logging]
disable = [
"missing-module-docstring",
"missing-class-docstring",
"missing-function-docstring"
]

View File

@ -1,3 +1,2 @@
black==23.7.0 ruff
isort==5.12.0 pre-commit
pre-commit==3.3.3

View File

@ -1,3 +0,0 @@
from .crop import *
from .device import *
from .utils import *

View File

@ -1,15 +0,0 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
@dataclass
class B30OcrResultItem:
rating_class: int
score: int
pure: Optional[int] = None
far: Optional[int] = None
lost: Optional[int] = None
date: Optional[datetime] = None
title: Optional[str] = None
song_id: Optional[str] = None

View File

@ -1,11 +1,12 @@
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import TYPE_CHECKING, Callable, List from typing import TYPE_CHECKING, Callable
import cv2 import cv2
from arcaea_offline_ocr.core import hashers from arcaea_offline_ocr.core import hashers
from arcaea_offline_ocr.providers import ImageCategory
from arcaea_offline_ocr.providers.ihdb import ( from arcaea_offline_ocr.providers.ihdb import (
PROP_KEY_BUILT_AT, PROP_KEY_BUILT_AT,
PROP_KEY_HASH_SIZE, PROP_KEY_HASH_SIZE,
@ -17,6 +18,7 @@ from arcaea_offline_ocr.providers.ihdb import (
if TYPE_CHECKING: if TYPE_CHECKING:
from sqlite3 import Connection from sqlite3 import Connection
from arcaea_offline_ocr.providers import ImageCategory
from arcaea_offline_ocr.types import Mat from arcaea_offline_ocr.types import Mat
@ -29,7 +31,7 @@ class ImageHashDatabaseBuildTask:
image_path: str image_path: str
image_id: str image_id: str
category: ImageCategory category: ImageCategory
imread_function: Callable[[str], "Mat"] = _default_imread_gray imread_function: Callable[[str], Mat] = _default_imread_gray
@dataclass @dataclass
@ -42,7 +44,7 @@ class _ImageHash:
class ImageHashesDatabaseBuilder: class ImageHashesDatabaseBuilder:
@staticmethod @staticmethod
def __insert_property(conn: "Connection", key: str, value: str): def __insert_property(conn: Connection, key: str, value: str):
return conn.execute( return conn.execute(
"INSERT INTO properties (key, value) VALUES (?, ?)", "INSERT INTO properties (key, value) VALUES (?, ?)",
(key, value), (key, value),
@ -51,13 +53,13 @@ class ImageHashesDatabaseBuilder:
@classmethod @classmethod
def build( def build(
cls, cls,
conn: "Connection", conn: Connection,
tasks: List[ImageHashDatabaseBuildTask], tasks: list[ImageHashDatabaseBuildTask],
*, *,
hash_size: int = 16, hash_size: int = 16,
high_freq_factor: int = 4, high_freq_factor: int = 4,
): ):
hashes: List[_ImageHash] = [] hashes: list[_ImageHash] = []
for task in tasks: for task in tasks:
img_gray = task.imread_function(task.image_path) img_gray = task.imread_function(task.image_path)
@ -82,7 +84,7 @@ class ImageHashesDatabaseBuilder:
image_hash_type=hash_type, image_hash_type=hash_type,
category=task.category, category=task.category,
hash=ImageHashDatabaseIdProvider.hash_mat_to_bytes(hash_mat), hash=ImageHashDatabaseIdProvider.hash_mat_to_bytes(hash_mat),
) ),
) )
conn.execute("CREATE TABLE properties (`key` VARCHAR, `value` VARCHAR)") conn.execute("CREATE TABLE properties (`key` VARCHAR, `value` VARCHAR)")
@ -92,7 +94,7 @@ class ImageHashesDatabaseBuilder:
`category` INTEGER, `category` INTEGER,
`hash_type` INTEGER, `hash_type` INTEGER,
`hash` BLOB `hash` BLOB
)""" )""",
) )
now = datetime.now(tz=timezone.utc) now = datetime.now(tz=timezone.utc)
@ -103,7 +105,8 @@ class ImageHashesDatabaseBuilder:
cls.__insert_property(conn, PROP_KEY_BUILT_AT, str(timestamp)) cls.__insert_property(conn, PROP_KEY_BUILT_AT, str(timestamp))
conn.executemany( conn.executemany(
"INSERT INTO hashes (`id`, `category`, `hash_type`, `hash`) VALUES (?, ?, ?, ?)", """INSERT INTO hashes (`id`, `category`, `hash_type`, `hash`)
VALUES (?, ?, ?, ?)""",
[ [
(it.image_id, it.category.value, it.image_hash_type.value, it.hash) (it.image_id, it.category.value, it.image_hash_type.value, it.hash)
for it in hashes for it in hashes

View File

@ -23,7 +23,7 @@ def difference(img_gray: Mat, hash_size: int) -> Mat:
def dct(img_gray: Mat, hash_size: int = 16, high_freq_factor: int = 4) -> Mat: def dct(img_gray: Mat, hash_size: int = 16, high_freq_factor: int = 4) -> Mat:
# TODO: consistency? # TODO: consistency? # noqa: FIX002, TD002, TD003
img_size_base = hash_size * high_freq_factor img_size_base = hash_size * high_freq_factor
img_size = (img_size_base, img_size_base) img_size = (img_size_base, img_size_base)

View File

@ -1,29 +1,32 @@
from __future__ import annotations
import math import math
from typing import Tuple from typing import TYPE_CHECKING
import cv2 import cv2
import numpy as np import numpy as np
from .types import Mat if TYPE_CHECKING:
from .types import Mat
__all__ = ["crop_xywh", "CropBlackEdges"] __all__ = ["CropBlackEdges", "crop_xywh"]
def crop_xywh(mat: Mat, rect: Tuple[int, int, int, int]): def crop_xywh(mat: Mat, rect: tuple[int, int, int, int]):
x, y, w, h = rect x, y, w, h = rect
return mat[y : y + h, x : x + w] return mat[y : y + h, x : x + w]
class CropBlackEdges: class CropBlackEdges:
@staticmethod @staticmethod
def is_black_edge(__img_gray_slice: Mat, black_pixel: int, ratio: float = 0.6): def is_black_edge(img_gray_slice: Mat, black_pixel: int, ratio: float = 0.6):
pixels_compared = __img_gray_slice < black_pixel pixels_compared = img_gray_slice < black_pixel
return np.count_nonzero(pixels_compared) > math.floor( return np.count_nonzero(pixels_compared) > math.floor(
__img_gray_slice.size * ratio img_gray_slice.size * ratio,
) )
@classmethod @classmethod
def get_crop_rect(cls, img_gray: Mat, black_threshold: int = 25): def get_crop_rect(cls, img_gray: Mat, black_threshold: int = 25): # noqa: C901
height, width = img_gray.shape[:2] height, width = img_gray.shape[:2]
left = 0 left = 0
right = width right = width
@ -54,13 +57,22 @@ class CropBlackEdges:
break break
bottom -= 1 bottom -= 1
assert right > left, "cropped width < 0" if right <= left:
assert bottom > top, "cropped height < 0" msg = "cropped width < 0"
raise ValueError(msg)
if bottom <= top:
msg = "cropped height < 0"
raise ValueError(msg)
return (left, top, right - left, bottom - top) return (left, top, right - left, bottom - top)
@classmethod @classmethod
def crop( def crop(
cls, img: Mat, convert_flag: cv2.COLOR_BGR2GRAY, black_threshold: int = 25 cls,
img: Mat,
convert_flag: cv2.COLOR_BGR2GRAY,
black_threshold: int = 25,
) -> Mat: ) -> Mat:
rect = cls.get_crop_rect(cv2.cvtColor(img, convert_flag), black_threshold) rect = cls.get_crop_rect(cv2.cvtColor(img, convert_flag), black_threshold)
return crop_xywh(img, rect) return crop_xywh(img, rect)

View File

@ -1,2 +0,0 @@
from .common import DeviceOcrResult
from .ocr import DeviceOcr

View File

@ -1,17 +0,0 @@
from dataclasses import dataclass
from typing import Optional
@dataclass
class DeviceOcrResult:
rating_class: int
score: int
pure: Optional[int] = None
far: Optional[int] = None
lost: Optional[int] = None
max_recall: Optional[int] = None
song_id: Optional[str] = None
song_id_possibility: Optional[float] = None
clear_status: Optional[int] = None
partner_id: Optional[str] = None
partner_id_possibility: Optional[float] = None

View File

@ -1,3 +0,0 @@
from .definition import *
from .extractor import *
from .masker import *

View File

@ -1,2 +0,0 @@
from .auto import *
from .common import DeviceRois

View File

@ -1,15 +0,0 @@
from typing import Tuple
Rect = Tuple[int, int, int, int]
class DeviceRois:
pure: Rect
far: Rect
lost: Rect
score: Rect
rating_class: Rect
max_recall: Rect
jacket: Rect
clear_status: Rect
partner_icon: Rect

View File

@ -1 +0,0 @@
from .common import DeviceRoisExtractor

View File

@ -1,48 +0,0 @@
from ....crop import crop_xywh
from ....types import Mat
from ..definition.common import DeviceRois
class DeviceRoisExtractor:
def __init__(self, img: Mat, rois: DeviceRois):
self.img = img
self.sizes = rois
def __construct_int_rect(self, rect):
return tuple(round(r) for r in rect)
@property
def pure(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.pure))
@property
def far(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.far))
@property
def lost(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.lost))
@property
def score(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.score))
@property
def jacket(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.jacket))
@property
def rating_class(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.rating_class))
@property
def max_recall(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.max_recall))
@property
def clear_status(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.clear_status))
@property
def partner_icon(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.partner_icon))

View File

@ -1,2 +0,0 @@
from .auto import *
from .common import DeviceRoisMasker

View File

@ -1,59 +0,0 @@
from ....types import Mat
class DeviceRoisMasker:
@classmethod
def pure(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def far(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def lost(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def score(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def rating_class_pst(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def rating_class_prs(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def rating_class_ftr(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def rating_class_byd(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def rating_class_etr(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def max_recall(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()

View File

@ -1,119 +0,0 @@
import sqlite3
from typing import List, Union
import cv2
import numpy as np
from .types import Mat
def phash_opencv(img_gray, hash_size=8, highfreq_factor=4):
# type: (Union[Mat, np.ndarray], int, int) -> np.ndarray
"""
Perceptual Hash computation.
Implementation follows
http://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html
Adapted from `imagehash.phash`, pure opencv implementation
The result is slightly different from `imagehash.phash`.
"""
if hash_size < 2:
raise ValueError("Hash size must be greater than or equal to 2")
img_size = hash_size * highfreq_factor
image = cv2.resize(img_gray, (img_size, img_size), interpolation=cv2.INTER_LANCZOS4)
image = np.float32(image)
dct = cv2.dct(image)
dctlowfreq = dct[:hash_size, :hash_size]
med = np.median(dctlowfreq)
diff = dctlowfreq > med
return diff
def hamming_distance_sql_function(user_input, db_entry) -> int:
return np.count_nonzero(
np.frombuffer(user_input, bool) ^ np.frombuffer(db_entry, bool)
)
class ImagePhashDatabase:
def __init__(self, db_path: str):
with sqlite3.connect(db_path) as conn:
self.hash_size = int(
conn.execute(
"SELECT value FROM properties WHERE key = 'hash_size'"
).fetchone()[0]
)
self.highfreq_factor = int(
conn.execute(
"SELECT value FROM properties WHERE key = 'highfreq_factor'"
).fetchone()[0]
)
self.built_timestamp = int(
conn.execute(
"SELECT value FROM properties WHERE key = 'built_timestamp'"
).fetchone()[0]
)
self.ids: List[str] = [
i[0] for i in conn.execute("SELECT id FROM hashes").fetchall()
]
self.hashes_byte = [
i[0] for i in conn.execute("SELECT hash FROM hashes").fetchall()
]
self.hashes = [np.frombuffer(hb, bool) for hb in self.hashes_byte]
self.jacket_ids: List[str] = []
self.jacket_hashes = []
self.partner_icon_ids: List[str] = []
self.partner_icon_hashes = []
for _id, _hash in zip(self.ids, self.hashes):
id_splitted = _id.split("||")
if len(id_splitted) > 1 and id_splitted[0] == "partner_icon":
self.partner_icon_ids.append(id_splitted[1])
self.partner_icon_hashes.append(_hash)
else:
self.jacket_ids.append(_id)
self.jacket_hashes.append(_hash)
def calculate_phash(self, img_gray: Mat):
return phash_opencv(
img_gray, hash_size=self.hash_size, highfreq_factor=self.highfreq_factor
)
def lookup_hash(self, image_hash: np.ndarray, *, limit: int = 5):
image_hash = image_hash.flatten()
xor_results = [
(id, np.count_nonzero(image_hash ^ h))
for id, h in zip(self.ids, self.hashes)
]
return sorted(xor_results, key=lambda r: r[1])[:limit]
def lookup_image(self, img_gray: Mat):
image_hash = self.calculate_phash(img_gray)
return self.lookup_hash(image_hash)[0]
def lookup_jackets(self, img_gray: Mat, *, limit: int = 5):
image_hash = self.calculate_phash(img_gray).flatten()
xor_results = [
(id, np.count_nonzero(image_hash ^ h))
for id, h in zip(self.jacket_ids, self.jacket_hashes)
]
return sorted(xor_results, key=lambda r: r[1])[:limit]
def lookup_jacket(self, img_gray: Mat):
return self.lookup_jackets(img_gray)[0]
def lookup_partner_icons(self, img_gray: Mat, *, limit: int = 5):
image_hash = self.calculate_phash(img_gray).flatten()
xor_results = [
(id, np.count_nonzero(image_hash ^ h))
for id, h in zip(self.partner_icon_ids, self.partner_icon_hashes)
]
return sorted(xor_results, key=lambda r: r[1])[:limit]
def lookup_partner_icon(self, img_gray: Mat):
return self.lookup_partner_icons(img_gray)[0]

View File

@ -5,8 +5,8 @@ from .knn import OcrKNearestTextProvider
__all__ = [ __all__ = [
"ImageCategory", "ImageCategory",
"ImageHashDatabaseIdProvider", "ImageHashDatabaseIdProvider",
"OcrKNearestTextProvider",
"ImageIdProvider", "ImageIdProvider",
"OcrTextProvider",
"ImageIdProviderResult", "ImageIdProviderResult",
"OcrKNearestTextProvider",
"OcrTextProvider",
] ]

View File

@ -1,17 +1,19 @@
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum from enum import IntEnum
from typing import TYPE_CHECKING, Any, Sequence, Optional from typing import TYPE_CHECKING, Any, Sequence
if TYPE_CHECKING: if TYPE_CHECKING:
from ..types import Mat from arcaea_offline_ocr.types import Mat
class OcrTextProvider(ABC): class OcrTextProvider(ABC):
@abstractmethod @abstractmethod
def result_raw(self, img: "Mat", /, *args, **kwargs) -> Any: ... def result_raw(self, img: Mat, /, *args, **kwargs) -> Any: ...
@abstractmethod @abstractmethod
def result(self, img: "Mat", /, *args, **kwargs) -> Optional[str]: ... def result(self, img: Mat, /, *args, **kwargs) -> str | None: ...
class ImageCategory(IntEnum): class ImageCategory(IntEnum):
@ -29,10 +31,20 @@ class ImageIdProviderResult:
class ImageIdProvider(ABC): class ImageIdProvider(ABC):
@abstractmethod @abstractmethod
def result( def result(
self, img: "Mat", category: ImageCategory, /, *args, **kwargs self,
img: Mat,
category: ImageCategory,
/,
*args,
**kwargs,
) -> ImageIdProviderResult: ... ) -> ImageIdProviderResult: ...
@abstractmethod @abstractmethod
def results( def results(
self, img: "Mat", category: ImageCategory, /, *args, **kwargs self,
img: Mat,
category: ImageCategory,
/,
*args,
**kwargs,
) -> Sequence[ImageIdProviderResult]: ... ) -> Sequence[ImageIdProviderResult]: ...

View File

@ -1,14 +1,17 @@
import sqlite3 from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timezone
from enum import IntEnum from enum import IntEnum
from typing import TYPE_CHECKING, Any, Callable, List, Optional, TypeVar from typing import TYPE_CHECKING, Any, Callable, TypeVar
from arcaea_offline_ocr.core import hashers from arcaea_offline_ocr.core import hashers
from .base import ImageCategory, ImageIdProvider, ImageIdProviderResult from .base import ImageCategory, ImageIdProvider, ImageIdProviderResult
if TYPE_CHECKING: if TYPE_CHECKING:
import sqlite3
from arcaea_offline_ocr.types import Mat from arcaea_offline_ocr.types import Mat
@ -19,9 +22,11 @@ PROP_KEY_BUILT_AT = "built_at"
def _sql_hamming_distance(hash1: bytes, hash2: bytes): def _sql_hamming_distance(hash1: bytes, hash2: bytes):
assert len(hash1) == len(hash2), "hash size does not match!" if len(hash1) != len(hash2):
count = sum(1 for byte1, byte2 in zip(hash1, hash2) if byte1 != byte2) msg = "hash size does not match!"
return count raise ValueError(msg)
return sum(1 for byte1, byte2 in zip(hash1, hash2) if byte1 != byte2)
class ImageHashType(IntEnum): class ImageHashType(IntEnum):
@ -36,7 +41,7 @@ class ImageHashDatabaseIdProviderResult(ImageIdProviderResult):
class MissingPropertiesError(Exception): class MissingPropertiesError(Exception):
keys: List[str] keys: list[str]
def __init__(self, keys, *args): def __init__(self, keys, *args):
super().__init__(*args) super().__init__(*args)
@ -72,7 +77,7 @@ class ImageHashDatabaseIdProvider(ImageIdProvider):
return self.properties[PROP_KEY_HIGH_FREQ_FACTOR] return self.properties[PROP_KEY_HIGH_FREQ_FACTOR]
@property @property
def built_at(self) -> Optional[datetime]: def built_at(self) -> datetime | None:
return self.properties.get(PROP_KEY_BUILT_AT) return self.properties.get(PROP_KEY_BUILT_AT)
@property @property
@ -80,7 +85,7 @@ class ImageHashDatabaseIdProvider(ImageIdProvider):
return self._hash_length return self._hash_length
def _initialize(self): def _initialize(self):
def get_property(key, converter: Callable[[Any], T]) -> Optional[T]: def get_property(key, converter: Callable[[Any], T]) -> T | None:
result = self.conn.execute( result = self.conn.execute(
"SELECT value FROM properties WHERE key = ?", "SELECT value FROM properties WHERE key = ?",
(key,), (key,),
@ -97,7 +102,8 @@ class ImageHashDatabaseIdProvider(ImageIdProvider):
PROP_KEY_HASH_SIZE: lambda x: int(x), PROP_KEY_HASH_SIZE: lambda x: int(x),
PROP_KEY_HIGH_FREQ_FACTOR: lambda x: int(x), PROP_KEY_HIGH_FREQ_FACTOR: lambda x: int(x),
PROP_KEY_BUILT_AT: lambda ts: datetime.fromtimestamp( PROP_KEY_BUILT_AT: lambda ts: datetime.fromtimestamp(
int(ts) / 1000, tz=timezone.utc int(ts) / 1000,
tz=timezone.utc,
), ),
} }
required_properties = [PROP_KEY_HASH_SIZE, PROP_KEY_HIGH_FREQ_FACTOR] required_properties = [PROP_KEY_HASH_SIZE, PROP_KEY_HIGH_FREQ_FACTOR]
@ -122,8 +128,11 @@ class ImageHashDatabaseIdProvider(ImageIdProvider):
self._hash_length = self.hash_size**2 self._hash_length = self.hash_size**2
def lookup_hash( def lookup_hash(
self, category: ImageCategory, hash_type: ImageHashType, hash: bytes self,
) -> List[ImageHashDatabaseIdProviderResult]: category: ImageCategory,
hash_type: ImageHashType,
hash_data: bytes,
) -> list[ImageHashDatabaseIdProviderResult]:
cursor = self.conn.execute( cursor = self.conn.execute(
""" """
SELECT SELECT
@ -132,7 +141,7 @@ SELECT
FROM hashes FROM hashes
WHERE category = ? AND hash_type = ? WHERE category = ? AND hash_type = ?
ORDER BY distance ASC LIMIT 10""", ORDER BY distance ASC LIMIT 10""",
(hash, category.value, hash_type.value), (hash_data, category.value, hash_type.value),
) )
results = [] results = []
@ -143,52 +152,52 @@ ORDER BY distance ASC LIMIT 10""",
category=category, category=category,
confidence=(self.hash_length - distance) / self.hash_length, confidence=(self.hash_length - distance) / self.hash_length,
image_hash_type=hash_type, image_hash_type=hash_type,
) ),
) )
return results return results
@staticmethod @staticmethod
def hash_mat_to_bytes(hash: "Mat") -> bytes: def hash_mat_to_bytes(hash_mat: Mat) -> bytes:
return bytes([255 if b else 0 for b in hash.flatten()]) return bytes([255 if b else 0 for b in hash_mat.flatten()])
def results(self, img: "Mat", category: ImageCategory, /): def results(self, img: Mat, category: ImageCategory, /):
results: List[ImageHashDatabaseIdProviderResult] = [] results: list[ImageHashDatabaseIdProviderResult] = []
results.extend( results.extend(
self.lookup_hash( self.lookup_hash(
category, category,
ImageHashType.AVERAGE, ImageHashType.AVERAGE,
self.hash_mat_to_bytes(hashers.average(img, self.hash_size)), self.hash_mat_to_bytes(hashers.average(img, self.hash_size)),
) ),
) )
results.extend( results.extend(
self.lookup_hash( self.lookup_hash(
category, category,
ImageHashType.DIFFERENCE, ImageHashType.DIFFERENCE,
self.hash_mat_to_bytes(hashers.difference(img, self.hash_size)), self.hash_mat_to_bytes(hashers.difference(img, self.hash_size)),
) ),
) )
results.extend( results.extend(
self.lookup_hash( self.lookup_hash(
category, category,
ImageHashType.DCT, ImageHashType.DCT,
self.hash_mat_to_bytes( self.hash_mat_to_bytes(
hashers.dct(img, self.hash_size, self.high_freq_factor) hashers.dct(img, self.hash_size, self.high_freq_factor),
),
), ),
)
) )
return results return results
def result( def result(
self, self,
img: "Mat", img: Mat,
category: ImageCategory, category: ImageCategory,
/, /,
*, *,
hash_type: ImageHashType = ImageHashType.DCT, hash_type: ImageHashType = ImageHashType.DCT,
): ):
return [ return next(
it for it in self.results(img, category) if it.image_hash_type == hash_type it for it in self.results(img, category) if it.image_hash_type == hash_type
][0] )

View File

@ -1,17 +1,20 @@
from __future__ import annotations
import logging import logging
import math import math
from typing import TYPE_CHECKING, Callable, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Callable, Sequence
import cv2 import cv2
import numpy as np import numpy as np
from ..crop import crop_xywh from arcaea_offline_ocr.crop import crop_xywh
from .base import OcrTextProvider from .base import OcrTextProvider
if TYPE_CHECKING: if TYPE_CHECKING:
from cv2.ml import KNearest from cv2.ml import KNearest
from ..types import Mat from arcaea_offline_ocr.types import Mat
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,10 +22,10 @@ logger = logging.getLogger(__name__)
class FixRects: class FixRects:
@staticmethod @staticmethod
def connect_broken( def connect_broken(
rects: Sequence[Tuple[int, int, int, int]], rects: Sequence[tuple[int, int, int, int]],
img_width: int, img_width: int,
img_height: int, img_height: int,
tolerance: Optional[int] = None, tolerance: int | None = None,
): ):
# for a "broken" digit, please refer to # for a "broken" digit, please refer to
# /assets/fix_rects/broken_masked.jpg # /assets/fix_rects/broken_masked.jpg
@ -69,8 +72,8 @@ class FixRects:
@staticmethod @staticmethod
def split_connected( def split_connected(
img_masked: "Mat", img_masked: Mat,
rects: Sequence[Tuple[int, int, int, int]], rects: Sequence[tuple[int, int, int, int]],
rect_wh_ratio: float = 1.05, rect_wh_ratio: float = 1.05,
width_range_ratio: float = 0.1, width_range_ratio: float = 0.1,
): ):
@ -111,7 +114,7 @@ class FixRects:
# split the rect # split the rect
new_rects.extend( new_rects.extend(
[(rx, ry, x_mid - rx, rh), (x_mid, ry, rx + rw - x_mid, rh)] [(rx, ry, x_mid - rx, rh), (x_mid, ry, rx + rw - x_mid, rh)],
) )
return_rects = [r for r in rects if r not in connected_rects] return_rects = [r for r in rects if r not in connected_rects]
@ -119,7 +122,7 @@ class FixRects:
return return_rects return return_rects
def resize_fill_square(img: "Mat", target: int = 20): def resize_fill_square(img: Mat, target: int = 20):
h, w = img.shape[:2] h, w = img.shape[:2]
if h > w: if h > w:
new_h = target new_h = target
@ -132,11 +135,21 @@ def resize_fill_square(img: "Mat", target: int = 20):
border_size = math.ceil((max(new_w, new_h) - min(new_w, new_h)) / 2) border_size = math.ceil((max(new_w, new_h) - min(new_w, new_h)) / 2)
if new_w < new_h: if new_w < new_h:
resized = cv2.copyMakeBorder( resized = cv2.copyMakeBorder(
resized, 0, 0, border_size, border_size, cv2.BORDER_CONSTANT resized,
0,
0,
border_size,
border_size,
cv2.BORDER_CONSTANT,
) )
else: else:
resized = cv2.copyMakeBorder( resized = cv2.copyMakeBorder(
resized, border_size, border_size, 0, 0, cv2.BORDER_CONSTANT resized,
border_size,
border_size,
0,
0,
cv2.BORDER_CONSTANT,
) )
return cv2.resize(resized, (target, target)) return cv2.resize(resized, (target, target))
@ -151,8 +164,8 @@ def preprocess_hog(digit_rois):
return np.float32(samples) return np.float32(samples)
def ocr_digit_samples_knn(__samples, knn_model: cv2.ml.KNearest, k: int = 4): def ocr_digit_samples_knn(samples, knn_model: cv2.ml.KNearest, k: int = 4):
_, results, _, _ = knn_model.findNearest(__samples, k) _, results, _, _ = knn_model.findNearest(samples, k)
return [int(r) for r in results.ravel()] return [int(r) for r in results.ravel()]
@ -160,11 +173,15 @@ class OcrKNearestTextProvider(OcrTextProvider):
_ContourFilter = Callable[["Mat"], bool] _ContourFilter = Callable[["Mat"], bool]
_RectsFilter = Callable[[Sequence[int]], bool] _RectsFilter = Callable[[Sequence[int]], bool]
def __init__(self, model: "KNearest"): def __init__(self, model: KNearest):
self.model = model self.model = model
def contours( def contours(
self, img: "Mat", /, *, contours_filter: Optional[_ContourFilter] = None self,
img: Mat,
/,
*,
contours_filter: _ContourFilter | None = None,
): ):
cnts, _ = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) cnts, _ = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
if contours_filter: if contours_filter:
@ -174,12 +191,12 @@ class OcrKNearestTextProvider(OcrTextProvider):
def result_raw( def result_raw(
self, self,
img: "Mat", img: Mat,
/, /,
*, *,
fix_rects: bool = True, fix_rects: bool = True,
contours_filter: Optional[_ContourFilter] = None, contours_filter: _ContourFilter | None = None,
rects_filter: Optional[_RectsFilter] = None, rects_filter: _RectsFilter | None = None,
): ):
""" """
:param img: grayscaled roi :param img: grayscaled roi
@ -192,11 +209,11 @@ class OcrKNearestTextProvider(OcrTextProvider):
rects = [cv2.boundingRect(cnt) for cnt in cnts] rects = [cv2.boundingRect(cnt) for cnt in cnts]
if fix_rects and rects_filter: if fix_rects and rects_filter:
rects = FixRects.connect_broken(rects, img.shape[1], img.shape[0]) # type: ignore rects = FixRects.connect_broken(rects, img.shape[1], img.shape[0])
rects = list(filter(rects_filter, rects)) rects = list(filter(rects_filter, rects))
rects = FixRects.split_connected(img, rects) rects = FixRects.split_connected(img, rects)
elif fix_rects: elif fix_rects:
rects = FixRects.connect_broken(rects, img.shape[1], img.shape[0]) # type: ignore rects = FixRects.connect_broken(rects, img.shape[1], img.shape[0])
rects = FixRects.split_connected(img, rects) rects = FixRects.split_connected(img, rects)
elif rects_filter: elif rects_filter:
rects = list(filter(rects_filter, rects)) rects = list(filter(rects_filter, rects))
@ -216,12 +233,12 @@ class OcrKNearestTextProvider(OcrTextProvider):
def result( def result(
self, self,
img: "Mat", img: Mat,
/, /,
*, *,
fix_rects: bool = True, fix_rects: bool = True,
contours_filter: Optional[_ContourFilter] = None, contours_filter: _ContourFilter | None = None,
rects_filter: Optional[_RectsFilter] = None, rects_filter: _RectsFilter | None = None,
): ):
""" """
:param img: grayscaled roi :param img: grayscaled roi

View File

@ -0,0 +1,3 @@
from .chieri import ChieriBotV4Best30Scenario
__all__ = ["ChieriBotV4Best30Scenario"]

View File

@ -0,0 +1,24 @@
from __future__ import annotations
from abc import abstractmethod
from typing import TYPE_CHECKING
from arcaea_offline_ocr.scenarios.base import OcrScenario, OcrScenarioResult
if TYPE_CHECKING:
from arcaea_offline_ocr.types import Mat
class Best30Scenario(OcrScenario):
@abstractmethod
def components(self, img: Mat, /) -> list[Mat]: ...
@abstractmethod
def result(self, component_img: Mat, /, *args, **kwargs) -> OcrScenarioResult: ...
@abstractmethod
def results(self, img: Mat, /, *args, **kwargs) -> list[OcrScenarioResult]:
"""
Commonly a shorthand for `[self.result(comp) for comp in self.components(img)]`
"""
...

View File

@ -0,0 +1,3 @@
from .v4 import ChieriBotV4Best30Scenario
__all__ = ["ChieriBotV4Best30Scenario"]

View File

@ -0,0 +1,3 @@
from .impl import ChieriBotV4Best30Scenario
__all__ = ["ChieriBotV4Best30Scenario"]

View File

@ -1,19 +1,19 @@
import numpy as np import numpy as np
__all__ = [ __all__ = [
"FONT_THRESHOLD",
"PURE_BG_MIN_HSV",
"PURE_BG_MAX_HSV",
"FAR_BG_MIN_HSV",
"FAR_BG_MAX_HSV",
"LOST_BG_MIN_HSV",
"LOST_BG_MAX_HSV",
"BYD_MIN_HSV",
"BYD_MAX_HSV", "BYD_MAX_HSV",
"FTR_MIN_HSV", "BYD_MIN_HSV",
"FAR_BG_MAX_HSV",
"FAR_BG_MIN_HSV",
"FONT_THRESHOLD",
"FTR_MAX_HSV", "FTR_MAX_HSV",
"PRS_MIN_HSV", "FTR_MIN_HSV",
"LOST_BG_MAX_HSV",
"LOST_BG_MIN_HSV",
"PRS_MAX_HSV", "PRS_MAX_HSV",
"PRS_MIN_HSV",
"PURE_BG_MAX_HSV",
"PURE_BG_MIN_HSV",
] ]
FONT_THRESHOLD = 160 FONT_THRESHOLD = 160
@ -27,11 +27,11 @@ FAR_BG_MAX_HSV = np.array([20, 255, 255], np.uint8)
LOST_BG_MIN_HSV = np.array([115, 60, 150], np.uint8) LOST_BG_MIN_HSV = np.array([115, 60, 150], np.uint8)
LOST_BG_MAX_HSV = np.array([140, 255, 255], np.uint8) LOST_BG_MAX_HSV = np.array([140, 255, 255], np.uint8)
BYD_MIN_HSV = (158, 120, 0) BYD_MIN_HSV = np.array([158, 120, 0], np.uint8)
BYD_MAX_HSV = (172, 255, 255) BYD_MAX_HSV = np.array([172, 255, 255], np.uint8)
FTR_MIN_HSV = (145, 70, 0) FTR_MIN_HSV = np.array([145, 70, 0], np.uint8)
FTR_MAX_HSV = (160, 255, 255) FTR_MAX_HSV = np.array([160, 255, 255], np.uint8)
PRS_MIN_HSV = (45, 60, 0) PRS_MIN_HSV = np.array([45, 60, 0], np.uint8)
PRS_MAX_HSV = (70, 255, 255) PRS_MAX_HSV = np.array([70, 255, 255], np.uint8)

View File

@ -1,12 +1,22 @@
from typing import List, Optional, Tuple from __future__ import annotations
from typing import TYPE_CHECKING
import cv2 import cv2
import numpy as np import numpy as np
from ....crop import crop_xywh from arcaea_offline_ocr.crop import crop_xywh
from ....phash_db import ImagePhashDatabase from arcaea_offline_ocr.providers import (
from ....types import Mat ImageCategory,
from ...shared import B30OcrResultItem ImageIdProvider,
OcrKNearestTextProvider,
)
from arcaea_offline_ocr.scenarios.b30.base import Best30Scenario
from arcaea_offline_ocr.scenarios.base import OcrScenarioResult
if TYPE_CHECKING:
from arcaea_offline_ocr.types import Mat
from .colors import ( from .colors import (
BYD_MAX_HSV, BYD_MAX_HSV,
BYD_MIN_HSV, BYD_MIN_HSV,
@ -22,29 +32,20 @@ from .colors import (
PURE_BG_MIN_HSV, PURE_BG_MIN_HSV,
) )
from .rois import ChieriBotV4Rois from .rois import ChieriBotV4Rois
from ....providers.knn import OcrKNearestTextProvider
class ChieriBotV4Ocr: class ChieriBotV4Best30Scenario(Best30Scenario):
def __init__( def __init__(
self, self,
score_knn_provider: OcrKNearestTextProvider, score_knn_provider: OcrKNearestTextProvider,
pfl_knn_provider: OcrKNearestTextProvider, pfl_knn_provider: OcrKNearestTextProvider,
phash_db: ImagePhashDatabase, image_id_provider: ImageIdProvider,
factor: float = 1.0, factor: float = 1.0,
): ):
self.__phash_db = phash_db
self.__rois = ChieriBotV4Rois(factor) self.__rois = ChieriBotV4Rois(factor)
self.pfl_knn_provider = pfl_knn_provider self.pfl_knn_provider = pfl_knn_provider
self.score_knn_provider = score_knn_provider self.score_knn_provider = score_knn_provider
self.image_id_provider = image_id_provider
@property
def phash_db(self):
return self.__phash_db
@phash_db.setter
def phash_db(self, phash_db: ImagePhashDatabase):
self.__phash_db = phash_db
@property @property
def rois(self): def rois(self):
@ -74,30 +75,36 @@ class ChieriBotV4Ocr:
rating_class_results = [np.count_nonzero(m) for m in rating_class_masks] rating_class_results = [np.count_nonzero(m) for m in rating_class_masks]
if max(rating_class_results) < 70: if max(rating_class_results) < 70:
return 0 return 0
else:
return max(enumerate(rating_class_results), key=lambda i: i[1])[0] + 1 return max(enumerate(rating_class_results), key=lambda i: i[1])[0] + 1
def ocr_component_song_id(self, component_bgr: Mat): def ocr_component_song_id_results(self, component_bgr: Mat):
jacket_rect = self.rois.component_rois.jacket_rect.floored() jacket_rect = self.rois.component_rois.jacket_rect.floored()
jacket_roi = cv2.cvtColor( jacket_roi = cv2.cvtColor(
crop_xywh(component_bgr, jacket_rect), cv2.COLOR_BGR2GRAY crop_xywh(component_bgr, jacket_rect),
cv2.COLOR_BGR2GRAY,
) )
return self.phash_db.lookup_jacket(jacket_roi)[0] return self.image_id_provider.results(jacket_roi, ImageCategory.JACKET)
def ocr_component_score_knn(self, component_bgr: Mat) -> int: def ocr_component_score_knn(self, component_bgr: Mat) -> int:
# sourcery skip: inline-immediately-returned-variable # sourcery skip: inline-immediately-returned-variable
score_rect = self.rois.component_rois.score_rect.rounded() score_rect = self.rois.component_rois.score_rect.rounded()
score_roi = cv2.cvtColor( score_roi = cv2.cvtColor(
crop_xywh(component_bgr, score_rect), cv2.COLOR_BGR2GRAY crop_xywh(component_bgr, score_rect),
cv2.COLOR_BGR2GRAY,
) )
_, score_roi = cv2.threshold( _, score_roi = cv2.threshold(
score_roi, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU score_roi,
0,
255,
cv2.THRESH_BINARY + cv2.THRESH_OTSU,
) )
if score_roi[1][1] == 255: if score_roi[1][1] == 255:
score_roi = 255 - score_roi score_roi = 255 - score_roi
contours, _ = cv2.findContours( contours, _ = cv2.findContours(
score_roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE score_roi,
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE,
) )
for contour in contours: for contour in contours:
rect = cv2.boundingRect(contour) rect = cv2.boundingRect(contour)
@ -109,8 +116,9 @@ class ChieriBotV4Ocr:
return int(ocr_result) if ocr_result else 0 return int(ocr_result) if ocr_result else 0
def find_pfl_rects( def find_pfl_rects(
self, component_pfl_processed: Mat self,
) -> List[Tuple[int, int, int, int]]: component_pfl_processed: Mat,
) -> list[tuple[int, int, int, int]]:
# sourcery skip: inline-immediately-returned-variable # sourcery skip: inline-immediately-returned-variable
pfl_roi_find = cv2.morphologyEx( pfl_roi_find = cv2.morphologyEx(
component_pfl_processed, component_pfl_processed,
@ -118,14 +126,16 @@ class ChieriBotV4Ocr:
cv2.getStructuringElement(cv2.MORPH_RECT, [10, 1]), cv2.getStructuringElement(cv2.MORPH_RECT, [10, 1]),
) )
pfl_contours, _ = cv2.findContours( pfl_contours, _ = cv2.findContours(
pfl_roi_find, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE pfl_roi_find,
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_NONE,
) )
pfl_rects = [cv2.boundingRect(c) for c in pfl_contours] pfl_rects = [cv2.boundingRect(c) for c in pfl_contours]
pfl_rects = [ pfl_rects = [
r for r in pfl_rects if r[3] > component_pfl_processed.shape[0] * 0.1 r for r in pfl_rects if r[3] > component_pfl_processed.shape[0] * 0.1
] ]
pfl_rects = sorted(pfl_rects, key=lambda r: r[1]) pfl_rects = sorted(pfl_rects, key=lambda r: r[1])
pfl_rects_adjusted = [ return [
( (
max(rect[0] - 2, 0), max(rect[0] - 2, 0),
rect[1], rect[1],
@ -134,7 +144,6 @@ class ChieriBotV4Ocr:
) )
for rect in pfl_rects for rect in pfl_rects
] ]
return pfl_rects_adjusted
def preprocess_component_pfl(self, component_bgr: Mat) -> Mat: def preprocess_component_pfl(self, component_bgr: Mat) -> Mat:
pfl_rect = self.rois.component_rois.pfl_rect.rounded() pfl_rect = self.rois.component_rois.pfl_rect.rounded()
@ -157,11 +166,17 @@ class ChieriBotV4Ocr:
pfl_roi_blurred = cv2.GaussianBlur(pfl_roi, (5, 5), 0) pfl_roi_blurred = cv2.GaussianBlur(pfl_roi, (5, 5), 0)
# pfl_roi_blurred = cv2.medianBlur(pfl_roi, 3) # pfl_roi_blurred = cv2.medianBlur(pfl_roi, 3)
_, pfl_roi_blurred_threshold = cv2.threshold( _, pfl_roi_blurred_threshold = cv2.threshold(
pfl_roi_blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU pfl_roi_blurred,
0,
255,
cv2.THRESH_BINARY + cv2.THRESH_OTSU,
) )
# and a threshold of the original roi # and a threshold of the original roi
_, pfl_roi_threshold = cv2.threshold( _, pfl_roi_threshold = cv2.threshold(
pfl_roi, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU pfl_roi,
0,
255,
cv2.THRESH_BINARY + cv2.THRESH_OTSU,
) )
# turn thresholds into black background # turn thresholds into black background
if pfl_roi_blurred_threshold[2][2] == 255: if pfl_roi_blurred_threshold[2][2] == 255:
@ -171,13 +186,15 @@ class ChieriBotV4Ocr:
# return a bitwise_and result # return a bitwise_and result
result = cv2.bitwise_and(pfl_roi_blurred_threshold, pfl_roi_threshold) result = cv2.bitwise_and(pfl_roi_blurred_threshold, pfl_roi_threshold)
result_eroded = cv2.erode( result_eroded = cv2.erode(
result, cv2.getStructuringElement(cv2.MORPH_CROSS, (2, 2)) result,
cv2.getStructuringElement(cv2.MORPH_CROSS, (2, 2)),
) )
return result_eroded if len(self.find_pfl_rects(result_eroded)) == 3 else result return result_eroded if len(self.find_pfl_rects(result_eroded)) == 3 else result
def ocr_component_pfl( def ocr_component_pfl(
self, component_bgr: Mat self,
) -> Tuple[Optional[int], Optional[int], Optional[int]]: component_bgr: Mat,
) -> tuple[int | None, int | None, int | None]:
try: try:
pfl_roi = self.preprocess_component_pfl(component_bgr) pfl_roi = self.preprocess_component_pfl(component_bgr)
pfl_rects = self.find_pfl_rects(pfl_roi) pfl_rects = self.find_pfl_rects(pfl_roi)
@ -188,31 +205,39 @@ class ChieriBotV4Ocr:
pure_far_lost.append(int(result) if result else None) pure_far_lost.append(int(result) if result else None)
return tuple(pure_far_lost) return tuple(pure_far_lost)
except Exception: except Exception: # noqa: BLE001
return (None, None, None) return (None, None, None)
def ocr_component(self, component_bgr: Mat) -> B30OcrResultItem: def ocr_component(self, component_bgr: Mat) -> OcrScenarioResult:
component_blur = cv2.GaussianBlur(component_bgr, (5, 5), 0) component_blur = cv2.GaussianBlur(component_bgr, (5, 5), 0)
rating_class = self.ocr_component_rating_class(component_blur) rating_class = self.ocr_component_rating_class(component_blur)
song_id = self.ocr_component_song_id(component_bgr) song_id_results = self.ocr_component_song_id_results(component_bgr)
# title = self.ocr_component_title(component_blur)
# score = self.ocr_component_score(component_blur) # score = self.ocr_component_score(component_blur)
score = self.ocr_component_score_knn(component_bgr) score = self.ocr_component_score_knn(component_bgr)
pure, far, lost = self.ocr_component_pfl(component_bgr) pure, far, lost = self.ocr_component_pfl(component_bgr)
return B30OcrResultItem( return OcrScenarioResult(
song_id=song_id, song_id=song_id_results[0].image_id,
song_id_results=song_id_results,
rating_class=rating_class, rating_class=rating_class,
# title=title,
score=score, score=score,
pure=pure, pure=pure,
far=far, far=far,
lost=lost, lost=lost,
date=None, played_at=None,
) )
def ocr(self, img_bgr: Mat) -> List[B30OcrResultItem]: def components(self, img: Mat, /):
self.set_factor(img_bgr) """
return [ :param img: BGR format image
self.ocr_component(component_bgr) """
for component_bgr in self.rois.components(img_bgr) self.set_factor(img)
] return self.rois.components(img)
def result(self, component_img: Mat, /):
return self.ocr_component(component_img)
def results(self, img: Mat, /) -> list[OcrScenarioResult]:
"""
:param img: BGR format image
"""
return [self.ocr_component(component) for component in self.components(img)]

View File

@ -1,8 +1,7 @@
from typing import List from __future__ import annotations
from ....crop import crop_xywh from arcaea_offline_ocr.crop import crop_xywh
from ....types import Mat, XYWHRect from arcaea_offline_ocr.types import Mat, XYWHRect
from ....utils import apply_factor
class ChieriBotV4ComponentRois: class ChieriBotV4ComponentRois:
@ -19,39 +18,39 @@ class ChieriBotV4ComponentRois:
@property @property
def top_font_color_detect(self): def top_font_color_detect(self):
return apply_factor(XYWHRect(35, 10, 120, 100), self.factor) return XYWHRect(35, 10, 120, 100), self.factor
@property @property
def bottom_font_color_detect(self): def bottom_font_color_detect(self):
return apply_factor(XYWHRect(30, 125, 175, 110), self.factor) return XYWHRect(30, 125, 175, 110) * self.factor
@property @property
def bg_point(self): def bg_point(self):
return apply_factor((75, 10), self.factor) return (75 * self.factor, 10 * self.factor)
@property @property
def rating_class_rect(self): def rating_class_rect(self):
return apply_factor(XYWHRect(21, 40, 7, 20), self.factor) return XYWHRect(21, 40, 7, 20) * self.factor
@property @property
def title_rect(self): def title_rect(self):
return apply_factor(XYWHRect(35, 10, 430, 50), self.factor) return XYWHRect(35, 10, 430, 50) * self.factor
@property @property
def jacket_rect(self): def jacket_rect(self):
return apply_factor(XYWHRect(263, 0, 239, 239), self.factor) return XYWHRect(263, 0, 239, 239) * self.factor
@property @property
def score_rect(self): def score_rect(self):
return apply_factor(XYWHRect(30, 60, 270, 55), self.factor) return XYWHRect(30, 60, 270, 55) * self.factor
@property @property
def pfl_rect(self): def pfl_rect(self):
return apply_factor(XYWHRect(50, 125, 80, 100), self.factor) return XYWHRect(50, 125, 80, 100) * self.factor
@property @property
def date_rect(self): def date_rect(self):
return apply_factor(XYWHRect(205, 200, 225, 25), self.factor) return XYWHRect(205, 200, 225, 25) * self.factor
class ChieriBotV4Rois: class ChieriBotV4Rois:
@ -74,27 +73,27 @@ class ChieriBotV4Rois:
@property @property
def top(self): def top(self):
return apply_factor(823, self.factor) return 823 * self.factor
@property @property
def left(self): def left(self):
return apply_factor(107, self.factor) return 107 * self.factor
@property @property
def width(self): def width(self):
return apply_factor(502, self.factor) return 502 * self.factor
@property @property
def height(self): def height(self):
return apply_factor(240, self.factor) return 240 * self.factor
@property @property
def vertical_gap(self): def vertical_gap(self):
return apply_factor(74, self.factor) return 74 * self.factor
@property @property
def horizontal_gap(self): def horizontal_gap(self):
return apply_factor(40, self.factor) return 40 * self.factor
@property @property
def horizontal_items(self): def horizontal_items(self):
@ -104,9 +103,9 @@ class ChieriBotV4Rois:
@property @property
def b33_vertical_gap(self): def b33_vertical_gap(self):
return apply_factor(121, self.factor) return 121 * self.factor
def components(self, img_bgr: Mat) -> List[Mat]: def components(self, img_bgr: Mat) -> list[Mat]:
first_rect = XYWHRect(x=self.left, y=self.top, w=self.width, h=self.height) first_rect = XYWHRect(x=self.left, y=self.top, w=self.width, h=self.height)
results = [] results = []

View File

@ -0,0 +1,42 @@
from __future__ import annotations
from abc import ABC
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Sequence
if TYPE_CHECKING:
from datetime import datetime
from arcaea_offline_ocr.providers import ImageIdProviderResult
@dataclass(kw_only=True)
class OcrScenarioResult:
song_id: str
rating_class: int
score: int
song_id_results: Sequence[ImageIdProviderResult] = field(default_factory=list)
partner_id_results: Sequence[ImageIdProviderResult] = field(
default_factory=list,
)
pure: int | None = None
pure_inaccurate: int | None = None
pure_early: int | None = None
pure_late: int | None = None
far: int | None = None
far_inaccurate: int | None = None
far_early: int | None = None
far_late: int | None = None
lost: int | None = None
played_at: datetime | None = None
max_recall: int | None = None
clear_status: int | None = None
clear_type: int | None = None
modifier: int | None = None
class OcrScenario(ABC): # noqa: B024
pass

View File

@ -0,0 +1,13 @@
from .extractor import DeviceRoisExtractor
from .impl import DeviceScenario
from .masker import DeviceRoisMaskerAutoT1, DeviceRoisMaskerAutoT2
from .rois import DeviceRoisAutoT1, DeviceRoisAutoT2
__all__ = [
"DeviceRoisAutoT1",
"DeviceRoisAutoT2",
"DeviceRoisExtractor",
"DeviceRoisMaskerAutoT1",
"DeviceRoisMaskerAutoT2",
"DeviceScenario",
]

View File

@ -0,0 +1,8 @@
from abc import abstractmethod
from arcaea_offline_ocr.scenarios.base import OcrScenario, OcrScenarioResult
class DeviceScenarioBase(OcrScenario):
@abstractmethod
def result(self) -> OcrScenarioResult: ...

View File

@ -0,0 +1,3 @@
from .base import DeviceRoisExtractor
__all__ = ["DeviceRoisExtractor"]

View File

@ -0,0 +1,45 @@
from arcaea_offline_ocr.crop import crop_xywh
from arcaea_offline_ocr.scenarios.device.rois import DeviceRois
from arcaea_offline_ocr.types import Mat
class DeviceRoisExtractor:
def __init__(self, img: Mat, rois: DeviceRois):
self.img = img
self.sizes = rois
@property
def pure(self):
return crop_xywh(self.img, self.sizes.pure.rounded())
@property
def far(self):
return crop_xywh(self.img, self.sizes.far.rounded())
@property
def lost(self):
return crop_xywh(self.img, self.sizes.lost.rounded())
@property
def score(self):
return crop_xywh(self.img, self.sizes.score.rounded())
@property
def jacket(self):
return crop_xywh(self.img, self.sizes.jacket.rounded())
@property
def rating_class(self):
return crop_xywh(self.img, self.sizes.rating_class.rounded())
@property
def max_recall(self):
return crop_xywh(self.img, self.sizes.max_recall.rounded())
@property
def clear_status(self):
return crop_xywh(self.img, self.sizes.clear_status.rounded())
@property
def partner_icon(self):
return crop_xywh(self.img, self.sizes.partner_icon.rounded())

View File

@ -1,26 +1,31 @@
import cv2 import cv2
import numpy as np import numpy as np
from ..phash_db import ImagePhashDatabase from arcaea_offline_ocr.providers import (
from ..providers.knn import OcrKNearestTextProvider ImageCategory,
from ..types import Mat ImageIdProvider,
from .common import DeviceOcrResult OcrKNearestTextProvider,
from .rois.extractor import DeviceRoisExtractor )
from .rois.masker import DeviceRoisMasker from arcaea_offline_ocr.scenarios.base import OcrScenarioResult
from arcaea_offline_ocr.types import Mat
from .base import DeviceScenarioBase
from .extractor import DeviceRoisExtractor
from .masker import DeviceRoisMasker
class DeviceOcr: class DeviceScenario(DeviceScenarioBase):
def __init__( def __init__(
self, self,
extractor: DeviceRoisExtractor, extractor: DeviceRoisExtractor,
masker: DeviceRoisMasker, masker: DeviceRoisMasker,
knn_provider: OcrKNearestTextProvider, knn_provider: OcrKNearestTextProvider,
phash_db: ImagePhashDatabase, image_id_provider: ImageIdProvider,
): ):
self.extractor = extractor self.extractor = extractor
self.masker = masker self.masker = masker
self.knn_provider = knn_provider self.knn_provider = knn_provider
self.phash_db = phash_db self.image_id_provider = image_id_provider
def pfl(self, roi_gray: Mat, factor: float = 1.25): def pfl(self, roi_gray: Mat, factor: float = 1.25):
def contour_filter(cnt): def contour_filter(cnt):
@ -28,7 +33,8 @@ class DeviceOcr:
contours = self.knn_provider.contours(roi_gray) contours = self.knn_provider.contours(roi_gray)
contours_filtered = self.knn_provider.contours( contours_filtered = self.knn_provider.contours(
roi_gray, contours_filter=contour_filter roi_gray,
contours_filter=contour_filter,
) )
roi_ocr = roi_gray.copy() roi_ocr = roi_gray.copy()
@ -79,7 +85,7 @@ class DeviceOcr:
def max_recall(self): def max_recall(self):
ocr_result = self.knn_provider.result( ocr_result = self.knn_provider.result(
self.masker.max_recall(self.extractor.max_recall) self.masker.max_recall(self.extractor.max_recall),
) )
return int(ocr_result) if ocr_result else None return int(ocr_result) if ocr_result else None
@ -93,20 +99,18 @@ class DeviceOcr:
] ]
return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0] return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0]
def lookup_song_id(self): def song_id_results(self):
return self.phash_db.lookup_jacket( return self.image_id_provider.results(
cv2.cvtColor(self.extractor.jacket, cv2.COLOR_BGR2GRAY) cv2.cvtColor(self.extractor.jacket, cv2.COLOR_BGR2GRAY),
ImageCategory.JACKET,
) )
def song_id(self):
return self.lookup_song_id()[0]
@staticmethod @staticmethod
def preprocess_char_icon(img_gray: Mat): def preprocess_char_icon(img_gray: Mat):
h, w = img_gray.shape[:2] h, w = img_gray.shape[:2]
img = cv2.copyMakeBorder(img_gray, max(w - h, 0), 0, 0, 0, cv2.BORDER_REPLICATE) img = cv2.copyMakeBorder(img_gray, max(w - h, 0), 0, 0, 0, cv2.BORDER_REPLICATE)
h, w = img.shape[:2] h, w = img.shape[:2]
img = cv2.fillPoly( return cv2.fillPoly(
img, img,
[ [
np.array([[0, 0], [round(w / 2), 0], [0, round(h / 2)]], np.int32), np.array([[0, 0], [round(w / 2), 0], [0, round(h / 2)]], np.int32),
@ -114,21 +118,18 @@ class DeviceOcr:
np.array([[0, h], [round(w / 2), h], [0, round(h / 2)]], np.int32), np.array([[0, h], [round(w / 2), h], [0, round(h / 2)]], np.int32),
np.array([[w, h], [round(w / 2), h], [w, round(h / 2)]], np.int32), np.array([[w, h], [round(w / 2), h], [w, round(h / 2)]], np.int32),
], ],
(128), (128,),
) )
return img
def lookup_partner_id(self): def partner_id_results(self):
return self.phash_db.lookup_partner_icon( return self.image_id_provider.results(
self.preprocess_char_icon( self.preprocess_char_icon(
cv2.cvtColor(self.extractor.partner_icon, cv2.COLOR_BGR2GRAY) cv2.cvtColor(self.extractor.partner_icon, cv2.COLOR_BGR2GRAY),
) ),
ImageCategory.PARTNER_ICON,
) )
def partner_id(self): def result(self):
return self.lookup_partner_id()[0]
def ocr(self) -> DeviceOcrResult:
rating_class = self.rating_class() rating_class = self.rating_class()
pure = self.pure() pure = self.pure()
far = self.far() far = self.far()
@ -137,20 +138,18 @@ class DeviceOcr:
max_recall = self.max_recall() max_recall = self.max_recall()
clear_status = self.clear_status() clear_status = self.clear_status()
hash_len = self.phash_db.hash_size**2 song_id_results = self.song_id_results()
song_id, song_id_distance = self.lookup_song_id() partner_id_results = self.partner_id_results()
partner_id, partner_id_distance = self.lookup_partner_id()
return DeviceOcrResult( return OcrScenarioResult(
song_id=song_id_results[0].image_id,
song_id_results=song_id_results,
rating_class=rating_class, rating_class=rating_class,
pure=pure, pure=pure,
far=far, far=far,
lost=lost, lost=lost,
score=score, score=score,
max_recall=max_recall, max_recall=max_recall,
song_id=song_id, partner_id_results=partner_id_results,
song_id_possibility=1 - song_id_distance / hash_len,
clear_status=clear_status, clear_status=clear_status,
partner_id=partner_id,
partner_id_possibility=1 - partner_id_distance / hash_len,
) )

View File

@ -0,0 +1,9 @@
from .auto import DeviceRoisMaskerAuto, DeviceRoisMaskerAutoT1, DeviceRoisMaskerAutoT2
from .base import DeviceRoisMasker
__all__ = [
"DeviceRoisMasker",
"DeviceRoisMaskerAuto",
"DeviceRoisMaskerAutoT1",
"DeviceRoisMaskerAutoT2",
]

View File

@ -1,17 +1,18 @@
import cv2 import cv2
import numpy as np import numpy as np
from ....types import Mat from arcaea_offline_ocr.types import Mat
from .common import DeviceRoisMasker
from .base import DeviceRoisMasker
class DeviceRoisMaskerAuto(DeviceRoisMasker): class DeviceRoisMaskerAuto(DeviceRoisMasker):
# pylint: disable=abstract-method
@staticmethod @staticmethod
def mask_bgr_in_hsv(roi_bgr: Mat, hsv_lower: Mat, hsv_upper: Mat): def mask_bgr_in_hsv(roi_bgr: Mat, hsv_lower: Mat, hsv_upper: Mat):
return cv2.inRange( return cv2.inRange(
cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), hsv_lower, hsv_upper cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV),
hsv_lower,
hsv_upper,
) )
@ -101,25 +102,33 @@ class DeviceRoisMaskerAutoT1(DeviceRoisMaskerAuto):
@classmethod @classmethod
def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat: def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, cls.TRACK_LOST_HSV_MIN, cls.TRACK_LOST_HSV_MAX roi_bgr,
cls.TRACK_LOST_HSV_MIN,
cls.TRACK_LOST_HSV_MAX,
) )
@classmethod @classmethod
def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat: def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, cls.TRACK_COMPLETE_HSV_MIN, cls.TRACK_COMPLETE_HSV_MAX roi_bgr,
cls.TRACK_COMPLETE_HSV_MIN,
cls.TRACK_COMPLETE_HSV_MAX,
) )
@classmethod @classmethod
def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat: def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, cls.FULL_RECALL_HSV_MIN, cls.FULL_RECALL_HSV_MAX roi_bgr,
cls.FULL_RECALL_HSV_MIN,
cls.FULL_RECALL_HSV_MAX,
) )
@classmethod @classmethod
def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat: def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, cls.PURE_MEMORY_HSV_MIN, cls.PURE_MEMORY_HSV_MAX roi_bgr,
cls.PURE_MEMORY_HSV_MIN,
cls.PURE_MEMORY_HSV_MAX,
) )
@ -203,29 +212,39 @@ class DeviceRoisMaskerAutoT2(DeviceRoisMaskerAuto):
@classmethod @classmethod
def max_recall(cls, roi_bgr: Mat) -> Mat: def max_recall(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, cls.MAX_RECALL_HSV_MIN, cls.MAX_RECALL_HSV_MAX roi_bgr,
cls.MAX_RECALL_HSV_MIN,
cls.MAX_RECALL_HSV_MAX,
) )
@classmethod @classmethod
def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat: def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, cls.TRACK_LOST_HSV_MIN, cls.TRACK_LOST_HSV_MAX roi_bgr,
cls.TRACK_LOST_HSV_MIN,
cls.TRACK_LOST_HSV_MAX,
) )
@classmethod @classmethod
def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat: def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, cls.TRACK_COMPLETE_HSV_MIN, cls.TRACK_COMPLETE_HSV_MAX roi_bgr,
cls.TRACK_COMPLETE_HSV_MIN,
cls.TRACK_COMPLETE_HSV_MAX,
) )
@classmethod @classmethod
def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat: def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, cls.FULL_RECALL_HSV_MIN, cls.FULL_RECALL_HSV_MAX roi_bgr,
cls.FULL_RECALL_HSV_MIN,
cls.FULL_RECALL_HSV_MAX,
) )
@classmethod @classmethod
def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat: def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, cls.PURE_MEMORY_HSV_MIN, cls.PURE_MEMORY_HSV_MAX roi_bgr,
cls.PURE_MEMORY_HSV_MIN,
cls.PURE_MEMORY_HSV_MAX,
) )

View File

@ -0,0 +1,61 @@
from abc import ABC, abstractmethod
from arcaea_offline_ocr.types import Mat
class DeviceRoisMasker(ABC):
@classmethod
@abstractmethod
def pure(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def far(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def lost(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def score(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def rating_class_pst(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def rating_class_prs(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def rating_class_ftr(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def rating_class_byd(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def rating_class_etr(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def max_recall(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat: ...

View File

@ -0,0 +1,9 @@
from .auto import DeviceRoisAuto, DeviceRoisAutoT1, DeviceRoisAutoT2
from .base import DeviceRois
__all__ = [
"DeviceRois",
"DeviceRoisAuto",
"DeviceRoisAutoT1",
"DeviceRoisAutoT2",
]

View File

@ -1,6 +1,6 @@
from .common import DeviceRois from arcaea_offline_ocr.types import XYWHRect
__all__ = ["DeviceRoisAuto", "DeviceRoisAutoT1", "DeviceRoisAutoT2"] from .base import DeviceRois
class DeviceRoisAuto(DeviceRois): class DeviceRoisAuto(DeviceRois):
@ -50,7 +50,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property @property
def pure(self): def pure(self):
return ( return XYWHRect(
self.pfl_x, self.pfl_x,
self.layout_area_h_mid + 110 * self.factor, self.layout_area_h_mid + 110 * self.factor,
self.pfl_w, self.pfl_w,
@ -59,7 +59,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property @property
def far(self): def far(self):
return ( return XYWHRect(
self.pfl_x, self.pfl_x,
self.pure[1] + self.pure[3] + 12 * self.factor, self.pure[1] + self.pure[3] + 12 * self.factor,
self.pfl_w, self.pfl_w,
@ -68,7 +68,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property @property
def lost(self): def lost(self):
return ( return XYWHRect(
self.pfl_x, self.pfl_x,
self.far[1] + self.far[3] + 10 * self.factor, self.far[1] + self.far[3] + 10 * self.factor,
self.pfl_w, self.pfl_w,
@ -79,7 +79,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
def score(self): def score(self):
w = 280 * self.factor w = 280 * self.factor
h = 45 * self.factor h = 45 * self.factor
return ( return XYWHRect(
self.w_mid - w / 2, self.w_mid - w / 2,
self.layout_area_h_mid - 75 * self.factor - h, self.layout_area_h_mid - 75 * self.factor - h,
w, w,
@ -88,7 +88,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property @property
def rating_class(self): def rating_class(self):
return ( return XYWHRect(
self.w_mid - 610 * self.factor, self.w_mid - 610 * self.factor,
self.layout_area_h_mid - 180 * self.factor, self.layout_area_h_mid - 180 * self.factor,
265 * self.factor, 265 * self.factor,
@ -97,7 +97,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property @property
def max_recall(self): def max_recall(self):
return ( return XYWHRect(
self.w_mid - 465 * self.factor, self.w_mid - 465 * self.factor,
self.layout_area_h_mid - 215 * self.factor, self.layout_area_h_mid - 215 * self.factor,
150 * self.factor, 150 * self.factor,
@ -106,7 +106,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property @property
def jacket(self): def jacket(self):
return ( return XYWHRect(
self.w_mid - 610 * self.factor, self.w_mid - 610 * self.factor,
self.layout_area_h_mid - 143 * self.factor, self.layout_area_h_mid - 143 * self.factor,
375 * self.factor, 375 * self.factor,
@ -117,7 +117,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
def clear_status(self): def clear_status(self):
w = 550 * self.factor w = 550 * self.factor
h = 60 * self.factor h = 60 * self.factor
return ( return XYWHRect(
self.w_mid - w / 2, self.w_mid - w / 2,
self.layout_area_h_mid - 155 * self.factor - h, self.layout_area_h_mid - 155 * self.factor - h,
w * 0.4, w * 0.4,
@ -128,7 +128,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
def partner_icon(self): def partner_icon(self):
w = 90 * self.factor w = 90 * self.factor
h = 75 * self.factor h = 75 * self.factor
return (self.w_mid - w / 2, 0, w, h) return XYWHRect(self.w_mid - w / 2, 0, w, h)
class DeviceRoisAutoT2(DeviceRoisAuto): class DeviceRoisAutoT2(DeviceRoisAuto):
@ -174,7 +174,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property @property
def pure(self): def pure(self):
return ( return XYWHRect(
self.pfl_x, self.pfl_x,
self.layout_area_h_mid + 175 * self.factor, self.layout_area_h_mid + 175 * self.factor,
self.pfl_w, self.pfl_w,
@ -183,7 +183,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property @property
def far(self): def far(self):
return ( return XYWHRect(
self.pfl_x, self.pfl_x,
self.pure[1] + self.pure[3] + 30 * self.factor, self.pure[1] + self.pure[3] + 30 * self.factor,
self.pfl_w, self.pfl_w,
@ -192,7 +192,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property @property
def lost(self): def lost(self):
return ( return XYWHRect(
self.pfl_x, self.pfl_x,
self.far[1] + self.far[3] + 35 * self.factor, self.far[1] + self.far[3] + 35 * self.factor,
self.pfl_w, self.pfl_w,
@ -203,7 +203,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
def score(self): def score(self):
w = 420 * self.factor w = 420 * self.factor
h = 70 * self.factor h = 70 * self.factor
return ( return XYWHRect(
self.w_mid - w / 2, self.w_mid - w / 2,
self.layout_area_h_mid - 110 * self.factor - h, self.layout_area_h_mid - 110 * self.factor - h,
w, w,
@ -212,7 +212,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property @property
def rating_class(self): def rating_class(self):
return ( return XYWHRect(
max(0, self.w_mid - 965 * self.factor), max(0, self.w_mid - 965 * self.factor),
self.layout_area_h_mid - 330 * self.factor, self.layout_area_h_mid - 330 * self.factor,
350 * self.factor, 350 * self.factor,
@ -221,7 +221,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property @property
def max_recall(self): def max_recall(self):
return ( return XYWHRect(
self.w_mid - 625 * self.factor, self.w_mid - 625 * self.factor,
self.layout_area_h_mid - 275 * self.factor, self.layout_area_h_mid - 275 * self.factor,
150 * self.factor, 150 * self.factor,
@ -230,7 +230,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property @property
def jacket(self): def jacket(self):
return ( return XYWHRect(
self.w_mid - 915 * self.factor, self.w_mid - 915 * self.factor,
self.layout_area_h_mid - 215 * self.factor, self.layout_area_h_mid - 215 * self.factor,
565 * self.factor, 565 * self.factor,
@ -241,7 +241,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
def clear_status(self): def clear_status(self):
w = 825 * self.factor w = 825 * self.factor
h = 90 * self.factor h = 90 * self.factor
return ( return XYWHRect(
self.w_mid - w / 2, self.w_mid - w / 2,
self.layout_area_h_mid - 235 * self.factor - h, self.layout_area_h_mid - 235 * self.factor - h,
w * 0.4, w * 0.4,
@ -252,4 +252,4 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
def partner_icon(self): def partner_icon(self):
w = 135 * self.factor w = 135 * self.factor
h = 110 * self.factor h = 110 * self.factor
return (self.w_mid - w / 2, 0, w, h) return XYWHRect(self.w_mid - w / 2, 0, w, h)

View File

@ -0,0 +1,33 @@
from abc import ABC, abstractmethod
from arcaea_offline_ocr.types import XYWHRect
class DeviceRois(ABC):
@property
@abstractmethod
def pure(self) -> XYWHRect: ...
@property
@abstractmethod
def far(self) -> XYWHRect: ...
@property
@abstractmethod
def lost(self) -> XYWHRect: ...
@property
@abstractmethod
def score(self) -> XYWHRect: ...
@property
@abstractmethod
def rating_class(self) -> XYWHRect: ...
@property
@abstractmethod
def max_recall(self) -> XYWHRect: ...
@property
@abstractmethod
def jacket(self) -> XYWHRect: ...
@property
@abstractmethod
def clear_status(self) -> XYWHRect: ...
@property
@abstractmethod
def partner_icon(self) -> XYWHRect: ...

View File

@ -25,12 +25,18 @@ class XYWHRect(NamedTuple):
def __add__(self, other): def __add__(self, other):
if not isinstance(other, (list, tuple)) or len(other) != 4: if not isinstance(other, (list, tuple)) or len(other) != 4:
raise TypeError() raise TypeError
return self.__class__(*[a + b for a, b in zip(self, other)]) return self.__class__(*[a + b for a, b in zip(self, other)])
def __sub__(self, other): def __sub__(self, other):
if not isinstance(other, (list, tuple)) or len(other) != 4: if not isinstance(other, (list, tuple)) or len(other) != 4:
raise TypeError() raise TypeError
return self.__class__(*[a - b for a, b in zip(self, other)]) return self.__class__(*[a - b for a, b in zip(self, other)])
def __mul__(self, other):
if not isinstance(other, (int, float)):
raise TypeError
return self.__class__(*[v * other for v in self])

View File

@ -1,11 +1,6 @@
from collections.abc import Iterable
from typing import TypeVar, overload
import cv2 import cv2
import numpy as np import numpy as np
from .types import XYWHRect
__all__ = ["imread_unicode"] __all__ = ["imread_unicode"]
@ -13,27 +8,3 @@ def imread_unicode(filepath: str, flags: int = cv2.IMREAD_UNCHANGED):
# https://stackoverflow.com/a/57872297/16484891 # https://stackoverflow.com/a/57872297/16484891
# CC BY-SA 4.0 # CC BY-SA 4.0
return cv2.imdecode(np.fromfile(filepath, dtype=np.uint8), flags) return cv2.imdecode(np.fromfile(filepath, dtype=np.uint8), flags)
@overload
def apply_factor(item: int, factor: float) -> float: ...
@overload
def apply_factor(item: float, factor: float) -> float: ...
T = TypeVar("T", bound=Iterable)
@overload
def apply_factor(item: T, factor: float) -> T: ...
def apply_factor(item, factor: float):
if isinstance(item, (int, float)):
return item * factor
if isinstance(item, XYWHRect):
return item.__class__(*[i * factor for i in item])
if isinstance(item, Iterable):
return item.__class__([i * factor for i in item])