From 619bff2ea4916100371848782fd86ccdd0584c20 Mon Sep 17 00:00:00 2001 From: 283375 Date: Fri, 10 Jan 2025 23:55:37 +0800 Subject: [PATCH] feat: image hashes database --- .../dependencies/ihdb/__init__.py | 18 +++ .../dependencies/ihdb/builder.py | 85 +++++++++++ .../dependencies/ihdb/index.py | 141 ++++++++++++++++++ .../dependencies/ihdb/models.py | 46 ++++++ 4 files changed, 290 insertions(+) create mode 100644 src/arcaea_offline_ocr/dependencies/ihdb/__init__.py create mode 100644 src/arcaea_offline_ocr/dependencies/ihdb/builder.py create mode 100644 src/arcaea_offline_ocr/dependencies/ihdb/index.py create mode 100644 src/arcaea_offline_ocr/dependencies/ihdb/models.py diff --git a/src/arcaea_offline_ocr/dependencies/ihdb/__init__.py b/src/arcaea_offline_ocr/dependencies/ihdb/__init__.py new file mode 100644 index 0000000..ab73c7d --- /dev/null +++ b/src/arcaea_offline_ocr/dependencies/ihdb/__init__.py @@ -0,0 +1,18 @@ +from .builder import ImageHashesDatabaseBuilder +from .index import ImageHashesDatabase, ImageHashesDatabasePropertyMissingError +from .models import ( + ImageHashBuildTask, + ImageHashHashType, + ImageHashResult, + ImageHashType, +) + +__all__ = [ + "ImageHashesDatabase", + "ImageHashesDatabasePropertyMissingError", + "ImageHashHashType", + "ImageHashResult", + "ImageHashType", + "ImageHashesDatabaseBuilder", + "ImageHashBuildTask", +] diff --git a/src/arcaea_offline_ocr/dependencies/ihdb/builder.py b/src/arcaea_offline_ocr/dependencies/ihdb/builder.py new file mode 100644 index 0000000..972e12c --- /dev/null +++ b/src/arcaea_offline_ocr/dependencies/ihdb/builder.py @@ -0,0 +1,85 @@ +import logging +from datetime import datetime, timezone +from sqlite3 import Connection +from typing import List + +from arcaea_offline_ocr.core import hashers + +from .index import ImageHashesDatabase +from .models import ImageHash, ImageHashBuildTask, ImageHashHashType + +logger = logging.getLogger(__name__) + + +class ImageHashesDatabaseBuilder: + @staticmethod + def __insert_property(conn: Connection, key: str, value: str): + return conn.execute( + "INSERT INTO properties (key, value) VALUES (?, ?)", + (key, value), + ) + + @classmethod + def build( + cls, + conn: Connection, + tasks: List[ImageHashBuildTask], + *, + hash_size: int = 16, + high_freq_factor: int = 4, + ): + rows: List[ImageHash] = [] + + for task in tasks: + try: + img_gray = task.imread_function(task.image_path) + + for hash_type, hash_mat in [ + ( + ImageHashHashType.AVERAGE, + hashers.average(img_gray, hash_size), + ), + ( + ImageHashHashType.DCT, + hashers.dct(img_gray, hash_size, high_freq_factor), + ), + ( + ImageHashHashType.DIFFERENCE, + hashers.difference(img_gray, hash_size), + ), + ]: + rows.append( + ImageHash( + hash_type=hash_type, + type=task.type, + label=task.label, + hash=ImageHashesDatabase.hash_mat_to_bytes(hash_mat), + ) + ) + except Exception: + logger.exception("Error processing task %r", task) + + conn.execute("CREATE TABLE properties (`key` VARCHAR, `value` VARCHAR)") + conn.execute( + "CREATE TABLE hashes (`hash_type` INTEGER, `type` INTEGER, `label` VARCHAR, `hash` BLOB)" + ) + + now = datetime.now(tz=timezone.utc) + timestamp = int(now.timestamp() * 1000) + + cls.__insert_property(conn, ImageHashesDatabase.KEY_HASH_SIZE, str(hash_size)) + cls.__insert_property( + conn, ImageHashesDatabase.KEY_HIGH_FREQ_FACTOR, str(high_freq_factor) + ) + cls.__insert_property( + conn, ImageHashesDatabase.KEY_BUILT_TIMESTAMP, str(timestamp) + ) + + conn.executemany( + "INSERT INTO hashes (hash_type, type, label, hash) VALUES (?, ?, ?, ?)", + [ + (row.hash_type.value, row.type.value, row.label, row.hash) + for row in rows + ], + ) + conn.commit() diff --git a/src/arcaea_offline_ocr/dependencies/ihdb/index.py b/src/arcaea_offline_ocr/dependencies/ihdb/index.py new file mode 100644 index 0000000..4032a0c --- /dev/null +++ b/src/arcaea_offline_ocr/dependencies/ihdb/index.py @@ -0,0 +1,141 @@ +import sqlite3 +from datetime import datetime, timezone +from typing import Any, Callable, List, Optional, TypeVar + +from arcaea_offline_ocr.core import hashers +from arcaea_offline_ocr.types import Mat + +from .models import ImageHashHashType, ImageHashResult, ImageHashType + +T = TypeVar("T") + + +def _sql_hamming_distance(hash1: bytes, hash2: bytes): + assert len(hash1) == len(hash2), "hash size does not match!" + count = sum(1 for byte1, byte2 in zip(hash1, hash2) if byte1 != byte2) + return count + + +class ImageHashesDatabasePropertyMissingError(Exception): + pass + + +class ImageHashesDatabase: + KEY_HASH_SIZE = "hash_size" + KEY_HIGH_FREQ_FACTOR = "high_freq_factor" + KEY_BUILT_TIMESTAMP = "built_timestamp" + + def __init__(self, conn: sqlite3.Connection): + self.conn = conn + self.conn.create_function("HAMMING_DISTANCE", 2, _sql_hamming_distance) + + self._hash_size: int = -1 + self._high_freq_factor: int = -1 + self._built_time: Optional[datetime] = None + + self._hashes_count = { + ImageHashType.JACKET: 0, + ImageHashType.PARTNER_ICON: 0, + } + + self._hash_length: int = -1 + + self._initialize() + + @property + def hash_size(self): + return self._hash_size + + @property + def high_freq_factor(self): + return self._high_freq_factor + + @property + def hash_length(self): + return self._hash_length + + def _initialize(self): + def query_property(key, convert_func: Callable[[Any], T]) -> Optional[T]: + result = self.conn.execute( + "SELECT value FROM properties WHERE key = ?", + (key,), + ).fetchone() + return convert_func(result[0]) if result is not None else None + + def set_hashes_count(type: ImageHashType): + self._hashes_count[type] = self.conn.execute( + "SELECT COUNT(DISTINCT label) FROM hashes WHERE type = ?", (type.value,) + ).fetchone()[0] + + hash_size = query_property(self.KEY_HASH_SIZE, lambda x: int(x)) + if hash_size is None: + raise ImageHashesDatabasePropertyMissingError("hash_size") + self._hash_size = hash_size + + high_freq_factor = query_property(self.KEY_HIGH_FREQ_FACTOR, lambda x: int(x)) + if high_freq_factor is None: + raise ImageHashesDatabasePropertyMissingError("high_freq_factor") + self._high_freq_factor = high_freq_factor + + self._built_time = query_property( + self.KEY_BUILT_TIMESTAMP, + lambda ts: datetime.fromtimestamp(int(ts) / 1000, tz=timezone.utc), + ) + + set_hashes_count(ImageHashType.JACKET) + set_hashes_count(ImageHashType.PARTNER_ICON) + + self._hash_length = self._hash_size**2 + + def lookup_hash( + self, type: ImageHashType, hash_type: ImageHashHashType, hash: bytes + ) -> List[ImageHashResult]: + cursor = self.conn.execute( + "SELECT" + " label," + " HAMMING_DISTANCE(hash, ?) AS distance" + " FROM hashes" + " WHERE type = ? AND hash_type = ?" + " ORDER BY distance ASC LIMIT 10", + (hash, type.value, hash_type.value), + ) + + results = [] + for label, distance in cursor.fetchall(): + results.append( + ImageHashResult( + hash_type=hash_type, + type=type, + label=label, + confidence=(self.hash_length - distance) / self.hash_length, + ) + ) + + return results + + @staticmethod + def hash_mat_to_bytes(hash: Mat) -> bytes: + return bytes([255 if b else 0 for b in hash.flatten()]) + + def identify_image(self, type: ImageHashType, img) -> List[ImageHashResult]: + results = [] + + ahash = hashers.average(img, self.hash_size) + dhash = hashers.difference(img, self.hash_size) + phash = hashers.dct(img, self.hash_size, self.high_freq_factor) + + results.extend( + self.lookup_hash( + type, ImageHashHashType.AVERAGE, self.hash_mat_to_bytes(ahash) + ) + ) + results.extend( + self.lookup_hash( + type, ImageHashHashType.DIFFERENCE, self.hash_mat_to_bytes(dhash) + ) + ) + results.extend( + self.lookup_hash(type, ImageHashHashType.DCT, self.hash_mat_to_bytes(phash)) + ) + + return results diff --git a/src/arcaea_offline_ocr/dependencies/ihdb/models.py b/src/arcaea_offline_ocr/dependencies/ihdb/models.py new file mode 100644 index 0000000..fd9ee93 --- /dev/null +++ b/src/arcaea_offline_ocr/dependencies/ihdb/models.py @@ -0,0 +1,46 @@ +import dataclasses +from enum import IntEnum +from typing import Callable + +import cv2 + +from arcaea_offline_ocr.types import Mat + + +class ImageHashHashType(IntEnum): + AVERAGE = 0 + DIFFERENCE = 1 + DCT = 2 + + +class ImageHashType(IntEnum): + JACKET = 0 + PARTNER_ICON = 1 + + +@dataclasses.dataclass +class ImageHash: + hash_type: ImageHashHashType + type: ImageHashType + label: str + hash: bytes + + +@dataclasses.dataclass +class ImageHashResult: + hash_type: ImageHashHashType + type: ImageHashType + label: str + confidence: float + + +def _default_imread_gray(image_path: str): + return cv2.cvtColor(cv2.imread(image_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2GRAY) + + +@dataclasses.dataclass +class ImageHashBuildTask: + image_path: str + type: ImageHashType + label: str + imread_function: Callable[[str], Mat] = _default_imread_gray