2 Commits

Author SHA1 Message Date
2b18906935 refactor!: image hash database provider 2025-06-22 01:28:59 +08:00
abfd37dbef refactor!: OCR text result provider 2025-06-22 00:32:31 +08:00
14 changed files with 492 additions and 405 deletions

View File

@ -1,4 +1,3 @@
from .crop import * from .crop import *
from .device import * from .device import *
from .ocr import *
from .utils import * from .utils import *

View File

@ -4,12 +4,6 @@ import cv2
import numpy as np import numpy as np
from ....crop import crop_xywh from ....crop import crop_xywh
from ....ocr import (
FixRects,
ocr_digits_by_contour_knn,
preprocess_hog,
resize_fill_square,
)
from ....phash_db import ImagePhashDatabase from ....phash_db import ImagePhashDatabase
from ....types import Mat from ....types import Mat
from ...shared import B30OcrResultItem from ...shared import B30OcrResultItem
@ -28,36 +22,21 @@ from .colors import (
PURE_BG_MIN_HSV, PURE_BG_MIN_HSV,
) )
from .rois import ChieriBotV4Rois from .rois import ChieriBotV4Rois
from ....providers.knn import OcrKNearestTextProvider
class ChieriBotV4Ocr: class ChieriBotV4Ocr:
def __init__( def __init__(
self, self,
score_knn: cv2.ml.KNearest, score_knn_provider: OcrKNearestTextProvider,
pfl_knn: cv2.ml.KNearest, pfl_knn_provider: OcrKNearestTextProvider,
phash_db: ImagePhashDatabase, phash_db: ImagePhashDatabase,
factor: float = 1.0, factor: float = 1.0,
): ):
self.__score_knn = score_knn
self.__pfl_knn = pfl_knn
self.__phash_db = phash_db self.__phash_db = phash_db
self.__rois = ChieriBotV4Rois(factor) self.__rois = ChieriBotV4Rois(factor)
self.pfl_knn_provider = pfl_knn_provider
@property self.score_knn_provider = score_knn_provider
def score_knn(self):
return self.__score_knn
@score_knn.setter
def score_knn(self, knn_digits_model: cv2.ml.KNearest):
self.__score_knn = knn_digits_model
@property
def pfl_knn(self):
return self.__pfl_knn
@pfl_knn.setter
def pfl_knn(self, knn_digits_model: cv2.ml.KNearest):
self.__pfl_knn = knn_digits_model
@property @property
def phash_db(self): def phash_db(self):
@ -125,7 +104,9 @@ class ChieriBotV4Ocr:
if rect[3] > score_roi.shape[0] * 0.5: if rect[3] > score_roi.shape[0] * 0.5:
continue continue
score_roi = cv2.fillPoly(score_roi, [contour], 0) score_roi = cv2.fillPoly(score_roi, [contour], 0)
return ocr_digits_by_contour_knn(score_roi, self.score_knn)
ocr_result = self.score_knn_provider.result(score_roi)
return int(ocr_result) if ocr_result else 0
def find_pfl_rects( def find_pfl_rects(
self, component_pfl_processed: Mat self, component_pfl_processed: Mat
@ -203,25 +184,9 @@ class ChieriBotV4Ocr:
pure_far_lost = [] pure_far_lost = []
for pfl_roi_rect in pfl_rects: for pfl_roi_rect in pfl_rects:
roi = crop_xywh(pfl_roi, pfl_roi_rect) roi = crop_xywh(pfl_roi, pfl_roi_rect)
digit_contours, _ = cv2.findContours( result = self.pfl_knn_provider.result(roi)
roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE pure_far_lost.append(int(result) if result else None)
)
digit_rects = [cv2.boundingRect(c) for c in digit_contours]
digit_rects = FixRects.connect_broken(
digit_rects, roi.shape[1], roi.shape[0]
)
digit_rects = FixRects.split_connected(roi, digit_rects)
digit_rects = sorted(digit_rects, key=lambda r: r[0])
digits = []
for digit_rect in digit_rects:
digit = crop_xywh(roi, digit_rect)
digit = resize_fill_square(digit, 20)
digits.append(digit)
samples = preprocess_hog(digits)
_, results, _, _ = self.pfl_knn.findNearest(samples, 4)
results = [str(int(i)) for i in results.ravel()]
pure_far_lost.append(int("".join(results)))
return tuple(pure_far_lost) return tuple(pure_far_lost)
except Exception: except Exception:
return (None, None, None) return (None, None, None)

View File

@ -0,0 +1,6 @@
from .ihdb import ImageHashDatabaseBuildTask, ImageHashesDatabaseBuilder
__all__ = [
"ImageHashDatabaseBuildTask",
"ImageHashesDatabaseBuilder",
]

View File

@ -0,0 +1,112 @@
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Callable, List
import cv2
from arcaea_offline_ocr.core import hashers
from arcaea_offline_ocr.providers import ImageCategory
from arcaea_offline_ocr.providers.ihdb import (
PROP_KEY_BUILT_AT,
PROP_KEY_HASH_SIZE,
PROP_KEY_HIGH_FREQ_FACTOR,
ImageHashDatabaseIdProvider,
ImageHashType,
)
if TYPE_CHECKING:
from sqlite3 import Connection
from arcaea_offline_ocr.types import Mat
def _default_imread_gray(image_path: str):
return cv2.cvtColor(cv2.imread(image_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2GRAY)
@dataclass
class ImageHashDatabaseBuildTask:
image_path: str
image_id: str
category: ImageCategory
imread_function: Callable[[str], "Mat"] = _default_imread_gray
@dataclass
class _ImageHash:
image_id: str
category: ImageCategory
image_hash_type: ImageHashType
hash: bytes
class ImageHashesDatabaseBuilder:
@staticmethod
def __insert_property(conn: "Connection", key: str, value: str):
return conn.execute(
"INSERT INTO properties (key, value) VALUES (?, ?)",
(key, value),
)
@classmethod
def build(
cls,
conn: "Connection",
tasks: List[ImageHashDatabaseBuildTask],
*,
hash_size: int = 16,
high_freq_factor: int = 4,
):
hashes: List[_ImageHash] = []
for task in tasks:
img_gray = task.imread_function(task.image_path)
for hash_type, hash_mat in [
(
ImageHashType.AVERAGE,
hashers.average(img_gray, hash_size),
),
(
ImageHashType.DCT,
hashers.dct(img_gray, hash_size, high_freq_factor),
),
(
ImageHashType.DIFFERENCE,
hashers.difference(img_gray, hash_size),
),
]:
hashes.append(
_ImageHash(
image_id=task.image_id,
image_hash_type=hash_type,
category=task.category,
hash=ImageHashDatabaseIdProvider.hash_mat_to_bytes(hash_mat),
)
)
conn.execute("CREATE TABLE properties (`key` VARCHAR, `value` VARCHAR)")
conn.execute(
"""CREATE TABLE hashes (
`id` VARCHAR,
`category` INTEGER,
`hash_type` INTEGER,
`hash` BLOB
)"""
)
now = datetime.now(tz=timezone.utc)
timestamp = int(now.timestamp() * 1000)
cls.__insert_property(conn, PROP_KEY_HASH_SIZE, str(hash_size))
cls.__insert_property(conn, PROP_KEY_HIGH_FREQ_FACTOR, str(high_freq_factor))
cls.__insert_property(conn, PROP_KEY_BUILT_AT, str(timestamp))
conn.executemany(
"INSERT INTO hashes (`id`, `category`, `hash_type`, `hash`) VALUES (?, ?, ?, ?)",
[
(it.image_id, it.category.value, it.image_hash_type.value, it.hash)
for it in hashes
],
)
conn.commit()

View File

@ -1,18 +0,0 @@
from .builder import ImageHashesDatabaseBuilder
from .index import ImageHashesDatabase, ImageHashesDatabasePropertyMissingError
from .models import (
ImageHashBuildTask,
ImageHashHashType,
ImageHashResult,
ImageHashCategory,
)
__all__ = [
"ImageHashesDatabase",
"ImageHashesDatabasePropertyMissingError",
"ImageHashHashType",
"ImageHashResult",
"ImageHashCategory",
"ImageHashesDatabaseBuilder",
"ImageHashBuildTask",
]

View File

@ -1,85 +0,0 @@
import logging
from datetime import datetime, timezone
from sqlite3 import Connection
from typing import List
from arcaea_offline_ocr.core import hashers
from .index import ImageHashesDatabase
from .models import ImageHash, ImageHashBuildTask, ImageHashHashType
logger = logging.getLogger(__name__)
class ImageHashesDatabaseBuilder:
@staticmethod
def __insert_property(conn: Connection, key: str, value: str):
return conn.execute(
"INSERT INTO properties (key, value) VALUES (?, ?)",
(key, value),
)
@classmethod
def build(
cls,
conn: Connection,
tasks: List[ImageHashBuildTask],
*,
hash_size: int = 16,
high_freq_factor: int = 4,
):
rows: List[ImageHash] = []
for task in tasks:
try:
img_gray = task.imread_function(task.image_path)
for hash_type, hash_mat in [
(
ImageHashHashType.AVERAGE,
hashers.average(img_gray, hash_size),
),
(
ImageHashHashType.DCT,
hashers.dct(img_gray, hash_size, high_freq_factor),
),
(
ImageHashHashType.DIFFERENCE,
hashers.difference(img_gray, hash_size),
),
]:
rows.append(
ImageHash(
hash_type=hash_type,
category=task.category,
label=task.label,
hash=ImageHashesDatabase.hash_mat_to_bytes(hash_mat),
)
)
except Exception:
logger.exception("Error processing task %r", task)
conn.execute("CREATE TABLE properties (`key` VARCHAR, `value` VARCHAR)")
conn.execute(
"CREATE TABLE hashes (`hash_type` INTEGER, `category` INTEGER, `label` VARCHAR, `hash` BLOB)"
)
now = datetime.now(tz=timezone.utc)
timestamp = int(now.timestamp() * 1000)
cls.__insert_property(conn, ImageHashesDatabase.KEY_HASH_SIZE, str(hash_size))
cls.__insert_property(
conn, ImageHashesDatabase.KEY_HIGH_FREQ_FACTOR, str(high_freq_factor)
)
cls.__insert_property(
conn, ImageHashesDatabase.KEY_BUILT_TIMESTAMP, str(timestamp)
)
conn.executemany(
"INSERT INTO hashes (hash_type, category, label, hash) VALUES (?, ?, ?, ?)",
[
(row.hash_type.value, row.category.value, row.label, row.hash)
for row in rows
],
)
conn.commit()

View File

@ -1,144 +0,0 @@
import sqlite3
from datetime import datetime, timezone
from typing import Any, Callable, List, Optional, TypeVar
from arcaea_offline_ocr.core import hashers
from arcaea_offline_ocr.types import Mat
from .models import ImageHashHashType, ImageHashResult, ImageHashCategory
T = TypeVar("T")
def _sql_hamming_distance(hash1: bytes, hash2: bytes):
assert len(hash1) == len(hash2), "hash size does not match!"
count = sum(1 for byte1, byte2 in zip(hash1, hash2) if byte1 != byte2)
return count
class ImageHashesDatabasePropertyMissingError(Exception):
pass
class ImageHashesDatabase:
KEY_HASH_SIZE = "hash_size"
KEY_HIGH_FREQ_FACTOR = "high_freq_factor"
KEY_BUILT_TIMESTAMP = "built_timestamp"
def __init__(self, conn: sqlite3.Connection):
self.conn = conn
self.conn.create_function("HAMMING_DISTANCE", 2, _sql_hamming_distance)
self._hash_size: int = -1
self._high_freq_factor: int = -1
self._built_time: Optional[datetime] = None
self._hashes_count = {
ImageHashCategory.JACKET: 0,
ImageHashCategory.PARTNER_ICON: 0,
}
self._hash_length: int = -1
self._initialize()
@property
def hash_size(self):
return self._hash_size
@property
def high_freq_factor(self):
return self._high_freq_factor
@property
def hash_length(self):
return self._hash_length
def _initialize(self):
def query_property(key, convert_func: Callable[[Any], T]) -> Optional[T]:
result = self.conn.execute(
"SELECT value FROM properties WHERE key = ?",
(key,),
).fetchone()
return convert_func(result[0]) if result is not None else None
def set_hashes_count(category: ImageHashCategory):
self._hashes_count[category] = self.conn.execute(
"SELECT COUNT(DISTINCT label) FROM hashes WHERE category = ?",
(category.value,),
).fetchone()[0]
hash_size = query_property(self.KEY_HASH_SIZE, lambda x: int(x))
if hash_size is None:
raise ImageHashesDatabasePropertyMissingError("hash_size")
self._hash_size = hash_size
high_freq_factor = query_property(self.KEY_HIGH_FREQ_FACTOR, lambda x: int(x))
if high_freq_factor is None:
raise ImageHashesDatabasePropertyMissingError("high_freq_factor")
self._high_freq_factor = high_freq_factor
self._built_time = query_property(
self.KEY_BUILT_TIMESTAMP,
lambda ts: datetime.fromtimestamp(int(ts) / 1000, tz=timezone.utc),
)
set_hashes_count(ImageHashCategory.JACKET)
set_hashes_count(ImageHashCategory.PARTNER_ICON)
self._hash_length = self._hash_size**2
def lookup_hash(
self, category: ImageHashCategory, hash_type: ImageHashHashType, hash: bytes
) -> List[ImageHashResult]:
cursor = self.conn.execute(
"SELECT"
" label,"
" HAMMING_DISTANCE(hash, ?) AS distance"
" FROM hashes"
" WHERE category = ? AND hash_type = ?"
" ORDER BY distance ASC LIMIT 10",
(hash, category.value, hash_type.value),
)
results = []
for label, distance in cursor.fetchall():
results.append(
ImageHashResult(
hash_type=hash_type,
category=category,
label=label,
confidence=(self.hash_length - distance) / self.hash_length,
)
)
return results
@staticmethod
def hash_mat_to_bytes(hash: Mat) -> bytes:
return bytes([255 if b else 0 for b in hash.flatten()])
def identify_image(self, category: ImageHashCategory, img) -> List[ImageHashResult]:
results = []
ahash = hashers.average(img, self.hash_size)
dhash = hashers.difference(img, self.hash_size)
phash = hashers.dct(img, self.hash_size, self.high_freq_factor)
results.extend(
self.lookup_hash(
category, ImageHashHashType.AVERAGE, self.hash_mat_to_bytes(ahash)
)
)
results.extend(
self.lookup_hash(
category, ImageHashHashType.DIFFERENCE, self.hash_mat_to_bytes(dhash)
)
)
results.extend(
self.lookup_hash(
category, ImageHashHashType.DCT, self.hash_mat_to_bytes(phash)
)
)
return results

View File

@ -1,46 +0,0 @@
import dataclasses
from enum import IntEnum
from typing import Callable
import cv2
from arcaea_offline_ocr.types import Mat
class ImageHashHashType(IntEnum):
AVERAGE = 0
DIFFERENCE = 1
DCT = 2
class ImageHashCategory(IntEnum):
JACKET = 0
PARTNER_ICON = 1
@dataclasses.dataclass
class ImageHash:
hash_type: ImageHashHashType
category: ImageHashCategory
label: str
hash: bytes
@dataclasses.dataclass
class ImageHashResult:
hash_type: ImageHashHashType
category: ImageHashCategory
label: str
confidence: float
def _default_imread_gray(image_path: str):
return cv2.cvtColor(cv2.imread(image_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2GRAY)
@dataclasses.dataclass
class ImageHashBuildTask:
image_path: str
category: ImageHashCategory
label: str
imread_function: Callable[[str], Mat] = _default_imread_gray

View File

@ -5,10 +5,10 @@ from typing import Optional
@dataclass @dataclass
class DeviceOcrResult: class DeviceOcrResult:
rating_class: int rating_class: int
pure: int
far: int
lost: int
score: int score: int
pure: Optional[int] = None
far: Optional[int] = None
lost: Optional[int] = None
max_recall: Optional[int] = None max_recall: Optional[int] = None
song_id: Optional[str] = None song_id: Optional[str] = None
song_id_possibility: Optional[float] = None song_id_possibility: Optional[float] = None

View File

@ -1,15 +1,8 @@
import cv2 import cv2
import numpy as np import numpy as np
from ..crop import crop_xywh
from ..ocr import (
FixRects,
ocr_digit_samples_knn,
ocr_digits_by_contour_knn,
preprocess_hog,
resize_fill_square,
)
from ..phash_db import ImagePhashDatabase from ..phash_db import ImagePhashDatabase
from ..providers.knn import OcrKNearestTextProvider
from ..types import Mat from ..types import Mat
from .common import DeviceOcrResult from .common import DeviceOcrResult
from .rois.extractor import DeviceRoisExtractor from .rois.extractor import DeviceRoisExtractor
@ -21,38 +14,37 @@ class DeviceOcr:
self, self,
extractor: DeviceRoisExtractor, extractor: DeviceRoisExtractor,
masker: DeviceRoisMasker, masker: DeviceRoisMasker,
knn_model: cv2.ml.KNearest, knn_provider: OcrKNearestTextProvider,
phash_db: ImagePhashDatabase, phash_db: ImagePhashDatabase,
): ):
self.extractor = extractor self.extractor = extractor
self.masker = masker self.masker = masker
self.knn_model = knn_model self.knn_provider = knn_provider
self.phash_db = phash_db self.phash_db = phash_db
def pfl(self, roi_gray: Mat, factor: float = 1.25): def pfl(self, roi_gray: Mat, factor: float = 1.25):
contours, _ = cv2.findContours( def contour_filter(cnt):
roi_gray, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE return cv2.contourArea(cnt) >= 5 * factor
)
filtered_contours = [c for c in contours if cv2.contourArea(c) >= 5 * factor]
rects = [cv2.boundingRect(c) for c in filtered_contours]
rects = FixRects.connect_broken(rects, roi_gray.shape[1], roi_gray.shape[0])
filtered_rects = [r for r in rects if r[2] >= 5 * factor and r[3] >= 6 * factor] contours = self.knn_provider.contours(roi_gray)
filtered_rects = FixRects.split_connected(roi_gray, filtered_rects) contours_filtered = self.knn_provider.contours(
filtered_rects = sorted(filtered_rects, key=lambda r: r[0]) roi_gray, contours_filter=contour_filter
)
roi_ocr = roi_gray.copy() roi_ocr = roi_gray.copy()
filtered_contours_flattened = {tuple(c.flatten()) for c in filtered_contours} contours_filtered_flattened = {tuple(c.flatten()) for c in contours_filtered}
for contour in contours: for contour in contours:
if tuple(contour.flatten()) in filtered_contours_flattened: if tuple(contour.flatten()) in contours_filtered_flattened:
continue continue
roi_ocr = cv2.fillPoly(roi_ocr, [contour], [0]) roi_ocr = cv2.fillPoly(roi_ocr, [contour], [0])
digit_rois = [
resize_fill_square(crop_xywh(roi_ocr, r), 20) for r in filtered_rects
]
samples = preprocess_hog(digit_rois) ocr_result = self.knn_provider.result(
return ocr_digit_samples_knn(samples, self.knn_model) roi_ocr,
contours_filter=lambda cnt: cv2.contourArea(cnt) >= 5 * factor,
rects_filter=lambda rect: rect[2] >= 5 * factor and rect[3] >= 6 * factor,
)
return int(ocr_result) if ocr_result else 0
def pure(self): def pure(self):
return self.pfl(self.masker.pure(self.extractor.pure)) return self.pfl(self.masker.pure(self.extractor.pure))
@ -65,13 +57,14 @@ class DeviceOcr:
def score(self): def score(self):
roi = self.masker.score(self.extractor.score) roi = self.masker.score(self.extractor.score)
contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) contours = self.knn_provider.contours(roi)
for contour in contours: for contour in contours:
if ( if (
cv2.boundingRect(contour)[3] < roi.shape[0] * 0.6 cv2.boundingRect(contour)[3] < roi.shape[0] * 0.6
): # h < score_component_h * 0.6 ): # h < score_component_h * 0.6
roi = cv2.fillPoly(roi, [contour], [0]) roi = cv2.fillPoly(roi, [contour], [0])
return ocr_digits_by_contour_knn(roi, self.knn_model) ocr_result = self.knn_provider.result(roi)
return int(ocr_result) if ocr_result else 0
def rating_class(self): def rating_class(self):
roi = self.extractor.rating_class roi = self.extractor.rating_class
@ -85,9 +78,10 @@ class DeviceOcr:
return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0] return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0]
def max_recall(self): def max_recall(self):
return ocr_digits_by_contour_knn( ocr_result = self.knn_provider.result(
self.masker.max_recall(self.extractor.max_recall), self.knn_model self.masker.max_recall(self.extractor.max_recall)
) )
return int(ocr_result) if ocr_result else None
def clear_status(self): def clear_status(self):
roi = self.extractor.clear_status roi = self.extractor.clear_status

