From e948b6abeab2a42f73a179df72ca93d13783ea18 Mon Sep 17 00:00:00 2001 From: 283375 Date: Mon, 28 Aug 2023 22:54:37 +0800 Subject: [PATCH] chore(db)!: `Database` methods --- src/arcaea_offline/database.py | 51 ++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/src/arcaea_offline/database.py b/src/arcaea_offline/database.py index cc2a4d2..10cca5a 100644 --- a/src/arcaea_offline/database.py +++ b/src/arcaea_offline/database.py @@ -50,6 +50,8 @@ class Database(metaclass=Singleton): def sessionmaker(self): return self.__sessionmaker + # region init + def init(self, checkfirst: bool = True): # create tables & views if checkfirst: @@ -80,6 +82,7 @@ class Database(metaclass=Singleton): + list(ScoresBase.metadata.tables.keys()) + list(ConfigBase.metadata.tables.keys()) + [ + Chart.__tablename__, ScoreCalculated.__tablename__, ScoreBest.__tablename__, CalculatedPotential.__tablename__, @@ -87,30 +90,56 @@ class Database(metaclass=Singleton): ) return all(inspect(self.engine).has_table(t) for t in expect_tables) + # endregion + def version(self) -> Union[int, None]: stmt = select(Property).where(Property.key == "version") with self.sessionmaker() as session: result = session.scalar(stmt) return None if result is None else int(result.value) + # region Pack + def get_packs(self): stmt = select(Pack) with self.sessionmaker() as session: - results = list(session.scalars(stmt)) - return results + results = session.scalars(stmt) + return list(results) - def get_pack_by_id(self, value: str): - stmt = select(Pack).where(Pack.id == value) + def get_pack_by_id(self, pack_id: str): + stmt = select(Pack).where(Pack.id == pack_id) with self.sessionmaker() as session: result = session.scalar(stmt) return result - def get_charts_in_pack(self, pack: str): - stmt = ( - select(ChartInfo) - .join(Song, (Song.id == ChartInfo.song_id)) - .where(Song.set == pack) + # endregion + + # region Chart + + def get_charts_by_pack_id(self, pack_id: str): + stmt = select(Chart).where(Chart.set == pack_id) + with self.sessionmaker() as session: + results = session.scalars(stmt) + return list(results) + + def get_charts_by_song_id(self, song_id: str): + stmt = select(Chart).where(Chart.song_id == song_id) + with self.sessionmaker() as session: + results = session.scalars(stmt) + return list(results) + + def get_chart(self, song_id: str, rating_class: int): + stmt = select(Chart).where( + (Chart.song_id == song_id) & (Chart.rating_class == rating_class) ) with self.sessionmaker() as session: - results = list(session.scalars(stmt)) - return results + result = session.scalar(stmt) + return result + + # endregion + + def get_b30(self): + stmt = select(CalculatedPotential.b30).select_from(CalculatedPotential) + with self.sessionmaker() as session: + result = session.scalar(stmt) + return result