14 Commits

23 changed files with 521 additions and 93 deletions

View File

@ -0,0 +1,48 @@
name: "Build and draft a release"
on:
workflow_dispatch:
push:
tags:
- "v[0-9]+.[0-9]+.[0-9]+"
permissions:
contents: write
discussions: write
jobs:
build-and-draft-release:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python environment
uses: actions/setup-python@v5
with:
python-version: "3.x"
- name: Build package
run: |
pip install build
python -m build
- name: Remove `v` in tag name
uses: mad9000/actions-find-and-replace-string@5
id: tagNameReplaced
with:
source: ${{ github.ref_name }}
find: "v"
replace: ""
- name: Draft a release
uses: softprops/action-gh-release@v2
with:
discussion_category_name: New releases
draft: true
generate_release_notes: true
files: |
dist/arcaea_offline_ocr-${{ steps.tagNameReplaced.outputs.value }}*.whl
dist/arcaea-offline-ocr-${{ steps.tagNameReplaced.outputs.value }}.tar.gz

View File

@ -4,11 +4,10 @@ repos:
hooks: hooks:
- id: end-of-file-fixer - id: end-of-file-fixer
- id: trailing-whitespace - id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 23.1.0 - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.13
hooks: hooks:
- id: black - id: ruff
- repo: https://github.com/PyCQA/isort args: ["--fix"]
rev: 5.12.0 - id: ruff-format
hooks:
- id: isort

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "arcaea-offline-ocr" name = "arcaea-offline-ocr"
version = "0.0.97" version = "0.0.99"
authors = [{ name = "283375", email = "log_283375@163.com" }] authors = [{ name = "283375", email = "log_283375@163.com" }]
description = "Extract your Arcaea play result from screenshot." description = "Extract your Arcaea play result from screenshot."
readme = "README.md" readme = "README.md"
@ -16,8 +16,8 @@ classifiers = [
] ]
[project.urls] [project.urls]
"Homepage" = "https://github.com/283375/arcaea-offline-ocr" "Homepage" = "https://github.com/ArcaeaOffline/core-ocr"
"Bug Tracker" = "https://github.com/283375/arcaea-offline-ocr/issues" "Bug Tracker" = "https://github.com/ArcaeaOffline/core-ocr/issues"
[tool.isort] [tool.isort]
profile = "black" profile = "black"
@ -25,3 +25,14 @@ src_paths = ["src/arcaea_offline_ocr"]
[tool.pyright] [tool.pyright]
ignore = ["**/__debug*.*"] ignore = ["**/__debug*.*"]
[tool.pylint.main]
# extension-pkg-allow-list = ["cv2"]
generated-members = ["cv2.*"]
[tool.pylint.logging]
disable = [
"missing-module-docstring",
"missing-class-docstring",
"missing-function-docstring"
]

View File

@ -1,3 +1,2 @@
attrs==23.1.0 numpy~=2.3
numpy==1.26.1 opencv-python~=4.11
opencv-python==4.8.1.78

View File