View File

@ -0,0 +1,12 @@
from .base import ImageCategory, ImageIdProvider, ImageIdProviderResult, OcrTextProvider
from .ihdb import ImageHashDatabaseIdProvider
from .knn import OcrKNearestTextProvider
__all__ = [
"ImageCategory",
"ImageHashDatabaseIdProvider",
"OcrKNearestTextProvider",
"ImageIdProvider",
"OcrTextProvider",
"ImageIdProviderResult",
]

View File

@ -0,0 +1,38 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import IntEnum
from typing import TYPE_CHECKING, Any, Sequence, Optional
if TYPE_CHECKING:
from ..types import Mat
class OcrTextProvider(ABC):
@abstractmethod
def result_raw(self, img: "Mat", /, *args, **kwargs) -> Any: ...
@abstractmethod
def result(self, img: "Mat", /, *args, **kwargs) -> Optional[str]: ...
class ImageCategory(IntEnum):
JACKET = 0
PARTNER_ICON = 1
@dataclass(kw_only=True)
class ImageIdProviderResult:
image_id: str
category: ImageCategory
confidence: float
class ImageIdProvider(ABC):
@abstractmethod
def result(
self, img: "Mat", category: ImageCategory, /, *args, **kwargs
) -> ImageIdProviderResult: ...
@abstractmethod
def results(
self, img: "Mat", category: ImageCategory, /, *args, **kwargs
) -> Sequence[ImageIdProviderResult]: ...

