impr: TabOcr_BuildPHashDatabase

This commit is contained in:
2023-10-10 01:26:20 +08:00
parent 4a1e20a45f
commit 94e4d73a95
6 changed files with 226 additions and 113 deletions

View File

@ -4,13 +4,15 @@ import sqlite3
import time
from pathlib import Path
import cv2
from PySide6.QtCore import QThread, Signal, Slot
from PySide6.QtWidgets import QFileDialog, QMessageBox, QWidget
from ui.designer.tabs.tabOcr.tabOcr_BuildPHashDatabase_ui import (
Ui_tabOcr_BuildPHashDatabase,
Ui_TabOcr_BuildPHashDatabase,
)
from ui.extends.ocr import build_image_phash_database
from ui.extends.ocr.build_phash import build_image_phash_database, preprocess_char_icon
from ui.extends.shared.language import LanguageChangeEventFilter
logger = logging.getLogger(__name__)
@ -57,11 +59,14 @@ class BuildDatabaseThread(QThread):
self.finished.emit()
class TabOcr_BuildPHashDatabase(Ui_tabOcr_BuildPHashDatabase, QWidget):
class TabOcr_BuildPHashDatabase(Ui_TabOcr_BuildPHashDatabase, QWidget):
def __init__(self, parent=None):
super().__init__(parent)
self.setupUi(self)
self.languageChangeEventFilter = LanguageChangeEventFilter(self)
self.installEventFilter(self.languageChangeEventFilter)
self.songDirSelector.setMode(self.songDirSelector.getExistingDirectory)
self.charIconDirSelector.setMode(self.charIconDirSelector.getExistingDirectory)
@ -93,11 +98,30 @@ class TabOcr_BuildPHashDatabase(Ui_tabOcr_BuildPHashDatabase, QWidget):
charIconFilePaths = [
p for p in Path(charIconDir).glob("**/*") if p.suffix in acceptExts
]
self.readImageProgressBar.setMaximum(
len(songFilePaths) + len(charIconFilePaths)
)
i = 0
songMats = []
charIconMats = []
for image_path in songFilePaths:
songMats.append(cv2.imread(str(image_path.resolve()), cv2.IMREAD_GRAYSCALE))
i += 1
self.readImageProgressBar.setValue(i)
for image_path in charIconFilePaths:
mat = cv2.imread(str(image_path.resolve()), cv2.IMREAD_GRAYSCALE)
if self.preprocessCharIconCheckBox.isChecked():
mat = preprocess_char_icon(mat)
charIconMats.append(mat)
i += 1
self.readImageProgressBar.setValue(i)
songLabels = [re.sub(r"_.*$", "", p.stem) for p in songFilePaths]
charLabels = [f"character||{p.stem}" for p in charIconFilePaths]
charLabels = [f"partner||{p.stem}" for p in charIconFilePaths]
self.databaseBuildThread = BuildDatabaseThread(
songFilePaths + charIconFilePaths, songLabels + charLabels
songMats + charIconMats, songLabels + charLabels
)
self.databaseBuildThread.progress.connect(self.databaseBuildProgress)
self.databaseBuildThread.success.connect(self.databaseBuildSuccess)
@ -108,8 +132,8 @@ class TabOcr_BuildPHashDatabase(Ui_tabOcr_BuildPHashDatabase, QWidget):
@Slot(int, int)
def databaseBuildProgress(self, i: int, total: int):
if i < 5:
self.progressBar.setMaximum(total)
self.progressBar.setValue(i)
self.calculateHashProgressBar.setMaximum(total)
self.calculateHashProgressBar.setValue(i)
@Slot(str)
def databaseBuildError(self, msg: str):
@ -133,6 +157,8 @@ class TabOcr_BuildPHashDatabase(Ui_tabOcr_BuildPHashDatabase, QWidget):
def databaseBuildCleanUp(self):
self.databaseBuildThread.deleteLater()
self.databaseBuildThread = None
self.progressBar.setMaximum(0)
self.progressBar.setValue(0)
self.readImageProgressBar.setMaximum(0)
self.readImageProgressBar.setValue(0)
self.calculateHashProgressBar.setMaximum(0)
self.calculateHashProgressBar.setValue(0)
self.buildButton.setEnabled(True)