diff --git a/project.py b/project.py index 197fd0b..edce01e 100644 --- a/project.py +++ b/project.py @@ -223,6 +223,22 @@ class Project: def redactSources(self): list(self.redactSourcesYield()) + def train(self): + trainModule = self.getModule("train") + trainClass = trainModule.Train + + trainItems = [ + {"tag": tag, "value": int(value), "samples": self.samplesByTag(tag)} + for tag, value in self.tagValueMap.items() + ] + + trainClassInstance = trainClass(trainItems) + + knnModel = trainClassInstance.train_knn() + knnModel.save(str((self.path / "knn.dat").resolve())) + svmModel = trainClassInstance.train_svm() + svmModel.save(str((self.path / "svm.dat").resolve())) + def classify(self, sample: Path, tag: str): if tag != "ignored" and tag not in self.tags: raise ValueError(f'Unknown tag "{tag}"') diff --git a/ui/components/projectEntry_Manage.py b/ui/components/projectEntry_Manage.py index 8d14aa8..b869e39 100644 --- a/ui/components/projectEntry_Manage.py +++ b/ui/components/projectEntry_Manage.py @@ -119,3 +119,14 @@ class ProjectEntry_Manage(Ui_ProjectEntry_Manage, QWidget): self.abort = False progressDialog.close() progressDialog.deleteLater() + + @Slot() + def on_trainButton_clicked(self): + if not self.project: + return + + with BlockLabelDialog(self) as block: + block.setText(f"{self.project.name}
Training") + block.show() + + self.project.train() diff --git a/ui/components/projectEntry_Manage.ui b/ui/components/projectEntry_Manage.ui index 2501d44..61efdf9 100644 --- a/ui/components/projectEntry_Manage.ui +++ b/ui/components/projectEntry_Manage.ui @@ -21,10 +21,23 @@ - - + + + + Qt::Vertical + + + + 20 + 40 + + + + + + - Extract + Redact sources @@ -41,26 +54,6 @@ - - - - Redact sources - - - - - - - Qt::Vertical - - - - 20 - 40 - - - - @@ -68,6 +61,20 @@ + + + + Extract + + + + + + + Train + + + diff --git a/ui/components/projectEntry_Manage_ui.py b/ui/components/projectEntry_Manage_ui.py index 5431246..02bdf4a 100644 --- a/ui/components/projectEntry_Manage_ui.py +++ b/ui/components/projectEntry_Manage_ui.py @@ -31,10 +31,14 @@ class Ui_ProjectEntry_Manage(object): self.gridLayout.addWidget(self.projectDescriptionLabel, 1, 0, 1, 2) - self.extractButton = QPushButton(ProjectEntry_Manage) - self.extractButton.setObjectName(u"extractButton") + self.verticalSpacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding) - self.gridLayout.addWidget(self.extractButton, 3, 0, 1, 1) + self.gridLayout.addItem(self.verticalSpacer, 5, 0, 1, 2) + + self.redactSourcesButton = QPushButton(ProjectEntry_Manage) + self.redactSourcesButton.setObjectName(u"redactSourcesButton") + + self.gridLayout.addWidget(self.redactSourcesButton, 3, 1, 1, 1) self.projectNameLabel = QLabel(ProjectEntry_Manage) self.projectNameLabel.setObjectName(u"projectNameLabel") @@ -45,20 +49,21 @@ class Ui_ProjectEntry_Manage(object): self.gridLayout.addWidget(self.projectNameLabel, 0, 0, 1, 2) - self.redactSourcesButton = QPushButton(ProjectEntry_Manage) - self.redactSourcesButton.setObjectName(u"redactSourcesButton") - - self.gridLayout.addWidget(self.redactSourcesButton, 3, 1, 1, 1) - - self.verticalSpacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding) - - self.gridLayout.addItem(self.verticalSpacer, 5, 0, 1, 2) - self.updateButton = QPushButton(ProjectEntry_Manage) self.updateButton.setObjectName(u"updateButton") self.gridLayout.addWidget(self.updateButton, 2, 0, 1, 1) + self.extractButton = QPushButton(ProjectEntry_Manage) + self.extractButton.setObjectName(u"extractButton") + + self.gridLayout.addWidget(self.extractButton, 3, 0, 1, 1) + + self.trainButton = QPushButton(ProjectEntry_Manage) + self.trainButton.setObjectName(u"trainButton") + + self.gridLayout.addWidget(self.trainButton, 4, 0, 1, 1) + self.retranslateUi(ProjectEntry_Manage) @@ -67,10 +72,11 @@ class Ui_ProjectEntry_Manage(object): def retranslateUi(self, ProjectEntry_Manage): self.projectDescriptionLabel.setText(QCoreApplication.translate("ProjectEntry_Manage", u"-", None)) - self.extractButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Extract", None)) - self.projectNameLabel.setText(QCoreApplication.translate("ProjectEntry_Manage", u"-", None)) self.redactSourcesButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Redact sources", None)) + self.projectNameLabel.setText(QCoreApplication.translate("ProjectEntry_Manage", u"-", None)) self.updateButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Update", None)) + self.extractButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Extract", None)) + self.trainButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Train", None)) pass # retranslateUi diff --git a/uiIndex.py b/uiIndex.py index a35682c..cd49c44 100644 --- a/uiIndex.py +++ b/uiIndex.py @@ -1,9 +1,18 @@ +import logging import sys +import time from PySide6.QtWidgets import QApplication from ui.mainWindow import MainWindow +# logging.basicConfig( +# filename=f"ui-{int(time.time() * 1000)}.log", +# filemode="w", +# level=logging.DEBUG, +# ) + + if __name__ == "__main__": app = QApplication(sys.argv)