impr(db): general improvements

This commit is contained in:
283375 2023-06-06 19:36:32 +08:00
parent e874e38d3b
commit 3896efd8de
3 changed files with 80 additions and 139 deletions

View File

@ -1,13 +1,16 @@
import os import os
import sqlite3 import sqlite3
from dataclasses import fields, is_dataclass from dataclasses import fields, is_dataclass
from typing import List, NamedTuple, Optional, Union from typing import List, NamedTuple, Optional, TypeVar, Union
from thefuzz import fuzz from thefuzz import fuzz
from thefuzz import process as fuzz_process from thefuzz import process as fuzz_process
from .models import DbAliasRow, DbCalculatedRow, DbChartRow, DbPackageRow, DbScoreRow from .models import DbAliasRow, DbCalculatedRow, DbChartRow, DbPackageRow, DbScoreRow
from .utils.singleton import Singleton from .utils.singleton import Singleton
from .utils.types import TDataclass
T = TypeVar("T", bound=TDataclass)
class Database(metaclass=Singleton): class Database(metaclass=Singleton):
@ -26,38 +29,28 @@ class Database(metaclass=Singleton):
def validate_song_id(self, song_id): def validate_song_id(self, song_id):
with self.conn as conn: with self.conn as conn:
cursor = conn.cursor() cursor = conn.cursor()
result = cursor.execute( cursor.execute("SELECT COUNT(*) FROM charts WHERE song_id = ?", (song_id,))
"SELECT song_id FROM charts WHERE song_id = ?", (song_id,) result = cursor.fetchone()
).fetchall() return result[0] > 0
return len(result) > 0
def update_arcsong_db(self, path: Union[str, bytes]): def update_arcsong_db(self, path: Union[str, bytes]):
arcsong_conn = sqlite3.connect(path) with sqlite3.connect(path) as arcsong_conn:
data = {}
with arcsong_conn:
arcsong_cursor = arcsong_conn.cursor() arcsong_cursor = arcsong_conn.cursor()
data["charts"] = arcsong_cursor.execute( data = {
""" "charts": arcsong_cursor.execute(
SELECT song_id, rating_class, name_en, name_jp, artist, bpm, bpm_base, "set", time, side, world_unlock, remote_download, bg, date, version, difficulty, rating, note, chart_designer, jacket_designer, jacket_override, audio_override "SELECT song_id, rating_class, name_en, name_jp, artist, bpm, bpm_base, 'set', time, side, world_unlock, remote_download, bg, date, version, difficulty, rating, note, chart_designer, jacket_designer, jacket_override, audio_override FROM charts"
FROM charts ).fetchall(),
""" "aliases": arcsong_cursor.execute(
).fetchall() "SELECT sid, alias FROM alias"
data["aliases"] = arcsong_cursor.execute( ).fetchall(),
""" "packages": arcsong_cursor.execute(
SELECT sid, alias "SELECT id, name FROM packages"
FROM alias ).fetchall(),
""" }
).fetchall()
data["packages"] = arcsong_cursor.execute(
"""
SELECT id, name
FROM packages
"""
).fetchall()
with self.conn as conn: with self.conn as conn:
cursor = conn.cursor() cursor = conn.cursor()
for table in data: for table, rows in data.items():
columns = [ columns = [
row[0] row[0]
for row in cursor.execute( for row in cursor.execute(
@ -65,15 +58,16 @@ class Database(metaclass=Singleton):
).description ).description
] ]
column_count = len(columns) column_count = len(columns)
assert column_count == len(data[table][0]) assert column_count == len(
columns_insert_str = ", ".join(columns) rows[0]
values_insert_str = ", ".join("?" * column_count) ), f"Incompatible column count for table '{table}'"
placeholders = ", ".join(["?" for _ in range(column_count)])
update_clauses = ", ".join( update_clauses = ", ".join(
[f"{column} = excluded.{column}" for column in columns] [f"{column} = excluded.{column}" for column in columns]
) )
cursor.executemany( cursor.executemany(
f"INSERT INTO {table} ({columns_insert_str}) VALUES ({values_insert_str}) ON CONFLICT DO UPDATE SET {update_clauses}", f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({placeholders}) ON CONFLICT DO UPDATE SET {update_clauses}",
data[table], rows,
) )
conn.commit() conn.commit()
@ -225,7 +219,7 @@ class Database(metaclass=Singleton):
) AS subquery ) AS subquery
WHERE name IS NOT NULL AND name <> '' WHERE name IS NOT NULL AND name <> ''
GROUP BY song_id, name GROUP BY song_id, name
""" """,
] ]
with self.conn as conn: with self.conn as conn:
@ -243,99 +237,54 @@ class Database(metaclass=Singleton):
def __get_columns_clause(self, columns: List[str]): def __get_columns_clause(self, columns: List[str]):
return ", ".join([f'"{column}"' for column in columns]) return ", ".join([f'"{column}"' for column in columns])
def get_packages(self): def __get_table(
with self.conn as conn: self, table_name: str, datacls: T, where_clause: str = "", params=None
) -> List[T]:
if params is None:
params = []
columns_clause = self.__get_columns_clause( columns_clause = self.__get_columns_clause(
self.__get_columns_from_dataclass(DbPackageRow) self.__get_columns_from_dataclass(datacls)
) )
sql = f"SELECT {columns_clause} FROM {table_name}"
if where_clause:
sql += " WHERE "
sql += where_clause
with self.conn as conn:
cursor = conn.cursor() cursor = conn.cursor()
return [ cursor.execute(sql, params)
DbPackageRow(*row) return [datacls(*row) for row in cursor.fetchall()]
for row in cursor.execute(
f"SELECT {columns_clause} FROM packages" def get_packages(self):
).fetchall() return self.__get_table("packages", DbPackageRow)
]
def get_aliases(self): def get_aliases(self):
with self.conn as conn: return self.__get_table("aliases", DbAliasRow)
columns_clause = self.__get_columns_clause(
self.__get_columns_from_dataclass(DbAliasRow)
)
cursor = conn.cursor()
return [
DbAliasRow(*row)
for row in cursor.execute(
f"SELECT {columns_clause} FROM aliases"
).fetchall()
]
def get_aliases_by_song_id(self, song_id: str): def get_aliases_by_song_id(self, song_id: str):
with self.conn as conn: return self.__get_table("aliases", DbAliasRow, "song_id = ?", (song_id,))
columns_clause = self.__get_columns_clause(
self.__get_columns_from_dataclass(DbAliasRow)
)
cursor = conn.cursor()
return [
DbAliasRow(*row)
for row in (
cursor.execute(
f"SELECT {columns_clause} FROM aliases WHERE song_id = ?",
(song_id,),
).fetchall()
)
]
def get_charts(self): def get_charts(self):
with self.conn as conn: return self.__get_table("charts", DbChartRow)
columns_clause = self.__get_columns_clause(
self.__get_columns_from_dataclass(DbChartRow)
)
cursor = conn.cursor()
return [
DbChartRow(*row)
for row in cursor.execute(
f"SELECT {columns_clause} FROM charts"
).fetchall()
]
def get_charts_by_song_id(self, song_id: str): def get_charts_by_song_id(self, song_id: str):
with self.conn as conn: return self.__get_table("charts", DbChartRow, "song_id = ?", (song_id,))
columns_clause = self.__get_columns_clause(
self.__get_columns_from_dataclass(DbChartRow)
)
cursor = conn.cursor()
return [
DbChartRow(*row)
for row in (
cursor.execute(
f"SELECT {columns_clause} FROM charts WHERE song_id = ?",
(song_id,),
).fetchall()
)
]
def get_charts_by_package_id(self, package_id: str): def get_charts_by_package_id(self, package_id: str):
with self.conn as conn: return self.__get_table("charts", DbChartRow, "package_id = ?", (package_id,))
columns_clause = self.__get_columns_clause(
self.__get_columns_from_dataclass(DbChartRow)
)
cursor = conn.cursor()
return [
DbChartRow(*row)
for row in cursor.execute(
f"SELECT {columns_clause} FROM charts WHERE package_id = ?",
(package_id,),
).fetchall()
]
class FuzzySearchSongIdResult(NamedTuple): class FuzzySearchSongIdResult(NamedTuple):
song_id: str song_id: str
confidence: int confidence: int
def fuzzy_search_song_id(self, input_str: str, limit: int= 5) -> List[FuzzySearchSongIdResult]: def fuzzy_search_song_id(
self, input_str: str, limit: int = 5
) -> List[FuzzySearchSongIdResult]:
with self.conn as conn: with self.conn as conn:
cursor = conn.cursor() cursor = conn.cursor()
db_results = cursor.execute("SELECT song_id, name FROM song_id_names").fetchall() db_results = cursor.execute(
"SELECT song_id, name FROM song_id_names"
).fetchall()
name_song_id_map = {r[1]: r[0] for r in db_results} name_song_id_map = {r[1]: r[0] for r in db_results}
names = name_song_id_map.keys() names = name_song_id_map.keys()
fuzzy_results = fuzz_process.extractBests(input_str, names, scorer=fuzz.partial_ratio, limit=limit) # type: ignore fuzzy_results = fuzz_process.extractBests(input_str, names, scorer=fuzz.partial_ratio, limit=limit) # type: ignore
@ -346,8 +295,9 @@ class Database(metaclass=Singleton):
song_id = name_song_id_map[name] song_id = name_song_id_map[name]
results[song_id] = max(confidence, results.get(song_id, 0)) results[song_id] = max(confidence, results.get(song_id, 0))
return [self.FuzzySearchSongIdResult(si, confi) for si, confi in results.items()] return [
self.FuzzySearchSongIdResult(si, confi) for si, confi in results.items()
]
def get_scores( def get_scores(
self, self,
@ -355,26 +305,18 @@ class Database(metaclass=Singleton):
song_id: Optional[List[str]] = None, song_id: Optional[List[str]] = None,
rating_class: Optional[List[int]] = None, rating_class: Optional[List[int]] = None,
): ):
columns = ",".join([f"[{field.name}]" for field in fields(DbScoreRow)])
where_clauses = [] where_clauses = []
params = [] params = []
if song_id: if song_id:
where_clauses.append(f"song_id IN ({','.join('?'*len(song_id))})") where_clauses.append(f"song_id IN ({','.join('?'*len(song_id))})")
params.extend(song_id) params.extend(song_id)
if rating_class: if rating_class:
where_clauses.append( where_clauses.append(f"rating_class IN ({','.join('?'*len(rating_class))})")
f"rating_class IN ({','.join('?'*len(rating_class))})"
)
params.extend(rating_class) params.extend(rating_class)
final_sql = f"SELECT {columns} FROM scores"
if where_clauses: return self.__get_table(
final_sql += " WHERE " "scores", DbScoreRow, " AND ".join(where_clauses), params
final_sql += " AND ".join(where_clauses) )
with self.conn as conn:
cursor = conn.cursor()
return [
DbScoreRow(*row) for row in cursor.execute(final_sql, params).fetchall()
]
def get_calculated( def get_calculated(
self, self,
@ -382,27 +324,18 @@ class Database(metaclass=Singleton):
song_id: Optional[List[str]] = None, song_id: Optional[List[str]] = None,
rating_class: Optional[List[int]] = None, rating_class: Optional[List[int]] = None,
): ):
columns = ",".join([f"[{field.name}]" for field in fields(DbCalculatedRow)])
where_clauses = [] where_clauses = []
params = [] params = []
if song_id: if song_id:
where_clauses.append(f"song_id IN ({','.join('?'*len(song_id))})") where_clauses.append(f"song_id IN ({','.join('?'*len(song_id))})")
params.extend(song_id) params.extend(song_id)
if rating_class: if rating_class:
where_clauses.append( where_clauses.append(f"rating_class IN ({','.join('?'*len(rating_class))})")
f"rating_class IN ({','.join('?'*len(rating_class))})"
)
params.extend(rating_class) params.extend(rating_class)
final_sql = f"SELECT {columns} FROM calculated"
if where_clauses: return self.__get_table(
final_sql += " WHERE " "calculated", DbCalculatedRow, " AND ".join(where_clauses), params
final_sql += " AND ".join(where_clauses) )
with self.conn as conn:
cursor = conn.cursor()
return [
DbCalculatedRow(*row)
for row in cursor.execute(final_sql, params).fetchall()
]
def get_b30(self) -> float: def get_b30(self) -> float:
with self.conn as conn: with self.conn as conn:
@ -417,12 +350,14 @@ class Database(metaclass=Singleton):
def get_potential(self) -> float: def get_potential(self) -> float:
with self.conn as conn: with self.conn as conn:
cursor = conn.cursor() cursor = conn.cursor()
return cursor.execute("SELECT potential FROM calculated_potential").fetchone()[0] return cursor.execute(
"SELECT potential FROM calculated_potential"
).fetchone()[0]
def insert_score(self, score: DbScoreRow): def insert_score(self, score: DbScoreRow):
columns = self.__get_columns_from_dataclass(DbScoreRow) columns = self.__get_columns_from_dataclass(DbScoreRow)
columns_clause = self.__get_columns_clause(columns) columns_clause = self.__get_columns_clause(columns)
params = [score.__getattribute__(column) for column in columns] params = [getattr(score, column) for column in columns]
with self.conn as conn: with self.conn as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(

View File

View File

@ -0,0 +1,6 @@
from typing import Any, Protocol, ClassVar, Dict
class TDataclass(Protocol):
__dataclass_fields__: ClassVar[Dict]
def __call__(self, *args: Any, **kwds: Any) -> Any:
...