refactor!: ImageHashType -> ImageHashCategory

This commit is contained in:
2025-06-17 22:28:16 +08:00
parent 212afa32db
commit b545c5b6bf
4 changed files with 29 additions and 26 deletions

View File

@ -4,7 +4,7 @@ from .models import (
ImageHashBuildTask,
ImageHashHashType,
ImageHashResult,
ImageHashType,
ImageHashCategory,
)
__all__ = [
@ -12,7 +12,7 @@ __all__ = [
"ImageHashesDatabasePropertyMissingError",
"ImageHashHashType",
"ImageHashResult",
"ImageHashType",
"ImageHashCategory",
"ImageHashesDatabaseBuilder",
"ImageHashBuildTask",
]

View File

@ -51,7 +51,7 @@ class ImageHashesDatabaseBuilder:
rows.append(
ImageHash(
hash_type=hash_type,
type=task.type,
category=task.category,
label=task.label,
hash=ImageHashesDatabase.hash_mat_to_bytes(hash_mat),
)
@ -61,7 +61,7 @@ class ImageHashesDatabaseBuilder:
conn.execute("CREATE TABLE properties (`key` VARCHAR, `value` VARCHAR)")
conn.execute(
"CREATE TABLE hashes (`hash_type` INTEGER, `type` INTEGER, `label` VARCHAR, `hash` BLOB)"
"CREATE TABLE hashes (`hash_type` INTEGER, `category` INTEGER, `label` VARCHAR, `hash` BLOB)"
)
now = datetime.now(tz=timezone.utc)
@ -76,9 +76,9 @@ class ImageHashesDatabaseBuilder:
)
conn.executemany(
"INSERT INTO hashes (hash_type, type, label, hash) VALUES (?, ?, ?, ?)",
"INSERT INTO hashes (hash_type, category, label, hash) VALUES (?, ?, ?, ?)",
[
(row.hash_type.value, row.type.value, row.label, row.hash)
(row.hash_type.value, row.category.value, row.label, row.hash)
for row in rows
],
)

View File

@ -5,7 +5,7 @@ from typing import Any, Callable, List, Optional, TypeVar
from arcaea_offline_ocr.core import hashers
from arcaea_offline_ocr.types import Mat
from .models import ImageHashHashType, ImageHashResult, ImageHashType
from .models import ImageHashHashType, ImageHashResult, ImageHashCategory
T = TypeVar("T")
@ -34,8 +34,8 @@ class ImageHashesDatabase:
self._built_time: Optional[datetime] = None
self._hashes_count = {
ImageHashType.JACKET: 0,
ImageHashType.PARTNER_ICON: 0,
ImageHashCategory.JACKET: 0,
ImageHashCategory.PARTNER_ICON: 0,
}
self._hash_length: int = -1
@ -62,9 +62,10 @@ class ImageHashesDatabase:
).fetchone()
return convert_func(result[0]) if result is not None else None
def set_hashes_count(type: ImageHashType):
self._hashes_count[type] = self.conn.execute(
"SELECT COUNT(DISTINCT label) FROM hashes WHERE type = ?", (type.value,)
def set_hashes_count(category: ImageHashCategory):
self._hashes_count[category] = self.conn.execute(
"SELECT COUNT(DISTINCT label) FROM hashes WHERE category = ?",
(category.value,),
).fetchone()[0]
hash_size = query_property(self.KEY_HASH_SIZE, lambda x: int(x))
@ -82,22 +83,22 @@ class ImageHashesDatabase:
lambda ts: datetime.fromtimestamp(int(ts) / 1000, tz=timezone.utc),
)
set_hashes_count(ImageHashType.JACKET)
set_hashes_count(ImageHashType.PARTNER_ICON)
set_hashes_count(ImageHashCategory.JACKET)
set_hashes_count(ImageHashCategory.PARTNER_ICON)
self._hash_length = self._hash_size**2
def lookup_hash(
self, type: ImageHashType, hash_type: ImageHashHashType, hash: bytes
self, category: ImageHashCategory, hash_type: ImageHashHashType, hash: bytes
) -> List[ImageHashResult]:
cursor = self.conn.execute(
"SELECT"
" label,"
" HAMMING_DISTANCE(hash, ?) AS distance"
" FROM hashes"
" WHERE type = ? AND hash_type = ?"
" WHERE category = ? AND hash_type = ?"
" ORDER BY distance ASC LIMIT 10",
(hash, type.value, hash_type.value),
(hash, category.value, hash_type.value),
)
results = []
@ -105,7 +106,7 @@ class ImageHashesDatabase:
results.append(
ImageHashResult(
hash_type=hash_type,
type=type,
category=category,
label=label,
confidence=(self.hash_length - distance) / self.hash_length,
)
@ -117,7 +118,7 @@ class ImageHashesDatabase:
def hash_mat_to_bytes(hash: Mat) -> bytes:
return bytes([255 if b else 0 for b in hash.flatten()])
def identify_image(self, type: ImageHashType, img) -> List[ImageHashResult]:
def identify_image(self, category: ImageHashCategory, img) -> List[ImageHashResult]:
results = []
ahash = hashers.average(img, self.hash_size)
@ -126,16 +127,18 @@ class ImageHashesDatabase:
results.extend(
self.lookup_hash(
type, ImageHashHashType.AVERAGE, self.hash_mat_to_bytes(ahash)
category, ImageHashHashType.AVERAGE, self.hash_mat_to_bytes(ahash)
)
)
results.extend(
self.lookup_hash(
type, ImageHashHashType.DIFFERENCE, self.hash_mat_to_bytes(dhash)
category, ImageHashHashType.DIFFERENCE, self.hash_mat_to_bytes(dhash)
)
)
results.extend(
self.lookup_hash(type, ImageHashHashType.DCT, self.hash_mat_to_bytes(phash))
self.lookup_hash(
category, ImageHashHashType.DCT, self.hash_mat_to_bytes(phash)
)
)
return results

View File

@ -13,7 +13,7 @@ class ImageHashHashType(IntEnum):
DCT = 2
class ImageHashType(IntEnum):
class ImageHashCategory(IntEnum):
JACKET = 0
PARTNER_ICON = 1
@ -21,7 +21,7 @@ class ImageHashType(IntEnum):
@dataclasses.dataclass
class ImageHash:
hash_type: ImageHashHashType
type: ImageHashType
category: ImageHashCategory
label: str
hash: bytes
@ -29,7 +29,7 @@ class ImageHash:
@dataclasses.dataclass
class ImageHashResult:
hash_type: ImageHashHashType
type: ImageHashType
category: ImageHashCategory
label: str
confidence: float
@ -41,6 +41,6 @@ def _default_imread_gray(image_path: str):
@dataclasses.dataclass
class ImageHashBuildTask:
image_path: str
type: ImageHashType
category: ImageHashCategory
label: str
imread_function: Callable[[str], Mat] = _default_imread_gray