diff --git a/src/arcaea_offline_ocr/device/common.py b/src/arcaea_offline_ocr/device/common.py index 5a48d37..69c49cb 100644 --- a/src/arcaea_offline_ocr/device/common.py +++ b/src/arcaea_offline_ocr/device/common.py @@ -10,7 +10,9 @@ class DeviceOcrResult: far: int lost: int score: int - max_recall: int + max_recall: Optional[int] = None song_id: Optional[str] = None - title: Optional[str] = None - clear_type: Optional[str] = None + song_id_possibility: Optional[float] = None + clear_status: Optional[str] = None + partner_id: Optional[str] = None + partner_id_possibility: Optional[float] = None diff --git a/src/arcaea_offline_ocr/device/ocr.py b/src/arcaea_offline_ocr/device/ocr.py index 7232bcb..17927b0 100644 --- a/src/arcaea_offline_ocr/device/ocr.py +++ b/src/arcaea_offline_ocr/device/ocr.py @@ -1,6 +1,5 @@ import cv2 import numpy as np -from PIL import Image from ..crop import crop_xywh from ..ocr import ( @@ -98,5 +97,64 @@ class DeviceOcr: ] return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0] + def lookup_song_id(self): + return self.phash_db.lookup_jacket( + cv2.cvtColor(self.extractor.jacket, cv2.COLOR_BGR2GRAY) + ) + def song_id(self): - return self.phash_db.lookup_image(Image.fromarray(self.extractor.jacket))[0] + return self.lookup_song_id()[0] + + @staticmethod + def preprocess_char_icon(img_gray: cv2.Mat): + h, w = img_gray.shape[:2] + img = cv2.copyMakeBorder(img_gray, w - h, 0, 0, 0, cv2.BORDER_REPLICATE) + h, w = img.shape[:2] + img = cv2.fillPoly( + img, + [ + np.array([[0, 0], [round(w / 2), 0], [0, round(h / 2)]], np.int32), + np.array([[w, 0], [round(w / 2), 0], [w, round(h / 2)]], np.int32), + np.array([[0, h], [round(w / 2), h], [0, round(h / 2)]], np.int32), + np.array([[w, h], [round(w / 2), h], [w, round(h / 2)]], np.int32), + ], + (128), + ) + return img + + def lookup_partner_id(self): + return self.phash_db.lookup_partner_icon( + self.preprocess_char_icon( + cv2.cvtColor(self.extractor.partner_icon, cv2.COLOR_BGR2GRAY) + ) + ) + + def partner_id(self): + return self.lookup_partner_id()[0] + + def ocr(self) -> DeviceOcrResult: + rating_class = self.rating_class() + pure = self.pure() + far = self.far() + lost = self.lost() + score = self.score() + max_recall = self.max_recall() + clear_status = self.clear_status() + + hash_len = self.phash_db.hash_size**2 + song_id, song_id_distance = self.lookup_song_id() + partner_id, partner_id_distance = self.lookup_partner_id() + + return DeviceOcrResult( + rating_class=rating_class, + pure=pure, + far=far, + lost=lost, + score=score, + max_recall=max_recall, + song_id=song_id, + song_id_possibility=1 - song_id_distance / hash_len, + clear_status=clear_status, + partner_id=partner_id, + partner_id_possibility=1 - partner_id_distance / hash_len, + ) diff --git a/src/arcaea_offline_ocr/phash_db.py b/src/arcaea_offline_ocr/phash_db.py index a0a406d..8d95b6b 100644 --- a/src/arcaea_offline_ocr/phash_db.py +++ b/src/arcaea_offline_ocr/phash_db.py @@ -64,14 +64,14 @@ class ImagePhashDatabase: self.jacket_ids: List[str] = [] self.jacket_hashes = [] - self.partner_ids: List[str] = [] - self.partner_hashes = [] + self.partner_icon_ids: List[str] = [] + self.partner_icon_hashes = [] for id, hash in zip(self.ids, self.hashes): id_splitted = id.split("||") - if len(id_splitted) > 1 and id_splitted[0] == "partner": - self.partner_ids.append(id) - self.partner_hashes.append(hash) + if len(id_splitted) > 1 and id_splitted[0] == "partner_icon": + self.partner_icon_ids.append(id_splitted[1]) + self.partner_icon_hashes.append(hash) else: self.jacket_ids.append(id) self.jacket_hashes.append(hash) @@ -104,13 +104,13 @@ class ImagePhashDatabase: def lookup_jacket(self, img_gray: cv2.Mat): return self.lookup_jackets(img_gray)[0] - def lookup_partners(self, img_gray: cv2.Mat, *, limit: int = 5): + def lookup_partner_icons(self, img_gray: cv2.Mat, *, limit: int = 5): image_hash = self.calculate_phash(img_gray).flatten() xor_results = [ (id, np.count_nonzero(image_hash ^ h)) - for id, h in zip(self.partner_ids, self.partner_hashes) + for id, h in zip(self.partner_icon_ids, self.partner_icon_hashes) ] return sorted(xor_results, key=lambda r: r[1])[:limit] - def lookup_partner(self, img_gray: cv2.Mat): - return self.lookup_partners(img_gray)[0] + def lookup_partner_icon(self, img_gray: cv2.Mat): + return self.lookup_partner_icons(img_gray)[0]