refactor: DeviceOcr

This commit is contained in:
283375 2023-10-10 01:45:02 +08:00
parent 6a19ead8d1
commit ede2b4ec51
Signed by: 283375
SSH Key Fingerprint: SHA256:UcX0qg6ZOSDOeieKPGokA5h7soykG61nz2uxuQgVLSk
3 changed files with 74 additions and 14 deletions

View File

@ -10,7 +10,9 @@ class DeviceOcrResult:
far: int
lost: int
score: int
max_recall: int
max_recall: Optional[int] = None
song_id: Optional[str] = None
title: Optional[str] = None
clear_type: Optional[str] = None
song_id_possibility: Optional[float] = None
clear_status: Optional[str] = None
partner_id: Optional[str] = None
partner_id_possibility: Optional[float] = None

View File

@ -1,6 +1,5 @@
import cv2
import numpy as np
from PIL import Image
from ..crop import crop_xywh
from ..ocr import (
@ -98,5 +97,64 @@ class DeviceOcr:
]
return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0]
def lookup_song_id(self):
return self.phash_db.lookup_jacket(
cv2.cvtColor(self.extractor.jacket, cv2.COLOR_BGR2GRAY)
)
def song_id(self):
return self.phash_db.lookup_image(Image.fromarray(self.extractor.jacket))[0]
return self.lookup_song_id()[0]
@staticmethod
def preprocess_char_icon(img_gray: cv2.Mat):
h, w = img_gray.shape[:2]
img = cv2.copyMakeBorder(img_gray, w - h, 0, 0, 0, cv2.BORDER_REPLICATE)
h, w = img.shape[:2]
img = cv2.fillPoly(
img,
[
np.array([[0, 0], [round(w / 2), 0], [0, round(h / 2)]], np.int32),
np.array([[w, 0], [round(w / 2), 0], [w, round(h / 2)]], np.int32),
np.array([[0, h], [round(w / 2), h], [0, round(h / 2)]], np.int32),
np.array([[w, h], [round(w / 2), h], [w, round(h / 2)]], np.int32),
],
(128),
)
return img
def lookup_partner_id(self):
return self.phash_db.lookup_partner_icon(
self.preprocess_char_icon(
cv2.cvtColor(self.extractor.partner_icon, cv2.COLOR_BGR2GRAY)
)
)
def partner_id(self):
return self.lookup_partner_id()[0]
def ocr(self) -> DeviceOcrResult:
rating_class = self.rating_class()
pure = self.pure()
far = self.far()
lost = self.lost()
score = self.score()
max_recall = self.max_recall()
clear_status = self.clear_status()
hash_len = self.phash_db.hash_size**2
song_id, song_id_distance = self.lookup_song_id()
partner_id, partner_id_distance = self.lookup_partner_id()
return DeviceOcrResult(
rating_class=rating_class,
pure=pure,
far=far,
lost=lost,
score=score,
max_recall=max_recall,
song_id=song_id,
song_id_possibility=1 - song_id_distance / hash_len,
clear_status=clear_status,
partner_id=partner_id,
partner_id_possibility=1 - partner_id_distance / hash_len,
)

View File

@ -64,14 +64,14 @@ class ImagePhashDatabase:
self.jacket_ids: List[str] = []
self.jacket_hashes = []
self.partner_ids: List[str] = []
self.partner_hashes = []
self.partner_icon_ids: List[str] = []
self.partner_icon_hashes = []
for id, hash in zip(self.ids, self.hashes):
id_splitted = id.split("||")
if len(id_splitted) > 1 and id_splitted[0] == "partner":
self.partner_ids.append(id)
self.partner_hashes.append(hash)
if len(id_splitted) > 1 and id_splitted[0] == "partner_icon":
self.partner_icon_ids.append(id_splitted[1])
self.partner_icon_hashes.append(hash)
else:
self.jacket_ids.append(id)
self.jacket_hashes.append(hash)
@ -104,13 +104,13 @@ class ImagePhashDatabase:
def lookup_jacket(self, img_gray: cv2.Mat):
return self.lookup_jackets(img_gray)[0]
def lookup_partners(self, img_gray: cv2.Mat, *, limit: int = 5):
def lookup_partner_icons(self, img_gray: cv2.Mat, *, limit: int = 5):
image_hash = self.calculate_phash(img_gray).flatten()
xor_results = [
(id, np.count_nonzero(image_hash ^ h))
for id, h in zip(self.partner_ids, self.partner_hashes)
for id, h in zip(self.partner_icon_ids, self.partner_icon_hashes)
]
return sorted(xor_results, key=lambda r: r[1])[:limit]
def lookup_partner(self, img_gray: cv2.Mat):
return self.lookup_partners(img_gray)[0]
def lookup_partner_icon(self, img_gray: cv2.Mat):
return self.lookup_partner_icons(img_gray)[0]