refactor: sqlalchemy custom types

- Unify `IntEnum` type decorators to single `DbIntEnum`
- Add timezone aware `TZDateTime` from sqlalchemy docs
This commit is contained in:
283375 2024-05-20 21:21:49 +08:00
parent 0d5e21a90e
commit ce715bfccc
Signed by: 283375
SSH Key Fingerprint: SHA256:UcX0qg6ZOSDOeieKPGokA5h7soykG61nz2uxuQgVLSk
2 changed files with 74 additions and 114 deletions

View File

@ -1,58 +1,46 @@
from typing import Optional from datetime import datetime, timezone
from enum import IntEnum
from typing import Optional, Type
from sqlalchemy import Integer from sqlalchemy import DateTime, Integer
from sqlalchemy.types import TypeDecorator from sqlalchemy.types import TypeDecorator
from arcaea_offline.constants.enums import (
ArcaeaPlayResultClearType,
ArcaeaPlayResultModifier,
ArcaeaRatingClass,
)
class DbIntEnum(TypeDecorator):
class DbRatingClass(TypeDecorator): """sqlalchemy `TypeDecorator` for `IntEnum`s"""
"""sqlalchemy rating_class type decorator"""
impl = Integer impl = Integer
cache_ok = True
def process_bind_param( def __init__(self, enum_class: Type[IntEnum]):
self, value: Optional[ArcaeaRatingClass], dialect super().__init__()
) -> Optional[int]: self.enum_class = enum_class
def process_bind_param(self, value: Optional[IntEnum], dialect) -> Optional[int]:
return None if value is None else value.value return None if value is None else value.value
def process_result_value( def process_result_value(self, value: Optional[int], dialect) -> Optional[IntEnum]:
self, value: Optional[int], dialect return None if value is None else self.enum_class(value)
) -> Optional[ArcaeaRatingClass]:
return None if value is None else ArcaeaRatingClass(value)
class DbClearType(TypeDecorator): class TZDateTime(TypeDecorator):
"""sqlalchemy clear_type type decorator""" """
Store Timezone Aware Timestamps as Timezone Naive UTC
impl = Integer https://docs.sqlalchemy.org/en/20/core/custom_types.html#store-timezone-aware-timestamps-as-timezone-naive-utc
"""
def process_bind_param( impl = DateTime
self, value: Optional[ArcaeaPlayResultClearType], dialect cache_ok = True
) -> Optional[int]:
return None if value is None else value.value
def process_result_value( def process_bind_param(self, value: Optional[datetime], dialect):
self, value: Optional[int], dialect if value is not None:
) -> Optional[ArcaeaPlayResultClearType]: if not value.tzinfo or value.tzinfo.utcoffset(value) is None:
return None if value is None else ArcaeaPlayResultClearType(value) raise TypeError("tzinfo is required")
value = value.astimezone(timezone.utc).replace(tzinfo=None)
return value
def process_result_value(self, value: Optional[datetime], dialect):
class DbModifier(TypeDecorator): if value is not None:
"""sqlalchemy modifier type decorator""" value = value.replace(tzinfo=timezone.utc)
return value
impl = Integer
def process_bind_param(
self, value: Optional[ArcaeaPlayResultModifier], dialect
) -> Optional[int]:
return None if value is None else value.value
def process_result_value(
self, value: Optional[int], dialect
) -> Optional[ArcaeaPlayResultModifier]:
return None if value is None else ArcaeaPlayResultModifier(value)

View File

