1 Commits

Author SHA1 Message Date
5215218526 ci: tag regex 2025-06-25 23:54:27 +08:00
54 changed files with 725 additions and 1376 deletions

View File

@ -0,0 +1,41 @@
name: "Build and draft a release"
on:
workflow_dispatch:
push:
tags:
# regex taken from
# https://packaging.python.org/en/latest/specifications/version-specifiers/#appendix-parsing-version-strings-with-regular-expressions
- '^([1-9][0-9]*!)?(0|[1-9][0-9]*)(\.(0|[1-9][0-9]*))*((a|b|rc)(0|[1-9][0-9]*))?(\.post(0|[1-9][0-9]*))?(\.dev(0|[1-9][0-9]*))?$'
permissions:
contents: write
discussions: write
jobs:
build-and-draft-release:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python environment
uses: actions/setup-python@v5
with:
python-version: "3.x"
- name: Build package
run: |
pip install build
python -m build
- name: Draft a release
uses: softprops/action-gh-release@v2
with:
discussion_category_name: New releases
draft: true
generate_release_notes: true
files: |
dist/*

View File

@ -1,103 +0,0 @@
name: Build, Release, Publish
on:
workflow_dispatch:
push:
tags:
- "*.*.*"
jobs:
build:
name: Build package
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python environment
uses: actions/setup-python@v5
with:
python-version: "3.x"
- name: Build package
run: |
pip install build
python -m build
- name: Store the distribution files
uses: actions/upload-artifact@v4
with:
name: python-package-distributions
path: dist/
draft-release:
name: Draft a release
runs-on: ubuntu-latest
needs:
- build
permissions:
contents: write
discussions: write
steps:
- name: Download the distribution files
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Draft a release
uses: softprops/action-gh-release@v2
with:
discussion_category_name: New releases
draft: true
generate_release_notes: true
files: |
dist/*
publish-to-pypi:
name: Publish distribution to PyPI
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes
needs:
- build
environment:
name: pypi
url: https://pypi.org/p/arcaea-offline-ocr
permissions:
id-token: write
steps:
- name: Download the distribution files
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Publish distribution to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
publish-to-testpypi:
name: Publish distribution to TestPyPI
runs-on: ubuntu-latest
needs:
- build
environment:
name: testpypi
url: https://test.pypi.org/p/arcaea-offline-ocr
permissions:
id-token: write
steps:
- name: Download the distribution files
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Publish distribution to TestPyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://test.pypi.org/legacy/

View File

@ -4,10 +4,11 @@ repos:
hooks: hooks:
- id: end-of-file-fixer - id: end-of-file-fixer
- id: trailing-whitespace - id: trailing-whitespace
- repo: https://github.com/psf/black
- repo: https://github.com/astral-sh/ruff-pre-commit rev: 23.1.0
rev: v0.11.13
hooks: hooks:
- id: ruff - id: black
args: ["--fix"] - repo: https://github.com/PyCQA/isort
- id: ruff-format rev: 5.12.0
hooks:
- id: isort

156
README.md
View File

@ -1,152 +1,32 @@
# Arcaea Offline OCR # Arcaea Offline OCR
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
## Example ## Example
> Results from `arcaea_offline_ocr 0.1.0a2`
### Build an image hash database (ihdb)
```py ```py
import sqlite3
from pathlib import Path
import cv2 import cv2
from arcaea_offline_ocr.device.ocr import DeviceOcr
from arcaea_offline_ocr.device.rois.definition import DeviceRoisAutoT2
from arcaea_offline_ocr.device.rois.extractor import DeviceRoisExtractor
from arcaea_offline_ocr.device.rois.masker import DeviceRoisMaskerAutoT2
from arcaea_offline_ocr.phash_db import ImagePhashDatabase
from arcaea_offline_ocr.builders.ihdb import ( img_path = "/path/to/opencv/supported/image/formats.jpg"
ImageHashDatabaseBuildTask, img = cv2.imread(img_path, cv2.IMREAD_COLOR)
ImageHashesDatabaseBuilder,
)
from arcaea_offline_ocr.providers import ImageCategory, ImageHashDatabaseIdProvider
from arcaea_offline_ocr.scenarios.device import DeviceScenario
def build(): rois = DeviceRoisAutoT2(img.shape[1], img.shape[0])
def _read_partner_icon(image_path: str): extractor = DeviceRoisExtractor(img, rois)
return DeviceScenario.preprocess_char_icon( masker = DeviceRoisMaskerAutoT2()
cv2.cvtColor(cv2.imread(image_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2GRAY),
)
builder = ImageHashesDatabaseBuilder() knn_model = cv2.ml.KNearest.load("/path/to/trained/knn/model.dat")
tasks = [ phash_db = ImagePhashDatabase("/path/to/image/phash/database.db")
ImageHashDatabaseBuildTask(
image_path=str(file),
image_id=file.stem,
category=ImageCategory.JACKET,
)
for file in Path("/path/to/some/jackets").glob("*.jpg")
]
tasks.extend( ocr = DeviceOcr(extractor, masker, knn_model, phash_db)
[ print(ocr.ocr())
ImageHashDatabaseBuildTask(
image_path=str(file),
image_id=file.stem,
category=ImageCategory.PARTNER_ICON,
imread_function=_read_partner_icon,
)
for file in Path("/path/to/some/partner_icons").glob("*.png")
],
)
with sqlite3.connect("/path/to/ihdb-X.Y.Z.db") as conn:
builder.build(conn, tasks)
``` ```
### Device OCR ```sh
$ python example.py
```py DeviceOcrResult(rating_class=2, pure=1135, far=11, lost=0, score=9953016, max_recall=1146, song_id='ringedgenesis', song_id_possibility=0.953125, clear_status=2, partner_id='45', partner_id_possibility=0.8046875)
import json
import sqlite3
from dataclasses import asdict
import cv2
from arcaea_offline_ocr.providers import (
ImageHashDatabaseIdProvider,
OcrKNearestTextProvider,
)
from arcaea_offline_ocr.scenarios.device import (
DeviceRoisAutoT2,
DeviceRoisExtractor,
DeviceRoisMaskerAutoT2,
DeviceScenario,
)
with sqlite3.connect("/path/to/ihdb-X.Y.Z.db") as conn:
img = cv2.imread("/path/to/your/screenshot.jpg")
h, w = img.shape[:2]
r = DeviceRoisAutoT2(w, h)
m = DeviceRoisMaskerAutoT2()
e = DeviceRoisExtractor(img, r)
scenario = DeviceScenario(
extractor=e,
masker=m,
knn_provider=OcrKNearestTextProvider(
cv2.ml.KNearest.load("/path/to/knn_model.dat"),
),
image_id_provider=ImageHashDatabaseIdProvider(conn),
)
result = scenario.result()
with open("result.jsonc", "w", encoding="utf-8") as jf:
json.dump(asdict(result), jf, indent=2, ensure_ascii=False)
```
```jsonc
// result.json
{
"song_id": "vector",
"rating_class": 1,
"score": 9990996,
"song_id_results": [
{
"image_id": "vector",
"category": 0,
"confidence": 1.0,
"image_hash_type": 0
},
{
"image_id": "clotho",
"category": 0,
"confidence": 0.71875,
"image_hash_type": 0
}
// 28 more results omitted…
],
"partner_id_results": [
{
"image_id": "23",
"category": 1,
"confidence": 0.90625,
"image_hash_type": 0
},
{
"image_id": "45",
"category": 1,
"confidence": 0.8828125,
"image_hash_type": 0
}
// 28 more results omitted…
],
"pure": 1000,
"pure_inaccurate": null,
"pure_early": null,
"pure_late": null,
"far": 2,
"far_inaccurate": null,
"far_early": null,
"far_late": null,
"lost": 0,
"played_at": null,
"max_recall": 1002,
"clear_status": 2,
"clear_type": null,
"modifier": null
}
``` ```
## License ## License
@ -168,4 +48,4 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
## Credits ## Credits
- [JohannesBuchner/imagehash](https://github.com/JohannesBuchner/imagehash): `arcaea_offline_ocr.core.hashers` implementations reference [283375/image-phash-database](https://github.com/283375/image-phash-database)

View File

@ -1,33 +1,38 @@
[build-system] [build-system]
requires = ["setuptools>=64", "setuptools-scm>=8"] requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[project] [project]
dynamic = ["version"]
name = "arcaea-offline-ocr" name = "arcaea-offline-ocr"
version = "0.0.99"
authors = [{ name = "283375", email = "log_283375@163.com" }] authors = [{ name = "283375", email = "log_283375@163.com" }]
description = "Extract your Arcaea play result from screenshot." description = "Extract your Arcaea play result from screenshot."
readme = "README.md" readme = "README.md"
requires-python = ">=3.8" requires-python = ">=3.8"
dependencies = ["numpy~=2.3", "opencv-python~=4.11"] dependencies = ["attrs==23.1.0", "numpy==1.26.1", "opencv-python==4.8.1.78"]
classifiers = [ classifiers = [
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
] ]
[project.optional-dependencies]
dev = ["ruff", "pre-commit"]
[project.urls] [project.urls]
"Homepage" = "https://github.com/ArcaeaOffline/core-ocr" "Homepage" = "https://github.com/ArcaeaOffline/core-ocr"
"Bug Tracker" = "https://github.com/ArcaeaOffline/core-ocr/issues" "Bug Tracker" = "https://github.com/ArcaeaOffline/core-ocr/issues"
[tool.setuptools_scm] [tool.isort]
profile = "black"
src_paths = ["src/arcaea_offline_ocr"]
[tool.pyright] [tool.pyright]
ignore = ["**/__debug*.*"] ignore = ["**/__debug*.*"]
[tool.ruff.lint] [tool.pylint.main]
select = ["ALL"] # extension-pkg-allow-list = ["cv2"]
ignore = ["ANN", "D", "ERA", "PLR"] generated-members = ["cv2.*"]
[tool.pylint.logging]
disable = [
"missing-module-docstring",
"missing-class-docstring",
"missing-function-docstring"
]

View File

@ -1,2 +1,3 @@
ruff black==23.7.0
pre-commit isort==5.12.0
pre-commit==3.3.3

View File

@ -1,2 +1,3 @@
numpy~=2.3 attrs==23.1.0
opencv-python~=4.11 numpy==1.26.1
opencv-python==4.8.1.78

View File

@ -0,0 +1,4 @@
from .crop import *
from .device import *
from .ocr import *
from .utils import *

View File

