wip: arcaea-offline-ocr==0.1.0

API changes, modifier & clear_type support
This commit is contained in:
283375 2023-10-12 17:05:04 +08:00
parent cde8a047a7
commit 5c5c1a227d
Signed by: 283375
SSH Key Fingerprint: SHA256:UcX0qg6ZOSDOeieKPGokA5h7soykG61nz2uxuQgVLSk
7 changed files with 200 additions and 81 deletions

1
.gitignore vendored
View File

@ -3,6 +3,7 @@ __debug*
arcaea_offline.db
arcaea_offline.ini
/data
ui/resources/VERSION

View File

@ -6,7 +6,7 @@ from arcaea_offline.calculate import calculate_score_range
from arcaea_offline.database import Database
from arcaea_offline.models import Chart, Score
from arcaea_offline_ocr.b30.shared import B30OcrResultItem
from arcaea_offline_ocr.device.shared import DeviceOcrResult
from arcaea_offline_ocr.device.common import DeviceOcrResult
from arcaea_offline_ocr.utils import convert_to_srgb
from PIL import Image
from PIL.ImageQt import ImageQt

47
ui/extends/shared/data.py Normal file
View File

@ -0,0 +1,47 @@
import json
import sys
from functools import cached_property
from pathlib import Path
from typing import Literal
from PySide6.QtCore import QFile
from .singleton import Singleton
TPartnerModifier = dict[str, Literal[0, 1, 2]]
class Data(metaclass=Singleton):
def __init__(self):
root = Path(sys.argv[0]).parent
self.__dataPath = (root / "data").resolve()
@property
def dataPath(self):
return self.__dataPath
@cached_property
def partnerModifiers(self) -> TPartnerModifier:
data = {}
builtinFile = QFile(":/partnerModifiers.json")
builtinFile.open(QFile.OpenModeFlag.ReadOnly)
builtinData = json.loads(str(builtinFile.readAll(), encoding="utf-8"))
builtinFile.close()
data |= builtinData
customFile = self.dataPath / "partnerModifiers.json"
if customFile.exists():
with open(customFile, "r", encoding="utf-8") as f:
customData = json.loads(f.read())
data |= customData
return data
def expirePartnerModifiersCache(self):
# expire property caches
# https://stackoverflow.com/a/69367025/16484891, CC BY-SA 4.0
self.__dict__.pop("partnerModifiers", None)
@property
def arcaeaPath(self):
return self.dataPath / "Arcaea"

View File

