From ce715bfccc55a2111a29cd2bf753b487efe78f8b Mon Sep 17 00:00:00 2001 From: 283375 Date: Mon, 20 May 2024 21:21:49 +0800 Subject: [PATCH] refactor: sqlalchemy custom types - Unify `IntEnum` type decorators to single `DbIntEnum` - Add timezone aware `TZDateTime` from sqlalchemy docs --- .../database/models/_custom_types.py | 74 +++++------- tests/db/models/test_custom_types.py | 114 +++++++----------- 2 files changed, 74 insertions(+), 114 deletions(-) diff --git a/src/arcaea_offline/database/models/_custom_types.py b/src/arcaea_offline/database/models/_custom_types.py index 1f3f09b..e42beb5 100644 --- a/src/arcaea_offline/database/models/_custom_types.py +++ b/src/arcaea_offline/database/models/_custom_types.py @@ -1,58 +1,46 @@ -from typing import Optional +from datetime import datetime, timezone +from enum import IntEnum +from typing import Optional, Type -from sqlalchemy import Integer +from sqlalchemy import DateTime, Integer from sqlalchemy.types import TypeDecorator -from arcaea_offline.constants.enums import ( - ArcaeaPlayResultClearType, - ArcaeaPlayResultModifier, - ArcaeaRatingClass, -) - -class DbRatingClass(TypeDecorator): - """sqlalchemy rating_class type decorator""" +class DbIntEnum(TypeDecorator): + """sqlalchemy `TypeDecorator` for `IntEnum`s""" impl = Integer + cache_ok = True - def process_bind_param( - self, value: Optional[ArcaeaRatingClass], dialect - ) -> Optional[int]: + def __init__(self, enum_class: Type[IntEnum]): + super().__init__() + self.enum_class = enum_class + + def process_bind_param(self, value: Optional[IntEnum], dialect) -> Optional[int]: return None if value is None else value.value - def process_result_value( - self, value: Optional[int], dialect - ) -> Optional[ArcaeaRatingClass]: - return None if value is None else ArcaeaRatingClass(value) + def process_result_value(self, value: Optional[int], dialect) -> Optional[IntEnum]: + return None if value is None else self.enum_class(value) -class DbClearType(TypeDecorator): - """sqlalchemy clear_type type decorator""" +class TZDateTime(TypeDecorator): + """ + Store Timezone Aware Timestamps as Timezone Naive UTC - impl = Integer + https://docs.sqlalchemy.org/en/20/core/custom_types.html#store-timezone-aware-timestamps-as-timezone-naive-utc + """ - def process_bind_param( - self, value: Optional[ArcaeaPlayResultClearType], dialect - ) -> Optional[int]: - return None if value is None else value.value + impl = DateTime + cache_ok = True - def process_result_value( - self, value: Optional[int], dialect - ) -> Optional[ArcaeaPlayResultClearType]: - return None if value is None else ArcaeaPlayResultClearType(value) + def process_bind_param(self, value: Optional[datetime], dialect): + if value is not None: + if not value.tzinfo or value.tzinfo.utcoffset(value) is None: + raise TypeError("tzinfo is required") + value = value.astimezone(timezone.utc).replace(tzinfo=None) + return value - -class DbModifier(TypeDecorator): - """sqlalchemy modifier type decorator""" - - impl = Integer - - def process_bind_param( - self, value: Optional[ArcaeaPlayResultModifier], dialect - ) -> Optional[int]: - return None if value is None else value.value - - def process_result_value( - self, value: Optional[int], dialect - ) -> Optional[ArcaeaPlayResultModifier]: - return None if value is None else ArcaeaPlayResultModifier(value) + def process_result_value(self, value: Optional[datetime], dialect): + if value is not None: + value = value.replace(tzinfo=timezone.utc) + return value diff --git a/tests/db/models/test_custom_types.py b/tests/db/models/test_custom_types.py index a6b3fb8..8b77b0f 100644 --- a/tests/db/models/test_custom_types.py +++ b/tests/db/models/test_custom_types.py @@ -1,95 +1,67 @@ +from datetime import datetime, timedelta, timezone +from enum import IntEnum from typing import Optional from sqlalchemy import text -from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column -from arcaea_offline.constants.enums import ( - ArcaeaPlayResultClearType, - ArcaeaPlayResultModifier, - ArcaeaRatingClass, -) -from arcaea_offline.database.models._custom_types import ( - DbClearType, - DbModifier, - DbRatingClass, -) +from arcaea_offline.database.models._custom_types import DbIntEnum, TZDateTime -class Base(DeclarativeBase): - pass +class TestIntEnum(IntEnum): + __test__ = False + + ONE = 1 + TWO = 2 + THREE = 3 -class RatingClassTestModel(Base): - __tablename__ = "test_rating_class" +class TestBase(DeclarativeBase): + __test__ = False id: Mapped[int] = mapped_column(primary_key=True) - value: Mapped[Optional[ArcaeaRatingClass]] = mapped_column( - DbRatingClass, nullable=True - ) -class ClearTypeTestModel(Base): - __tablename__ = "test_clear_type" - - id: Mapped[int] = mapped_column(primary_key=True) - value: Mapped[Optional[ArcaeaPlayResultClearType]] = mapped_column( - DbClearType, nullable=True - ) +class IntEnumTestModel(TestBase): + __tablename__ = "test_int_enum" + value: Mapped[Optional[TestIntEnum]] = mapped_column(DbIntEnum(TestIntEnum)) -class ModifierTestModel(Base): - __tablename__ = "test_modifier" - - id: Mapped[int] = mapped_column(primary_key=True) - value: Mapped[Optional[ArcaeaPlayResultModifier]] = mapped_column( - DbModifier, nullable=True - ) +class TZDatetimeTestModel(TestBase): + __tablename__ = "test_tz_datetime" + value: Mapped[Optional[datetime]] = mapped_column(TZDateTime) class TestCustomTypes: - def _common_test_method(self, session: Session, obj: Base, value_in_db): - """ - This method stores the `obj` into the given `session`, - then fetches the raw value of `obj.value` from database, - and asserts that the value is equal to `value_in_db`. - """ - session.add(obj) - session.commit() + def test_int_enum(self, db_session): + def _query_value(_id: int): + return db_session.execute( + text( + f"SELECT value FROM {IntEnumTestModel.__tablename__} WHERE id = {_id}" + ) + ).one()[0] - exec_result = session.execute( - text( - f"SELECT value FROM {obj.__tablename__} WHERE id = {obj.id}" # type: ignore - ) - ).fetchone()[0] + TestBase.metadata.create_all(db_session.bind) - if value_in_db is None: - assert exec_result is value_in_db - else: - assert exec_result == value_in_db + basic_obj = IntEnumTestModel(id=1, value=TestIntEnum.TWO) + null_obj = IntEnumTestModel(id=2, value=None) + db_session.add(basic_obj) + db_session.add(null_obj) + db_session.commit() - def test_rating_class(self, db_session): - Base.metadata.create_all(db_session.bind) + assert _query_value(1) == TestIntEnum.TWO.value + assert _query_value(2) is None - basic_obj = RatingClassTestModel(id=1, value=ArcaeaRatingClass.FUTURE) - self._common_test_method(db_session, basic_obj, 2) + def test_tz_datetime(self, db_session): + TestBase.metadata.create_all(db_session.bind) - null_obj = RatingClassTestModel(id=2, value=None) - self._common_test_method(db_session, null_obj, None) + dt1 = datetime.now(tz=timezone(timedelta(hours=8))) - def test_clear_type(self, db_session): - Base.metadata.create_all(db_session.bind) + basic_obj = TZDatetimeTestModel(id=1, value=dt1) + null_obj = TZDatetimeTestModel(id=2, value=None) + db_session.add(basic_obj) + db_session.add(null_obj) + db_session.commit() - basic_obj = ClearTypeTestModel(id=1, value=ArcaeaPlayResultClearType.TRACK_LOST) - self._common_test_method(db_session, basic_obj, 0) - - null_obj = ClearTypeTestModel(id=2, value=None) - self._common_test_method(db_session, null_obj, None) - - def test_modifier(self, db_session): - Base.metadata.create_all(db_session.bind) - - basic_obj = ModifierTestModel(id=1, value=ArcaeaPlayResultModifier.HARD) - self._common_test_method(db_session, basic_obj, 2) - - null_obj = ModifierTestModel(id=2, value=None) - self._common_test_method(db_session, null_obj, None) + assert basic_obj.value == dt1 + assert null_obj.value is None