mirror of
https://github.com/283375/arcaea-offline-ocr-model.git
synced 2025-04-04 14:10:18 +00:00
wip(ui): sample classifying
This commit is contained in:
parent
67a794f4f1
commit
7e11f3ee5d
@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import CHAR, TEXT, TIMESTAMP, text, event, DDL
|
||||
from sqlalchemy import CHAR, DDL, TEXT, TIMESTAMP, event, text
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ class ClassifiedSample(ProjectBase):
|
||||
)
|
||||
tag: Mapped[str] = mapped_column(TEXT(), primary_key=True)
|
||||
timestamp: Mapped[datetime] = mapped_column(
|
||||
TIMESTAMP(), server_default=text("CURRENT_TIMESTAMP")
|
||||
TIMESTAMP(timezone=True), server_default=text("CURRENT_TIMESTAMP")
|
||||
)
|
||||
|
||||
|
||||
|
15
project.py
15
project.py
@ -102,27 +102,28 @@ class Project:
|
||||
@property
|
||||
def samplesClassified(self):
|
||||
with self.__sessionmaker() as session:
|
||||
return [
|
||||
samplesClassifiedMd5s = [
|
||||
cs.sampleNumpyMd5 for cs in session.scalars(select(ClassifiedSample))
|
||||
]
|
||||
return [p for p in self.samples if p.stem in samplesClassifiedMd5s]
|
||||
|
||||
@property
|
||||
def samplesIgnored(self):
|
||||
with self.__sessionmaker() as session:
|
||||
return [
|
||||
samplesIgnoredMd5s = [
|
||||
cs.sampleNumpyMd5
|
||||
for cs in session.scalars(
|
||||
select(ClassifiedSample).where(ClassifiedSample.tag == "ignored")
|
||||
)
|
||||
]
|
||||
return [p for p in self.samples if p.stem in samplesIgnoredMd5s]
|
||||
|
||||
@property
|
||||
def samplesUnclassified(self):
|
||||
samplesNumpyMd5s = [s.stem for s in self.samples]
|
||||
classifiedSamples = []
|
||||
classifiedSamples += self.samplesClassified
|
||||
classifiedSamples += self.samplesIgnored
|
||||
return [s for s in samplesNumpyMd5s if s not in classifiedSamples]
|
||||
classifiedList = []
|
||||
classifiedList += self.samplesClassified
|
||||
classifiedList += self.samplesIgnored
|
||||
return list(filter(lambda p: p not in classifiedList, self.samples))
|
||||
|
||||
def samplesByTag(self, tag: str):
|
||||
if tag != "ignored" and tag not in self.tags:
|
||||
|
@ -16,6 +16,7 @@ class ProjectEntry(Ui_ProjectEntry, QWidget):
|
||||
def setProject(self, project: Project):
|
||||
self.project = project
|
||||
self.tabManage.setProject(project)
|
||||
self.tabClassify.setProject(project)
|
||||
|
||||
def reloadProject(self):
|
||||
self.project.reload()
|
||||
|
@ -24,49 +24,33 @@
|
||||
<string>Manage</string>
|
||||
</attribute>
|
||||
</widget>
|
||||
<widget class="QWidget" name="tabClassify">
|
||||
<widget class="ProjectEntry_Classify" name="tabClassify">
|
||||
<attribute name="title">
|
||||
<string>Classify</string>
|
||||
</attribute>
|
||||
<layout class="QHBoxLayout" name="horizontalLayout">
|
||||
<item>
|
||||
<widget class="SamplesListWidget" name="unclassifiedListWidget"/>
|
||||
</item>
|
||||
<item>
|
||||
<layout class="QVBoxLayout" name="verticalLayout_2">
|
||||
<item>
|
||||
<widget class="QListWidget" name="tagsListWidget">
|
||||
<property name="sizePolicy">
|
||||
<sizepolicy hsizetype="Expanding" vsizetype="Preferred">
|
||||
<horstretch>0</horstretch>
|
||||
<verstretch>0</verstretch>
|
||||
</sizepolicy>
|
||||
</property>
|
||||
</widget>
|
||||
</item>
|
||||
<item>
|
||||
<widget class="SamplesListWidget" name="classfiedListWidget"/>
|
||||
</item>
|
||||
</layout>
|
||||
</item>
|
||||
</layout>
|
||||
</widget>
|
||||
<widget class="QWidget" name="tabSamples">
|
||||
<attribute name="title">
|
||||
<string>Samples</string>
|
||||
</attribute>
|
||||
</widget>
|
||||
</widget>
|
||||
</item>
|
||||
</layout>
|
||||
</widget>
|
||||
<customwidgets>
|
||||
<customwidget>
|
||||
<class>SamplesListWidget</class>
|
||||
<extends>QListWidget</extends>
|
||||
<header>ui.components.samplesListWidget</header>
|
||||
</customwidget>
|
||||
<customwidget>
|
||||
<class>ProjectEntry_Manage</class>
|
||||
<extends>QWidget</extends>
|
||||
<header>ui.components.projectEntry_Manage</header>
|
||||
<container>1</container>
|
||||
</customwidget>
|
||||
<customwidget>
|
||||
<class>ProjectEntry_Classify</class>
|
||||
<extends>QWidget</extends>
|
||||
<header>ui.components.projectEntry_Classify</header>
|
||||
<container>1</container>
|
||||
</customwidget>
|
||||
</customwidgets>
|
||||
<resources/>
|
||||
<connections/>
|
||||
|
92
ui/components/projectEntry_Classify.py
Normal file
92
ui/components/projectEntry_Classify.py
Normal file
@ -0,0 +1,92 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from PySide6.QtCore import Qt, Slot
|
||||
from PySide6.QtGui import QDragEnterEvent, QDropEvent
|
||||
from PySide6.QtWidgets import QLabel, QWidget
|
||||
|
||||
from project import Project
|
||||
|
||||
from .projectEntry_Classify_ui import Ui_ProjectEntry_Classify
|
||||
|
||||
|
||||
class TagLabel(QLabel):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.setAcceptDrops(True)
|
||||
|
||||
self.project: Project | None = None
|
||||
self.tag = None
|
||||
|
||||
def enableDropEffect(self):
|
||||
# palette = self.palette()
|
||||
# palette.setBrush(QPalette.ColorRole.Base, palette.highlight())
|
||||
# palette.setBrush(QPalette.ColorRole.Window, palette.highlight())
|
||||
# palette.setBrush(QPalette.ColorRole.Text, palette.highlightedText())
|
||||
# self.setPalette(palette)
|
||||
font = self.font()
|
||||
font.setBold(True)
|
||||
font.setUnderline(True)
|
||||
self.setFont(font)
|
||||
|
||||
def disableDropEffect(self):
|
||||
# palette = self.palette()
|
||||
# palette.setBrush(QPalette.ColorRole.Base, palette.base())
|
||||
# palette.setBrush(QPalette.ColorRole.Window, palette.window())
|
||||
# palette.setBrush(QPalette.ColorRole.Text, palette.text())
|
||||
# self.setPalette(palette)
|
||||
font = self.font()
|
||||
font.setBold(False)
|
||||
font.setUnderline(False)
|
||||
self.setFont(font)
|
||||
|
||||
def dragEnterEvent(self, event: QDragEnterEvent):
|
||||
mimeData = event.mimeData()
|
||||
if mimeData.hasFormat("application/ao-ocr-model_sample"):
|
||||
self.enableDropEffect()
|
||||
event.accept()
|
||||
|
||||
def dragLeaveEvent(self, event):
|
||||
self.disableDropEffect()
|
||||
return super().dragLeaveEvent(event)
|
||||
|
||||
def dropEvent(self, event: QDropEvent):
|
||||
if self.project and self.tag and event.dropAction() == Qt.DropAction.MoveAction:
|
||||
data = bytes(event.mimeData().data("application/ao-ocr-model_sample"))
|
||||
paths = json.loads(data.decode("utf-8"))
|
||||
paths = [Path(p) for p in paths]
|
||||
for path in paths:
|
||||
self.project.classify(path, self.tag)
|
||||
event.acceptProposedAction()
|
||||
|
||||
if not event.isAccepted():
|
||||
event.ignore()
|
||||
self.disableDropEffect()
|
||||
|
||||
|
||||
class ProjectEntry_Classify(Ui_ProjectEntry_Classify, QWidget):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.setupUi(self)
|
||||
self.project = None
|
||||
|
||||
self.tagLabels: list[TagLabel] = []
|
||||
|
||||
def setProject(self, project: Project):
|
||||
self.project = project
|
||||
|
||||
for tagLabel in self.tagLabels:
|
||||
self.frame.layout().removeWidget(tagLabel)
|
||||
tagLabel.deleteLater()
|
||||
|
||||
for tag in self.project.tags:
|
||||
tagLabel = TagLabel(self)
|
||||
tagLabel.tag = tag
|
||||
tagLabel.project = project
|
||||
tagLabel.setText(tag)
|
||||
self.frame.layout().addWidget(tagLabel)
|
||||
self.tagLabels.append(tagLabel)
|
||||
|
||||
@Slot()
|
||||
def on_loadSamplesButton_clicked(self):
|
||||
self.samplesListWidget.setSamples(self.project.samplesUnclassified)
|
85
ui/components/projectEntry_Classify.ui
Normal file
85
ui/components/projectEntry_Classify.ui
Normal file
@ -0,0 +1,85 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<ui version="4.0">
|
||||
<class>ProjectEntry_Classify</class>
|
||||
<widget class="QWidget" name="ProjectEntry_Classify">
|
||||
<property name="geometry">
|
||||
<rect>
|
||||
<x>0</x>
|
||||
<y>0</y>
|
||||
<width>677</width>
|
||||
<height>523</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="windowTitle">
|
||||
<string notr="true">ProjectEntry_Classify</string>
|
||||
</property>
|
||||
<layout class="QHBoxLayout" name="horizontalLayout">
|
||||
<item>
|
||||
<widget class="SamplesListWidget" name="samplesListWidget"/>
|
||||
</item>
|
||||
<item>
|
||||
<layout class="QVBoxLayout" name="verticalLayout">
|
||||
<item>
|
||||
<spacer name="verticalSpacer">
|
||||
<property name="orientation">
|
||||
<enum>Qt::Vertical</enum>
|
||||
</property>
|
||||
<property name="sizeHint" stdset="0">
|
||||
<size>
|
||||
<width>20</width>
|
||||
<height>40</height>
|
||||
</size>
|
||||
</property>
|
||||
</spacer>
|
||||
</item>
|
||||
<item>
|
||||
<widget class="QPushButton" name="loadSamplesButton">
|
||||
<property name="text">
|
||||
<string>Load samples</string>
|
||||
</property>
|
||||
</widget>
|
||||
</item>
|
||||
<item>
|
||||
<widget class="QFrame" name="frame">
|
||||
<property name="minimumSize">
|
||||
<size>
|
||||
<width>100</width>
|
||||
<height>200</height>
|
||||
</size>
|
||||
</property>
|
||||
<property name="frameShape">
|
||||
<enum>QFrame::StyledPanel</enum>
|
||||
</property>
|
||||
<property name="frameShadow">
|
||||
<enum>QFrame::Raised</enum>
|
||||
</property>
|
||||
<layout class="QVBoxLayout" name="verticalLayout_2"/>
|
||||
</widget>
|
||||
</item>
|
||||
<item>
|
||||
<spacer name="verticalSpacer_2">
|
||||
<property name="orientation">
|
||||
<enum>Qt::Vertical</enum>
|
||||
</property>
|
||||
<property name="sizeHint" stdset="0">
|
||||
<size>
|
||||
<width>20</width>
|
||||
<height>40</height>
|
||||
</size>
|
||||
</property>
|
||||
</spacer>
|
||||
</item>
|
||||
</layout>
|
||||
</item>
|
||||
</layout>
|
||||
</widget>
|
||||
<customwidgets>
|
||||
<customwidget>
|
||||
<class>SamplesListWidget</class>
|
||||
<extends>QListWidget</extends>
|
||||
<header>ui.extends.samplesListWidget</header>
|
||||
</customwidget>
|
||||
</customwidgets>
|
||||
<resources/>
|
||||
<connections/>
|
||||
</ui>
|
75
ui/components/projectEntry_Classify_ui.py
Normal file
75
ui/components/projectEntry_Classify_ui.py
Normal file
@ -0,0 +1,75 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
################################################################################
|
||||
## Form generated from reading UI file 'projectEntry_Classify.ui'
|
||||
##
|
||||
## Created by: Qt User Interface Compiler version 6.5.2
|
||||
##
|
||||
## WARNING! All changes made in this file will be lost when recompiling UI file!
|
||||
################################################################################
|
||||
|
||||
from PySide6.QtCore import (QCoreApplication, QDate, QDateTime, QLocale,
|
||||
QMetaObject, QObject, QPoint, QRect,
|
||||
QSize, QTime, QUrl, Qt)
|
||||
from PySide6.QtGui import (QBrush, QColor, QConicalGradient, QCursor,
|
||||
QFont, QFontDatabase, QGradient, QIcon,
|
||||
QImage, QKeySequence, QLinearGradient, QPainter,
|
||||
QPalette, QPixmap, QRadialGradient, QTransform)
|
||||
from PySide6.QtWidgets import (QApplication, QFrame, QHBoxLayout, QListWidgetItem,
|
||||
QPushButton, QSizePolicy, QSpacerItem, QVBoxLayout,
|
||||
QWidget)
|
||||
|
||||
from ui.extends.samplesListWidget import SamplesListWidget
|
||||
|
||||
class Ui_ProjectEntry_Classify(object):
|
||||
def setupUi(self, ProjectEntry_Classify):
|
||||
if not ProjectEntry_Classify.objectName():
|
||||
ProjectEntry_Classify.setObjectName(u"ProjectEntry_Classify")
|
||||
ProjectEntry_Classify.resize(677, 523)
|
||||
ProjectEntry_Classify.setWindowTitle(u"ProjectEntry_Classify")
|
||||
self.horizontalLayout = QHBoxLayout(ProjectEntry_Classify)
|
||||
self.horizontalLayout.setObjectName(u"horizontalLayout")
|
||||
self.samplesListWidget = SamplesListWidget(ProjectEntry_Classify)
|
||||
self.samplesListWidget.setObjectName(u"samplesListWidget")
|
||||
|
||||
self.horizontalLayout.addWidget(self.samplesListWidget)
|
||||
|
||||
self.verticalLayout = QVBoxLayout()
|
||||
self.verticalLayout.setObjectName(u"verticalLayout")
|
||||
self.verticalSpacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)
|
||||
|
||||
self.verticalLayout.addItem(self.verticalSpacer)
|
||||
|
||||
self.loadSamplesButton = QPushButton(ProjectEntry_Classify)
|
||||
self.loadSamplesButton.setObjectName(u"loadSamplesButton")
|
||||
|
||||
self.verticalLayout.addWidget(self.loadSamplesButton)
|
||||
|
||||
self.frame = QFrame(ProjectEntry_Classify)
|
||||
self.frame.setObjectName(u"frame")
|
||||
self.frame.setMinimumSize(QSize(100, 200))
|
||||
self.frame.setFrameShape(QFrame.StyledPanel)
|
||||
self.frame.setFrameShadow(QFrame.Raised)
|
||||
self.verticalLayout_2 = QVBoxLayout(self.frame)
|
||||
self.verticalLayout_2.setObjectName(u"verticalLayout_2")
|
||||
|
||||
self.verticalLayout.addWidget(self.frame)
|
||||
|
||||
self.verticalSpacer_2 = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)
|
||||
|
||||
self.verticalLayout.addItem(self.verticalSpacer_2)
|
||||
|
||||
|
||||
self.horizontalLayout.addLayout(self.verticalLayout)
|
||||
|
||||
|
||||
self.retranslateUi(ProjectEntry_Classify)
|
||||
|
||||
QMetaObject.connectSlotsByName(ProjectEntry_Classify)
|
||||
# setupUi
|
||||
|
||||
def retranslateUi(self, ProjectEntry_Classify):
|
||||
self.loadSamplesButton.setText(QCoreApplication.translate("ProjectEntry_Classify", u"Load samples", None))
|
||||
pass
|
||||
# retranslateUi
|
||||
|
@ -15,11 +15,11 @@ from PySide6.QtGui import (QBrush, QColor, QConicalGradient, QCursor,
|
||||
QFont, QFontDatabase, QGradient, QIcon,
|
||||
QImage, QKeySequence, QLinearGradient, QPainter,
|
||||
QPalette, QPixmap, QRadialGradient, QTransform)
|
||||
from PySide6.QtWidgets import (QApplication, QHBoxLayout, QListWidget, QListWidgetItem,
|
||||
QSizePolicy, QTabWidget, QVBoxLayout, QWidget)
|
||||
from PySide6.QtWidgets import (QApplication, QSizePolicy, QTabWidget, QVBoxLayout,
|
||||
QWidget)
|
||||
|
||||
from ui.components.projectEntry_Classify import ProjectEntry_Classify
|
||||
from ui.components.projectEntry_Manage import ProjectEntry_Manage
|
||||
from ui.components.samplesListWidget import SamplesListWidget
|
||||
|
||||
class Ui_ProjectEntry(object):
|
||||
def setupUi(self, ProjectEntry):
|
||||
@ -33,36 +33,12 @@ class Ui_ProjectEntry(object):
|
||||
self.tabManage = ProjectEntry_Manage()
|
||||
self.tabManage.setObjectName(u"tabManage")
|
||||
self.tabWidget.addTab(self.tabManage, "")
|
||||
self.tabClassify = QWidget()
|
||||
self.tabClassify = ProjectEntry_Classify()
|
||||
self.tabClassify.setObjectName(u"tabClassify")
|
||||
self.horizontalLayout = QHBoxLayout(self.tabClassify)
|
||||
self.horizontalLayout.setObjectName(u"horizontalLayout")
|
||||
self.unclassifiedListWidget = SamplesListWidget(self.tabClassify)
|
||||
self.unclassifiedListWidget.setObjectName(u"unclassifiedListWidget")
|
||||
|
||||
self.horizontalLayout.addWidget(self.unclassifiedListWidget)
|
||||
|
||||
self.verticalLayout_2 = QVBoxLayout()
|
||||
self.verticalLayout_2.setObjectName(u"verticalLayout_2")
|
||||
self.tagsListWidget = QListWidget(self.tabClassify)
|
||||
self.tagsListWidget.setObjectName(u"tagsListWidget")
|
||||
sizePolicy = QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Preferred)
|
||||
sizePolicy.setHorizontalStretch(0)
|
||||
sizePolicy.setVerticalStretch(0)
|
||||
sizePolicy.setHeightForWidth(self.tagsListWidget.sizePolicy().hasHeightForWidth())
|
||||
self.tagsListWidget.setSizePolicy(sizePolicy)
|
||||
|
||||
self.verticalLayout_2.addWidget(self.tagsListWidget)
|
||||
|
||||
self.classfiedListWidget = SamplesListWidget(self.tabClassify)
|
||||
self.classfiedListWidget.setObjectName(u"classfiedListWidget")
|
||||
|
||||
self.verticalLayout_2.addWidget(self.classfiedListWidget)
|
||||
|
||||
|
||||
self.horizontalLayout.addLayout(self.verticalLayout_2)
|
||||
|
||||
self.tabWidget.addTab(self.tabClassify, "")
|
||||
self.tabSamples = QWidget()
|
||||
self.tabSamples.setObjectName(u"tabSamples")
|
||||
self.tabWidget.addTab(self.tabSamples, "")
|
||||
|
||||
self.verticalLayout.addWidget(self.tabWidget)
|
||||
|
||||
@ -79,5 +55,6 @@ class Ui_ProjectEntry(object):
|
||||
ProjectEntry.setWindowTitle(QCoreApplication.translate("ProjectEntry", u"projectEntry", None))
|
||||
self.tabWidget.setTabText(self.tabWidget.indexOf(self.tabManage), QCoreApplication.translate("ProjectEntry", u"Manage", None))
|
||||
self.tabWidget.setTabText(self.tabWidget.indexOf(self.tabClassify), QCoreApplication.translate("ProjectEntry", u"Classify", None))
|
||||
self.tabWidget.setTabText(self.tabWidget.indexOf(self.tabSamples), QCoreApplication.translate("ProjectEntry", u"Samples", None))
|
||||
# retranslateUi
|
||||
|
||||
|
61
ui/extends/samplesListWidget.py
Normal file
61
ui/extends/samplesListWidget.py
Normal file
@ -0,0 +1,61 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from PySide6.QtCore import QByteArray, QMimeData, Qt
|
||||
from PySide6.QtGui import QDrag, QPixmap
|
||||
from PySide6.QtWidgets import QListWidget, QListWidgetItem, QMessageBox, QProgressDialog
|
||||
|
||||
|
||||
class SamplesListWidget(QListWidget):
|
||||
PathlibPathRole = Qt.ItemDataRole.UserRole
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
|
||||
self.setViewMode(QListWidget.ViewMode.IconMode)
|
||||
self.setEditTriggers(QListWidget.EditTrigger.NoEditTriggers)
|
||||
self.setMovement(QListWidget.Movement.Static)
|
||||
self.setDragDropMode(QListWidget.DragDropMode.DragOnly)
|
||||
self.setDragEnabled(True)
|
||||
self.setSelectionMode(QListWidget.SelectionMode.MultiSelection)
|
||||
|
||||
def setSamples(self, samples: list[Path]):
|
||||
self.clear()
|
||||
|
||||
samplesNum = len(samples)
|
||||
progressDialog = QProgressDialog("", "Abort", 0, samplesNum, self)
|
||||
progressDialog.setWindowModality(Qt.WindowModality.ApplicationModal)
|
||||
|
||||
for i, sample in enumerate(samples):
|
||||
item = QListWidgetItem(QPixmap(str(sample)), f"{sample.stem[:3]}...", self)
|
||||
item.setData(self.PathlibPathRole, sample)
|
||||
self.addItem(item)
|
||||
progressDialog.setValue(i)
|
||||
progressDialog.setLabelText(f"{i + 1}/{samplesNum}")
|
||||
|
||||
if progressDialog.wasCanceled():
|
||||
break
|
||||
|
||||
progressDialog.setValue(samplesNum)
|
||||
QMessageBox.information(
|
||||
self, None, f"Loaded {self.model().rowCount()} samples."
|
||||
)
|
||||
|
||||
def startDrag(self, supportedActions: Qt.DropAction):
|
||||
drag = QDrag(self)
|
||||
items = self.selectedItems()
|
||||
paths = [str(item.data(self.PathlibPathRole).resolve()) for item in items]
|
||||
mimeDataString = json.dumps(paths, ensure_ascii=False)
|
||||
|
||||
mimeData = QMimeData()
|
||||
mimeData.setData(
|
||||
"application/ao-ocr-model_sample",
|
||||
QByteArray(mimeDataString.encode("utf-8")),
|
||||
)
|
||||
|
||||
drag.setPixmap(items[0].icon().pixmap(items[0].icon().availableSizes()[0]))
|
||||
drag.setMimeData(mimeData)
|
||||
if drag.exec(Qt.DropAction.MoveAction) == Qt.DropAction.MoveAction:
|
||||
for item in items:
|
||||
index = self.indexFromItem(item)
|
||||
self.model().removeRow(index.row())
|
@ -6,6 +6,9 @@ from ui.mainWindow import MainWindow
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = QApplication(sys.argv)
|
||||
|
||||
app.setStyle("fusion")
|
||||
|
||||
window = MainWindow()
|
||||
window.show()
|
||||
sys.exit(app.exec())
|
||||
|
Loading…
x
Reference in New Issue
Block a user