Compare commits

...

4 Commits

Author SHA1 Message Date
8e79ffedce
ci: sync changes in master branch 2024-05-22 01:37:39 +08:00
677ab6c31e
test: refactor legacy tests 2024-05-21 21:12:11 +08:00
ab88b6903c
test: conftest database clean-up 2024-05-21 21:04:30 +08:00
ce715bfccc
refactor: sqlalchemy custom types
- Unify `IntEnum` type decorators to single `DbIntEnum`
- Add timezone aware `TZDateTime` from sqlalchemy docs
2024-05-20 21:21:49 +08:00
11 changed files with 290 additions and 290 deletions

40
.github/workflows/main.yml vendored Normal file
View File

@ -0,0 +1,40 @@
name: test & lint
on:
push:
branches:
- '*'
pull_request:
types: [opened, reopened]
workflow_dispatch:
jobs:
pytest:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dev dependencies
run: 'pip install .[dev]'
- name: Run tests
run: 'pytest -v'
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install dev dependencies
run: 'pip install .[dev]'
- name: Run linter
run: 'ruff check'

View File

@ -1,23 +0,0 @@
name: Run tests
on:
push:
branches:
- 'master'
pull_request:
types: [opened, reopened]
workflow_dispatch:
jobs:
pytest:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11']
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- run: 'pip install -r requirements.dev.txt .'
- run: 'pytest -v'

View File

@ -4,11 +4,10 @@ repos:
hooks: hooks:
- id: end-of-file-fixer - id: end-of-file-fixer
- id: trailing-whitespace - id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 23.1.0 - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
hooks: hooks:
- id: black - id: ruff
- repo: https://github.com/PyCQA/isort args: ["--fix"]
rev: 5.12.0 - id: ruff-format
hooks:
- id: isort

View File

@ -18,24 +18,34 @@ classifiers = [
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
] ]
[project.optional-dependencies]
dev = ["ruff~=0.4", "pre-commit~=3.3", "pytest~=7.4", "tox~=4.11"]
[project.urls] [project.urls]
"Homepage" = "https://github.com/283375/arcaea-offline" "Homepage" = "https://github.com/283375/arcaea-offline"
"Bug Tracker" = "https://github.com/283375/arcaea-offline/issues" "Bug Tracker" = "https://github.com/283375/arcaea-offline/issues"
[tool.isort]
profile = "black"
src_paths = ["src/arcaea_offline"]
[tool.pyright] [tool.pyright]
ignore = ["build/"] ignore = ["build/"]
[tool.pylint.main] [tool.ruff.lint]
jobs = 0 # Full list: https://docs.astral.sh/ruff/rules
select = [
[tool.pylint.logging] "E", # pycodestyle (Error)
disable = [ "W", # pycodestyle (Warning)
"missing-module-docstring", "F", # pyflakes
"missing-class-docstring", "I", # isort
"missing-function-docstring", "PL", # pylint
"not-callable", # false positive to sqlalchemy `func.*`, remove this when pylint-dev/pylint(#8138) closed "N", # pep8-naming
"FBT", # flake8-boolean-trap
"A", # flake8-builtins
"DTZ", # flake8-datetimez
"LOG", # flake8-logging
"Q", # flake8-quotes
"G", # flake8-logging-format
"PIE", # flake8-pie
"PT", # flake8-pytest-style
]
ignore = [
"E501", # line-too-long
] ]

View File

@ -1,6 +1,4 @@
black==23.3.0 ruff~=0.4
isort==5.12.0 pre-commit~=3.3
pre-commit==3.3.1 pytest~=7.4
pylint==3.0.2 tox~=4.11
pytest==7.4.3
tox==4.11.3

View File