@ -1,95 +1,67 @@
from datetime import datetime, timedelta, timezone
from enum import IntEnum
from typing import Optional from typing import Optional
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from arcaea_offline.constants.enums import ( from arcaea_offline.database.models._custom_types import DbIntEnum, TZDateTime
ArcaeaPlayResultClearType,
ArcaeaPlayResultModifier,
ArcaeaRatingClass,
)
from arcaea_offline.database.models._custom_types import (
DbClearType,
DbModifier,
DbRatingClass,
)
class Base(DeclarativeBase): class TestIntEnum(IntEnum):
pass __test__ = False
ONE = 1
TWO = 2
THREE = 3
class RatingClassTestModel(Base): class TestBase(DeclarativeBase):
__tablename__ = "test_rating_class" __test__ = False
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
value: Mapped[Optional[ArcaeaRatingClass]] = mapped_column(
DbRatingClass, nullable=True
)
class ClearTypeTestModel(Base): class IntEnumTestModel(TestBase):
__tablename__ = "test_clear_type" __tablename__ = "test_int_enum"
value: Mapped[Optional[TestIntEnum]] = mapped_column(DbIntEnum(TestIntEnum))
id: Mapped[int] = mapped_column(primary_key=True)
value: Mapped[Optional[ArcaeaPlayResultClearType]] = mapped_column(
DbClearType, nullable=True
)
class ModifierTestModel(Base): class TZDatetimeTestModel(TestBase):
__tablename__ = "test_modifier" __tablename__ = "test_tz_datetime"
value: Mapped[Optional[datetime]] = mapped_column(TZDateTime)
id: Mapped[int] = mapped_column(primary_key=True)
value: Mapped[Optional[ArcaeaPlayResultModifier]] = mapped_column(
DbModifier, nullable=True
)
class TestCustomTypes: class TestCustomTypes:
def _common_test_method(self, session: Session, obj: Base, value_in_db): def test_int_enum(self, db_session):
""" def _query_value(_id: int):
This method stores the `obj` into the given `session`, return db_session.execute(
then fetches the raw value of `obj.value` from database, text(
and asserts that the value is equal to `value_in_db`. f"SELECT value FROM {IntEnumTestModel.__tablename__} WHERE id = {_id}"
""" )
session.add(obj) ).one()[0]
session.commit()
exec_result = session.execute( TestBase.metadata.create_all(db_session.bind)
text(
f"SELECT value FROM {obj.__tablename__} WHERE id = {obj.id}" # type: ignore
)
).fetchone()[0]
if value_in_db is None: basic_obj = IntEnumTestModel(id=1, value=TestIntEnum.TWO)
assert exec_result is value_in_db null_obj = IntEnumTestModel(id=2, value=None)
else: db_session.add(basic_obj)
assert exec_result == value_in_db db_session.add(null_obj)
db_session.commit()
def test_rating_class(self, db_session): assert _query_value(1) == TestIntEnum.TWO.value
Base.metadata.create_all(db_session.bind) assert _query_value(2) is None
basic_obj = RatingClassTestModel(id=1, value=ArcaeaRatingClass.FUTURE) def test_tz_datetime(self, db_session):
self._common_test_method(db_session, basic_obj, 2) TestBase.metadata.create_all(db_session.bind)
null_obj = RatingClassTestModel(id=2, value=None) dt1 = datetime.now(tz=timezone(timedelta(hours=8)))
self._common_test_method(db_session, null_obj, None)
def test_clear_type(self, db_session): basic_obj = TZDatetimeTestModel(id=1, value=dt1)
Base.metadata.create_all(db_session.bind) null_obj = TZDatetimeTestModel(id=2, value=None)
db_session.add(basic_obj)
db_session.add(null_obj)
db_session.commit()
basic_obj = ClearTypeTestModel(id=1, value=ArcaeaPlayResultClearType.TRACK_LOST) assert basic_obj.value == dt1
self._common_test_method(db_session, basic_obj, 0) assert null_obj.value is None
null_obj = ClearTypeTestModel(id=2, value=None)
self._common_test_method(db_session, null_obj, None)
def test_modifier(self, db_session):
Base.metadata.create_all(db_session.bind)
basic_obj = ModifierTestModel(id=1, value=ArcaeaPlayResultModifier.HARD)
self._common_test_method(db_session, basic_obj, 2)
null_obj = ModifierTestModel(id=2, value=None)
self._common_test_method(db_session, null_obj, None)