mirror of
https://github.com/283375/arcaea-offline-ocr-model.git
synced 2025-04-04 14:10:18 +00:00
283 lines
9.2 KiB
Python
283 lines
9.2 KiB
Python
import importlib
|
|
import logging
|
|
import os
|
|
import time
|
|
from copy import deepcopy
|
|
from functools import cached_property
|
|
from hashlib import md5
|
|
from pathlib import Path
|
|
|
|
import cv2
|
|
from sqlalchemy import create_engine, select
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy.pool import NullPool
|
|
|
|
from dbModels import ClassifiedSample, ProjectBase, Property, TagValue
|
|
|
|
PROJECTS_ROOT_PATH = Path("projects")
|
|
ACCEPT_EXTS = [".jpg", ".png"]
|
|
|
|
|
|
def initProject(path: Path):
|
|
engine = create_engine(
|
|
f"sqlite:///{(path / 'project.db').resolve().as_posix()}", poolclass=NullPool
|
|
)
|
|
ProjectBase.metadata.create_all(engine)
|
|
(path / "sources").mkdir(parents=True, exist_ok=True)
|
|
(path / "samples").mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
class Project:
|
|
path: Path
|
|
|
|
def __init__(self, path: Path):
|
|
self.path = path
|
|
|
|
self.__engine = create_engine(
|
|
f"sqlite:///{(path / 'project.db').resolve().as_posix()}",
|
|
poolclass=NullPool,
|
|
)
|
|
self.__sessionmaker = sessionmaker(self.__engine)
|
|
self.reload()
|
|
|
|
def reload(self):
|
|
with self.__sessionmaker() as session:
|
|
nameProperty = session.scalar(
|
|
select(Property).where(Property.key == "name")
|
|
)
|
|
self.__name = nameProperty.value if nameProperty else self.path.name
|
|
|
|
self._tagValueDict = {}
|
|
tagValues = session.scalars(select(TagValue))
|
|
for tagValue in tagValues:
|
|
self._tagValueDict[tagValue.tag] = tagValue.value
|
|
self._tags = list(self._tagValueDict.keys())
|
|
self._values = list(self._tagValueDict.values())
|
|
|
|
# expire property caches
|
|
# https://stackoverflow.com/a/69367025/16484891, CC BY-SA 4.0
|
|
self.__dict__.pop("name", None)
|
|
self.__dict__.pop("tags", None)
|
|
self.__dict__.pop("values", None)
|
|
self.__dict__.pop("tagValueMap", None)
|
|
|
|
def __repr__(self):
|
|
return f"Project(path={repr(self.path)})"
|
|
|
|
@property
|
|
def name(self):
|
|
return self.__name
|
|
|
|
@cached_property
|
|
def tags(self):
|
|
return deepcopy(self._tags)
|
|
|
|
@cached_property
|
|
def values(self):
|
|
return deepcopy(self.values)
|
|
|
|
@cached_property
|
|
def tagValueMap(self):
|
|
return deepcopy(self._tagValueDict)
|
|
|
|
@cached_property
|
|
def sourcesPath(self):
|
|
return self.path / "sources"
|
|
|
|
@cached_property
|
|
def samplesPath(self):
|
|
return self.path / "samples"
|
|
|
|
def listPathFiles(self, path: Path, acceptSuffixes: list[str] = ACCEPT_EXTS):
|
|
return [p for p in path.glob("**/*") if p.suffix in acceptSuffixes]
|
|
|
|
@property
|
|
def sources(self):
|
|
return self.listPathFiles(self.sourcesPath)
|
|
|
|
@property
|
|
def samples(self):
|
|
return self.listPathFiles(self.samplesPath)
|
|
|
|
@property
|
|
def samplesClassified(self):
|
|
with self.__sessionmaker() as session:
|
|
samplesClassifiedMd5s = [
|
|
cs.sampleNumpyMd5
|
|
for cs in session.scalars(
|
|
select(ClassifiedSample).where(ClassifiedSample.tag != "ignored")
|
|
)
|
|
]
|
|
return [p for p in self.samples if p.stem in samplesClassifiedMd5s]
|
|
|
|
@property
|
|
def samplesIgnored(self):
|
|
with self.__sessionmaker() as session:
|
|
samplesIgnoredMd5s = [
|
|
cs.sampleNumpyMd5
|
|
for cs in session.scalars(
|
|
select(ClassifiedSample).where(ClassifiedSample.tag == "ignored")
|
|
)
|
|
]
|
|
return [p for p in self.samples if p.stem in samplesIgnoredMd5s]
|
|
|
|
@property
|
|
def samplesUnclassified(self):
|
|
classifiedList = []
|
|
classifiedList += self.samplesClassified
|
|
classifiedList += self.samplesIgnored
|
|
return list(filter(lambda p: p not in classifiedList, self.samples))
|
|
|
|
def samplesByTag(self, tag: str):
|
|
if tag != "ignored" and tag not in self.tags:
|
|
raise ValueError(f'Unknown tag "{tag}"')
|
|
|
|
with self.__sessionmaker() as session:
|
|
sampleMd5s = [
|
|
cs.sampleNumpyMd5
|
|
for cs in session.scalars(
|
|
select(ClassifiedSample).where(ClassifiedSample.tag == tag)
|
|
)
|
|
]
|
|
return [p for p in self.samples if p.stem in sampleMd5s]
|
|
|
|
def getModule(self, moduleName: str):
|
|
cwdPath = Path(os.getcwd())
|
|
importParts = [
|
|
*self.path.resolve().relative_to(cwdPath.resolve()).parts,
|
|
moduleName,
|
|
]
|
|
importName = ".".join(importParts)
|
|
return importlib.import_module(importName)
|
|
|
|
def extractSamplesYield(self):
|
|
extractModule = self.getModule("extract")
|
|
getSamples = extractModule.extractSamples
|
|
assert callable(getSamples)
|
|
|
|
extractLogger = logging.getLogger(
|
|
f"extract-{self.name}-{int(time.time() * 1000)}"
|
|
)
|
|
|
|
extractLogger.info("Reading existing samples MD5...")
|
|
# existingSamplesMd5 = [
|
|
# self.getSampleOriginalFileName(sample).split(".")[0] for sample in samples
|
|
# ]
|
|
existingSamplesMd5 = []
|
|
for sample in self.samples:
|
|
with open(sample, "rb") as sf:
|
|
existingSamplesMd5.append(md5(sf.read()).hexdigest())
|
|
|
|
sources = self.sources
|
|
sourcesNum = len(sources)
|
|
for i, source in enumerate(sources):
|
|
try:
|
|
extractLogger.info(f"Extracting {source.resolve()}")
|
|
samples = getSamples(source)
|
|
for sample in samples:
|
|
success, sampleBuffer = cv2.imencode(".jpg", sample)
|
|
if not success:
|
|
extractLogger.warning(
|
|
f"cv2 cannot encode {sampleMd5} from {source.name}, skipping"
|
|
)
|
|
continue
|
|
|
|
sampleMd5 = md5(sampleBuffer).hexdigest()
|
|
if sampleMd5 in existingSamplesMd5:
|
|
extractLogger.debug(f"{sampleMd5} from {source.name} skipped")
|
|
continue
|
|
|
|
extractLogger.info(f"{sampleMd5} <- {source.name}")
|
|
sampleSavePath = self.samplesPath / f"{sampleMd5}.jpg"
|
|
with open(sampleSavePath, "wb") as sf:
|
|
sf.write(sampleBuffer)
|
|
existingSamplesMd5.append(sampleMd5)
|
|
except Exception:
|
|
extractLogger.exception(f"Error extracting {source.resolve()}")
|
|
finally:
|
|
yield (source, i, sourcesNum)
|
|
|
|
def extractSamples(self):
|
|
list(self.extractSamplesYield())
|
|
|
|
def redactSourcesYield(self):
|
|
redactModule = self.getModule("redact")
|
|
redactSource = redactModule.redactSource
|
|
assert callable(redactSource)
|
|
|
|
redactLogger = logging.getLogger(
|
|
f"redact-{self.name}-{int(time.time() * 1000)}"
|
|
)
|
|
|
|
sources = self.sources
|
|
sourcesNum = len(sources)
|
|
for i, source in enumerate(sources):
|
|
try:
|
|
redactLogger.info(f"Redacting {source.resolve()}")
|
|
redactSource(source)
|
|
except Exception:
|
|
redactLogger.exception(f"Error redacting {source.resolve()}")
|
|
finally:
|
|
yield (source, i, sourcesNum)
|
|
|
|
def redactSources(self):
|
|
list(self.redactSourcesYield())
|
|
|
|
def train(self):
|
|
trainModule = self.getModule("train")
|
|
trainClass = trainModule.Train
|
|
|
|
trainItems = [
|
|
{"tag": tag, "value": int(value), "samples": self.samplesByTag(tag)}
|
|
for tag, value in self.tagValueMap.items()
|
|
]
|
|
|
|
trainClassInstance = trainClass(trainItems)
|
|
|
|
knnModel = trainClassInstance.train_knn()
|
|
knnModel.save(str((self.path / "knn.dat").resolve()))
|
|
svmModel = trainClassInstance.train_svm()
|
|
svmModel.save(str((self.path / "svm.dat").resolve()))
|
|
|
|
def classify(self, sample: Path, tag: str):
|
|
if tag != "ignored" and tag not in self.tags:
|
|
raise ValueError(f'Unknown tag "{tag}"')
|
|
|
|
with self.__sessionmaker() as session:
|
|
cs = ClassifiedSample()
|
|
cs.sampleNumpyMd5 = sample.stem
|
|
cs.tag = tag
|
|
session.add(cs)
|
|
session.commit()
|
|
|
|
def unclassify(self, sample: Path):
|
|
with self.__sessionmaker() as session:
|
|
stmt = select(ClassifiedSample).where(
|
|
ClassifiedSample.sampleNumpyMd5 == sample.stem
|
|
)
|
|
cs = session.scalar(stmt)
|
|
session.delete(cs)
|
|
session.commit()
|
|
|
|
def ignore(self, sample: Path):
|
|
self.classify(sample, "ignored")
|
|
|
|
|
|
class Projects:
|
|
def __init__(self, rootFolderPath=PROJECTS_ROOT_PATH):
|
|
self.rootFolderPath = rootFolderPath
|
|
self.projects: list[Project] = []
|
|
self.detectProjects()
|
|
|
|
def detectProjects(self):
|
|
self.projects.clear()
|
|
|
|
folders = [p for p in self.rootFolderPath.iterdir() if p.is_dir()]
|
|
for folder in folders:
|
|
if not (folder / "project.db").exists():
|
|
continue
|
|
project = Project(folder)
|
|
if not (project.sourcesPath.exists() and project.samplesPath.exists()):
|
|
continue
|
|
self.projects.append(project)
|