From 01bfd0f350c4fa06ac41bc5fe0ad8f4bc59755c1 Mon Sep 17 00:00:00 2001 From: 283375 Date: Thu, 31 Aug 2023 21:51:49 +0800 Subject: [PATCH] feat(db): `COUNT` related methods --- src/arcaea_offline/database.py | 60 +++++++++++++++++++++++++++++----- 1 file changed, 52 insertions(+), 8 deletions(-) diff --git a/src/arcaea_offline/database.py b/src/arcaea_offline/database.py index 73e22ce..aeb24c1 100644 --- a/src/arcaea_offline/database.py +++ b/src/arcaea_offline/database.py @@ -1,8 +1,8 @@ import logging -from typing import Optional, Union +from typing import Optional, Type, Union from sqlalchemy import Engine, func, inspect, select -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, sessionmaker from .external.arcsong.arcsong_json import ArcSongJsonBuilder from .models.config import * @@ -153,12 +153,6 @@ class Database(metaclass=Singleton): result = session.scalar(stmt) return result - def count_scores(self): - stmt = select(func.count(Score.id)) - with self.sessionmaker() as session: - result = session.scalar(stmt) - return result or 0 - # endregion def get_b30(self): @@ -167,6 +161,56 @@ class Database(metaclass=Singleton): result = session.scalar(stmt) return result + # region COUNT + + def __count_table(self, base: Type[DeclarativeBase]): + stmt = select(func.count()).select_from(base) + with self.sessionmaker() as session: + result = session.scalar(stmt) + return result or 0 + + def __count_column(self, column: InstrumentedAttribute): + stmt = select(func.count(column)) + with self.sessionmaker() as session: + result = session.scalar(stmt) + return result or 0 + + def count_packs(self): + return self.__count_column(Pack.id) + + def count_songs(self): + return self.__count_column(Song.id) + + def count_difficulties(self): + return self.__count_table(Difficulty) + + def count_chart_infos(self): + return self.__count_table(ChartInfo) + + def count_complete_chart_infos(self): + stmt = ( + select(func.count()) + .select_from(ChartInfo) + .where((ChartInfo.constant != None) & (ChartInfo.note != None)) + ) + with self.sessionmaker() as session: + result = session.scalar(stmt) + return result or 0 + + def count_charts(self): + return self.__count_table(Chart) + + def count_scores(self): + return self.__count_column(Score.id) + + def count_scores_calculated(self): + return self.__count_table(ScoreCalculated) + + def count_scores_best(self): + return self.__count_table(ScoreBest) + + # endregion + def generate_arcsong(self): with self.sessionmaker() as session: arcsong = ArcSongJsonBuilder(session).generate_arcsong_json()