chore: apply ruff rules

This commit is contained in:
2025-06-27 01:36:52 +08:00
parent 57f430770e
commit 673e45834d
22 changed files with 264 additions and 173 deletions

View File

@ -1,14 +1,17 @@
import sqlite3
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
from enum import IntEnum
from typing import TYPE_CHECKING, Any, Callable, List, Optional, TypeVar
from typing import TYPE_CHECKING, Any, Callable, TypeVar
from arcaea_offline_ocr.core import hashers
from .base import ImageCategory, ImageIdProvider, ImageIdProviderResult
if TYPE_CHECKING:
import sqlite3
from arcaea_offline_ocr.types import Mat
@ -19,9 +22,11 @@ PROP_KEY_BUILT_AT = "built_at"
def _sql_hamming_distance(hash1: bytes, hash2: bytes):
assert len(hash1) == len(hash2), "hash size does not match!"
count = sum(1 for byte1, byte2 in zip(hash1, hash2) if byte1 != byte2)
return count
if len(hash1) != len(hash2):
msg = "hash size does not match!"
raise ValueError(msg)
return sum(1 for byte1, byte2 in zip(hash1, hash2) if byte1 != byte2)
class ImageHashType(IntEnum):
@ -36,7 +41,7 @@ class ImageHashDatabaseIdProviderResult(ImageIdProviderResult):
class MissingPropertiesError(Exception):
keys: List[str]
keys: list[str]
def __init__(self, keys, *args):
super().__init__(*args)
@ -72,7 +77,7 @@ class ImageHashDatabaseIdProvider(ImageIdProvider):
return self.properties[PROP_KEY_HIGH_FREQ_FACTOR]
@property
def built_at(self) -> Optional[datetime]:
def built_at(self) -> datetime | None:
return self.properties.get(PROP_KEY_BUILT_AT)
@property
@ -80,7 +85,7 @@ class ImageHashDatabaseIdProvider(ImageIdProvider):
return self._hash_length
def _initialize(self):
def get_property(key, converter: Callable[[Any], T]) -> Optional[T]:
def get_property(key, converter: Callable[[Any], T]) -> T | None:
result = self.conn.execute(
"SELECT value FROM properties WHERE key = ?",
(key,),
@ -97,7 +102,8 @@ class ImageHashDatabaseIdProvider(ImageIdProvider):
PROP_KEY_HASH_SIZE: lambda x: int(x),
PROP_KEY_HIGH_FREQ_FACTOR: lambda x: int(x),
PROP_KEY_BUILT_AT: lambda ts: datetime.fromtimestamp(
int(ts) / 1000, tz=timezone.utc
int(ts) / 1000,
tz=timezone.utc,
),
}
required_properties = [PROP_KEY_HASH_SIZE, PROP_KEY_HIGH_FREQ_FACTOR]
@ -122,8 +128,11 @@ class ImageHashDatabaseIdProvider(ImageIdProvider):
self._hash_length = self.hash_size**2
def lookup_hash(
self, category: ImageCategory, hash_type: ImageHashType, hash: bytes
) -> List[ImageHashDatabaseIdProviderResult]:
self,
category: ImageCategory,
hash_type: ImageHashType,
hash_data: bytes,
) -> list[ImageHashDatabaseIdProviderResult]:
cursor = self.conn.execute(
"""
SELECT
@ -132,7 +141,7 @@ SELECT
FROM hashes
WHERE category = ? AND hash_type = ?
ORDER BY distance ASC LIMIT 10""",
(hash, category.value, hash_type.value),
(hash_data, category.value, hash_type.value),
)
results = []
@ -143,52 +152,52 @@ ORDER BY distance ASC LIMIT 10""",
category=category,
confidence=(self.hash_length - distance) / self.hash_length,
image_hash_type=hash_type,
)
),
)
return results
@staticmethod
def hash_mat_to_bytes(hash: "Mat") -> bytes:
return bytes([255 if b else 0 for b in hash.flatten()])
def hash_mat_to_bytes(hash_mat: Mat) -> bytes:
return bytes([255 if b else 0 for b in hash_mat.flatten()])
def results(self, img: "Mat", category: ImageCategory, /):
results: List[ImageHashDatabaseIdProviderResult] = []
def results(self, img: Mat, category: ImageCategory, /):
results: list[ImageHashDatabaseIdProviderResult] = []
results.extend(
self.lookup_hash(
category,
ImageHashType.AVERAGE,
self.hash_mat_to_bytes(hashers.average(img, self.hash_size)),
)
),
)
results.extend(
self.lookup_hash(
category,
ImageHashType.DIFFERENCE,
self.hash_mat_to_bytes(hashers.difference(img, self.hash_size)),
)
),
)
results.extend(
self.lookup_hash(
category,
ImageHashType.DCT,
self.hash_mat_to_bytes(
hashers.dct(img, self.hash_size, self.high_freq_factor)
hashers.dct(img, self.hash_size, self.high_freq_factor),
),
)
),
)
return results
def result(
self,
img: "Mat",
img: Mat,
category: ImageCategory,
/,
*,
hash_type: ImageHashType = ImageHashType.DCT,
):
return [
return next(
it for it in self.results(img, category) if it.image_hash_type == hash_type
][0]
)