@ -1,19 +1,19 @@
import numpy as np import numpy as np
__all__ = [ __all__ = [
"BYD_MAX_HSV",
"BYD_MIN_HSV",
"FAR_BG_MAX_HSV",
"FAR_BG_MIN_HSV",
"FONT_THRESHOLD", "FONT_THRESHOLD",
"FTR_MAX_HSV",
"FTR_MIN_HSV",
"LOST_BG_MAX_HSV",
"LOST_BG_MIN_HSV",
"PRS_MAX_HSV",
"PRS_MIN_HSV",
"PURE_BG_MAX_HSV",
"PURE_BG_MIN_HSV", "PURE_BG_MIN_HSV",
"PURE_BG_MAX_HSV",
"FAR_BG_MIN_HSV",
"FAR_BG_MAX_HSV",
"LOST_BG_MIN_HSV",
"LOST_BG_MAX_HSV",
"BYD_MIN_HSV",
"BYD_MAX_HSV",
"FTR_MIN_HSV",
"FTR_MAX_HSV",
"PRS_MIN_HSV",
"PRS_MAX_HSV",
] ]
FONT_THRESHOLD = 160 FONT_THRESHOLD = 160
@ -27,11 +27,11 @@ FAR_BG_MAX_HSV = np.array([20, 255, 255], np.uint8)
LOST_BG_MIN_HSV = np.array([115, 60, 150], np.uint8) LOST_BG_MIN_HSV = np.array([115, 60, 150], np.uint8)
LOST_BG_MAX_HSV = np.array([140, 255, 255], np.uint8) LOST_BG_MAX_HSV = np.array([140, 255, 255], np.uint8)
BYD_MIN_HSV = np.array([158, 120, 0], np.uint8) BYD_MIN_HSV = (158, 120, 0)
BYD_MAX_HSV = np.array([172, 255, 255], np.uint8) BYD_MAX_HSV = (172, 255, 255)
FTR_MIN_HSV = np.array([145, 70, 0], np.uint8) FTR_MIN_HSV = (145, 70, 0)
FTR_MAX_HSV = np.array([160, 255, 255], np.uint8) FTR_MAX_HSV = (160, 255, 255)
PRS_MIN_HSV = np.array([45, 60, 0], np.uint8) PRS_MIN_HSV = (45, 60, 0)
PRS_MAX_HSV = np.array([70, 255, 255], np.uint8) PRS_MAX_HSV = (70, 255, 255)

View File

