refactor: sqlalchemy custom types

- Unify `IntEnum` type decorators to single `DbIntEnum`
- Add timezone aware `TZDateTime` from sqlalchemy docs
This commit is contained in:
2024-05-20 21:21:49 +08:00
parent 0d5e21a90e
commit ce715bfccc
2 changed files with 74 additions and 114 deletions

View File

@ -1,58 +1,46 @@
from typing import Optional
from datetime import datetime, timezone
from enum import IntEnum
from typing import Optional, Type
from sqlalchemy import Integer
from sqlalchemy import DateTime, Integer
from sqlalchemy.types import TypeDecorator
from arcaea_offline.constants.enums import (
ArcaeaPlayResultClearType,
ArcaeaPlayResultModifier,
ArcaeaRatingClass,
)
class DbRatingClass(TypeDecorator):
"""sqlalchemy rating_class type decorator"""
class DbIntEnum(TypeDecorator):
"""sqlalchemy `TypeDecorator` for `IntEnum`s"""
impl = Integer
cache_ok = True
def process_bind_param(
self, value: Optional[ArcaeaRatingClass], dialect
) -> Optional[int]:
def __init__(self, enum_class: Type[IntEnum]):
super().__init__()
self.enum_class = enum_class
def process_bind_param(self, value: Optional[IntEnum], dialect) -> Optional[int]:
return None if value is None else value.value
def process_result_value(
self, value: Optional[int], dialect
) -> Optional[ArcaeaRatingClass]:
return None if value is None else ArcaeaRatingClass(value)
def process_result_value(self, value: Optional[int], dialect) -> Optional[IntEnum]:
return None if value is None else self.enum_class(value)
class DbClearType(TypeDecorator):
"""sqlalchemy clear_type type decorator"""
class TZDateTime(TypeDecorator):
"""
Store Timezone Aware Timestamps as Timezone Naive UTC
impl = Integer
https://docs.sqlalchemy.org/en/20/core/custom_types.html#store-timezone-aware-timestamps-as-timezone-naive-utc
"""
def process_bind_param(
self, value: Optional[ArcaeaPlayResultClearType], dialect
) -> Optional[int]:
return None if value is None else value.value
impl = DateTime
cache_ok = True
def process_result_value(
self, value: Optional[int], dialect
) -> Optional[ArcaeaPlayResultClearType]:
return None if value is None else ArcaeaPlayResultClearType(value)
def process_bind_param(self, value: Optional[datetime], dialect):
if value is not None:
if not value.tzinfo or value.tzinfo.utcoffset(value) is None:
raise TypeError("tzinfo is required")
value = value.astimezone(timezone.utc).replace(tzinfo=None)
return value
class DbModifier(TypeDecorator):
"""sqlalchemy modifier type decorator"""
impl = Integer
def process_bind_param(
self, value: Optional[ArcaeaPlayResultModifier], dialect
) -> Optional[int]:
return None if value is None else value.value
def process_result_value(
self, value: Optional[int], dialect
) -> Optional[ArcaeaPlayResultModifier]:
return None if value is None else ArcaeaPlayResultModifier(value)
def process_result_value(self, value: Optional[datetime], dialect):
if value is not None:
value = value.replace(tzinfo=timezone.utc)
return value