From 3896efd8de37e5f32889fd3f01693629630ce905 Mon Sep 17 00:00:00 2001 From: 283375 Date: Tue, 6 Jun 2023 19:36:32 +0800 Subject: [PATCH] impr(db): general improvements --- src/arcaea_offline/database.py | 213 ++++++++++----------------- src/arcaea_offline/utils/__init__.py | 0 src/arcaea_offline/utils/types.py | 6 + 3 files changed, 80 insertions(+), 139 deletions(-) create mode 100644 src/arcaea_offline/utils/__init__.py create mode 100644 src/arcaea_offline/utils/types.py diff --git a/src/arcaea_offline/database.py b/src/arcaea_offline/database.py index 8d220cb..cb8f2aa 100644 --- a/src/arcaea_offline/database.py +++ b/src/arcaea_offline/database.py @@ -1,13 +1,16 @@ import os import sqlite3 from dataclasses import fields, is_dataclass -from typing import List, NamedTuple, Optional, Union +from typing import List, NamedTuple, Optional, TypeVar, Union from thefuzz import fuzz from thefuzz import process as fuzz_process from .models import DbAliasRow, DbCalculatedRow, DbChartRow, DbPackageRow, DbScoreRow from .utils.singleton import Singleton +from .utils.types import TDataclass + +T = TypeVar("T", bound=TDataclass) class Database(metaclass=Singleton): @@ -26,38 +29,28 @@ class Database(metaclass=Singleton): def validate_song_id(self, song_id): with self.conn as conn: cursor = conn.cursor() - result = cursor.execute( - "SELECT song_id FROM charts WHERE song_id = ?", (song_id,) - ).fetchall() - return len(result) > 0 + cursor.execute("SELECT COUNT(*) FROM charts WHERE song_id = ?", (song_id,)) + result = cursor.fetchone() + return result[0] > 0 def update_arcsong_db(self, path: Union[str, bytes]): - arcsong_conn = sqlite3.connect(path) - data = {} - with arcsong_conn: + with sqlite3.connect(path) as arcsong_conn: arcsong_cursor = arcsong_conn.cursor() - data["charts"] = arcsong_cursor.execute( - """ - SELECT song_id, rating_class, name_en, name_jp, artist, bpm, bpm_base, "set", time, side, world_unlock, remote_download, bg, date, version, difficulty, rating, note, chart_designer, jacket_designer, jacket_override, audio_override - FROM charts - """ - ).fetchall() - data["aliases"] = arcsong_cursor.execute( - """ - SELECT sid, alias - FROM alias - """ - ).fetchall() - data["packages"] = arcsong_cursor.execute( - """ - SELECT id, name - FROM packages - """ - ).fetchall() + data = { + "charts": arcsong_cursor.execute( + "SELECT song_id, rating_class, name_en, name_jp, artist, bpm, bpm_base, 'set', time, side, world_unlock, remote_download, bg, date, version, difficulty, rating, note, chart_designer, jacket_designer, jacket_override, audio_override FROM charts" + ).fetchall(), + "aliases": arcsong_cursor.execute( + "SELECT sid, alias FROM alias" + ).fetchall(), + "packages": arcsong_cursor.execute( + "SELECT id, name FROM packages" + ).fetchall(), + } with self.conn as conn: cursor = conn.cursor() - for table in data: + for table, rows in data.items(): columns = [ row[0] for row in cursor.execute( @@ -65,15 +58,16 @@ class Database(metaclass=Singleton): ).description ] column_count = len(columns) - assert column_count == len(data[table][0]) - columns_insert_str = ", ".join(columns) - values_insert_str = ", ".join("?" * column_count) + assert column_count == len( + rows[0] + ), f"Incompatible column count for table '{table}'" + placeholders = ", ".join(["?" for _ in range(column_count)]) update_clauses = ", ".join( [f"{column} = excluded.{column}" for column in columns] ) cursor.executemany( - f"INSERT INTO {table} ({columns_insert_str}) VALUES ({values_insert_str}) ON CONFLICT DO UPDATE SET {update_clauses}", - data[table], + f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({placeholders}) ON CONFLICT DO UPDATE SET {update_clauses}", + rows, ) conn.commit() @@ -225,7 +219,7 @@ class Database(metaclass=Singleton): ) AS subquery WHERE name IS NOT NULL AND name <> '' GROUP BY song_id, name - """ + """, ] with self.conn as conn: @@ -243,111 +237,67 @@ class Database(metaclass=Singleton): def __get_columns_clause(self, columns: List[str]): return ", ".join([f'"{column}"' for column in columns]) - def get_packages(self): + def __get_table( + self, table_name: str, datacls: T, where_clause: str = "", params=None + ) -> List[T]: + if params is None: + params = [] + columns_clause = self.__get_columns_clause( + self.__get_columns_from_dataclass(datacls) + ) + + sql = f"SELECT {columns_clause} FROM {table_name}" + if where_clause: + sql += " WHERE " + sql += where_clause with self.conn as conn: - columns_clause = self.__get_columns_clause( - self.__get_columns_from_dataclass(DbPackageRow) - ) cursor = conn.cursor() - return [ - DbPackageRow(*row) - for row in cursor.execute( - f"SELECT {columns_clause} FROM packages" - ).fetchall() - ] + cursor.execute(sql, params) + return [datacls(*row) for row in cursor.fetchall()] + + def get_packages(self): + return self.__get_table("packages", DbPackageRow) def get_aliases(self): - with self.conn as conn: - columns_clause = self.__get_columns_clause( - self.__get_columns_from_dataclass(DbAliasRow) - ) - cursor = conn.cursor() - return [ - DbAliasRow(*row) - for row in cursor.execute( - f"SELECT {columns_clause} FROM aliases" - ).fetchall() - ] + return self.__get_table("aliases", DbAliasRow) def get_aliases_by_song_id(self, song_id: str): - with self.conn as conn: - columns_clause = self.__get_columns_clause( - self.__get_columns_from_dataclass(DbAliasRow) - ) - cursor = conn.cursor() - return [ - DbAliasRow(*row) - for row in ( - cursor.execute( - f"SELECT {columns_clause} FROM aliases WHERE song_id = ?", - (song_id,), - ).fetchall() - ) - ] + return self.__get_table("aliases", DbAliasRow, "song_id = ?", (song_id,)) def get_charts(self): - with self.conn as conn: - columns_clause = self.__get_columns_clause( - self.__get_columns_from_dataclass(DbChartRow) - ) - cursor = conn.cursor() - return [ - DbChartRow(*row) - for row in cursor.execute( - f"SELECT {columns_clause} FROM charts" - ).fetchall() - ] + return self.__get_table("charts", DbChartRow) def get_charts_by_song_id(self, song_id: str): - with self.conn as conn: - columns_clause = self.__get_columns_clause( - self.__get_columns_from_dataclass(DbChartRow) - ) - cursor = conn.cursor() - return [ - DbChartRow(*row) - for row in ( - cursor.execute( - f"SELECT {columns_clause} FROM charts WHERE song_id = ?", - (song_id,), - ).fetchall() - ) - ] + return self.__get_table("charts", DbChartRow, "song_id = ?", (song_id,)) def get_charts_by_package_id(self, package_id: str): - with self.conn as conn: - columns_clause = self.__get_columns_clause( - self.__get_columns_from_dataclass(DbChartRow) - ) - cursor = conn.cursor() - return [ - DbChartRow(*row) - for row in cursor.execute( - f"SELECT {columns_clause} FROM charts WHERE package_id = ?", - (package_id,), - ).fetchall() - ] + return self.__get_table("charts", DbChartRow, "package_id = ?", (package_id,)) class FuzzySearchSongIdResult(NamedTuple): song_id: str confidence: int - def fuzzy_search_song_id(self, input_str: str, limit: int= 5) -> List[FuzzySearchSongIdResult]: + def fuzzy_search_song_id( + self, input_str: str, limit: int = 5 + ) -> List[FuzzySearchSongIdResult]: with self.conn as conn: cursor = conn.cursor() - db_results = cursor.execute("SELECT song_id, name FROM song_id_names").fetchall() + db_results = cursor.execute( + "SELECT song_id, name FROM song_id_names" + ).fetchall() name_song_id_map = {r[1]: r[0] for r in db_results} names = name_song_id_map.keys() fuzzy_results = fuzz_process.extractBests(input_str, names, scorer=fuzz.partial_ratio, limit=limit) # type: ignore results = {} - for fuzzy_result in fuzzy_results: + for fuzzy_result in fuzzy_results: name = fuzzy_result[0] confidence = fuzzy_result[1] song_id = name_song_id_map[name] results[song_id] = max(confidence, results.get(song_id, 0)) - return [self.FuzzySearchSongIdResult(si, confi) for si, confi in results.items()] - + return [ + self.FuzzySearchSongIdResult(si, confi) for si, confi in results.items() + ] def get_scores( self, @@ -355,26 +305,18 @@ class Database(metaclass=Singleton): song_id: Optional[List[str]] = None, rating_class: Optional[List[int]] = None, ): - columns = ",".join([f"[{field.name}]" for field in fields(DbScoreRow)]) where_clauses = [] params = [] if song_id: where_clauses.append(f"song_id IN ({','.join('?'*len(song_id))})") params.extend(song_id) if rating_class: - where_clauses.append( - f"rating_class IN ({','.join('?'*len(rating_class))})" - ) + where_clauses.append(f"rating_class IN ({','.join('?'*len(rating_class))})") params.extend(rating_class) - final_sql = f"SELECT {columns} FROM scores" - if where_clauses: - final_sql += " WHERE " - final_sql += " AND ".join(where_clauses) - with self.conn as conn: - cursor = conn.cursor() - return [ - DbScoreRow(*row) for row in cursor.execute(final_sql, params).fetchall() - ] + + return self.__get_table( + "scores", DbScoreRow, " AND ".join(where_clauses), params + ) def get_calculated( self, @@ -382,27 +324,18 @@ class Database(metaclass=Singleton): song_id: Optional[List[str]] = None, rating_class: Optional[List[int]] = None, ): - columns = ",".join([f"[{field.name}]" for field in fields(DbCalculatedRow)]) where_clauses = [] params = [] if song_id: where_clauses.append(f"song_id IN ({','.join('?'*len(song_id))})") params.extend(song_id) if rating_class: - where_clauses.append( - f"rating_class IN ({','.join('?'*len(rating_class))})" - ) + where_clauses.append(f"rating_class IN ({','.join('?'*len(rating_class))})") params.extend(rating_class) - final_sql = f"SELECT {columns} FROM calculated" - if where_clauses: - final_sql += " WHERE " - final_sql += " AND ".join(where_clauses) - with self.conn as conn: - cursor = conn.cursor() - return [ - DbCalculatedRow(*row) - for row in cursor.execute(final_sql, params).fetchall() - ] + + return self.__get_table( + "calculated", DbCalculatedRow, " AND ".join(where_clauses), params + ) def get_b30(self) -> float: with self.conn as conn: @@ -417,12 +350,14 @@ class Database(metaclass=Singleton): def get_potential(self) -> float: with self.conn as conn: cursor = conn.cursor() - return cursor.execute("SELECT potential FROM calculated_potential").fetchone()[0] + return cursor.execute( + "SELECT potential FROM calculated_potential" + ).fetchone()[0] def insert_score(self, score: DbScoreRow): columns = self.__get_columns_from_dataclass(DbScoreRow) columns_clause = self.__get_columns_clause(columns) - params = [score.__getattribute__(column) for column in columns] + params = [getattr(score, column) for column in columns] with self.conn as conn: cursor = conn.cursor() cursor.execute( diff --git a/src/arcaea_offline/utils/__init__.py b/src/arcaea_offline/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/arcaea_offline/utils/types.py b/src/arcaea_offline/utils/types.py new file mode 100644 index 0000000..54773e2 --- /dev/null +++ b/src/arcaea_offline/utils/types.py @@ -0,0 +1,6 @@ +from typing import Any, Protocol, ClassVar, Dict + +class TDataclass(Protocol): + __dataclass_fields__: ClassVar[Dict] + def __call__(self, *args: Any, **kwds: Any) -> Any: + ...