283375 5c5c1a227d
wip: arcaea-offline-ocr==0.1.0
API changes, modifier & clear_type support
2023-10-12 17:05:04 +08:00

461 lines
15 KiB
Python

import logging
from enum import IntEnum
from typing import Any, Callable, Optional, overload
from arcaea_offline.calculate import calculate_score_range
from arcaea_offline.database import Database
from arcaea_offline.models import Chart, Score
from arcaea_offline_ocr.b30.shared import B30OcrResultItem
from arcaea_offline_ocr.device.common import DeviceOcrResult
from arcaea_offline_ocr.utils import convert_to_srgb
from PIL import Image
from PIL.ImageQt import ImageQt
from PySide6.QtCore import (
QAbstractListModel,
QAbstractTableModel,
QCoreApplication,
QFileInfo,
QModelIndex,
QObject,
QRunnable,
Qt,
QThreadPool,
Signal,
Slot,
)
from PySide6.QtGui import QImage, QPixmap
from ui.extends.shared.delegates.chartDelegate import ChartDelegate
from ui.extends.shared.delegates.imageDelegate import ImageDelegate
from ui.extends.shared.delegates.scoreDelegate import ScoreDelegate
logger = logging.getLogger(__name__)
class OcrRunnableSignals(QObject):
rowId: int = -1
resultReady = Signal("QVariant")
finished = Signal()
class OcrRunnable(QRunnable):
def __init__(self):
super().__init__()
self.signals = OcrRunnableSignals()
class IccOption(IntEnum):
Ignore = 0
UsePIL = 1
TryFix = 2
class OcrQueueModel(QAbstractListModel):
ImagePathRole = Qt.ItemDataRole.UserRole + 1
ImageQImageRole = Qt.ItemDataRole.UserRole + 2
ImagePixmapRole = Qt.ItemDataRole.UserRole + 3
OcrResultRole = Qt.ItemDataRole.UserRole + 10
ScoreRole = Qt.ItemDataRole.UserRole + 11
ChartRole = Qt.ItemDataRole.UserRole + 12
ScoreValidateOkRole = Qt.ItemDataRole.UserRole + 13
OcrRunnableRole = Qt.ItemDataRole.UserRole + 20
ProcessOcrResultFuncRole = (
Qt.ItemDataRole.UserRole + 21
) # Callable[[imageStr, DeviceOcrResult], tuple[Chart, Score]]
started = Signal()
progress = Signal(int)
finished = Signal()
def __init__(self, parent=None):
super().__init__(parent)
self.__db = Database()
self.__items: list[dict[int, Any]] = []
self.__iccOption = IccOption.UsePIL
self.__taskFinishedNum = 0
@property
def imagePaths(self):
return [item.get(self.ImagePathRole) for item in self.__items]
def clear(self):
self.beginResetModel()
self.beginRemoveRows(QModelIndex(), 0, self.rowCount() - 1)
self.__items.clear()
self.endRemoveRows()
self.__taskFinishedNum = 0
self.endResetModel()
def rowCount(self, *args):
return len(self.__items)
def data(self, index, role):
if (
index.isValid()
and 0 <= index.row() < self.rowCount()
and index.column() == 0
):
return self.__items[index.row()].get(role)
return None
def setData(self, index: QModelIndex, value: Any, role: int):
if not 0 <= index.row() < self.rowCount():
return False
item = self.__items[index.row()]
updateRole = None
if role == self.OcrResultRole:
item[self.OcrResultRole] = value
updateRole = role
if role == self.ChartRole and isinstance(value, Chart):
item[self.ChartRole] = value
self.updateScoreValidateOk(index.row())
updateRole = role
if role == self.ScoreRole and isinstance(value, Score):
item[self.ScoreRole] = value
self.updateScoreValidateOk(index.row())
updateRole = role
if role == self.ScoreValidateOkRole and isinstance(value, bool):
item[self.ScoreValidateOkRole] = value
updateRole = role
if role == self.OcrRunnableRole and isinstance(value, OcrRunnable):
item[self.OcrRunnableRole] = value
updateRole = role
if role == self.ProcessOcrResultFuncRole and callable(value):
item[self.ProcessOcrResultFuncRole] = value
updateRole = role
if updateRole is not None:
self.dataChanged.emit(index, index, [updateRole])
return True
else:
logger.warning(
f"{repr(self)} setData at row {index.row()} with role {role} and value {value} rejected."
)
return False
@property
def iccOption(self):
return self.__iccOption
@iccOption.setter
def iccOption(self, opt: IccOption):
self.__iccOption = opt
@overload
def addItem(
self,
image: str,
runnable: OcrRunnable = None,
process_func: Callable[[Optional[str], QImage, Any], Score] = None,
):
...
@overload
def addItem(
self,
image: QImage,
runnable: OcrRunnable = None,
process_func: Callable[[Optional[str], QImage, Any], Score] = None,
):
...
def addItem(
self,
image,
runnable=None,
process_func=None,
):
if isinstance(image, str):
if image in self.imagePaths or not QFileInfo(image).exists():
logger.warning(f"Attempting to add an invalid file {image}")
return
imagePath = image
if self.iccOption == IccOption.TryFix:
img = Image.open(image)
img = convert_to_srgb(img)
qImage = ImageQt(img)
elif self.iccOption == IccOption.UsePIL:
img = Image.open(image)
qImage = ImageQt(img)
else:
qImage = QImage(image)
qPixmap = QPixmap(qImage)
elif isinstance(image, QImage):
imagePath = None
qImage = image.copy()
qPixmap = QPixmap(qImage)
else:
raise ValueError("Unsupported type for `image`")
self.beginInsertRows(QModelIndex(), self.rowCount(), self.rowCount())
self.__items.append(
{
self.ImagePathRole: imagePath,
self.ImageQImageRole: qImage,
self.ImagePixmapRole: qPixmap,
self.OcrResultRole: None,
self.ScoreRole: None,
self.ChartRole: None,
self.ScoreValidateOkRole: False,
self.OcrRunnableRole: runnable,
self.ProcessOcrResultFuncRole: process_func,
}
)
self.endInsertRows()
def updateOcrResult(self, row: int, result: Any) -> bool:
if not 0 <= row < self.rowCount():
return False
index = self.index(row, 0)
imagePath: str = index.data(self.ImagePathRole)
qImage: QImage = index.data(self.ImageQImageRole)
logger.debug(f"update request: {result}@row{row}")
processOcrResultFunc = index.data(self.ProcessOcrResultFuncRole)
chart, scoreInsert = processOcrResultFunc(imagePath, qImage, result)
self.setData(index, result, self.OcrResultRole)
self.setData(index, chart, self.ChartRole)
self.setData(index, scoreInsert, self.ScoreRole)
return True
@Slot(DeviceOcrResult)
def ocrTaskReady(self, result: DeviceOcrResult):
row = self.sender().rowId
self.updateOcrResult(row, result)
@Slot()
def ocrTaskFinished(self):
self.__taskFinishedNum += 1
self.progress.emit(self.__taskFinishedNum)
if self.__taskFinishedNum == self.__taskNum:
self.finished.emit()
def startQueue(self):
self.__taskNum = self.rowCount()
self.__taskFinishedNum = 0
self.started.emit()
for row in range(self.rowCount()):
modelIndex = self.index(row, 0)
runnable: OcrRunnable = modelIndex.data(self.OcrRunnableRole)
runnable.signals.rowId = row
runnable.signals.resultReady.connect(self.ocrTaskReady)
runnable.signals.finished.connect(self.ocrTaskFinished)
QThreadPool.globalInstance().start(runnable)
def updateScoreValidateOk(self, row: int):
if not 0 <= row < self.rowCount():
return
index = self.index(row, 0)
chart = index.data(self.ChartRole)
score = index.data(self.ScoreRole)
if (
isinstance(chart, Chart)
and isinstance(score, Score)
and chart.notes is not None
and score.pure is not None
and score.far is not None
):
scoreRange = calculate_score_range(chart.notes, score.pure, score.far)
scoreValidateOk = scoreRange[0] <= score.score <= scoreRange[1]
self.setData(index, scoreValidateOk, self.ScoreValidateOkRole)
else:
self.setData(index, False, self.ScoreValidateOkRole)
def acceptItem(self, row: int, ignoreValidate: bool = False):
if not 0 <= row < self.rowCount():
return
item = self.__items[row]
score = item[self.ScoreRole]
if not isinstance(score, Score) or (
not item[self.ScoreValidateOkRole] and not ignoreValidate
):
return
try:
self.__db.insert_score(score)
self.beginRemoveRows(QModelIndex(), row, row)
self.__items.pop(row)
self.endRemoveRows()
return
except Exception as e:
logger.exception(f"Error accepting {repr(item)}")
return
def acceptItems(self, __rows: list[int], ignoreValidate: bool = False):
items = sorted(__rows, reverse=True)
[self.acceptItem(item, ignoreValidate) for item in items]
def acceptAllItems(self, ignoreValidate: bool = False):
self.acceptItems([*range(self.rowCount())], ignoreValidate)
def removeItem(self, row: int):
if not 0 <= row < self.rowCount():
return
self.beginRemoveRows(QModelIndex(), row, row)
self.__items.pop(row)
self.endRemoveRows()
def removeItems(self, __rows: list[int]):
rows = sorted(__rows, reverse=True)
[self.removeItem(row) for row in rows]
class OcrQueueTableProxyModel(QAbstractTableModel):
def __init__(self, parent=None):
super().__init__(parent)
self.retranslateHeaders()
self.__sourceModel = None
self.__columnRoleMapping = [
[Qt.ItemDataRole.CheckStateRole],
[
OcrQueueModel.ImagePathRole,
OcrQueueModel.ImageQImageRole,
OcrQueueModel.ImagePixmapRole,
],
[
OcrQueueModel.OcrResultRole,
OcrQueueModel.ChartRole,
],
[
OcrQueueModel.OcrResultRole,
OcrQueueModel.ScoreRole,
OcrQueueModel.ChartRole,
OcrQueueModel.ScoreValidateOkRole,
],
]
def retranslateHeaders(self):
self.__horizontalHeaders = [
# fmt: off
QCoreApplication.translate("OcrTableModel", "horizontalHeader.title.select"),
QCoreApplication.translate("OcrTableModel", "horizontalHeader.title.imagePreview"),
QCoreApplication.translate("OcrTableModel", "horizontalHeader.title.chart"),
QCoreApplication.translate("OcrTableModel", "horizontalHeader.title.score"),
# fmt: on
]
def sourceModel(self) -> OcrQueueModel:
return self.__sourceModel
def setSourceModel(self, sourceModel):
if not isinstance(sourceModel, OcrQueueModel):
return False
# connect signals
sourceModel.rowsAboutToBeInserted.connect(self.rowsAboutToBeInserted)
sourceModel.rowsInserted.connect(self.rowsInserted)
sourceModel.rowsAboutToBeRemoved.connect(self.rowsAboutToBeRemoved)
sourceModel.rowsRemoved.connect(self.rowsRemoved)
sourceModel.dataChanged.connect(self.dataChanged)
sourceModel.layoutAboutToBeChanged.connect(self.layoutAboutToBeChanged)
sourceModel.layoutChanged.connect(self.layoutChanged)
self.__sourceModel = sourceModel
return True
def rowCount(self, *args):
return self.sourceModel().rowCount()
def columnCount(self, *args):
return len(self.__horizontalHeaders)
def headerData(self, section: int, orientation: Qt.Orientation, role: int):
if (
orientation == Qt.Orientation.Horizontal
and 0 <= section < len(self.__horizontalHeaders)
and role == Qt.ItemDataRole.DisplayRole
):
return self.__horizontalHeaders[section]
return None
def data(self, index, role):
if (
0 <= index.row() < self.rowCount()
and 0 <= index.column() < self.columnCount()
and role in self.__columnRoleMapping[index.column()]
):
srcIndex = self.sourceModel().index(index.row(), 0)
return srcIndex.data(role)
return None
def setData(self, index, value, role):
if index.column() == 2 and role == OcrQueueModel.ChartRole:
return self.sourceModel().setData(index, value, role)
if index.column() == 3 and role == OcrQueueModel.ScoreRole:
return self.sourceModel().setData(index, value, role)
return False
def flags(self, index: QModelIndex) -> Qt.ItemFlag:
flags = (
self.sourceModel().flags(index)
if isinstance(self.sourceModel(), OcrQueueModel)
else super().flags(index)
)
flags = flags | Qt.ItemFlag.ItemIsEnabled
flags = flags | Qt.ItemFlag.ItemIsEditable
flags = flags | Qt.ItemFlag.ItemIsSelectable
if index.column() == 0:
flags = flags & ~Qt.ItemFlag.ItemIsEnabled & ~Qt.ItemFlag.ItemIsEditable
return flags
class OcrImageDelegate(ImageDelegate):
def getPixmap(self, index: QModelIndex):
return index.data(OcrQueueModel.ImagePixmapRole)
def getImagePath(self, index: QModelIndex):
return index.data(OcrQueueModel.ImagePathRole)
class OcrChartDelegate(ChartDelegate):
def getChart(self, index: QModelIndex) -> Chart | None:
return index.data(OcrQueueModel.ChartRole)
def paintWarningBackground(self, index: QModelIndex) -> bool:
return isinstance(index.data(OcrQueueModel.OcrResultRole), DeviceOcrResult)
def setModelData(self, editor, model: OcrQueueTableProxyModel, index):
if editor.validate():
model.setData(index, editor.value(), OcrQueueModel.ChartRole)
class OcrScoreDelegate(ScoreDelegate):
def getScore(self, index: QModelIndex):
return index.data(OcrQueueModel.ScoreRole)
def getChart(self, index: QModelIndex):
return index.data(OcrQueueModel.ChartRole)
def getScoreValidateOk(self, index: QModelIndex):
return index.data(OcrQueueModel.ScoreValidateOkRole)
def paintWarningBackground(self, index: QModelIndex) -> bool:
return True
# return isinstance(self.getChart(index), Chart) and isinstance(
# self.getScore(index), Score
# )
# return isinstance(
# index.data(OcrQueueModel.OcrResultRole), (DeviceOcrResult, B30OcrResultItem)
# )
def setModelData(self, editor, model: OcrQueueTableProxyModel, index):
if super().confirmSetModelData(editor):
model.setData(index, editor.value(), OcrQueueModel.ScoreRole)