diff --git a/src/arcaea_offline/database/db.py b/src/arcaea_offline/database/db.py index f4ad1f8..9947aa3 100644 --- a/src/arcaea_offline/database/db.py +++ b/src/arcaea_offline/database/db.py @@ -5,7 +5,6 @@ from typing import Iterable, Optional, Type, Union from sqlalchemy import Engine, func, inspect, select from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, sessionmaker -from arcaea_offline.external.arcsong.arcsong_json import ArcSongJsonBuilder from arcaea_offline.singleton import Singleton from .models.v4.config import ConfigBase, Property @@ -403,11 +402,3 @@ class Database(metaclass=Singleton): return self.__count_table(ScoreBest) # endregion - - # region export - def generate_arcsong(self): - with self.sessionmaker() as session: - arcsong = ArcSongJsonBuilder(session).generate_arcsong_json() - return arcsong - - # endregion diff --git a/src/arcaea_offline/external/arcsong/__init__.py b/src/arcaea_offline/external/arcsong/__init__.py deleted file mode 100644 index 28eda7b..0000000 --- a/src/arcaea_offline/external/arcsong/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .arcsong_db import ArcsongDbParser - -__all__ = ["ArcsongDbParser"] diff --git a/src/arcaea_offline/external/arcsong/arcsong_db.py b/src/arcaea_offline/external/arcsong/arcsong_db.py deleted file mode 100644 index cc4abb1..0000000 --- a/src/arcaea_offline/external/arcsong/arcsong_db.py +++ /dev/null @@ -1,34 +0,0 @@ -import sqlite3 -from typing import List - -from sqlalchemy.orm import Session - -from arcaea_offline.database.models.v4 import ChartInfo - - -class ArcsongDbParser: - def __init__(self, filepath): - self.filepath = filepath - - def parse(self) -> List[ChartInfo]: - results = [] - with sqlite3.connect(self.filepath) as conn: - cursor = conn.cursor() - arcsong_db_results = cursor.execute( - "SELECT song_id, rating_class, rating, note FROM charts" - ) - for result in arcsong_db_results: - chart = ChartInfo( - song_id=result[0], - rating_class=result[1], - constant=result[2], - notes=result[3] or None, - ) - results.append(chart) - - return results - - def write_database(self, session: Session): - results = self.parse() - for result in results: - session.merge(result) diff --git a/src/arcaea_offline/external/arcsong/arcsong_json.py b/src/arcaea_offline/external/arcsong/arcsong_json.py deleted file mode 100644 index ce23d3a..0000000 --- a/src/arcaea_offline/external/arcsong/arcsong_json.py +++ /dev/null @@ -1,157 +0,0 @@ -import logging -import re -from typing import List, Optional, TypedDict - -from sqlalchemy import func, select -from sqlalchemy.orm import Session - -from arcaea_offline.database.models.v4 import ( - ChartInfo, - Difficulty, - DifficultyLocalized, - Pack, - Song, - SongLocalized, -) - -logger = logging.getLogger(__name__) - - -class TArcSongJsonDifficultyItem(TypedDict): - name_en: str - name_jp: str - artist: str - bpm: str - bpm_base: float - set: str - set_friendly: str - time: int - side: int - world_unlock: bool - remote_download: bool - bg: str - date: int - version: str - difficulty: int - rating: int - note: int - chart_designer: str - jacket_designer: str - jacket_override: bool - audio_override: bool - - -class TArcSongJsonSongItem(TypedDict): - song_id: str - difficulties: List[TArcSongJsonDifficultyItem] - alias: List[str] - - -class TArcSongJson(TypedDict): - songs: List[TArcSongJsonSongItem] - - -class ArcSongJsonBuilder: - def __init__(self, session: Session): - self.session = session - - def get_difficulty_item( - self, - difficulty: Difficulty, - song: Song, - pack: Pack, - song_localized: Optional[SongLocalized], - ) -> TArcSongJsonDifficultyItem: - if "_append_" in pack.id: - base_pack = self.session.scalar( - select(Pack).where(Pack.id == re.sub(r"_append_.*$", "", pack.id)) - ) - else: - base_pack = None - - difficulty_localized = self.session.scalar( - select(DifficultyLocalized).where( - (DifficultyLocalized.song_id == difficulty.song_id) - & (DifficultyLocalized.rating_class == difficulty.rating_class) - ) - ) - chart_info = self.session.scalar( - select(ChartInfo).where( - (ChartInfo.song_id == difficulty.song_id) - & (ChartInfo.rating_class == difficulty.rating_class) - ) - ) - - if difficulty_localized: - name_jp = difficulty_localized.title_ja or "" - elif song_localized: - name_jp = song_localized.title_ja or "" - else: - name_jp = "" - - return { - "name_en": difficulty.title or song.title, - "name_jp": name_jp, - "artist": difficulty.artist or song.artist, - "bpm": difficulty.bpm or song.bpm or "", - "bpm_base": difficulty.bpm_base or song.bpm_base or 0.0, - "set": song.set, - "set_friendly": f"{base_pack.name} - {pack.name}" - if base_pack - else pack.name, - "time": 0, - "side": song.side or 0, - "world_unlock": False, - "remote_download": False, - "bg": difficulty.bg or song.bg or "", - "date": difficulty.date or song.date or 0, - "version": difficulty.version or song.version or "", - "difficulty": difficulty.rating * 2 + int(difficulty.rating_plus), - "rating": chart_info.constant or 0 if chart_info else 0, - "note": chart_info.notes or 0 if chart_info else 0, - "chart_designer": difficulty.chart_designer or "", - "jacket_designer": difficulty.jacket_desginer or "", - "jacket_override": difficulty.jacket_override, - "audio_override": difficulty.audio_override, - } - - def get_song_item(self, song: Song) -> TArcSongJsonSongItem: - difficulties = self.session.scalars( - select(Difficulty).where(Difficulty.song_id == song.id) - ) - - pack = self.session.scalar(select(Pack).where(Pack.id == song.set)) - if not pack: - logger.warning( - 'Cannot find pack "%s", using placeholder instead.', song.set - ) - pack = Pack(id="unknown", name="Unknown", description="__PLACEHOLDER__") - song_localized = self.session.scalar( - select(SongLocalized).where(SongLocalized.id == song.id) - ) - - return { - "song_id": song.id, - "difficulties": [ - self.get_difficulty_item(difficulty, song, pack, song_localized) - for difficulty in difficulties - ], - "alias": [], - } - - def generate_arcsong_json(self) -> TArcSongJson: - songs = self.session.scalars(select(Song)) - arcsong_songs = [] - for song in songs: - proceed = self.session.scalar( - select(func.count(Difficulty.rating_class)).where( - Difficulty.song_id == song.id - ) - ) - - if not proceed: - continue - - arcsong_songs.append(self.get_song_item(song)) - - return {"songs": arcsong_songs} diff --git a/src/arcaea_offline/external/exporters/arcsong/__init__.py b/src/arcaea_offline/external/exporters/arcsong/__init__.py new file mode 100644 index 0000000..54ad51f --- /dev/null +++ b/src/arcaea_offline/external/exporters/arcsong/__init__.py @@ -0,0 +1,3 @@ +from .json import ArcsongJsonExporter + +__all__ = ["ArcsongJsonExporter"] diff --git a/src/arcaea_offline/external/exporters/arcsong/definitions.py b/src/arcaea_offline/external/exporters/arcsong/definitions.py new file mode 100644 index 0000000..9c1f301 --- /dev/null +++ b/src/arcaea_offline/external/exporters/arcsong/definitions.py @@ -0,0 +1,35 @@ +from typing import List, TypedDict + + +class ArcsongJsonDifficultyItem(TypedDict): + name_en: str + name_jp: str + artist: str + bpm: str + bpm_base: float + set: str + set_friendly: str + time: int + side: int + world_unlock: bool + remote_download: bool + bg: str + date: int + version: str + difficulty: int + rating: int + note: int + chart_designer: str + jacket_designer: str + jacket_override: bool + audio_override: bool + + +class ArcsongJsonSongItem(TypedDict): + song_id: str + difficulties: List[ArcsongJsonDifficultyItem] + alias: List[str] + + +class ArcsongJsonRoot(TypedDict): + songs: List[ArcsongJsonSongItem] diff --git a/src/arcaea_offline/external/exporters/arcsong/json.py b/src/arcaea_offline/external/exporters/arcsong/json.py new file mode 100644 index 0000000..52fecbf --- /dev/null +++ b/src/arcaea_offline/external/exporters/arcsong/json.py @@ -0,0 +1,98 @@ +import logging +import re +from typing import List, Optional + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from arcaea_offline.constants.enums.arcaea import ArcaeaLanguage +from arcaea_offline.database.models.v5 import Difficulty, Pack, Song + +from .definitions import ArcsongJsonDifficultyItem, ArcsongJsonRoot, ArcsongJsonSongItem + +logger = logging.getLogger(__name__) + + +class ArcsongJsonExporter: + @staticmethod + def craft_difficulty_item( + difficulty: Difficulty, *, base_pack: Optional[Pack] + ) -> ArcsongJsonDifficultyItem: + song = difficulty.song + pack = song.pack + chart_info = difficulty.chart_info + + song_localized_ja = next( + (lo for lo in song.localized_objects if lo.lang == ArcaeaLanguage.JA), + None, + ) + difficulty_localized_ja = next( + (lo for lo in difficulty.localized_objects if lo.lang == ArcaeaLanguage.JA), + None, + ) + + if difficulty_localized_ja: + name_jp = difficulty_localized_ja.title or "" + elif song_localized_ja: + name_jp = song_localized_ja.title or "" + else: + name_jp = "" + + return { + "name_en": difficulty.title or song.title, + "name_jp": name_jp, + "artist": difficulty.artist or song.artist, + "bpm": difficulty.bpm or song.bpm or "", + "bpm_base": difficulty.bpm_base or song.bpm_base or 0.0, + "set": song.pack_id, + "set_friendly": f"{base_pack.name} - {pack.name}" + if base_pack + else pack.name, + "time": 0, + "side": song.side or 0, + "world_unlock": False, + "remote_download": False, + "bg": difficulty.bg or song.bg or "", + "date": difficulty.date or song.date or 0, + "version": difficulty.version or song.version or "", + "difficulty": difficulty.rating * 2 + int(difficulty.rating_plus), + "rating": chart_info.constant or 0 if chart_info else 0, + "note": chart_info.notes or 0 if chart_info else 0, + "chart_designer": difficulty.chart_designer or "", + "jacket_designer": difficulty.jacket_desginer or "", + "jacket_override": difficulty.jacket_override, + "audio_override": difficulty.audio_override, + } + + @classmethod + def craft(cls, session: Session) -> ArcsongJsonRoot: + songs = session.scalars(select(Song)) + + arcsong_songs: List[ArcsongJsonSongItem] = [] + for song in songs: + if len(song.difficulties) == 0: + continue + + pack = song.pack + if "_append_" in pack.id: + base_pack = session.scalar( + select(Pack).where(Pack.id == re.sub(r"_append_.*$", "", pack.id)) + ) + else: + base_pack = None + + arcsong_difficulties = [] + for difficulty in song.difficulties: + arcsong_difficulties.append( + cls.craft_difficulty_item(difficulty, base_pack=base_pack) + ) + + arcsong_songs.append( + { + "song_id": song.id, + "difficulties": arcsong_difficulties, + "alias": [], + } + ) + + return {"songs": arcsong_songs} diff --git a/src/arcaea_offline/external/importers/arcsong.py b/src/arcaea_offline/external/importers/arcsong.py new file mode 100644 index 0000000..5ed20b9 --- /dev/null +++ b/src/arcaea_offline/external/importers/arcsong.py @@ -0,0 +1,38 @@ +import sqlite3 +from typing import List, overload + +from arcaea_offline.constants.enums.arcaea import ArcaeaRatingClass +from arcaea_offline.database.models.v5 import ChartInfo + + +class ArcsongDatabaseImporter: + @classmethod + @overload + def parse(cls, conn: sqlite3.Connection) -> List[ChartInfo]: ... + + @classmethod + @overload + def parse(cls, conn: sqlite3.Cursor) -> List[ChartInfo]: ... + + @classmethod + def parse(cls, conn) -> List[ChartInfo]: + if isinstance(conn, sqlite3.Connection): + return cls.parse(conn.cursor()) + + assert isinstance(conn, sqlite3.Cursor) + + results = [] + db_results = conn.execute( + "SELECT song_id, rating_class, rating, note FROM charts" + ) + for result in db_results: + results.append( + ChartInfo( + song_id=result[0], + rating_class=ArcaeaRatingClass(result[1]), + constant=result[2], + notes=result[3] or None, + ) + ) + + return results diff --git a/tests/external/importers/test_arcsong.py b/tests/external/importers/test_arcsong.py new file mode 100644 index 0000000..fb5104b --- /dev/null +++ b/tests/external/importers/test_arcsong.py @@ -0,0 +1,45 @@ +import sqlite3 + +import tests.resources +from arcaea_offline.constants.enums.arcaea import ArcaeaRatingClass +from arcaea_offline.database.models.v5 import ChartInfo +from arcaea_offline.external.importers.arcsong import ( + ArcsongDatabaseImporter, +) + +db = sqlite3.connect(":memory:") +db.executescript( + tests.resources.get_resource("arcsong.sql").read_text(encoding="utf-8") +) + + +class TestArcsongDatabaseImporter: + def test_parse(self): + items = ArcsongDatabaseImporter.parse(db) + + assert all(isinstance(item, ChartInfo) for item in items) + assert len(items) == 3 + + base1_pst = next( + it + for it in items + if it.song_id == "base1" and it.rating_class is ArcaeaRatingClass.PAST + ) + assert base1_pst.constant == 30 + assert base1_pst.notes == 500 + + base1_prs = next( + it + for it in items + if it.song_id == "base1" and it.rating_class is ArcaeaRatingClass.PRESENT + ) + assert base1_prs.constant == 60 + assert base1_prs.notes == 700 + + base1_ftr = next( + it + for it in items + if it.song_id == "base1" and it.rating_class is ArcaeaRatingClass.FUTURE + ) + assert base1_ftr.constant == 90 + assert base1_ftr.notes == 1000 diff --git a/tests/resources/arcsong.sql b/tests/resources/arcsong.sql new file mode 100644 index 0000000..c4b1fad --- /dev/null +++ b/tests/resources/arcsong.sql @@ -0,0 +1,40 @@ +CREATE TABLE packages( + `id` TEXT PRIMARY KEY NOT NULL, + `name` TEXT NOT NULL DEFAULT "" +); + +CREATE TABLE charts( + song_id TEXT NOT NULL DEFAULT '', + rating_class INTEGER NOT NULL DEFAULT 0, + name_en TEXT NOT NULL DEFAULT '', + name_jp TEXT DEFAULT '', + artist TEXT NOT NULL DEFAULT '', + bpm TEXT NOT NULL DEFAULT '', + bpm_base DOUBLE NOT NULL DEFAULT 0, + `set` TEXT NOT NULL DEFAULT '', + `time` INTEGER DEFAULT 0, + side INTEGER NOT NULL DEFAULT 0, + world_unlock BOOLEAN NOT NULL DEFAULT 0, + remote_download BOOLEAN DEFAULT '', + bg TEXT NOT NULL DEFAULT '', + `date` INTEGER NOT NULL DEFAULT 0, + `version` TEXT NOT NULL DEFAULT '', + difficulty INTEGER NOT NULL DEFAULT 0, + rating INTEGER NOT NULL DEFAULT 0, + note INTEGER NOT NULL DEFAULT 0, + chart_designer TEXT DEFAULT '', + jacket_designer TEXT DEFAULT '', + jacket_override BOOLEAN NOT NULL DEFAULT 0, + audio_override BOOLEAN NOT NULL DEFAULT 0, + PRIMARY KEY(song_id, rating_class) +); + + +INSERT INTO packages ("id", "name") VALUES + ('base', 'Base Pack'), + ('core', 'Core Pack'); + +INSERT INTO charts ("song_id", "rating_class", "name_en", "name_jp", "artist", "bpm", "bpm_base", "set", "time", "side", "world_unlock", "remote_download", "bg", "date", "version", "difficulty", "rating", "note", "chart_designer", "jacket_designer", "jacket_override", "audio_override") VALUES + ('base1', '0', 'Base song 1', 'ベース・ソング・ワン', 'Artist', '1024', '1024.0', 'base', '1024', '1', '1', '0', '', '1400067914', '1.0', '6', '30', '500', 'Charter', '78rwey63a', '0', '0'), + ('base1', '1', 'Base song 1', 'ベース・ソング・ワン', 'Artist', '1024', '1024.0', 'base', '1024', '1', '1', '0', '', '1400067914', '1.0', '12', '60', '700', 'Charter', '78rwey63b', '0', '0'), + ('base1', '2', 'Base song 1', 'ベース・ソング・ワン', 'Artist', '1024', '1024.0', 'base', '1024', '1', '1', '0', '', '1400067914', '1.0', '18', '90', '1000', 'Charter', '78rwey63c', '0', '0');