From 5c5c1a227d682fca6031892d4b692baa3c2f1214 Mon Sep 17 00:00:00 2001 From: 283375 Date: Thu, 12 Oct 2023 17:05:04 +0800 Subject: [PATCH] wip: arcaea-offline-ocr==0.1.0 API changes, modifier & clear_type support --- .gitignore | 1 + ui/extends/components/ocrQueue.py | 2 +- ui/extends/shared/data.py | 47 +++++++++ ui/extends/tabs/tabOcr/tabOcr_Device.py | 112 +++++++++++---------- ui/implements/tabs/tabOcr/tabOcr_Device.py | 82 ++++++++++----- ui/resources/partnerModifiers.json | 35 +++++++ ui/resources/resources.qrc | 2 + 7 files changed, 200 insertions(+), 81 deletions(-) create mode 100644 ui/extends/shared/data.py create mode 100644 ui/resources/partnerModifiers.json diff --git a/.gitignore b/.gitignore index 2195cba..0cbc18a 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ __debug* arcaea_offline.db arcaea_offline.ini +/data ui/resources/VERSION diff --git a/ui/extends/components/ocrQueue.py b/ui/extends/components/ocrQueue.py index 2dfbd3f..4fb89ed 100644 --- a/ui/extends/components/ocrQueue.py +++ b/ui/extends/components/ocrQueue.py @@ -6,7 +6,7 @@ from arcaea_offline.calculate import calculate_score_range from arcaea_offline.database import Database from arcaea_offline.models import Chart, Score from arcaea_offline_ocr.b30.shared import B30OcrResultItem -from arcaea_offline_ocr.device.shared import DeviceOcrResult +from arcaea_offline_ocr.device.common import DeviceOcrResult from arcaea_offline_ocr.utils import convert_to_srgb from PIL import Image from PIL.ImageQt import ImageQt diff --git a/ui/extends/shared/data.py b/ui/extends/shared/data.py new file mode 100644 index 0000000..683b076 --- /dev/null +++ b/ui/extends/shared/data.py @@ -0,0 +1,47 @@ +import json +import sys +from functools import cached_property +from pathlib import Path +from typing import Literal + +from PySide6.QtCore import QFile + +from .singleton import Singleton + +TPartnerModifier = dict[str, Literal[0, 1, 2]] + + +class Data(metaclass=Singleton): + def __init__(self): + root = Path(sys.argv[0]).parent + self.__dataPath = (root / "data").resolve() + + @property + def dataPath(self): + return self.__dataPath + + @cached_property + def partnerModifiers(self) -> TPartnerModifier: + data = {} + builtinFile = QFile(":/partnerModifiers.json") + builtinFile.open(QFile.OpenModeFlag.ReadOnly) + builtinData = json.loads(str(builtinFile.readAll(), encoding="utf-8")) + builtinFile.close() + data |= builtinData + + customFile = self.dataPath / "partnerModifiers.json" + if customFile.exists(): + with open(customFile, "r", encoding="utf-8") as f: + customData = json.loads(f.read()) + data |= customData + + return data + + def expirePartnerModifiersCache(self): + # expire property caches + # https://stackoverflow.com/a/69367025/16484891, CC BY-SA 4.0 + self.__dict__.pop("partnerModifiers", None) + + @property + def arcaeaPath(self): + return self.dataPath / "Arcaea" diff --git a/ui/extends/tabs/tabOcr/tabOcr_Device.py b/ui/extends/tabs/tabOcr/tabOcr_Device.py index 795ee10..69b031a 100644 --- a/ui/extends/tabs/tabOcr/tabOcr_Device.py +++ b/ui/extends/tabs/tabOcr/tabOcr_Device.py @@ -1,69 +1,58 @@ import contextlib import logging -from typing import Tuple +from typing import Tuple, Type import cv2 +import exif from arcaea_offline.database import Database from arcaea_offline.models import Chart, Score -from arcaea_offline_ocr.device.shared import DeviceOcrResult -from arcaea_offline_ocr.device.v2 import DeviceV2AutoRois, DeviceV2Ocr, DeviceV2Rois -from arcaea_offline_ocr.device.v2.sizes import SizesV1, SizesV2 +from arcaea_offline.utils.partner import KanaeDayNight, kanae_day_night +from arcaea_offline_ocr.device import DeviceOcr, DeviceOcrResult +from arcaea_offline_ocr.device.rois import ( + DeviceRois, + DeviceRoisAuto, + DeviceRoisExtractor, + DeviceRoisMasker, +) +from arcaea_offline_ocr.phash_db import ImagePhashDatabase from arcaea_offline_ocr.utils import imread_unicode from PySide6.QtCore import QDateTime, QFileInfo from ui.extends.components.ocrQueue import OcrRunnable +from ui.extends.shared.data import Data logger = logging.getLogger(__name__) -import exif - -class TabDeviceV2OcrRunnable(OcrRunnable): - def __init__(self, imagePath, device, knnModel, phashDb, *, sizesV2: bool): +class TabDeviceOcrRunnable(OcrRunnable): + def __init__( + self, + imagePath: str, + rois: DeviceRois | Type[DeviceRoisAuto], + masker: DeviceRoisMasker, + knnModel: cv2.ml.KNearest, + phashDb: ImagePhashDatabase, + ): super().__init__() self.imagePath = imagePath - self.device = device + self.rois = rois + self.masker = masker self.knnModel = knnModel self.phashDb = phashDb - self.sizesV2 = sizesV2 def run(self): try: - rois = DeviceV2Rois( - self.device, imread_unicode(self.imagePath, cv2.IMREAD_COLOR) - ) - rois.sizes = ( - SizesV2(self.device.factor) - if self.sizesV2 - else SizesV1(self.device.factor) - ) - ocr = DeviceV2Ocr(self.knnModel, self.phashDb) - result = ocr.ocr(rois) + img = imread_unicode(self.imagePath, cv2.IMREAD_COLOR) + if isinstance(self.rois, type) and issubclass(self.rois, DeviceRoisAuto): + rois = self.rois(img.shape[1], img.shape[0]) + else: + rois = self.rois + extractor = DeviceRoisExtractor(img, rois) + ocr = DeviceOcr(extractor, self.masker, self.knnModel, self.phashDb) + result = ocr.ocr() self.signals.resultReady.emit(result) except Exception: - logger.exception(f"DeviceV2 ocr {self.imagePath} error") - finally: - self.signals.finished.emit() - - -class TabDeviceV2AutoRoisOcrRunnable(OcrRunnable): - def __init__(self, imagePath, knnModel, phashDb, *, sizesV2: bool): - super().__init__() - self.imagePath = imagePath - self.knnModel = knnModel - self.phashDb = phashDb - self.sizesV2 = sizesV2 - - def run(self): - try: - rois = DeviceV2AutoRois(imread_unicode(self.imagePath, cv2.IMREAD_COLOR)) - factor = rois.sizes.factor - rois.sizes = SizesV2(factor) if self.sizesV2 else SizesV1(factor) - ocr = DeviceV2Ocr(self.knnModel, self.phashDb) - result = ocr.ocr(rois) - self.signals.resultReady.emit(result) - except Exception: - logger.exception(f"DeviceV2AutoRois ocr {self.imagePath} error") + logger.exception("DeviceOcr error:") finally: self.signals.finished.emit() @@ -83,7 +72,24 @@ def getImageDate(imagePath: str) -> QDateTime: class ScoreConverter: @staticmethod - def deviceV2(imagePath: str, _, result: DeviceOcrResult) -> Tuple[Chart, Score]: + def device(imagePath: str, _, result: DeviceOcrResult) -> Tuple[Chart, Score]: + partnerModifiers = Data().partnerModifiers + imageDate = getImageDate(imagePath) + + # calculate clear type + if result.partner_id == "50": + dayNight = kanae_day_night(imageDate) + modifier = 1 if dayNight == KanaeDayNight.Day else 2 + else: + modifier = partnerModifiers.get(result.partner_id, 0) + + if result.clear_status == 1 and modifier == 1: + clearType = 4 + elif result.clear_status == 1 and modifier == 2: + clearType = 5 + else: + clearType = result.clear_status + db = Database() score = Score( song_id=result.song_id, @@ -92,16 +98,16 @@ class ScoreConverter: pure=result.pure, far=result.far, lost=result.lost, - date=getImageDate(imagePath).toSecsSinceEpoch(), + date=imageDate.toSecsSinceEpoch(), max_recall=result.max_recall, + modifier=modifier, + clear_type=clearType, comment=f"OCR {QFileInfo(imagePath).fileName()}", ) - chart = db.get_chart(score.song_id, score.rating_class) - if not chart: - chart = Chart( - song_id=result.song_id, - rating_class=result.rating_class, - title=result.song_id, - constant=0.0, - ) + chart = db.get_chart(score.song_id, score.rating_class) or Chart( + song_id=result.song_id, + rating_class=result.rating_class, + title=result.song_id, + constant=0.0, + ) return (chart, score) diff --git a/ui/implements/tabs/tabOcr/tabOcr_Device.py b/ui/implements/tabs/tabOcr/tabOcr_Device.py index 57b1bf9..6c5635e 100644 --- a/ui/implements/tabs/tabOcr/tabOcr_Device.py +++ b/ui/implements/tabs/tabOcr/tabOcr_Device.py @@ -1,20 +1,21 @@ import logging import cv2 +from arcaea_offline_ocr.device.rois import ( + DeviceRoisAutoT1, + DeviceRoisAutoT2, + DeviceRoisMaskerAutoT1, + DeviceRoisMaskerAutoT2, +) from arcaea_offline_ocr.phash_db import ImagePhashDatabase -from PySide6.QtCore import Qt, Slot -from PySide6.QtWidgets import QApplication, QFileDialog, QWidget +from PySide6.QtCore import Slot +from PySide6.QtWidgets import QApplication, QFileDialog, QMessageBox, QWidget from ui.designer.tabs.tabOcr.tabOcr_Device_ui import Ui_TabOcr_Device from ui.extends.components.ocrQueue import OcrQueueModel from ui.extends.shared.language import LanguageChangeEventFilter from ui.extends.shared.settings import KNN_MODEL_FILE, PHASH_DATABASE_FILE -from ui.extends.tabs.tabOcr.tabOcr_Device import ( - ScoreConverter, - TabDeviceV2AutoRoisOcrRunnable, - TabDeviceV2OcrRunnable, -) - +from ui.extends.tabs.tabOcr.tabOcr_Device import ScoreConverter, TabDeviceOcrRunnable logger = logging.getLogger(__name__) @@ -106,9 +107,15 @@ class TabOcr_Device(Ui_TabOcr_Device, QWidget): try: knnModelFile = self.dependencies_knnModelSelector.selectedFiles()[0] self.knnModel = cv2.ml.KNearest.load(knnModelFile) - self.dependencies_knnModelStatusLabel.setText( - f'OK, varCount {self.knnModel.getVarCount()}' - ) + varCount = self.knnModel.getVarCount() + if varCount != 81: + self.dependencies_knnModelStatusLabel.setText( + f'WARN, varCount {varCount}' + ) + else: + self.dependencies_knnModelStatusLabel.setText( + f'OK, varCount {varCount}' + ) except Exception: logger.exception("Error loading knn model:") self.dependencies_knnModelStatusLabel.setText( @@ -150,30 +157,51 @@ class TabOcr_Device(Ui_TabOcr_Device, QWidget): QApplication.processEvents() self.ocrQueue.resizeTableView() + def deviceRois(self): + if self.options_roisUseCustomCheckBox.isChecked(): + ... + else: + selectedPreset = self.options_roisComboBox.currentData() + if selectedPreset == "AutoT1": + return DeviceRoisAutoT1 + elif selectedPreset == "AutoT2": + return DeviceRoisAutoT2 + else: + QMessageBox.critical(self, None, "Select a Rois preset first.") + return None + + def deviceRoisMasker(self): + if self.options_maskerUseCustomCheckBox.isChecked(): + ... + else: + selectedPreset = self.options_maskerComboBox.currentData() + if selectedPreset == "AutoT1": + return DeviceRoisMaskerAutoT1() + elif selectedPreset == "AutoT2": + return DeviceRoisMaskerAutoT2() + else: + QMessageBox.critical(self, None, "Select a Masker preset first.") + return None + @Slot() def on_ocr_startButton_clicked(self): for row in range(self.ocrQueueModel.rowCount()): index = self.ocrQueueModel.index(row, 0) imagePath = index.data(OcrQueueModel.ImagePathRole) - if self.deviceUseAutoFactorCheckBox.checkState() == Qt.CheckState.Checked: - runnable = TabDeviceV2AutoRoisOcrRunnable( - imagePath, - self.knnModel, - self.phashDatabase, - sizesV2=self.deviceSizesV2CheckBox.isChecked(), - ) - else: - runnable = TabDeviceV2OcrRunnable( - imagePath, - self.deviceComboBox.currentData(), - self.knnModel, - self.phashDatabase, - sizesV2=self.deviceSizesV2CheckBox.isChecked(), - ) + + rois = self.deviceRois() + masker = self.deviceRoisMasker() + + if rois is None or masker is None: + return + + runnable = TabDeviceOcrRunnable( + imagePath, rois, masker, self.knnModel, self.phashDatabase + ) self.ocrQueueModel.setData(index, runnable, OcrQueueModel.OcrRunnableRole) self.ocrQueueModel.setData( index, - ScoreConverter.deviceV2, + ScoreConverter.device, OcrQueueModel.ProcessOcrResultFuncRole, ) self.ocrQueueModel.startQueue() diff --git a/ui/resources/partnerModifiers.json b/ui/resources/partnerModifiers.json new file mode 100644 index 0000000..48e8548 --- /dev/null +++ b/ui/resources/partnerModifiers.json @@ -0,0 +1,35 @@ +{ + "__COMMENT__": "1: EASY, 2: HARD", + "0": 1, + "0u": 1, + "7": 2, + "9": 1, + "10": 2, + "10u": 2, + "15": 1, + "16": 1, + "20": 1, + "28": 2, + "28u": 2, + "29": 2, + "29u": 2, + "35": 2, + "36": 2, + "36u": 2, + "37": 2, + "41": 2, + "42": 2, + "42u": 2, + "43": 2, + "43u": 2, + "54": 2, + "55": 2, + "57": 2, + "61": 2, + "64": 2, + "66": 2, + "66u": 2, + "67": 2, + "68": 1, + "70": 2 +} diff --git a/ui/resources/resources.qrc b/ui/resources/resources.qrc index d09e49f..dd4b4b4 100644 --- a/ui/resources/resources.qrc +++ b/ui/resources/resources.qrc @@ -4,6 +4,8 @@ VERSION LICENSE + partnerModifiers.json + images/icon.png images/logo.png images/stepCalculator/stamina.png