refactor: sqlalchemy custom types

- Unify `IntEnum` type decorators to single `DbIntEnum`
- Add timezone aware `TZDateTime` from sqlalchemy docs
This commit is contained in:
283375 2024-05-20 21:21:49 +08:00
parent 0d5e21a90e
commit ce715bfccc
Signed by: 283375
SSH Key Fingerprint: SHA256:UcX0qg6ZOSDOeieKPGokA5h7soykG61nz2uxuQgVLSk
2 changed files with 74 additions and 114 deletions

View File

@ -1,58 +1,46 @@
from typing import Optional
from datetime import datetime, timezone
from enum import IntEnum
from typing import Optional, Type
from sqlalchemy import Integer
from sqlalchemy import DateTime, Integer
from sqlalchemy.types import TypeDecorator
from arcaea_offline.constants.enums import (
ArcaeaPlayResultClearType,
ArcaeaPlayResultModifier,
ArcaeaRatingClass,
)
class DbRatingClass(TypeDecorator):
"""sqlalchemy rating_class type decorator"""
class DbIntEnum(TypeDecorator):
"""sqlalchemy `TypeDecorator` for `IntEnum`s"""
impl = Integer
cache_ok = True
def process_bind_param(
self, value: Optional[ArcaeaRatingClass], dialect
) -> Optional[int]:
def __init__(self, enum_class: Type[IntEnum]):
super().__init__()
self.enum_class = enum_class
def process_bind_param(self, value: Optional[IntEnum], dialect) -> Optional[int]:
return None if value is None else value.value
def process_result_value(
self, value: Optional[int], dialect
) -> Optional[ArcaeaRatingClass]:
return None if value is None else ArcaeaRatingClass(value)
def process_result_value(self, value: Optional[int], dialect) -> Optional[IntEnum]:
return None if value is None else self.enum_class(value)
class DbClearType(TypeDecorator):
"""sqlalchemy clear_type type decorator"""
class TZDateTime(TypeDecorator):
"""
Store Timezone Aware Timestamps as Timezone Naive UTC
impl = Integer
https://docs.sqlalchemy.org/en/20/core/custom_types.html#store-timezone-aware-timestamps-as-timezone-naive-utc
"""
def process_bind_param(
self, value: Optional[ArcaeaPlayResultClearType], dialect
) -> Optional[int]:
return None if value is None else value.value
impl = DateTime
cache_ok = True
def process_result_value(
self, value: Optional[int], dialect
) -> Optional[ArcaeaPlayResultClearType]:
return None if value is None else ArcaeaPlayResultClearType(value)
def process_bind_param(self, value: Optional[datetime], dialect):
if value is not None:
if not value.tzinfo or value.tzinfo.utcoffset(value) is None:
raise TypeError("tzinfo is required")
value = value.astimezone(timezone.utc).replace(tzinfo=None)
return value
class DbModifier(TypeDecorator):
"""sqlalchemy modifier type decorator"""
impl = Integer
def process_bind_param(
self, value: Optional[ArcaeaPlayResultModifier], dialect
) -> Optional[int]:
return None if value is None else value.value
def process_result_value(
self, value: Optional[int], dialect
) -> Optional[ArcaeaPlayResultModifier]:
return None if value is None else ArcaeaPlayResultModifier(value)
def process_result_value(self, value: Optional[datetime], dialect):
if value is not None:
value = value.replace(tzinfo=timezone.utc)
return value

View File