@ -1,58 +1,46 @@
from typing import Optional from datetime import datetime, timezone
from enum import IntEnum
from typing import Optional, Type
from sqlalchemy import Integer from sqlalchemy import DateTime, Integer
from sqlalchemy.types import TypeDecorator from sqlalchemy.types import TypeDecorator
from arcaea_offline.constants.enums import (
ArcaeaPlayResultClearType,
ArcaeaPlayResultModifier,
ArcaeaRatingClass,
)
class DbIntEnum(TypeDecorator):
class DbRatingClass(TypeDecorator): """sqlalchemy `TypeDecorator` for `IntEnum`s"""
"""sqlalchemy rating_class type decorator"""
impl = Integer impl = Integer
cache_ok = True
def process_bind_param( def __init__(self, enum_class: Type[IntEnum]):
self, value: Optional[ArcaeaRatingClass], dialect super().__init__()
) -> Optional[int]: self.enum_class = enum_class
def process_bind_param(self, value: Optional[IntEnum], dialect) -> Optional[int]:
return None if value is None else value.value return None if value is None else value.value
def process_result_value( def process_result_value(self, value: Optional[int], dialect) -> Optional[IntEnum]:
self, value: Optional[int], dialect return None if value is None else self.enum_class(value)
) -> Optional[ArcaeaRatingClass]:
return None if value is None else ArcaeaRatingClass(value)
class DbClearType(TypeDecorator): class TZDateTime(TypeDecorator):
"""sqlalchemy clear_type type decorator""" """
Store Timezone Aware Timestamps as Timezone Naive UTC
impl = Integer https://docs.sqlalchemy.org/en/20/core/custom_types.html#store-timezone-aware-timestamps-as-timezone-naive-utc
"""
def process_bind_param( impl = DateTime
self, value: Optional[ArcaeaPlayResultClearType], dialect cache_ok = True
) -> Optional[int]:
return None if value is None else value.value
def process_result_value( def process_bind_param(self, value: Optional[datetime], dialect):
self, value: Optional[int], dialect if value is not None:
) -> Optional[ArcaeaPlayResultClearType]: if not value.tzinfo or value.tzinfo.utcoffset(value) is None:
return None if value is None else ArcaeaPlayResultClearType(value) raise TypeError("tzinfo is required")
value = value.astimezone(timezone.utc).replace(tzinfo=None)
return value
def process_result_value(self, value: Optional[datetime], dialect):
class DbModifier(TypeDecorator): if value is not None:
"""sqlalchemy modifier type decorator""" value = value.replace(tzinfo=timezone.utc)
return value
impl = Integer
def process_bind_param(
self, value: Optional[ArcaeaPlayResultModifier], dialect
) -> Optional[int]:
return None if value is None else value.value
def process_result_value(
self, value: Optional[int], dialect
) -> Optional[ArcaeaPlayResultModifier]:
return None if value is None else ArcaeaPlayResultModifier(value)

View File

@ -1,27 +1,53 @@
import pytest import pytest
from sqlalchemy import create_engine from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
# region sqlalchemy fixtures # region sqlalchemy fixtures
# from https://medium.com/@vittorio.camisa/agile-database-integration-tests-with-python-sqlalchemy-and-factory-boy-6824e8fe33a1
engine = create_engine("sqlite:///:memory:") engine = create_engine("sqlite:///:memory:")
Session = sessionmaker() Session = sessionmaker()
@pytest.fixture(scope="module") @pytest.fixture(scope="session")
def db_conn(): def db_conn():
connection = engine.connect() conn = engine.connect()
yield connection yield conn
connection.close() conn.close()
@pytest.fixture(scope="function") @pytest.fixture()
def db_session(db_conn): def db_session(db_conn):
transaction = db_conn.begin()
session = Session(bind=db_conn) session = Session(bind=db_conn)
yield session yield session
session.close() session.close()
transaction.rollback()
# drop everything
query_tables = db_conn.execute(
text("SELECT name FROM sqlite_master WHERE type='table'")
).fetchall()
for row in query_tables:
table_name = row[0]
db_conn.execute(text(f"DROP TABLE {table_name}"))
query_views = db_conn.execute(
text("SELECT name FROM sqlite_master WHERE type='view'")
).fetchall()
for row in query_views:
view_name = row[0]
db_conn.execute(text(f"DROP VIEW {view_name}"))
query_indexes = db_conn.execute(
text("SELECT name FROM sqlite_master WHERE type='index'")
).fetchall()
for row in query_indexes:
index_name = row[0]
db_conn.execute(text(f"DROP INDEX {index_name}"))
query_triggers = db_conn.execute(
text("SELECT name FROM sqlite_master WHERE type='trigger'")
).fetchall()
for row in query_triggers:
trigger_name = row[0]
db_conn.execute(text(f"DROP TRIGGER {trigger_name}"))
# endregion # endregion

