mirror of
https://github.com/283375/arcaea-offline-pyside-ui.git
synced 2025-07-01 12:26:26 +00:00
impr: TabOcr_BuildPHashDatabase
This commit is contained in:
@ -4,13 +4,15 @@ import sqlite3
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
from PySide6.QtCore import QThread, Signal, Slot
|
||||
from PySide6.QtWidgets import QFileDialog, QMessageBox, QWidget
|
||||
|
||||
from ui.designer.tabs.tabOcr.tabOcr_BuildPHashDatabase_ui import (
|
||||
Ui_tabOcr_BuildPHashDatabase,
|
||||
Ui_TabOcr_BuildPHashDatabase,
|
||||
)
|
||||
from ui.extends.ocr import build_image_phash_database
|
||||
from ui.extends.ocr.build_phash import build_image_phash_database, preprocess_char_icon
|
||||
from ui.extends.shared.language import LanguageChangeEventFilter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -57,11 +59,14 @@ class BuildDatabaseThread(QThread):
|
||||
self.finished.emit()
|
||||
|
||||
|
||||
class TabOcr_BuildPHashDatabase(Ui_tabOcr_BuildPHashDatabase, QWidget):
|
||||
class TabOcr_BuildPHashDatabase(Ui_TabOcr_BuildPHashDatabase, QWidget):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.setupUi(self)
|
||||
|
||||
self.languageChangeEventFilter = LanguageChangeEventFilter(self)
|
||||
self.installEventFilter(self.languageChangeEventFilter)
|
||||
|
||||
self.songDirSelector.setMode(self.songDirSelector.getExistingDirectory)
|
||||
self.charIconDirSelector.setMode(self.charIconDirSelector.getExistingDirectory)
|
||||
|
||||
@ -93,11 +98,30 @@ class TabOcr_BuildPHashDatabase(Ui_tabOcr_BuildPHashDatabase, QWidget):
|
||||
charIconFilePaths = [
|
||||
p for p in Path(charIconDir).glob("**/*") if p.suffix in acceptExts
|
||||
]
|
||||
|
||||
self.readImageProgressBar.setMaximum(
|
||||
len(songFilePaths) + len(charIconFilePaths)
|
||||
)
|
||||
i = 0
|
||||
songMats = []
|
||||
charIconMats = []
|
||||
for image_path in songFilePaths:
|
||||
songMats.append(cv2.imread(str(image_path.resolve()), cv2.IMREAD_GRAYSCALE))
|
||||
i += 1
|
||||
self.readImageProgressBar.setValue(i)
|
||||
for image_path in charIconFilePaths:
|
||||
mat = cv2.imread(str(image_path.resolve()), cv2.IMREAD_GRAYSCALE)
|
||||
if self.preprocessCharIconCheckBox.isChecked():
|
||||
mat = preprocess_char_icon(mat)
|
||||
charIconMats.append(mat)
|
||||
i += 1
|
||||
self.readImageProgressBar.setValue(i)
|
||||
|
||||
songLabels = [re.sub(r"_.*$", "", p.stem) for p in songFilePaths]
|
||||
charLabels = [f"character||{p.stem}" for p in charIconFilePaths]
|
||||
charLabels = [f"partner||{p.stem}" for p in charIconFilePaths]
|
||||
|
||||
self.databaseBuildThread = BuildDatabaseThread(
|
||||
songFilePaths + charIconFilePaths, songLabels + charLabels
|
||||
songMats + charIconMats, songLabels + charLabels
|
||||
)
|
||||
self.databaseBuildThread.progress.connect(self.databaseBuildProgress)
|
||||
self.databaseBuildThread.success.connect(self.databaseBuildSuccess)
|
||||
@ -108,8 +132,8 @@ class TabOcr_BuildPHashDatabase(Ui_tabOcr_BuildPHashDatabase, QWidget):
|
||||
@Slot(int, int)
|
||||
def databaseBuildProgress(self, i: int, total: int):
|
||||
if i < 5:
|
||||
self.progressBar.setMaximum(total)
|
||||
self.progressBar.setValue(i)
|
||||
self.calculateHashProgressBar.setMaximum(total)
|
||||
self.calculateHashProgressBar.setValue(i)
|
||||
|
||||
@Slot(str)
|
||||
def databaseBuildError(self, msg: str):
|
||||
@ -133,6 +157,8 @@ class TabOcr_BuildPHashDatabase(Ui_tabOcr_BuildPHashDatabase, QWidget):
|
||||
def databaseBuildCleanUp(self):
|
||||
self.databaseBuildThread.deleteLater()
|
||||
self.databaseBuildThread = None
|
||||
self.progressBar.setMaximum(0)
|
||||
self.progressBar.setValue(0)
|
||||
self.readImageProgressBar.setMaximum(0)
|
||||
self.readImageProgressBar.setValue(0)
|
||||
self.calculateHashProgressBar.setMaximum(0)
|
||||
self.calculateHashProgressBar.setValue(0)
|
||||
self.buildButton.setEnabled(True)
|
||||
|
Reference in New Issue
Block a user