diff --git a/src/arcaea_offline/external/chart_info_db/__init__.py b/src/arcaea_offline/external/chart_info_db/__init__.py deleted file mode 100644 index 6622093..0000000 --- a/src/arcaea_offline/external/chart_info_db/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .parser import ChartInfoDbParser - -__all__ = ["ChartInfoDbParser"] diff --git a/src/arcaea_offline/external/chart_info_db/parser.py b/src/arcaea_offline/external/chart_info_db/parser.py deleted file mode 100644 index a80a891..0000000 --- a/src/arcaea_offline/external/chart_info_db/parser.py +++ /dev/null @@ -1,35 +0,0 @@ -import contextlib -import sqlite3 -from typing import List - -from sqlalchemy.orm import Session - -from ...models.songs import ChartInfo - - -class ChartInfoDbParser: - def __init__(self, filepath): - self.filepath = filepath - - def parse(self) -> List[ChartInfo]: - results = [] - with sqlite3.connect(self.filepath) as conn: - with contextlib.closing(conn.cursor()) as cursor: - db_results = cursor.execute( - "SELECT song_id, rating_class, constant, notes FROM charts_info" - ).fetchall() - for result in db_results: - chart = ChartInfo( - song_id=result[0], - rating_class=result[1], - constant=result[2], - notes=result[3] or None, - ) - results.append(chart) - - return results - - def write_database(self, session: Session): - results = self.parse() - for result in results: - session.merge(result) diff --git a/src/arcaea_offline/external/importers/chart_info_database.py b/src/arcaea_offline/external/importers/chart_info_database.py new file mode 100644 index 0000000..e82ceec --- /dev/null +++ b/src/arcaea_offline/external/importers/chart_info_database.py @@ -0,0 +1,42 @@ +import sqlite3 +from contextlib import closing +from typing import List, overload + +from arcaea_offline.constants.enums.arcaea import ArcaeaRatingClass +from arcaea_offline.database.models.v5 import ChartInfo + + +class ChartInfoDatabaseParser: + @classmethod + @overload + def parse(cls, conn: sqlite3.Connection) -> List[ChartInfo]: ... + + @classmethod + @overload + def parse(cls, conn: sqlite3.Cursor) -> List[ChartInfo]: ... + + @classmethod + def parse(cls, conn) -> List[ChartInfo]: + if isinstance(conn, sqlite3.Connection): + with closing(conn.cursor()) as cur: + return cls.parse(cur) + + if not isinstance(conn, sqlite3.Cursor): + raise ValueError("conn must be sqlite3.Connection or sqlite3.Cursor!") + + db_items = conn.execute( + "SELECT song_id, rating_class, constant, notes FROM charts_info" + ).fetchall() + + results: List[ChartInfo] = [] + for item in db_items: + (song_id, rating_class, constant, notes) = item + + chart_info = ChartInfo() + chart_info.song_id = song_id + chart_info.rating_class = ArcaeaRatingClass(rating_class) + chart_info.constant = constant + chart_info.notes = notes + + results.append(chart_info) + return results diff --git a/tests/external/importers/test_chart_info_database.py b/tests/external/importers/test_chart_info_database.py new file mode 100644 index 0000000..c73b707 --- /dev/null +++ b/tests/external/importers/test_chart_info_database.py @@ -0,0 +1,34 @@ +import sqlite3 + +import tests.resources +from arcaea_offline.constants.enums.arcaea import ArcaeaRatingClass +from arcaea_offline.database.models.v5 import ChartInfo +from arcaea_offline.external.importers.chart_info_database import ( + ChartInfoDatabaseParser, +) + +db = sqlite3.connect(":memory:") +db.executescript(tests.resources.get_resource("cidb.sql").read_text(encoding="utf-8")) + + +class TestChartInfoDatabaseParser: + def test_parse(self): + items = ChartInfoDatabaseParser.parse(db) + assert all(isinstance(item, ChartInfo) for item in items) + + assert len(items) == 3 + + test1 = next(filter(lambda x: x.song_id == "test1", items)) + assert test1.rating_class is ArcaeaRatingClass.PRESENT + assert test1.constant == 90 + assert test1.notes == 900 + + test2 = next(filter(lambda x: x.song_id == "test2", items)) + assert test2.rating_class is ArcaeaRatingClass.FUTURE + assert test2.constant == 95 + assert test2.notes == 950 + + test3 = next(filter(lambda x: x.song_id == "test3", items)) + assert test3.rating_class is ArcaeaRatingClass.BEYOND + assert test3.constant == 100 + assert test3.notes is None diff --git a/tests/resources/cidb.sql b/tests/resources/cidb.sql new file mode 100644 index 0000000..a6f411b --- /dev/null +++ b/tests/resources/cidb.sql @@ -0,0 +1,11 @@ +CREATE TABLE charts_info ( + song_id TEXT NOT NULL, + rating_class INTEGER NOT NULL, + constant INTEGER NOT NULL, + notes INTEGER +); + +INSERT INTO charts_info (song_id, rating_class, constant, notes) VALUES + ("test1", 1, 90, 900), + ("test2", 2, 95, 950), + ("test3", 3, 100, NULL);