mirror of
https://github.com/283375/arcaea-offline.git
synced 2025-04-19 06:00:18 +00:00
refactor: sqlalchemy custom types
- Unify `IntEnum` type decorators to single `DbIntEnum` - Add timezone aware `TZDateTime` from sqlalchemy docs
This commit is contained in:
parent
0d5e21a90e
commit
ce715bfccc
@ -1,58 +1,46 @@
|
|||||||
from typing import Optional
|
from datetime import datetime, timezone
|
||||||
|
from enum import IntEnum
|
||||||
|
from typing import Optional, Type
|
||||||
|
|
||||||
from sqlalchemy import Integer
|
from sqlalchemy import DateTime, Integer
|
||||||
from sqlalchemy.types import TypeDecorator
|
from sqlalchemy.types import TypeDecorator
|
||||||
|
|
||||||
from arcaea_offline.constants.enums import (
|
|
||||||
ArcaeaPlayResultClearType,
|
|
||||||
ArcaeaPlayResultModifier,
|
|
||||||
ArcaeaRatingClass,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
class DbIntEnum(TypeDecorator):
|
||||||
class DbRatingClass(TypeDecorator):
|
"""sqlalchemy `TypeDecorator` for `IntEnum`s"""
|
||||||
"""sqlalchemy rating_class type decorator"""
|
|
||||||
|
|
||||||
impl = Integer
|
impl = Integer
|
||||||
|
cache_ok = True
|
||||||
|
|
||||||
def process_bind_param(
|
def __init__(self, enum_class: Type[IntEnum]):
|
||||||
self, value: Optional[ArcaeaRatingClass], dialect
|
super().__init__()
|
||||||
) -> Optional[int]:
|
self.enum_class = enum_class
|
||||||
|
|
||||||
|
def process_bind_param(self, value: Optional[IntEnum], dialect) -> Optional[int]:
|
||||||
return None if value is None else value.value
|
return None if value is None else value.value
|
||||||
|
|
||||||
def process_result_value(
|
def process_result_value(self, value: Optional[int], dialect) -> Optional[IntEnum]:
|
||||||
self, value: Optional[int], dialect
|
return None if value is None else self.enum_class(value)
|
||||||
) -> Optional[ArcaeaRatingClass]:
|
|
||||||
return None if value is None else ArcaeaRatingClass(value)
|
|
||||||
|
|
||||||
|
|
||||||
class DbClearType(TypeDecorator):
|
class TZDateTime(TypeDecorator):
|
||||||
"""sqlalchemy clear_type type decorator"""
|
"""
|
||||||
|
Store Timezone Aware Timestamps as Timezone Naive UTC
|
||||||
|
|
||||||
impl = Integer
|
https://docs.sqlalchemy.org/en/20/core/custom_types.html#store-timezone-aware-timestamps-as-timezone-naive-utc
|
||||||
|
"""
|
||||||
|
|
||||||
def process_bind_param(
|
impl = DateTime
|
||||||
self, value: Optional[ArcaeaPlayResultClearType], dialect
|
cache_ok = True
|
||||||
) -> Optional[int]:
|
|
||||||
return None if value is None else value.value
|
|
||||||
|
|
||||||
def process_result_value(
|
def process_bind_param(self, value: Optional[datetime], dialect):
|
||||||
self, value: Optional[int], dialect
|
if value is not None:
|
||||||
) -> Optional[ArcaeaPlayResultClearType]:
|
if not value.tzinfo or value.tzinfo.utcoffset(value) is None:
|
||||||
return None if value is None else ArcaeaPlayResultClearType(value)
|
raise TypeError("tzinfo is required")
|
||||||
|
value = value.astimezone(timezone.utc).replace(tzinfo=None)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def process_result_value(self, value: Optional[datetime], dialect):
|
||||||
class DbModifier(TypeDecorator):
|
if value is not None:
|
||||||
"""sqlalchemy modifier type decorator"""
|
value = value.replace(tzinfo=timezone.utc)
|
||||||
|
return value
|
||||||
impl = Integer
|
|
||||||
|
|
||||||
def process_bind_param(
|
|
||||||
self, value: Optional[ArcaeaPlayResultModifier], dialect
|
|
||||||
) -> Optional[int]:
|
|
||||||
return None if value is None else value.value
|
|
||||||
|
|
||||||
def process_result_value(
|
|
||||||
self, value: Optional[int], dialect
|
|
||||||
) -> Optional[ArcaeaPlayResultModifier]:
|
|
||||||
return None if value is None else ArcaeaPlayResultModifier(value)
|
|
||||||
|
@ -1,95 +1,67 @@
|
|||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from enum import IntEnum
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
|
|
||||||
from arcaea_offline.constants.enums import (
|
from arcaea_offline.database.models._custom_types import DbIntEnum, TZDateTime
|
||||||
ArcaeaPlayResultClearType,
|
|
||||||
ArcaeaPlayResultModifier,
|
|
||||||
ArcaeaRatingClass,
|
|
||||||
)
|
|
||||||
from arcaea_offline.database.models._custom_types import (
|
|
||||||
DbClearType,
|
|
||||||
DbModifier,
|
|
||||||
DbRatingClass,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class TestIntEnum(IntEnum):
|
||||||
pass
|
__test__ = False
|
||||||
|
|
||||||
|
ONE = 1
|
||||||
|
TWO = 2
|
||||||
|
THREE = 3
|
||||||
|
|
||||||
|
|
||||||
class RatingClassTestModel(Base):
|
class TestBase(DeclarativeBase):
|
||||||
__tablename__ = "test_rating_class"
|
__test__ = False
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
value: Mapped[Optional[ArcaeaRatingClass]] = mapped_column(
|
|
||||||
DbRatingClass, nullable=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ClearTypeTestModel(Base):
|
class IntEnumTestModel(TestBase):
|
||||||
__tablename__ = "test_clear_type"
|
__tablename__ = "test_int_enum"
|
||||||
|
value: Mapped[Optional[TestIntEnum]] = mapped_column(DbIntEnum(TestIntEnum))
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
|
||||||
value: Mapped[Optional[ArcaeaPlayResultClearType]] = mapped_column(
|
|
||||||
DbClearType, nullable=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ModifierTestModel(Base):
|
class TZDatetimeTestModel(TestBase):
|
||||||
__tablename__ = "test_modifier"
|
__tablename__ = "test_tz_datetime"
|
||||||
|
value: Mapped[Optional[datetime]] = mapped_column(TZDateTime)
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
|
||||||
value: Mapped[Optional[ArcaeaPlayResultModifier]] = mapped_column(
|
|
||||||
DbModifier, nullable=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestCustomTypes:
|
class TestCustomTypes:
|
||||||
def _common_test_method(self, session: Session, obj: Base, value_in_db):
|
def test_int_enum(self, db_session):
|
||||||
"""
|
def _query_value(_id: int):
|
||||||
This method stores the `obj` into the given `session`,
|
return db_session.execute(
|
||||||
then fetches the raw value of `obj.value` from database,
|
text(
|
||||||
and asserts that the value is equal to `value_in_db`.
|
f"SELECT value FROM {IntEnumTestModel.__tablename__} WHERE id = {_id}"
|
||||||
"""
|
)
|
||||||
session.add(obj)
|
).one()[0]
|
||||||
session.commit()
|
|
||||||
|
|
||||||
exec_result = session.execute(
|
TestBase.metadata.create_all(db_session.bind)
|
||||||
text(
|
|
||||||
f"SELECT value FROM {obj.__tablename__} WHERE id = {obj.id}" # type: ignore
|
|
||||||
)
|
|
||||||
).fetchone()[0]
|
|
||||||
|
|
||||||
if value_in_db is None:
|
basic_obj = IntEnumTestModel(id=1, value=TestIntEnum.TWO)
|
||||||
assert exec_result is value_in_db
|
null_obj = IntEnumTestModel(id=2, value=None)
|
||||||
else:
|
db_session.add(basic_obj)
|
||||||
assert exec_result == value_in_db
|
db_session.add(null_obj)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
def test_rating_class(self, db_session):
|
assert _query_value(1) == TestIntEnum.TWO.value
|
||||||
Base.metadata.create_all(db_session.bind)
|
assert _query_value(2) is None
|
||||||
|
|
||||||
basic_obj = RatingClassTestModel(id=1, value=ArcaeaRatingClass.FUTURE)
|
def test_tz_datetime(self, db_session):
|
||||||
self._common_test_method(db_session, basic_obj, 2)
|
TestBase.metadata.create_all(db_session.bind)
|
||||||
|
|
||||||
null_obj = RatingClassTestModel(id=2, value=None)
|
dt1 = datetime.now(tz=timezone(timedelta(hours=8)))
|
||||||
self._common_test_method(db_session, null_obj, None)
|
|
||||||
|
|
||||||
def test_clear_type(self, db_session):
|
basic_obj = TZDatetimeTestModel(id=1, value=dt1)
|
||||||
Base.metadata.create_all(db_session.bind)
|
null_obj = TZDatetimeTestModel(id=2, value=None)
|
||||||
|
db_session.add(basic_obj)
|
||||||
|
db_session.add(null_obj)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
basic_obj = ClearTypeTestModel(id=1, value=ArcaeaPlayResultClearType.TRACK_LOST)
|
assert basic_obj.value == dt1
|
||||||
self._common_test_method(db_session, basic_obj, 0)
|
assert null_obj.value is None
|
||||||
|
|
||||||
null_obj = ClearTypeTestModel(id=2, value=None)
|
|
||||||
self._common_test_method(db_session, null_obj, None)
|
|
||||||
|
|
||||||
def test_modifier(self, db_session):
|
|
||||||
Base.metadata.create_all(db_session.bind)
|
|
||||||
|
|
||||||
basic_obj = ModifierTestModel(id=1, value=ArcaeaPlayResultModifier.HARD)
|
|
||||||
self._common_test_method(db_session, basic_obj, 2)
|
|
||||||
|
|
||||||
null_obj = ModifierTestModel(id=2, value=None)
|
|
||||||
self._common_test_method(db_session, null_obj, None)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user