diff --git a/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py b/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py index 6f68adb..c033088 100644 --- a/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py +++ b/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py @@ -3,9 +3,11 @@ from typing import List, Optional, Tuple import cv2 import numpy as np +from PIL import Image from ....crop import crop_xywh from ....ocr import FixRects, ocr_digits_by_contour_knn, preprocess_hog +from ....phash_db import ImagePHashDatabase from ....sift_db import SIFTDatabase from ....types import Mat, cv2_ml_KNearest from ....utils import construct_int_xywh_rect @@ -19,12 +21,12 @@ class ChieriBotV4Ocr: self, score_knn: cv2_ml_KNearest, pfl_knn: cv2_ml_KNearest, - sift_db: SIFTDatabase, + phash_db: ImagePHashDatabase, factor: Optional[float] = 1.0, ): self.__score_knn = score_knn self.__pfl_knn = pfl_knn - self.__sift_db = sift_db + self.__phash_db = phash_db self.__rois = ChieriBotV4Rois(factor) @property @@ -44,12 +46,12 @@ class ChieriBotV4Ocr: self.__pfl_knn = knn_digits_model @property - def sift_db(self): - return self.__sift_db + def phash_db(self): + return self.__phash_db - @sift_db.setter - def sift_db(self, sift_db: SIFTDatabase): - self.__sift_db = sift_db + @phash_db.setter + def phash_db(self, phash_db: ImagePHashDatabase): + self.__phash_db = phash_db @property def rois(self): @@ -98,7 +100,7 @@ class ChieriBotV4Ocr: jacket_roi = cv2.cvtColor( crop_xywh(component_bgr, jacket_rect), cv2.COLOR_BGR2GRAY ) - return self.sift_db.lookup_img(jacket_roi)[0] + return self.phash_db.lookup_image(Image.fromarray(jacket_roi))[0] # def ocr_component_score_paddle(self, component_bgr: Mat) -> int: # # sourcery skip: inline-immediately-returned-variable diff --git a/src/arcaea_offline_ocr/device/v2/ocr.py b/src/arcaea_offline_ocr/device/v2/ocr.py index 27c36b4..cff2f0a 100644 --- a/src/arcaea_offline_ocr/device/v2/ocr.py +++ b/src/arcaea_offline_ocr/device/v2/ocr.py @@ -4,6 +4,7 @@ from typing import Sequence import cv2 import numpy as np +from PIL import Image from ...crop import crop_xywh from ...mask import mask_byd, mask_ftr, mask_gray, mask_prs, mask_pst, mask_white @@ -14,6 +15,7 @@ from ...ocr import ( preprocess_hog, resize_fill_square, ) +from ...phash_db import ImagePHashDatabase from ...sift_db import SIFTDatabase from ...types import Mat, cv2_ml_KNearest from ..shared import DeviceOcrResult @@ -23,9 +25,9 @@ from .shared import MAX_RECALL_CLOSE_KERNEL class DeviceV2Ocr: - def __init__(self, knn_model: cv2_ml_KNearest, sift_db: SIFTDatabase): + def __init__(self, knn_model: cv2_ml_KNearest, phash_db: ImagePHashDatabase): self.__knn_model = knn_model - self.__sift_db = sift_db + self.__phash_db = phash_db @property def knn_model(self): @@ -38,14 +40,14 @@ class DeviceV2Ocr: self.__knn_model = value @property - def sift_db(self): - if not self.__sift_db: - raise ValueError("`sift_db` unset.") - return self.__sift_db + def phash_db(self): + if not self.__phash_db: + raise ValueError("`phash_db` unset.") + return self.__phash_db - @sift_db.setter - def sift_db(self, value: SIFTDatabase): - self.__sift_db = value + @phash_db.setter + def phash_db(self, value: SIFTDatabase): + self.__phash_db = value @lru_cache def _get_digit_widths(self, num_list: Sequence[int], factor: float): @@ -86,7 +88,7 @@ class DeviceV2Ocr: def ocr_song_id(self, rois: DeviceV2Rois): jacket = cv2.cvtColor(rois.jacket, cv2.COLOR_BGR2GRAY) - return self.sift_db.lookup_img(jacket)[0] + return self.phash_db.lookup_image(Image.fromarray(jacket))[0] def ocr_rating_class(self, rois: DeviceV2Rois): roi = cv2.cvtColor(rois.max_recall_rating_class, cv2.COLOR_BGR2HSV) diff --git a/src/arcaea_offline_ocr/phash_db.py b/src/arcaea_offline_ocr/phash_db.py new file mode 100644 index 0000000..6bbcd5b --- /dev/null +++ b/src/arcaea_offline_ocr/phash_db.py @@ -0,0 +1,65 @@ +import sqlite3 + +import imagehash +import numpy as np +from PIL import Image + + +def hamming_distance_sql_function(user_input, db_entry) -> int: + return np.count_nonzero( + np.frombuffer(user_input, bool) ^ np.frombuffer(db_entry, bool) + ) + + +class ImagePHashDatabase: + def __init__(self, db_path: str): + with sqlite3.connect(db_path) as conn: + self.hash_size = int( + conn.execute( + "SELECT value FROM properties WHERE key = 'hash_size'" + ).fetchone()[0] + ) + self.highfreq_factor = int( + conn.execute( + "SELECT value FROM properties WHERE key = 'highfreq_factor'" + ).fetchone()[0] + ) + self.built_timestamp = int( + conn.execute( + "SELECT value FROM properties WHERE key = 'built_timestamp'" + ).fetchone()[0] + ) + + # self.conn.create_function( + # "HAMMING_DISTANCE", + # 2, + # hamming_distance_sql_function, + # deterministic=True, + # ) + + self.ids = [i[0] for i in conn.execute("SELECT id FROM hashes").fetchall()] + self.hashes_byte = [ + i[0] for i in conn.execute("SELECT hash FROM hashes").fetchall() + ] + self.hashes = [np.frombuffer(hb, bool) for hb in self.hashes_byte] + self.hashes_slice_size = round(len(self.hashes_byte[0]) * 0.25) + self.hashes_head = [h[: self.hashes_slice_size] for h in self.hashes] + self.hashes_tail = [h[-self.hashes_slice_size :] for h in self.hashes] + + def lookup_hash(self, image_hash: imagehash.ImageHash, *, limit: int = 5): + image_hash = image_hash.hash.flatten() + # image_hash_head = image_hash[: self.hashes_slice_size] + # image_hash_tail = image_hash[-self.hashes_slice_size :] + # head_xor_results = [image_hash_head ^ h for h in self.hashes] + # tail_xor_results = [image_hash_head ^ h for h in self.hashes] + xor_results = [ + (id, np.count_nonzero(image_hash ^ h)) + for id, h in zip(self.ids, self.hashes) + ] + return sorted(xor_results, key=lambda r: r[1])[:limit] + + def lookup_image(self, pil_image: Image.Image): + image_hash = imagehash.phash( + pil_image, hash_size=self.hash_size, highfreq_factor=self.highfreq_factor + ) + return self.lookup_hash(image_hash)[0]