View File

@ -1,95 +1,67 @@
from datetime import datetime, timedelta, timezone
from enum import IntEnum
from typing import Optional from typing import Optional
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from arcaea_offline.constants.enums import ( from arcaea_offline.database.models._custom_types import DbIntEnum, TZDateTime
ArcaeaPlayResultClearType,
ArcaeaPlayResultModifier,
ArcaeaRatingClass,
)
from arcaea_offline.database.models._custom_types import (
DbClearType,
DbModifier,
DbRatingClass,
)
class Base(DeclarativeBase): class TestIntEnum(IntEnum):
pass __test__ = False
ONE = 1
TWO = 2
THREE = 3
class RatingClassTestModel(Base): class TestBase(DeclarativeBase):
__tablename__ = "test_rating_class" __test__ = False
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
value: Mapped[Optional[ArcaeaRatingClass]] = mapped_column(
DbRatingClass, nullable=True
)
class ClearTypeTestModel(Base): class IntEnumTestModel(TestBase):
__tablename__ = "test_clear_type" __tablename__ = "test_int_enum"
value: Mapped[Optional[TestIntEnum]] = mapped_column(DbIntEnum(TestIntEnum))
id: Mapped[int] = mapped_column(primary_key=True)
value: Mapped[Optional[ArcaeaPlayResultClearType]] = mapped_column(
DbClearType, nullable=True
)
class ModifierTestModel(Base): class TZDatetimeTestModel(TestBase):
__tablename__ = "test_modifier" __tablename__ = "test_tz_datetime"
value: Mapped[Optional[datetime]] = mapped_column(TZDateTime)
id: Mapped[int] = mapped_column(primary_key=True)
value: Mapped[Optional[ArcaeaPlayResultModifier]] = mapped_column(
DbModifier, nullable=True
)
class TestCustomTypes: class TestCustomTypes:
def _common_test_method(self, session: Session, obj: Base, value_in_db): def test_int_enum(self, db_session):
""" def _query_value(_id: int):
This method stores the `obj` into the given `session`, return db_session.execute(
then fetches the raw value of `obj.value` from database, text(
and asserts that the value is equal to `value_in_db`. f"SELECT value FROM {IntEnumTestModel.__tablename__} WHERE id = {_id}"
""" )
session.add(obj) ).one()[0]
session.commit()
exec_result = session.execute( TestBase.metadata.create_all(db_session.bind, checkfirst=False)
text(
f"SELECT value FROM {obj.__tablename__} WHERE id = {obj.id}" # type: ignore
)
).fetchone()[0]
if value_in_db is None: basic_obj = IntEnumTestModel(id=1, value=TestIntEnum.TWO)
assert exec_result is value_in_db null_obj = IntEnumTestModel(id=2, value=None)
else: db_session.add(basic_obj)
assert exec_result == value_in_db db_session.add(null_obj)
db_session.commit()
def test_rating_class(self, db_session): assert _query_value(1) == TestIntEnum.TWO.value
Base.metadata.create_all(db_session.bind) assert _query_value(2) is None
basic_obj = RatingClassTestModel(id=1, value=ArcaeaRatingClass.FUTURE) def test_tz_datetime(self, db_session):
self._common_test_method(db_session, basic_obj, 2) TestBase.metadata.create_all(db_session.bind, checkfirst=False)
null_obj = RatingClassTestModel(id=2, value=None) dt1 = datetime.now(tz=timezone(timedelta(hours=8)))
self._common_test_method(db_session, null_obj, None)
def test_clear_type(self, db_session): basic_obj = TZDatetimeTestModel(id=1, value=dt1)
Base.metadata.create_all(db_session.bind) null_obj = TZDatetimeTestModel(id=2, value=None)
db_session.add(basic_obj)
db_session.add(null_obj)
db_session.commit()
basic_obj = ClearTypeTestModel(id=1, value=ArcaeaPlayResultClearType.TRACK_LOST) assert basic_obj.value == dt1
self._common_test_method(db_session, basic_obj, 0) assert null_obj.value is None
null_obj = ClearTypeTestModel(id=2, value=None)
self._common_test_method(db_session, null_obj, None)
def test_modifier(self, db_session):
Base.metadata.create_all(db_session.bind)
basic_obj = ModifierTestModel(id=1, value=ArcaeaPlayResultModifier.HARD)
self._common_test_method(db_session, basic_obj, 2)
null_obj = ModifierTestModel(id=2, value=None)
self._common_test_method(db_session, null_obj, None)

