Compare commits

...

8 Commits

Author SHA1 Message Date
750c3c6819
feat: model training 2023-09-27 22:17:38 +08:00
bd69c32098
feat: ProjectEntry_Samples 2023-09-24 02:07:42 +08:00
0ae960a405
feat: BlockLabelDialog 2023-09-24 02:07:21 +08:00
01b8b2e26c
impr: better dialog when loading samples 2023-09-24 01:42:59 +08:00
5a5b0887c3
fix: sample issues 2023-09-24 01:27:10 +08:00
6a527bc1ff
impr: better describing ui 2023-09-24 00:30:45 +08:00
7c043e89ab
fix: exclude ignored from classified samples 2023-09-24 00:29:50 +08:00
3193c88779
fix: ignored tag 2023-09-23 22:55:43 +08:00
15 changed files with 418 additions and 65 deletions

View File

@ -103,7 +103,10 @@ class Project:
def samplesClassified(self):
with self.__sessionmaker() as session:
samplesClassifiedMd5s = [
cs.sampleNumpyMd5 for cs in session.scalars(select(ClassifiedSample))
cs.sampleNumpyMd5
for cs in session.scalars(
select(ClassifiedSample).where(ClassifiedSample.tag != "ignored")
)
]
return [p for p in self.samples if p.stem in samplesClassifiedMd5s]
@ -130,12 +133,13 @@ class Project:
raise ValueError(f'Unknown tag "{tag}"')
with self.__sessionmaker() as session:
return [
sampleMd5s = [
cs.sampleNumpyMd5
for cs in session.scalars(
select(ClassifiedSample).where(ClassifiedSample.tag == tag)
)
]
return [p for p in self.samples if p.stem in sampleMd5s]
def getModule(self, moduleName: str):
cwdPath = Path(os.getcwd())
@ -219,8 +223,24 @@ class Project:
def redactSources(self):
list(self.redactSourcesYield())
def train(self):
trainModule = self.getModule("train")
trainClass = trainModule.Train
trainItems = [
{"tag": tag, "value": int(value), "samples": self.samplesByTag(tag)}
for tag, value in self.tagValueMap.items()
]
trainClassInstance = trainClass(trainItems)
knnModel = trainClassInstance.train_knn()
knnModel.save(str((self.path / "knn.dat").resolve()))
svmModel = trainClassInstance.train_svm()
svmModel.save(str((self.path / "svm.dat").resolve()))
def classify(self, sample: Path, tag: str):
if tag not in self.tags:
if tag != "ignored" and tag not in self.tags:
raise ValueError(f'Unknown tag "{tag}"')
with self.__sessionmaker() as session:
@ -232,8 +252,10 @@ class Project:
def unclassify(self, sample: Path):
with self.__sessionmaker() as session:
cs = ClassifiedSample()
cs.sampleNumpyMd5 = sample.stem
stmt = select(ClassifiedSample).where(
ClassifiedSample.sampleNumpyMd5 == sample.stem
)
cs = session.scalar(stmt)
session.delete(cs)
session.commit()

View File

@ -0,0 +1,38 @@
from PySide6.QtCore import Qt
from PySide6.QtWidgets import QApplication, QLabel
class BlockLabelDialog(QLabel):
def __init__(
self,
parent=None,
modality: Qt.WindowModality = Qt.WindowModality.ApplicationModal,
*,
autoShow: bool = False
):
super().__init__(parent)
self.setWindowFlag(Qt.WindowType.Dialog, True)
self.setWindowFlag(Qt.WindowType.WindowMinimizeButtonHint, False)
self.setWindowFlag(Qt.WindowType.WindowMaximizeButtonHint, False)
self.setWindowFlag(Qt.WindowType.WindowCloseButtonHint, False)
self.setWindowModality(modality)
self.setWindowTitle("Please Wait")
self.setMinimumWidth(200)
self.setMargin(20)
self.setAlignment(Qt.AlignmentFlag.AlignCenter)
self.autoShow = autoShow
def show(self):
super().show()
QApplication.processEvents()
def __enter__(self):
if self.autoShow:
self.show()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
self.deleteLater()

View File

@ -17,6 +17,7 @@ class ProjectEntry(Ui_ProjectEntry, QWidget):
self.project = project
self.tabManage.setProject(project)
self.tabClassify.setProject(project)
self.tabSamples.setProject(project)
def reloadProject(self):
self.project.reload()

View File

@ -29,7 +29,7 @@
<string>Classify</string>
</attribute>
</widget>
<widget class="QWidget" name="tabSamples">
<widget class="ProjectEntry_Samples" name="tabSamples">
<attribute name="title">
<string>Samples</string>
</attribute>
@ -51,6 +51,12 @@
<header>ui.components.projectEntry_Classify</header>
<container>1</container>
</customwidget>
<customwidget>
<class>ProjectEntry_Samples</class>
<extends>QWidget</extends>
<header>ui.components.projectEntry_Samples</header>
<container>1</container>
</customwidget>
</customwidgets>
<resources/>
<connections/>

View File

@ -7,6 +7,7 @@ from PySide6.QtWidgets import QLabel, QWidget
from project import Project
from .blockLabelDialog import BlockLabelDialog
from .projectEntry_Classify_ui import Ui_ProjectEntry_Classify
@ -89,4 +90,8 @@ class ProjectEntry_Classify(Ui_ProjectEntry_Classify, QWidget):
@Slot()
def on_loadSamplesButton_clicked(self):
with BlockLabelDialog(self) as block:
block.setText(f"{self.project.name}<br>Loading unclassified samples")
block.show()
self.samplesListWidget.setSamples(self.project.samplesUnclassified)

View File

@ -3,6 +3,7 @@ from PySide6.QtWidgets import QApplication, QWidget
from project import Project
from .blockLabelDialog import BlockLabelDialog
from .projectEntry_Manage_ui import Ui_ProjectEntry_Manage
from .yieldProgress import YieldProgress
@ -26,13 +27,21 @@ class ProjectEntry_Manage(Ui_ProjectEntry_Manage, QWidget):
self.projectDescriptionLabel.setText("-")
return
with BlockLabelDialog(self) as block:
block.setText(f"{self.project.name}<br>Updating status")
block.show()
QApplication.processEvents()
self.projectNameLabel.setText(self.project.name)
self.projectDescriptionLabel.setText(
"<br>".join(
[
str(self.project.path.resolve()),
f"{len(self.project.sources)} sources",
f"{len(self.project.samples)} samples ({len(self.project.samplesUnclassified)} unclassified)",
f"{len(self.project.samples)} samples",
f"- {len(self.project.samplesClassified)} classified",
f"- {len(self.project.samplesIgnored)} ignored",
f"- {len(self.project.samplesUnclassified)} unclassified",
]
)
)
@ -110,3 +119,14 @@ class ProjectEntry_Manage(Ui_ProjectEntry_Manage, QWidget):
self.abort = False
progressDialog.close()
progressDialog.deleteLater()
@Slot()
def on_trainButton_clicked(self):
if not self.project:
return
with BlockLabelDialog(self) as block:
block.setText(f"{self.project.name}<br>Training")
block.show()
self.project.train()

View File

@ -21,10 +21,23 @@
</property>
</widget>
</item>
<item row="3" column="0">
<widget class="QPushButton" name="extractButton">
<item row="5" column="0" colspan="2">
<spacer name="verticalSpacer">
<property name="orientation">
<enum>Qt::Vertical</enum>
</property>
<property name="sizeHint" stdset="0">
<size>
<width>20</width>
<height>40</height>
</size>
</property>
</spacer>
</item>
<item row="3" column="1">
<widget class="QPushButton" name="redactSourcesButton">
<property name="text">
<string>Extract</string>
<string>Redact sources</string>
</property>
</widget>
</item>
@ -41,26 +54,6 @@
</property>
</widget>
</item>
<item row="3" column="1">
<widget class="QPushButton" name="redactSourcesButton">
<property name="text">
<string>Redact sources</string>
</property>
</widget>
</item>
<item row="5" column="0" colspan="2">
<spacer name="verticalSpacer">
<property name="orientation">
<enum>Qt::Vertical</enum>
</property>
<property name="sizeHint" stdset="0">
<size>
<width>20</width>
<height>40</height>
</size>
</property>
</spacer>
</item>
<item row="2" column="0">
<widget class="QPushButton" name="updateButton">
<property name="text">
@ -68,6 +61,20 @@
</property>
</widget>
</item>
<item row="3" column="0">
<widget class="QPushButton" name="extractButton">
<property name="text">
<string>Extract</string>
</property>
</widget>
</item>
<item row="4" column="0">
<widget class="QPushButton" name="trainButton">
<property name="text">
<string>Train</string>
</property>
</widget>
</item>
</layout>
</widget>
<resources/>

View File

@ -31,10 +31,14 @@ class Ui_ProjectEntry_Manage(object):
self.gridLayout.addWidget(self.projectDescriptionLabel, 1, 0, 1, 2)
self.extractButton = QPushButton(ProjectEntry_Manage)
self.extractButton.setObjectName(u"extractButton")
self.verticalSpacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)
self.gridLayout.addWidget(self.extractButton, 3, 0, 1, 1)
self.gridLayout.addItem(self.verticalSpacer, 5, 0, 1, 2)
self.redactSourcesButton = QPushButton(ProjectEntry_Manage)
self.redactSourcesButton.setObjectName(u"redactSourcesButton")
self.gridLayout.addWidget(self.redactSourcesButton, 3, 1, 1, 1)
self.projectNameLabel = QLabel(ProjectEntry_Manage)
self.projectNameLabel.setObjectName(u"projectNameLabel")
@ -45,20 +49,21 @@ class Ui_ProjectEntry_Manage(object):
self.gridLayout.addWidget(self.projectNameLabel, 0, 0, 1, 2)
self.redactSourcesButton = QPushButton(ProjectEntry_Manage)
self.redactSourcesButton.setObjectName(u"redactSourcesButton")
self.gridLayout.addWidget(self.redactSourcesButton, 3, 1, 1, 1)
self.verticalSpacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)
self.gridLayout.addItem(self.verticalSpacer, 5, 0, 1, 2)
self.updateButton = QPushButton(ProjectEntry_Manage)
self.updateButton.setObjectName(u"updateButton")
self.gridLayout.addWidget(self.updateButton, 2, 0, 1, 1)
self.extractButton = QPushButton(ProjectEntry_Manage)
self.extractButton.setObjectName(u"extractButton")
self.gridLayout.addWidget(self.extractButton, 3, 0, 1, 1)
self.trainButton = QPushButton(ProjectEntry_Manage)
self.trainButton.setObjectName(u"trainButton")
self.gridLayout.addWidget(self.trainButton, 4, 0, 1, 1)
self.retranslateUi(ProjectEntry_Manage)
@ -67,10 +72,11 @@ class Ui_ProjectEntry_Manage(object):
def retranslateUi(self, ProjectEntry_Manage):
self.projectDescriptionLabel.setText(QCoreApplication.translate("ProjectEntry_Manage", u"-", None))
self.extractButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Extract", None))
self.projectNameLabel.setText(QCoreApplication.translate("ProjectEntry_Manage", u"-", None))
self.redactSourcesButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Redact sources", None))
self.projectNameLabel.setText(QCoreApplication.translate("ProjectEntry_Manage", u"-", None))
self.updateButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Update", None))
self.extractButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Extract", None))
self.trainButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Train", None))
pass
# retranslateUi