@ -1,4 +1,3 @@
from math import floor
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import cv2 import cv2
@ -13,9 +12,21 @@ from ....ocr import (
) )
from ....phash_db import ImagePhashDatabase from ....phash_db import ImagePhashDatabase
from ....types import Mat from ....types import Mat
from ....utils import construct_int_xywh_rect
from ...shared import B30OcrResultItem from ...shared import B30OcrResultItem
from .colors import * from .colors import (
BYD_MAX_HSV,
BYD_MIN_HSV,
FAR_BG_MAX_HSV,
FAR_BG_MIN_HSV,
FTR_MAX_HSV,
FTR_MIN_HSV,
LOST_BG_MAX_HSV,
LOST_BG_MIN_HSV,
PRS_MAX_HSV,
PRS_MIN_HSV,
PURE_BG_MAX_HSV,
PURE_BG_MIN_HSV,
)
from .rois import ChieriBotV4Rois from .rois import ChieriBotV4Rois
@ -25,7 +36,7 @@ class ChieriBotV4Ocr:
score_knn: cv2.ml.KNearest, score_knn: cv2.ml.KNearest,
pfl_knn: cv2.ml.KNearest, pfl_knn: cv2.ml.KNearest,
phash_db: ImagePhashDatabase, phash_db: ImagePhashDatabase,
factor: Optional[float] = 1.0, factor: float = 1.0,
): ):
self.__score_knn = score_knn self.__score_knn = score_knn
self.__pfl_knn = pfl_knn self.__pfl_knn = pfl_knn
@ -72,9 +83,8 @@ class ChieriBotV4Ocr:
self.factor = img.shape[0] / 4400 self.factor = img.shape[0] / 4400
def ocr_component_rating_class(self, component_bgr: Mat) -> int: def ocr_component_rating_class(self, component_bgr: Mat) -> int:
rating_class_rect = construct_int_xywh_rect( rating_class_rect = self.rois.component_rois.rating_class_rect.rounded()
self.rois.component_rois.rating_class_rect
)
rating_class_roi = crop_xywh(component_bgr, rating_class_rect) rating_class_roi = crop_xywh(component_bgr, rating_class_rect)
rating_class_roi = cv2.cvtColor(rating_class_roi, cv2.COLOR_BGR2HSV) rating_class_roi = cv2.cvtColor(rating_class_roi, cv2.COLOR_BGR2HSV)
rating_class_masks = [ rating_class_masks = [
@ -89,9 +99,7 @@ class ChieriBotV4Ocr:
return max(enumerate(rating_class_results), key=lambda i: i[1])[0] + 1 return max(enumerate(rating_class_results), key=lambda i: i[1])[0] + 1
def ocr_component_song_id(self, component_bgr: Mat): def ocr_component_song_id(self, component_bgr: Mat):
jacket_rect = construct_int_xywh_rect( jacket_rect = self.rois.component_rois.jacket_rect.floored()
self.rois.component_rois.jacket_rect, floor
)
jacket_roi = cv2.cvtColor( jacket_roi = cv2.cvtColor(
crop_xywh(component_bgr, jacket_rect), cv2.COLOR_BGR2GRAY crop_xywh(component_bgr, jacket_rect), cv2.COLOR_BGR2GRAY
) )
@ -99,7 +107,7 @@ class ChieriBotV4Ocr:
def ocr_component_score_knn(self, component_bgr: Mat) -> int: def ocr_component_score_knn(self, component_bgr: Mat) -> int:
# sourcery skip: inline-immediately-returned-variable # sourcery skip: inline-immediately-returned-variable
score_rect = construct_int_xywh_rect(self.rois.component_rois.score_rect) score_rect = self.rois.component_rois.score_rect.rounded()
score_roi = cv2.cvtColor( score_roi = cv2.cvtColor(
crop_xywh(component_bgr, score_rect), cv2.COLOR_BGR2GRAY crop_xywh(component_bgr, score_rect), cv2.COLOR_BGR2GRAY
) )
@ -119,7 +127,9 @@ class ChieriBotV4Ocr:
score_roi = cv2.fillPoly(score_roi, [contour], 0) score_roi = cv2.fillPoly(score_roi, [contour], 0)
return ocr_digits_by_contour_knn(score_roi, self.score_knn) return ocr_digits_by_contour_knn(score_roi, self.score_knn)
def find_pfl_rects(self, component_pfl_processed: Mat) -> List[List[int]]: def find_pfl_rects(
self, component_pfl_processed: Mat
) -> List[Tuple[int, int, int, int]]:
# sourcery skip: inline-immediately-returned-variable # sourcery skip: inline-immediately-returned-variable
pfl_roi_find = cv2.morphologyEx( pfl_roi_find = cv2.morphologyEx(
component_pfl_processed, component_pfl_processed,
@ -146,7 +156,7 @@ class ChieriBotV4Ocr:
return pfl_rects_adjusted return pfl_rects_adjusted
def preprocess_component_pfl(self, component_bgr: Mat) -> Mat: def preprocess_component_pfl(self, component_bgr: Mat) -> Mat:
pfl_rect = construct_int_xywh_rect(self.rois.component_rois.pfl_rect) pfl_rect = self.rois.component_rois.pfl_rect.rounded()
pfl_roi = crop_xywh(component_bgr, pfl_rect) pfl_roi = crop_xywh(component_bgr, pfl_rect)
pfl_roi_hsv = cv2.cvtColor(pfl_roi, cv2.COLOR_BGR2HSV) pfl_roi_hsv = cv2.cvtColor(pfl_roi, cv2.COLOR_BGR2HSV)

View File

@ -1,12 +1,12 @@
from typing import List, Optional from typing import List
from ....crop import crop_xywh from ....crop import crop_xywh
from ....types import Mat, XYWHRect from ....types import Mat, XYWHRect
from ....utils import apply_factor, construct_int_xywh_rect from ....utils import apply_factor
class ChieriBotV4ComponentRois: class ChieriBotV4ComponentRois:
def __init__(self, factor: Optional[float] = 1.0): def __init__(self, factor: float = 1.0):
self.__factor = factor self.__factor = factor
@property @property
@ -19,11 +19,11 @@ class ChieriBotV4ComponentRois:
@property @property
def top_font_color_detect(self): def top_font_color_detect(self):
return apply_factor((35, 10, 120, 100), self.factor) return apply_factor(XYWHRect(35, 10, 120, 100), self.factor)
@property @property
def bottom_font_color_detect(self): def bottom_font_color_detect(self):
return apply_factor((30, 125, 175, 110), self.factor) return apply_factor(XYWHRect(30, 125, 175, 110), self.factor)
@property @property
def bg_point(self): def bg_point(self):
@ -31,31 +31,31 @@ class ChieriBotV4ComponentRois:
@property @property
def rating_class_rect(self): def rating_class_rect(self):
return apply_factor((21, 40, 7, 20), self.factor) return apply_factor(XYWHRect(21, 40, 7, 20), self.factor)
@property @property
def title_rect(self): def title_rect(self):
return apply_factor((35, 10, 430, 50), self.factor) return apply_factor(XYWHRect(35, 10, 430, 50), self.factor)
@property @property
def jacket_rect(self): def jacket_rect(self):
return apply_factor((263, 0, 239, 239), self.factor) return apply_factor(XYWHRect(263, 0, 239, 239), self.factor)
@property @property
def score_rect(self): def score_rect(self):
return apply_factor((30, 60, 270, 55), self.factor) return apply_factor(XYWHRect(30, 60, 270, 55), self.factor)
@property @property
def pfl_rect(self): def pfl_rect(self):
return apply_factor((50, 125, 80, 100), self.factor) return apply_factor(XYWHRect(50, 125, 80, 100), self.factor)
@property @property
def date_rect(self): def date_rect(self):
return apply_factor((205, 200, 225, 25), self.factor) return apply_factor(XYWHRect(205, 200, 225, 25), self.factor)
class ChieriBotV4Rois: class ChieriBotV4Rois:
def __init__(self, factor: Optional[float] = 1.0): def __init__(self, factor: float = 1.0):
self.__factor = factor self.__factor = factor
self.__component_rois = ChieriBotV4ComponentRois(factor) self.__component_rois = ChieriBotV4ComponentRois(factor)
@ -100,9 +100,7 @@ class ChieriBotV4Rois:
def horizontal_items(self): def horizontal_items(self):
return 3 return 3
@property vertical_items = 10
def vertical_items(self):
return 10
@property @property
def b33_vertical_gap(self): def b33_vertical_gap(self):
@ -112,16 +110,17 @@ class ChieriBotV4Rois:
first_rect = XYWHRect(x=self.left, y=self.top, w=self.width, h=self.height) first_rect = XYWHRect(x=self.left, y=self.top, w=self.width, h=self.height)
results = [] results = []
last_rect = first_rect
for vi in range(self.vertical_items): for vi in range(self.vertical_items):
rect = XYWHRect(*first_rect) rect = XYWHRect(*first_rect)
rect += (0, (self.vertical_gap + self.height) * vi, 0, 0) rect += (0, (self.vertical_gap + self.height) * vi, 0, 0)
for hi in range(self.horizontal_items): for hi in range(self.horizontal_items):
if hi > 0: if hi > 0:
rect += ((self.width + self.horizontal_gap), 0, 0, 0) rect += ((self.width + self.horizontal_gap), 0, 0, 0)
int_rect = construct_int_xywh_rect(rect) results.append(crop_xywh(img_bgr, rect.rounded()))
results.append(crop_xywh(img_bgr, int_rect)) last_rect = rect
rect += ( last_rect += (
-(self.width + self.horizontal_gap) * 2, -(self.width + self.horizontal_gap) * 2,
self.height + self.b33_vertical_gap, self.height + self.b33_vertical_gap,
0, 0,
@ -129,8 +128,7 @@ class ChieriBotV4Rois:
) )
for hi in range(self.horizontal_items): for hi in range(self.horizontal_items):
if hi > 0: if hi > 0:
rect += ((self.width + self.horizontal_gap), 0, 0, 0) last_rect += ((self.width + self.horizontal_gap), 0, 0, 0)
int_rect = construct_int_xywh_rect(rect) results.append(crop_xywh(img_bgr, last_rect.rounded()))
results.append(crop_xywh(img_bgr, int_rect))
return results return results

View File

@ -1,10 +1,9 @@
from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Optional, Union from typing import Optional
import attrs
@attrs.define @dataclass
class B30OcrResultItem: class B30OcrResultItem:
rating_class: int rating_class: int
score: int score: int

View File

View File

@ -0,0 +1,3 @@
from .index import average, dct, difference
__all__ = ["average", "dct", "difference"]

View File

@ -0,0 +1,7 @@
import cv2
from arcaea_offline_ocr.types import Mat
def _resize_image(src: Mat, dsize: ...) -> Mat:
return cv2.resize(src, dsize, fx=0, fy=0, interpolation=cv2.INTER_AREA)

View File

@ -0,0 +1,35 @@
import cv2
import numpy as np
from arcaea_offline_ocr.types import Mat
from ._common import _resize_image
def average(img_gray: Mat, hash_size: int) -> Mat:
img_resized = _resize_image(img_gray, (hash_size, hash_size))
diff = img_resized > img_resized.mean()
return diff.flatten()
def difference(img_gray: Mat, hash_size: int) -> Mat:
img_size = (hash_size + 1, hash_size)
img_resized = _resize_image(img_gray, img_size)
previous = img_resized[:, :-1]
current = img_resized[:, 1:]
diff = previous > current
return diff.flatten()
def dct(img_gray: Mat, hash_size: int = 16, high_freq_factor: int = 4) -> Mat:
# TODO: consistency?
img_size_base = hash_size * high_freq_factor
img_size = (img_size_base, img_size_base)
img_resized = _resize_image(img_gray, img_size)
img_resized = img_resized.astype(np.float32)
dct_mat = cv2.dct(img_resized)
hash_mat = dct_mat[:hash_size, :hash_size]
return hash_mat > hash_mat.mean()

View File

@ -0,0 +1,18 @@
from .builder import ImageHashesDatabaseBuilder
from .index import ImageHashesDatabase, ImageHashesDatabasePropertyMissingError
from .models import (
ImageHashBuildTask,
ImageHashHashType,
ImageHashResult,
ImageHashCategory,
)
__all__ = [
"ImageHashesDatabase",
"ImageHashesDatabasePropertyMissingError",
"ImageHashHashType",
"ImageHashResult",
"ImageHashCategory",
"ImageHashesDatabaseBuilder",
"ImageHashBuildTask",
]

View File

@ -0,0 +1,85 @@
import logging
from datetime import datetime, timezone
from sqlite3 import Connection
from typing import List
from arcaea_offline_ocr.core import hashers
from .index import ImageHashesDatabase
from .models import ImageHash, ImageHashBuildTask, ImageHashHashType
logger = logging.getLogger(__name__)
class ImageHashesDatabaseBuilder:
@staticmethod
def __insert_property(conn: Connection, key: str, value: str):
return conn.execute(
"INSERT INTO properties (key, value) VALUES (?, ?)",
(key, value),
)
@classmethod
def build(
cls,
conn: Connection,
tasks: List[ImageHashBuildTask],
*,
hash_size: int = 16,
high_freq_factor: int = 4,
):
rows: List[ImageHash] = []
for task in tasks:
try:
img_gray = task.imread_function(task.image_path)
for hash_type, hash_mat in [
(
ImageHashHashType.AVERAGE,
hashers.average(img_gray, hash_size),
),
(
ImageHashHashType.DCT,
hashers.dct(img_gray, hash_size, high_freq_factor),
),
(
ImageHashHashType.DIFFERENCE,
hashers.difference(img_gray, hash_size),
),
]:
rows.append(
ImageHash(
hash_type=hash_type,
category=task.category,
label=task.label,
hash=ImageHashesDatabase.hash_mat_to_bytes(hash_mat),
)
)
except Exception:
logger.exception("Error processing task %r", task)
conn.execute("CREATE TABLE properties (`key` VARCHAR, `value` VARCHAR)")
conn.execute(
"CREATE TABLE hashes (`hash_type` INTEGER, `category` INTEGER, `label` VARCHAR, `hash` BLOB)"
)
now = datetime.now(tz=timezone.utc)
timestamp = int(now.timestamp() * 1000)
cls.__insert_property(conn, ImageHashesDatabase.KEY_HASH_SIZE, str(hash_size))
cls.__insert_property(
conn, ImageHashesDatabase.KEY_HIGH_FREQ_FACTOR, str(high_freq_factor)
)
cls.__insert_property(
conn, ImageHashesDatabase.KEY_BUILT_TIMESTAMP, str(timestamp)
)
conn.executemany(
"INSERT INTO hashes (hash_type, category, label, hash) VALUES (?, ?, ?, ?)",
[
(row.hash_type.value, row.category.value, row.label, row.hash)
for row in rows
],
)
conn.commit()

View File

@ -0,0 +1,144 @@
import sqlite3
from datetime import datetime, timezone
from typing import Any, Callable, List, Optional, TypeVar
from arcaea_offline_ocr.core import hashers
from arcaea_offline_ocr.types import Mat
from .models import ImageHashHashType, ImageHashResult, ImageHashCategory
T = TypeVar("T")
def _sql_hamming_distance(hash1: bytes, hash2: bytes):
assert len(hash1) == len(hash2), "hash size does not match!"
count = sum(1 for byte1, byte2 in zip(hash1, hash2) if byte1 != byte2)
return count
class ImageHashesDatabasePropertyMissingError(Exception):
pass
class ImageHashesDatabase:
KEY_HASH_SIZE = "hash_size"
KEY_HIGH_FREQ_FACTOR = "high_freq_factor"
KEY_BUILT_TIMESTAMP = "built_timestamp"
def __init__(self, conn: sqlite3.Connection):
self.conn = conn
self.conn.create_function("HAMMING_DISTANCE", 2, _sql_hamming_distance)
self._hash_size: int = -1
self._high_freq_factor: int = -1
self._built_time: Optional[datetime] = None
self._hashes_count = {
ImageHashCategory.JACKET: 0,
ImageHashCategory.PARTNER_ICON: 0,
}
self._hash_length: int = -1
self._initialize()
@property
def hash_size(self):
return self._hash_size
@property
def high_freq_factor(self):
return self._high_freq_factor
@property
def hash_length(self):
return self._hash_length
def _initialize(self):
def query_property(key, convert_func: Callable[[Any], T]) -> Optional[T]:
result = self.conn.execute(
"SELECT value FROM properties WHERE key = ?",
(key,),
).fetchone()
return convert_func(result[0]) if result is not None else None
def set_hashes_count(category: ImageHashCategory):
self._hashes_count[category] = self.conn.execute(
"SELECT COUNT(DISTINCT label) FROM hashes WHERE category = ?",
(category.value,),
).fetchone()[0]
hash_size = query_property(self.KEY_HASH_SIZE, lambda x: int(x))
if hash_size is None:
raise ImageHashesDatabasePropertyMissingError("hash_size")
self._hash_size = hash_size
high_freq_factor = query_property(self.KEY_HIGH_FREQ_FACTOR, lambda x: int(x))
if high_freq_factor is None:
raise ImageHashesDatabasePropertyMissingError("high_freq_factor")
self._high_freq_factor = high_freq_factor
self._built_time = query_property(
self.KEY_BUILT_TIMESTAMP,
lambda ts: datetime.fromtimestamp(int(ts) / 1000, tz=timezone.utc),
)
set_hashes_count(ImageHashCategory.JACKET)
set_hashes_count(ImageHashCategory.PARTNER_ICON)
self._hash_length = self._hash_size**2
def lookup_hash(
self, category: ImageHashCategory, hash_type: ImageHashHashType, hash: bytes
) -> List[ImageHashResult]:
cursor = self.conn.execute(
"SELECT"
" label,"
" HAMMING_DISTANCE(hash, ?) AS distance"
" FROM hashes"
" WHERE category = ? AND hash_type = ?"
" ORDER BY distance ASC LIMIT 10",
(hash, category.value, hash_type.value),
)
results = []
for label, distance in cursor.fetchall():
results.append(
ImageHashResult(
hash_type=hash_type,
category=category,
label=label,
confidence=(self.hash_length - distance) / self.hash_length,
)
)
return results
@staticmethod
def hash_mat_to_bytes(hash: Mat) -> bytes:
return bytes([255 if b else 0 for b in hash.flatten()])
def identify_image(self, category: ImageHashCategory, img) -> List[ImageHashResult]:
results = []
ahash = hashers.average(img, self.hash_size)
dhash = hashers.difference(img, self.hash_size)
phash = hashers.dct(img, self.hash_size, self.high_freq_factor)
results.extend(
self.lookup_hash(
category, ImageHashHashType.AVERAGE, self.hash_mat_to_bytes(ahash)
)
)
results.extend(
self.lookup_hash(
category, ImageHashHashType.DIFFERENCE, self.hash_mat_to_bytes(dhash)
)
)
results.extend(
self.lookup_hash(
category, ImageHashHashType.DCT, self.hash_mat_to_bytes(phash)
)
)
return results

View File

@ -0,0 +1,46 @@
import dataclasses
from enum import IntEnum
from typing import Callable
import cv2
from arcaea_offline_ocr.types import Mat
class ImageHashHashType(IntEnum):
AVERAGE = 0
DIFFERENCE = 1
DCT = 2
class ImageHashCategory(IntEnum):
JACKET = 0
PARTNER_ICON = 1
@dataclasses.dataclass
class ImageHash:
hash_type: ImageHashHashType
category: ImageHashCategory
label: str
hash: bytes
@dataclasses.dataclass
class ImageHashResult:
hash_type: ImageHashHashType
category: ImageHashCategory
label: str
confidence: float
def _default_imread_gray(image_path: str):
return cv2.cvtColor(cv2.imread(image_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2GRAY)
@dataclasses.dataclass
class ImageHashBuildTask:
image_path: str
category: ImageHashCategory
label: str
imread_function: Callable[[str], Mat] = _default_imread_gray

View File

@ -1,9 +1,8 @@
from dataclasses import dataclass
from typing import Optional from typing import Optional
import attrs
@dataclass
@attrs.define
class DeviceOcrResult: class DeviceOcrResult:
rating_class: int rating_class: int
pure: int pure: int

View File

@ -67,8 +67,9 @@ class DeviceOcr:
roi = self.masker.score(self.extractor.score) roi = self.masker.score(self.extractor.score)
contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
for contour in contours: for contour in contours:
x, y, w, h = cv2.boundingRect(contour) if (
if h < roi.shape[0] * 0.6: cv2.boundingRect(contour)[3] < roi.shape[0] * 0.6
): # h < score_component_h * 0.6
roi = cv2.fillPoly(roi, [contour], [0]) roi = cv2.fillPoly(roi, [contour], [0])
return ocr_digits_by_contour_knn(roi, self.knn_model) return ocr_digits_by_contour_knn(roi, self.knn_model)
@ -79,6 +80,7 @@ class DeviceOcr:
self.masker.rating_class_prs(roi), self.masker.rating_class_prs(roi),
self.masker.rating_class_ftr(roi), self.masker.rating_class_ftr(roi),
self.masker.rating_class_byd(roi), self.masker.rating_class_byd(roi),
self.masker.rating_class_etr(roi),
] ]
return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0] return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0]
@ -108,7 +110,7 @@ class DeviceOcr:
@staticmethod @staticmethod
def preprocess_char_icon(img_gray: Mat): def preprocess_char_icon(img_gray: Mat):
h, w = img_gray.shape[:2] h, w = img_gray.shape[:2]
img = cv2.copyMakeBorder(img_gray, w - h, 0, 0, 0, cv2.BORDER_REPLICATE) img = cv2.copyMakeBorder(img_gray, max(w - h, 0), 0, 0, 0, cv2.BORDER_REPLICATE)
h, w = img.shape[:2] h, w = img.shape[:2]
img = cv2.fillPoly( img = cv2.fillPoly(
img, img,

View File

@ -6,6 +6,8 @@ from .common import DeviceRoisMasker
class DeviceRoisMaskerAuto(DeviceRoisMasker): class DeviceRoisMaskerAuto(DeviceRoisMasker):
# pylint: disable=abstract-method
@staticmethod @staticmethod
def mask_bgr_in_hsv(roi_bgr: Mat, hsv_lower: Mat, hsv_upper: Mat): def mask_bgr_in_hsv(roi_bgr: Mat, hsv_lower: Mat, hsv_upper: Mat):
return cv2.inRange( return cv2.inRange(
@ -32,6 +34,9 @@ class DeviceRoisMaskerAutoT1(DeviceRoisMaskerAuto):
BYD_HSV_MIN = np.array([170, 50, 50], np.uint8) BYD_HSV_MIN = np.array([170, 50, 50], np.uint8)
BYD_HSV_MAX = np.array([179, 210, 198], np.uint8) BYD_HSV_MAX = np.array([179, 210, 198], np.uint8)
ETR_HSV_MIN = np.array([130, 60, 80], np.uint8)
ETR_HSV_MAX = np.array([140, 145, 180], np.uint8)
TRACK_LOST_HSV_MIN = np.array([170, 75, 90], np.uint8) TRACK_LOST_HSV_MIN = np.array([170, 75, 90], np.uint8)
TRACK_LOST_HSV_MAX = np.array([175, 170, 160], np.uint8) TRACK_LOST_HSV_MAX = np.array([175, 170, 160], np.uint8)
@ -85,6 +90,10 @@ class DeviceRoisMaskerAutoT1(DeviceRoisMaskerAuto):
def rating_class_byd(cls, roi_bgr: Mat) -> Mat: def rating_class_byd(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv(roi_bgr, cls.BYD_HSV_MIN, cls.BYD_HSV_MAX) return cls.mask_bgr_in_hsv(roi_bgr, cls.BYD_HSV_MIN, cls.BYD_HSV_MAX)
@classmethod
def rating_class_etr(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv(roi_bgr, cls.ETR_HSV_MIN, cls.ETR_HSV_MAX)
@classmethod @classmethod
def max_recall(cls, roi_bgr: Mat) -> Mat: def max_recall(cls, roi_bgr: Mat) -> Mat:
return cls.gray(roi_bgr) return cls.gray(roi_bgr)
@ -116,7 +125,7 @@ class DeviceRoisMaskerAutoT1(DeviceRoisMaskerAuto):
class DeviceRoisMaskerAutoT2(DeviceRoisMaskerAuto): class DeviceRoisMaskerAutoT2(DeviceRoisMaskerAuto):
PFL_HSV_MIN = np.array([0, 0, 248], np.uint8) PFL_HSV_MIN = np.array([0, 0, 248], np.uint8)
PFL_HSV_MAX = np.array([179, 10, 255], np.uint8) PFL_HSV_MAX = np.array([179, 40, 255], np.uint8)
SCORE_HSV_MIN = np.array([0, 0, 180], np.uint8) SCORE_HSV_MIN = np.array([0, 0, 180], np.uint8)
SCORE_HSV_MAX = np.array([179, 255, 255], np.uint8) SCORE_HSV_MAX = np.array([179, 255, 255], np.uint8)
@ -133,6 +142,9 @@ class DeviceRoisMaskerAutoT2(DeviceRoisMaskerAuto):
BYD_HSV_MIN = np.array([170, 50, 50], np.uint8) BYD_HSV_MIN = np.array([170, 50, 50], np.uint8)
BYD_HSV_MAX = np.array([179, 210, 198], np.uint8) BYD_HSV_MAX = np.array([179, 210, 198], np.uint8)
ETR_HSV_MIN = np.array([130, 60, 80], np.uint8)
ETR_HSV_MAX = np.array([140, 145, 180], np.uint8)
MAX_RECALL_HSV_MIN = np.array([125, 0, 0], np.uint8) MAX_RECALL_HSV_MIN = np.array([125, 0, 0], np.uint8)
MAX_RECALL_HSV_MAX = np.array([145, 100, 150], np.uint8) MAX_RECALL_HSV_MAX = np.array([145, 100, 150], np.uint8)
@ -184,6 +196,10 @@ class DeviceRoisMaskerAutoT2(DeviceRoisMaskerAuto):
def rating_class_byd(cls, roi_bgr: Mat) -> Mat: def rating_class_byd(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv(roi_bgr, cls.BYD_HSV_MIN, cls.BYD_HSV_MAX) return cls.mask_bgr_in_hsv(roi_bgr, cls.BYD_HSV_MIN, cls.BYD_HSV_MAX)
@classmethod
def rating_class_etr(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv(roi_bgr, cls.ETR_HSV_MIN, cls.ETR_HSV_MAX)
@classmethod @classmethod
def max_recall(cls, roi_bgr: Mat) -> Mat: def max_recall(cls, roi_bgr: Mat) -> Mat:
return cls.mask_bgr_in_hsv( return cls.mask_bgr_in_hsv(

View File

@ -34,6 +34,10 @@ class DeviceRoisMasker:
def rating_class_byd(cls, roi_bgr: Mat) -> Mat: def rating_class_byd(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError() raise NotImplementedError()
@classmethod
def rating_class_etr(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError()
@classmethod @classmethod
def max_recall(cls, roi_bgr: Mat) -> Mat: def max_recall(cls, roi_bgr: Mat) -> Mat:
raise NotImplementedError() raise NotImplementedError()

View File

@ -36,7 +36,7 @@ class FixRects:
if rect in consumed_rects: if rect in consumed_rects:
continue continue
x, y, w, h = rect x, _, w, h = rect
# grab those small rects # grab those small rects
if not img_height * 0.1 <= h <= img_height * 0.6: if not img_height * 0.1 <= h <= img_height * 0.6:
continue continue
@ -46,7 +46,7 @@ class FixRects:
for other_rect in rects: for other_rect in rects:
if rect == other_rect: if rect == other_rect:
continue continue
ox, oy, ow, oh = other_rect ox, _, ow, _ = other_rect
if abs(x - ox) < tolerance and abs((x + w) - (ox + ow)) < tolerance: if abs(x - ox) < tolerance and abs((x + w) - (ox + ow)) < tolerance:
group.append(other_rect) group.append(other_rect)

View File

@ -12,7 +12,8 @@ def phash_opencv(img_gray, hash_size=8, highfreq_factor=4):
""" """
Perceptual Hash computation. Perceptual Hash computation.
Implementation follows http://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html Implementation follows
http://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html
Adapted from `imagehash.phash`, pure opencv implementation Adapted from `imagehash.phash`, pure opencv implementation
@ -69,14 +70,14 @@ class ImagePhashDatabase:
self.partner_icon_ids: List[str] = [] self.partner_icon_ids: List[str] = []
self.partner_icon_hashes = [] self.partner_icon_hashes = []
for id, hash in zip(self.ids, self.hashes): for _id, _hash in zip(self.ids, self.hashes):
id_splitted = id.split("||") id_splitted = _id.split("||")
if len(id_splitted) > 1 and id_splitted[0] == "partner_icon": if len(id_splitted) > 1 and id_splitted[0] == "partner_icon":
self.partner_icon_ids.append(id_splitted[1]) self.partner_icon_ids.append(id_splitted[1])
self.partner_icon_hashes.append(hash) self.partner_icon_hashes.append(_hash)
else: else:
self.jacket_ids.append(id) self.jacket_ids.append(_id)
self.jacket_hashes.append(hash) self.jacket_hashes.append(_hash)
def calculate_phash(self, img_gray: Mat): def calculate_phash(self, img_gray: Mat):
return phash_opencv( return phash_opencv(

View File

@ -1,25 +1,36 @@
from collections.abc import Iterable from math import floor
from typing import NamedTuple, Tuple, Union from typing import Callable, NamedTuple, Union
import numpy as np import numpy as np
Mat = np.ndarray Mat = np.ndarray
_IntOrFloat = Union[int, float]
class XYWHRect(NamedTuple): class XYWHRect(NamedTuple):
x: int x: _IntOrFloat
y: int y: _IntOrFloat
w: int w: _IntOrFloat
h: int h: _IntOrFloat
def __add__(self, other: Union["XYWHRect", Tuple[int, int, int, int]]): def _to_int(self, func: Callable[[_IntOrFloat], int]):
if not isinstance(other, Iterable) or len(other) != 4: return (func(self.x), func(self.y), func(self.w), func(self.h))
raise ValueError()
def rounded(self):
return self._to_int(round)
def floored(self):
return self._to_int(floor)
def __add__(self, other):
if not isinstance(other, (list, tuple)) or len(other) != 4:
raise TypeError()
return self.__class__(*[a + b for a, b in zip(self, other)]) return self.__class__(*[a + b for a, b in zip(self, other)])
def __sub__(self, other: Union["XYWHRect", Tuple[int, int, int, int]]): def __sub__(self, other):
if not isinstance(other, Iterable) or len(other) != 4: if not isinstance(other, (list, tuple)) or len(other) != 4:
raise ValueError() raise TypeError()
return self.__class__(*[a - b for a, b in zip(self, other)]) return self.__class__(*[a - b for a, b in zip(self, other)])

View File

@ -1,5 +1,5 @@
from collections.abc import Iterable from collections.abc import Iterable
from typing import Callable, TypeVar, Union, overload from typing import TypeVar, overload
import cv2 import cv2
import numpy as np import numpy as np
@ -15,32 +15,25 @@ def imread_unicode(filepath: str, flags: int = cv2.IMREAD_UNCHANGED):
return cv2.imdecode(np.fromfile(filepath, dtype=np.uint8), flags) return cv2.imdecode(np.fromfile(filepath, dtype=np.uint8), flags)
def construct_int_xywh_rect( @overload
rect: XYWHRect, func: Callable[[Union[int, float]], int] = round def apply_factor(item: int, factor: float) -> float: ...
):
return XYWHRect(*[func(num) for num in rect])
@overload @overload
def apply_factor(item: int, factor: float) -> float: def apply_factor(item: float, factor: float) -> float: ...
...
@overload
def apply_factor(item: float, factor: float) -> float:
...
T = TypeVar("T", bound=Iterable) T = TypeVar("T", bound=Iterable)
@overload @overload
def apply_factor(item: T, factor: float) -> T: def apply_factor(item: T, factor: float) -> T: ...
...
def apply_factor(item, factor: float): def apply_factor(item, factor: float):
if isinstance(item, (int, float)): if isinstance(item, (int, float)):
return item * factor return item * factor
elif isinstance(item, Iterable): if isinstance(item, XYWHRect):
return item.__class__(*[i * factor for i in item])
if isinstance(item, Iterable):
return item.__class__([i * factor for i in item]) return item.__class__([i * factor for i in item])