View File

@ -1,118 +0,0 @@
from sqlalchemy import Engine
from sqlalchemy.orm import Session
from arcaea_offline.models.songs import (
Chart,
ChartInfo,
Difficulty,
Pack,
Song,
SongsBase,
SongsViewBase,
)
from ..db import create_engine_in_memory
def _song(**kw):
defaults = {"artist": "test"}
defaults.update(kw)
return Song(**defaults)
def _difficulty(**kw):
defaults = {"rating_plus": False, "audio_override": False, "jacket_override": False}
defaults.update(kw)
return Difficulty(**defaults)
class Test_Chart:
def init_db(self, engine: Engine):
SongsBase.metadata.create_all(engine)
SongsViewBase.metadata.create_all(engine)
def db(self):
db = create_engine_in_memory()
self.init_db(db)
return db
def test_chart_info(self):
pre_entites = [
Pack(id="test", name="Test Pack"),
_song(idx=0, id="song0", set="test", title="Full Chart Info"),
_song(idx=1, id="song1", set="test", title="Partial Chart Info"),
_song(idx=2, id="song2", set="test", title="No Chart Info"),
_difficulty(song_id="song0", rating_class=2, rating=9),
_difficulty(song_id="song1", rating_class=2, rating=9),
_difficulty(song_id="song2", rating_class=2, rating=9),
ChartInfo(song_id="song0", rating_class=2, constant=90, notes=1234),
ChartInfo(song_id="song1", rating_class=2, constant=90),
]
db = self.db()
with Session(db) as session:
session.add_all(pre_entites)
session.commit()
chart_song0_ratingclass2 = (
session.query(Chart)
.where((Chart.song_id == "song0") & (Chart.rating_class == 2))
.one()
)
assert chart_song0_ratingclass2.constant == 90
assert chart_song0_ratingclass2.notes == 1234
chart_song1_ratingclass2 = (
session.query(Chart)
.where((Chart.song_id == "song1") & (Chart.rating_class == 2))
.one()
)
assert chart_song1_ratingclass2.constant == 90
assert chart_song1_ratingclass2.notes is None
chart_song2_ratingclass2 = (
session.query(Chart)
.where((Chart.song_id == "song2") & (Chart.rating_class == 2))
.first()
)
assert chart_song2_ratingclass2 is None
def test_difficulty_title_override(self):
pre_entites = [
Pack(id="test", name="Test Pack"),
_song(idx=0, id="test", set="test", title="Test"),
_difficulty(song_id="test", rating_class=0, rating=2),
_difficulty(song_id="test", rating_class=1, rating=5),
_difficulty(song_id="test", rating_class=2, rating=8),
_difficulty(
song_id="test", rating_class=3, rating=10, title="TEST ~REVIVE~"
),
ChartInfo(song_id="test", rating_class=0, constant=10),
ChartInfo(song_id="test", rating_class=1, constant=10),
ChartInfo(song_id="test", rating_class=2, constant=10),
ChartInfo(song_id="test", rating_class=3, constant=10),
]
db = self.db()
with Session(db) as session:
session.add_all(pre_entites)
session.commit()
charts_original_title = (
session.query(Chart)
.where((Chart.song_id == "test") & (Chart.rating_class in [0, 1, 2]))
.all()
)
assert all(chart.title == "Test" for chart in charts_original_title)
chart_overrided_title = (
session.query(Chart)
.where((Chart.song_id == "test") & (Chart.rating_class == 3))
.one()
)
assert chart_overrided_title.title == "TEST ~REVIVE~"

