diff --git a/project.py b/project.py
index 197fd0b..edce01e 100644
--- a/project.py
+++ b/project.py
@@ -223,6 +223,22 @@ class Project:
def redactSources(self):
list(self.redactSourcesYield())
+ def train(self):
+ trainModule = self.getModule("train")
+ trainClass = trainModule.Train
+
+ trainItems = [
+ {"tag": tag, "value": int(value), "samples": self.samplesByTag(tag)}
+ for tag, value in self.tagValueMap.items()
+ ]
+
+ trainClassInstance = trainClass(trainItems)
+
+ knnModel = trainClassInstance.train_knn()
+ knnModel.save(str((self.path / "knn.dat").resolve()))
+ svmModel = trainClassInstance.train_svm()
+ svmModel.save(str((self.path / "svm.dat").resolve()))
+
def classify(self, sample: Path, tag: str):
if tag != "ignored" and tag not in self.tags:
raise ValueError(f'Unknown tag "{tag}"')
diff --git a/ui/components/projectEntry_Manage.py b/ui/components/projectEntry_Manage.py
index 8d14aa8..b869e39 100644
--- a/ui/components/projectEntry_Manage.py
+++ b/ui/components/projectEntry_Manage.py
@@ -119,3 +119,14 @@ class ProjectEntry_Manage(Ui_ProjectEntry_Manage, QWidget):
self.abort = False
progressDialog.close()
progressDialog.deleteLater()
+
+ @Slot()
+ def on_trainButton_clicked(self):
+ if not self.project:
+ return
+
+ with BlockLabelDialog(self) as block:
+ block.setText(f"{self.project.name}
Training")
+ block.show()
+
+ self.project.train()
diff --git a/ui/components/projectEntry_Manage.ui b/ui/components/projectEntry_Manage.ui
index 2501d44..61efdf9 100644
--- a/ui/components/projectEntry_Manage.ui
+++ b/ui/components/projectEntry_Manage.ui
@@ -21,10 +21,23 @@
- -
-
+
-
+
+
+ Qt::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
- Extract
+ Redact sources
@@ -41,26 +54,6 @@
- -
-
-
- Redact sources
-
-
-
- -
-
-
- Qt::Vertical
-
-
-
- 20
- 40
-
-
-
-
-
@@ -68,6 +61,20 @@
+ -
+
+
+ Extract
+
+
+
+ -
+
+
+ Train
+
+
+
diff --git a/ui/components/projectEntry_Manage_ui.py b/ui/components/projectEntry_Manage_ui.py
index 5431246..02bdf4a 100644
--- a/ui/components/projectEntry_Manage_ui.py
+++ b/ui/components/projectEntry_Manage_ui.py
@@ -31,10 +31,14 @@ class Ui_ProjectEntry_Manage(object):
self.gridLayout.addWidget(self.projectDescriptionLabel, 1, 0, 1, 2)
- self.extractButton = QPushButton(ProjectEntry_Manage)
- self.extractButton.setObjectName(u"extractButton")
+ self.verticalSpacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)
- self.gridLayout.addWidget(self.extractButton, 3, 0, 1, 1)
+ self.gridLayout.addItem(self.verticalSpacer, 5, 0, 1, 2)
+
+ self.redactSourcesButton = QPushButton(ProjectEntry_Manage)
+ self.redactSourcesButton.setObjectName(u"redactSourcesButton")
+
+ self.gridLayout.addWidget(self.redactSourcesButton, 3, 1, 1, 1)
self.projectNameLabel = QLabel(ProjectEntry_Manage)
self.projectNameLabel.setObjectName(u"projectNameLabel")
@@ -45,20 +49,21 @@ class Ui_ProjectEntry_Manage(object):
self.gridLayout.addWidget(self.projectNameLabel, 0, 0, 1, 2)
- self.redactSourcesButton = QPushButton(ProjectEntry_Manage)
- self.redactSourcesButton.setObjectName(u"redactSourcesButton")
-
- self.gridLayout.addWidget(self.redactSourcesButton, 3, 1, 1, 1)
-
- self.verticalSpacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)
-
- self.gridLayout.addItem(self.verticalSpacer, 5, 0, 1, 2)
-
self.updateButton = QPushButton(ProjectEntry_Manage)
self.updateButton.setObjectName(u"updateButton")
self.gridLayout.addWidget(self.updateButton, 2, 0, 1, 1)
+ self.extractButton = QPushButton(ProjectEntry_Manage)
+ self.extractButton.setObjectName(u"extractButton")
+
+ self.gridLayout.addWidget(self.extractButton, 3, 0, 1, 1)
+
+ self.trainButton = QPushButton(ProjectEntry_Manage)
+ self.trainButton.setObjectName(u"trainButton")
+
+ self.gridLayout.addWidget(self.trainButton, 4, 0, 1, 1)
+
self.retranslateUi(ProjectEntry_Manage)
@@ -67,10 +72,11 @@ class Ui_ProjectEntry_Manage(object):
def retranslateUi(self, ProjectEntry_Manage):
self.projectDescriptionLabel.setText(QCoreApplication.translate("ProjectEntry_Manage", u"-", None))
- self.extractButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Extract", None))
- self.projectNameLabel.setText(QCoreApplication.translate("ProjectEntry_Manage", u"-", None))
self.redactSourcesButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Redact sources", None))
+ self.projectNameLabel.setText(QCoreApplication.translate("ProjectEntry_Manage", u"-", None))
self.updateButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Update", None))
+ self.extractButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Extract", None))
+ self.trainButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Train", None))
pass
# retranslateUi
diff --git a/uiIndex.py b/uiIndex.py
index a35682c..cd49c44 100644
--- a/uiIndex.py
+++ b/uiIndex.py
@@ -1,9 +1,18 @@
+import logging
import sys
+import time
from PySide6.QtWidgets import QApplication
from ui.mainWindow import MainWindow
+# logging.basicConfig(
+# filename=f"ui-{int(time.time() * 1000)}.log",
+# filemode="w",
+# level=logging.DEBUG,
+# )
+
+
if __name__ == "__main__":
app = QApplication(sys.argv)