mirror of
https://github.com/283375/arcaea-offline-pyside-ui.git
synced 2025-07-01 12:26:26 +00:00
wip: arcaea-offline-ocr==0.1.0
API changes, modifier & clear_type support
This commit is contained in:
@ -1,69 +1,58 @@
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Type
|
||||
|
||||
import cv2
|
||||
import exif
|
||||
from arcaea_offline.database import Database
|
||||
from arcaea_offline.models import Chart, Score
|
||||
from arcaea_offline_ocr.device.shared import DeviceOcrResult
|
||||
from arcaea_offline_ocr.device.v2 import DeviceV2AutoRois, DeviceV2Ocr, DeviceV2Rois
|
||||
from arcaea_offline_ocr.device.v2.sizes import SizesV1, SizesV2
|
||||
from arcaea_offline.utils.partner import KanaeDayNight, kanae_day_night
|
||||
from arcaea_offline_ocr.device import DeviceOcr, DeviceOcrResult
|
||||
from arcaea_offline_ocr.device.rois import (
|
||||
DeviceRois,
|
||||
DeviceRoisAuto,
|
||||
DeviceRoisExtractor,
|
||||
DeviceRoisMasker,
|
||||
)
|
||||
from arcaea_offline_ocr.phash_db import ImagePhashDatabase
|
||||
from arcaea_offline_ocr.utils import imread_unicode
|
||||
from PySide6.QtCore import QDateTime, QFileInfo
|
||||
|
||||
from ui.extends.components.ocrQueue import OcrRunnable
|
||||
from ui.extends.shared.data import Data
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import exif
|
||||
|
||||
|
||||
class TabDeviceV2OcrRunnable(OcrRunnable):
|
||||
def __init__(self, imagePath, device, knnModel, phashDb, *, sizesV2: bool):
|
||||
class TabDeviceOcrRunnable(OcrRunnable):
|
||||
def __init__(
|
||||
self,
|
||||
imagePath: str,
|
||||
rois: DeviceRois | Type[DeviceRoisAuto],
|
||||
masker: DeviceRoisMasker,
|
||||
knnModel: cv2.ml.KNearest,
|
||||
phashDb: ImagePhashDatabase,
|
||||
):
|
||||
super().__init__()
|
||||
self.imagePath = imagePath
|
||||
self.device = device
|
||||
self.rois = rois
|
||||
self.masker = masker
|
||||
self.knnModel = knnModel
|
||||
self.phashDb = phashDb
|
||||
self.sizesV2 = sizesV2
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
rois = DeviceV2Rois(
|
||||
self.device, imread_unicode(self.imagePath, cv2.IMREAD_COLOR)
|
||||
)
|
||||
rois.sizes = (
|
||||
SizesV2(self.device.factor)
|
||||
if self.sizesV2
|
||||
else SizesV1(self.device.factor)
|
||||
)
|
||||
ocr = DeviceV2Ocr(self.knnModel, self.phashDb)
|
||||
result = ocr.ocr(rois)
|
||||
img = imread_unicode(self.imagePath, cv2.IMREAD_COLOR)
|
||||
if isinstance(self.rois, type) and issubclass(self.rois, DeviceRoisAuto):
|
||||
rois = self.rois(img.shape[1], img.shape[0])
|
||||
else:
|
||||
rois = self.rois
|
||||
extractor = DeviceRoisExtractor(img, rois)
|
||||
ocr = DeviceOcr(extractor, self.masker, self.knnModel, self.phashDb)
|
||||
result = ocr.ocr()
|
||||
self.signals.resultReady.emit(result)
|
||||
except Exception:
|
||||
logger.exception(f"DeviceV2 ocr {self.imagePath} error")
|
||||
finally:
|
||||
self.signals.finished.emit()
|
||||
|
||||
|
||||
class TabDeviceV2AutoRoisOcrRunnable(OcrRunnable):
|
||||
def __init__(self, imagePath, knnModel, phashDb, *, sizesV2: bool):
|
||||
super().__init__()
|
||||
self.imagePath = imagePath
|
||||
self.knnModel = knnModel
|
||||
self.phashDb = phashDb
|
||||
self.sizesV2 = sizesV2
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
rois = DeviceV2AutoRois(imread_unicode(self.imagePath, cv2.IMREAD_COLOR))
|
||||
factor = rois.sizes.factor
|
||||
rois.sizes = SizesV2(factor) if self.sizesV2 else SizesV1(factor)
|
||||
ocr = DeviceV2Ocr(self.knnModel, self.phashDb)
|
||||
result = ocr.ocr(rois)
|
||||
self.signals.resultReady.emit(result)
|
||||
except Exception:
|
||||
logger.exception(f"DeviceV2AutoRois ocr {self.imagePath} error")
|
||||
logger.exception("DeviceOcr error:")
|
||||
finally:
|
||||
self.signals.finished.emit()
|
||||
|
||||
@ -83,7 +72,24 @@ def getImageDate(imagePath: str) -> QDateTime:
|
||||
|
||||
class ScoreConverter:
|
||||
@staticmethod
|
||||
def deviceV2(imagePath: str, _, result: DeviceOcrResult) -> Tuple[Chart, Score]:
|
||||
def device(imagePath: str, _, result: DeviceOcrResult) -> Tuple[Chart, Score]:
|
||||
partnerModifiers = Data().partnerModifiers
|
||||
imageDate = getImageDate(imagePath)
|
||||
|
||||
# calculate clear type
|
||||
if result.partner_id == "50":
|
||||
dayNight = kanae_day_night(imageDate)
|
||||
modifier = 1 if dayNight == KanaeDayNight.Day else 2
|
||||
else:
|
||||
modifier = partnerModifiers.get(result.partner_id, 0)
|
||||
|
||||
if result.clear_status == 1 and modifier == 1:
|
||||
clearType = 4
|
||||
elif result.clear_status == 1 and modifier == 2:
|
||||
clearType = 5
|
||||
else:
|
||||
clearType = result.clear_status
|
||||
|
||||
db = Database()
|
||||
score = Score(
|
||||
song_id=result.song_id,
|
||||
@ -92,16 +98,16 @@ class ScoreConverter:
|
||||
pure=result.pure,
|
||||
far=result.far,
|
||||
lost=result.lost,
|
||||
date=getImageDate(imagePath).toSecsSinceEpoch(),
|
||||
date=imageDate.toSecsSinceEpoch(),
|
||||
max_recall=result.max_recall,
|
||||
modifier=modifier,
|
||||
clear_type=clearType,
|
||||
comment=f"OCR {QFileInfo(imagePath).fileName()}",
|
||||
)
|
||||
chart = db.get_chart(score.song_id, score.rating_class)
|
||||
if not chart:
|
||||
chart = Chart(
|
||||
song_id=result.song_id,
|
||||
rating_class=result.rating_class,
|
||||
title=result.song_id,
|
||||
constant=0.0,
|
||||
)
|
||||
chart = db.get_chart(score.song_id, score.rating_class) or Chart(
|
||||
song_id=result.song_id,
|
||||
rating_class=result.rating_class,
|
||||
title=result.song_id,
|
||||
constant=0.0,
|
||||
)
|
||||
return (chart, score)
|
||||
|
Reference in New Issue
Block a user