View File

@ -0,0 +1,64 @@
import logging
from PySide6.QtCore import Qt, Slot
from PySide6.QtWidgets import QListWidgetItem, QWidget
from project import Project
from ..extends.samplesListWidget import SamplesListWidget
from .blockLabelDialog import BlockLabelDialog
from .projectEntry_Samples_ui import Ui_ProjectEntry_Samples
logger = logging.getLogger(__name__)
class ProjectEntry_Samples(Ui_ProjectEntry_Samples, QWidget):
TagRole = Qt.ItemDataRole.UserRole + 1
def __init__(self, parent=None):
super().__init__(parent)
self.setupUi(self)
self.samplesListWidget.setDragEnabled(False)
self.project = None
def setProject(self, project: Project):
self.project = project
self.updateSampleTags()
def updateSampleTags(self):
self.tagsListWidget.clear()
with BlockLabelDialog(self) as block:
block.setText(f"{self.project.name}<br>Loading tags")
block.show()
for tag in self.project.tags + ["ignored"]:
samples = self.project.samplesByTag(tag)
item = QListWidgetItem(f"{tag} ({len(samples)} samples)")
item.setData(self.TagRole, tag)
self.tagsListWidget.addItem(item)
@Slot()
def on_loadSamplesButton_clicked(self):
tag = self.tagsListWidget.currentItem().data(self.TagRole)
samples = self.project.samplesByTag(tag)
self.samplesListWidget.setSamples(samples, cancellable=False)
@Slot()
def on_reloadButton_clicked(self):
self.updateSampleTags()
@Slot()
def on_unclassifyButton_clicked(self):
selectedSampleItems = self.samplesListWidget.selectedItems()
paths = [
item.data(SamplesListWidget.PathlibPathRole) for item in selectedSampleItems
]
for item, path in zip(selectedSampleItems, paths):
try:
self.project.unclassify(path)
index = self.samplesListWidget.indexFromItem(item)
self.samplesListWidget.model().removeRow(index.row())
except Exception:
logger.exception(f"cannot unclassify {path}")

