From 003e1a7289526f4727af71175955f7749cf8cf51 Mon Sep 17 00:00:00 2001 From: 283375 Date: Sat, 26 Aug 2023 16:52:19 +0800 Subject: [PATCH] refactor: `Database` class --- src/arcaea_offline/database.py | 56 +++++++++++++++++++++-------- src/arcaea_offline/models/common.py | 2 +- src/arcaea_offline/models/scores.py | 3 ++ src/arcaea_offline/singleton.py | 12 +++++++ 4 files changed, 57 insertions(+), 16 deletions(-) create mode 100644 src/arcaea_offline/singleton.py diff --git a/src/arcaea_offline/database.py b/src/arcaea_offline/database.py index 9a0e20f..b366e1a 100644 --- a/src/arcaea_offline/database.py +++ b/src/arcaea_offline/database.py @@ -1,22 +1,48 @@ -from sqlalchemy import Engine -from sqlalchemy.orm import Session +from sqlalchemy import Engine, select +from sqlalchemy.orm import sessionmaker from .models.common import * from .models.scores import * from .models.songs import * +from .singleton import Singleton -def init(engine: Engine, checkfirst: bool = True): - # sqlalchemy-utils issue #396 - # view.create_view() causes DuplicateTableError on Base.metadata.create_all(checkfirst=True) - # https://github.com/kvesteri/sqlalchemy-utils/issues/396 - if checkfirst: - ScoresViewBase.metadata.drop_all(engine) +class Database(metaclass=Singleton): + def __init__(self, engine: Engine): + self.engine = engine - SongsBase.metadata.create_all(engine, checkfirst=checkfirst) - ScoresBase.metadata.create_all(engine, checkfirst=checkfirst) - ScoresViewBase.metadata.create_all(engine) - CommonBase.metadata.create_all(engine, checkfirst=checkfirst) - with Session(engine) as session: - session.add(Property(id="version", value="2")) - session.commit() + @property + def engine(self): + return self.__engine + + @engine.setter + def engine(self, value: Engine): + if not isinstance(value, Engine): + raise ValueError("Database.engine only accepts sqlalchemy.Engine") + self.__engine = value + self.__sessionmaker = sessionmaker(self.__engine) + + @property + def sessionmaker(self): + return self.__sessionmaker + + def init(self, checkfirst: bool = True): + # create tables & views + if checkfirst: + # > https://github.com/kvesteri/sqlalchemy-utils/issues/396 + # > view.create_view() causes DuplicateTableError on Base.metadata.create_all(checkfirst=True) + # so if `checkfirst` is True, drop these views before creating + ScoresViewBase.metadata.drop_all(self.engine) + + SongsBase.metadata.create_all(self.engine, checkfirst=checkfirst) + ScoresBase.metadata.create_all(self.engine, checkfirst=checkfirst) + ScoresViewBase.metadata.create_all(self.engine) + CommonBase.metadata.create_all(self.engine, checkfirst=checkfirst) + + # insert version property + with self.sessionmaker() as session: + stmt = select(Property.value).where(Property.key == "version") + result = session.execute(stmt).fetchone() + if not checkfirst or not result: + session.add(Property(key="version", value="2")) + session.commit() diff --git a/src/arcaea_offline/models/common.py b/src/arcaea_offline/models/common.py index 435c075..85fba75 100644 --- a/src/arcaea_offline/models/common.py +++ b/src/arcaea_offline/models/common.py @@ -14,5 +14,5 @@ class CommonBase(DeclarativeBase): class Property(CommonBase): __tablename__ = "property" - id: Mapped[str] = mapped_column(TEXT(), primary_key=True) + key: Mapped[str] = mapped_column(TEXT(), primary_key=True) value: Mapped[str] = mapped_column(TEXT()) diff --git a/src/arcaea_offline/models/scores.py b/src/arcaea_offline/models/scores.py index 10c6b72..6ac279a 100644 --- a/src/arcaea_offline/models/scores.py +++ b/src/arcaea_offline/models/scores.py @@ -104,6 +104,7 @@ class Calculated(ScoresViewBase): & (Chart.rating_class == Score.rating_class), ), metadata=ScoresViewBase.metadata, + cascade_on_drop=False, ) @@ -131,6 +132,7 @@ class Best(ScoresViewBase): .group_by(Calculated.song_id, Calculated.rating_class) .order_by(Calculated.potential.desc()), metadata=ScoresViewBase.metadata, + cascade_on_drop=False, ) @@ -147,4 +149,5 @@ class CalculatedPotential(ScoresViewBase): name="calculated_potential", selectable=select(func.avg(_select_bests_subquery.c.b30_sum).label("b30")), metadata=ScoresViewBase.metadata, + cascade_on_drop=False, ) diff --git a/src/arcaea_offline/singleton.py b/src/arcaea_offline/singleton.py new file mode 100644 index 0000000..6776678 --- /dev/null +++ b/src/arcaea_offline/singleton.py @@ -0,0 +1,12 @@ +from typing import Generic, TypeVar + +T = TypeVar("T") + + +class Singleton(type, Generic[T]): + _instance = None + + def __call__(cls, *args, **kwargs) -> T: + if cls._instance is None: + cls._instance = super().__call__(*args, **kwargs) + return cls._instance