From e295e58388480b0cdba48629a3352bd981cc369a Mon Sep 17 00:00:00 2001 From: 283375 Date: Sat, 13 Apr 2024 22:56:15 +0800 Subject: [PATCH] feat: sqlalchemy `TypeDecorator`s for arcaea enums --- .../database/models/_custom_types.py | 58 +++++++++++ tests/db/models/test_custom_types.py | 95 +++++++++++++++++++ 2 files changed, 153 insertions(+) create mode 100644 src/arcaea_offline/database/models/_custom_types.py create mode 100644 tests/db/models/test_custom_types.py diff --git a/src/arcaea_offline/database/models/_custom_types.py b/src/arcaea_offline/database/models/_custom_types.py new file mode 100644 index 0000000..1f3f09b --- /dev/null +++ b/src/arcaea_offline/database/models/_custom_types.py @@ -0,0 +1,58 @@ +from typing import Optional + +from sqlalchemy import Integer +from sqlalchemy.types import TypeDecorator + +from arcaea_offline.constants.enums import ( + ArcaeaPlayResultClearType, + ArcaeaPlayResultModifier, + ArcaeaRatingClass, +) + + +class DbRatingClass(TypeDecorator): + """sqlalchemy rating_class type decorator""" + + impl = Integer + + def process_bind_param( + self, value: Optional[ArcaeaRatingClass], dialect + ) -> Optional[int]: + return None if value is None else value.value + + def process_result_value( + self, value: Optional[int], dialect + ) -> Optional[ArcaeaRatingClass]: + return None if value is None else ArcaeaRatingClass(value) + + +class DbClearType(TypeDecorator): + """sqlalchemy clear_type type decorator""" + + impl = Integer + + def process_bind_param( + self, value: Optional[ArcaeaPlayResultClearType], dialect + ) -> Optional[int]: + return None if value is None else value.value + + def process_result_value( + self, value: Optional[int], dialect + ) -> Optional[ArcaeaPlayResultClearType]: + return None if value is None else ArcaeaPlayResultClearType(value) + + +class DbModifier(TypeDecorator): + """sqlalchemy modifier type decorator""" + + impl = Integer + + def process_bind_param( + self, value: Optional[ArcaeaPlayResultModifier], dialect + ) -> Optional[int]: + return None if value is None else value.value + + def process_result_value( + self, value: Optional[int], dialect + ) -> Optional[ArcaeaPlayResultModifier]: + return None if value is None else ArcaeaPlayResultModifier(value) diff --git a/tests/db/models/test_custom_types.py b/tests/db/models/test_custom_types.py new file mode 100644 index 0000000..a6808e4 --- /dev/null +++ b/tests/db/models/test_custom_types.py @@ -0,0 +1,95 @@ +from typing import Optional + +from sqlalchemy import text +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +from arcaea_offline.constants.enums import ( + ArcaeaPlayResultClearType, + ArcaeaPlayResultModifier, + ArcaeaRatingClass, +) +from arcaea_offline.database.models._custom_types import ( + DbClearType, + DbModifier, + DbRatingClass, +) + + +class Base(DeclarativeBase): + pass + + +class RatingClassTestModel(Base): + __tablename__ = "test_rating_class" + + id: Mapped[int] = mapped_column(primary_key=True) + value: Mapped[Optional[ArcaeaRatingClass]] = mapped_column( + DbRatingClass, nullable=True + ) + + +class ClearTypeTestModel(Base): + __tablename__ = "test_clear_type" + + id: Mapped[int] = mapped_column(primary_key=True) + value: Mapped[Optional[ArcaeaPlayResultClearType]] = mapped_column( + DbClearType, nullable=True + ) + + +class ModifierTestModel(Base): + __tablename__ = "test_modifier" + + id: Mapped[int] = mapped_column(primary_key=True) + value: Mapped[Optional[ArcaeaPlayResultModifier]] = mapped_column( + DbModifier, nullable=True + ) + + +class TestCustomTypes: + def _common_test_method(self, db_session, obj: Base, value_in_db): + """ + This method stores the `obj` into the given `db_session`, + then fetches the raw value of `obj.value` from database, + and asserts that the value is equal to `value_in_db`. + """ + db_session.add(obj) + db_session.commit() + + exec_result = db_session.execute( + text( + f"SELECT value FROM {obj.__tablename__} WHERE id = {obj.id}" # type: ignore + ) + ).fetchone()[0] + + if value_in_db is None: + assert exec_result is value_in_db + else: + assert exec_result == value_in_db + + def test_rating_class(self, db_session): + Base.metadata.create_all(db_session.bind) + + basic_obj = RatingClassTestModel(id=1, value=ArcaeaRatingClass.FUTURE) + self._common_test_method(db_session, basic_obj, 2) + + null_obj = RatingClassTestModel(id=2, value=None) + self._common_test_method(db_session, null_obj, None) + + def test_clear_type(self, db_session): + Base.metadata.create_all(db_session.bind) + + basic_obj = ClearTypeTestModel(id=1, value=ArcaeaPlayResultClearType.TRACK_LOST) + self._common_test_method(db_session, basic_obj, 0) + + null_obj = ClearTypeTestModel(id=2, value=None) + self._common_test_method(db_session, null_obj, None) + + def test_modifier(self, db_session): + Base.metadata.create_all(db_session.bind) + + basic_obj = ModifierTestModel(id=1, value=ArcaeaPlayResultModifier.HARD) + self._common_test_method(db_session, basic_obj, 2) + + null_obj = ModifierTestModel(id=2, value=None) + self._common_test_method(db_session, null_obj, None)