View File

@ -0,0 +1,194 @@
import sqlite3
from dataclasses import dataclass
from datetime import datetime, timezone
from enum import IntEnum
from typing import TYPE_CHECKING, Any, Callable, List, Optional, TypeVar
from arcaea_offline_ocr.core import hashers
from .base import ImageCategory, ImageIdProvider, ImageIdProviderResult
if TYPE_CHECKING:
from arcaea_offline_ocr.types import Mat
T = TypeVar("T")
PROP_KEY_HASH_SIZE = "hash_size"
PROP_KEY_HIGH_FREQ_FACTOR = "high_freq_factor"
PROP_KEY_BUILT_AT = "built_at"
def _sql_hamming_distance(hash1: bytes, hash2: bytes):
assert len(hash1) == len(hash2), "hash size does not match!"
count = sum(1 for byte1, byte2 in zip(hash1, hash2) if byte1 != byte2)
return count
class ImageHashType(IntEnum):
AVERAGE = 0
DIFFERENCE = 1
DCT = 2
@dataclass(kw_only=True)
class ImageHashDatabaseIdProviderResult(ImageIdProviderResult):
image_hash_type: ImageHashType
class MissingPropertiesError(Exception):
keys: List[str]
def __init__(self, keys, *args):
super().__init__(*args)
self.keys = keys
class ImageHashDatabaseIdProvider(ImageIdProvider):
def __init__(self, conn: sqlite3.Connection):
self.conn = conn
self.conn.create_function("HAMMING_DISTANCE", 2, _sql_hamming_distance)
self.properties = {
PROP_KEY_HASH_SIZE: -1,
PROP_KEY_HIGH_FREQ_FACTOR: -1,
PROP_KEY_BUILT_AT: None,
}
self._hashes_count = {
ImageCategory.JACKET: 0,
ImageCategory.PARTNER_ICON: 0,
}
self._hash_length: int = -1
self._initialize()
@property
def hash_size(self) -> int:
return self.properties[PROP_KEY_HASH_SIZE]
@property
def high_freq_factor(self) -> int:
return self.properties[PROP_KEY_HIGH_FREQ_FACTOR]
@property
def built_at(self) -> Optional[datetime]:
return self.properties.get(PROP_KEY_BUILT_AT)
@property
def hash_length(self):
return self._hash_length
def _initialize(self):
def get_property(key, converter: Callable[[Any], T]) -> Optional[T]:
result = self.conn.execute(
"SELECT value FROM properties WHERE key = ?",
(key,),
).fetchone()
return converter(result[0]) if result is not None else None
def set_hashes_count(category: ImageCategory):
self._hashes_count[category] = self.conn.execute(
"SELECT COUNT(DISTINCT `id`) FROM hashes WHERE category = ?",
(category.value,),
).fetchone()[0]
properties_converter_map = {
PROP_KEY_HASH_SIZE: lambda x: int(x),
PROP_KEY_HIGH_FREQ_FACTOR: lambda x: int(x),
PROP_KEY_BUILT_AT: lambda ts: datetime.fromtimestamp(
int(ts) / 1000, tz=timezone.utc
),
}
required_properties = [PROP_KEY_HASH_SIZE, PROP_KEY_HIGH_FREQ_FACTOR]
missing_properties = []
for property_key, converter in properties_converter_map.items():
value = get_property(property_key, converter)
if value is None:
if property_key in required_properties:
missing_properties.append(property_key)
continue
self.properties[property_key] = value
if missing_properties:
raise MissingPropertiesError(keys=missing_properties)
set_hashes_count(ImageCategory.JACKET)
set_hashes_count(ImageCategory.PARTNER_ICON)
self._hash_length = self.hash_size**2
def lookup_hash(
self, category: ImageCategory, hash_type: ImageHashType, hash: bytes
) -> List[ImageHashDatabaseIdProviderResult]:
cursor = self.conn.execute(
"""
SELECT
`id`,
HAMMING_DISTANCE(hash, ?) AS distance
FROM hashes
WHERE category = ? AND hash_type = ?
ORDER BY distance ASC LIMIT 10""",
(hash, category.value, hash_type.value),
)
results = []
for id_, distance in cursor.fetchall():
results.append(
ImageHashDatabaseIdProviderResult(
image_id=id_,
category=category,
confidence=(self.hash_length - distance) / self.hash_length,
image_hash_type=hash_type,
)
)
return results
@staticmethod
def hash_mat_to_bytes(hash: "Mat") -> bytes:
return bytes([255 if b else 0 for b in hash.flatten()])
def results(self, img: "Mat", category: ImageCategory, /):
results: List[ImageHashDatabaseIdProviderResult] = []
results.extend(
self.lookup_hash(
category,
ImageHashType.AVERAGE,
self.hash_mat_to_bytes(hashers.average(img, self.hash_size)),
)
)
results.extend(
self.lookup_hash(
category,
ImageHashType.DIFFERENCE,
self.hash_mat_to_bytes(hashers.difference(img, self.hash_size)),
)
)
results.extend(
self.lookup_hash(
category,
ImageHashType.DCT,
self.hash_mat_to_bytes(
hashers.dct(img, self.hash_size, self.high_freq_factor)
),
)
)
return results
def result(
self,
img: "Mat",
category: ImageCategory,
/,
*,
hash_type: ImageHashType = ImageHashType.DCT,
):
return [
it for it in self.results(img, category) if it.image_hash_type == hash_type
][0]