@ -1,69 +1,58 @@
import contextlib
import logging
from typing import Tuple
from typing import Tuple, Type
import cv2
import exif
from arcaea_offline.database import Database
from arcaea_offline.models import Chart, Score
from arcaea_offline_ocr.device.shared import DeviceOcrResult
from arcaea_offline_ocr.device.v2 import DeviceV2AutoRois, DeviceV2Ocr, DeviceV2Rois
from arcaea_offline_ocr.device.v2.sizes import SizesV1, SizesV2
from arcaea_offline.utils.partner import KanaeDayNight, kanae_day_night
from arcaea_offline_ocr.device import DeviceOcr, DeviceOcrResult
from arcaea_offline_ocr.device.rois import (
DeviceRois,
DeviceRoisAuto,
DeviceRoisExtractor,
DeviceRoisMasker,
)
from arcaea_offline_ocr.phash_db import ImagePhashDatabase
from arcaea_offline_ocr.utils import imread_unicode
from PySide6.QtCore import QDateTime, QFileInfo
from ui.extends.components.ocrQueue import OcrRunnable
from ui.extends.shared.data import Data
logger = logging.getLogger(__name__)
import exif
class TabDeviceV2OcrRunnable(OcrRunnable):
def __init__(self, imagePath, device, knnModel, phashDb, *, sizesV2: bool):
class TabDeviceOcrRunnable(OcrRunnable):
def __init__(
self,
imagePath: str,
rois: DeviceRois | Type[DeviceRoisAuto],
masker: DeviceRoisMasker,
knnModel: cv2.ml.KNearest,
phashDb: ImagePhashDatabase,
):
super().__init__()
self.imagePath = imagePath
self.device = device
self.rois = rois
self.masker = masker
self.knnModel = knnModel
self.phashDb = phashDb
self.sizesV2 = sizesV2
def run(self):
try:
rois = DeviceV2Rois(
self.device, imread_unicode(self.imagePath, cv2.IMREAD_COLOR)
)
rois.sizes = (
SizesV2(self.device.factor)
if self.sizesV2
else SizesV1(self.device.factor)
)
ocr = DeviceV2Ocr(self.knnModel, self.phashDb)
result = ocr.ocr(rois)
img = imread_unicode(self.imagePath, cv2.IMREAD_COLOR)
if isinstance(self.rois, type) and issubclass(self.rois, DeviceRoisAuto):
rois = self.rois(img.shape[1], img.shape[0])
else:
rois = self.rois
extractor = DeviceRoisExtractor(img, rois)
ocr = DeviceOcr(extractor, self.masker, self.knnModel, self.phashDb)
result = ocr.ocr()
self.signals.resultReady.emit(result)
except Exception:
logger.exception(f"DeviceV2 ocr {self.imagePath} error")
finally:
self.signals.finished.emit()
class TabDeviceV2AutoRoisOcrRunnable(OcrRunnable):
def __init__(self, imagePath, knnModel, phashDb, *, sizesV2: bool):
super().__init__()
self.imagePath = imagePath
self.knnModel = knnModel
self.phashDb = phashDb
self.sizesV2 = sizesV2
def run(self):
try:
rois = DeviceV2AutoRois(imread_unicode(self.imagePath, cv2.IMREAD_COLOR))
factor = rois.sizes.factor
rois.sizes = SizesV2(factor) if self.sizesV2 else SizesV1(factor)
ocr = DeviceV2Ocr(self.knnModel, self.phashDb)
result = ocr.ocr(rois)
self.signals.resultReady.emit(result)
except Exception:
logger.exception(f"DeviceV2AutoRois ocr {self.imagePath} error")
logger.exception("DeviceOcr error:")
finally:
self.signals.finished.emit()
@ -83,7 +72,24 @@ def getImageDate(imagePath: str) -> QDateTime:
class ScoreConverter:
@staticmethod
def deviceV2(imagePath: str, _, result: DeviceOcrResult) -> Tuple[Chart, Score]:
def device(imagePath: str, _, result: DeviceOcrResult) -> Tuple[Chart, Score]:
partnerModifiers = Data().partnerModifiers
imageDate = getImageDate(imagePath)
# calculate clear type
if result.partner_id == "50":
dayNight = kanae_day_night(imageDate)
modifier = 1 if dayNight == KanaeDayNight.Day else 2
else:
modifier = partnerModifiers.get(result.partner_id, 0)
if result.clear_status == 1 and modifier == 1:
clearType = 4
elif result.clear_status == 1 and modifier == 2:
clearType = 5
else:
clearType = result.clear_status
db = Database()
score = Score(
song_id=result.song_id,
@ -92,16 +98,16 @@ class ScoreConverter:
pure=result.pure,
far=result.far,
lost=result.lost,
date=getImageDate(imagePath).toSecsSinceEpoch(),
date=imageDate.toSecsSinceEpoch(),
max_recall=result.max_recall,
modifier=modifier,
clear_type=clearType,
comment=f"OCR {QFileInfo(imagePath).fileName()}",
)
chart = db.get_chart(score.song_id, score.rating_class)
if not chart:
chart = Chart(
song_id=result.song_id,
rating_class=result.rating_class,
title=result.song_id,
constant=0.0,
)
chart = db.get_chart(score.song_id, score.rating_class) or Chart(
song_id=result.song_id,
rating_class=result.rating_class,
title=result.song_id,
constant=0.0,
)
return (chart, score)

View File