View File

View File

@ -0,0 +1,108 @@
from arcaea_offline.database.models.v4.songs import (
Chart,
ChartInfo,
Difficulty,
Pack,
Song,
SongsBase,
SongsViewBase,
)
def _song(**kw):
defaults = {"artist": "test"}
defaults.update(kw)
return Song(**defaults)
def _difficulty(**kw):
defaults = {"rating_plus": False, "audio_override": False, "jacket_override": False}
defaults.update(kw)
return Difficulty(**defaults)
class Test_Chart:
def init_db(self, session):
SongsBase.metadata.create_all(session.bind, checkfirst=False)
SongsViewBase.metadata.create_all(session.bind, checkfirst=False)
def test_chart_info(self, db_session):
self.init_db(db_session)
pre_entites = [
Pack(id="test", name="Test Pack"),
_song(idx=0, id="song0", set="test", title="Full Chart Info"),
_song(idx=1, id="song1", set="test", title="Partial Chart Info"),
_song(idx=2, id="song2", set="test", title="No Chart Info"),
_difficulty(song_id="song0", rating_class=2, rating=9),
_difficulty(song_id="song1", rating_class=2, rating=9),
_difficulty(song_id="song2", rating_class=2, rating=9),
ChartInfo(song_id="song0", rating_class=2, constant=90, notes=1234),
ChartInfo(song_id="song1", rating_class=2, constant=90),
]
db_session.add_all(pre_entites)
db_session.commit()
chart_song0_ratingclass2 = (
db_session.query(Chart)
.where((Chart.song_id == "song0") & (Chart.rating_class == 2))
.one()
)
assert chart_song0_ratingclass2.constant == 90
assert chart_song0_ratingclass2.notes == 1234
chart_song1_ratingclass2 = (
db_session.query(Chart)
.where((Chart.song_id == "song1") & (Chart.rating_class == 2))
.one()
)
assert chart_song1_ratingclass2.constant == 90
assert chart_song1_ratingclass2.notes is None
chart_song2_ratingclass2 = (
db_session.query(Chart)
.where((Chart.song_id == "song2") & (Chart.rating_class == 2))
.first()
)
assert chart_song2_ratingclass2 is None
def test_difficulty_title_override(self, db_session):
self.init_db(db_session)
pre_entites = [
Pack(id="test", name="Test Pack"),
_song(idx=0, id="test", set="test", title="Test"),
_difficulty(song_id="test", rating_class=0, rating=2),
_difficulty(song_id="test", rating_class=1, rating=5),
_difficulty(song_id="test", rating_class=2, rating=8),
_difficulty(
song_id="test", rating_class=3, rating=10, title="TEST ~REVIVE~"
),
ChartInfo(song_id="test", rating_class=0, constant=10),
ChartInfo(song_id="test", rating_class=1, constant=10),
ChartInfo(song_id="test", rating_class=2, constant=10),
ChartInfo(song_id="test", rating_class=3, constant=10),
]
db_session.add_all(pre_entites)
db_session.commit()
charts_original_title = (
db_session.query(Chart)
.where((Chart.song_id == "test") & (Chart.rating_class in [0, 1, 2]))
.all()
)
assert all(chart.title == "Test" for chart in charts_original_title)
chart_overrided_title = (
db_session.query(Chart)
.where((Chart.song_id == "test") & (Chart.rating_class == 3))
.one()
)
assert chart_overrided_title.title == "TEST ~REVIVE~"