mirror of
https://github.com/283375/arcaea-offline-ocr.git
synced 2025-07-01 12:26:27 +00:00
chore: apply ruff rules
This commit is contained in:
@ -1,14 +1,17 @@
|
||||
import sqlite3
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from enum import IntEnum
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar
|
||||
|
||||
from arcaea_offline_ocr.core import hashers
|
||||
|
||||
from .base import ImageCategory, ImageIdProvider, ImageIdProviderResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sqlite3
|
||||
|
||||
from arcaea_offline_ocr.types import Mat
|
||||
|
||||
|
||||
@ -19,9 +22,11 @@ PROP_KEY_BUILT_AT = "built_at"
|
||||
|
||||
|
||||
def _sql_hamming_distance(hash1: bytes, hash2: bytes):
|
||||
assert len(hash1) == len(hash2), "hash size does not match!"
|
||||
count = sum(1 for byte1, byte2 in zip(hash1, hash2) if byte1 != byte2)
|
||||
return count
|
||||
if len(hash1) != len(hash2):
|
||||
msg = "hash size does not match!"
|
||||
raise ValueError(msg)
|
||||
|
||||
return sum(1 for byte1, byte2 in zip(hash1, hash2) if byte1 != byte2)
|
||||
|
||||
|
||||
class ImageHashType(IntEnum):
|
||||
@ -36,7 +41,7 @@ class ImageHashDatabaseIdProviderResult(ImageIdProviderResult):
|
||||
|
||||
|
||||
class MissingPropertiesError(Exception):
|
||||
keys: List[str]
|
||||
keys: list[str]
|
||||
|
||||
def __init__(self, keys, *args):
|
||||
super().__init__(*args)
|
||||
@ -72,7 +77,7 @@ class ImageHashDatabaseIdProvider(ImageIdProvider):
|
||||
return self.properties[PROP_KEY_HIGH_FREQ_FACTOR]
|
||||
|
||||
@property
|
||||
def built_at(self) -> Optional[datetime]:
|
||||
def built_at(self) -> datetime | None:
|
||||
return self.properties.get(PROP_KEY_BUILT_AT)
|
||||
|
||||
@property
|
||||
@ -80,7 +85,7 @@ class ImageHashDatabaseIdProvider(ImageIdProvider):
|
||||
return self._hash_length
|
||||
|
||||
def _initialize(self):
|
||||
def get_property(key, converter: Callable[[Any], T]) -> Optional[T]:
|
||||
def get_property(key, converter: Callable[[Any], T]) -> T | None:
|
||||
result = self.conn.execute(
|
||||
"SELECT value FROM properties WHERE key = ?",
|
||||
(key,),
|
||||
@ -97,7 +102,8 @@ class ImageHashDatabaseIdProvider(ImageIdProvider):
|
||||
PROP_KEY_HASH_SIZE: lambda x: int(x),
|
||||
PROP_KEY_HIGH_FREQ_FACTOR: lambda x: int(x),
|
||||
PROP_KEY_BUILT_AT: lambda ts: datetime.fromtimestamp(
|
||||
int(ts) / 1000, tz=timezone.utc
|
||||
int(ts) / 1000,
|
||||
tz=timezone.utc,
|
||||
),
|
||||
}
|
||||
required_properties = [PROP_KEY_HASH_SIZE, PROP_KEY_HIGH_FREQ_FACTOR]
|
||||
@ -122,8 +128,11 @@ class ImageHashDatabaseIdProvider(ImageIdProvider):
|
||||
self._hash_length = self.hash_size**2
|
||||
|
||||
def lookup_hash(
|
||||
self, category: ImageCategory, hash_type: ImageHashType, hash: bytes
|
||||
) -> List[ImageHashDatabaseIdProviderResult]:
|
||||
self,
|
||||
category: ImageCategory,
|
||||
hash_type: ImageHashType,
|
||||
hash_data: bytes,
|
||||
) -> list[ImageHashDatabaseIdProviderResult]:
|
||||
cursor = self.conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
@ -132,7 +141,7 @@ SELECT
|
||||
FROM hashes
|
||||
WHERE category = ? AND hash_type = ?
|
||||
ORDER BY distance ASC LIMIT 10""",
|
||||
(hash, category.value, hash_type.value),
|
||||
(hash_data, category.value, hash_type.value),
|
||||
)
|
||||
|
||||
results = []
|
||||
@ -143,52 +152,52 @@ ORDER BY distance ASC LIMIT 10""",
|
||||
category=category,
|
||||
confidence=(self.hash_length - distance) / self.hash_length,
|
||||
image_hash_type=hash_type,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def hash_mat_to_bytes(hash: "Mat") -> bytes:
|
||||
return bytes([255 if b else 0 for b in hash.flatten()])
|
||||
def hash_mat_to_bytes(hash_mat: Mat) -> bytes:
|
||||
return bytes([255 if b else 0 for b in hash_mat.flatten()])
|
||||
|
||||
def results(self, img: "Mat", category: ImageCategory, /):
|
||||
results: List[ImageHashDatabaseIdProviderResult] = []
|
||||
def results(self, img: Mat, category: ImageCategory, /):
|
||||
results: list[ImageHashDatabaseIdProviderResult] = []
|
||||
|
||||
results.extend(
|
||||
self.lookup_hash(
|
||||
category,
|
||||
ImageHashType.AVERAGE,
|
||||
self.hash_mat_to_bytes(hashers.average(img, self.hash_size)),
|
||||
)
|
||||
),
|
||||
)
|
||||
results.extend(
|
||||
self.lookup_hash(
|
||||
category,
|
||||
ImageHashType.DIFFERENCE,
|
||||
self.hash_mat_to_bytes(hashers.difference(img, self.hash_size)),
|
||||
)
|
||||
),
|
||||
)
|
||||
results.extend(
|
||||
self.lookup_hash(
|
||||
category,
|
||||
ImageHashType.DCT,
|
||||
self.hash_mat_to_bytes(
|
||||
hashers.dct(img, self.hash_size, self.high_freq_factor)
|
||||
hashers.dct(img, self.hash_size, self.high_freq_factor),
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def result(
|
||||
self,
|
||||
img: "Mat",
|
||||
img: Mat,
|
||||
category: ImageCategory,
|
||||
/,
|
||||
*,
|
||||
hash_type: ImageHashType = ImageHashType.DCT,
|
||||
):
|
||||
return [
|
||||
return next(
|
||||
it for it in self.results(img, category) if it.image_hash_type == hash_type
|
||||
][0]
|
||||
)
|
||||
|
Reference in New Issue
Block a user