refactor: sqlalchemy custom types

- Unify `IntEnum` type decorators to single `DbIntEnum`
- Add timezone aware `TZDateTime` from sqlalchemy docs
This commit is contained in:
2024-05-20 21:21:49 +08:00
parent 0d5e21a90e
commit ce715bfccc
2 changed files with 74 additions and 114 deletions

View File

@ -1,95 +1,67 @@
from datetime import datetime, timedelta, timezone
from enum import IntEnum
from typing import Optional
from sqlalchemy import text
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from arcaea_offline.constants.enums import (
ArcaeaPlayResultClearType,
ArcaeaPlayResultModifier,
ArcaeaRatingClass,
)
from arcaea_offline.database.models._custom_types import (
DbClearType,
DbModifier,
DbRatingClass,
)
from arcaea_offline.database.models._custom_types import DbIntEnum, TZDateTime
class Base(DeclarativeBase):
pass
class TestIntEnum(IntEnum):
__test__ = False
ONE = 1
TWO = 2
THREE = 3
class RatingClassTestModel(Base):
__tablename__ = "test_rating_class"
class TestBase(DeclarativeBase):
__test__ = False
id: Mapped[int] = mapped_column(primary_key=True)
value: Mapped[Optional[ArcaeaRatingClass]] = mapped_column(
DbRatingClass, nullable=True
)
class ClearTypeTestModel(Base):
__tablename__ = "test_clear_type"
id: Mapped[int] = mapped_column(primary_key=True)
value: Mapped[Optional[ArcaeaPlayResultClearType]] = mapped_column(
DbClearType, nullable=True
)
class IntEnumTestModel(TestBase):
__tablename__ = "test_int_enum"
value: Mapped[Optional[TestIntEnum]] = mapped_column(DbIntEnum(TestIntEnum))
class ModifierTestModel(Base):
__tablename__ = "test_modifier"
id: Mapped[int] = mapped_column(primary_key=True)
value: Mapped[Optional[ArcaeaPlayResultModifier]] = mapped_column(
DbModifier, nullable=True
)
class TZDatetimeTestModel(TestBase):
__tablename__ = "test_tz_datetime"
value: Mapped[Optional[datetime]] = mapped_column(TZDateTime)
class TestCustomTypes:
def _common_test_method(self, session: Session, obj: Base, value_in_db):
"""
This method stores the `obj` into the given `session`,
then fetches the raw value of `obj.value` from database,
and asserts that the value is equal to `value_in_db`.
"""
session.add(obj)
session.commit()
def test_int_enum(self, db_session):
def _query_value(_id: int):
return db_session.execute(
text(
f"SELECT value FROM {IntEnumTestModel.__tablename__} WHERE id = {_id}"
)
).one()[0]
exec_result = session.execute(
text(
f"SELECT value FROM {obj.__tablename__} WHERE id = {obj.id}" # type: ignore
)
).fetchone()[0]
TestBase.metadata.create_all(db_session.bind)
if value_in_db is None:
assert exec_result is value_in_db
else:
assert exec_result == value_in_db
basic_obj = IntEnumTestModel(id=1, value=TestIntEnum.TWO)
null_obj = IntEnumTestModel(id=2, value=None)
db_session.add(basic_obj)
db_session.add(null_obj)
db_session.commit()
def test_rating_class(self, db_session):
Base.metadata.create_all(db_session.bind)
assert _query_value(1) == TestIntEnum.TWO.value
assert _query_value(2) is None
basic_obj = RatingClassTestModel(id=1, value=ArcaeaRatingClass.FUTURE)
self._common_test_method(db_session, basic_obj, 2)
def test_tz_datetime(self, db_session):
TestBase.metadata.create_all(db_session.bind)
null_obj = RatingClassTestModel(id=2, value=None)
self._common_test_method(db_session, null_obj, None)
dt1 = datetime.now(tz=timezone(timedelta(hours=8)))
def test_clear_type(self, db_session):
Base.metadata.create_all(db_session.bind)
basic_obj = TZDatetimeTestModel(id=1, value=dt1)
null_obj = TZDatetimeTestModel(id=2, value=None)
db_session.add(basic_obj)
db_session.add(null_obj)
db_session.commit()
basic_obj = ClearTypeTestModel(id=1, value=ArcaeaPlayResultClearType.TRACK_LOST)
self._common_test_method(db_session, basic_obj, 0)
null_obj = ClearTypeTestModel(id=2, value=None)
self._common_test_method(db_session, null_obj, None)
def test_modifier(self, db_session):
Base.metadata.create_all(db_session.bind)
basic_obj = ModifierTestModel(id=1, value=ArcaeaPlayResultModifier.HARD)
self._common_test_method(db_session, basic_obj, 2)
null_obj = ModifierTestModel(id=2, value=None)
self._common_test_method(db_session, null_obj, None)
assert basic_obj.value == dt1
assert null_obj.value is None