146 lines
4.8 KiB
Python

import logging
from typing import Optional, Union
from sqlalchemy import Engine, inspect, select
from sqlalchemy.orm import sessionmaker
from .models.config import *
from .models.scores import *
from .models.songs import *
from .singleton import Singleton
logger = logging.getLogger(__name__)
class Database(metaclass=Singleton):
def __init__(self, engine: Optional[Engine]):
try:
self.__engine
except AttributeError:
self.__engine = None
if engine is None:
if isinstance(self.engine, Engine):
return
raise ValueError("No sqlalchemy.Engine instance specified before.")
elif isinstance(engine, Engine):
if isinstance(self.engine, Engine):
logger.warning(
f"A sqlalchemy.Engine instance {self.engine} has been specified "
f"and will be replaced to {engine}"
)
self.engine = engine
else:
raise ValueError(
f"A sqlalchemy.Engine instance expected, not {repr(engine)}"
)
@property
def engine(self) -> Engine:
return self.__engine # type: ignore
@engine.setter
def engine(self, value: Engine):
if not isinstance(value, Engine):
raise ValueError("Database.engine only accepts sqlalchemy.Engine")
self.__engine = value
self.__sessionmaker = sessionmaker(self.__engine)
@property
def sessionmaker(self):
return self.__sessionmaker
# region init
def init(self, checkfirst: bool = True):
# create tables & views
if checkfirst:
# > https://github.com/kvesteri/sqlalchemy-utils/issues/396
# > view.create_view() causes DuplicateTableError on Base.metadata.create_all(checkfirst=True)
# so if `checkfirst` is True, drop these views before creating
SongsViewBase.metadata.drop_all(self.engine)
ScoresViewBase.metadata.drop_all(self.engine)
SongsBase.metadata.create_all(self.engine, checkfirst=checkfirst)
SongsViewBase.metadata.create_all(self.engine)
ScoresBase.metadata.create_all(self.engine, checkfirst=checkfirst)
ScoresViewBase.metadata.create_all(self.engine)
ConfigBase.metadata.create_all(self.engine, checkfirst=checkfirst)
# insert version property
with self.sessionmaker() as session:
stmt = select(Property.value).where(Property.key == "version")
result = session.execute(stmt).fetchone()
if not checkfirst or not result:
session.add(Property(key="version", value="3"))
session.commit()
def check_init(self) -> bool:
# check table exists
expect_tables = (
list(SongsBase.metadata.tables.keys())
+ list(ScoresBase.metadata.tables.keys())
+ list(ConfigBase.metadata.tables.keys())
+ [
Chart.__tablename__,
ScoreCalculated.__tablename__,
ScoreBest.__tablename__,
CalculatedPotential.__tablename__,
]
)
return all(inspect(self.engine).has_table(t) for t in expect_tables)
# endregion
def version(self) -> Union[int, None]:
stmt = select(Property).where(Property.key == "version")
with self.sessionmaker() as session:
result = session.scalar(stmt)
return None if result is None else int(result.value)
# region Pack
def get_packs(self):
stmt = select(Pack)
with self.sessionmaker() as session:
results = session.scalars(stmt)
return list(results)
def get_pack_by_id(self, pack_id: str):
stmt = select(Pack).where(Pack.id == pack_id)
with self.sessionmaker() as session:
result = session.scalar(stmt)
return result
# endregion
# region Chart
def get_charts_by_pack_id(self, pack_id: str):
stmt = select(Chart).where(Chart.set == pack_id)
with self.sessionmaker() as session:
results = session.scalars(stmt)
return list(results)
def get_charts_by_song_id(self, song_id: str):
stmt = select(Chart).where(Chart.song_id == song_id)
with self.sessionmaker() as session:
results = session.scalars(stmt)
return list(results)
def get_chart(self, song_id: str, rating_class: int):
stmt = select(Chart).where(
(Chart.song_id == song_id) & (Chart.rating_class == rating_class)
)
with self.sessionmaker() as session:
result = session.scalar(stmt)
return result
# endregion
def get_b30(self):
stmt = select(CalculatedPotential.b30).select_from(CalculatedPotential)
with self.sessionmaker() as session:
result = session.scalar(stmt)
return result