diff --git a/dbModels.py b/dbModels.py
index 642afbf..1cde18a 100644
--- a/dbModels.py
+++ b/dbModels.py
@@ -1,6 +1,6 @@
from datetime import datetime
-from sqlalchemy import CHAR, TEXT, TIMESTAMP, text, event, DDL
+from sqlalchemy import CHAR, DDL, TEXT, TIMESTAMP, event, text
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
@@ -30,7 +30,7 @@ class ClassifiedSample(ProjectBase):
)
tag: Mapped[str] = mapped_column(TEXT(), primary_key=True)
timestamp: Mapped[datetime] = mapped_column(
- TIMESTAMP(), server_default=text("CURRENT_TIMESTAMP")
+ TIMESTAMP(timezone=True), server_default=text("CURRENT_TIMESTAMP")
)
diff --git a/project.py b/project.py
index a4e14cf..263a549 100644
--- a/project.py
+++ b/project.py
@@ -102,27 +102,28 @@ class Project:
@property
def samplesClassified(self):
with self.__sessionmaker() as session:
- return [
+ samplesClassifiedMd5s = [
cs.sampleNumpyMd5 for cs in session.scalars(select(ClassifiedSample))
]
+ return [p for p in self.samples if p.stem in samplesClassifiedMd5s]
@property
def samplesIgnored(self):
with self.__sessionmaker() as session:
- return [
+ samplesIgnoredMd5s = [
cs.sampleNumpyMd5
for cs in session.scalars(
select(ClassifiedSample).where(ClassifiedSample.tag == "ignored")
)
]
+ return [p for p in self.samples if p.stem in samplesIgnoredMd5s]
@property
def samplesUnclassified(self):
- samplesNumpyMd5s = [s.stem for s in self.samples]
- classifiedSamples = []
- classifiedSamples += self.samplesClassified
- classifiedSamples += self.samplesIgnored
- return [s for s in samplesNumpyMd5s if s not in classifiedSamples]
+ classifiedList = []
+ classifiedList += self.samplesClassified
+ classifiedList += self.samplesIgnored
+ return list(filter(lambda p: p not in classifiedList, self.samples))
def samplesByTag(self, tag: str):
if tag != "ignored" and tag not in self.tags:
diff --git a/ui/components/projectEntry.py b/ui/components/projectEntry.py
index f60d48b..252e597 100644
--- a/ui/components/projectEntry.py
+++ b/ui/components/projectEntry.py
@@ -16,6 +16,7 @@ class ProjectEntry(Ui_ProjectEntry, QWidget):
def setProject(self, project: Project):
self.project = project
self.tabManage.setProject(project)
+ self.tabClassify.setProject(project)
def reloadProject(self):
self.project.reload()
diff --git a/ui/components/projectEntry.ui b/ui/components/projectEntry.ui
index 66028fa..d12d51f 100644
--- a/ui/components/projectEntry.ui
+++ b/ui/components/projectEntry.ui
@@ -24,49 +24,33 @@
Manage
-
+
Classify
-
- -
-
-
- -
-
-
-
-
-
-
- 0
- 0
-
-
-
-
- -
-
-
-
-
-
+
+
+
+ Samples
+
-
- SamplesListWidget
- QListWidget
- ui.components.samplesListWidget
-
ProjectEntry_Manage
QWidget
ui.components.projectEntry_Manage
1
+
+ ProjectEntry_Classify
+ QWidget
+ ui.components.projectEntry_Classify
+ 1
+
diff --git a/ui/components/projectEntry_Classify.py b/ui/components/projectEntry_Classify.py
new file mode 100644
index 0000000..7c5c1f5
--- /dev/null
+++ b/ui/components/projectEntry_Classify.py
@@ -0,0 +1,92 @@
+import json
+from pathlib import Path
+
+from PySide6.QtCore import Qt, Slot
+from PySide6.QtGui import QDragEnterEvent, QDropEvent
+from PySide6.QtWidgets import QLabel, QWidget
+
+from project import Project
+
+from .projectEntry_Classify_ui import Ui_ProjectEntry_Classify
+
+
+class TagLabel(QLabel):
+ def __init__(self, parent=None):
+ super().__init__(parent)
+ self.setAcceptDrops(True)
+
+ self.project: Project | None = None
+ self.tag = None
+
+ def enableDropEffect(self):
+ # palette = self.palette()
+ # palette.setBrush(QPalette.ColorRole.Base, palette.highlight())
+ # palette.setBrush(QPalette.ColorRole.Window, palette.highlight())
+ # palette.setBrush(QPalette.ColorRole.Text, palette.highlightedText())
+ # self.setPalette(palette)
+ font = self.font()
+ font.setBold(True)
+ font.setUnderline(True)
+ self.setFont(font)
+
+ def disableDropEffect(self):
+ # palette = self.palette()
+ # palette.setBrush(QPalette.ColorRole.Base, palette.base())
+ # palette.setBrush(QPalette.ColorRole.Window, palette.window())
+ # palette.setBrush(QPalette.ColorRole.Text, palette.text())
+ # self.setPalette(palette)
+ font = self.font()
+ font.setBold(False)
+ font.setUnderline(False)
+ self.setFont(font)
+
+ def dragEnterEvent(self, event: QDragEnterEvent):
+ mimeData = event.mimeData()
+ if mimeData.hasFormat("application/ao-ocr-model_sample"):
+ self.enableDropEffect()
+ event.accept()
+
+ def dragLeaveEvent(self, event):
+ self.disableDropEffect()
+ return super().dragLeaveEvent(event)
+
+ def dropEvent(self, event: QDropEvent):
+ if self.project and self.tag and event.dropAction() == Qt.DropAction.MoveAction:
+ data = bytes(event.mimeData().data("application/ao-ocr-model_sample"))
+ paths = json.loads(data.decode("utf-8"))
+ paths = [Path(p) for p in paths]
+ for path in paths:
+ self.project.classify(path, self.tag)
+ event.acceptProposedAction()
+
+ if not event.isAccepted():
+ event.ignore()
+ self.disableDropEffect()
+
+
+class ProjectEntry_Classify(Ui_ProjectEntry_Classify, QWidget):
+ def __init__(self, parent=None):
+ super().__init__(parent)
+ self.setupUi(self)
+ self.project = None
+
+ self.tagLabels: list[TagLabel] = []
+
+ def setProject(self, project: Project):
+ self.project = project
+
+ for tagLabel in self.tagLabels:
+ self.frame.layout().removeWidget(tagLabel)
+ tagLabel.deleteLater()
+
+ for tag in self.project.tags:
+ tagLabel = TagLabel(self)
+ tagLabel.tag = tag
+ tagLabel.project = project
+ tagLabel.setText(tag)
+ self.frame.layout().addWidget(tagLabel)
+ self.tagLabels.append(tagLabel)
+
+ @Slot()
+ def on_loadSamplesButton_clicked(self):
+ self.samplesListWidget.setSamples(self.project.samplesUnclassified)
diff --git a/ui/components/projectEntry_Classify.ui b/ui/components/projectEntry_Classify.ui
new file mode 100644
index 0000000..2ff61fb
--- /dev/null
+++ b/ui/components/projectEntry_Classify.ui
@@ -0,0 +1,85 @@
+
+
+ ProjectEntry_Classify
+
+
+
+ 0
+ 0
+ 677
+ 523
+
+
+
+ ProjectEntry_Classify
+
+
+ -
+
+
+ -
+
+
-
+
+
+ Qt::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
+ Load samples
+
+
+
+ -
+
+
+
+ 100
+ 200
+
+
+
+ QFrame::StyledPanel
+
+
+ QFrame::Raised
+
+
+
+
+ -
+
+
+ Qt::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+
+
+
+
+
+
+ SamplesListWidget
+ QListWidget
+ ui.extends.samplesListWidget
+
+
+
+
+
diff --git a/ui/components/projectEntry_Classify_ui.py b/ui/components/projectEntry_Classify_ui.py
new file mode 100644
index 0000000..9bd5304
--- /dev/null
+++ b/ui/components/projectEntry_Classify_ui.py
@@ -0,0 +1,75 @@
+# -*- coding: utf-8 -*-
+
+################################################################################
+## Form generated from reading UI file 'projectEntry_Classify.ui'
+##
+## Created by: Qt User Interface Compiler version 6.5.2
+##
+## WARNING! All changes made in this file will be lost when recompiling UI file!
+################################################################################
+
+from PySide6.QtCore import (QCoreApplication, QDate, QDateTime, QLocale,
+ QMetaObject, QObject, QPoint, QRect,
+ QSize, QTime, QUrl, Qt)
+from PySide6.QtGui import (QBrush, QColor, QConicalGradient, QCursor,
+ QFont, QFontDatabase, QGradient, QIcon,
+ QImage, QKeySequence, QLinearGradient, QPainter,
+ QPalette, QPixmap, QRadialGradient, QTransform)
+from PySide6.QtWidgets import (QApplication, QFrame, QHBoxLayout, QListWidgetItem,
+ QPushButton, QSizePolicy, QSpacerItem, QVBoxLayout,
+ QWidget)
+
+from ui.extends.samplesListWidget import SamplesListWidget
+
+class Ui_ProjectEntry_Classify(object):
+ def setupUi(self, ProjectEntry_Classify):
+ if not ProjectEntry_Classify.objectName():
+ ProjectEntry_Classify.setObjectName(u"ProjectEntry_Classify")
+ ProjectEntry_Classify.resize(677, 523)
+ ProjectEntry_Classify.setWindowTitle(u"ProjectEntry_Classify")
+ self.horizontalLayout = QHBoxLayout(ProjectEntry_Classify)
+ self.horizontalLayout.setObjectName(u"horizontalLayout")
+ self.samplesListWidget = SamplesListWidget(ProjectEntry_Classify)
+ self.samplesListWidget.setObjectName(u"samplesListWidget")
+
+ self.horizontalLayout.addWidget(self.samplesListWidget)
+
+ self.verticalLayout = QVBoxLayout()
+ self.verticalLayout.setObjectName(u"verticalLayout")
+ self.verticalSpacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)
+
+ self.verticalLayout.addItem(self.verticalSpacer)
+
+ self.loadSamplesButton = QPushButton(ProjectEntry_Classify)
+ self.loadSamplesButton.setObjectName(u"loadSamplesButton")
+
+ self.verticalLayout.addWidget(self.loadSamplesButton)
+
+ self.frame = QFrame(ProjectEntry_Classify)
+ self.frame.setObjectName(u"frame")
+ self.frame.setMinimumSize(QSize(100, 200))
+ self.frame.setFrameShape(QFrame.StyledPanel)
+ self.frame.setFrameShadow(QFrame.Raised)
+ self.verticalLayout_2 = QVBoxLayout(self.frame)
+ self.verticalLayout_2.setObjectName(u"verticalLayout_2")
+
+ self.verticalLayout.addWidget(self.frame)
+
+ self.verticalSpacer_2 = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)
+
+ self.verticalLayout.addItem(self.verticalSpacer_2)
+
+
+ self.horizontalLayout.addLayout(self.verticalLayout)
+
+
+ self.retranslateUi(ProjectEntry_Classify)
+
+ QMetaObject.connectSlotsByName(ProjectEntry_Classify)
+ # setupUi
+
+ def retranslateUi(self, ProjectEntry_Classify):
+ self.loadSamplesButton.setText(QCoreApplication.translate("ProjectEntry_Classify", u"Load samples", None))
+ pass
+ # retranslateUi
+
diff --git a/ui/components/projectEntry_ui.py b/ui/components/projectEntry_ui.py
index edec07f..c6f2844 100644
--- a/ui/components/projectEntry_ui.py
+++ b/ui/components/projectEntry_ui.py
@@ -15,11 +15,11 @@ from PySide6.QtGui import (QBrush, QColor, QConicalGradient, QCursor,
QFont, QFontDatabase, QGradient, QIcon,
QImage, QKeySequence, QLinearGradient, QPainter,
QPalette, QPixmap, QRadialGradient, QTransform)
-from PySide6.QtWidgets import (QApplication, QHBoxLayout, QListWidget, QListWidgetItem,
- QSizePolicy, QTabWidget, QVBoxLayout, QWidget)
+from PySide6.QtWidgets import (QApplication, QSizePolicy, QTabWidget, QVBoxLayout,
+ QWidget)
+from ui.components.projectEntry_Classify import ProjectEntry_Classify
from ui.components.projectEntry_Manage import ProjectEntry_Manage
-from ui.components.samplesListWidget import SamplesListWidget
class Ui_ProjectEntry(object):
def setupUi(self, ProjectEntry):
@@ -33,36 +33,12 @@ class Ui_ProjectEntry(object):
self.tabManage = ProjectEntry_Manage()
self.tabManage.setObjectName(u"tabManage")
self.tabWidget.addTab(self.tabManage, "")
- self.tabClassify = QWidget()
+ self.tabClassify = ProjectEntry_Classify()
self.tabClassify.setObjectName(u"tabClassify")
- self.horizontalLayout = QHBoxLayout(self.tabClassify)
- self.horizontalLayout.setObjectName(u"horizontalLayout")
- self.unclassifiedListWidget = SamplesListWidget(self.tabClassify)
- self.unclassifiedListWidget.setObjectName(u"unclassifiedListWidget")
-
- self.horizontalLayout.addWidget(self.unclassifiedListWidget)
-
- self.verticalLayout_2 = QVBoxLayout()
- self.verticalLayout_2.setObjectName(u"verticalLayout_2")
- self.tagsListWidget = QListWidget(self.tabClassify)
- self.tagsListWidget.setObjectName(u"tagsListWidget")
- sizePolicy = QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Preferred)
- sizePolicy.setHorizontalStretch(0)
- sizePolicy.setVerticalStretch(0)
- sizePolicy.setHeightForWidth(self.tagsListWidget.sizePolicy().hasHeightForWidth())
- self.tagsListWidget.setSizePolicy(sizePolicy)
-
- self.verticalLayout_2.addWidget(self.tagsListWidget)
-
- self.classfiedListWidget = SamplesListWidget(self.tabClassify)
- self.classfiedListWidget.setObjectName(u"classfiedListWidget")
-
- self.verticalLayout_2.addWidget(self.classfiedListWidget)
-
-
- self.horizontalLayout.addLayout(self.verticalLayout_2)
-
self.tabWidget.addTab(self.tabClassify, "")
+ self.tabSamples = QWidget()
+ self.tabSamples.setObjectName(u"tabSamples")
+ self.tabWidget.addTab(self.tabSamples, "")
self.verticalLayout.addWidget(self.tabWidget)
@@ -79,5 +55,6 @@ class Ui_ProjectEntry(object):
ProjectEntry.setWindowTitle(QCoreApplication.translate("ProjectEntry", u"projectEntry", None))
self.tabWidget.setTabText(self.tabWidget.indexOf(self.tabManage), QCoreApplication.translate("ProjectEntry", u"Manage", None))
self.tabWidget.setTabText(self.tabWidget.indexOf(self.tabClassify), QCoreApplication.translate("ProjectEntry", u"Classify", None))
+ self.tabWidget.setTabText(self.tabWidget.indexOf(self.tabSamples), QCoreApplication.translate("ProjectEntry", u"Samples", None))
# retranslateUi
diff --git a/ui/extends/samplesListWidget.py b/ui/extends/samplesListWidget.py
new file mode 100644
index 0000000..81c38fe
--- /dev/null
+++ b/ui/extends/samplesListWidget.py
@@ -0,0 +1,61 @@
+import json
+from pathlib import Path
+
+from PySide6.QtCore import QByteArray, QMimeData, Qt
+from PySide6.QtGui import QDrag, QPixmap
+from PySide6.QtWidgets import QListWidget, QListWidgetItem, QMessageBox, QProgressDialog
+
+
+class SamplesListWidget(QListWidget):
+ PathlibPathRole = Qt.ItemDataRole.UserRole
+
+ def __init__(self, parent=None):
+ super().__init__(parent)
+
+ self.setViewMode(QListWidget.ViewMode.IconMode)
+ self.setEditTriggers(QListWidget.EditTrigger.NoEditTriggers)
+ self.setMovement(QListWidget.Movement.Static)
+ self.setDragDropMode(QListWidget.DragDropMode.DragOnly)
+ self.setDragEnabled(True)
+ self.setSelectionMode(QListWidget.SelectionMode.MultiSelection)
+
+ def setSamples(self, samples: list[Path]):
+ self.clear()
+
+ samplesNum = len(samples)
+ progressDialog = QProgressDialog("", "Abort", 0, samplesNum, self)
+ progressDialog.setWindowModality(Qt.WindowModality.ApplicationModal)
+
+ for i, sample in enumerate(samples):
+ item = QListWidgetItem(QPixmap(str(sample)), f"{sample.stem[:3]}...", self)
+ item.setData(self.PathlibPathRole, sample)
+ self.addItem(item)
+ progressDialog.setValue(i)
+ progressDialog.setLabelText(f"{i + 1}/{samplesNum}")
+
+ if progressDialog.wasCanceled():
+ break
+
+ progressDialog.setValue(samplesNum)
+ QMessageBox.information(
+ self, None, f"Loaded {self.model().rowCount()} samples."
+ )
+
+ def startDrag(self, supportedActions: Qt.DropAction):
+ drag = QDrag(self)
+ items = self.selectedItems()
+ paths = [str(item.data(self.PathlibPathRole).resolve()) for item in items]
+ mimeDataString = json.dumps(paths, ensure_ascii=False)
+
+ mimeData = QMimeData()
+ mimeData.setData(
+ "application/ao-ocr-model_sample",
+ QByteArray(mimeDataString.encode("utf-8")),
+ )
+
+ drag.setPixmap(items[0].icon().pixmap(items[0].icon().availableSizes()[0]))
+ drag.setMimeData(mimeData)
+ if drag.exec(Qt.DropAction.MoveAction) == Qt.DropAction.MoveAction:
+ for item in items:
+ index = self.indexFromItem(item)
+ self.model().removeRow(index.row())
diff --git a/uiIndex.py b/uiIndex.py
index dfabef9..a35682c 100644
--- a/uiIndex.py
+++ b/uiIndex.py
@@ -6,6 +6,9 @@ from ui.mainWindow import MainWindow
if __name__ == "__main__":
app = QApplication(sys.argv)
+
+ app.setStyle("fusion")
+
window = MainWindow()
window.show()
sys.exit(app.exec())