View File

@ -0,0 +1,78 @@
<?xml version="1.0" encoding="UTF-8"?>
<ui version="4.0">
<class>ProjectEntry_Samples</class>
<widget class="QWidget" name="ProjectEntry_Samples">
<property name="geometry">
<rect>
<x>0</x>
<y>0</y>
<width>648</width>
<height>504</height>
</rect>
</property>
<property name="windowTitle">
<string notr="true">ProjectEntry_Samples</string>
</property>
<layout class="QHBoxLayout" name="horizontalLayout">
<item>
<layout class="QVBoxLayout" name="verticalLayout_2">
<item>
<widget class="QListWidget" name="tagsListWidget">
<property name="sizePolicy">
<sizepolicy hsizetype="Minimum" vsizetype="Expanding">
<horstretch>0</horstretch>
<verstretch>0</verstretch>
</sizepolicy>
</property>
</widget>
</item>
<item>
<widget class="QPushButton" name="loadSamplesButton">
<property name="text">
<string>Load &gt;</string>
</property>
</widget>
</item>
<item>
<widget class="QPushButton" name="reloadButton">
<property name="text">
<string>Reload</string>
</property>
</widget>
</item>
</layout>
</item>
<item>
<widget class="SamplesListWidget" name="samplesListWidget"/>
</item>
<item>
<widget class="QFrame" name="frame">
<property name="frameShape">
<enum>QFrame::StyledPanel</enum>
</property>
<property name="frameShadow">
<enum>QFrame::Raised</enum>
</property>
<layout class="QVBoxLayout" name="verticalLayout">
<item>
<widget class="QPushButton" name="unclassifyButton">
<property name="text">
<string>Unclassify</string>
</property>
</widget>
</item>
</layout>
</widget>
</item>
</layout>
</widget>
<customwidgets>
<customwidget>
<class>SamplesListWidget</class>
<extends>QListWidget</extends>
<header>ui.extends.samplesListWidget</header>
</customwidget>
</customwidgets>
<resources/>
<connections/>
</ui>