@ -1,20 +1,21 @@
import logging
import cv2
from arcaea_offline_ocr.device.rois import (
DeviceRoisAutoT1,
DeviceRoisAutoT2,
DeviceRoisMaskerAutoT1,
DeviceRoisMaskerAutoT2,
)
from arcaea_offline_ocr.phash_db import ImagePhashDatabase
from PySide6.QtCore import Qt, Slot
from PySide6.QtWidgets import QApplication, QFileDialog, QWidget
from PySide6.QtCore import Slot
from PySide6.QtWidgets import QApplication, QFileDialog, QMessageBox, QWidget
from ui.designer.tabs.tabOcr.tabOcr_Device_ui import Ui_TabOcr_Device
from ui.extends.components.ocrQueue import OcrQueueModel
from ui.extends.shared.language import LanguageChangeEventFilter
from ui.extends.shared.settings import KNN_MODEL_FILE, PHASH_DATABASE_FILE
from ui.extends.tabs.tabOcr.tabOcr_Device import (
ScoreConverter,
TabDeviceV2AutoRoisOcrRunnable,
TabDeviceV2OcrRunnable,
)
from ui.extends.tabs.tabOcr.tabOcr_Device import ScoreConverter, TabDeviceOcrRunnable
logger = logging.getLogger(__name__)
@ -106,9 +107,15 @@ class TabOcr_Device(Ui_TabOcr_Device, QWidget):
try:
knnModelFile = self.dependencies_knnModelSelector.selectedFiles()[0]
self.knnModel = cv2.ml.KNearest.load(knnModelFile)
self.dependencies_knnModelStatusLabel.setText(
f'<font color="green">OK</font>, varCount {self.knnModel.getVarCount()}'
)
varCount = self.knnModel.getVarCount()
if varCount != 81:
self.dependencies_knnModelStatusLabel.setText(
f'<font color="darkorange">WARN</font>, varCount {varCount}'
)
else:
self.dependencies_knnModelStatusLabel.setText(
f'<font color="green">OK</font>, varCount {varCount}'
)
except Exception:
logger.exception("Error loading knn model:")
self.dependencies_knnModelStatusLabel.setText(
@ -150,30 +157,51 @@ class TabOcr_Device(Ui_TabOcr_Device, QWidget):
QApplication.processEvents()
self.ocrQueue.resizeTableView()
def deviceRois(self):
if self.options_roisUseCustomCheckBox.isChecked():
...
else:
selectedPreset = self.options_roisComboBox.currentData()
if selectedPreset == "AutoT1":
return DeviceRoisAutoT1
elif selectedPreset == "AutoT2":
return DeviceRoisAutoT2
else:
QMessageBox.critical(self, None, "Select a Rois preset first.")
return None
def deviceRoisMasker(self):
if self.options_maskerUseCustomCheckBox.isChecked():
...
else:
selectedPreset = self.options_maskerComboBox.currentData()
if selectedPreset == "AutoT1":
return DeviceRoisMaskerAutoT1()
elif selectedPreset == "AutoT2":
return DeviceRoisMaskerAutoT2()
else:
QMessageBox.critical(self, None, "Select a Masker preset first.")
return None
@Slot()
def on_ocr_startButton_clicked(self):
for row in range(self.ocrQueueModel.rowCount()):
index = self.ocrQueueModel.index(row, 0)
imagePath = index.data(OcrQueueModel.ImagePathRole)
if self.deviceUseAutoFactorCheckBox.checkState() == Qt.CheckState.Checked:
runnable = TabDeviceV2AutoRoisOcrRunnable(
imagePath,
self.knnModel,
self.phashDatabase,
sizesV2=self.deviceSizesV2CheckBox.isChecked(),
)
else:
runnable = TabDeviceV2OcrRunnable(
imagePath,
self.deviceComboBox.currentData(),
self.knnModel,
self.phashDatabase,
sizesV2=self.deviceSizesV2CheckBox.isChecked(),
)
rois = self.deviceRois()
masker = self.deviceRoisMasker()
if rois is None or masker is None:
return
runnable = TabDeviceOcrRunnable(
imagePath, rois, masker, self.knnModel, self.phashDatabase
)
self.ocrQueueModel.setData(index, runnable, OcrQueueModel.OcrRunnableRole)
self.ocrQueueModel.setData(
index,
ScoreConverter.deviceV2,
ScoreConverter.device,
OcrQueueModel.ProcessOcrResultFuncRole,
)
self.ocrQueueModel.startQueue()

View File

@ -0,0 +1,35 @@
{
"__COMMENT__": "1: EASY, 2: HARD",
"0": 1,
"0u": 1,
"7": 2,
"9": 1,
"10": 2,
"10u": 2,
"15": 1,
"16": 1,
"20": 1,
"28": 2,
"28u": 2,
"29": 2,
"29u": 2,
"35": 2,
"36": 2,
"36u": 2,
"37": 2,
"41": 2,
"42": 2,
"42u": 2,
"43": 2,
"43u": 2,
"54": 2,
"55": 2,
"57": 2,
"61": 2,
"64": 2,
"66": 2,
"66u": 2,
"67": 2,
"68": 1,
"70": 2
}

View File

@ -4,6 +4,8 @@
<file>VERSION</file>
<file>LICENSE</file>
<file>partnerModifiers.json</file>
<file>images/icon.png</file>
<file>images/logo.png</file>
<file>images/stepCalculator/stamina.png</file>