@ -1,95 +1,67 @@
from datetime import datetime, timedelta, timezone
from enum import IntEnum
from typing import Optional
from sqlalchemy import text
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from arcaea_offline.constants.enums import (
ArcaeaPlayResultClearType,
ArcaeaPlayResultModifier,
ArcaeaRatingClass,
)
from arcaea_offline.database.models._custom_types import (
DbClearType,
DbModifier,
DbRatingClass,
)
from arcaea_offline.database.models._custom_types import DbIntEnum, TZDateTime
class Base(DeclarativeBase):
pass
class TestIntEnum(IntEnum):
__test__ = False
ONE = 1
TWO = 2
THREE = 3
class RatingClassTestModel(Base):
__tablename__ = "test_rating_class"
class TestBase(DeclarativeBase):
__test__ = False
id: Mapped[int] = mapped_column(primary_key=True)
value: Mapped[Optional[ArcaeaRatingClass]] = mapped_column(
DbRatingClass, nullable=True
)
class ClearTypeTestModel(Base):
__tablename__ = "test_clear_type"
id: Mapped[int] = mapped_column(primary_key=True)
value: Mapped[Optional[ArcaeaPlayResultClearType]] = mapped_column(
DbClearType, nullable=True
)
class IntEnumTestModel(TestBase):
__tablename__ = "test_int_enum"
value: Mapped[Optional[TestIntEnum]] = mapped_column(DbIntEnum(TestIntEnum))
class ModifierTestModel(Base):
__tablename__ = "test_modifier"
id: Mapped[int] = mapped_column(primary_key=True)
value: Mapped[Optional[ArcaeaPlayResultModifier]] = mapped_column(
DbModifier, nullable=True
)
class TZDatetimeTestModel(TestBase):
__tablename__ = "test_tz_datetime"
value: Mapped[Optional[datetime]] = mapped_column(TZDateTime)
class TestCustomTypes:
def _common_test_method(self, session: Session, obj: Base, value_in_db):
"""
This method stores the `obj` into the given `session`,
then fetches the raw value of `obj.value` from database,
and asserts that the value is equal to `value_in_db`.
"""
session.add(obj)
session.commit()
def test_int_enum(self, db_session):
def _query_value(_id: int):
return db_session.execute(
text(
f"SELECT value FROM {IntEnumTestModel.__tablename__} WHERE id = {_id}"
)
).one()[0]
exec_result = session.execute(
text(
f"SELECT value FROM {obj.__tablename__} WHERE id = {obj.id}" # type: ignore
)
).fetchone()[0]
TestBase.metadata.create_all(db_session.bind)
if value_in_db is None:
assert exec_result is value_in_db
else:
assert exec_result == value_in_db
basic_obj = IntEnumTestModel(id=1, value=TestIntEnum.TWO)
null_obj = IntEnumTestModel(id=2, value=None)
db_session.add(basic_obj)
db_session.add(null_obj)
db_session.commit()
def test_rating_class(self, db_session):
Base.metadata.create_all(db_session.bind)
assert _query_value(1) == TestIntEnum.TWO.value
assert _query_value(2) is None
basic_obj = RatingClassTestModel(id=1, value=ArcaeaRatingClass.FUTURE)
self._common_test_method(db_session, basic_obj, 2)
def test_tz_datetime(self, db_session):
TestBase.metadata.create_all(db_session.bind)
null_obj = RatingClassTestModel(id=2, value=None)
self._common_test_method(db_session, null_obj, None)
dt1 = datetime.now(tz=timezone(timedelta(hours=8)))
def test_clear_type(self, db_session):
Base.metadata.create_all(db_session.bind)
basic_obj = TZDatetimeTestModel(id=1, value=dt1)
null_obj = TZDatetimeTestModel(id=2, value=None)
db_session.add(basic_obj)
db_session.add(null_obj)
db_session.commit()
basic_obj = ClearTypeTestModel(id=1, value=ArcaeaPlayResultClearType.TRACK_LOST)
self._common_test_method(db_session, basic_obj, 0)
null_obj = ClearTypeTestModel(id=2, value=None)
self._common_test_method(db_session, null_obj, None)
def test_modifier(self, db_session):
Base.metadata.create_all(db_session.bind)
basic_obj = ModifierTestModel(id=1, value=ArcaeaPlayResultModifier.HARD)
self._common_test_method(db_session, basic_obj, 2)
null_obj = ModifierTestModel(id=2, value=None)
self._common_test_method(db_session, null_obj, None)
assert basic_obj.value == dt1
assert null_obj.value is None