mirror of
https://github.com/283375/arcaea-offline-pyside-ui.git
synced 2025-04-11 13:40:17 +00:00
461 lines
15 KiB
Python
461 lines
15 KiB
Python
import logging
|
|
from enum import IntEnum
|
|
from typing import Any, Callable, Optional, overload
|
|
|
|
from arcaea_offline.calculate import calculate_score_range
|
|
from arcaea_offline.database import Database
|
|
from arcaea_offline.models import Chart, Score
|
|
from arcaea_offline_ocr.b30.shared import B30OcrResultItem
|
|
from arcaea_offline_ocr.device.common import DeviceOcrResult
|
|
from arcaea_offline_ocr.utils import convert_to_srgb
|
|
from PIL import Image
|
|
from PIL.ImageQt import ImageQt
|
|
from PySide6.QtCore import (
|
|
QAbstractListModel,
|
|
QAbstractTableModel,
|
|
QCoreApplication,
|
|
QFileInfo,
|
|
QModelIndex,
|
|
QObject,
|
|
QRunnable,
|
|
Qt,
|
|
QThreadPool,
|
|
Signal,
|
|
Slot,
|
|
)
|
|
from PySide6.QtGui import QImage, QPixmap
|
|
|
|
from ui.extends.shared.delegates.chartDelegate import ChartDelegate
|
|
from ui.extends.shared.delegates.imageDelegate import ImageDelegate
|
|
from ui.extends.shared.delegates.scoreDelegate import ScoreDelegate
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class OcrRunnableSignals(QObject):
|
|
rowId: int = -1
|
|
|
|
resultReady = Signal("QVariant")
|
|
finished = Signal()
|
|
|
|
|
|
class OcrRunnable(QRunnable):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.signals = OcrRunnableSignals()
|
|
|
|
|
|
class IccOption(IntEnum):
|
|
Ignore = 0
|
|
UsePIL = 1
|
|
TryFix = 2
|
|
|
|
|
|
class OcrQueueModel(QAbstractListModel):
|
|
ImagePathRole = Qt.ItemDataRole.UserRole + 1
|
|
ImageQImageRole = Qt.ItemDataRole.UserRole + 2
|
|
ImagePixmapRole = Qt.ItemDataRole.UserRole + 3
|
|
|
|
OcrResultRole = Qt.ItemDataRole.UserRole + 10
|
|
ScoreRole = Qt.ItemDataRole.UserRole + 11
|
|
ChartRole = Qt.ItemDataRole.UserRole + 12
|
|
ScoreValidateOkRole = Qt.ItemDataRole.UserRole + 13
|
|
|
|
OcrRunnableRole = Qt.ItemDataRole.UserRole + 20
|
|
ProcessOcrResultFuncRole = (
|
|
Qt.ItemDataRole.UserRole + 21
|
|
) # Callable[[imageStr, DeviceOcrResult], tuple[Chart, Score]]
|
|
|
|
started = Signal()
|
|
progress = Signal(int)
|
|
finished = Signal()
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent)
|
|
self.__db = Database()
|
|
self.__items: list[dict[int, Any]] = []
|
|
self.__iccOption = IccOption.UsePIL
|
|
|
|
self.__taskFinishedNum = 0
|
|
|
|
@property
|
|
def imagePaths(self):
|
|
return [item.get(self.ImagePathRole) for item in self.__items]
|
|
|
|
def clear(self):
|
|
self.beginResetModel()
|
|
self.beginRemoveRows(QModelIndex(), 0, self.rowCount() - 1)
|
|
self.__items.clear()
|
|
self.endRemoveRows()
|
|
self.__taskFinishedNum = 0
|
|
self.endResetModel()
|
|
|
|
def rowCount(self, *args):
|
|
return len(self.__items)
|
|
|
|
def data(self, index, role):
|
|
if (
|
|
index.isValid()
|
|
and 0 <= index.row() < self.rowCount()
|
|
and index.column() == 0
|
|
):
|
|
return self.__items[index.row()].get(role)
|
|
return None
|
|
|
|
def setData(self, index: QModelIndex, value: Any, role: int):
|
|
if not 0 <= index.row() < self.rowCount():
|
|
return False
|
|
|
|
item = self.__items[index.row()]
|
|
updateRole = None
|
|
|
|
if role == self.OcrResultRole:
|
|
item[self.OcrResultRole] = value
|
|
updateRole = role
|
|
|
|
if role == self.ChartRole and isinstance(value, Chart):
|
|
item[self.ChartRole] = value
|
|
self.updateScoreValidateOk(index.row())
|
|
updateRole = role
|
|
|
|
if role == self.ScoreRole and isinstance(value, Score):
|
|
item[self.ScoreRole] = value
|
|
self.updateScoreValidateOk(index.row())
|
|
updateRole = role
|
|
|
|
if role == self.ScoreValidateOkRole and isinstance(value, bool):
|
|
item[self.ScoreValidateOkRole] = value
|
|
updateRole = role
|
|
|
|
if role == self.OcrRunnableRole and isinstance(value, OcrRunnable):
|
|
item[self.OcrRunnableRole] = value
|
|
updateRole = role
|
|
|
|
if role == self.ProcessOcrResultFuncRole and callable(value):
|
|
item[self.ProcessOcrResultFuncRole] = value
|
|
updateRole = role
|
|
|
|
if updateRole is not None:
|
|
self.dataChanged.emit(index, index, [updateRole])
|
|
return True
|
|
else:
|
|
logger.warning(
|
|
f"{repr(self)} setData at row {index.row()} with role {role} and value {value} rejected."
|
|
)
|
|
return False
|
|
|
|
@property
|
|
def iccOption(self):
|
|
return self.__iccOption
|
|
|
|
@iccOption.setter
|
|
def iccOption(self, opt: IccOption):
|
|
self.__iccOption = opt
|
|
|
|
@overload
|
|
def addItem(
|
|
self,
|
|
image: str,
|
|
runnable: OcrRunnable = None,
|
|
process_func: Callable[[Optional[str], QImage, Any], Score] = None,
|
|
):
|
|
...
|
|
|
|
@overload
|
|
def addItem(
|
|
self,
|
|
image: QImage,
|
|
runnable: OcrRunnable = None,
|
|
process_func: Callable[[Optional[str], QImage, Any], Score] = None,
|
|
):
|
|
...
|
|
|
|
def addItem(
|
|
self,
|
|
image,
|
|
runnable=None,
|
|
process_func=None,
|
|
):
|
|
if isinstance(image, str):
|
|
if image in self.imagePaths or not QFileInfo(image).exists():
|
|
logger.warning(f"Attempting to add an invalid file {image}")
|
|
return
|
|
imagePath = image
|
|
if self.iccOption == IccOption.TryFix:
|
|
img = Image.open(image)
|
|
img = convert_to_srgb(img)
|
|
qImage = ImageQt(img)
|
|
elif self.iccOption == IccOption.UsePIL:
|
|
img = Image.open(image)
|
|
qImage = ImageQt(img)
|
|
else:
|
|
qImage = QImage(image)
|
|
|
|
qPixmap = QPixmap(qImage)
|
|
elif isinstance(image, QImage):
|
|
imagePath = None
|
|
qImage = image.copy()
|
|
qPixmap = QPixmap(qImage)
|
|
else:
|
|
raise ValueError("Unsupported type for `image`")
|
|
|
|
self.beginInsertRows(QModelIndex(), self.rowCount(), self.rowCount())
|
|
self.__items.append(
|
|
{
|
|
self.ImagePathRole: imagePath,
|
|
self.ImageQImageRole: qImage,
|
|
self.ImagePixmapRole: qPixmap,
|
|
self.OcrResultRole: None,
|
|
self.ScoreRole: None,
|
|
self.ChartRole: None,
|
|
self.ScoreValidateOkRole: False,
|
|
self.OcrRunnableRole: runnable,
|
|
self.ProcessOcrResultFuncRole: process_func,
|
|
}
|
|
)
|
|
self.endInsertRows()
|
|
|
|
def updateOcrResult(self, row: int, result: Any) -> bool:
|
|
if not 0 <= row < self.rowCount():
|
|
return False
|
|
|
|
index = self.index(row, 0)
|
|
imagePath: str = index.data(self.ImagePathRole)
|
|
qImage: QImage = index.data(self.ImageQImageRole)
|
|
logger.debug(f"update request: {result}@row{row}")
|
|
processOcrResultFunc = index.data(self.ProcessOcrResultFuncRole)
|
|
|
|
chart, scoreInsert = processOcrResultFunc(imagePath, qImage, result)
|
|
|
|
self.setData(index, result, self.OcrResultRole)
|
|
self.setData(index, chart, self.ChartRole)
|
|
self.setData(index, scoreInsert, self.ScoreRole)
|
|
return True
|
|
|
|
@Slot(DeviceOcrResult)
|
|
def ocrTaskReady(self, result: DeviceOcrResult):
|
|
row = self.sender().rowId
|
|
self.updateOcrResult(row, result)
|
|
|
|
@Slot()
|
|
def ocrTaskFinished(self):
|
|
self.__taskFinishedNum += 1
|
|
self.progress.emit(self.__taskFinishedNum)
|
|
if self.__taskFinishedNum == self.__taskNum:
|
|
self.finished.emit()
|
|
|
|
def startQueue(self):
|
|
self.__taskNum = self.rowCount()
|
|
self.__taskFinishedNum = 0
|
|
self.started.emit()
|
|
for row in range(self.rowCount()):
|
|
modelIndex = self.index(row, 0)
|
|
runnable: OcrRunnable = modelIndex.data(self.OcrRunnableRole)
|
|
runnable.signals.rowId = row
|
|
runnable.signals.resultReady.connect(self.ocrTaskReady)
|
|
runnable.signals.finished.connect(self.ocrTaskFinished)
|
|
QThreadPool.globalInstance().start(runnable)
|
|
|
|
def updateScoreValidateOk(self, row: int):
|
|
if not 0 <= row < self.rowCount():
|
|
return
|
|
|
|
index = self.index(row, 0)
|
|
chart = index.data(self.ChartRole)
|
|
score = index.data(self.ScoreRole)
|
|
if (
|
|
isinstance(chart, Chart)
|
|
and isinstance(score, Score)
|
|
and chart.notes is not None
|
|
and score.pure is not None
|
|
and score.far is not None
|
|
):
|
|
scoreRange = calculate_score_range(chart.notes, score.pure, score.far)
|
|
scoreValidateOk = scoreRange[0] <= score.score <= scoreRange[1]
|
|
self.setData(index, scoreValidateOk, self.ScoreValidateOkRole)
|
|
else:
|
|
self.setData(index, False, self.ScoreValidateOkRole)
|
|
|
|
def acceptItem(self, row: int, ignoreValidate: bool = False):
|
|
if not 0 <= row < self.rowCount():
|
|
return
|
|
|
|
item = self.__items[row]
|
|
score = item[self.ScoreRole]
|
|
if not isinstance(score, Score) or (
|
|
not item[self.ScoreValidateOkRole] and not ignoreValidate
|
|
):
|
|
return
|
|
|
|
try:
|
|
self.__db.insert_score(score)
|
|
self.beginRemoveRows(QModelIndex(), row, row)
|
|
self.__items.pop(row)
|
|
self.endRemoveRows()
|
|
return
|
|
except Exception as e:
|
|
logger.exception(f"Error accepting {repr(item)}")
|
|
return
|
|
|
|
def acceptItems(self, __rows: list[int], ignoreValidate: bool = False):
|
|
items = sorted(__rows, reverse=True)
|
|
[self.acceptItem(item, ignoreValidate) for item in items]
|
|
|
|
def acceptAllItems(self, ignoreValidate: bool = False):
|
|
self.acceptItems([*range(self.rowCount())], ignoreValidate)
|
|
|
|
def removeItem(self, row: int):
|
|
if not 0 <= row < self.rowCount():
|
|
return
|
|
|
|
self.beginRemoveRows(QModelIndex(), row, row)
|
|
self.__items.pop(row)
|
|
self.endRemoveRows()
|
|
|
|
def removeItems(self, __rows: list[int]):
|
|
rows = sorted(__rows, reverse=True)
|
|
[self.removeItem(row) for row in rows]
|
|
|
|
|
|
class OcrQueueTableProxyModel(QAbstractTableModel):
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent)
|
|
self.retranslateHeaders()
|
|
self.__sourceModel = None
|
|
self.__columnRoleMapping = [
|
|
[Qt.ItemDataRole.CheckStateRole],
|
|
[
|
|
OcrQueueModel.ImagePathRole,
|
|
OcrQueueModel.ImageQImageRole,
|
|
OcrQueueModel.ImagePixmapRole,
|
|
],
|
|
[
|
|
OcrQueueModel.OcrResultRole,
|
|
OcrQueueModel.ChartRole,
|
|
],
|
|
[
|
|
OcrQueueModel.OcrResultRole,
|
|
OcrQueueModel.ScoreRole,
|
|
OcrQueueModel.ChartRole,
|
|
OcrQueueModel.ScoreValidateOkRole,
|
|
],
|
|
]
|
|
|
|
def retranslateHeaders(self):
|
|
self.__horizontalHeaders = [
|
|
# fmt: off
|
|
QCoreApplication.translate("OcrTableModel", "horizontalHeader.title.select"),
|
|
QCoreApplication.translate("OcrTableModel", "horizontalHeader.title.imagePreview"),
|
|
QCoreApplication.translate("OcrTableModel", "horizontalHeader.title.chart"),
|
|
QCoreApplication.translate("OcrTableModel", "horizontalHeader.title.score"),
|
|
# fmt: on
|
|
]
|
|
|
|
def sourceModel(self) -> OcrQueueModel:
|
|
return self.__sourceModel
|
|
|
|
def setSourceModel(self, sourceModel):
|
|
if not isinstance(sourceModel, OcrQueueModel):
|
|
return False
|
|
|
|
# connect signals
|
|
sourceModel.rowsAboutToBeInserted.connect(self.rowsAboutToBeInserted)
|
|
sourceModel.rowsInserted.connect(self.rowsInserted)
|
|
sourceModel.rowsAboutToBeRemoved.connect(self.rowsAboutToBeRemoved)
|
|
sourceModel.rowsRemoved.connect(self.rowsRemoved)
|
|
sourceModel.dataChanged.connect(self.dataChanged)
|
|
sourceModel.layoutAboutToBeChanged.connect(self.layoutAboutToBeChanged)
|
|
sourceModel.layoutChanged.connect(self.layoutChanged)
|
|
|
|
self.__sourceModel = sourceModel
|
|
return True
|
|
|
|
def rowCount(self, *args):
|
|
return self.sourceModel().rowCount()
|
|
|
|
def columnCount(self, *args):
|
|
return len(self.__horizontalHeaders)
|
|
|
|
def headerData(self, section: int, orientation: Qt.Orientation, role: int):
|
|
if (
|
|
orientation == Qt.Orientation.Horizontal
|
|
and 0 <= section < len(self.__horizontalHeaders)
|
|
and role == Qt.ItemDataRole.DisplayRole
|
|
):
|
|
return self.__horizontalHeaders[section]
|
|
return None
|
|
|
|
def data(self, index, role):
|
|
if (
|
|
0 <= index.row() < self.rowCount()
|
|
and 0 <= index.column() < self.columnCount()
|
|
and role in self.__columnRoleMapping[index.column()]
|
|
):
|
|
srcIndex = self.sourceModel().index(index.row(), 0)
|
|
return srcIndex.data(role)
|
|
return None
|
|
|
|
def setData(self, index, value, role):
|
|
if index.column() == 2 and role == OcrQueueModel.ChartRole:
|
|
return self.sourceModel().setData(index, value, role)
|
|
if index.column() == 3 and role == OcrQueueModel.ScoreRole:
|
|
return self.sourceModel().setData(index, value, role)
|
|
return False
|
|
|
|
def flags(self, index: QModelIndex) -> Qt.ItemFlag:
|
|
flags = (
|
|
self.sourceModel().flags(index)
|
|
if isinstance(self.sourceModel(), OcrQueueModel)
|
|
else super().flags(index)
|
|
)
|
|
flags = flags | Qt.ItemFlag.ItemIsEnabled
|
|
flags = flags | Qt.ItemFlag.ItemIsEditable
|
|
flags = flags | Qt.ItemFlag.ItemIsSelectable
|
|
if index.column() == 0:
|
|
flags = flags & ~Qt.ItemFlag.ItemIsEnabled & ~Qt.ItemFlag.ItemIsEditable
|
|
return flags
|
|
|
|
|
|
class OcrImageDelegate(ImageDelegate):
|
|
def getPixmap(self, index: QModelIndex):
|
|
return index.data(OcrQueueModel.ImagePixmapRole)
|
|
|
|
def getImagePath(self, index: QModelIndex):
|
|
return index.data(OcrQueueModel.ImagePathRole)
|
|
|
|
|
|
class OcrChartDelegate(ChartDelegate):
|
|
def getChart(self, index: QModelIndex) -> Chart | None:
|
|
return index.data(OcrQueueModel.ChartRole)
|
|
|
|
def paintWarningBackground(self, index: QModelIndex) -> bool:
|
|
return isinstance(index.data(OcrQueueModel.OcrResultRole), DeviceOcrResult)
|
|
|
|
def setModelData(self, editor, model: OcrQueueTableProxyModel, index):
|
|
if editor.validate():
|
|
model.setData(index, editor.value(), OcrQueueModel.ChartRole)
|
|
|
|
|
|
class OcrScoreDelegate(ScoreDelegate):
|
|
def getScore(self, index: QModelIndex):
|
|
return index.data(OcrQueueModel.ScoreRole)
|
|
|
|
def getChart(self, index: QModelIndex):
|
|
return index.data(OcrQueueModel.ChartRole)
|
|
|
|
def getScoreValidateOk(self, index: QModelIndex):
|
|
return index.data(OcrQueueModel.ScoreValidateOkRole)
|
|
|
|
def paintWarningBackground(self, index: QModelIndex) -> bool:
|
|
return True
|
|
# return isinstance(self.getChart(index), Chart) and isinstance(
|
|
# self.getScore(index), Score
|
|
# )
|
|
# return isinstance(
|
|
# index.data(OcrQueueModel.OcrResultRole), (DeviceOcrResult, B30OcrResultItem)
|
|
# )
|
|
|
|
def setModelData(self, editor, model: OcrQueueTableProxyModel, index):
|
|
if super().confirmSetModelData(editor):
|
|
model.setData(index, editor.value(), OcrQueueModel.ScoreRole)
|