diff --git a/src/arcaea_offline_ocr/phash_db.py b/src/arcaea_offline_ocr/phash_db.py index 4d0d954..384b244 100644 --- a/src/arcaea_offline_ocr/phash_db.py +++ b/src/arcaea_offline_ocr/phash_db.py @@ -1,11 +1,12 @@ import sqlite3 +from typing import List, Union import cv2 import numpy as np def phash_opencv(img_gray, hash_size=8, highfreq_factor=4): - # type: (cv2.Mat | np.ndarray, int, int) -> np.ndarray + # type: (Union[cv2.Mat, np.ndarray], int, int) -> np.ndarray """ Perceptual Hash computation. @@ -53,28 +54,35 @@ class ImagePHashDatabase: ).fetchone()[0] ) - # self.conn.create_function( - # "HAMMING_DISTANCE", - # 2, - # hamming_distance_sql_function, - # deterministic=True, - # ) - - self.ids = [i[0] for i in conn.execute("SELECT id FROM hashes").fetchall()] + self.ids: List[str] = [ + i[0] for i in conn.execute("SELECT id FROM hashes").fetchall() + ] self.hashes_byte = [ i[0] for i in conn.execute("SELECT hash FROM hashes").fetchall() ] self.hashes = [np.frombuffer(hb, bool) for hb in self.hashes_byte] - self.hashes_slice_size = round(len(self.hashes_byte[0]) * 0.25) - self.hashes_head = [h[: self.hashes_slice_size] for h in self.hashes] - self.hashes_tail = [h[-self.hashes_slice_size :] for h in self.hashes] + + self.jacket_ids: List[str] = [] + self.jacket_hashes = [] + self.partner_ids: List[str] = [] + self.partner_hashes = [] + + for id, hash in zip(self.ids, self.hashes): + id_splitted = id.split("||") + if len(id_splitted) > 1 and id_splitted[0] == "partner": + self.partner_ids.append(id) + self.partner_hashes.append(hash) + else: + self.jacket_ids.append(id) + self.jacket_hashes.append(hash) + + def calculate_phash(self, img_gray: cv2.Mat): + return phash_opencv( + img_gray, hash_size=self.hash_size, highfreq_factor=self.highfreq_factor + ) def lookup_hash(self, image_hash: np.ndarray, *, limit: int = 5): image_hash = image_hash.flatten() - # image_hash_head = image_hash[: self.hashes_slice_size] - # image_hash_tail = image_hash[-self.hashes_slice_size :] - # head_xor_results = [image_hash_head ^ h for h in self.hashes] - # tail_xor_results = [image_hash_head ^ h for h in self.hashes] xor_results = [ (id, np.count_nonzero(image_hash ^ h)) for id, h in zip(self.ids, self.hashes) @@ -82,7 +90,27 @@ class ImagePHashDatabase: return sorted(xor_results, key=lambda r: r[1])[:limit] def lookup_image(self, img_gray: cv2.Mat): - image_hash = phash_opencv( - img_gray, hash_size=self.hash_size, highfreq_factor=self.highfreq_factor - ) + image_hash = self.calculate_phash(img_gray) return self.lookup_hash(image_hash)[0] + + def lookup_jackets(self, img_gray: cv2.Mat, *, limit: int = 5): + image_hash = self.calculate_phash(img_gray).flatten() + xor_results = [ + (id, np.count_nonzero(image_hash ^ h)) + for id, h in zip(self.jacket_ids, self.jacket_hashes) + ] + return sorted(xor_results, key=lambda r: r[1])[:limit] + + def lookup_jacket(self, img_gray: cv2.Mat): + return self.lookup_jackets(img_gray)[0] + + def lookup_partners(self, img_gray: cv2.Mat, *, limit: int = 5): + image_hash = self.calculate_phash(img_gray).flatten() + xor_results = [ + (id, np.count_nonzero(image_hash ^ h)) + for id, h in zip(self.partner_ids, self.partner_hashes) + ] + return sorted(xor_results, key=lambda r: r[1])[:limit] + + def lookup_partner(self, img_gray: cv2.Mat): + return self.lookup_partners(img_gray)[0]