impr(db): general improvements

This commit is contained in:
283375 2023-06-06 19:36:32 +08:00
parent e874e38d3b
commit 3896efd8de
3 changed files with 80 additions and 139 deletions

View File

@ -1,13 +1,16 @@
import os
import sqlite3
from dataclasses import fields, is_dataclass
from typing import List, NamedTuple, Optional, Union
from typing import List, NamedTuple, Optional, TypeVar, Union
from thefuzz import fuzz
from thefuzz import process as fuzz_process
from .models import DbAliasRow, DbCalculatedRow, DbChartRow, DbPackageRow, DbScoreRow
from .utils.singleton import Singleton
from .utils.types import TDataclass
T = TypeVar("T", bound=TDataclass)
class Database(metaclass=Singleton):
@ -26,38 +29,28 @@ class Database(metaclass=Singleton):
def validate_song_id(self, song_id):
with self.conn as conn:
cursor = conn.cursor()
result = cursor.execute(
"SELECT song_id FROM charts WHERE song_id = ?", (song_id,)
).fetchall()
return len(result) > 0
cursor.execute("SELECT COUNT(*) FROM charts WHERE song_id = ?", (song_id,))
result = cursor.fetchone()
return result[0] > 0
def update_arcsong_db(self, path: Union[str, bytes]):
arcsong_conn = sqlite3.connect(path)
data = {}
with arcsong_conn:
with sqlite3.connect(path) as arcsong_conn:
arcsong_cursor = arcsong_conn.cursor()
data["charts"] = arcsong_cursor.execute(
"""
SELECT song_id, rating_class, name_en, name_jp, artist, bpm, bpm_base, "set", time, side, world_unlock, remote_download, bg, date, version, difficulty, rating, note, chart_designer, jacket_designer, jacket_override, audio_override
FROM charts
"""
).fetchall()
data["aliases"] = arcsong_cursor.execute(
"""
SELECT sid, alias
FROM alias
"""
).fetchall()
data["packages"] = arcsong_cursor.execute(
"""
SELECT id, name
FROM packages
"""
).fetchall()
data = {
"charts": arcsong_cursor.execute(
"SELECT song_id, rating_class, name_en, name_jp, artist, bpm, bpm_base, 'set', time, side, world_unlock, remote_download, bg, date, version, difficulty, rating, note, chart_designer, jacket_designer, jacket_override, audio_override FROM charts"
).fetchall(),
"aliases": arcsong_cursor.execute(
"SELECT sid, alias FROM alias"
).fetchall(),
"packages": arcsong_cursor.execute(
"SELECT id, name FROM packages"
).fetchall(),
}
with self.conn as conn:
cursor = conn.cursor()
for table in data:
for table, rows in data.items():
columns = [
row[0]
for row in cursor.execute(
@ -65,15 +58,16 @@ class Database(metaclass=Singleton):
).description
]
column_count = len(columns)
assert column_count == len(data[table][0])
columns_insert_str = ", ".join(columns)
values_insert_str = ", ".join("?" * column_count)
assert column_count == len(
rows[0]
), f"Incompatible column count for table '{table}'"
placeholders = ", ".join(["?" for _ in range(column_count)])
update_clauses = ", ".join(
[f"{column} = excluded.{column}" for column in columns]
)
cursor.executemany(
f"INSERT INTO {table} ({columns_insert_str}) VALUES ({values_insert_str}) ON CONFLICT DO UPDATE SET {update_clauses}",
data[table],
f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({placeholders}) ON CONFLICT DO UPDATE SET {update_clauses}",
rows,
)
conn.commit()
@ -225,7 +219,7 @@ class Database(metaclass=Singleton):
) AS subquery
WHERE name IS NOT NULL AND name <> ''
GROUP BY song_id, name
"""
""",
]
with self.conn as conn:
@ -243,111 +237,67 @@ class Database(metaclass=Singleton):
def __get_columns_clause(self, columns: List[str]):
return ", ".join([f'"{column}"' for column in columns])
def get_packages(self):
def __get_table(
self, table_name: str, datacls: T, where_clause: str = "", params=None
) -> List[T]:
if params is None:
params = []
columns_clause = self.__get_columns_clause(
self.__get_columns_from_dataclass(datacls)
)
sql = f"SELECT {columns_clause} FROM {table_name}"
if where_clause:
sql += " WHERE "
sql += where_clause
with self.conn as conn:
columns_clause = self.__get_columns_clause(
self.__get_columns_from_dataclass(DbPackageRow)
)
cursor = conn.cursor()
return [
DbPackageRow(*row)
for row in cursor.execute(
f"SELECT {columns_clause} FROM packages"
).fetchall()
]
cursor.execute(sql, params)
return [datacls(*row) for row in cursor.fetchall()]
def get_packages(self):
return self.__get_table("packages", DbPackageRow)
def get_aliases(self):
with self.conn as conn:
columns_clause = self.__get_columns_clause(
self.__get_columns_from_dataclass(DbAliasRow)
)
cursor = conn.cursor()
return [
DbAliasRow(*row)
for row in cursor.execute(
f"SELECT {columns_clause} FROM aliases"
).fetchall()
]
return self.__get_table("aliases", DbAliasRow)
def get_aliases_by_song_id(self, song_id: str):
with self.conn as conn:
columns_clause = self.__get_columns_clause(
self.__get_columns_from_dataclass(DbAliasRow)
)
cursor = conn.cursor()
return [
DbAliasRow(*row)
for row in (
cursor.execute(
f"SELECT {columns_clause} FROM aliases WHERE song_id = ?",
(song_id,),
).fetchall()
)
]
return self.__get_table("aliases", DbAliasRow, "song_id = ?", (song_id,))
def get_charts(self):
with self.conn as conn:
columns_clause = self.__get_columns_clause(
self.__get_columns_from_dataclass(DbChartRow)
)
cursor = conn.cursor()
return [
DbChartRow(*row)
for row in cursor.execute(
f"SELECT {columns_clause} FROM charts"
).fetchall()
]
return self.__get_table("charts", DbChartRow)
def get_charts_by_song_id(self, song_id: str):
with self.conn as conn:
columns_clause = self.__get_columns_clause(
self.__get_columns_from_dataclass(DbChartRow)
)
cursor = conn.cursor()
return [
DbChartRow(*row)
for row in (
cursor.execute(
f"SELECT {columns_clause} FROM charts WHERE song_id = ?",
(song_id,),
).fetchall()
)
]
return self.__get_table("charts", DbChartRow, "song_id = ?", (song_id,))
def get_charts_by_package_id(self, package_id: str):
with self.conn as conn:
columns_clause = self.__get_columns_clause(
self.__get_columns_from_dataclass(DbChartRow)
)
cursor = conn.cursor()
return [
DbChartRow(*row)
for row in cursor.execute(
f"SELECT {columns_clause} FROM charts WHERE package_id = ?",
(package_id,),
).fetchall()
]
return self.__get_table("charts", DbChartRow, "package_id = ?", (package_id,))
class FuzzySearchSongIdResult(NamedTuple):
song_id: str
confidence: int
def fuzzy_search_song_id(self, input_str: str, limit: int= 5) -> List[FuzzySearchSongIdResult]:
def fuzzy_search_song_id(
self, input_str: str, limit: int = 5
) -> List[FuzzySearchSongIdResult]:
with self.conn as conn:
cursor = conn.cursor()
db_results = cursor.execute("SELECT song_id, name FROM song_id_names").fetchall()
db_results = cursor.execute(
"SELECT song_id, name FROM song_id_names"
).fetchall()
name_song_id_map = {r[1]: r[0] for r in db_results}
names = name_song_id_map.keys()
fuzzy_results = fuzz_process.extractBests(input_str, names, scorer=fuzz.partial_ratio, limit=limit) # type: ignore
results = {}
for fuzzy_result in fuzzy_results:
for fuzzy_result in fuzzy_results:
name = fuzzy_result[0]
confidence = fuzzy_result[1]
song_id = name_song_id_map[name]
results[song_id] = max(confidence, results.get(song_id, 0))
return [self.FuzzySearchSongIdResult(si, confi) for si, confi in results.items()]
return [
self.FuzzySearchSongIdResult(si, confi) for si, confi in results.items()
]
def get_scores(
self,
@ -355,26 +305,18 @@ class Database(metaclass=Singleton):
song_id: Optional[List[str]] = None,
rating_class: Optional[List[int]] = None,
):
columns = ",".join([f"[{field.name}]" for field in fields(DbScoreRow)])
where_clauses = []
params = []
if song_id:
where_clauses.append(f"song_id IN ({','.join('?'*len(song_id))})")
params.extend(song_id)
if rating_class:
where_clauses.append(
f"rating_class IN ({','.join('?'*len(rating_class))})"
)
where_clauses.append(f"rating_class IN ({','.join('?'*len(rating_class))})")
params.extend(rating_class)
final_sql = f"SELECT {columns} FROM scores"
if where_clauses:
final_sql += " WHERE "
final_sql += " AND ".join(where_clauses)
with self.conn as conn:
cursor = conn.cursor()
return [
DbScoreRow(*row) for row in cursor.execute(final_sql, params).fetchall()
]
return self.__get_table(
"scores", DbScoreRow, " AND ".join(where_clauses), params
)
def get_calculated(
self,
@ -382,27 +324,18 @@ class Database(metaclass=Singleton):
song_id: Optional[List[str]] = None,
rating_class: Optional[List[int]] = None,
):
columns = ",".join([f"[{field.name}]" for field in fields(DbCalculatedRow)])
where_clauses = []
params = []
if song_id:
where_clauses.append(f"song_id IN ({','.join('?'*len(song_id))})")
params.extend(song_id)
if rating_class:
where_clauses.append(
f"rating_class IN ({','.join('?'*len(rating_class))})"
)
where_clauses.append(f"rating_class IN ({','.join('?'*len(rating_class))})")
params.extend(rating_class)
final_sql = f"SELECT {columns} FROM calculated"
if where_clauses:
final_sql += " WHERE "
final_sql += " AND ".join(where_clauses)
with self.conn as conn:
cursor = conn.cursor()
return [
DbCalculatedRow(*row)
for row in cursor.execute(final_sql, params).fetchall()
]
return self.__get_table(
"calculated", DbCalculatedRow, " AND ".join(where_clauses), params
)
def get_b30(self) -> float:
with self.conn as conn:
@ -417,12 +350,14 @@ class Database(metaclass=Singleton):
def get_potential(self) -> float:
with self.conn as conn:
cursor = conn.cursor()
return cursor.execute("SELECT potential FROM calculated_potential").fetchone()[0]
return cursor.execute(
"SELECT potential FROM calculated_potential"
).fetchone()[0]
def insert_score(self, score: DbScoreRow):
columns = self.__get_columns_from_dataclass(DbScoreRow)
columns_clause = self.__get_columns_clause(columns)
params = [score.__getattribute__(column) for column in columns]
params = [getattr(score, column) for column in columns]
with self.conn as conn:
cursor = conn.cursor()
cursor.execute(

View File

View File

@ -0,0 +1,6 @@
from typing import Any, Protocol, ClassVar, Dict
class TDataclass(Protocol):
__dataclass_fields__: ClassVar[Dict]
def __call__(self, *args: Any, **kwds: Any) -> Any:
...