mirror of
https://github.com/283375/arcaea-offline.git
synced 2025-04-21 06:50:18 +00:00
146 lines
4.8 KiB
Python
146 lines
4.8 KiB
Python
import logging
|
|
from typing import Optional, Union
|
|
|
|
from sqlalchemy import Engine, inspect, select
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
from .models.config import *
|
|
from .models.scores import *
|
|
from .models.songs import *
|
|
from .singleton import Singleton
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Database(metaclass=Singleton):
|
|
def __init__(self, engine: Optional[Engine]):
|
|
try:
|
|
self.__engine
|
|
except AttributeError:
|
|
self.__engine = None
|
|
|
|
if engine is None:
|
|
if isinstance(self.engine, Engine):
|
|
return
|
|
raise ValueError("No sqlalchemy.Engine instance specified before.")
|
|
elif isinstance(engine, Engine):
|
|
if isinstance(self.engine, Engine):
|
|
logger.warning(
|
|
f"A sqlalchemy.Engine instance {self.engine} has been specified "
|
|
f"and will be replaced to {engine}"
|
|
)
|
|
self.engine = engine
|
|
else:
|
|
raise ValueError(
|
|
f"A sqlalchemy.Engine instance expected, not {repr(engine)}"
|
|
)
|
|
|
|
@property
|
|
def engine(self) -> Engine:
|
|
return self.__engine # type: ignore
|
|
|
|
@engine.setter
|
|
def engine(self, value: Engine):
|
|
if not isinstance(value, Engine):
|
|
raise ValueError("Database.engine only accepts sqlalchemy.Engine")
|
|
self.__engine = value
|
|
self.__sessionmaker = sessionmaker(self.__engine)
|
|
|
|
@property
|
|
def sessionmaker(self):
|
|
return self.__sessionmaker
|
|
|
|
# region init
|
|
|
|
def init(self, checkfirst: bool = True):
|
|
# create tables & views
|
|
if checkfirst:
|
|
# > https://github.com/kvesteri/sqlalchemy-utils/issues/396
|
|
# > view.create_view() causes DuplicateTableError on Base.metadata.create_all(checkfirst=True)
|
|
# so if `checkfirst` is True, drop these views before creating
|
|
SongsViewBase.metadata.drop_all(self.engine)
|
|
ScoresViewBase.metadata.drop_all(self.engine)
|
|
|
|
SongsBase.metadata.create_all(self.engine, checkfirst=checkfirst)
|
|
SongsViewBase.metadata.create_all(self.engine)
|
|
ScoresBase.metadata.create_all(self.engine, checkfirst=checkfirst)
|
|
ScoresViewBase.metadata.create_all(self.engine)
|
|
ConfigBase.metadata.create_all(self.engine, checkfirst=checkfirst)
|
|
|
|
# insert version property
|
|
with self.sessionmaker() as session:
|
|
stmt = select(Property.value).where(Property.key == "version")
|
|
result = session.execute(stmt).fetchone()
|
|
if not checkfirst or not result:
|
|
session.add(Property(key="version", value="3"))
|
|
session.commit()
|
|
|
|
def check_init(self) -> bool:
|
|
# check table exists
|
|
expect_tables = (
|
|
list(SongsBase.metadata.tables.keys())
|
|
+ list(ScoresBase.metadata.tables.keys())
|
|
+ list(ConfigBase.metadata.tables.keys())
|
|
+ [
|
|
Chart.__tablename__,
|
|
ScoreCalculated.__tablename__,
|
|
ScoreBest.__tablename__,
|
|
CalculatedPotential.__tablename__,
|
|
]
|
|
)
|
|
return all(inspect(self.engine).has_table(t) for t in expect_tables)
|
|
|
|
# endregion
|
|
|
|
def version(self) -> Union[int, None]:
|
|
stmt = select(Property).where(Property.key == "version")
|
|
with self.sessionmaker() as session:
|
|
result = session.scalar(stmt)
|
|
return None if result is None else int(result.value)
|
|
|
|
# region Pack
|
|
|
|
def get_packs(self):
|
|
stmt = select(Pack)
|
|
with self.sessionmaker() as session:
|
|
results = session.scalars(stmt)
|
|
return list(results)
|
|
|
|
def get_pack_by_id(self, pack_id: str):
|
|
stmt = select(Pack).where(Pack.id == pack_id)
|
|
with self.sessionmaker() as session:
|
|
result = session.scalar(stmt)
|
|
return result
|
|
|
|
# endregion
|
|
|
|
# region Chart
|
|
|
|
def get_charts_by_pack_id(self, pack_id: str):
|
|
stmt = select(Chart).where(Chart.set == pack_id)
|
|
with self.sessionmaker() as session:
|
|
results = session.scalars(stmt)
|
|
return list(results)
|
|
|
|
def get_charts_by_song_id(self, song_id: str):
|
|
stmt = select(Chart).where(Chart.song_id == song_id)
|
|
with self.sessionmaker() as session:
|
|
results = session.scalars(stmt)
|
|
return list(results)
|
|
|
|
def get_chart(self, song_id: str, rating_class: int):
|
|
stmt = select(Chart).where(
|
|
(Chart.song_id == song_id) & (Chart.rating_class == rating_class)
|
|
)
|
|
with self.sessionmaker() as session:
|
|
result = session.scalar(stmt)
|
|
return result
|
|
|
|
# endregion
|
|
|
|
def get_b30(self):
|
|
stmt = select(CalculatedPotential.b30).select_from(CalculatedPotential)
|
|
with self.sessionmaker() as session:
|
|
result = session.scalar(stmt)
|
|
return result
|