diff --git a/dbModels.py b/dbModels.py new file mode 100644 index 0000000..4752b32 --- /dev/null +++ b/dbModels.py @@ -0,0 +1,29 @@ +from sqlalchemy import CHAR, TEXT +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + +class ProjectBase(DeclarativeBase): + pass + + +class Property(ProjectBase): + __tablename__ = "properties" + + key: Mapped[str] = mapped_column(TEXT(), primary_key=True) + value: Mapped[str] = mapped_column(TEXT(), primary_key=True) + + +class TagValue(ProjectBase): + __tablename__ = "tag_values" + + tag: Mapped[str] = mapped_column(TEXT(), primary_key=True) + value: Mapped[str] = mapped_column(TEXT(), primary_key=True) + + +class ClassifiedSample(ProjectBase): + __tablename__ = "classified_samples" + + sampleNumpyMd5: Mapped[str] = mapped_column( + "sample_numpy_md5", CHAR(32), primary_key=True, unique=True + ) + tag: Mapped[str] = mapped_column(TEXT(), primary_key=True) diff --git a/project.py b/project.py index 2eadc73..a4e14cf 100644 --- a/project.py +++ b/project.py @@ -1,37 +1,73 @@ import importlib -import json import logging import os -import re import time from copy import deepcopy from functools import cached_property -from pathlib import Path -from typing import Any from hashlib import md5 +from pathlib import Path import cv2 +from sqlalchemy import create_engine, select +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import NullPool + +from dbModels import ClassifiedSample, ProjectBase, Property, TagValue PROJECTS_ROOT_PATH = Path("projects") ACCEPT_EXTS = [".jpg", ".png"] +def initProject(path: Path): + engine = create_engine( + f"sqlite:///{(path / 'project.db').resolve().as_posix()}", poolclass=NullPool + ) + ProjectBase.metadata.create_all(engine) + (path / "sources").mkdir(parents=True, exist_ok=True) + (path / "samples").mkdir(parents=True, exist_ok=True) + + class Project: path: Path def __init__(self, path: Path): self.path = path - self._tagValueDict = {} - with open(self.path / "project.json", "r", encoding="utf-8") as jf: - projectJson = json.loads(jf.read()) - self._tagValueDict: dict[str, Any] = projectJson["tagValueMap"] - self.name = projectJson.get("name", self.path.name) - self._tags = list(self._tagValueDict.keys()) - self._values = list(self._tagValueDict.values()) + + self.__engine = create_engine( + f"sqlite:///{(path / 'project.db').resolve().as_posix()}", + poolclass=NullPool, + ) + self.__sessionmaker = sessionmaker(self.__engine) + self.reload() + + def reload(self): + with self.__sessionmaker() as session: + nameProperty = session.scalar( + select(Property).where(Property.key == "name") + ) + self.__name = nameProperty.value if nameProperty else self.path.name + + self._tagValueDict = {} + tagValues = session.scalars(select(TagValue)) + for tagValue in tagValues: + self._tagValueDict[tagValue.tag] = tagValue.value + self._tags = list(self._tagValueDict.keys()) + self._values = list(self._tagValueDict.values()) + + # expire property caches + # https://stackoverflow.com/a/69367025/16484891, CC BY-SA 4.0 + self.__dict__.pop("name", None) + self.__dict__.pop("tags", None) + self.__dict__.pop("values", None) + self.__dict__.pop("tagValueMap", None) def __repr__(self): return f"Project(path={repr(self.path)})" + @property + def name(self): + return self.__name + @cached_property def tags(self): return deepcopy(self._tags) @@ -44,11 +80,6 @@ class Project: def tagValueMap(self): return deepcopy(self._tagValueDict) - @cached_property - def tagsReExp(self): - tagsDivided = "|".join(str(tag) for tag in self.tags) - return re.compile(f"^({tagsDivided})\\^") - @cached_property def sourcesPath(self): return self.path / "sources" @@ -57,29 +88,6 @@ class Project: def samplesPath(self): return self.path / "samples" - @cached_property - def samplesUnclassifiedPath(self): - return self.samplesPath / "unclassified" - - @cached_property - def samplesClassifiedPath(self): - return self.samplesPath / "classified" - - @cached_property - def samplesIgnoredPath(self): - return self.samplesPath / "ignored" - - def createFolders(self): - folders = [ - self.sourcesPath, - self.samplesClassifiedPath, - self.samplesUnclassifiedPath, - self.samplesIgnoredPath, - ] - - for folder in folders: - folder.mkdir(parents=True, exist_ok=True) - def listPathFiles(self, path: Path, acceptSuffixes: list[str] = ACCEPT_EXTS): return [p for p in path.glob("**/*") if p.suffix in acceptSuffixes] @@ -91,24 +99,42 @@ class Project: def samples(self): return self.listPathFiles(self.samplesPath) - @property - def samplesUnclassified(self): - return self.listPathFiles(self.samplesUnclassifiedPath) - @property def samplesClassified(self): - return self.listPathFiles(self.samplesClassifiedPath) + with self.__sessionmaker() as session: + return [ + cs.sampleNumpyMd5 for cs in session.scalars(select(ClassifiedSample)) + ] @property def samplesIgnored(self): - return self.listPathFiles(self.samplesIgnoredPath) + with self.__sessionmaker() as session: + return [ + cs.sampleNumpyMd5 + for cs in session.scalars( + select(ClassifiedSample).where(ClassifiedSample.tag == "ignored") + ) + ] + + @property + def samplesUnclassified(self): + samplesNumpyMd5s = [s.stem for s in self.samples] + classifiedSamples = [] + classifiedSamples += self.samplesClassified + classifiedSamples += self.samplesIgnored + return [s for s in samplesNumpyMd5s if s not in classifiedSamples] def samplesByTag(self, tag: str): - if tag not in self.tags: + if tag != "ignored" and tag not in self.tags: raise ValueError(f'Unknown tag "{tag}"') - samples = self.samples - return [p for p in samples if p.stem.startswith(f"{tag}^")] + with self.__sessionmaker() as session: + return [ + cs.sampleNumpyMd5 + for cs in session.scalars( + select(ClassifiedSample).where(ClassifiedSample.tag == tag) + ) + ] def getModule(self, moduleName: str): cwdPath = Path(os.getcwd()) @@ -119,9 +145,9 @@ class Project: importName = ".".join(importParts) return importlib.import_module(importName) - def extractYield(self): + def extractSamplesYield(self): extractModule = self.getModule("extract") - getSamples = extractModule.getSamples + getSamples = extractModule.extractSamples assert callable(getSamples) extractLogger = logging.getLogger( @@ -157,7 +183,7 @@ class Project: continue extractLogger.info(f"{sampleMd5} <- {source.name}") - sampleSavePath = self.samplesUnclassifiedPath / f"{sampleMd5}.jpg" + sampleSavePath = self.samplesPath / f"{sampleMd5}.jpg" with open(sampleSavePath, "wb") as sf: sf.write(sampleBuffer) existingSamplesMd5.append(sampleMd5) @@ -166,10 +192,10 @@ class Project: finally: yield (source, i, sourcesNum) - def extract(self): - list(self.extractYield()) + def extractSamples(self): + list(self.extractSamplesYield()) - def redactYield(self): + def redactSourcesYield(self): redactModule = self.getModule("redact") redactSource = redactModule.redactSource assert callable(redactSource) @@ -189,27 +215,29 @@ class Project: finally: yield (source, i, sourcesNum) - def redact(self): - list(self.redactYield()) - - def getSampleOriginalFileName(self, sample: Path): - return self.tagsReExp.sub("", sample.name) + def redactSources(self): + list(self.redactSourcesYield()) def classify(self, sample: Path, tag: str): if tag not in self.tags: raise ValueError(f'Unknown tag "{tag}"') - originalFileName = self.getSampleOriginalFileName(sample) - classifiedFileName = f"{tag}^{originalFileName}" - return sample.rename(self.samplesClassifiedPath / classifiedFileName) + with self.__sessionmaker() as session: + cs = ClassifiedSample() + cs.sampleNumpyMd5 = sample.stem + cs.tag = tag + session.add(cs) + session.commit() def unclassify(self, sample: Path): - originalFileName = self.getSampleOriginalFileName(sample) - return sample.rename(self.samplesUnclassifiedPath / originalFileName) + with self.__sessionmaker() as session: + cs = ClassifiedSample() + cs.sampleNumpyMd5 = sample.stem + session.delete(cs) + session.commit() def ignore(self, sample: Path): - originalFileName = self.getSampleOriginalFileName(sample) - return sample.rename(self.samplesIgnoredPath / originalFileName) + self.classify(sample, "ignored") class Projects: @@ -223,14 +251,9 @@ class Projects: folders = [p for p in self.rootFolderPath.iterdir() if p.is_dir()] for folder in folders: - if not (folder / "project.json").exists(): + if not (folder / "project.db").exists(): continue project = Project(folder) - if not ( - project.sourcesPath.exists() - and project.samplesClassifiedPath.exists() - and project.samplesUnclassifiedPath.exists() - and project.samplesIgnoredPath.exists() - ): + if not (project.sourcesPath.exists() and project.samplesPath.exists()): continue self.projects.append(project)