mirror of
https://github.com/283375/arcaea-offline-ocr.git
synced 2025-04-18 13:00:18 +00:00
refactor: DeviceOcr
This commit is contained in:
parent
6a19ead8d1
commit
ede2b4ec51
@ -10,7 +10,9 @@ class DeviceOcrResult:
|
|||||||
far: int
|
far: int
|
||||||
lost: int
|
lost: int
|
||||||
score: int
|
score: int
|
||||||
max_recall: int
|
max_recall: Optional[int] = None
|
||||||
song_id: Optional[str] = None
|
song_id: Optional[str] = None
|
||||||
title: Optional[str] = None
|
song_id_possibility: Optional[float] = None
|
||||||
clear_type: Optional[str] = None
|
clear_status: Optional[str] = None
|
||||||
|
partner_id: Optional[str] = None
|
||||||
|
partner_id_possibility: Optional[float] = None
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from ..crop import crop_xywh
|
from ..crop import crop_xywh
|
||||||
from ..ocr import (
|
from ..ocr import (
|
||||||
@ -98,5 +97,64 @@ class DeviceOcr:
|
|||||||
]
|
]
|
||||||
return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0]
|
return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0]
|
||||||
|
|
||||||
|
def lookup_song_id(self):
|
||||||
|
return self.phash_db.lookup_jacket(
|
||||||
|
cv2.cvtColor(self.extractor.jacket, cv2.COLOR_BGR2GRAY)
|
||||||
|
)
|
||||||
|
|
||||||
def song_id(self):
|
def song_id(self):
|
||||||
return self.phash_db.lookup_image(Image.fromarray(self.extractor.jacket))[0]
|
return self.lookup_song_id()[0]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def preprocess_char_icon(img_gray: cv2.Mat):
|
||||||
|
h, w = img_gray.shape[:2]
|
||||||
|
img = cv2.copyMakeBorder(img_gray, w - h, 0, 0, 0, cv2.BORDER_REPLICATE)
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
img = cv2.fillPoly(
|
||||||
|
img,
|
||||||
|
[
|
||||||
|
np.array([[0, 0], [round(w / 2), 0], [0, round(h / 2)]], np.int32),
|
||||||
|
np.array([[w, 0], [round(w / 2), 0], [w, round(h / 2)]], np.int32),
|
||||||
|
np.array([[0, h], [round(w / 2), h], [0, round(h / 2)]], np.int32),
|
||||||
|
np.array([[w, h], [round(w / 2), h], [w, round(h / 2)]], np.int32),
|
||||||
|
],
|
||||||
|
(128),
|
||||||
|
)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def lookup_partner_id(self):
|
||||||
|
return self.phash_db.lookup_partner_icon(
|
||||||
|
self.preprocess_char_icon(
|
||||||
|
cv2.cvtColor(self.extractor.partner_icon, cv2.COLOR_BGR2GRAY)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def partner_id(self):
|
||||||
|
return self.lookup_partner_id()[0]
|
||||||
|
|
||||||
|
def ocr(self) -> DeviceOcrResult:
|
||||||
|
rating_class = self.rating_class()
|
||||||
|
pure = self.pure()
|
||||||
|
far = self.far()
|
||||||
|
lost = self.lost()
|
||||||
|
score = self.score()
|
||||||
|
max_recall = self.max_recall()
|
||||||
|
clear_status = self.clear_status()
|
||||||
|
|
||||||
|
hash_len = self.phash_db.hash_size**2
|
||||||
|
song_id, song_id_distance = self.lookup_song_id()
|
||||||
|
partner_id, partner_id_distance = self.lookup_partner_id()
|
||||||
|
|
||||||
|
return DeviceOcrResult(
|
||||||
|
rating_class=rating_class,
|
||||||
|
pure=pure,
|
||||||
|
far=far,
|
||||||
|
lost=lost,
|
||||||
|
score=score,
|
||||||
|
max_recall=max_recall,
|
||||||
|
song_id=song_id,
|
||||||
|
song_id_possibility=1 - song_id_distance / hash_len,
|
||||||
|
clear_status=clear_status,
|
||||||
|
partner_id=partner_id,
|
||||||
|
partner_id_possibility=1 - partner_id_distance / hash_len,
|
||||||
|
)
|
||||||
|
@ -64,14 +64,14 @@ class ImagePhashDatabase:
|
|||||||
|
|
||||||
self.jacket_ids: List[str] = []
|
self.jacket_ids: List[str] = []
|
||||||
self.jacket_hashes = []
|
self.jacket_hashes = []
|
||||||
self.partner_ids: List[str] = []
|
self.partner_icon_ids: List[str] = []
|
||||||
self.partner_hashes = []
|
self.partner_icon_hashes = []
|
||||||
|
|
||||||
for id, hash in zip(self.ids, self.hashes):
|
for id, hash in zip(self.ids, self.hashes):
|
||||||
id_splitted = id.split("||")
|
id_splitted = id.split("||")
|
||||||
if len(id_splitted) > 1 and id_splitted[0] == "partner":
|
if len(id_splitted) > 1 and id_splitted[0] == "partner_icon":
|
||||||
self.partner_ids.append(id)
|
self.partner_icon_ids.append(id_splitted[1])
|
||||||
self.partner_hashes.append(hash)
|
self.partner_icon_hashes.append(hash)
|
||||||
else:
|
else:
|
||||||
self.jacket_ids.append(id)
|
self.jacket_ids.append(id)
|
||||||
self.jacket_hashes.append(hash)
|
self.jacket_hashes.append(hash)
|
||||||
@ -104,13 +104,13 @@ class ImagePhashDatabase:
|
|||||||
def lookup_jacket(self, img_gray: cv2.Mat):
|
def lookup_jacket(self, img_gray: cv2.Mat):
|
||||||
return self.lookup_jackets(img_gray)[0]
|
return self.lookup_jackets(img_gray)[0]
|
||||||
|
|
||||||
def lookup_partners(self, img_gray: cv2.Mat, *, limit: int = 5):
|
def lookup_partner_icons(self, img_gray: cv2.Mat, *, limit: int = 5):
|
||||||
image_hash = self.calculate_phash(img_gray).flatten()
|
image_hash = self.calculate_phash(img_gray).flatten()
|
||||||
xor_results = [
|
xor_results = [
|
||||||
(id, np.count_nonzero(image_hash ^ h))
|
(id, np.count_nonzero(image_hash ^ h))
|
||||||
for id, h in zip(self.partner_ids, self.partner_hashes)
|
for id, h in zip(self.partner_icon_ids, self.partner_icon_hashes)
|
||||||
]
|
]
|
||||||
return sorted(xor_results, key=lambda r: r[1])[:limit]
|
return sorted(xor_results, key=lambda r: r[1])[:limit]
|
||||||
|
|
||||||
def lookup_partner(self, img_gray: cv2.Mat):
|
def lookup_partner_icon(self, img_gray: cv2.Mat):
|
||||||
return self.lookup_partners(img_gray)[0]
|
return self.lookup_partner_icons(img_gray)[0]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user