refactor!: ImageHashType -> ImageHashCategory

This commit is contained in:
2025-06-17 22:28:16 +08:00
parent 212afa32db
commit b545c5b6bf
4 changed files with 29 additions and 26 deletions

View File

@ -4,7 +4,7 @@ from .models import (
ImageHashBuildTask, ImageHashBuildTask,
ImageHashHashType, ImageHashHashType,
ImageHashResult, ImageHashResult,
ImageHashType, ImageHashCategory,
) )
__all__ = [ __all__ = [
@ -12,7 +12,7 @@ __all__ = [
"ImageHashesDatabasePropertyMissingError", "ImageHashesDatabasePropertyMissingError",
"ImageHashHashType", "ImageHashHashType",
"ImageHashResult", "ImageHashResult",
"ImageHashType", "ImageHashCategory",
"ImageHashesDatabaseBuilder", "ImageHashesDatabaseBuilder",
"ImageHashBuildTask", "ImageHashBuildTask",
] ]

View File

@ -51,7 +51,7 @@ class ImageHashesDatabaseBuilder:
rows.append( rows.append(
ImageHash( ImageHash(
hash_type=hash_type, hash_type=hash_type,
type=task.type, category=task.category,
label=task.label, label=task.label,
hash=ImageHashesDatabase.hash_mat_to_bytes(hash_mat), hash=ImageHashesDatabase.hash_mat_to_bytes(hash_mat),
) )
@ -61,7 +61,7 @@ class ImageHashesDatabaseBuilder:
conn.execute("CREATE TABLE properties (`key` VARCHAR, `value` VARCHAR)") conn.execute("CREATE TABLE properties (`key` VARCHAR, `value` VARCHAR)")
conn.execute( conn.execute(
"CREATE TABLE hashes (`hash_type` INTEGER, `type` INTEGER, `label` VARCHAR, `hash` BLOB)" "CREATE TABLE hashes (`hash_type` INTEGER, `category` INTEGER, `label` VARCHAR, `hash` BLOB)"
) )
now = datetime.now(tz=timezone.utc) now = datetime.now(tz=timezone.utc)
@ -76,9 +76,9 @@ class ImageHashesDatabaseBuilder:
) )
conn.executemany( conn.executemany(
"INSERT INTO hashes (hash_type, type, label, hash) VALUES (?, ?, ?, ?)", "INSERT INTO hashes (hash_type, category, label, hash) VALUES (?, ?, ?, ?)",
[ [
(row.hash_type.value, row.type.value, row.label, row.hash) (row.hash_type.value, row.category.value, row.label, row.hash)
for row in rows for row in rows
], ],
) )

View File

@ -5,7 +5,7 @@ from typing import Any, Callable, List, Optional, TypeVar
from arcaea_offline_ocr.core import hashers from arcaea_offline_ocr.core import hashers
from arcaea_offline_ocr.types import Mat from arcaea_offline_ocr.types import Mat
from .models import ImageHashHashType, ImageHashResult, ImageHashType from .models import ImageHashHashType, ImageHashResult, ImageHashCategory
T = TypeVar("T") T = TypeVar("T")
@ -34,8 +34,8 @@ class ImageHashesDatabase:
self._built_time: Optional[datetime] = None self._built_time: Optional[datetime] = None
self._hashes_count = { self._hashes_count = {
ImageHashType.JACKET: 0, ImageHashCategory.JACKET: 0,
ImageHashType.PARTNER_ICON: 0, ImageHashCategory.PARTNER_ICON: 0,
} }
self._hash_length: int = -1 self._hash_length: int = -1
@ -62,9 +62,10 @@ class ImageHashesDatabase:
).fetchone() ).fetchone()
return convert_func(result[0]) if result is not None else None return convert_func(result[0]) if result is not None else None
def set_hashes_count(type: ImageHashType): def set_hashes_count(category: ImageHashCategory):
self._hashes_count[type] = self.conn.execute( self._hashes_count[category] = self.conn.execute(
"SELECT COUNT(DISTINCT label) FROM hashes WHERE type = ?", (type.value,) "SELECT COUNT(DISTINCT label) FROM hashes WHERE category = ?",
(category.value,),
).fetchone()[0] ).fetchone()[0]
hash_size = query_property(self.KEY_HASH_SIZE, lambda x: int(x)) hash_size = query_property(self.KEY_HASH_SIZE, lambda x: int(x))
@ -82,22 +83,22 @@ class ImageHashesDatabase:
lambda ts: datetime.fromtimestamp(int(ts) / 1000, tz=timezone.utc), lambda ts: datetime.fromtimestamp(int(ts) / 1000, tz=timezone.utc),
) )
set_hashes_count(ImageHashType.JACKET) set_hashes_count(ImageHashCategory.JACKET)
set_hashes_count(ImageHashType.PARTNER_ICON) set_hashes_count(ImageHashCategory.PARTNER_ICON)
self._hash_length = self._hash_size**2 self._hash_length = self._hash_size**2
def lookup_hash( def lookup_hash(
self, type: ImageHashType, hash_type: ImageHashHashType, hash: bytes self, category: ImageHashCategory, hash_type: ImageHashHashType, hash: bytes
) -> List[ImageHashResult]: ) -> List[ImageHashResult]:
cursor = self.conn.execute( cursor = self.conn.execute(
"SELECT" "SELECT"
" label," " label,"
" HAMMING_DISTANCE(hash, ?) AS distance" " HAMMING_DISTANCE(hash, ?) AS distance"
" FROM hashes" " FROM hashes"
" WHERE type = ? AND hash_type = ?" " WHERE category = ? AND hash_type = ?"
" ORDER BY distance ASC LIMIT 10", " ORDER BY distance ASC LIMIT 10",
(hash, type.value, hash_type.value), (hash, category.value, hash_type.value),
) )
results = [] results = []
@ -105,7 +106,7 @@ class ImageHashesDatabase:
results.append( results.append(
ImageHashResult( ImageHashResult(
hash_type=hash_type, hash_type=hash_type,
type=type, category=category,
label=label, label=label,
confidence=(self.hash_length - distance) / self.hash_length, confidence=(self.hash_length - distance) / self.hash_length,
) )
@ -117,7 +118,7 @@ class ImageHashesDatabase:
def hash_mat_to_bytes(hash: Mat) -> bytes: def hash_mat_to_bytes(hash: Mat) -> bytes:
return bytes([255 if b else 0 for b in hash.flatten()]) return bytes([255 if b else 0 for b in hash.flatten()])
def identify_image(self, type: ImageHashType, img) -> List[ImageHashResult]: def identify_image(self, category: ImageHashCategory, img) -> List[ImageHashResult]:
results = [] results = []
ahash = hashers.average(img, self.hash_size) ahash = hashers.average(img, self.hash_size)
@ -126,16 +127,18 @@ class ImageHashesDatabase:
results.extend( results.extend(
self.lookup_hash( self.lookup_hash(
type, ImageHashHashType.AVERAGE, self.hash_mat_to_bytes(ahash) category, ImageHashHashType.AVERAGE, self.hash_mat_to_bytes(ahash)
) )
) )
results.extend( results.extend(
self.lookup_hash( self.lookup_hash(
type, ImageHashHashType.DIFFERENCE, self.hash_mat_to_bytes(dhash) category, ImageHashHashType.DIFFERENCE, self.hash_mat_to_bytes(dhash)
) )
) )
results.extend( results.extend(
self.lookup_hash(type, ImageHashHashType.DCT, self.hash_mat_to_bytes(phash)) self.lookup_hash(
category, ImageHashHashType.DCT, self.hash_mat_to_bytes(phash)
)
) )
return results return results

View File

@ -13,7 +13,7 @@ class ImageHashHashType(IntEnum):
DCT = 2 DCT = 2
class ImageHashType(IntEnum): class ImageHashCategory(IntEnum):
JACKET = 0 JACKET = 0
PARTNER_ICON = 1 PARTNER_ICON = 1
@ -21,7 +21,7 @@ class ImageHashType(IntEnum):
@dataclasses.dataclass @dataclasses.dataclass
class ImageHash: class ImageHash:
hash_type: ImageHashHashType hash_type: ImageHashHashType
type: ImageHashType category: ImageHashCategory
label: str label: str
hash: bytes hash: bytes
@ -29,7 +29,7 @@ class ImageHash:
@dataclasses.dataclass @dataclasses.dataclass
class ImageHashResult: class ImageHashResult:
hash_type: ImageHashHashType hash_type: ImageHashHashType
type: ImageHashType category: ImageHashCategory
label: str label: str
confidence: float confidence: float
@ -41,6 +41,6 @@ def _default_imread_gray(image_path: str):
@dataclasses.dataclass @dataclasses.dataclass
class ImageHashBuildTask: class ImageHashBuildTask:
image_path: str image_path: str
type: ImageHashType category: ImageHashCategory
label: str label: str
imread_function: Callable[[str], Mat] = _default_imread_gray imread_function: Callable[[str], Mat] = _default_imread_gray