mirror of
https://github.com/283375/arcaea-offline.git
synced 2025-04-19 22:20:17 +00:00
impr(db): general improvements
This commit is contained in:
parent
e874e38d3b
commit
3896efd8de
@ -1,13 +1,16 @@
|
|||||||
import os
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from dataclasses import fields, is_dataclass
|
from dataclasses import fields, is_dataclass
|
||||||
from typing import List, NamedTuple, Optional, Union
|
from typing import List, NamedTuple, Optional, TypeVar, Union
|
||||||
|
|
||||||
from thefuzz import fuzz
|
from thefuzz import fuzz
|
||||||
from thefuzz import process as fuzz_process
|
from thefuzz import process as fuzz_process
|
||||||
|
|
||||||
from .models import DbAliasRow, DbCalculatedRow, DbChartRow, DbPackageRow, DbScoreRow
|
from .models import DbAliasRow, DbCalculatedRow, DbChartRow, DbPackageRow, DbScoreRow
|
||||||
from .utils.singleton import Singleton
|
from .utils.singleton import Singleton
|
||||||
|
from .utils.types import TDataclass
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=TDataclass)
|
||||||
|
|
||||||
|
|
||||||
class Database(metaclass=Singleton):
|
class Database(metaclass=Singleton):
|
||||||
@ -26,38 +29,28 @@ class Database(metaclass=Singleton):
|
|||||||
def validate_song_id(self, song_id):
|
def validate_song_id(self, song_id):
|
||||||
with self.conn as conn:
|
with self.conn as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
result = cursor.execute(
|
cursor.execute("SELECT COUNT(*) FROM charts WHERE song_id = ?", (song_id,))
|
||||||
"SELECT song_id FROM charts WHERE song_id = ?", (song_id,)
|
result = cursor.fetchone()
|
||||||
).fetchall()
|
return result[0] > 0
|
||||||
return len(result) > 0
|
|
||||||
|
|
||||||
def update_arcsong_db(self, path: Union[str, bytes]):
|
def update_arcsong_db(self, path: Union[str, bytes]):
|
||||||
arcsong_conn = sqlite3.connect(path)
|
with sqlite3.connect(path) as arcsong_conn:
|
||||||
data = {}
|
|
||||||
with arcsong_conn:
|
|
||||||
arcsong_cursor = arcsong_conn.cursor()
|
arcsong_cursor = arcsong_conn.cursor()
|
||||||
data["charts"] = arcsong_cursor.execute(
|
data = {
|
||||||
"""
|
"charts": arcsong_cursor.execute(
|
||||||
SELECT song_id, rating_class, name_en, name_jp, artist, bpm, bpm_base, "set", time, side, world_unlock, remote_download, bg, date, version, difficulty, rating, note, chart_designer, jacket_designer, jacket_override, audio_override
|
"SELECT song_id, rating_class, name_en, name_jp, artist, bpm, bpm_base, 'set', time, side, world_unlock, remote_download, bg, date, version, difficulty, rating, note, chart_designer, jacket_designer, jacket_override, audio_override FROM charts"
|
||||||
FROM charts
|
).fetchall(),
|
||||||
"""
|
"aliases": arcsong_cursor.execute(
|
||||||
).fetchall()
|
"SELECT sid, alias FROM alias"
|
||||||
data["aliases"] = arcsong_cursor.execute(
|
).fetchall(),
|
||||||
"""
|
"packages": arcsong_cursor.execute(
|
||||||
SELECT sid, alias
|
"SELECT id, name FROM packages"
|
||||||
FROM alias
|
).fetchall(),
|
||||||
"""
|
}
|
||||||
).fetchall()
|
|
||||||
data["packages"] = arcsong_cursor.execute(
|
|
||||||
"""
|
|
||||||
SELECT id, name
|
|
||||||
FROM packages
|
|
||||||
"""
|
|
||||||
).fetchall()
|
|
||||||
|
|
||||||
with self.conn as conn:
|
with self.conn as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
for table in data:
|
for table, rows in data.items():
|
||||||
columns = [
|
columns = [
|
||||||
row[0]
|
row[0]
|
||||||
for row in cursor.execute(
|
for row in cursor.execute(
|
||||||
@ -65,15 +58,16 @@ class Database(metaclass=Singleton):
|
|||||||
).description
|
).description
|
||||||
]
|
]
|
||||||
column_count = len(columns)
|
column_count = len(columns)
|
||||||
assert column_count == len(data[table][0])
|
assert column_count == len(
|
||||||
columns_insert_str = ", ".join(columns)
|
rows[0]
|
||||||
values_insert_str = ", ".join("?" * column_count)
|
), f"Incompatible column count for table '{table}'"
|
||||||
|
placeholders = ", ".join(["?" for _ in range(column_count)])
|
||||||
update_clauses = ", ".join(
|
update_clauses = ", ".join(
|
||||||
[f"{column} = excluded.{column}" for column in columns]
|
[f"{column} = excluded.{column}" for column in columns]
|
||||||
)
|
)
|
||||||
cursor.executemany(
|
cursor.executemany(
|
||||||
f"INSERT INTO {table} ({columns_insert_str}) VALUES ({values_insert_str}) ON CONFLICT DO UPDATE SET {update_clauses}",
|
f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({placeholders}) ON CONFLICT DO UPDATE SET {update_clauses}",
|
||||||
data[table],
|
rows,
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
@ -225,7 +219,7 @@ class Database(metaclass=Singleton):
|
|||||||
) AS subquery
|
) AS subquery
|
||||||
WHERE name IS NOT NULL AND name <> ''
|
WHERE name IS NOT NULL AND name <> ''
|
||||||
GROUP BY song_id, name
|
GROUP BY song_id, name
|
||||||
"""
|
""",
|
||||||
]
|
]
|
||||||
|
|
||||||
with self.conn as conn:
|
with self.conn as conn:
|
||||||
@ -243,99 +237,54 @@ class Database(metaclass=Singleton):
|
|||||||
def __get_columns_clause(self, columns: List[str]):
|
def __get_columns_clause(self, columns: List[str]):
|
||||||
return ", ".join([f'"{column}"' for column in columns])
|
return ", ".join([f'"{column}"' for column in columns])
|
||||||
|
|
||||||
def get_packages(self):
|
def __get_table(
|
||||||
with self.conn as conn:
|
self, table_name: str, datacls: T, where_clause: str = "", params=None
|
||||||
|
) -> List[T]:
|
||||||
|
if params is None:
|
||||||
|
params = []
|
||||||
columns_clause = self.__get_columns_clause(
|
columns_clause = self.__get_columns_clause(
|
||||||
self.__get_columns_from_dataclass(DbPackageRow)
|
self.__get_columns_from_dataclass(datacls)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sql = f"SELECT {columns_clause} FROM {table_name}"
|
||||||
|
if where_clause:
|
||||||
|
sql += " WHERE "
|
||||||
|
sql += where_clause
|
||||||
|
with self.conn as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
return [
|
cursor.execute(sql, params)
|
||||||
DbPackageRow(*row)
|
return [datacls(*row) for row in cursor.fetchall()]
|
||||||
for row in cursor.execute(
|
|
||||||
f"SELECT {columns_clause} FROM packages"
|
def get_packages(self):
|
||||||
).fetchall()
|
return self.__get_table("packages", DbPackageRow)
|
||||||
]
|
|
||||||
|
|
||||||
def get_aliases(self):
|
def get_aliases(self):
|
||||||
with self.conn as conn:
|
return self.__get_table("aliases", DbAliasRow)
|
||||||
columns_clause = self.__get_columns_clause(
|
|
||||||
self.__get_columns_from_dataclass(DbAliasRow)
|
|
||||||
)
|
|
||||||
cursor = conn.cursor()
|
|
||||||
return [
|
|
||||||
DbAliasRow(*row)
|
|
||||||
for row in cursor.execute(
|
|
||||||
f"SELECT {columns_clause} FROM aliases"
|
|
||||||
).fetchall()
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_aliases_by_song_id(self, song_id: str):
|
def get_aliases_by_song_id(self, song_id: str):
|
||||||
with self.conn as conn:
|
return self.__get_table("aliases", DbAliasRow, "song_id = ?", (song_id,))
|
||||||
columns_clause = self.__get_columns_clause(
|
|
||||||
self.__get_columns_from_dataclass(DbAliasRow)
|
|
||||||
)
|
|
||||||
cursor = conn.cursor()
|
|
||||||
return [
|
|
||||||
DbAliasRow(*row)
|
|
||||||
for row in (
|
|
||||||
cursor.execute(
|
|
||||||
f"SELECT {columns_clause} FROM aliases WHERE song_id = ?",
|
|
||||||
(song_id,),
|
|
||||||
).fetchall()
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_charts(self):
|
def get_charts(self):
|
||||||
with self.conn as conn:
|
return self.__get_table("charts", DbChartRow)
|
||||||
columns_clause = self.__get_columns_clause(
|
|
||||||
self.__get_columns_from_dataclass(DbChartRow)
|
|
||||||
)
|
|
||||||
cursor = conn.cursor()
|
|
||||||
return [
|
|
||||||
DbChartRow(*row)
|
|
||||||
for row in cursor.execute(
|
|
||||||
f"SELECT {columns_clause} FROM charts"
|
|
||||||
).fetchall()
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_charts_by_song_id(self, song_id: str):
|
def get_charts_by_song_id(self, song_id: str):
|
||||||
with self.conn as conn:
|
return self.__get_table("charts", DbChartRow, "song_id = ?", (song_id,))
|
||||||
columns_clause = self.__get_columns_clause(
|
|
||||||
self.__get_columns_from_dataclass(DbChartRow)
|
|
||||||
)
|
|
||||||
cursor = conn.cursor()
|
|
||||||
return [
|
|
||||||
DbChartRow(*row)
|
|
||||||
for row in (
|
|
||||||
cursor.execute(
|
|
||||||
f"SELECT {columns_clause} FROM charts WHERE song_id = ?",
|
|
||||||
(song_id,),
|
|
||||||
).fetchall()
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_charts_by_package_id(self, package_id: str):
|
def get_charts_by_package_id(self, package_id: str):
|
||||||
with self.conn as conn:
|
return self.__get_table("charts", DbChartRow, "package_id = ?", (package_id,))
|
||||||
columns_clause = self.__get_columns_clause(
|
|
||||||
self.__get_columns_from_dataclass(DbChartRow)
|
|
||||||
)
|
|
||||||
cursor = conn.cursor()
|
|
||||||
return [
|
|
||||||
DbChartRow(*row)
|
|
||||||
for row in cursor.execute(
|
|
||||||
f"SELECT {columns_clause} FROM charts WHERE package_id = ?",
|
|
||||||
(package_id,),
|
|
||||||
).fetchall()
|
|
||||||
]
|
|
||||||
|
|
||||||
class FuzzySearchSongIdResult(NamedTuple):
|
class FuzzySearchSongIdResult(NamedTuple):
|
||||||
song_id: str
|
song_id: str
|
||||||
confidence: int
|
confidence: int
|
||||||
|
|
||||||
def fuzzy_search_song_id(self, input_str: str, limit: int= 5) -> List[FuzzySearchSongIdResult]:
|
def fuzzy_search_song_id(
|
||||||
|
self, input_str: str, limit: int = 5
|
||||||
|
) -> List[FuzzySearchSongIdResult]:
|
||||||
with self.conn as conn:
|
with self.conn as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
db_results = cursor.execute("SELECT song_id, name FROM song_id_names").fetchall()
|
db_results = cursor.execute(
|
||||||
|
"SELECT song_id, name FROM song_id_names"
|
||||||
|
).fetchall()
|
||||||
name_song_id_map = {r[1]: r[0] for r in db_results}
|
name_song_id_map = {r[1]: r[0] for r in db_results}
|
||||||
names = name_song_id_map.keys()
|
names = name_song_id_map.keys()
|
||||||
fuzzy_results = fuzz_process.extractBests(input_str, names, scorer=fuzz.partial_ratio, limit=limit) # type: ignore
|
fuzzy_results = fuzz_process.extractBests(input_str, names, scorer=fuzz.partial_ratio, limit=limit) # type: ignore
|
||||||
@ -346,8 +295,9 @@ class Database(metaclass=Singleton):
|
|||||||
song_id = name_song_id_map[name]
|
song_id = name_song_id_map[name]
|
||||||
results[song_id] = max(confidence, results.get(song_id, 0))
|
results[song_id] = max(confidence, results.get(song_id, 0))
|
||||||
|
|
||||||
return [self.FuzzySearchSongIdResult(si, confi) for si, confi in results.items()]
|
return [
|
||||||
|
self.FuzzySearchSongIdResult(si, confi) for si, confi in results.items()
|
||||||
|
]
|
||||||
|
|
||||||
def get_scores(
|
def get_scores(
|
||||||
self,
|
self,
|
||||||
@ -355,26 +305,18 @@ class Database(metaclass=Singleton):
|
|||||||
song_id: Optional[List[str]] = None,
|
song_id: Optional[List[str]] = None,
|
||||||
rating_class: Optional[List[int]] = None,
|
rating_class: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
columns = ",".join([f"[{field.name}]" for field in fields(DbScoreRow)])
|
|
||||||
where_clauses = []
|
where_clauses = []
|
||||||
params = []
|
params = []
|
||||||
if song_id:
|
if song_id:
|
||||||
where_clauses.append(f"song_id IN ({','.join('?'*len(song_id))})")
|
where_clauses.append(f"song_id IN ({','.join('?'*len(song_id))})")
|
||||||
params.extend(song_id)
|
params.extend(song_id)
|
||||||
if rating_class:
|
if rating_class:
|
||||||
where_clauses.append(
|
where_clauses.append(f"rating_class IN ({','.join('?'*len(rating_class))})")
|
||||||
f"rating_class IN ({','.join('?'*len(rating_class))})"
|
|
||||||
)
|
|
||||||
params.extend(rating_class)
|
params.extend(rating_class)
|
||||||
final_sql = f"SELECT {columns} FROM scores"
|
|
||||||
if where_clauses:
|
return self.__get_table(
|
||||||
final_sql += " WHERE "
|
"scores", DbScoreRow, " AND ".join(where_clauses), params
|
||||||
final_sql += " AND ".join(where_clauses)
|
)
|
||||||
with self.conn as conn:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
return [
|
|
||||||
DbScoreRow(*row) for row in cursor.execute(final_sql, params).fetchall()
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_calculated(
|
def get_calculated(
|
||||||
self,
|
self,
|
||||||
@ -382,27 +324,18 @@ class Database(metaclass=Singleton):
|
|||||||
song_id: Optional[List[str]] = None,
|
song_id: Optional[List[str]] = None,
|
||||||
rating_class: Optional[List[int]] = None,
|
rating_class: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
columns = ",".join([f"[{field.name}]" for field in fields(DbCalculatedRow)])
|
|
||||||
where_clauses = []
|
where_clauses = []
|
||||||
params = []
|
params = []
|
||||||
if song_id:
|
if song_id:
|
||||||
where_clauses.append(f"song_id IN ({','.join('?'*len(song_id))})")
|
where_clauses.append(f"song_id IN ({','.join('?'*len(song_id))})")
|
||||||
params.extend(song_id)
|
params.extend(song_id)
|
||||||
if rating_class:
|
if rating_class:
|
||||||
where_clauses.append(
|
where_clauses.append(f"rating_class IN ({','.join('?'*len(rating_class))})")
|
||||||
f"rating_class IN ({','.join('?'*len(rating_class))})"
|
|
||||||
)
|
|
||||||
params.extend(rating_class)
|
params.extend(rating_class)
|
||||||
final_sql = f"SELECT {columns} FROM calculated"
|
|
||||||
if where_clauses:
|
return self.__get_table(
|
||||||
final_sql += " WHERE "
|
"calculated", DbCalculatedRow, " AND ".join(where_clauses), params
|
||||||
final_sql += " AND ".join(where_clauses)
|
)
|
||||||
with self.conn as conn:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
return [
|
|
||||||
DbCalculatedRow(*row)
|
|
||||||
for row in cursor.execute(final_sql, params).fetchall()
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_b30(self) -> float:
|
def get_b30(self) -> float:
|
||||||
with self.conn as conn:
|
with self.conn as conn:
|
||||||
@ -417,12 +350,14 @@ class Database(metaclass=Singleton):
|
|||||||
def get_potential(self) -> float:
|
def get_potential(self) -> float:
|
||||||
with self.conn as conn:
|
with self.conn as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
return cursor.execute("SELECT potential FROM calculated_potential").fetchone()[0]
|
return cursor.execute(
|
||||||
|
"SELECT potential FROM calculated_potential"
|
||||||
|
).fetchone()[0]
|
||||||
|
|
||||||
def insert_score(self, score: DbScoreRow):
|
def insert_score(self, score: DbScoreRow):
|
||||||
columns = self.__get_columns_from_dataclass(DbScoreRow)
|
columns = self.__get_columns_from_dataclass(DbScoreRow)
|
||||||
columns_clause = self.__get_columns_clause(columns)
|
columns_clause = self.__get_columns_clause(columns)
|
||||||
params = [score.__getattribute__(column) for column in columns]
|
params = [getattr(score, column) for column in columns]
|
||||||
with self.conn as conn:
|
with self.conn as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
|
0
src/arcaea_offline/utils/__init__.py
Normal file
0
src/arcaea_offline/utils/__init__.py
Normal file
6
src/arcaea_offline/utils/types.py
Normal file
6
src/arcaea_offline/utils/types.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from typing import Any, Protocol, ClassVar, Dict
|
||||||
|
|
||||||
|
class TDataclass(Protocol):
|
||||||
|
__dataclass_fields__: ClassVar[Dict]
|
||||||
|
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
||||||
|
...
|
Loading…
x
Reference in New Issue
Block a user