mirror of
https://github.com/283375/arcaea-offline-ocr.git
synced 2025-07-01 12:26:27 +00:00
Compare commits
4 Commits
1aa71685ce
...
5e0642c832
Author | SHA1 | Date | |
---|---|---|---|
5e0642c832
|
|||
ede2b4ec51
|
|||
6a19ead8d1
|
|||
d13076c667
|
@ -7,7 +7,7 @@ from PIL import Image
|
||||
|
||||
from ....crop import crop_xywh
|
||||
from ....ocr import FixRects, ocr_digits_by_contour_knn, preprocess_hog
|
||||
from ....phash_db import ImagePHashDatabase
|
||||
from ....phash_db import ImagePhashDatabase
|
||||
from ....types import Mat, cv2_ml_KNearest
|
||||
from ....utils import construct_int_xywh_rect
|
||||
from ...shared import B30OcrResultItem
|
||||
@ -20,7 +20,7 @@ class ChieriBotV4Ocr:
|
||||
self,
|
||||
score_knn: cv2_ml_KNearest,
|
||||
pfl_knn: cv2_ml_KNearest,
|
||||
phash_db: ImagePHashDatabase,
|
||||
phash_db: ImagePhashDatabase,
|
||||
factor: Optional[float] = 1.0,
|
||||
):
|
||||
self.__score_knn = score_knn
|
||||
@ -49,7 +49,7 @@ class ChieriBotV4Ocr:
|
||||
return self.__phash_db
|
||||
|
||||
@phash_db.setter
|
||||
def phash_db(self, phash_db: ImagePHashDatabase):
|
||||
def phash_db(self, phash_db: ImagePhashDatabase):
|
||||
self.__phash_db = phash_db
|
||||
|
||||
@property
|
||||
|
@ -10,7 +10,9 @@ class DeviceOcrResult:
|
||||
far: int
|
||||
lost: int
|
||||
score: int
|
||||
max_recall: int
|
||||
max_recall: Optional[int] = None
|
||||
song_id: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
clear_type: Optional[str] = None
|
||||
song_id_possibility: Optional[float] = None
|
||||
clear_status: Optional[str] = None
|
||||
partner_id: Optional[str] = None
|
||||
partner_id_possibility: Optional[float] = None
|
||||
|
@ -1,6 +1,5 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..crop import crop_xywh
|
||||
from ..ocr import (
|
||||
@ -10,18 +9,19 @@ from ..ocr import (
|
||||
preprocess_hog,
|
||||
resize_fill_square,
|
||||
)
|
||||
from ..phash_db import ImagePHashDatabase
|
||||
from .roi.extractor import DeviceRoiExtractor
|
||||
from .roi.masker import DeviceRoiMasker
|
||||
from ..phash_db import ImagePhashDatabase
|
||||
from .common import DeviceOcrResult
|
||||
from .rois.extractor import DeviceRoisExtractor
|
||||
from .rois.masker import DeviceRoisMasker
|
||||
|
||||
|
||||
class DeviceOcr:
|
||||
def __init__(
|
||||
self,
|
||||
extractor: DeviceRoiExtractor,
|
||||
masker: DeviceRoiMasker,
|
||||
extractor: DeviceRoisExtractor,
|
||||
masker: DeviceRoisMasker,
|
||||
knn_model: cv2.ml.KNearest,
|
||||
phash_db: ImagePHashDatabase,
|
||||
phash_db: ImagePhashDatabase,
|
||||
):
|
||||
self.extractor = extractor
|
||||
self.masker = masker
|
||||
@ -97,5 +97,64 @@ class DeviceOcr:
|
||||
]
|
||||
return max(enumerate(results), key=lambda i: np.count_nonzero(i[1]))[0]
|
||||
|
||||
def lookup_song_id(self):
|
||||
return self.phash_db.lookup_jacket(
|
||||
cv2.cvtColor(self.extractor.jacket, cv2.COLOR_BGR2GRAY)
|
||||
)
|
||||
|
||||
def song_id(self):
|
||||
return self.phash_db.lookup_image(Image.fromarray(self.extractor.jacket))[0]
|
||||
return self.lookup_song_id()[0]
|
||||
|
||||
@staticmethod
|
||||
def preprocess_char_icon(img_gray: cv2.Mat):
|
||||
h, w = img_gray.shape[:2]
|
||||
img = cv2.copyMakeBorder(img_gray, w - h, 0, 0, 0, cv2.BORDER_REPLICATE)
|
||||
h, w = img.shape[:2]
|
||||
img = cv2.fillPoly(
|
||||
img,
|
||||
[
|
||||
np.array([[0, 0], [round(w / 2), 0], [0, round(h / 2)]], np.int32),
|
||||
np.array([[w, 0], [round(w / 2), 0], [w, round(h / 2)]], np.int32),
|
||||
np.array([[0, h], [round(w / 2), h], [0, round(h / 2)]], np.int32),
|
||||
np.array([[w, h], [round(w / 2), h], [w, round(h / 2)]], np.int32),
|
||||
],
|
||||
(128),
|
||||
)
|
||||
return img
|
||||
|
||||
def lookup_partner_id(self):
|
||||
return self.phash_db.lookup_partner_icon(
|
||||
self.preprocess_char_icon(
|
||||
cv2.cvtColor(self.extractor.partner_icon, cv2.COLOR_BGR2GRAY)
|
||||
)
|
||||
)
|
||||
|
||||
def partner_id(self):
|
||||
return self.lookup_partner_id()[0]
|
||||
|
||||
def ocr(self) -> DeviceOcrResult:
|
||||
rating_class = self.rating_class()
|
||||
pure = self.pure()
|
||||
far = self.far()
|
||||
lost = self.lost()
|
||||
score = self.score()
|
||||
max_recall = self.max_recall()
|
||||
clear_status = self.clear_status()
|
||||
|
||||
hash_len = self.phash_db.hash_size**2
|
||||
song_id, song_id_distance = self.lookup_song_id()
|
||||
partner_id, partner_id_distance = self.lookup_partner_id()
|
||||
|
||||
return DeviceOcrResult(
|
||||
rating_class=rating_class,
|
||||
pure=pure,
|
||||
far=far,
|
||||
lost=lost,
|
||||
score=score,
|
||||
max_recall=max_recall,
|
||||
song_id=song_id,
|
||||
song_id_possibility=1 - song_id_distance / hash_len,
|
||||
clear_status=clear_status,
|
||||
partner_id=partner_id,
|
||||
partner_id_possibility=1 - partner_id_distance / hash_len,
|
||||
)
|
||||
|
@ -149,7 +149,7 @@ class DeviceRoisMaskerAutoT2(DeviceRoisMaskerAuto):
|
||||
BYD_HSV_MAX = np.array([179, 210, 198], np.uint8)
|
||||
|
||||
MAX_RECALL_HSV_MIN = np.array([125, 0, 0], np.uint8)
|
||||
MAX_RECALL_HSV_MAX = np.array([130, 100, 150], np.uint8)
|
||||
MAX_RECALL_HSV_MAX = np.array([145, 100, 150], np.uint8)
|
||||
|
||||
TRACK_LOST_HSV_MIN = np.array([170, 75, 90], np.uint8)
|
||||
TRACK_LOST_HSV_MAX = np.array([175, 170, 160], np.uint8)
|
||||
|
@ -1,11 +1,12 @@
|
||||
import sqlite3
|
||||
from typing import List, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def phash_opencv(img_gray, hash_size=8, highfreq_factor=4):
|
||||
# type: (cv2.Mat | np.ndarray, int, int) -> np.ndarray
|
||||
# type: (Union[cv2.Mat, np.ndarray], int, int) -> np.ndarray
|
||||
"""
|
||||
Perceptual Hash computation.
|
||||
|
||||
@ -34,7 +35,7 @@ def hamming_distance_sql_function(user_input, db_entry) -> int:
|
||||
)
|
||||
|
||||
|
||||
class ImagePHashDatabase:
|
||||
class ImagePhashDatabase:
|
||||
def __init__(self, db_path: str):
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
self.hash_size = int(
|
||||
@ -53,28 +54,35 @@ class ImagePHashDatabase:
|
||||
).fetchone()[0]
|
||||
)
|
||||
|
||||
# self.conn.create_function(
|
||||
# "HAMMING_DISTANCE",
|
||||
# 2,
|
||||
# hamming_distance_sql_function,
|
||||
# deterministic=True,
|
||||
# )
|
||||
|
||||
self.ids = [i[0] for i in conn.execute("SELECT id FROM hashes").fetchall()]
|
||||
self.ids: List[str] = [
|
||||
i[0] for i in conn.execute("SELECT id FROM hashes").fetchall()
|
||||
]
|
||||
self.hashes_byte = [
|
||||
i[0] for i in conn.execute("SELECT hash FROM hashes").fetchall()
|
||||
]
|
||||
self.hashes = [np.frombuffer(hb, bool) for hb in self.hashes_byte]
|
||||
self.hashes_slice_size = round(len(self.hashes_byte[0]) * 0.25)
|
||||
self.hashes_head = [h[: self.hashes_slice_size] for h in self.hashes]
|
||||
self.hashes_tail = [h[-self.hashes_slice_size :] for h in self.hashes]
|
||||
|
||||
self.jacket_ids: List[str] = []
|
||||
self.jacket_hashes = []
|
||||
self.partner_icon_ids: List[str] = []
|
||||
self.partner_icon_hashes = []
|
||||
|
||||
for id, hash in zip(self.ids, self.hashes):
|
||||
id_splitted = id.split("||")
|
||||
if len(id_splitted) > 1 and id_splitted[0] == "partner_icon":
|
||||
self.partner_icon_ids.append(id_splitted[1])
|
||||
self.partner_icon_hashes.append(hash)
|
||||
else:
|
||||
self.jacket_ids.append(id)
|
||||
self.jacket_hashes.append(hash)
|
||||
|
||||
def calculate_phash(self, img_gray: cv2.Mat):
|
||||
return phash_opencv(
|
||||
img_gray, hash_size=self.hash_size, highfreq_factor=self.highfreq_factor
|
||||
)
|
||||
|
||||
def lookup_hash(self, image_hash: np.ndarray, *, limit: int = 5):
|
||||
image_hash = image_hash.flatten()
|
||||
# image_hash_head = image_hash[: self.hashes_slice_size]
|
||||
# image_hash_tail = image_hash[-self.hashes_slice_size :]
|
||||
# head_xor_results = [image_hash_head ^ h for h in self.hashes]
|
||||
# tail_xor_results = [image_hash_head ^ h for h in self.hashes]
|
||||
xor_results = [
|
||||
(id, np.count_nonzero(image_hash ^ h))
|
||||
for id, h in zip(self.ids, self.hashes)
|
||||
@ -82,7 +90,27 @@ class ImagePHashDatabase:
|
||||
return sorted(xor_results, key=lambda r: r[1])[:limit]
|
||||
|
||||
def lookup_image(self, img_gray: cv2.Mat):
|
||||
image_hash = phash_opencv(
|
||||
img_gray, hash_size=self.hash_size, highfreq_factor=self.highfreq_factor
|
||||
)
|
||||
image_hash = self.calculate_phash(img_gray)
|
||||
return self.lookup_hash(image_hash)[0]
|
||||
|
||||
def lookup_jackets(self, img_gray: cv2.Mat, *, limit: int = 5):
|
||||
image_hash = self.calculate_phash(img_gray).flatten()
|
||||
xor_results = [
|
||||
(id, np.count_nonzero(image_hash ^ h))
|
||||
for id, h in zip(self.jacket_ids, self.jacket_hashes)
|
||||
]
|
||||
return sorted(xor_results, key=lambda r: r[1])[:limit]
|
||||
|
||||
def lookup_jacket(self, img_gray: cv2.Mat):
|
||||
return self.lookup_jackets(img_gray)[0]
|
||||
|
||||
def lookup_partner_icons(self, img_gray: cv2.Mat, *, limit: int = 5):
|
||||
image_hash = self.calculate_phash(img_gray).flatten()
|
||||
xor_results = [
|
||||
(id, np.count_nonzero(image_hash ^ h))
|
||||
for id, h in zip(self.partner_icon_ids, self.partner_icon_hashes)
|
||||
]
|
||||
return sorted(xor_results, key=lambda r: r[1])[:limit]
|
||||
|
||||
def lookup_partner_icon(self, img_gray: cv2.Mat):
|
||||
return self.lookup_partner_icons(img_gray)[0]
|
||||
|
Reference in New Issue
Block a user