View File

@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-
################################################################################
## Form generated from reading UI file 'projectEntry_Samples.ui'
##
## Created by: Qt User Interface Compiler version 6.5.2
##
## WARNING! All changes made in this file will be lost when recompiling UI file!
################################################################################
from PySide6.QtCore import (QCoreApplication, QDate, QDateTime, QLocale,
QMetaObject, QObject, QPoint, QRect,
QSize, QTime, QUrl, Qt)
from PySide6.QtGui import (QBrush, QColor, QConicalGradient, QCursor,
QFont, QFontDatabase, QGradient, QIcon,
QImage, QKeySequence, QLinearGradient, QPainter,
QPalette, QPixmap, QRadialGradient, QTransform)
from PySide6.QtWidgets import (QApplication, QFrame, QHBoxLayout, QListWidget,
QListWidgetItem, QPushButton, QSizePolicy, QVBoxLayout,
QWidget)
from ui.extends.samplesListWidget import SamplesListWidget
class Ui_ProjectEntry_Samples(object):
def setupUi(self, ProjectEntry_Samples):
if not ProjectEntry_Samples.objectName():
ProjectEntry_Samples.setObjectName(u"ProjectEntry_Samples")
ProjectEntry_Samples.resize(648, 504)
ProjectEntry_Samples.setWindowTitle(u"ProjectEntry_Samples")
self.horizontalLayout = QHBoxLayout(ProjectEntry_Samples)
self.horizontalLayout.setObjectName(u"horizontalLayout")
self.verticalLayout_2 = QVBoxLayout()
self.verticalLayout_2.setObjectName(u"verticalLayout_2")
self.tagsListWidget = QListWidget(ProjectEntry_Samples)
self.tagsListWidget.setObjectName(u"tagsListWidget")
sizePolicy = QSizePolicy(QSizePolicy.Minimum, QSizePolicy.Expanding)
sizePolicy.setHorizontalStretch(0)
sizePolicy.setVerticalStretch(0)
sizePolicy.setHeightForWidth(self.tagsListWidget.sizePolicy().hasHeightForWidth())
self.tagsListWidget.setSizePolicy(sizePolicy)
self.verticalLayout_2.addWidget(self.tagsListWidget)
self.loadSamplesButton = QPushButton(ProjectEntry_Samples)
self.loadSamplesButton.setObjectName(u"loadSamplesButton")
self.verticalLayout_2.addWidget(self.loadSamplesButton)
self.reloadButton = QPushButton(ProjectEntry_Samples)
self.reloadButton.setObjectName(u"reloadButton")
self.verticalLayout_2.addWidget(self.reloadButton)
self.horizontalLayout.addLayout(self.verticalLayout_2)
self.samplesListWidget = SamplesListWidget(ProjectEntry_Samples)
self.samplesListWidget.setObjectName(u"samplesListWidget")
self.horizontalLayout.addWidget(self.samplesListWidget)
self.frame = QFrame(ProjectEntry_Samples)
self.frame.setObjectName(u"frame")
self.frame.setFrameShape(QFrame.StyledPanel)
self.frame.setFrameShadow(QFrame.Raised)
self.verticalLayout = QVBoxLayout(self.frame)
self.verticalLayout.setObjectName(u"verticalLayout")
self.unclassifyButton = QPushButton(self.frame)
self.unclassifyButton.setObjectName(u"unclassifyButton")
self.verticalLayout.addWidget(self.unclassifyButton)
self.horizontalLayout.addWidget(self.frame)
self.retranslateUi(ProjectEntry_Samples)
QMetaObject.connectSlotsByName(ProjectEntry_Samples)
# setupUi
def retranslateUi(self, ProjectEntry_Samples):
self.loadSamplesButton.setText(QCoreApplication.translate("ProjectEntry_Samples", u"Load >", None))
self.reloadButton.setText(QCoreApplication.translate("ProjectEntry_Samples", u"Reload", None))
self.unclassifyButton.setText(QCoreApplication.translate("ProjectEntry_Samples", u"Unclassify", None))
pass
# retranslateUi

