mirror of
https://github.com/283375/arcaea-offline-ocr-model.git
synced 2025-04-04 14:10:18 +00:00
feat: model training
This commit is contained in:
parent
bd69c32098
commit
750c3c6819
16
project.py
16
project.py
@ -223,6 +223,22 @@ class Project:
|
||||
def redactSources(self):
|
||||
list(self.redactSourcesYield())
|
||||
|
||||
def train(self):
|
||||
trainModule = self.getModule("train")
|
||||
trainClass = trainModule.Train
|
||||
|
||||
trainItems = [
|
||||
{"tag": tag, "value": int(value), "samples": self.samplesByTag(tag)}
|
||||
for tag, value in self.tagValueMap.items()
|
||||
]
|
||||
|
||||
trainClassInstance = trainClass(trainItems)
|
||||
|
||||
knnModel = trainClassInstance.train_knn()
|
||||
knnModel.save(str((self.path / "knn.dat").resolve()))
|
||||
svmModel = trainClassInstance.train_svm()
|
||||
svmModel.save(str((self.path / "svm.dat").resolve()))
|
||||
|
||||
def classify(self, sample: Path, tag: str):
|
||||
if tag != "ignored" and tag not in self.tags:
|
||||
raise ValueError(f'Unknown tag "{tag}"')
|
||||
|
@ -119,3 +119,14 @@ class ProjectEntry_Manage(Ui_ProjectEntry_Manage, QWidget):
|
||||
self.abort = False
|
||||
progressDialog.close()
|
||||
progressDialog.deleteLater()
|
||||
|
||||
@Slot()
|
||||
def on_trainButton_clicked(self):
|
||||
if not self.project:
|
||||
return
|
||||
|
||||
with BlockLabelDialog(self) as block:
|
||||
block.setText(f"{self.project.name}<br>Training")
|
||||
block.show()
|
||||
|
||||
self.project.train()
|
||||
|
@ -21,10 +21,23 @@
|
||||
</property>
|
||||
</widget>
|
||||
</item>
|
||||
<item row="3" column="0">
|
||||
<widget class="QPushButton" name="extractButton">
|
||||
<item row="5" column="0" colspan="2">
|
||||
<spacer name="verticalSpacer">
|
||||
<property name="orientation">
|
||||
<enum>Qt::Vertical</enum>
|
||||
</property>
|
||||
<property name="sizeHint" stdset="0">
|
||||
<size>
|
||||
<width>20</width>
|
||||
<height>40</height>
|
||||
</size>
|
||||
</property>
|
||||
</spacer>
|
||||
</item>
|
||||
<item row="3" column="1">
|
||||
<widget class="QPushButton" name="redactSourcesButton">
|
||||
<property name="text">
|
||||
<string>Extract</string>
|
||||
<string>Redact sources</string>
|
||||
</property>
|
||||
</widget>
|
||||
</item>
|
||||
@ -41,26 +54,6 @@
|
||||
</property>
|
||||
</widget>
|
||||
</item>
|
||||
<item row="3" column="1">
|
||||
<widget class="QPushButton" name="redactSourcesButton">
|
||||
<property name="text">
|
||||
<string>Redact sources</string>
|
||||
</property>
|
||||
</widget>
|
||||
</item>
|
||||
<item row="5" column="0" colspan="2">
|
||||
<spacer name="verticalSpacer">
|
||||
<property name="orientation">
|
||||
<enum>Qt::Vertical</enum>
|
||||
</property>
|
||||
<property name="sizeHint" stdset="0">
|
||||
<size>
|
||||
<width>20</width>
|
||||
<height>40</height>
|
||||
</size>
|
||||
</property>
|
||||
</spacer>
|
||||
</item>
|
||||
<item row="2" column="0">
|
||||
<widget class="QPushButton" name="updateButton">
|
||||
<property name="text">
|
||||
@ -68,6 +61,20 @@
|
||||
</property>
|
||||
</widget>
|
||||
</item>
|
||||
<item row="3" column="0">
|
||||
<widget class="QPushButton" name="extractButton">
|
||||
<property name="text">
|
||||
<string>Extract</string>
|
||||
</property>
|
||||
</widget>
|
||||
</item>
|
||||
<item row="4" column="0">
|
||||
<widget class="QPushButton" name="trainButton">
|
||||
<property name="text">
|
||||
<string>Train</string>
|
||||
</property>
|
||||
</widget>
|
||||
</item>
|
||||
</layout>
|
||||
</widget>
|
||||
<resources/>
|
||||
|
@ -31,10 +31,14 @@ class Ui_ProjectEntry_Manage(object):
|
||||
|
||||
self.gridLayout.addWidget(self.projectDescriptionLabel, 1, 0, 1, 2)
|
||||
|
||||
self.extractButton = QPushButton(ProjectEntry_Manage)
|
||||
self.extractButton.setObjectName(u"extractButton")
|
||||
self.verticalSpacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)
|
||||
|
||||
self.gridLayout.addWidget(self.extractButton, 3, 0, 1, 1)
|
||||
self.gridLayout.addItem(self.verticalSpacer, 5, 0, 1, 2)
|
||||
|
||||
self.redactSourcesButton = QPushButton(ProjectEntry_Manage)
|
||||
self.redactSourcesButton.setObjectName(u"redactSourcesButton")
|
||||
|
||||
self.gridLayout.addWidget(self.redactSourcesButton, 3, 1, 1, 1)
|
||||
|
||||
self.projectNameLabel = QLabel(ProjectEntry_Manage)
|
||||
self.projectNameLabel.setObjectName(u"projectNameLabel")
|
||||
@ -45,20 +49,21 @@ class Ui_ProjectEntry_Manage(object):
|
||||
|
||||
self.gridLayout.addWidget(self.projectNameLabel, 0, 0, 1, 2)
|
||||
|
||||
self.redactSourcesButton = QPushButton(ProjectEntry_Manage)
|
||||
self.redactSourcesButton.setObjectName(u"redactSourcesButton")
|
||||
|
||||
self.gridLayout.addWidget(self.redactSourcesButton, 3, 1, 1, 1)
|
||||
|
||||
self.verticalSpacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)
|
||||
|
||||
self.gridLayout.addItem(self.verticalSpacer, 5, 0, 1, 2)
|
||||
|
||||
self.updateButton = QPushButton(ProjectEntry_Manage)
|
||||
self.updateButton.setObjectName(u"updateButton")
|
||||
|
||||
self.gridLayout.addWidget(self.updateButton, 2, 0, 1, 1)
|
||||
|
||||
self.extractButton = QPushButton(ProjectEntry_Manage)
|
||||
self.extractButton.setObjectName(u"extractButton")
|
||||
|
||||
self.gridLayout.addWidget(self.extractButton, 3, 0, 1, 1)
|
||||
|
||||
self.trainButton = QPushButton(ProjectEntry_Manage)
|
||||
self.trainButton.setObjectName(u"trainButton")
|
||||
|
||||
self.gridLayout.addWidget(self.trainButton, 4, 0, 1, 1)
|
||||
|
||||
|
||||
self.retranslateUi(ProjectEntry_Manage)
|
||||
|
||||
@ -67,10 +72,11 @@ class Ui_ProjectEntry_Manage(object):
|
||||
|
||||
def retranslateUi(self, ProjectEntry_Manage):
|
||||
self.projectDescriptionLabel.setText(QCoreApplication.translate("ProjectEntry_Manage", u"-", None))
|
||||
self.extractButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Extract", None))
|
||||
self.projectNameLabel.setText(QCoreApplication.translate("ProjectEntry_Manage", u"-", None))
|
||||
self.redactSourcesButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Redact sources", None))
|
||||
self.projectNameLabel.setText(QCoreApplication.translate("ProjectEntry_Manage", u"-", None))
|
||||
self.updateButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Update", None))
|
||||
self.extractButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Extract", None))
|
||||
self.trainButton.setText(QCoreApplication.translate("ProjectEntry_Manage", u"Train", None))
|
||||
pass
|
||||
# retranslateUi
|
||||
|
||||
|
@ -1,9 +1,18 @@
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
|
||||
from PySide6.QtWidgets import QApplication
|
||||
|
||||
from ui.mainWindow import MainWindow
|
||||
|
||||
# logging.basicConfig(
|
||||
# filename=f"ui-{int(time.time() * 1000)}.log",
|
||||
# filemode="w",
|
||||
# level=logging.DEBUG,
|
||||
# )
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = QApplication(sys.argv)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user