feat!: ImagePHashDatabase

This commit is contained in:
283375 2023-09-27 17:15:21 +08:00
parent 65430a30b8
commit be87a0fbe1
Signed by: 283375
SSH Key Fingerprint: SHA256:UcX0qg6ZOSDOeieKPGokA5h7soykG61nz2uxuQgVLSk
3 changed files with 87 additions and 18 deletions

View File

@ -3,9 +3,11 @@ from typing import List, Optional, Tuple
import cv2
import numpy as np
from PIL import Image
from ....crop import crop_xywh
from ....ocr import FixRects, ocr_digits_by_contour_knn, preprocess_hog
from ....phash_db import ImagePHashDatabase
from ....sift_db import SIFTDatabase
from ....types import Mat, cv2_ml_KNearest
from ....utils import construct_int_xywh_rect
@ -19,12 +21,12 @@ class ChieriBotV4Ocr:
self,
score_knn: cv2_ml_KNearest,
pfl_knn: cv2_ml_KNearest,
sift_db: SIFTDatabase,
phash_db: ImagePHashDatabase,
factor: Optional[float] = 1.0,
):
self.__score_knn = score_knn
self.__pfl_knn = pfl_knn
self.__sift_db = sift_db
self.__phash_db = phash_db
self.__rois = ChieriBotV4Rois(factor)
@property
@ -44,12 +46,12 @@ class ChieriBotV4Ocr:
self.__pfl_knn = knn_digits_model
@property
def sift_db(self):
return self.__sift_db
def phash_db(self):
return self.__phash_db
@sift_db.setter
def sift_db(self, sift_db: SIFTDatabase):
self.__sift_db = sift_db
@phash_db.setter
def phash_db(self, phash_db: ImagePHashDatabase):
self.__phash_db = phash_db
@property
def rois(self):
@ -98,7 +100,7 @@ class ChieriBotV4Ocr:
jacket_roi = cv2.cvtColor(
crop_xywh(component_bgr, jacket_rect), cv2.COLOR_BGR2GRAY
)
return self.sift_db.lookup_img(jacket_roi)[0]
return self.phash_db.lookup_image(Image.fromarray(jacket_roi))[0]
# def ocr_component_score_paddle(self, component_bgr: Mat) -> int:
# # sourcery skip: inline-immediately-returned-variable

View File

@ -4,6 +4,7 @@ from typing import Sequence
import cv2
import numpy as np
from PIL import Image
from ...crop import crop_xywh
from ...mask import mask_byd, mask_ftr, mask_gray, mask_prs, mask_pst, mask_white
@ -14,6 +15,7 @@ from ...ocr import (
preprocess_hog,
resize_fill_square,
)
from ...phash_db import ImagePHashDatabase
from ...sift_db import SIFTDatabase
from ...types import Mat, cv2_ml_KNearest
from ..shared import DeviceOcrResult
@ -23,9 +25,9 @@ from .shared import MAX_RECALL_CLOSE_KERNEL
class DeviceV2Ocr:
def __init__(self, knn_model: cv2_ml_KNearest, sift_db: SIFTDatabase):
def __init__(self, knn_model: cv2_ml_KNearest, phash_db: ImagePHashDatabase):
self.__knn_model = knn_model
self.__sift_db = sift_db
self.__phash_db = phash_db
@property
def knn_model(self):
@ -38,14 +40,14 @@ class DeviceV2Ocr:
self.__knn_model = value
@property
def sift_db(self):
if not self.__sift_db:
raise ValueError("`sift_db` unset.")
return self.__sift_db
def phash_db(self):
if not self.__phash_db:
raise ValueError("`phash_db` unset.")
return self.__phash_db
@sift_db.setter
def sift_db(self, value: SIFTDatabase):
self.__sift_db = value
@phash_db.setter
def phash_db(self, value: SIFTDatabase):
self.__phash_db = value
@lru_cache
def _get_digit_widths(self, num_list: Sequence[int], factor: float):
@ -86,7 +88,7 @@ class DeviceV2Ocr:
def ocr_song_id(self, rois: DeviceV2Rois):
jacket = cv2.cvtColor(rois.jacket, cv2.COLOR_BGR2GRAY)
return self.sift_db.lookup_img(jacket)[0]
return self.phash_db.lookup_image(Image.fromarray(jacket))[0]
def ocr_rating_class(self, rois: DeviceV2Rois):
roi = cv2.cvtColor(rois.max_recall_rating_class, cv2.COLOR_BGR2HSV)

View File

@ -0,0 +1,65 @@
import sqlite3
import imagehash
import numpy as np
from PIL import Image
def hamming_distance_sql_function(user_input, db_entry) -> int:
return np.count_nonzero(
np.frombuffer(user_input, bool) ^ np.frombuffer(db_entry, bool)
)
class ImagePHashDatabase:
def __init__(self, db_path: str):
with sqlite3.connect(db_path) as conn:
self.hash_size = int(
conn.execute(
"SELECT value FROM properties WHERE key = 'hash_size'"
).fetchone()[0]
)
self.highfreq_factor = int(
conn.execute(
"SELECT value FROM properties WHERE key = 'highfreq_factor'"
).fetchone()[0]
)
self.built_timestamp = int(
conn.execute(
"SELECT value FROM properties WHERE key = 'built_timestamp'"
).fetchone()[0]
)
# self.conn.create_function(
# "HAMMING_DISTANCE",
# 2,
# hamming_distance_sql_function,
# deterministic=True,
# )
self.ids = [i[0] for i in conn.execute("SELECT id FROM hashes").fetchall()]
self.hashes_byte = [
i[0] for i in conn.execute("SELECT hash FROM hashes").fetchall()
]
self.hashes = [np.frombuffer(hb, bool) for hb in self.hashes_byte]
self.hashes_slice_size = round(len(self.hashes_byte[0]) * 0.25)
self.hashes_head = [h[: self.hashes_slice_size] for h in self.hashes]
self.hashes_tail = [h[-self.hashes_slice_size :] for h in self.hashes]
def lookup_hash(self, image_hash: imagehash.ImageHash, *, limit: int = 5):
image_hash = image_hash.hash.flatten()
# image_hash_head = image_hash[: self.hashes_slice_size]
# image_hash_tail = image_hash[-self.hashes_slice_size :]
# head_xor_results = [image_hash_head ^ h for h in self.hashes]
# tail_xor_results = [image_hash_head ^ h for h in self.hashes]
xor_results = [
(id, np.count_nonzero(image_hash ^ h))
for id, h in zip(self.ids, self.hashes)
]
return sorted(xor_results, key=lambda r: r[1])[:limit]
def lookup_image(self, pil_image: Image.Image):
image_hash = imagehash.phash(
pil_image, hash_size=self.hash_size, highfreq_factor=self.highfreq_factor
)
return self.lookup_hash(image_hash)[0]