mirror of
https://github.com/283375/arcaea-offline-pyside-ui.git
synced 2025-04-04 10:20:18 +00:00
wip: arcaea-offline-ocr==0.1.0
API changes, modifier & clear_type support
This commit is contained in:
parent
cde8a047a7
commit
5c5c1a227d
1
.gitignore
vendored
1
.gitignore
vendored
@ -3,6 +3,7 @@ __debug*
|
||||
|
||||
arcaea_offline.db
|
||||
arcaea_offline.ini
|
||||
/data
|
||||
|
||||
ui/resources/VERSION
|
||||
|
||||
|
@ -6,7 +6,7 @@ from arcaea_offline.calculate import calculate_score_range
|
||||
from arcaea_offline.database import Database
|
||||
from arcaea_offline.models import Chart, Score
|
||||
from arcaea_offline_ocr.b30.shared import B30OcrResultItem
|
||||
from arcaea_offline_ocr.device.shared import DeviceOcrResult
|
||||
from arcaea_offline_ocr.device.common import DeviceOcrResult
|
||||
from arcaea_offline_ocr.utils import convert_to_srgb
|
||||
from PIL import Image
|
||||
from PIL.ImageQt import ImageQt
|
||||
|
47
ui/extends/shared/data.py
Normal file
47
ui/extends/shared/data.py
Normal file
@ -0,0 +1,47 @@
|
||||
import json
|
||||
import sys
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from PySide6.QtCore import QFile
|
||||
|
||||
from .singleton import Singleton
|
||||
|
||||
TPartnerModifier = dict[str, Literal[0, 1, 2]]
|
||||
|
||||
|
||||
class Data(metaclass=Singleton):
|
||||
def __init__(self):
|
||||
root = Path(sys.argv[0]).parent
|
||||
self.__dataPath = (root / "data").resolve()
|
||||
|
||||
@property
|
||||
def dataPath(self):
|
||||
return self.__dataPath
|
||||
|
||||
@cached_property
|
||||
def partnerModifiers(self) -> TPartnerModifier:
|
||||
data = {}
|
||||
builtinFile = QFile(":/partnerModifiers.json")
|
||||
builtinFile.open(QFile.OpenModeFlag.ReadOnly)
|
||||
builtinData = json.loads(str(builtinFile.readAll(), encoding="utf-8"))
|
||||
builtinFile.close()
|
||||
data |= builtinData
|
||||
|
||||
customFile = self.dataPath / "partnerModifiers.json"
|
||||
if customFile.exists():
|
||||
with open(customFile, "r", encoding="utf-8") as f:
|
||||
customData = json.loads(f.read())
|
||||
data |= customData
|
||||
|
||||
return data
|
||||
|
||||
def expirePartnerModifiersCache(self):
|
||||
# expire property caches
|
||||
# https://stackoverflow.com/a/69367025/16484891, CC BY-SA 4.0
|
||||
self.__dict__.pop("partnerModifiers", None)
|
||||
|
||||
@property
|
||||
def arcaeaPath(self):
|
||||
return self.dataPath / "Arcaea"
|
@ -1,69 +1,58 @@
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Type
|
||||
|
||||
import cv2
|
||||
import exif
|
||||
from arcaea_offline.database import Database
|
||||
from arcaea_offline.models import Chart, Score
|
||||
from arcaea_offline_ocr.device.shared import DeviceOcrResult
|
||||
from arcaea_offline_ocr.device.v2 import DeviceV2AutoRois, DeviceV2Ocr, DeviceV2Rois
|
||||
from arcaea_offline_ocr.device.v2.sizes import SizesV1, SizesV2
|
||||
from arcaea_offline.utils.partner import KanaeDayNight, kanae_day_night
|
||||
from arcaea_offline_ocr.device import DeviceOcr, DeviceOcrResult
|
||||
from arcaea_offline_ocr.device.rois import (
|
||||
DeviceRois,
|
||||
DeviceRoisAuto,
|
||||
DeviceRoisExtractor,
|
||||
DeviceRoisMasker,
|
||||
)
|
||||
from arcaea_offline_ocr.phash_db import ImagePhashDatabase
|
||||
from arcaea_offline_ocr.utils import imread_unicode
|
||||
from PySide6.QtCore import QDateTime, QFileInfo
|
||||
|
||||
from ui.extends.components.ocrQueue import OcrRunnable
|
||||
from ui.extends.shared.data import Data
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import exif
|
||||
|
||||
|
||||
class TabDeviceV2OcrRunnable(OcrRunnable):
|
||||
def __init__(self, imagePath, device, knnModel, phashDb, *, sizesV2: bool):
|
||||
class TabDeviceOcrRunnable(OcrRunnable):
|
||||
def __init__(
|
||||
self,
|
||||
imagePath: str,
|
||||
rois: DeviceRois | Type[DeviceRoisAuto],
|
||||
masker: DeviceRoisMasker,
|
||||
knnModel: cv2.ml.KNearest,
|
||||
phashDb: ImagePhashDatabase,
|
||||
):
|
||||
super().__init__()
|
||||
self.imagePath = imagePath
|
||||
self.device = device
|
||||
self.rois = rois
|
||||
self.masker = masker
|
||||
self.knnModel = knnModel
|
||||
self.phashDb = phashDb
|
||||
self.sizesV2 = sizesV2
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
rois = DeviceV2Rois(
|
||||
self.device, imread_unicode(self.imagePath, cv2.IMREAD_COLOR)
|
||||
)
|
||||
rois.sizes = (
|
||||
SizesV2(self.device.factor)
|
||||
if self.sizesV2
|
||||
else SizesV1(self.device.factor)
|
||||
)
|
||||
ocr = DeviceV2Ocr(self.knnModel, self.phashDb)
|
||||
result = ocr.ocr(rois)
|
||||
img = imread_unicode(self.imagePath, cv2.IMREAD_COLOR)
|
||||
if isinstance(self.rois, type) and issubclass(self.rois, DeviceRoisAuto):
|
||||
rois = self.rois(img.shape[1], img.shape[0])
|
||||
else:
|
||||
rois = self.rois
|
||||
extractor = DeviceRoisExtractor(img, rois)
|
||||
ocr = DeviceOcr(extractor, self.masker, self.knnModel, self.phashDb)
|
||||
result = ocr.ocr()
|
||||
self.signals.resultReady.emit(result)
|
||||
except Exception:
|
||||
logger.exception(f"DeviceV2 ocr {self.imagePath} error")
|
||||
finally:
|
||||
self.signals.finished.emit()
|
||||
|
||||
|
||||
class TabDeviceV2AutoRoisOcrRunnable(OcrRunnable):
|
||||
def __init__(self, imagePath, knnModel, phashDb, *, sizesV2: bool):
|
||||
super().__init__()
|
||||
self.imagePath = imagePath
|
||||
self.knnModel = knnModel
|
||||
self.phashDb = phashDb
|
||||
self.sizesV2 = sizesV2
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
rois = DeviceV2AutoRois(imread_unicode(self.imagePath, cv2.IMREAD_COLOR))
|
||||
factor = rois.sizes.factor
|
||||
rois.sizes = SizesV2(factor) if self.sizesV2 else SizesV1(factor)
|
||||
ocr = DeviceV2Ocr(self.knnModel, self.phashDb)
|
||||
result = ocr.ocr(rois)
|
||||
self.signals.resultReady.emit(result)
|
||||
except Exception:
|
||||
logger.exception(f"DeviceV2AutoRois ocr {self.imagePath} error")
|
||||
logger.exception("DeviceOcr error:")
|
||||
finally:
|
||||
self.signals.finished.emit()
|
||||
|
||||
@ -83,7 +72,24 @@ def getImageDate(imagePath: str) -> QDateTime:
|
||||
|
||||
class ScoreConverter:
|
||||
@staticmethod
|
||||
def deviceV2(imagePath: str, _, result: DeviceOcrResult) -> Tuple[Chart, Score]:
|
||||
def device(imagePath: str, _, result: DeviceOcrResult) -> Tuple[Chart, Score]:
|
||||
partnerModifiers = Data().partnerModifiers
|
||||
imageDate = getImageDate(imagePath)
|
||||
|
||||
# calculate clear type
|
||||
if result.partner_id == "50":
|
||||
dayNight = kanae_day_night(imageDate)
|
||||
modifier = 1 if dayNight == KanaeDayNight.Day else 2
|
||||
else:
|
||||
modifier = partnerModifiers.get(result.partner_id, 0)
|
||||
|
||||
if result.clear_status == 1 and modifier == 1:
|
||||
clearType = 4
|
||||
elif result.clear_status == 1 and modifier == 2:
|
||||
clearType = 5
|
||||
else:
|
||||
clearType = result.clear_status
|
||||
|
||||
db = Database()
|
||||
score = Score(
|
||||
song_id=result.song_id,
|
||||
@ -92,16 +98,16 @@ class ScoreConverter:
|
||||
pure=result.pure,
|
||||
far=result.far,
|
||||
lost=result.lost,
|
||||
date=getImageDate(imagePath).toSecsSinceEpoch(),
|
||||
date=imageDate.toSecsSinceEpoch(),
|
||||
max_recall=result.max_recall,
|
||||
modifier=modifier,
|
||||
clear_type=clearType,
|
||||
comment=f"OCR {QFileInfo(imagePath).fileName()}",
|
||||
)
|
||||
chart = db.get_chart(score.song_id, score.rating_class)
|
||||
if not chart:
|
||||
chart = Chart(
|
||||
song_id=result.song_id,
|
||||
rating_class=result.rating_class,
|
||||
title=result.song_id,
|
||||
constant=0.0,
|
||||
)
|
||||
chart = db.get_chart(score.song_id, score.rating_class) or Chart(
|
||||
song_id=result.song_id,
|
||||
rating_class=result.rating_class,
|
||||
title=result.song_id,
|
||||
constant=0.0,
|
||||
)
|
||||
return (chart, score)
|
||||
|
@ -1,20 +1,21 @@
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
from arcaea_offline_ocr.device.rois import (
|
||||
DeviceRoisAutoT1,
|
||||
DeviceRoisAutoT2,
|
||||
DeviceRoisMaskerAutoT1,
|
||||
DeviceRoisMaskerAutoT2,
|
||||
)
|
||||
from arcaea_offline_ocr.phash_db import ImagePhashDatabase
|
||||
from PySide6.QtCore import Qt, Slot
|
||||
from PySide6.QtWidgets import QApplication, QFileDialog, QWidget
|
||||
from PySide6.QtCore import Slot
|
||||
from PySide6.QtWidgets import QApplication, QFileDialog, QMessageBox, QWidget
|
||||
|
||||
from ui.designer.tabs.tabOcr.tabOcr_Device_ui import Ui_TabOcr_Device
|
||||
from ui.extends.components.ocrQueue import OcrQueueModel
|
||||
from ui.extends.shared.language import LanguageChangeEventFilter
|
||||
from ui.extends.shared.settings import KNN_MODEL_FILE, PHASH_DATABASE_FILE
|
||||
from ui.extends.tabs.tabOcr.tabOcr_Device import (
|
||||
ScoreConverter,
|
||||
TabDeviceV2AutoRoisOcrRunnable,
|
||||
TabDeviceV2OcrRunnable,
|
||||
)
|
||||
|
||||
from ui.extends.tabs.tabOcr.tabOcr_Device import ScoreConverter, TabDeviceOcrRunnable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -106,9 +107,15 @@ class TabOcr_Device(Ui_TabOcr_Device, QWidget):
|
||||
try:
|
||||
knnModelFile = self.dependencies_knnModelSelector.selectedFiles()[0]
|
||||
self.knnModel = cv2.ml.KNearest.load(knnModelFile)
|
||||
self.dependencies_knnModelStatusLabel.setText(
|
||||
f'<font color="green">OK</font>, varCount {self.knnModel.getVarCount()}'
|
||||
)
|
||||
varCount = self.knnModel.getVarCount()
|
||||
if varCount != 81:
|
||||
self.dependencies_knnModelStatusLabel.setText(
|
||||
f'<font color="darkorange">WARN</font>, varCount {varCount}'
|
||||
)
|
||||
else:
|
||||
self.dependencies_knnModelStatusLabel.setText(
|
||||
f'<font color="green">OK</font>, varCount {varCount}'
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error loading knn model:")
|
||||
self.dependencies_knnModelStatusLabel.setText(
|
||||
@ -150,30 +157,51 @@ class TabOcr_Device(Ui_TabOcr_Device, QWidget):
|
||||
QApplication.processEvents()
|
||||
self.ocrQueue.resizeTableView()
|
||||
|
||||
def deviceRois(self):
|
||||
if self.options_roisUseCustomCheckBox.isChecked():
|
||||
...
|
||||
else:
|
||||
selectedPreset = self.options_roisComboBox.currentData()
|
||||
if selectedPreset == "AutoT1":
|
||||
return DeviceRoisAutoT1
|
||||
elif selectedPreset == "AutoT2":
|
||||
return DeviceRoisAutoT2
|
||||
else:
|
||||
QMessageBox.critical(self, None, "Select a Rois preset first.")
|
||||
return None
|
||||
|
||||
def deviceRoisMasker(self):
|
||||
if self.options_maskerUseCustomCheckBox.isChecked():
|
||||
...
|
||||
else:
|
||||
selectedPreset = self.options_maskerComboBox.currentData()
|
||||
if selectedPreset == "AutoT1":
|
||||
return DeviceRoisMaskerAutoT1()
|
||||
elif selectedPreset == "AutoT2":
|
||||
return DeviceRoisMaskerAutoT2()
|
||||
else:
|
||||
QMessageBox.critical(self, None, "Select a Masker preset first.")
|
||||
return None
|
||||
|
||||
@Slot()
|
||||
def on_ocr_startButton_clicked(self):
|
||||
for row in range(self.ocrQueueModel.rowCount()):
|
||||
index = self.ocrQueueModel.index(row, 0)
|
||||
imagePath = index.data(OcrQueueModel.ImagePathRole)
|
||||
if self.deviceUseAutoFactorCheckBox.checkState() == Qt.CheckState.Checked:
|
||||
runnable = TabDeviceV2AutoRoisOcrRunnable(
|
||||
imagePath,
|
||||
self.knnModel,
|
||||
self.phashDatabase,
|
||||
sizesV2=self.deviceSizesV2CheckBox.isChecked(),
|
||||
)
|
||||
else:
|
||||
runnable = TabDeviceV2OcrRunnable(
|
||||
imagePath,
|
||||
self.deviceComboBox.currentData(),
|
||||
self.knnModel,
|
||||
self.phashDatabase,
|
||||
sizesV2=self.deviceSizesV2CheckBox.isChecked(),
|
||||
)
|
||||
|
||||
rois = self.deviceRois()
|
||||
masker = self.deviceRoisMasker()
|
||||
|
||||
if rois is None or masker is None:
|
||||
return
|
||||
|
||||
runnable = TabDeviceOcrRunnable(
|
||||
imagePath, rois, masker, self.knnModel, self.phashDatabase
|
||||
)
|
||||
self.ocrQueueModel.setData(index, runnable, OcrQueueModel.OcrRunnableRole)
|
||||
self.ocrQueueModel.setData(
|
||||
index,
|
||||
ScoreConverter.deviceV2,
|
||||
ScoreConverter.device,
|
||||
OcrQueueModel.ProcessOcrResultFuncRole,
|
||||
)
|
||||
self.ocrQueueModel.startQueue()
|
||||
|
35
ui/resources/partnerModifiers.json
Normal file
35
ui/resources/partnerModifiers.json
Normal file
@ -0,0 +1,35 @@
|
||||
{
|
||||
"__COMMENT__": "1: EASY, 2: HARD",
|
||||
"0": 1,
|
||||
"0u": 1,
|
||||
"7": 2,
|
||||
"9": 1,
|
||||
"10": 2,
|
||||
"10u": 2,
|
||||
"15": 1,
|
||||
"16": 1,
|
||||
"20": 1,
|
||||
"28": 2,
|
||||
"28u": 2,
|
||||
"29": 2,
|
||||
"29u": 2,
|
||||
"35": 2,
|
||||
"36": 2,
|
||||
"36u": 2,
|
||||
"37": 2,
|
||||
"41": 2,
|
||||
"42": 2,
|
||||
"42u": 2,
|
||||
"43": 2,
|
||||
"43u": 2,
|
||||
"54": 2,
|
||||
"55": 2,
|
||||
"57": 2,
|
||||
"61": 2,
|
||||
"64": 2,
|
||||
"66": 2,
|
||||
"66u": 2,
|
||||
"67": 2,
|
||||
"68": 1,
|
||||
"70": 2
|
||||
}
|
@ -4,6 +4,8 @@
|
||||
<file>VERSION</file>
|
||||
<file>LICENSE</file>
|
||||
|
||||
<file>partnerModifiers.json</file>
|
||||
|
||||
<file>images/icon.png</file>
|
||||
<file>images/logo.png</file>
|
||||
<file>images/stepCalculator/stamina.png</file>
|
||||
|
Loading…
x
Reference in New Issue
Block a user