feat: model training

This commit is contained in:
283375 2023-09-27 22:17:38 +08:00
parent bd69c32098
commit 750c3c6819
Signed by: 283375
SSH Key Fingerprint: SHA256:UcX0qg6ZOSDOeieKPGokA5h7soykG61nz2uxuQgVLSk
5 changed files with 86 additions and 37 deletions

View File

@ -223,6 +223,22 @@ class Project:
def redactSources(self):
list(self.redactSourcesYield())
def train(self):
trainModule = self.getModule("train")
trainClass = trainModule.Train
trainItems = [
{"tag": tag, "value": int(value), "samples": self.samplesByTag(tag)}
for tag, value in self.tagValueMap.items()
]
trainClassInstance = trainClass(trainItems)
knnModel = trainClassInstance.train_knn()
knnModel.save(str((self.path / "knn.dat").resolve()))
svmModel = trainClassInstance.train_svm()
svmModel.save(str((self.path / "svm.dat").resolve()))
def classify(self, sample: Path, tag: str):
if tag != "ignored" and tag not in self.tags:
raise ValueError(f'Unknown tag "{tag}"')

View File

@ -119,3 +119,14 @@ class ProjectEntry_Manage(Ui_ProjectEntry_Manage, QWidget):
self.abort = False
progressDialog.close()
progressDialog.deleteLater()
@Slot()
def on_trainButton_clicked(self):
if not self.project:
return
with BlockLabelDialog(self) as block:
block.setText(f"{self.project.name}<br>Training")
block.show()
self.project.train()

View File

@ -21,10 +21,23 @@
</property>
</widget>
</item>
<item row="3" column="0">
<widget class="QPushButton" name="extractButton">
<item row="5" column="0" colspan="2">
<spacer name="verticalSpacer">
<property name="orientation">
<enum>Qt::Vertical</enum>
</property>
<property name="sizeHint" stdset="0">
<size>
<width>20</width>
<height>40</height>
</size>
</property>
</spacer>
</item>
<item row="3" column="1">
<widget class="QPushButton" name="redactSourcesButton">
<property name="text">
<string>Extract</string>
<string>Redact sources</string>
</property>
</widget>
</item>
@ -41,26 +54,6 @@
</property>
</widget>
</item>
<item row="3" column="1">
<widget class="QPushButton" name="redactSourcesButton">
<property name="text">
<string>Redact sources</string>
</property>
</widget>
</item>
<item row="5" column="0" colspan="2">
<spacer name="verticalSpacer">
<property name="orientation">
<enum>Qt::Vertical</enum>
</property>
<property name="sizeHint" stdset="0">
<size>
<width>20</width>
<height>40</height>
</size>
</property>
</spacer>
</item>
<item row="2" column="0">
<widget class="QPushButton" name="updateButton">
<property name="text">
@ -68,6 +61,20 @@
</property>
</widget>
</item>
<item row="3" column="0">
<widget class="QPushButton" name="extractButton">
<property name="text">
<string>Extract</string>
</property>
</widget>
</item>
<item row="4" column="0">
<widget class="QPushButton" name="trainButton">
<property name="text">
<string>Train</string>
</property>
</widget>
</item>
</layout>
</widget>
<resources/>

View File

@ -31,10 +31,14 @@ class Ui_ProjectEntry_Manage(object):
self.gridLayout.addWidget(self.projectDescriptionLabel, 1, 0, 1, 2)
self.extractButton = QPushButton(ProjectEntry_Manage)
self.extractButton.setObjectName(u"extractButton")
self.verticalSpacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)
self.gridLayout.addWidget(self.extractButton, 3, 0, 1, 1)
self.gridLayout.addItem(self.verticalSpacer, 5, 0, 1, 2)
self.redactSourcesButton = QPushButton(ProjectEntry_Manage)
self.redactSourcesButton.setObjectName(u"redactSourcesButton")
self.gridLayout.addWidget(self.redactSourcesButton, 3, 1, 1, 1)
self.projectNameLabel = QLabel(ProjectEntry_Manage)
self.projectNameLabel.setObjectName(u"projectNameLabel")
@ -45,20 +49,21 @@ class Ui_ProjectEntry_Manage(object):
self.gridLayout.addWidget(self.projectNameLabel, 0, 0, 1, 2)
self.redactSourcesButton = QPushButton(ProjectEntry_Manage)
self.redactSourcesButton.setObjectName(u"redactSourcesButton")
self.gridLayout.addWidget(self.redactSourcesButton, 3, 1, 1, 1)
self.verticalSpacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)
self.gridLayout.addItem(self.verticalSpacer, 5, 0, 1, 2)
self.updateButton = QPushButton(ProjectEntry_Manage)
self.updateButton.setObjectName(u"updateButton")
self.gridLayout.addWidget(self.updateButton, 2, 0, 1, 1)
self.extractButton = QPushButton(ProjectEntry_Manage)
self.extractButton.setObjectName(u"extractButton")
self.gridLayout.addWidget(self.extractButton, 3, 0, 1, 1)
self.trainButton = QPushButton(ProjectEntry_Manage)
self.trainButton.setObjectName(u"trainButton")
self.gridLayout.addWidget(self.trainButton, 4, 0, 1, 1)
self.retranslateUi(ProjectEntry_Manage)
@ -67,10 +72,11 @@ class Ui_ProjectEntry_Manage(object):
def retranslateUi(self, ProjectEntry_Manage):
self.projectDescriptionLabel.setText(QCoreApplication.translate("ProjectEntry_Manage", u"-", None))
self.extractButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Extract", None))
self.projectNameLabel.setText(QCoreApplication.translate("ProjectEntry_Manage", u"-", None))
self.redactSourcesButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Redact sources", None))
self.projectNameLabel.setText(QCoreApplication.translate("ProjectEntry_Manage", u"-", None))
self.updateButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Update", None))
self.extractButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Extract", None))
self.trainButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Train", None))
pass
# retranslateUi

View File

@ -1,9 +1,18 @@
import logging
import sys
import time
from PySide6.QtWidgets import QApplication
from ui.mainWindow import MainWindow
# logging.basicConfig(
# filename=f"ui-{int(time.time() * 1000)}.log",
# filemode="w",
# level=logging.DEBUG,
# )
if __name__ == "__main__":
app = QApplication(sys.argv)