From c009a28f92b6329a9c193af91ec05b697ce73d26 Mon Sep 17 00:00:00 2001 From: 283375 Date: Tue, 3 Oct 2023 16:38:11 +0800 Subject: [PATCH] refactor: use opencv to calculate image phash --- src/arcaea_offline_ocr/phash_db.py | 38 ++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/src/arcaea_offline_ocr/phash_db.py b/src/arcaea_offline_ocr/phash_db.py index 6bbcd5b..7ffa94e 100644 --- a/src/arcaea_offline_ocr/phash_db.py +++ b/src/arcaea_offline_ocr/phash_db.py @@ -1,8 +1,32 @@ import sqlite3 -import imagehash +import cv2 import numpy as np -from PIL import Image + + +def phash_opencv(img_gray, hash_size=8, highfreq_factor=4): + # type: (cv2.Mat | np.ndarray, int, int) -> np.ndarray + """ + Perceptual Hash computation. + + Implementation follows http://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html + + Adapted from `imagehash.phash`, pure opencv implementation + + The result is slightly different from `imagehash.phash`. + """ + if hash_size < 2: + raise ValueError("Hash size must be greater than or equal to 2") + + img_size = hash_size * highfreq_factor + image = cv2.resize(img_gray, (img_size, img_size), interpolation=cv2.INTER_LANCZOS4) + image = np.float32(image) + dct = cv2.dct(image, flags=cv2.DCT_ROWS) + dct = cv2.dct(image) + dctlowfreq = dct[:hash_size, :hash_size] + med = np.median(dctlowfreq) + diff = dctlowfreq > med + return diff def hamming_distance_sql_function(user_input, db_entry) -> int: @@ -46,8 +70,8 @@ class ImagePHashDatabase: self.hashes_head = [h[: self.hashes_slice_size] for h in self.hashes] self.hashes_tail = [h[-self.hashes_slice_size :] for h in self.hashes] - def lookup_hash(self, image_hash: imagehash.ImageHash, *, limit: int = 5): - image_hash = image_hash.hash.flatten() + def lookup_hash(self, image_hash: np.ndarray, *, limit: int = 5): + image_hash = image_hash.flatten() # image_hash_head = image_hash[: self.hashes_slice_size] # image_hash_tail = image_hash[-self.hashes_slice_size :] # head_xor_results = [image_hash_head ^ h for h in self.hashes] @@ -58,8 +82,8 @@ class ImagePHashDatabase: ] return sorted(xor_results, key=lambda r: r[1])[:limit] - def lookup_image(self, pil_image: Image.Image): - image_hash = imagehash.phash( - pil_image, hash_size=self.hash_size, highfreq_factor=self.highfreq_factor + def lookup_image(self, img_gray: cv2.Mat): + image_hash = phash_opencv( + img_gray, hash_size=self.hash_size, highfreq_factor=self.highfreq_factor ) return self.lookup_hash(image_hash)[0]