diff --git a/src/arcaea_offline_ocr/builders/__init__.py b/src/arcaea_offline_ocr/builders/__init__.py new file mode 100644 index 0000000..94fdce2 --- /dev/null +++ b/src/arcaea_offline_ocr/builders/__init__.py @@ -0,0 +1,6 @@ +from .ihdb import ImageHashDatabaseBuildTask, ImageHashesDatabaseBuilder + +__all__ = [ + "ImageHashDatabaseBuildTask", + "ImageHashesDatabaseBuilder", +] diff --git a/src/arcaea_offline_ocr/builders/ihdb.py b/src/arcaea_offline_ocr/builders/ihdb.py new file mode 100644 index 0000000..6088684 --- /dev/null +++ b/src/arcaea_offline_ocr/builders/ihdb.py @@ -0,0 +1,112 @@ +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Callable, List + +import cv2 + +from arcaea_offline_ocr.core import hashers +from arcaea_offline_ocr.providers import ImageCategory +from arcaea_offline_ocr.providers.ihdb import ( + PROP_KEY_BUILT_AT, + PROP_KEY_HASH_SIZE, + PROP_KEY_HIGH_FREQ_FACTOR, + ImageHashDatabaseIdProvider, + ImageHashType, +) + +if TYPE_CHECKING: + from sqlite3 import Connection + + from arcaea_offline_ocr.types import Mat + + +def _default_imread_gray(image_path: str): + return cv2.cvtColor(cv2.imread(image_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2GRAY) + + +@dataclass +class ImageHashDatabaseBuildTask: + image_path: str + image_id: str + category: ImageCategory + imread_function: Callable[[str], "Mat"] = _default_imread_gray + + +@dataclass +class _ImageHash: + image_id: str + category: ImageCategory + image_hash_type: ImageHashType + hash: bytes + + +class ImageHashesDatabaseBuilder: + @staticmethod + def __insert_property(conn: "Connection", key: str, value: str): + return conn.execute( + "INSERT INTO properties (key, value) VALUES (?, ?)", + (key, value), + ) + + @classmethod + def build( + cls, + conn: "Connection", + tasks: List[ImageHashDatabaseBuildTask], + *, + hash_size: int = 16, + high_freq_factor: int = 4, + ): + hashes: List[_ImageHash] = [] + + for task in tasks: + img_gray = task.imread_function(task.image_path) + + for hash_type, hash_mat in [ + ( + ImageHashType.AVERAGE, + hashers.average(img_gray, hash_size), + ), + ( + ImageHashType.DCT, + hashers.dct(img_gray, hash_size, high_freq_factor), + ), + ( + ImageHashType.DIFFERENCE, + hashers.difference(img_gray, hash_size), + ), + ]: + hashes.append( + _ImageHash( + image_id=task.image_id, + image_hash_type=hash_type, + category=task.category, + hash=ImageHashDatabaseIdProvider.hash_mat_to_bytes(hash_mat), + ) + ) + + conn.execute("CREATE TABLE properties (`key` VARCHAR, `value` VARCHAR)") + conn.execute( + """CREATE TABLE hashes ( +`id` VARCHAR, +`category` INTEGER, +`hash_type` INTEGER, +`hash` BLOB +)""" + ) + + now = datetime.now(tz=timezone.utc) + timestamp = int(now.timestamp() * 1000) + + cls.__insert_property(conn, PROP_KEY_HASH_SIZE, str(hash_size)) + cls.__insert_property(conn, PROP_KEY_HIGH_FREQ_FACTOR, str(high_freq_factor)) + cls.__insert_property(conn, PROP_KEY_BUILT_AT, str(timestamp)) + + conn.executemany( + "INSERT INTO hashes (`id`, `category`, `hash_type`, `hash`) VALUES (?, ?, ?, ?)", + [ + (it.image_id, it.category.value, it.image_hash_type.value, it.hash) + for it in hashes + ], + ) + conn.commit() diff --git a/src/arcaea_offline_ocr/dependencies/ihdb/__init__.py b/src/arcaea_offline_ocr/dependencies/ihdb/__init__.py deleted file mode 100644 index b114bc2..0000000 --- a/src/arcaea_offline_ocr/dependencies/ihdb/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -from .builder import ImageHashesDatabaseBuilder -from .index import ImageHashesDatabase, ImageHashesDatabasePropertyMissingError -from .models import ( - ImageHashBuildTask, - ImageHashHashType, - ImageHashResult, - ImageHashCategory, -) - -__all__ = [ - "ImageHashesDatabase", - "ImageHashesDatabasePropertyMissingError", - "ImageHashHashType", - "ImageHashResult", - "ImageHashCategory", - "ImageHashesDatabaseBuilder", - "ImageHashBuildTask", -] diff --git a/src/arcaea_offline_ocr/dependencies/ihdb/builder.py b/src/arcaea_offline_ocr/dependencies/ihdb/builder.py deleted file mode 100644 index 07b72aa..0000000 --- a/src/arcaea_offline_ocr/dependencies/ihdb/builder.py +++ /dev/null @@ -1,85 +0,0 @@ -import logging -from datetime import datetime, timezone -from sqlite3 import Connection -from typing import List - -from arcaea_offline_ocr.core import hashers - -from .index import ImageHashesDatabase -from .models import ImageHash, ImageHashBuildTask, ImageHashHashType - -logger = logging.getLogger(__name__) - - -class ImageHashesDatabaseBuilder: - @staticmethod - def __insert_property(conn: Connection, key: str, value: str): - return conn.execute( - "INSERT INTO properties (key, value) VALUES (?, ?)", - (key, value), - ) - - @classmethod - def build( - cls, - conn: Connection, - tasks: List[ImageHashBuildTask], - *, - hash_size: int = 16, - high_freq_factor: int = 4, - ): - rows: List[ImageHash] = [] - - for task in tasks: - try: - img_gray = task.imread_function(task.image_path) - - for hash_type, hash_mat in [ - ( - ImageHashHashType.AVERAGE, - hashers.average(img_gray, hash_size), - ), - ( - ImageHashHashType.DCT, - hashers.dct(img_gray, hash_size, high_freq_factor), - ), - ( - ImageHashHashType.DIFFERENCE, - hashers.difference(img_gray, hash_size), - ), - ]: - rows.append( - ImageHash( - hash_type=hash_type, - category=task.category, - label=task.label, - hash=ImageHashesDatabase.hash_mat_to_bytes(hash_mat), - ) - ) - except Exception: - logger.exception("Error processing task %r", task) - - conn.execute("CREATE TABLE properties (`key` VARCHAR, `value` VARCHAR)") - conn.execute( - "CREATE TABLE hashes (`hash_type` INTEGER, `category` INTEGER, `label` VARCHAR, `hash` BLOB)" - ) - - now = datetime.now(tz=timezone.utc) - timestamp = int(now.timestamp() * 1000) - - cls.__insert_property(conn, ImageHashesDatabase.KEY_HASH_SIZE, str(hash_size)) - cls.__insert_property( - conn, ImageHashesDatabase.KEY_HIGH_FREQ_FACTOR, str(high_freq_factor) - ) - cls.__insert_property( - conn, ImageHashesDatabase.KEY_BUILT_TIMESTAMP, str(timestamp) - ) - - conn.executemany( - "INSERT INTO hashes (hash_type, category, label, hash) VALUES (?, ?, ?, ?)", - [ - (row.hash_type.value, row.category.value, row.label, row.hash) - for row in rows - ], - ) - conn.commit() diff --git a/src/arcaea_offline_ocr/dependencies/ihdb/index.py b/src/arcaea_offline_ocr/dependencies/ihdb/index.py deleted file mode 100644 index ff13656..0000000 --- a/src/arcaea_offline_ocr/dependencies/ihdb/index.py +++ /dev/null @@ -1,144 +0,0 @@ -import sqlite3 -from datetime import datetime, timezone -from typing import Any, Callable, List, Optional, TypeVar - -from arcaea_offline_ocr.core import hashers -from arcaea_offline_ocr.types import Mat - -from .models import ImageHashHashType, ImageHashResult, ImageHashCategory - -T = TypeVar("T") - - -def _sql_hamming_distance(hash1: bytes, hash2: bytes): - assert len(hash1) == len(hash2), "hash size does not match!" - count = sum(1 for byte1, byte2 in zip(hash1, hash2) if byte1 != byte2) - return count - - -class ImageHashesDatabasePropertyMissingError(Exception): - pass - - -class ImageHashesDatabase: - KEY_HASH_SIZE = "hash_size" - KEY_HIGH_FREQ_FACTOR = "high_freq_factor" - KEY_BUILT_TIMESTAMP = "built_timestamp" - - def __init__(self, conn: sqlite3.Connection): - self.conn = conn - self.conn.create_function("HAMMING_DISTANCE", 2, _sql_hamming_distance) - - self._hash_size: int = -1 - self._high_freq_factor: int = -1 - self._built_time: Optional[datetime] = None - - self._hashes_count = { - ImageHashCategory.JACKET: 0, - ImageHashCategory.PARTNER_ICON: 0, - } - - self._hash_length: int = -1 - - self._initialize() - - @property - def hash_size(self): - return self._hash_size - - @property - def high_freq_factor(self): - return self._high_freq_factor - - @property - def hash_length(self): - return self._hash_length - - def _initialize(self): - def query_property(key, convert_func: Callable[[Any], T]) -> Optional[T]: - result = self.conn.execute( - "SELECT value FROM properties WHERE key = ?", - (key,), - ).fetchone() - return convert_func(result[0]) if result is not None else None - - def set_hashes_count(category: ImageHashCategory): - self._hashes_count[category] = self.conn.execute( - "SELECT COUNT(DISTINCT label) FROM hashes WHERE category = ?", - (category.value,), - ).fetchone()[0] - - hash_size = query_property(self.KEY_HASH_SIZE, lambda x: int(x)) - if hash_size is None: - raise ImageHashesDatabasePropertyMissingError("hash_size") - self._hash_size = hash_size - - high_freq_factor = query_property(self.KEY_HIGH_FREQ_FACTOR, lambda x: int(x)) - if high_freq_factor is None: - raise ImageHashesDatabasePropertyMissingError("high_freq_factor") - self._high_freq_factor = high_freq_factor - - self._built_time = query_property( - self.KEY_BUILT_TIMESTAMP, - lambda ts: datetime.fromtimestamp(int(ts) / 1000, tz=timezone.utc), - ) - - set_hashes_count(ImageHashCategory.JACKET) - set_hashes_count(ImageHashCategory.PARTNER_ICON) - - self._hash_length = self._hash_size**2 - - def lookup_hash( - self, category: ImageHashCategory, hash_type: ImageHashHashType, hash: bytes - ) -> List[ImageHashResult]: - cursor = self.conn.execute( - "SELECT" - " label," - " HAMMING_DISTANCE(hash, ?) AS distance" - " FROM hashes" - " WHERE category = ? AND hash_type = ?" - " ORDER BY distance ASC LIMIT 10", - (hash, category.value, hash_type.value), - ) - - results = [] - for label, distance in cursor.fetchall(): - results.append( - ImageHashResult( - hash_type=hash_type, - category=category, - label=label, - confidence=(self.hash_length - distance) / self.hash_length, - ) - ) - - return results - - @staticmethod - def hash_mat_to_bytes(hash: Mat) -> bytes: - return bytes([255 if b else 0 for b in hash.flatten()]) - - def identify_image(self, category: ImageHashCategory, img) -> List[ImageHashResult]: - results = [] - - ahash = hashers.average(img, self.hash_size) - dhash = hashers.difference(img, self.hash_size) - phash = hashers.dct(img, self.hash_size, self.high_freq_factor) - - results.extend( - self.lookup_hash( - category, ImageHashHashType.AVERAGE, self.hash_mat_to_bytes(ahash) - ) - ) - results.extend( - self.lookup_hash( - category, ImageHashHashType.DIFFERENCE, self.hash_mat_to_bytes(dhash) - ) - ) - results.extend( - self.lookup_hash( - category, ImageHashHashType.DCT, self.hash_mat_to_bytes(phash) - ) - ) - - return results diff --git a/src/arcaea_offline_ocr/dependencies/ihdb/models.py b/src/arcaea_offline_ocr/dependencies/ihdb/models.py deleted file mode 100644 index dd0af39..0000000 --- a/src/arcaea_offline_ocr/dependencies/ihdb/models.py +++ /dev/null @@ -1,46 +0,0 @@ -import dataclasses -from enum import IntEnum -from typing import Callable - -import cv2 - -from arcaea_offline_ocr.types import Mat - - -class ImageHashHashType(IntEnum): - AVERAGE = 0 - DIFFERENCE = 1 - DCT = 2 - - -class ImageHashCategory(IntEnum): - JACKET = 0 - PARTNER_ICON = 1 - - -@dataclasses.dataclass -class ImageHash: - hash_type: ImageHashHashType - category: ImageHashCategory - label: str - hash: bytes - - -@dataclasses.dataclass -class ImageHashResult: - hash_type: ImageHashHashType - category: ImageHashCategory - label: str - confidence: float - - -def _default_imread_gray(image_path: str): - return cv2.cvtColor(cv2.imread(image_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2GRAY) - - -@dataclasses.dataclass -class ImageHashBuildTask: - image_path: str - category: ImageHashCategory - label: str - imread_function: Callable[[str], Mat] = _default_imread_gray diff --git a/src/arcaea_offline_ocr/providers/__init__.py b/src/arcaea_offline_ocr/providers/__init__.py index e69de29..baa233b 100644 --- a/src/arcaea_offline_ocr/providers/__init__.py +++ b/src/arcaea_offline_ocr/providers/__init__.py @@ -0,0 +1,12 @@ +from .base import ImageCategory, ImageIdProvider, ImageIdProviderResult, OcrTextProvider +from .ihdb import ImageHashDatabaseIdProvider +from .knn import OcrKNearestTextProvider + +__all__ = [ + "ImageCategory", + "ImageHashDatabaseIdProvider", + "OcrKNearestTextProvider", + "ImageIdProvider", + "OcrTextProvider", + "ImageIdProviderResult", +] diff --git a/src/arcaea_offline_ocr/providers/base.py b/src/arcaea_offline_ocr/providers/base.py index 6e84be1..b98c058 100644 --- a/src/arcaea_offline_ocr/providers/base.py +++ b/src/arcaea_offline_ocr/providers/base.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from dataclasses import dataclass +from enum import IntEnum +from typing import TYPE_CHECKING, Any, Sequence, Optional if TYPE_CHECKING: from ..types import Mat @@ -10,3 +12,27 @@ class OcrTextProvider(ABC): def result_raw(self, img: "Mat", /, *args, **kwargs) -> Any: ... @abstractmethod def result(self, img: "Mat", /, *args, **kwargs) -> Optional[str]: ... + + +class ImageCategory(IntEnum): + JACKET = 0 + PARTNER_ICON = 1 + + +@dataclass(kw_only=True) +class ImageIdProviderResult: + image_id: str + category: ImageCategory + confidence: float + + +class ImageIdProvider(ABC): + @abstractmethod + def result( + self, img: "Mat", category: ImageCategory, /, *args, **kwargs + ) -> ImageIdProviderResult: ... + + @abstractmethod + def results( + self, img: "Mat", category: ImageCategory, /, *args, **kwargs + ) -> Sequence[ImageIdProviderResult]: ... diff --git a/src/arcaea_offline_ocr/providers/ihdb.py b/src/arcaea_offline_ocr/providers/ihdb.py new file mode 100644 index 0000000..0539264 --- /dev/null +++ b/src/arcaea_offline_ocr/providers/ihdb.py @@ -0,0 +1,194 @@ +import sqlite3 +from dataclasses import dataclass +from datetime import datetime, timezone +from enum import IntEnum +from typing import TYPE_CHECKING, Any, Callable, List, Optional, TypeVar + +from arcaea_offline_ocr.core import hashers + +from .base import ImageCategory, ImageIdProvider, ImageIdProviderResult + +if TYPE_CHECKING: + from arcaea_offline_ocr.types import Mat + + +T = TypeVar("T") +PROP_KEY_HASH_SIZE = "hash_size" +PROP_KEY_HIGH_FREQ_FACTOR = "high_freq_factor" +PROP_KEY_BUILT_AT = "built_at" + + +def _sql_hamming_distance(hash1: bytes, hash2: bytes): + assert len(hash1) == len(hash2), "hash size does not match!" + count = sum(1 for byte1, byte2 in zip(hash1, hash2) if byte1 != byte2) + return count + + +class ImageHashType(IntEnum): + AVERAGE = 0 + DIFFERENCE = 1 + DCT = 2 + + +@dataclass(kw_only=True) +class ImageHashDatabaseIdProviderResult(ImageIdProviderResult): + image_hash_type: ImageHashType + + +class MissingPropertiesError(Exception): + keys: List[str] + + def __init__(self, keys, *args): + super().__init__(*args) + self.keys = keys + + +class ImageHashDatabaseIdProvider(ImageIdProvider): + def __init__(self, conn: sqlite3.Connection): + self.conn = conn + self.conn.create_function("HAMMING_DISTANCE", 2, _sql_hamming_distance) + + self.properties = { + PROP_KEY_HASH_SIZE: -1, + PROP_KEY_HIGH_FREQ_FACTOR: -1, + PROP_KEY_BUILT_AT: None, + } + + self._hashes_count = { + ImageCategory.JACKET: 0, + ImageCategory.PARTNER_ICON: 0, + } + + self._hash_length: int = -1 + + self._initialize() + + @property + def hash_size(self) -> int: + return self.properties[PROP_KEY_HASH_SIZE] + + @property + def high_freq_factor(self) -> int: + return self.properties[PROP_KEY_HIGH_FREQ_FACTOR] + + @property + def built_at(self) -> Optional[datetime]: + return self.properties.get(PROP_KEY_BUILT_AT) + + @property + def hash_length(self): + return self._hash_length + + def _initialize(self): + def get_property(key, converter: Callable[[Any], T]) -> Optional[T]: + result = self.conn.execute( + "SELECT value FROM properties WHERE key = ?", + (key,), + ).fetchone() + return converter(result[0]) if result is not None else None + + def set_hashes_count(category: ImageCategory): + self._hashes_count[category] = self.conn.execute( + "SELECT COUNT(DISTINCT `id`) FROM hashes WHERE category = ?", + (category.value,), + ).fetchone()[0] + + properties_converter_map = { + PROP_KEY_HASH_SIZE: lambda x: int(x), + PROP_KEY_HIGH_FREQ_FACTOR: lambda x: int(x), + PROP_KEY_BUILT_AT: lambda ts: datetime.fromtimestamp( + int(ts) / 1000, tz=timezone.utc + ), + } + required_properties = [PROP_KEY_HASH_SIZE, PROP_KEY_HIGH_FREQ_FACTOR] + + missing_properties = [] + for property_key, converter in properties_converter_map.items(): + value = get_property(property_key, converter) + if value is None: + if property_key in required_properties: + missing_properties.append(property_key) + + continue + + self.properties[property_key] = value + + if missing_properties: + raise MissingPropertiesError(keys=missing_properties) + + set_hashes_count(ImageCategory.JACKET) + set_hashes_count(ImageCategory.PARTNER_ICON) + + self._hash_length = self.hash_size**2 + + def lookup_hash( + self, category: ImageCategory, hash_type: ImageHashType, hash: bytes + ) -> List[ImageHashDatabaseIdProviderResult]: + cursor = self.conn.execute( + """ +SELECT + `id`, + HAMMING_DISTANCE(hash, ?) AS distance +FROM hashes +WHERE category = ? AND hash_type = ? +ORDER BY distance ASC LIMIT 10""", + (hash, category.value, hash_type.value), + ) + + results = [] + for id_, distance in cursor.fetchall(): + results.append( + ImageHashDatabaseIdProviderResult( + image_id=id_, + category=category, + confidence=(self.hash_length - distance) / self.hash_length, + image_hash_type=hash_type, + ) + ) + + return results + + @staticmethod + def hash_mat_to_bytes(hash: "Mat") -> bytes: + return bytes([255 if b else 0 for b in hash.flatten()]) + + def results(self, img: "Mat", category: ImageCategory, /): + results: List[ImageHashDatabaseIdProviderResult] = [] + + results.extend( + self.lookup_hash( + category, + ImageHashType.AVERAGE, + self.hash_mat_to_bytes(hashers.average(img, self.hash_size)), + ) + ) + results.extend( + self.lookup_hash( + category, + ImageHashType.DIFFERENCE, + self.hash_mat_to_bytes(hashers.difference(img, self.hash_size)), + ) + ) + results.extend( + self.lookup_hash( + category, + ImageHashType.DCT, + self.hash_mat_to_bytes( + hashers.dct(img, self.hash_size, self.high_freq_factor) + ), + ) + ) + + return results + + def result( + self, + img: "Mat", + category: ImageCategory, + /, + *, + hash_type: ImageHashType = ImageHashType.DCT, + ): + return [ + it for it in self.results(img, category) if it.image_hash_type == hash_type + ][0]