View File

@ -20,6 +20,7 @@ from PySide6.QtWidgets import (QApplication, QSizePolicy, QTabWidget, QVBoxLayou
from ui.components.projectEntry_Classify import ProjectEntry_Classify
from ui.components.projectEntry_Manage import ProjectEntry_Manage
from ui.components.projectEntry_Samples import ProjectEntry_Samples
class Ui_ProjectEntry(object):
def setupUi(self, ProjectEntry):
@ -36,7 +37,7 @@ class Ui_ProjectEntry(object):
self.tabClassify = ProjectEntry_Classify()
self.tabClassify.setObjectName(u"tabClassify")
self.tabWidget.addTab(self.tabClassify, "")
self.tabSamples = QWidget()
self.tabSamples = ProjectEntry_Samples()
self.tabSamples.setObjectName(u"tabSamples")
self.tabWidget.addTab(self.tabSamples, "")

View File

@ -1,6 +0,0 @@
from PySide6.QtWidgets import QListWidget
class SamplesListWidget(QListWidget):
def __init__(self, parent=None):
super().__init__(parent)

View File

@ -3,7 +3,13 @@ from pathlib import Path
from PySide6.QtCore import QByteArray, QMimeData, Qt
from PySide6.QtGui import QDrag, QPixmap
from PySide6.QtWidgets import QListWidget, QListWidgetItem, QMessageBox, QProgressDialog
from PySide6.QtWidgets import (
QApplication,
QListWidget,
QListWidgetItem,
QMessageBox,
QProgressDialog,
)
class SamplesListWidget(QListWidget):
@ -19,12 +25,19 @@ class SamplesListWidget(QListWidget):
self.setDragEnabled(True)
self.setSelectionMode(QListWidget.SelectionMode.MultiSelection)
def setSamples(self, samples: list[Path]):
def setSamples(self, samples: list[Path], *, cancellable: bool = True):
self.clear()
samplesNum = len(samples)
progressDialog = QProgressDialog("", "Abort", 0, samplesNum, self)
progressDialog.setWindowFlag(Qt.WindowType.WindowMinimizeButtonHint, False)
progressDialog.setWindowFlag(Qt.WindowType.WindowMaximizeButtonHint, False)
progressDialog.setWindowFlag(Qt.WindowType.WindowCloseButtonHint, False)
progressDialog.setWindowModality(Qt.WindowModality.ApplicationModal)
if not cancellable:
progressDialog.setCancelButton(None)
progressDialog.show()
QApplication.processEvents()
for i, sample in enumerate(samples):
item = QListWidgetItem(QPixmap(str(sample)), f"{sample.stem[:3]}...", self)
@ -45,6 +58,7 @@ class SamplesListWidget(QListWidget):
break
progressDialog.setValue(samplesNum)
if i + 1 != samplesNum:
QMessageBox.information(
self, None, f"Loaded {self.model().rowCount()} samples."
)

View File

@ -1,9 +1,18 @@
import logging
import sys
import time
from PySide6.QtWidgets import QApplication
from ui.mainWindow import MainWindow
# logging.basicConfig(
# filename=f"ui-{int(time.time() * 1000)}.log",
# filemode="w",
# level=logging.DEBUG,
# )
if __name__ == "__main__":
app = QApplication(sys.argv)