@ -1,51 +1,60 @@
from __future__ import annotations from math import floor
from typing import List, Optional, Tuple
from typing import TYPE_CHECKING
import cv2 import cv2
import numpy as np import numpy as np
from arcaea_offline_ocr.crop import crop_xywh from ....crop import crop_xywh
from arcaea_offline_ocr.providers import ( from ....ocr import (
ImageCategory, FixRects,
ImageIdProvider, ocr_digits_by_contour_knn,
OcrKNearestTextProvider, preprocess_hog,
) resize_fill_square,
from arcaea_offline_ocr.scenarios.b30.base import Best30Scenario
from arcaea_offline_ocr.scenarios.base import OcrScenarioResult
if TYPE_CHECKING:
from arcaea_offline_ocr.types import Mat
from .colors import (
BYD_MAX_HSV,
BYD_MIN_HSV,
FAR_BG_MAX_HSV,
FAR_BG_MIN_HSV,
FTR_MAX_HSV,
FTR_MIN_HSV,
LOST_BG_MAX_HSV,
LOST_BG_MIN_HSV,
PRS_MAX_HSV,
PRS_MIN_HSV,
PURE_BG_MAX_HSV,
PURE_BG_MIN_HSV,
) )
from ....phash_db import ImagePhashDatabase
from ....types import Mat
from ....utils import construct_int_xywh_rect
from ...shared import B30OcrResultItem
from .colors import *
from .rois import ChieriBotV4Rois from .rois import ChieriBotV4Rois
class ChieriBotV4Best30Scenario(Best30Scenario): class ChieriBotV4Ocr:
def __init__( def __init__(
self, self,
score_knn_provider: OcrKNearestTextProvider, score_knn: cv2.ml.KNearest,
pfl_knn_provider: OcrKNearestTextProvider, pfl_knn: cv2.ml.KNearest,
image_id_provider: ImageIdProvider, phash_db: ImagePhashDatabase,
factor: float = 1.0, factor: Optional[float] = 1.0,
): ):
self.__score_knn = score_knn
self.__pfl_knn = pfl_knn
self.__phash_db = phash_db
self.__rois = ChieriBotV4Rois(factor) self.__rois = ChieriBotV4Rois(factor)
self.pfl_knn_provider = pfl_knn_provider
self.score_knn_provider = score_knn_provider @property
self.image_id_provider = image_id_provider def score_knn(self):
return self.__score_knn
@score_knn.setter
def score_knn(self, knn_digits_model: cv2.ml.KNearest):
self.__score_knn = knn_digits_model
@property
def pfl_knn(self):
return self.__pfl_knn
@pfl_knn.setter
def pfl_knn(self, knn_digits_model: cv2.ml.KNearest):
self.__pfl_knn = knn_digits_model
@property
def phash_db(self):
return self.__phash_db
@phash_db.setter
def phash_db(self, phash_db: ImagePhashDatabase):
self.__phash_db = phash_db
@property @property
def rois(self): def rois(self):
@ -63,8 +72,9 @@ class ChieriBotV4Best30Scenario(Best30Scenario):
self.factor = img.shape[0] / 4400 self.factor = img.shape[0] / 4400
def ocr_component_rating_class(self, component_bgr: Mat) -> int: def ocr_component_rating_class(self, component_bgr: Mat) -> int:
rating_class_rect = self.rois.component_rois.rating_class_rect.rounded() rating_class_rect = construct_int_xywh_rect(
self.rois.component_rois.rating_class_rect
)
rating_class_roi = crop_xywh(component_bgr, rating_class_rect) rating_class_roi = crop_xywh(component_bgr, rating_class_rect)
rating_class_roi = cv2.cvtColor(rating_class_roi, cv2.COLOR_BGR2HSV) rating_class_roi = cv2.cvtColor(rating_class_roi, cv2.COLOR_BGR2HSV)
rating_class_masks = [ rating_class_masks = [
@ -75,50 +85,41 @@ class ChieriBotV4Best30Scenario(Best30Scenario):
rating_class_results = [np.count_nonzero(m) for m in rating_class_masks] rating_class_results = [np.count_nonzero(m) for m in rating_class_masks]
if max(rating_class_results) < 70: if max(rating_class_results) < 70:
return 0 return 0
else:
return max(enumerate(rating_class_results), key=lambda i: i[1])[0] + 1 return max(enumerate(rating_class_results), key=lambda i: i[1])[0] + 1
def ocr_component_song_id_results(self, component_bgr: Mat): def ocr_component_song_id(self, component_bgr: Mat):
jacket_rect = self.rois.component_rois.jacket_rect.floored() jacket_rect = construct_int_xywh_rect(
jacket_roi = cv2.cvtColor( self.rois.component_rois.jacket_rect, floor
crop_xywh(component_bgr, jacket_rect),
cv2.COLOR_BGR2GRAY,
) )
return self.image_id_provider.results(jacket_roi, ImageCategory.JACKET) jacket_roi = cv2.cvtColor(
crop_xywh(component_bgr, jacket_rect), cv2.COLOR_BGR2GRAY
)
return self.phash_db.lookup_jacket(jacket_roi)[0]
def ocr_component_score_knn(self, component_bgr: Mat) -> int: def ocr_component_score_knn(self, component_bgr: Mat) -> int:
# sourcery skip: inline-immediately-returned-variable # sourcery skip: inline-immediately-returned-variable
score_rect = self.rois.component_rois.score_rect.rounded() score_rect = construct_int_xywh_rect(self.rois.component_rois.score_rect)
score_roi = cv2.cvtColor( score_roi = cv2.cvtColor(
crop_xywh(component_bgr, score_rect), crop_xywh(component_bgr, score_rect), cv2.COLOR_BGR2GRAY
cv2.COLOR_BGR2GRAY,
) )
_, score_roi = cv2.threshold( _, score_roi = cv2.threshold(
score_roi, score_roi, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
0,
255,
cv2.THRESH_BINARY + cv2.THRESH_OTSU,
) )
if score_roi[1][1] == 255: if score_roi[1][1] == 255:
score_roi = 255 - score_roi score_roi = 255 - score_roi
contours, _ = cv2.findContours( contours, _ = cv2.findContours(
score_roi, score_roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE,
) )
for contour in contours: for contour in contours:
rect = cv2.boundingRect(contour) rect = cv2.boundingRect(contour)
if rect[3] > score_roi.shape[0] * 0.5: if rect[3] > score_roi.shape[0] * 0.5:
continue continue
score_roi = cv2.fillPoly(score_roi, [contour], 0) score_roi = cv2.fillPoly(score_roi, [contour], 0)
return ocr_digits_by_contour_knn(score_roi, self.score_knn)
ocr_result = self.score_knn_provider.result(score_roi) def find_pfl_rects(self, component_pfl_processed: Mat) -> List[List[int]]:
return int(ocr_result) if ocr_result else 0
def find_pfl_rects(
self,
component_pfl_processed: Mat,
) -> list[tuple[int, int, int, int]]:
# sourcery skip: inline-immediately-returned-variable # sourcery skip: inline-immediately-returned-variable
pfl_roi_find = cv2.morphologyEx( pfl_roi_find = cv2.morphologyEx(
component_pfl_processed, component_pfl_processed,
@ -126,16 +127,14 @@ class ChieriBotV4Best30Scenario(Best30Scenario):
cv2.getStructuringElement(cv2.MORPH_RECT, [10, 1]), cv2.getStructuringElement(cv2.MORPH_RECT, [10, 1]),
) )
pfl_contours, _ = cv2.findContours( pfl_contours, _ = cv2.findContours(
pfl_roi_find, pfl_roi_find, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_NONE,
) )
pfl_rects = [cv2.boundingRect(c) for c in pfl_contours] pfl_rects = [cv2.boundingRect(c) for c in pfl_contours]
pfl_rects = [ pfl_rects = [
r for r in pfl_rects if r[3] > component_pfl_processed.shape[0] * 0.1 r for r in pfl_rects if r[3] > component_pfl_processed.shape[0] * 0.1
] ]
pfl_rects = sorted(pfl_rects, key=lambda r: r[1]) pfl_rects = sorted(pfl_rects, key=lambda r: r[1])
return [ pfl_rects_adjusted = [
( (
max(rect[0] - 2, 0), max(rect[0] - 2, 0),
rect[1], rect[1],
@ -144,9 +143,10 @@ class ChieriBotV4Best30Scenario(Best30Scenario):
) )
for rect in pfl_rects for rect in pfl_rects
] ]
return pfl_rects_adjusted
def preprocess_component_pfl(self, component_bgr: Mat) -> Mat: def preprocess_component_pfl(self, component_bgr: Mat) -> Mat:
pfl_rect = self.rois.component_rois.pfl_rect.rounded() pfl_rect = construct_int_xywh_rect(self.rois.component_rois.pfl_rect)
pfl_roi = crop_xywh(component_bgr, pfl_rect) pfl_roi = crop_xywh(component_bgr, pfl_rect)
pfl_roi_hsv = cv2.cvtColor(pfl_roi, cv2.COLOR_BGR2HSV) pfl_roi_hsv = cv2.cvtColor(pfl_roi, cv2.COLOR_BGR2HSV)
@ -166,17 +166,11 @@ class ChieriBotV4Best30Scenario(Best30Scenario):
pfl_roi_blurred = cv2.GaussianBlur(pfl_roi, (5, 5), 0) pfl_roi_blurred = cv2.GaussianBlur(pfl_roi, (5, 5), 0)
# pfl_roi_blurred = cv2.medianBlur(pfl_roi, 3) # pfl_roi_blurred = cv2.medianBlur(pfl_roi, 3)
_, pfl_roi_blurred_threshold = cv2.threshold( _, pfl_roi_blurred_threshold = cv2.threshold(
pfl_roi_blurred, pfl_roi_blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
0,
255,
cv2.THRESH_BINARY + cv2.THRESH_OTSU,
) )
# and a threshold of the original roi # and a threshold of the original roi
_, pfl_roi_threshold = cv2.threshold( _, pfl_roi_threshold = cv2.threshold(
pfl_roi, pfl_roi, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
0,
255,
cv2.THRESH_BINARY + cv2.THRESH_OTSU,
) )
# turn thresholds into black background # turn thresholds into black background
if pfl_roi_blurred_threshold[2][2] == 255: if pfl_roi_blurred_threshold[2][2] == 255:
@ -186,58 +180,64 @@ class ChieriBotV4Best30Scenario(Best30Scenario):
# return a bitwise_and result # return a bitwise_and result
result = cv2.bitwise_and(pfl_roi_blurred_threshold, pfl_roi_threshold) result = cv2.bitwise_and(pfl_roi_blurred_threshold, pfl_roi_threshold)
result_eroded = cv2.erode( result_eroded = cv2.erode(
result, result, cv2.getStructuringElement(cv2.MORPH_CROSS, (2, 2))
cv2.getStructuringElement(cv2.MORPH_CROSS, (2, 2)),
) )
return result_eroded if len(self.find_pfl_rects(result_eroded)) == 3 else result return result_eroded if len(self.find_pfl_rects(result_eroded)) == 3 else result
def ocr_component_pfl( def ocr_component_pfl(
self, self, component_bgr: Mat
component_bgr: Mat, ) -> Tuple[Optional[int], Optional[int], Optional[int]]:
) -> tuple[int | None, int | None, int | None]:
try: try:
pfl_roi = self.preprocess_component_pfl(component_bgr) pfl_roi = self.preprocess_component_pfl(component_bgr)
pfl_rects = self.find_pfl_rects(pfl_roi) pfl_rects = self.find_pfl_rects(pfl_roi)
pure_far_lost = [] pure_far_lost = []
for pfl_roi_rect in pfl_rects: for pfl_roi_rect in pfl_rects:
roi = crop_xywh(pfl_roi, pfl_roi_rect) roi = crop_xywh(pfl_roi, pfl_roi_rect)
result = self.pfl_knn_provider.result(roi) digit_contours, _ = cv2.findContours(
pure_far_lost.append(int(result) if result else None) roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
digit_rects = [cv2.boundingRect(c) for c in digit_contours]
digit_rects = FixRects.connect_broken(
digit_rects, roi.shape[1], roi.shape[0]
)
digit_rects = FixRects.split_connected(roi, digit_rects)
digit_rects = sorted(digit_rects, key=lambda r: r[0])
digits = []
for digit_rect in digit_rects:
digit = crop_xywh(roi, digit_rect)
digit = resize_fill_square(digit, 20)
digits.append(digit)
samples = preprocess_hog(digits)
_, results, _, _ = self.pfl_knn.findNearest(samples, 4)
results = [str(int(i)) for i in results.ravel()]
pure_far_lost.append(int("".join(results)))
return tuple(pure_far_lost) return tuple(pure_far_lost)
except Exception: # noqa: BLE001 except Exception:
return (None, None, None) return (None, None, None)
def ocr_component(self, component_bgr: Mat) -> OcrScenarioResult: def ocr_component(self, component_bgr: Mat) -> B30OcrResultItem:
component_blur = cv2.GaussianBlur(component_bgr, (5, 5), 0) component_blur = cv2.GaussianBlur(component_bgr, (5, 5), 0)
rating_class = self.ocr_component_rating_class(component_blur) rating_class = self.ocr_component_rating_class(component_blur)
song_id_results = self.ocr_component_song_id_results(component_bgr) song_id = self.ocr_component_song_id(component_bgr)
# title = self.ocr_component_title(component_blur)
# score = self.ocr_component_score(component_blur) # score = self.ocr_component_score(component_blur)
score = self.ocr_component_score_knn(component_bgr) score = self.ocr_component_score_knn(component_bgr)
pure, far, lost = self.ocr_component_pfl(component_bgr) pure, far, lost = self.ocr_component_pfl(component_bgr)
return OcrScenarioResult( return B30OcrResultItem(
song_id=song_id_results[0].image_id, song_id=song_id,
song_id_results=song_id_results,
rating_class=rating_class, rating_class=rating_class,
# title=title,
score=score, score=score,
pure=pure, pure=pure,
far=far, far=far,
lost=lost, lost=lost,
played_at=None, date=None,
) )
def components(self, img: Mat, /): def ocr(self, img_bgr: Mat) -> List[B30OcrResultItem]:
""" self.set_factor(img_bgr)
:param img: BGR format image return [
""" self.ocr_component(component_bgr)
self.set_factor(img) for component_bgr in self.rois.components(img_bgr)
return self.rois.components(img) ]
def result(self, component_img: Mat, /):
return self.ocr_component(component_img)
def results(self, img: Mat, /) -> list[OcrScenarioResult]:
"""
:param img: BGR format image
"""
return [self.ocr_component(component) for component in self.components(img)]

View File

@ -1,11 +1,12 @@
from __future__ import annotations from typing import List, Optional
from arcaea_offline_ocr.crop import crop_xywh from ....crop import crop_xywh
from arcaea_offline_ocr.types import Mat, XYWHRect from ....types import Mat, XYWHRect
from ....utils import apply_factor, construct_int_xywh_rect
class ChieriBotV4ComponentRois: class ChieriBotV4ComponentRois:
def __init__(self, factor: float = 1.0): def __init__(self, factor: Optional[float] = 1.0):
self.__factor = factor self.__factor = factor
@property @property
@ -18,43 +19,43 @@ class ChieriBotV4ComponentRois:
@property @property
def top_font_color_detect(self): def top_font_color_detect(self):
return XYWHRect(35, 10, 120, 100), self.factor return apply_factor((35, 10, 120, 100), self.factor)
@property @property
def bottom_font_color_detect(self): def bottom_font_color_detect(self):
return XYWHRect(30, 125, 175, 110) * self.factor return apply_factor((30, 125, 175, 110), self.factor)
@property @property
def bg_point(self): def bg_point(self):
return (75 * self.factor, 10 * self.factor) return apply_factor((75, 10), self.factor)
@property @property
def rating_class_rect(self): def rating_class_rect(self):
return XYWHRect(21, 40, 7, 20) * self.factor return apply_factor((21, 40, 7, 20), self.factor)
@property @property
def title_rect(self): def title_rect(self):
return XYWHRect(35, 10, 430, 50) * self.factor return apply_factor((35, 10, 430, 50), self.factor)
@property @property
def jacket_rect(self): def jacket_rect(self):
return XYWHRect(263, 0, 239, 239) * self.factor return apply_factor((263, 0, 239, 239), self.factor)
@property @property
def score_rect(self): def score_rect(self):
return XYWHRect(30, 60, 270, 55) * self.factor return apply_factor((30, 60, 270, 55), self.factor)
@property @property
def pfl_rect(self): def pfl_rect(self):
return XYWHRect(50, 125, 80, 100) * self.factor return apply_factor((50, 125, 80, 100), self.factor)
@property @property
def date_rect(self): def date_rect(self):
return XYWHRect(205, 200, 225, 25) * self.factor return apply_factor((205, 200, 225, 25), self.factor)
class ChieriBotV4Rois: class ChieriBotV4Rois:
def __init__(self, factor: float = 1.0): def __init__(self, factor: Optional[float] = 1.0):
self.__factor = factor self.__factor = factor
self.__component_rois = ChieriBotV4ComponentRois(factor) self.__component_rois = ChieriBotV4ComponentRois(factor)
@ -73,53 +74,54 @@ class ChieriBotV4Rois:
@property @property
def top(self): def top(self):
return 823 * self.factor return apply_factor(823, self.factor)
@property @property
def left(self): def left(self):
return 107 * self.factor return apply_factor(107, self.factor)
@property @property
def width(self): def width(self):
return 502 * self.factor return apply_factor(502, self.factor)
@property @property
def height(self): def height(self):
return 240 * self.factor return apply_factor(240, self.factor)
@property @property
def vertical_gap(self): def vertical_gap(self):
return 74 * self.factor return apply_factor(74, self.factor)
@property @property
def horizontal_gap(self): def horizontal_gap(self):
return 40 * self.factor return apply_factor(40, self.factor)
@property @property
def horizontal_items(self): def horizontal_items(self):
return 3 return 3
vertical_items = 10 @property
def vertical_items(self):
return 10
@property @property
def b33_vertical_gap(self): def b33_vertical_gap(self):
return 121 * self.factor return apply_factor(121, self.factor)
def components(self, img_bgr: Mat) -> list[Mat]: def components(self, img_bgr: Mat) -> List[Mat]:
first_rect = XYWHRect(x=self.left, y=self.top, w=self.width, h=self.height) first_rect = XYWHRect(x=self.left, y=self.top, w=self.width, h=self.height)
results = [] results = []
last_rect = first_rect
for vi in range(self.vertical_items): for vi in range(self.vertical_items):
rect = XYWHRect(*first_rect) rect = XYWHRect(*first_rect)
rect += (0, (self.vertical_gap + self.height) * vi, 0, 0) rect += (0, (self.vertical_gap + self.height) * vi, 0, 0)
for hi in range(self.horizontal_items): for hi in range(self.horizontal_items):
if hi > 0: if hi > 0:
rect += ((self.width + self.horizontal_gap), 0, 0, 0) rect += ((self.width + self.horizontal_gap), 0, 0, 0)
results.append(crop_xywh(img_bgr, rect.rounded())) int_rect = construct_int_xywh_rect(rect)
last_rect = rect results.append(crop_xywh(img_bgr, int_rect))
last_rect += ( rect += (
-(self.width + self.horizontal_gap) * 2, -(self.width + self.horizontal_gap) * 2,
self.height + self.b33_vertical_gap, self.height + self.b33_vertical_gap,
0, 0,
@ -127,7 +129,8 @@ class ChieriBotV4Rois:
) )
for hi in range(self.horizontal_items): for hi in range(self.horizontal_items):
if hi > 0: if hi > 0:
last_rect += ((self.width + self.horizontal_gap), 0, 0, 0) rect += ((self.width + self.horizontal_gap), 0, 0, 0)
results.append(crop_xywh(img_bgr, last_rect.rounded())) int_rect = construct_int_xywh_rect(rect)
results.append(crop_xywh(img_bgr, int_rect))
return results return results

View File

@ -0,0 +1,16 @@
from datetime import datetime
from typing import Optional
import attrs
@attrs.define
class B30OcrResultItem:
rating_class: int
score: int
pure: Optional[int] = None
far: Optional[int] = None
lost: Optional[int] = None
date: Optional[datetime] = None
title: Optional[str] = None
song_id: Optional[str] = None

View File

@ -1,6 +0,0 @@
from .ihdb import ImageHashDatabaseBuildTask, ImageHashesDatabaseBuilder
__all__ = [
"ImageHashDatabaseBuildTask",
"ImageHashesDatabaseBuilder",
]

View File

@ -1,115 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Callable
import cv2
from arcaea_offline_ocr.core import hashers
from arcaea_offline_ocr.providers.ihdb import (
PROP_KEY_BUILT_AT,
PROP_KEY_HASH_SIZE,
PROP_KEY_HIGH_FREQ_FACTOR,
ImageHashDatabaseIdProvider,
ImageHashType,
)
if TYPE_CHECKING:
from sqlite3 import Connection
from arcaea_offline_ocr.providers import ImageCategory
from arcaea_offline_ocr.types import Mat
def _default_imread_gray(image_path: str):
return cv2.cvtColor(cv2.imread(image_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2GRAY)
@dataclass
class ImageHashDatabaseBuildTask:
image_path: str
image_id: str
category: ImageCategory
imread_function: Callable[[str], Mat] = _default_imread_gray
@dataclass
class _ImageHash:
image_id: str
category: ImageCategory
image_hash_type: ImageHashType
hash: bytes
class ImageHashesDatabaseBuilder:
@staticmethod
def __insert_property(conn: Connection, key: str, value: str):
return conn.execute(
"INSERT INTO properties (key, value) VALUES (?, ?)",
(key, value),
)
@classmethod
def build(
cls,
conn: Connection,
tasks: list[ImageHashDatabaseBuildTask],
*,
hash_size: int = 16,
high_freq_factor: int = 4,
):
hashes: list[_ImageHash] = []
for task in tasks:
img_gray = task.imread_function(task.image_path)
for hash_type, hash_mat in [
(
ImageHashType.AVERAGE,
hashers.average(img_gray, hash_size),
),
(
ImageHashType.DCT,
hashers.dct(img_gray, hash_size, high_freq_factor),
),
(
ImageHashType.DIFFERENCE,
hashers.difference(img_gray, hash_size),
),
]:
hashes.append(
_ImageHash(
image_id=task.image_id,
image_hash_type=hash_type,
category=task.category,
hash=ImageHashDatabaseIdProvider.hash_mat_to_bytes(hash_mat),
),
)
conn.execute("CREATE TABLE properties (`key` VARCHAR, `value` VARCHAR)")
conn.execute(
"""CREATE TABLE hashes (
`id` VARCHAR,
`category` INTEGER,
`hash_type` INTEGER,
`hash` BLOB
)""",
)
now = datetime.now(tz=timezone.utc)
timestamp = int(now.timestamp() * 1000)
cls.__insert_property(conn, PROP_KEY_HASH_SIZE, str(hash_size))
cls.__insert_property(conn, PROP_KEY_HIGH_FREQ_FACTOR, str(high_freq_factor))
cls.__insert_property(conn, PROP_KEY_BUILT_AT, str(timestamp))
conn.executemany(
"""INSERT INTO hashes (`id`, `category`, `hash_type`, `hash`)
VALUES (?, ?, ?, ?)""",
[
(it.image_id, it.category.value, it.image_hash_type.value, it.hash)
for it in hashes
],
)
conn.commit()

View File

@ -1,3 +0,0 @@
from .index import average, dct, difference
__all__ = ["average", "dct", "difference"]

View File

@ -1,7 +0,0 @@
import cv2
from arcaea_offline_ocr.types import Mat
def _resize_image(src: Mat, dsize: ...) -> Mat:
return cv2.resize(src, dsize, fx=0, fy=0, interpolation=cv2.INTER_AREA)

View File

@ -1,35 +0,0 @@
import cv2
import numpy as np
from arcaea_offline_ocr.types import Mat
from ._common import _resize_image
def average(img_gray: Mat, hash_size: int) -> Mat:
img_resized = _resize_image(img_gray, (hash_size, hash_size))
diff = img_resized > img_resized.mean()
return diff.flatten()
def difference(img_gray: Mat, hash_size: int) -> Mat:
img_size = (hash_size + 1, hash_size)
img_resized = _resize_image(img_gray, img_size)
previous = img_resized[:, :-1]
current = img_resized[:, 1:]
diff = previous > current
return diff.flatten()
def dct(img_gray: Mat, hash_size: int = 16, high_freq_factor: int = 4) -> Mat:
# TODO: consistency? # noqa: FIX002, TD002, TD003
img_size_base = hash_size * high_freq_factor
img_size = (img_size_base, img_size_base)
img_resized = _resize_image(img_gray, img_size)
img_resized = img_resized.astype(np.float32)
dct_mat = cv2.dct(img_resized)
hash_mat = dct_mat[:hash_size, :hash_size]
return hash_mat > hash_mat.mean()

View File

@ -1,32 +1,29 @@
from __future__ import annotations
import math import math
from typing import TYPE_CHECKING from typing import Tuple
import cv2 import cv2
import numpy as np import numpy as np
if TYPE_CHECKING: from .types import Mat
from .types import Mat
__all__ = ["CropBlackEdges", "crop_xywh"] __all__ = ["crop_xywh", "CropBlackEdges"]
def crop_xywh(mat: Mat, rect: tuple[int, int, int, int]): def crop_xywh(mat: Mat, rect: Tuple[int, int, int, int]):
x, y, w, h = rect x, y, w, h = rect
return mat[y : y + h, x : x + w] return mat[y : y + h, x : x + w]
class CropBlackEdges: class CropBlackEdges:
@staticmethod @staticmethod
def is_black_edge(img_gray_slice: Mat, black_pixel: int, ratio: float = 0.6): def is_black_edge(__img_gray_slice: Mat, black_pixel: int, ratio: float = 0.6):
pixels_compared = img_gray_slice < black_pixel pixels_compared = __img_gray_slice < black_pixel
return np.count_nonzero(pixels_compared) > math.floor( return np.count_nonzero(pixels_compared) > math.floor(
img_gray_slice.size * ratio, __img_gray_slice.size * ratio
) )
@classmethod @classmethod
def get_crop_rect(cls, img_gray: Mat, black_threshold: int = 25): # noqa: C901 def get_crop_rect(cls, img_gray: Mat, black_threshold: int = 25):
height, width = img_gray.shape[:2] height, width = img_gray.shape[:2]
left = 0 left = 0
right = width right = width
@ -57,22 +54,13 @@ class CropBlackEdges:
break break
bottom -= 1 bottom -= 1
if right <= left: assert right > left, "cropped width < 0"
msg = "cropped width < 0" assert bottom > top, "cropped height < 0"
raise ValueError(msg)
if bottom <= top:
msg = "cropped height < 0"
raise ValueError(msg)
return (left, top, right - left, bottom - top) return (left, top, right - left, bottom - top)
@classmethod @classmethod
def crop( def crop(
cls, cls, img: Mat, convert_flag: cv2.COLOR_BGR2GRAY, black_threshold: int = 25
img: Mat,
convert_flag: cv2.COLOR_BGR2GRAY,
black_threshold: int = 25,
) -> Mat: ) -> Mat:
rect = cls.get_crop_rect(cv2.cvtColor(img, convert_flag), black_threshold) rect = cls.get_crop_rect(cv2.cvtColor(img, convert_flag), black_threshold)
return crop_xywh(img, rect) return crop_xywh(img, rect)

View File

@ -0,0 +1,2 @@
from .common import DeviceOcrResult
from .ocr import DeviceOcr

View File

@ -0,0 +1,18 @@
from typing import Optional
import attrs
@attrs.define
class DeviceOcrResult:
rating_class: int
pure: int
far: int
lost: int
score: int
max_recall: Optional[int] = None
song_id: Optional[str] = None
song_id_possibility: Optional[float] = None
clear_status: Optional[int] = None
partner_id: Optional[str] = None
partner_id_possibility: Optional[float] = None

View File

@ -1,56 +1,58 @@
import cv2 import cv2
import numpy as np import numpy as np
from arcaea_offline_ocr.providers import ( from ..crop import crop_xywh
ImageCategory, from ..ocr import (
ImageIdProvider, FixRects,
OcrKNearestTextProvider, ocr_digit_samples_knn,
ocr_digits_by_contour_knn,
preprocess_hog,
resize_fill_square,
) )
from arcaea_offline_ocr.scenarios.base import OcrScenarioResult from ..phash_db import ImagePhashDatabase
from arcaea_offline_ocr.types import Mat from ..types import Mat
from .common import DeviceOcrResult
from .base import DeviceScenarioBase from .rois.extractor import DeviceRoisExtractor
from .extractor import DeviceRoisExtractor from .rois.masker import DeviceRoisMasker
from .masker import DeviceRoisMasker
class DeviceScenario(DeviceScenarioBase): class DeviceOcr:
def __init__( def __init__(
self, self,
extractor: DeviceRoisExtractor, extractor: DeviceRoisExtractor,
masker: DeviceRoisMasker, masker: DeviceRoisMasker,
knn_provider: OcrKNearestTextProvider, knn_model: cv2.ml.KNearest,
image_id_provider: ImageIdProvider, phash_db: ImagePhashDatabase,
): ):
self.extractor = extractor self.extractor = extractor
self.masker = masker self.masker = masker
self.knn_provider = knn_provider self.knn_model = knn_model
self.image_id_provider = image_id_provider self.phash_db = phash_db
def pfl(self, roi_gray: Mat, factor: float = 1.25): def pfl(self, roi_gray: Mat, factor: float = 1.25):
def contour_filter(cnt): contours, _ = cv2.findContours(
return cv2.contourArea(cnt) >= 5 * factor roi_gray, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
contours = self.knn_provider.contours(roi_gray)
contours_filtered = self.knn_provider.contours(
roi_gray,
contours_filter=contour_filter,
) )
filtered_contours = [c for c in contours if cv2.contourArea(c) >= 5 * factor]
rects = [cv2.boundingRect(c) for c in filtered_contours]
rects = FixRects.connect_broken(rects, roi_gray.shape[1], roi_gray.shape[0])
filtered_rects = [r for r in rects if r[2] >= 5 * factor and r[3] >= 6 * factor]
filtered_rects = FixRects.split_connected(roi_gray, filtered_rects)
filtered_rects = sorted(filtered_rects, key=lambda r: r[0])
roi_ocr = roi_gray.copy() roi_ocr = roi_gray.copy()
contours_filtered_flattened = {tuple(c.flatten()) for c in contours_filtered} filtered_contours_flattened = {tuple(c.flatten()) for c in filtered_contours}
for contour in contours: for contour in contours:
if tuple(contour.flatten()) in contours_filtered_flattened: if tuple(contour.flatten()) in filtered_contours_flattened:
continue continue
roi_ocr = cv2.fillPoly(roi_ocr, [contour], [0]) roi_ocr = cv2.fillPoly(roi_ocr, [contour], [0])
digit_rois = [
resize_fill_square(crop_xywh(roi_ocr, r), 20) for r in filtered_rects
]
ocr_result = self.knn_provider.result( samples = preprocess_hog(digit_rois)
roi_ocr, return ocr_digit_samples_knn(samples, self.knn_model)
contours_filter=lambda cnt: cv2.contourArea(cnt) >= 5 * factor,
rects_filter=lambda rect: rect[2] >= 5 * factor and rect[3] >= 6 * factor,
)
return int(ocr_result) if ocr_result else 0
def pure(self): def pure(self):
return self.pfl(self.masker.pure(self.extractor.pure)) return self.pfl(self.masker.pure(self.extractor.pure))
@ -63,14 +65,13 @@ class DeviceScenario(DeviceScenarioBase):
def score(self): def score(self):
roi = self.masker.score(self.extractor.score) roi = self.masker.score(self.extractor.score)
contours = self.knn_provider.contours(roi) contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
for contour in contours: for contour in contours:
if ( if (
cv2.boundingRect(contour)[3] < roi.shape[0] * 0.6 cv2.boundingRect(contour)[3] < roi.shape[0] * 0.6
): # h < score_component_h * 0.6 ): # h < score_component_h * 0.6
roi = cv2.fillPoly(roi, [contour], [0]) roi = cv2.fillPoly(roi, [contour], [0])
ocr_result = self.knn_provider.result(roi) return ocr_digits_by_contour_knn(roi, self.knn_model)
return int(ocr_result) if ocr_result else 0
def rating_class(self): def rating_class(self):
roi = self.extractor.rating_class roi = self.extractor.rating_class
@ -84,10 +85,9 @@ class DeviceScenario(DeviceScenarioBase):
return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0] return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0]
def max_recall(self): def max_recall(self):
ocr_result = self.knn_provider.result( return ocr_digits_by_contour_knn(
self.masker.max_recall(self.extractor.max_recall), self.masker.max_recall(self.extractor.max_recall), self.knn_model
) )
return int(ocr_result) if ocr_result else None
def clear_status(self): def clear_status(self):
roi = self.extractor.clear_status roi = self.extractor.clear_status
@ -99,18 +99,20 @@ class DeviceScenario(DeviceScenarioBase):
] ]
return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0] return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0]
def song_id_results(self): def lookup_song_id(self):
return self.image_id_provider.results( return self.phash_db.lookup_jacket(
cv2.cvtColor(self.extractor.jacket, cv2.COLOR_BGR2GRAY), cv2.cvtColor(self.extractor.jacket, cv2.COLOR_BGR2GRAY)
ImageCategory.JACKET,
) )
def song_id(self):
return self.lookup_song_id()[0]
@staticmethod @staticmethod
def preprocess_char_icon(img_gray: Mat): def preprocess_char_icon(img_gray: Mat):
h, w = img_gray.shape[:2] h, w = img_gray.shape[:2]
img = cv2.copyMakeBorder(img_gray, max(w - h, 0), 0, 0, 0, cv2.BORDER_REPLICATE) img = cv2.copyMakeBorder(img_gray, w - h, 0, 0, 0, cv2.BORDER_REPLICATE)
h, w = img.shape[:2] h, w = img.shape[:2]
return cv2.fillPoly( img = cv2.fillPoly(
img, img,
[ [
np.array([[0, 0], [round(w / 2), 0], [0, round(h / 2)]], np.int32), np.array([[0, 0], [round(w / 2), 0], [0, round(h / 2)]], np.int32),
@ -118,18 +120,21 @@ class DeviceScenario(DeviceScenarioBase):
np.array([[0, h], [round(w / 2), h], [0, round(h / 2)]], np.int32), np.array([[0, h], [round(w / 2), h], [0, round(h / 2)]], np.int32),
np.array([[w, h], [round(w / 2), h], [w, round(h / 2)]], np.int32), np.array([[w, h], [round(w / 2), h], [w, round(h / 2)]], np.int32),
], ],
(128,), (128),
) )
return img
def partner_id_results(self): def lookup_partner_id(self):
return self.image_id_provider.results( return self.phash_db.lookup_partner_icon(
self.preprocess_char_icon( self.preprocess_char_icon(
cv2.cvtColor(self.extractor.partner_icon, cv2.COLOR_BGR2GRAY), cv2.cvtColor(self.extractor.partner_icon, cv2.COLOR_BGR2GRAY)
), )
ImageCategory.PARTNER_ICON,
) )
def result(self): def partner_id(self):
return self.lookup_partner_id()[0]
def ocr(self) -> DeviceOcrResult:
rating_class = self.rating_class() rating_class = self.rating_class()
pure = self.pure() pure = self.pure()
far = self.far() far = self.far()
@ -138,18 +143,20 @@ class DeviceScenario(DeviceScenarioBase):
max_recall = self.max_recall() max_recall = self.max_recall()
clear_status = self.clear_status() clear_status = self.clear_status()
song_id_results = self.song_id_results() hash_len = self.phash_db.hash_size**2
partner_id_results = self.partner_id_results() song_id, song_id_distance = self.lookup_song_id()
partner_id, partner_id_distance = self.lookup_partner_id()
return OcrScenarioResult( return DeviceOcrResult(
song_id=song_id_results[0].image_id,
song_id_results=song_id_results,
rating_class=rating_class, rating_class=rating_class,
pure=pure, pure=pure,
far=far, far=far,
lost=lost, lost=lost,
score=score, score=score,
max_recall=max_recall, max_recall=max_recall,
partner_id_results=partner_id_results, song_id=song_id,
song_id_possibility=1 - song_id_distance / hash_len,
clear_status=clear_status, clear_status=clear_status,
partner_id=partner_id,
partner_id_possibility=1 - partner_id_distance / hash_len,
) )