View File

@ -1,18 +1,19 @@
import logging
import math import math
from typing import Optional, Sequence, Tuple from typing import TYPE_CHECKING, Callable, Optional, Sequence, Tuple
import cv2 import cv2
import numpy as np import numpy as np
from .crop import crop_xywh from ..crop import crop_xywh
from .types import Mat from .base import OcrTextProvider
__all__ = [ if TYPE_CHECKING:
"FixRects", from cv2.ml import KNearest
"preprocess_hog",
"ocr_digits_by_contour_get_samples", from ..types import Mat
"ocr_digits_by_contour_knn",
] logger = logging.getLogger(__name__)
class FixRects: class FixRects:
@ -68,7 +69,7 @@ class FixRects:
@staticmethod @staticmethod
def split_connected( def split_connected(
img_masked: Mat, img_masked: "Mat",
rects: Sequence[Tuple[int, int, int, int]], rects: Sequence[Tuple[int, int, int, int]],
rect_wh_ratio: float = 1.05, rect_wh_ratio: float = 1.05,
width_range_ratio: float = 0.1, width_range_ratio: float = 0.1,
@ -118,7 +119,7 @@ class FixRects:
return return_rects return return_rects
def resize_fill_square(img: Mat, target: int = 20): def resize_fill_square(img: "Mat", target: int = 20):
h, w = img.shape[:2] h, w = img.shape[:2]
if h > w: if h > w:
new_h = target new_h = target
@ -152,29 +153,88 @@ def preprocess_hog(digit_rois):
def ocr_digit_samples_knn(__samples, knn_model: cv2.ml.KNearest, k: int = 4): def ocr_digit_samples_knn(__samples, knn_model: cv2.ml.KNearest, k: int = 4):
_, results, _, _ = knn_model.findNearest(__samples, k) _, results, _, _ = knn_model.findNearest(__samples, k)
result_list = [int(r) for r in results.ravel()] return [int(r) for r in results.ravel()]
result_str = "".join(str(r) for r in result_list if r > -1)
return int(result_str) if result_str else 0
def ocr_digits_by_contour_get_samples(__roi_gray: Mat, size: int): class OcrKNearestTextProvider(OcrTextProvider):
roi = __roi_gray.copy() _ContourFilter = Callable[["Mat"], bool]
contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) _RectsFilter = Callable[[Sequence[int]], bool]
rects = [cv2.boundingRect(c) for c in contours]
rects = FixRects.connect_broken(rects, roi.shape[1], roi.shape[0])
rects = FixRects.split_connected(roi, rects)
rects = sorted(rects, key=lambda r: r[0])
# digit_rois = [cv2.resize(crop_xywh(roi, rect), size) for rect in rects]
digit_rois = [resize_fill_square(crop_xywh(roi, rect), size) for rect in rects]
return preprocess_hog(digit_rois)
def __init__(self, model: "KNearest"):
self.model = model
def ocr_digits_by_contour_knn( def contours(
__roi_gray: Mat, self, img: "Mat", /, *, contours_filter: Optional[_ContourFilter] = None
knn_model: cv2.ml.KNearest, ):
*, cnts, _ = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
k=4, if contours_filter:
size: int = 20, cnts = list(filter(contours_filter, cnts))
) -> int:
samples = ocr_digits_by_contour_get_samples(__roi_gray, size) return cnts
return ocr_digit_samples_knn(samples, knn_model, k)
def result_raw(
self,
img: "Mat",
/,
*,
fix_rects: bool = True,
contours_filter: Optional[_ContourFilter] = None,
rects_filter: Optional[_RectsFilter] = None,
):
"""
:param img: grayscaled roi
"""
try:
cnts, _ = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours_filter:
cnts = list(filter(contours_filter, cnts))
rects = [cv2.boundingRect(cnt) for cnt in cnts]
if fix_rects and rects_filter:
rects = FixRects.connect_broken(rects, img.shape[1], img.shape[0]) # type: ignore
rects = list(filter(rects_filter, rects))
rects = FixRects.split_connected(img, rects)
elif fix_rects:
rects = FixRects.connect_broken(rects, img.shape[1], img.shape[0]) # type: ignore
rects = FixRects.split_connected(img, rects)
elif rects_filter:
rects = list(filter(rects_filter, rects))
rects = sorted(rects, key=lambda r: r[0])
digits = []
for rect in rects:
digit = crop_xywh(img, rect)
digit = resize_fill_square(digit, 20)
digits.append(digit)
samples = preprocess_hog(digits)
return ocr_digit_samples_knn(samples, self.model)
except Exception:
logger.exception("Error occurred during KNearest OCR")
return None
def result(
self,
img: "Mat",
/,
*,
fix_rects: bool = True,
contours_filter: Optional[_ContourFilter] = None,
rects_filter: Optional[_RectsFilter] = None,
):
"""
:param img: grayscaled roi
"""
raw = self.result_raw(
img,
fix_rects=fix_rects,
contours_filter=contours_filter,
rects_filter=rects_filter,
)
return (
"".join(["".join(str(r) for r in raw if r > -1)])
if raw is not None
else None
)