mirror of
https://github.com/283375/arcaea-offline-ocr-model.git
synced 2025-04-11 09:10:17 +00:00
wip: use database for management
This commit is contained in:
parent
b9d69fe577
commit
ed1dfd11ea
29
dbModels.py
Normal file
29
dbModels.py
Normal file
@ -0,0 +1,29 @@
|
||||
from sqlalchemy import CHAR, TEXT
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class ProjectBase(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class Property(ProjectBase):
|
||||
__tablename__ = "properties"
|
||||
|
||||
key: Mapped[str] = mapped_column(TEXT(), primary_key=True)
|
||||
value: Mapped[str] = mapped_column(TEXT(), primary_key=True)
|
||||
|
||||
|
||||
class TagValue(ProjectBase):
|
||||
__tablename__ = "tag_values"
|
||||
|
||||
tag: Mapped[str] = mapped_column(TEXT(), primary_key=True)
|
||||
value: Mapped[str] = mapped_column(TEXT(), primary_key=True)
|
||||
|
||||
|
||||
class ClassifiedSample(ProjectBase):
|
||||
__tablename__ = "classified_samples"
|
||||
|
||||
sampleNumpyMd5: Mapped[str] = mapped_column(
|
||||
"sample_numpy_md5", CHAR(32), primary_key=True, unique=True
|
||||
)
|
||||
tag: Mapped[str] = mapped_column(TEXT(), primary_key=True)
|
169
project.py
169
project.py
@ -1,37 +1,73 @@
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
from sqlalchemy import create_engine, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from dbModels import ClassifiedSample, ProjectBase, Property, TagValue
|
||||
|
||||
PROJECTS_ROOT_PATH = Path("projects")
|
||||
ACCEPT_EXTS = [".jpg", ".png"]
|
||||
|
||||
|
||||
def initProject(path: Path):
|
||||
engine = create_engine(
|
||||
f"sqlite:///{(path / 'project.db').resolve().as_posix()}", poolclass=NullPool
|
||||
)
|
||||
ProjectBase.metadata.create_all(engine)
|
||||
(path / "sources").mkdir(parents=True, exist_ok=True)
|
||||
(path / "samples").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
class Project:
|
||||
path: Path
|
||||
|
||||
def __init__(self, path: Path):
|
||||
self.path = path
|
||||
self._tagValueDict = {}
|
||||
with open(self.path / "project.json", "r", encoding="utf-8") as jf:
|
||||
projectJson = json.loads(jf.read())
|
||||
self._tagValueDict: dict[str, Any] = projectJson["tagValueMap"]
|
||||
self.name = projectJson.get("name", self.path.name)
|
||||
self._tags = list(self._tagValueDict.keys())
|
||||
self._values = list(self._tagValueDict.values())
|
||||
|
||||
self.__engine = create_engine(
|
||||
f"sqlite:///{(path / 'project.db').resolve().as_posix()}",
|
||||
poolclass=NullPool,
|
||||
)
|
||||
self.__sessionmaker = sessionmaker(self.__engine)
|
||||
self.reload()
|
||||
|
||||
def reload(self):
|
||||
with self.__sessionmaker() as session:
|
||||
nameProperty = session.scalar(
|
||||
select(Property).where(Property.key == "name")
|
||||
)
|
||||
self.__name = nameProperty.value if nameProperty else self.path.name
|
||||
|
||||
self._tagValueDict = {}
|
||||
tagValues = session.scalars(select(TagValue))
|
||||
for tagValue in tagValues:
|
||||
self._tagValueDict[tagValue.tag] = tagValue.value
|
||||
self._tags = list(self._tagValueDict.keys())
|
||||
self._values = list(self._tagValueDict.values())
|
||||
|
||||
# expire property caches
|
||||
# https://stackoverflow.com/a/69367025/16484891, CC BY-SA 4.0
|
||||
self.__dict__.pop("name", None)
|
||||
self.__dict__.pop("tags", None)
|
||||
self.__dict__.pop("values", None)
|
||||
self.__dict__.pop("tagValueMap", None)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Project(path={repr(self.path)})"
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__name
|
||||
|
||||
@cached_property
|
||||
def tags(self):
|
||||
return deepcopy(self._tags)
|
||||
@ -44,11 +80,6 @@ class Project:
|
||||
def tagValueMap(self):
|
||||
return deepcopy(self._tagValueDict)
|
||||
|
||||
@cached_property
|
||||
def tagsReExp(self):
|
||||
tagsDivided = "|".join(str(tag) for tag in self.tags)
|
||||
return re.compile(f"^({tagsDivided})\\^")
|
||||
|
||||
@cached_property
|
||||
def sourcesPath(self):
|
||||
return self.path / "sources"
|
||||
@ -57,29 +88,6 @@ class Project:
|
||||
def samplesPath(self):
|
||||
return self.path / "samples"
|
||||
|
||||
@cached_property
|
||||
def samplesUnclassifiedPath(self):
|
||||
return self.samplesPath / "unclassified"
|
||||
|
||||
@cached_property
|
||||
def samplesClassifiedPath(self):
|
||||
return self.samplesPath / "classified"
|
||||
|
||||
@cached_property
|
||||
def samplesIgnoredPath(self):
|
||||
return self.samplesPath / "ignored"
|
||||
|
||||
def createFolders(self):
|
||||
folders = [
|
||||
self.sourcesPath,
|
||||
self.samplesClassifiedPath,
|
||||
self.samplesUnclassifiedPath,
|
||||
self.samplesIgnoredPath,
|
||||
]
|
||||
|
||||
for folder in folders:
|
||||
folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def listPathFiles(self, path: Path, acceptSuffixes: list[str] = ACCEPT_EXTS):
|
||||
return [p for p in path.glob("**/*") if p.suffix in acceptSuffixes]
|
||||
|
||||
@ -91,24 +99,42 @@ class Project:
|
||||
def samples(self):
|
||||
return self.listPathFiles(self.samplesPath)
|
||||
|
||||
@property
|
||||
def samplesUnclassified(self):
|
||||
return self.listPathFiles(self.samplesUnclassifiedPath)
|
||||
|
||||
@property
|
||||
def samplesClassified(self):
|
||||
return self.listPathFiles(self.samplesClassifiedPath)
|
||||
with self.__sessionmaker() as session:
|
||||
return [
|
||||
cs.sampleNumpyMd5 for cs in session.scalars(select(ClassifiedSample))
|
||||
]
|
||||
|
||||
@property
|
||||
def samplesIgnored(self):
|
||||
return self.listPathFiles(self.samplesIgnoredPath)
|
||||
with self.__sessionmaker() as session:
|
||||
return [
|
||||
cs.sampleNumpyMd5
|
||||
for cs in session.scalars(
|
||||
select(ClassifiedSample).where(ClassifiedSample.tag == "ignored")
|
||||
)
|
||||
]
|
||||
|
||||
@property
|
||||
def samplesUnclassified(self):
|
||||
samplesNumpyMd5s = [s.stem for s in self.samples]
|
||||
classifiedSamples = []
|
||||
classifiedSamples += self.samplesClassified
|
||||
classifiedSamples += self.samplesIgnored
|
||||
return [s for s in samplesNumpyMd5s if s not in classifiedSamples]
|
||||
|
||||
def samplesByTag(self, tag: str):
|
||||
if tag not in self.tags:
|
||||
if tag != "ignored" and tag not in self.tags:
|
||||
raise ValueError(f'Unknown tag "{tag}"')
|
||||
|
||||
samples = self.samples
|
||||
return [p for p in samples if p.stem.startswith(f"{tag}^")]
|
||||
with self.__sessionmaker() as session:
|
||||
return [
|
||||
cs.sampleNumpyMd5
|
||||
for cs in session.scalars(
|
||||
select(ClassifiedSample).where(ClassifiedSample.tag == tag)
|
||||
)
|
||||
]
|
||||
|
||||
def getModule(self, moduleName: str):
|
||||
cwdPath = Path(os.getcwd())
|
||||
@ -119,9 +145,9 @@ class Project:
|
||||
importName = ".".join(importParts)
|
||||
return importlib.import_module(importName)
|
||||
|
||||
def extractYield(self):
|
||||
def extractSamplesYield(self):
|
||||
extractModule = self.getModule("extract")
|
||||
getSamples = extractModule.getSamples
|
||||
getSamples = extractModule.extractSamples
|
||||
assert callable(getSamples)
|
||||
|
||||
extractLogger = logging.getLogger(
|
||||
@ -157,7 +183,7 @@ class Project:
|
||||
continue
|
||||
|
||||
extractLogger.info(f"{sampleMd5} <- {source.name}")
|
||||
sampleSavePath = self.samplesUnclassifiedPath / f"{sampleMd5}.jpg"
|
||||
sampleSavePath = self.samplesPath / f"{sampleMd5}.jpg"
|
||||
with open(sampleSavePath, "wb") as sf:
|
||||
sf.write(sampleBuffer)
|
||||
existingSamplesMd5.append(sampleMd5)
|
||||
@ -166,10 +192,10 @@ class Project:
|
||||
finally:
|
||||
yield (source, i, sourcesNum)
|
||||
|
||||
def extract(self):
|
||||
list(self.extractYield())
|
||||
def extractSamples(self):
|
||||
list(self.extractSamplesYield())
|
||||
|
||||
def redactYield(self):
|
||||
def redactSourcesYield(self):
|
||||
redactModule = self.getModule("redact")
|
||||
redactSource = redactModule.redactSource
|
||||
assert callable(redactSource)
|
||||
@ -189,27 +215,29 @@ class Project:
|
||||
finally:
|
||||
yield (source, i, sourcesNum)
|
||||
|
||||
def redact(self):
|
||||
list(self.redactYield())
|
||||
|
||||
def getSampleOriginalFileName(self, sample: Path):
|
||||
return self.tagsReExp.sub("", sample.name)
|
||||
def redactSources(self):
|
||||
list(self.redactSourcesYield())
|
||||
|
||||
def classify(self, sample: Path, tag: str):
|
||||
if tag not in self.tags:
|
||||
raise ValueError(f'Unknown tag "{tag}"')
|
||||
|
||||
originalFileName = self.getSampleOriginalFileName(sample)
|
||||
classifiedFileName = f"{tag}^{originalFileName}"
|
||||
return sample.rename(self.samplesClassifiedPath / classifiedFileName)
|
||||
with self.__sessionmaker() as session:
|
||||
cs = ClassifiedSample()
|
||||
cs.sampleNumpyMd5 = sample.stem
|
||||
cs.tag = tag
|
||||
session.add(cs)
|
||||
session.commit()
|
||||
|
||||
def unclassify(self, sample: Path):
|
||||
originalFileName = self.getSampleOriginalFileName(sample)
|
||||
return sample.rename(self.samplesUnclassifiedPath / originalFileName)
|
||||
with self.__sessionmaker() as session:
|
||||
cs = ClassifiedSample()
|
||||
cs.sampleNumpyMd5 = sample.stem
|
||||
session.delete(cs)
|
||||
session.commit()
|
||||
|
||||
def ignore(self, sample: Path):
|
||||
originalFileName = self.getSampleOriginalFileName(sample)
|
||||
return sample.rename(self.samplesIgnoredPath / originalFileName)
|
||||
self.classify(sample, "ignored")
|
||||
|
||||
|
||||
class Projects:
|
||||
@ -223,14 +251,9 @@ class Projects:
|
||||
|
||||
folders = [p for p in self.rootFolderPath.iterdir() if p.is_dir()]
|
||||
for folder in folders:
|
||||
if not (folder / "project.json").exists():
|
||||
if not (folder / "project.db").exists():
|
||||
continue
|
||||
project = Project(folder)
|
||||
if not (
|
||||
project.sourcesPath.exists()
|
||||
and project.samplesClassifiedPath.exists()
|
||||
and project.samplesUnclassifiedPath.exists()
|
||||
and project.samplesIgnoredPath.exists()
|
||||
):
|
||||
if not (project.sourcesPath.exists() and project.samplesPath.exists()):
|
||||
continue
|
||||
self.projects.append(project)
|
||||
|
Loading…
x
Reference in New Issue
Block a user