View File

@ -0,0 +1,3 @@
from .definition import *
from .extractor import *
from .masker import *

View File

@ -0,0 +1,2 @@
from .auto import *
from .common import DeviceRois

View File

@ -1,6 +1,6 @@
from arcaea_offline_ocr.types import XYWHRect from .common import DeviceRois
from .base import DeviceRois __all__ = ["DeviceRoisAuto", "DeviceRoisAutoT1", "DeviceRoisAutoT2"]
class DeviceRoisAuto(DeviceRois): class DeviceRoisAuto(DeviceRois):
@ -50,7 +50,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property @property
def pure(self): def pure(self):
return XYWHRect( return (
self.pfl_x, self.pfl_x,
self.layout_area_h_mid + 110 * self.factor, self.layout_area_h_mid + 110 * self.factor,
self.pfl_w, self.pfl_w,
@ -59,7 +59,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property @property
def far(self): def far(self):
return XYWHRect( return (
self.pfl_x, self.pfl_x,
self.pure[1] + self.pure[3] + 12 * self.factor, self.pure[1] + self.pure[3] + 12 * self.factor,
self.pfl_w, self.pfl_w,
@ -68,7 +68,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property @property
def lost(self): def lost(self):
return XYWHRect( return (
self.pfl_x, self.pfl_x,
self.far[1] + self.far[3] + 10 * self.factor, self.far[1] + self.far[3] + 10 * self.factor,
self.pfl_w, self.pfl_w,
@ -79,7 +79,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
def score(self): def score(self):
w = 280 * self.factor w = 280 * self.factor
h = 45 * self.factor h = 45 * self.factor
return XYWHRect( return (
self.w_mid - w / 2, self.w_mid - w / 2,
self.layout_area_h_mid - 75 * self.factor - h, self.layout_area_h_mid - 75 * self.factor - h,
w, w,
@ -88,7 +88,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property @property
def rating_class(self): def rating_class(self):
return XYWHRect( return (
self.w_mid - 610 * self.factor, self.w_mid - 610 * self.factor,
self.layout_area_h_mid - 180 * self.factor, self.layout_area_h_mid - 180 * self.factor,
265 * self.factor, 265 * self.factor,
@ -97,7 +97,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property @property
def max_recall(self): def max_recall(self):
return XYWHRect( return (
self.w_mid - 465 * self.factor, self.w_mid - 465 * self.factor,
self.layout_area_h_mid - 215 * self.factor, self.layout_area_h_mid - 215 * self.factor,
150 * self.factor, 150 * self.factor,
@ -106,7 +106,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
@property @property
def jacket(self): def jacket(self):
return XYWHRect( return (
self.w_mid - 610 * self.factor, self.w_mid - 610 * self.factor,
self.layout_area_h_mid - 143 * self.factor, self.layout_area_h_mid - 143 * self.factor,
375 * self.factor, 375 * self.factor,
@ -117,7 +117,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
def clear_status(self): def clear_status(self):
w = 550 * self.factor w = 550 * self.factor
h = 60 * self.factor h = 60 * self.factor
return XYWHRect( return (
self.w_mid - w / 2, self.w_mid - w / 2,
self.layout_area_h_mid - 155 * self.factor - h, self.layout_area_h_mid - 155 * self.factor - h,
w * 0.4, w * 0.4,
@ -128,7 +128,7 @@ class DeviceRoisAutoT1(DeviceRoisAuto):
def partner_icon(self): def partner_icon(self):
w = 90 * self.factor w = 90 * self.factor
h = 75 * self.factor h = 75 * self.factor
return XYWHRect(self.w_mid - w / 2, 0, w, h) return (self.w_mid - w / 2, 0, w, h)
class DeviceRoisAutoT2(DeviceRoisAuto): class DeviceRoisAutoT2(DeviceRoisAuto):
@ -174,7 +174,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property @property
def pure(self): def pure(self):
return XYWHRect( return (
self.pfl_x, self.pfl_x,
self.layout_area_h_mid + 175 * self.factor, self.layout_area_h_mid + 175 * self.factor,
self.pfl_w, self.pfl_w,
@ -183,7 +183,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property @property
def far(self): def far(self):
return XYWHRect( return (
self.pfl_x, self.pfl_x,
self.pure[1] + self.pure[3] + 30 * self.factor, self.pure[1] + self.pure[3] + 30 * self.factor,
self.pfl_w, self.pfl_w,
@ -192,7 +192,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property @property
def lost(self): def lost(self):
return XYWHRect( return (
self.pfl_x, self.pfl_x,
self.far[1] + self.far[3] + 35 * self.factor, self.far[1] + self.far[3] + 35 * self.factor,
self.pfl_w, self.pfl_w,
@ -203,7 +203,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
def score(self): def score(self):
w = 420 * self.factor w = 420 * self.factor
h = 70 * self.factor h = 70 * self.factor
return XYWHRect( return (
self.w_mid - w / 2, self.w_mid - w / 2,
self.layout_area_h_mid - 110 * self.factor - h, self.layout_area_h_mid - 110 * self.factor - h,
w, w,
@ -212,7 +212,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property @property
def rating_class(self): def rating_class(self):
return XYWHRect( return (
max(0, self.w_mid - 965 * self.factor), max(0, self.w_mid - 965 * self.factor),
self.layout_area_h_mid - 330 * self.factor, self.layout_area_h_mid - 330 * self.factor,
350 * self.factor, 350 * self.factor,
@ -221,7 +221,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property @property
def max_recall(self): def max_recall(self):
return XYWHRect( return (
self.w_mid - 625 * self.factor, self.w_mid - 625 * self.factor,
self.layout_area_h_mid - 275 * self.factor, self.layout_area_h_mid - 275 * self.factor,
150 * self.factor, 150 * self.factor,
@ -230,7 +230,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
@property @property
def jacket(self): def jacket(self):
return XYWHRect( return (
self.w_mid - 915 * self.factor, self.w_mid - 915 * self.factor,
self.layout_area_h_mid - 215 * self.factor, self.layout_area_h_mid - 215 * self.factor,
565 * self.factor, 565 * self.factor,
@ -241,7 +241,7 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
def clear_status(self): def clear_status(self):
w = 825 * self.factor w = 825 * self.factor
h = 90 * self.factor h = 90 * self.factor
return XYWHRect( return (
self.w_mid - w / 2, self.w_mid - w / 2,
self.layout_area_h_mid - 235 * self.factor - h, self.layout_area_h_mid - 235 * self.factor - h,
w * 0.4, w * 0.4,
@ -252,4 +252,4 @@ class DeviceRoisAutoT2(DeviceRoisAuto):
def partner_icon(self): def partner_icon(self):
w = 135 * self.factor w = 135 * self.factor
h = 110 * self.factor h = 110 * self.factor
return XYWHRect(self.w_mid - w / 2, 0, w, h) return (self.w_mid - w / 2, 0, w, h)

View File

@ -0,0 +1,15 @@
from typing import Tuple
Rect = Tuple[int, int, int, int]
class DeviceRois:
pure: Rect
far: Rect
lost: Rect
score: Rect
rating_class: Rect
max_recall: Rect
jacket: Rect
clear_status: Rect
partner_icon: Rect

View File

@ -0,0 +1 @@
from .common import DeviceRoisExtractor

View File

@ -0,0 +1,48 @@
from ....crop import crop_xywh
from ....types import Mat
from ..definition.common import DeviceRois
class DeviceRoisExtractor:
def __init__(self, img: Mat, rois: DeviceRois):
self.img = img
self.sizes = rois
def __construct_int_rect(self, rect):
return tuple(round(r) for r in rect)
@property
def pure(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.pure))
@property
def far(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.far))
@property
def lost(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.lost))
@property
def score(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.score))
@property
def jacket(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.jacket))
@property
def rating_class(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.rating_class))
@property
def max_recall(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.max_recall))
@property
def clear_status(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.clear_status))
@property
def partner_icon(self):
return crop_xywh(self.img, self.__construct_int_rect(self.sizes.partner_icon))

View File

@ -0,0 +1,2 @@
from .auto import *
from .common import DeviceRoisMasker

View File

@ -1,18 +1,17 @@
import cv2 import cv2
import numpy as np import numpy as np
from arcaea_offline_ocr.types import Mat from ....types import Mat
from .common import DeviceRoisMasker
from .base import DeviceRoisMasker
class DeviceRoisMaskerAuto(DeviceRoisMasker): class DeviceRoisMaskerAuto(DeviceRoisMasker):
# pylint: disable=abstract-method
@staticmethod @staticmethod
def mask_bgr_in_hsv(roi_bgr: Mat, hsv_lower: Mat, hsv_upper: Mat): def mask_bgr_in_hsv(roi_bgr: Mat, hsv_lower: Mat, hsv_upper: Mat):
return cv2.inRange( return cv2.inRange(
cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV), hsv_lower, hsv_upper
hsv_lower,
hsv_upper,
) )
@ -102,33 +101,25 @@ class DeviceRoisMaskerAutoT1(DeviceRoisMaskerAuto):
@classmethod @classmethod
def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat: def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, roi_bgr, cls.TRACK_LOST_HSV_MIN, cls.TRACK_LOST_HSV_MAX
cls.TRACK_LOST_HSV_MIN,
cls.TRACK_LOST_HSV_MAX,
) )
@classmethod @classmethod
def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat: def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, roi_bgr, cls.TRACK_COMPLETE_HSV_MIN, cls.TRACK_COMPLETE_HSV_MAX
cls.TRACK_COMPLETE_HSV_MIN,
cls.TRACK_COMPLETE_HSV_MAX,
) )
@classmethod @classmethod
def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat: def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, roi_bgr, cls.FULL_RECALL_HSV_MIN, cls.FULL_RECALL_HSV_MAX
cls.FULL_RECALL_HSV_MIN,
cls.FULL_RECALL_HSV_MAX,
) )
@classmethod @classmethod
def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat: def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, roi_bgr, cls.PURE_MEMORY_HSV_MIN, cls.PURE_MEMORY_HSV_MAX
cls.PURE_MEMORY_HSV_MIN,
cls.PURE_MEMORY_HSV_MAX,
) )
@ -212,39 +203,29 @@ class DeviceRoisMaskerAutoT2(DeviceRoisMaskerAuto):
@classmethod @classmethod
def max_recall(cls, roi_bgr: Mat) -> Mat: def max_recall(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, roi_bgr, cls.MAX_RECALL_HSV_MIN, cls.MAX_RECALL_HSV_MAX
cls.MAX_RECALL_HSV_MIN,
cls.MAX_RECALL_HSV_MAX,
) )
@classmethod @classmethod
def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat: def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, roi_bgr, cls.TRACK_LOST_HSV_MIN, cls.TRACK_LOST_HSV_MAX
cls.TRACK_LOST_HSV_MIN,
cls.TRACK_LOST_HSV_MAX,
) )
@classmethod @classmethod
def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat: def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, roi_bgr, cls.TRACK_COMPLETE_HSV_MIN, cls.TRACK_COMPLETE_HSV_MAX
cls.TRACK_COMPLETE_HSV_MIN,
cls.TRACK_COMPLETE_HSV_MAX,
) )
@classmethod @classmethod
def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat: def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, roi_bgr, cls.FULL_RECALL_HSV_MIN, cls.FULL_RECALL_HSV_MAX
cls.FULL_RECALL_HSV_MIN,
cls.FULL_RECALL_HSV_MAX,
) )
@classmethod @classmethod
def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat: def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(
roi_bgr, roi_bgr, cls.PURE_MEMORY_HSV_MIN, cls.PURE_MEMORY_HSV_MAX
cls.PURE_MEMORY_HSV_MIN,
cls.PURE_MEMORY_HSV_MAX,
) )

View File

@ -0,0 +1,59 @@
from ....types import Mat
class DeviceRoisMasker:
@classmethod
def pure(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def far(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def lost(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def score(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def rating_class_pst(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def rating_class_prs(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def rating_class_ftr(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def rating_class_byd(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def rating_class_etr(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def max_recall(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod
def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()

View File

@ -1,31 +1,27 @@
from __future__ import annotations
import logging
import math import math
from typing import TYPE_CHECKING, Callable, Sequence from typing import Optional, Sequence, Tuple
import cv2 import cv2
import numpy as np import numpy as np
from arcaea_offline_ocr.crop import crop_xywh from .crop import crop_xywh
from .types import Mat
from .base import OcrTextProvider __all__ = [
"FixRects",
if TYPE_CHECKING: "preprocess_hog",
from cv2.ml import KNearest "ocr_digits_by_contour_get_samples",
"ocr_digits_by_contour_knn",
from arcaea_offline_ocr.types import Mat ]
logger = logging.getLogger(__name__)
class FixRects: class FixRects:
@staticmethod @staticmethod
def connect_broken( def connect_broken(
rects: Sequence[tuple[int, int, int, int]], rects: Sequence[Tuple[int, int, int, int]],
img_width: int, img_width: int,
img_height: int, img_height: int,
tolerance: int | None = None, tolerance: Optional[int] = None,
): ):
# for a "broken" digit, please refer to # for a "broken" digit, please refer to
# /assets/fix_rects/broken_masked.jpg # /assets/fix_rects/broken_masked.jpg
@ -73,7 +69,7 @@ class FixRects:
@staticmethod @staticmethod
def split_connected( def split_connected(
img_masked: Mat, img_masked: Mat,
rects: Sequence[tuple[int, int, int, int]], rects: Sequence[Tuple[int, int, int, int]],
rect_wh_ratio: float = 1.05, rect_wh_ratio: float = 1.05,
width_range_ratio: float = 0.1, width_range_ratio: float = 0.1,
): ):
@ -114,7 +110,7 @@ class FixRects:
# split the rect # split the rect
new_rects.extend( new_rects.extend(
[(rx, ry, x_mid - rx, rh), (x_mid, ry, rx + rw - x_mid, rh)], [(rx, ry, x_mid - rx, rh), (x_mid, ry, rx + rw - x_mid, rh)]
) )
return_rects = [r for r in rects if r not in connected_rects] return_rects = [r for r in rects if r not in connected_rects]
@ -135,21 +131,11 @@ def resize_fill_square(img: Mat, target: int = 20):
border_size = math.ceil((max(new_w, new_h) - min(new_w, new_h)) / 2) border_size = math.ceil((max(new_w, new_h) - min(new_w, new_h)) / 2)
if new_w < new_h: if new_w < new_h:
resized = cv2.copyMakeBorder( resized = cv2.copyMakeBorder(
resized, resized, 0, 0, border_size, border_size, cv2.BORDER_CONSTANT
0,
0,
border_size,
border_size,
cv2.BORDER_CONSTANT,
) )
else: else:
resized = cv2.copyMakeBorder( resized = cv2.copyMakeBorder(
resized, resized, border_size, border_size, 0, 0, cv2.BORDER_CONSTANT
border_size,
border_size,
0,
0,
cv2.BORDER_CONSTANT,
) )
return cv2.resize(resized, (target, target)) return cv2.resize(resized, (target, target))
@ -164,94 +150,31 @@ def preprocess_hog(digit_rois):
return np.float32(samples) return np.float32(samples)
def ocr_digit_samples_knn(samples, knn_model: cv2.ml.KNearest, k: int = 4): def ocr_digit_samples_knn(__samples, knn_model: cv2.ml.KNearest, k: int = 4):
_, results, _, _ = knn_model.findNearest(samples, k) _, results, _, _ = knn_model.findNearest(__samples, k)
return [int(r) for r in results.ravel()] result_list = [int(r) for r in results.ravel()]
result_str = "".join(str(r) for r in result_list if r > -1)
return int(result_str) if result_str else 0
class OcrKNearestTextProvider(OcrTextProvider): def ocr_digits_by_contour_get_samples(__roi_gray: Mat, size: int):
_ContourFilter = Callable[["Mat"], bool] roi = __roi_gray.copy()
_RectsFilter = Callable[[Sequence[int]], bool] contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
rects = [cv2.boundingRect(c) for c in contours]
def __init__(self, model: KNearest): rects = FixRects.connect_broken(rects, roi.shape[1], roi.shape[0])
self.model = model rects = FixRects.split_connected(roi, rects)
def contours(
self,
img: Mat,
/,
*,
contours_filter: _ContourFilter | None = None,
):
cnts, _ = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
if contours_filter:
cnts = list(filter(contours_filter, cnts))
return cnts
def result_raw(
self,
img: Mat,
/,
*,
fix_rects: bool = True,
contours_filter: _ContourFilter | None = None,
rects_filter: _RectsFilter | None = None,
):
"""
:param img: grayscaled roi
"""
try:
cnts, _ = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours_filter:
cnts = list(filter(contours_filter, cnts))
rects = [cv2.boundingRect(cnt) for cnt in cnts]
if fix_rects and rects_filter:
rects = FixRects.connect_broken(rects, img.shape[1], img.shape[0])
rects = list(filter(rects_filter, rects))
rects = FixRects.split_connected(img, rects)
elif fix_rects:
rects = FixRects.connect_broken(rects, img.shape[1], img.shape[0])
rects = FixRects.split_connected(img, rects)
elif rects_filter:
rects = list(filter(rects_filter, rects))
rects = sorted(rects, key=lambda r: r[0]) rects = sorted(rects, key=lambda r: r[0])
# digit_rois = [cv2.resize(crop_xywh(roi, rect), size) for rect in rects]
digit_rois = [resize_fill_square(crop_xywh(roi, rect), size) for rect in rects]
return preprocess_hog(digit_rois)
digits = []
for rect in rects:
digit = crop_xywh(img, rect)
digit = resize_fill_square(digit, 20)
digits.append(digit)
samples = preprocess_hog(digits)
return ocr_digit_samples_knn(samples, self.model)
except Exception:
logger.exception("Error occurred during KNearest OCR")
return None
def result( def ocr_digits_by_contour_knn(
self, __roi_gray: Mat,
img: Mat, knn_model: cv2.ml.KNearest,
/,
*, *,
fix_rects: bool = True, k=4,
contours_filter: _ContourFilter | None = None, size: int = 20,
rects_filter: _RectsFilter | None = None, ) -> int:
): samples = ocr_digits_by_contour_get_samples(__roi_gray, size)
""" return ocr_digit_samples_knn(samples, knn_model, k)
:param img: grayscaled roi
"""
raw = self.result_raw(
img,
fix_rects=fix_rects,
contours_filter=contours_filter,
rects_filter=rects_filter,
)
return (
"".join(["".join(str(r) for r in raw if r > -1)])
if raw is not None
else None
)

View File

@ -0,0 +1,119 @@
import sqlite3
from typing import List, Union
import cv2
import numpy as np
from .types import Mat
def phash_opencv(img_gray, hash_size=8, highfreq_factor=4):
# type: (Union[Mat, np.ndarray], int, int) -> np.ndarray
"""
Perceptual Hash computation.
Implementation follows
http://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html
Adapted from `imagehash.phash`, pure opencv implementation
The result is slightly different from `imagehash.phash`.
"""
if hash_size < 2:
raise ValueError("Hash size must be greater than or equal to 2")
img_size = hash_size * highfreq_factor
image = cv2.resize(img_gray, (img_size, img_size), interpolation=cv2.INTER_LANCZOS4)
image = np.float32(image)
dct = cv2.dct(image)
dctlowfreq = dct[:hash_size, :hash_size]
med = np.median(dctlowfreq)
diff = dctlowfreq > med
return diff
def hamming_distance_sql_function(user_input, db_entry) -> int:
return np.count_nonzero(
np.frombuffer(user_input, bool) ^ np.frombuffer(db_entry, bool)
)
class ImagePhashDatabase:
def __init__(self, db_path: str):
with sqlite3.connect(db_path) as conn:
self.hash_size = int(
conn.execute(
"SELECT value FROM properties WHERE key = 'hash_size'"
).fetchone()[0]
)
self.highfreq_factor = int(
conn.execute(
"SELECT value FROM properties WHERE key = 'highfreq_factor'"
).fetchone()[0]
)
self.built_timestamp = int(
conn.execute(
"SELECT value FROM properties WHERE key = 'built_timestamp'"
).fetchone()[0]
)
self.ids: List[str] = [
i[0] for i in conn.execute("SELECT id FROM hashes").fetchall()
]
self.hashes_byte = [
i[0] for i in conn.execute("SELECT hash FROM hashes").fetchall()
]
self.hashes = [np.frombuffer(hb, bool) for hb in self.hashes_byte]
self.jacket_ids: List[str] = []
self.jacket_hashes = []
self.partner_icon_ids: List[str] = []
self.partner_icon_hashes = []
for _id, _hash in zip(self.ids, self.hashes):
id_splitted = _id.split("||")
if len(id_splitted) > 1 and id_splitted[0] == "partner_icon":
self.partner_icon_ids.append(id_splitted[1])
self.partner_icon_hashes.append(_hash)
else:
self.jacket_ids.append(_id)
self.jacket_hashes.append(_hash)
def calculate_phash(self, img_gray: Mat):
return phash_opencv(
img_gray, hash_size=self.hash_size, highfreq_factor=self.highfreq_factor
)
def lookup_hash(self, image_hash: np.ndarray, *, limit: int = 5):
image_hash = image_hash.flatten()
xor_results = [
(id, np.count_nonzero(image_hash ^ h))
for id, h in zip(self.ids, self.hashes)
]
return sorted(xor_results, key=lambda r: r[1])[:limit]
def lookup_image(self, img_gray: Mat):
image_hash = self.calculate_phash(img_gray)
return self.lookup_hash(image_hash)[0]
def lookup_jackets(self, img_gray: Mat, *, limit: int = 5):
image_hash = self.calculate_phash(img_gray).flatten()
xor_results = [
(id, np.count_nonzero(image_hash ^ h))
for id, h in zip(self.jacket_ids, self.jacket_hashes)
]
return sorted(xor_results, key=lambda r: r[1])[:limit]
def lookup_jacket(self, img_gray: Mat):
return self.lookup_jackets(img_gray)[0]
def lookup_partner_icons(self, img_gray: Mat, *, limit: int = 5):
image_hash = self.calculate_phash(img_gray).flatten()
xor_results = [
(id, np.count_nonzero(image_hash ^ h))
for id, h in zip(self.partner_icon_ids, self.partner_icon_hashes)
]
return sorted(xor_results, key=lambda r: r[1])[:limit]
def lookup_partner_icon(self, img_gray: Mat):
return self.lookup_partner_icons(img_gray)[0]

View File

@ -1,12 +0,0 @@
from .base import ImageCategory, ImageIdProvider, ImageIdProviderResult, OcrTextProvider
from .ihdb import ImageHashDatabaseIdProvider
from .knn import OcrKNearestTextProvider
__all__ = [
"ImageCategory",
"ImageHashDatabaseIdProvider",
"ImageIdProvider",
"ImageIdProviderResult",
"OcrKNearestTextProvider",
"OcrTextProvider",
]

View File

@ -1,50 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import IntEnum
from typing import TYPE_CHECKING, Any, Sequence
if TYPE_CHECKING:
from arcaea_offline_ocr.types import Mat
class OcrTextProvider(ABC):
@abstractmethod
def result_raw(self, img: Mat, /, *args, **kwargs) -> Any: ...
@abstractmethod
def result(self, img: Mat, /, *args, **kwargs) -> str | None: ...
class ImageCategory(IntEnum):
JACKET = 0
PARTNER_ICON = 1
@dataclass(kw_only=True)
class ImageIdProviderResult:
image_id: str
category: ImageCategory
confidence: float
class ImageIdProvider(ABC):
@abstractmethod
def result(
self,
img: Mat,
category: ImageCategory,
/,
*args,
**kwargs,
) -> ImageIdProviderResult: ...
@abstractmethod
def results(
self,
img: Mat,
category: ImageCategory,
/,
*args,
**kwargs,
) -> Sequence[ImageIdProviderResult]: ...

View File

@ -1,203 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
from enum import IntEnum
from typing import TYPE_CHECKING, Any, Callable, TypeVar
from arcaea_offline_ocr.core import hashers
from .base import ImageCategory, ImageIdProvider, ImageIdProviderResult
if TYPE_CHECKING:
import sqlite3
from arcaea_offline_ocr.types import Mat
T = TypeVar("T")
PROP_KEY_HASH_SIZE = "hash_size"
PROP_KEY_HIGH_FREQ_FACTOR = "high_freq_factor"
PROP_KEY_BUILT_AT = "built_at"
def _sql_hamming_distance(hash1: bytes, hash2: bytes):
if len(hash1) != len(hash2):
msg = "hash size does not match!"
raise ValueError(msg)
return sum(1 for byte1, byte2 in zip(hash1, hash2) if byte1 != byte2)
class ImageHashType(IntEnum):
AVERAGE = 0
DIFFERENCE = 1
DCT = 2
@dataclass(kw_only=True)
class ImageHashDatabaseIdProviderResult(ImageIdProviderResult):
image_hash_type: ImageHashType
class MissingPropertiesError(Exception):
keys: list[str]
def __init__(self, keys, *args):
super().__init__(*args)
self.keys = keys
class ImageHashDatabaseIdProvider(ImageIdProvider):
def __init__(self, conn: sqlite3.Connection):
self.conn = conn
self.conn.create_function("HAMMING_DISTANCE", 2, _sql_hamming_distance)
self.properties = {
PROP_KEY_HASH_SIZE: -1,
PROP_KEY_HIGH_FREQ_FACTOR: -1,
PROP_KEY_BUILT_AT: None,
}
self._hashes_count = {
ImageCategory.JACKET: 0,
ImageCategory.PARTNER_ICON: 0,
}
self._hash_length: int = -1
self._initialize()
@property
def hash_size(self) -> int:
return self.properties[PROP_KEY_HASH_SIZE]
@property
def high_freq_factor(self) -> int:
return self.properties[PROP_KEY_HIGH_FREQ_FACTOR]
@property
def built_at(self) -> datetime | None:
return self.properties.get(PROP_KEY_BUILT_AT)
@property
def hash_length(self):
return self._hash_length
def _initialize(self):
def get_property(key, converter: Callable[[Any], T]) -> T | None:
result = self.conn.execute(
"SELECT value FROM properties WHERE key = ?",
(key,),
).fetchone()
return converter(result[0]) if result is not None else None
def set_hashes_count(category: ImageCategory):
self._hashes_count[category] = self.conn.execute(
"SELECT COUNT(DISTINCT `id`) FROM hashes WHERE category = ?",
(category.value,),
).fetchone()[0]
properties_converter_map = {
PROP_KEY_HASH_SIZE: lambda x: int(x),
PROP_KEY_HIGH_FREQ_FACTOR: lambda x: int(x),
PROP_KEY_BUILT_AT: lambda ts: datetime.fromtimestamp(
int(ts) / 1000,
tz=timezone.utc,
),
}
required_properties = [PROP_KEY_HASH_SIZE, PROP_KEY_HIGH_FREQ_FACTOR]
missing_properties = []
for property_key, converter in properties_converter_map.items():
value = get_property(property_key, converter)
if value is None:
if property_key in required_properties:
missing_properties.append(property_key)
continue
self.properties[property_key] = value
if missing_properties:
raise MissingPropertiesError(keys=missing_properties)
set_hashes_count(ImageCategory.JACKET)
set_hashes_count(ImageCategory.PARTNER_ICON)
self._hash_length = self.hash_size**2
def lookup_hash(
self,
category: ImageCategory,
hash_type: ImageHashType,
hash_data: bytes,
) -> list[ImageHashDatabaseIdProviderResult]:
cursor = self.conn.execute(
"""
SELECT
`id`,
HAMMING_DISTANCE(hash, ?) AS distance
FROM hashes
WHERE category = ? AND hash_type = ?
ORDER BY distance ASC LIMIT 10""",
(hash_data, category.value, hash_type.value),
)
results = []
for id_, distance in cursor.fetchall():
results.append(
ImageHashDatabaseIdProviderResult(
image_id=id_,
category=category,
confidence=(self.hash_length - distance) / self.hash_length,
image_hash_type=hash_type,
),
)
return results
@staticmethod
def hash_mat_to_bytes(hash_mat: Mat) -> bytes:
return bytes([255 if b else 0 for b in hash_mat.flatten()])
def results(self, img: Mat, category: ImageCategory, /):
results: list[ImageHashDatabaseIdProviderResult] = []
results.extend(
self.lookup_hash(
category,
ImageHashType.AVERAGE,
self.hash_mat_to_bytes(hashers.average(img, self.hash_size)),
),
)
results.extend(
self.lookup_hash(
category,
ImageHashType.DIFFERENCE,
self.hash_mat_to_bytes(hashers.difference(img, self.hash_size)),
),
)
results.extend(
self.lookup_hash(
category,
ImageHashType.DCT,
self.hash_mat_to_bytes(
hashers.dct(img, self.hash_size, self.high_freq_factor),
),
),
)
return results
def result(
self,
img: Mat,
category: ImageCategory,
/,
*,
hash_type: ImageHashType = ImageHashType.DCT,
):
return next(
it for it in self.results(img, category) if it.image_hash_type == hash_type
)

View File

@ -1,3 +0,0 @@
from .chieri import ChieriBotV4Best30Scenario
__all__ = ["ChieriBotV4Best30Scenario"]

View File

@ -1,24 +0,0 @@
from __future__ import annotations
from abc import abstractmethod
from typing import TYPE_CHECKING
from arcaea_offline_ocr.scenarios.base import OcrScenario, OcrScenarioResult
if TYPE_CHECKING:
from arcaea_offline_ocr.types import Mat
class Best30Scenario(OcrScenario):
@abstractmethod
def components(self, img: Mat, /) -> list[Mat]: ...
@abstractmethod
def result(self, component_img: Mat, /, *args, **kwargs) -> OcrScenarioResult: ...
@abstractmethod
def results(self, img: Mat, /, *args, **kwargs) -> list[OcrScenarioResult]:
"""
Commonly a shorthand for `[self.result(comp) for comp in self.components(img)]`
"""
...

View File

@ -1,3 +0,0 @@
from .v4 import ChieriBotV4Best30Scenario
__all__ = ["ChieriBotV4Best30Scenario"]

View File

@ -1,3 +0,0 @@
from .impl import ChieriBotV4Best30Scenario
__all__ = ["ChieriBotV4Best30Scenario"]

View File

@ -1,42 +0,0 @@
from __future__ import annotations
from abc import ABC
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Sequence
if TYPE_CHECKING:
from datetime import datetime
from arcaea_offline_ocr.providers import ImageIdProviderResult
@dataclass(kw_only=True)
class OcrScenarioResult:
song_id: str
rating_class: int
score: int
song_id_results: Sequence[ImageIdProviderResult] = field(default_factory=list)
partner_id_results: Sequence[ImageIdProviderResult] = field(
default_factory=list,
)
pure: int | None = None
pure_inaccurate: int | None = None
pure_early: int | None = None
pure_late: int | None = None
far: int | None = None
far_inaccurate: int | None = None
far_early: int | None = None
far_late: int | None = None
lost: int | None = None
played_at: datetime | None = None
max_recall: int | None = None
clear_status: int | None = None
clear_type: int | None = None
modifier: int | None = None
class OcrScenario(ABC): # noqa: B024
pass

View File

@ -1,13 +0,0 @@
from .extractor import DeviceRoisExtractor
from .impl import DeviceScenario
from .masker import DeviceRoisMaskerAutoT1, DeviceRoisMaskerAutoT2
from .rois import DeviceRoisAutoT1, DeviceRoisAutoT2
__all__ = [
"DeviceRoisAutoT1",
"DeviceRoisAutoT2",
"DeviceRoisExtractor",
"DeviceRoisMaskerAutoT1",
"DeviceRoisMaskerAutoT2",
"DeviceScenario",
]

View File

@ -1,8 +0,0 @@
from abc import abstractmethod
from arcaea_offline_ocr.scenarios.base import OcrScenario, OcrScenarioResult
class DeviceScenarioBase(OcrScenario):
@abstractmethod
def result(self) -> OcrScenarioResult: ...

View File

@ -1,3 +0,0 @@
from .base import DeviceRoisExtractor
__all__ = ["DeviceRoisExtractor"]

View File

@ -1,45 +0,0 @@
from arcaea_offline_ocr.crop import crop_xywh
from arcaea_offline_ocr.scenarios.device.rois import DeviceRois
from arcaea_offline_ocr.types import Mat
class DeviceRoisExtractor:
def __init__(self, img: Mat, rois: DeviceRois):
self.img = img
self.sizes = rois
@property
def pure(self):
return crop_xywh(self.img, self.sizes.pure.rounded())
@property
def far(self):
return crop_xywh(self.img, self.sizes.far.rounded())
@property
def lost(self):
return crop_xywh(self.img, self.sizes.lost.rounded())
@property
def score(self):
return crop_xywh(self.img, self.sizes.score.rounded())
@property
def jacket(self):
return crop_xywh(self.img, self.sizes.jacket.rounded())
@property
def rating_class(self):
return crop_xywh(self.img, self.sizes.rating_class.rounded())
@property
def max_recall(self):
return crop_xywh(self.img, self.sizes.max_recall.rounded())
@property
def clear_status(self):
return crop_xywh(self.img, self.sizes.clear_status.rounded())
@property
def partner_icon(self):
return crop_xywh(self.img, self.sizes.partner_icon.rounded())

View File

@ -1,9 +0,0 @@
from .auto import DeviceRoisMaskerAuto, DeviceRoisMaskerAutoT1, DeviceRoisMaskerAutoT2
from .base import DeviceRoisMasker
__all__ = [
"DeviceRoisMasker",
"DeviceRoisMaskerAuto",
"DeviceRoisMaskerAutoT1",
"DeviceRoisMaskerAutoT2",
]

View File

@ -1,61 +0,0 @@
from abc import ABC, abstractmethod
from arcaea_offline_ocr.types import Mat
class DeviceRoisMasker(ABC):
@classmethod
@abstractmethod
def pure(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def far(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def lost(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def score(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def rating_class_pst(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def rating_class_prs(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def rating_class_ftr(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def rating_class_byd(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def rating_class_etr(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def max_recall(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def clear_status_track_lost(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def clear_status_track_complete(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def clear_status_full_recall(cls, roi_bgr: Mat) -> Mat: ...
@classmethod
@abstractmethod
def clear_status_pure_memory(cls, roi_bgr: Mat) -> Mat: ...

View File

@ -1,9 +0,0 @@
from .auto import DeviceRoisAuto, DeviceRoisAutoT1, DeviceRoisAutoT2
from .base import DeviceRois
__all__ = [
"DeviceRois",
"DeviceRoisAuto",
"DeviceRoisAutoT1",
"DeviceRoisAutoT2",
]

View File

@ -1,33 +0,0 @@
from abc import ABC, abstractmethod
from arcaea_offline_ocr.types import XYWHRect
class DeviceRois(ABC):
@property
@abstractmethod
def pure(self) -> XYWHRect: ...
@property
@abstractmethod
def far(self) -> XYWHRect: ...
@property
@abstractmethod
def lost(self) -> XYWHRect: ...
@property
@abstractmethod
def score(self) -> XYWHRect: ...
@property
@abstractmethod
def rating_class(self) -> XYWHRect: ...
@property
@abstractmethod
def max_recall(self) -> XYWHRect: ...
@property
@abstractmethod
def jacket(self) -> XYWHRect: ...
@property
@abstractmethod
def clear_status(self) -> XYWHRect: ...
@property
@abstractmethod
def partner_icon(self) -> XYWHRect: ...

View File

@ -1,42 +1,25 @@
from math import floor from collections.abc import Iterable
from typing import Callable, NamedTuple, Union from typing import NamedTuple, Tuple, Union
import numpy as np import numpy as np
Mat = np.ndarray Mat = np.ndarray
_IntOrFloat = Union[int, float]
class XYWHRect(NamedTuple): class XYWHRect(NamedTuple):
x: _IntOrFloat x: int
y: _IntOrFloat y: int
w: _IntOrFloat w: int
h: _IntOrFloat h: int
def _to_int(self, func: Callable[[_IntOrFloat], int]): def __add__(self, other: Union["XYWHRect", Tuple[int, int, int, int]]):
return (func(self.x), func(self.y), func(self.w), func(self.h)) if not isinstance(other, Iterable) or len(other) != 4:
raise ValueError()
def rounded(self):
return self._to_int(round)
def floored(self):
return self._to_int(floor)
def __add__(self, other):
if not isinstance(other, (list, tuple)) or len(other) != 4:
raise TypeError
return self.__class__(*[a + b for a, b in zip(self, other)]) return self.__class__(*[a + b for a, b in zip(self, other)])
def __sub__(self, other): def __sub__(self, other: Union["XYWHRect", Tuple[int, int, int, int]]):
if not isinstance(other, (list, tuple)) or len(other) != 4: if not isinstance(other, Iterable) or len(other) != 4:
raise TypeError raise ValueError()
return self.__class__(*[a - b for a, b in zip(self, other)]) return self.__class__(*[a - b for a, b in zip(self, other)])
def __mul__(self, other):
if not isinstance(other, (int, float)):
raise TypeError
return self.__class__(*[v * other for v in self])

View File

@ -1,6 +1,11 @@
from collections.abc import Iterable
from typing import Callable, TypeVar, Union, overload
import cv2 import cv2
import numpy as np import numpy as np
from .types import XYWHRect
__all__ = ["imread_unicode"] __all__ = ["imread_unicode"]
@ -8,3 +13,34 @@ def imread_unicode(filepath: str, flags: int = cv2.IMREAD_UNCHANGED):
# https://stackoverflow.com/a/57872297/16484891 # https://stackoverflow.com/a/57872297/16484891
# CC BY-SA 4.0 # CC BY-SA 4.0
return cv2.imdecode(np.fromfile(filepath, dtype=np.uint8), flags) return cv2.imdecode(np.fromfile(filepath, dtype=np.uint8), flags)
def construct_int_xywh_rect(
rect: XYWHRect, func: Callable[[Union[int, float]], int] = round
):
return XYWHRect(*[func(num) for num in rect])
@overload
def apply_factor(item: int, factor: float) -> float:
...
@overload
def apply_factor(item: float, factor: float) -> float:
...
T = TypeVar("T", bound=Iterable)
@overload
def apply_factor(item: T, factor: float) -> T:
...
def apply_factor(item, factor: float):
if isinstance(item, (int, float)):
return item * factor
if isinstance(item, Iterable):
return item.__class__([i * factor for i in item])