feat: add Version

This commit is contained in:
2025-08-04 00:58:11 +08:00
parent ab03b27730
commit a8164f37e2
4 changed files with 59 additions and 2 deletions

View File

@ -4,10 +4,13 @@ from sqlalchemy import MetaData
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm.exc import DetachedInstanceError from sqlalchemy.orm.exc import DetachedInstanceError
from ._types import ForceTimezoneDateTime from arcaea_offline.utils import Version
from ._types import ForceTimezoneDateTime, VersionDatabaseType
TYPE_ANNOTATION_MAP = { TYPE_ANNOTATION_MAP = {
datetime: ForceTimezoneDateTime, datetime: ForceTimezoneDateTime,
Version: VersionDatabaseType,
} }

View File

@ -1,9 +1,11 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional from typing import Optional
from sqlalchemy import DateTime from sqlalchemy import DateTime, String
from sqlalchemy.types import TypeDecorator from sqlalchemy.types import TypeDecorator
from arcaea_offline.utils import Version
class ForceTimezoneDateTime(TypeDecorator): class ForceTimezoneDateTime(TypeDecorator):
""" """
@ -26,3 +28,23 @@ class ForceTimezoneDateTime(TypeDecorator):
if value is not None: if value is not None:
value = value.replace(tzinfo=timezone.utc) value = value.replace(tzinfo=timezone.utc)
return value return value
class VersionDatabaseType(TypeDecorator):
impl = String
cache_ok = True
def process_bind_param(self, value: Optional[Version], dialect):
if value is None:
return None
if not isinstance(value, Version):
raise ValueError("Input is not a Version instance.")
return str(f"{value.first}.{value.second}.{value.third}")
def process_result_value(self, value: Optional[str], dialect):
if value is None:
return None
return Version(*(map(int, value.split("."))))

View File

@ -0,0 +1,5 @@
from .version import Version
__all__ = [
"Version",
]

View File

@ -0,0 +1,27 @@
from typing import NamedTuple
class Version(NamedTuple):
first: int
second: int
third: int
@classmethod
def from_string(cls, version_str: str):
version_str = version_str.removesuffix("c")
parts = version_str.split(".")
if len(parts) not in {2, 3}:
raise ValueError(f"Invalid version string {version_str}")
try:
if len(parts) == 2: # noqa: PLR2004
parts.append("0")
first, second, third = map(int, parts)
except ValueError as e:
raise ValueError(f"Invalid version string {version_str}") from e
return cls(first, second, third)
def __str__(self):
return f"{self.first}.{self.second}.{self.third}"