diff --git a/project.py b/project.py index 8135962..197fd0b 100644 --- a/project.py +++ b/project.py @@ -133,12 +133,13 @@ class Project: raise ValueError(f'Unknown tag "{tag}"') with self.__sessionmaker() as session: - return [ + sampleMd5s = [ cs.sampleNumpyMd5 for cs in session.scalars( select(ClassifiedSample).where(ClassifiedSample.tag == tag) ) ] + return [p for p in self.samples if p.stem in sampleMd5s] def getModule(self, moduleName: str): cwdPath = Path(os.getcwd()) @@ -235,8 +236,10 @@ class Project: def unclassify(self, sample: Path): with self.__sessionmaker() as session: - cs = ClassifiedSample() - cs.sampleNumpyMd5 = sample.stem + stmt = select(ClassifiedSample).where( + ClassifiedSample.sampleNumpyMd5 == sample.stem + ) + cs = session.scalar(stmt) session.delete(cs) session.commit()