refactor: database.py

This commit is contained in:
283375 2023-08-26 01:35:00 +08:00
parent b23bd2652a
commit 4a3523d380
Signed by: 283375
SSH Key Fingerprint: SHA256:UcX0qg6ZOSDOeieKPGokA5h7soykG61nz2uxuQgVLSk

View File

@ -1,345 +1,22 @@
import atexit
import os
import sqlite3
from dataclasses import fields, is_dataclass
from functools import wraps
from typing import Callable, List, NamedTuple, Optional, TypeVar, Union
from sqlalchemy import Engine
from sqlalchemy.orm import Session
from thefuzz import fuzz
from thefuzz import process as fuzz_process
from .external import ExternalScoreItem
from .init_sqls import INIT_SQLS
from .models import (
DbAliasRow,
DbCalculatedRow,
DbChartRow,
DbPackageRow,
DbScoreRow,
ScoreInsert,
)
from .utils.singleton import Singleton
from .utils.types import TDataclass
TC = TypeVar("TC", bound=Callable)
TD = TypeVar("TD", bound=TDataclass)
from .models.common import *
from .models.scores import *
from .models.songs import *
class Database(metaclass=Singleton):
dbDir = os.getcwd()
dbFilename = "arcaea_offline.db"
def init(engine: Engine, checkfirst: bool = True):
# sqlalchemy-utils issue #396
# view.create_view() causes DuplicateTableError on Base.metadata.create_all(checkfirst=True)
# https://github.com/kvesteri/sqlalchemy-utils/issues/396
if checkfirst:
ScoresViewBase.metadata.drop_all(engine)
def __init__(self):
self.__conn: sqlite3.Connection = None # type: ignore
self.renew_conn()
self.__update_hooks = []
@property
def conn(self):
return self.__conn
def renew_conn(self):
if self.__conn:
atexit.unregister(self.__conn.close)
self.__conn = sqlite3.connect(os.path.join(self.dbDir, self.dbFilename))
self.__conn.execute("PRAGMA journal_mode = WAL;")
self.__conn.execute("PRAGMA foreign_keys = ON;")
atexit.register(self.__conn.close)
def _check_conn(self):
if not isinstance(self.__conn, sqlite3.Connection):
raise ValueError("Database not connected")
def check_conn(func: TC) -> TC: # type: ignore
@wraps(func)
def wrapper(self, *args, **kwargs):
self._check_conn()
return func(self, *args, **kwargs)
return wrapper # type: ignore
def register_update_hook(self, hook: Callable) -> bool:
if callable(hook):
if hook not in self.__update_hooks:
self.__update_hooks.append(hook)
return True
return False
def unregister_update_hook(self, hook: Callable) -> bool:
if hook in self.__update_hooks:
self.__update_hooks.remove(hook)
return True
return False
def __trigger_update_hooks(self):
for hook in self.__update_hooks:
hook()
@check_conn
def update_arcsong_db(self, path: Union[str, bytes]):
with sqlite3.connect(path) as arcsong_conn:
arcsong_cursor = arcsong_conn.cursor()
data = {
"charts": arcsong_cursor.execute(
"SELECT song_id, rating_class, name_en, name_jp, artist, bpm, bpm_base, [set], time, side, world_unlock, remote_download, bg, date, version, difficulty, rating, note, chart_designer, jacket_designer, jacket_override, audio_override FROM charts"
).fetchall(),
"aliases": arcsong_cursor.execute(
"SELECT sid, alias FROM alias"
).fetchall(),
"packages": arcsong_cursor.execute(
"SELECT id, name FROM packages"
).fetchall(),
}
with self.conn as conn:
cursor = conn.cursor()
for table, rows in data.items():
columns = [
row[0]
for row in cursor.execute(
f"SELECT * FROM {table} LIMIT 1"
).description
]
column_count = len(columns)
assert column_count == len(
rows[0]
), f"Incompatible column count for table '{table}'"
placeholders = ", ".join(["?" for _ in range(column_count)])
update_clauses = ", ".join(
[f"{column} = excluded.{column}" for column in columns]
)
cursor.executemany(
f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({placeholders}) ON CONFLICT DO UPDATE SET {update_clauses}",
rows,
)
conn.commit()
self.__trigger_update_hooks()
@check_conn
def import_external(self, external_scores: List[ExternalScoreItem]):
for external in external_scores:
# 2017 Jan. 22, 00:00, UTC+8
time = external.time if external.time > -1 else 1485014400
pure = external.pure if external.pure > -1 else None
far = external.far if external.far > -1 else None
lost = external.lost if external.lost > -1 else None
max_recall = external.max_recall if external.max_recall > -1 else None
clear_type = external.clear_type if external.clear_type > -1 else None
score = ScoreInsert(
song_id=external.song_id,
rating_class=external.rating_class,
score=external.score,
time=time,
pure=pure,
far=far,
lost=lost,
max_recall=max_recall,
clear_type=clear_type,
)
self.insert_score(score)
@check_conn
def init(self):
create_sqls = INIT_SQLS[1]["init"]
with self.conn as conn:
cursor = conn.cursor()
for sql in create_sqls:
cursor.execute(sql)
conn.commit()
def __get_columns_from_dataclass(self, dataclass) -> List[str]:
if is_dataclass(dataclass):
dc_fields = fields(dataclass)
return [field.name for field in dc_fields]
return []
def __get_columns_clause(self, columns: List[str]):
return ", ".join([f'"{column}"' for column in columns])
@check_conn
def __get_table(
self, table_name: str, datacls: TD, where_clause: str = "", params=None
) -> List[TD]:
if params is None:
params = []
columns_clause = self.__get_columns_clause(
self.__get_columns_from_dataclass(datacls)
)
sql = f"SELECT {columns_clause} FROM {table_name}"
if where_clause:
sql += " WHERE "
sql += where_clause
with self.conn as conn:
cursor = conn.cursor()
cursor.execute(sql, params)
return [datacls(*row) for row in cursor.fetchall()]
def get_packages(self):
return self.__get_table("packages", DbPackageRow)
def get_package_by_package_id(self, package_id: str):
result = self.__get_table(
"packages", DbPackageRow, "package_id = ?", (package_id,)
)
return result[0] if result else None
def get_aliases(self):
return self.__get_table("aliases", DbAliasRow)
def get_aliases_by_song_id(self, song_id: str):
return self.__get_table("aliases", DbAliasRow, "song_id = ?", (song_id,))
def get_charts(self):
return self.__get_table("charts", DbChartRow)
def get_charts_by_song_id(self, song_id: str):
return self.__get_table("charts", DbChartRow, "song_id = ?", (song_id,))
def get_charts_by_package_id(self, package_id: str):
return self.__get_table("charts", DbChartRow, "package_id = ?", (package_id,))
def get_chart(self, song_id: str, rating_class: int):
return self.__get_table(
"charts",
DbChartRow,
"song_id = ? AND rating_class = ?",
(song_id, rating_class),
)[0]
@check_conn
def validate_song_id(self, song_id):
with self.conn as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM charts WHERE song_id = ?", (song_id,))
result = cursor.fetchone()
return result[0] > 0
@check_conn
def validate_chart(self, song_id: str, rating_class: int):
with self.conn as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT COUNT(*) FROM charts WHERE song_id = ? AND rating_class = ?",
(song_id, rating_class),
)
result = cursor.fetchone()
return result[0] > 0
class FuzzySearchSongIdResult(NamedTuple):
song_id: str
confidence: int
@check_conn
def fuzzy_search_song_id(
self, input_str: str, limit: int = 5
) -> List[FuzzySearchSongIdResult]:
with self.conn as conn:
cursor = conn.cursor()
db_results = cursor.execute(
"SELECT song_id, name FROM song_id_names"
).fetchall()
name_song_id_map = {r[1]: r[0] for r in db_results}
names = name_song_id_map.keys()
fuzzy_results = fuzz_process.extractBests(input_str, names, scorer=fuzz.partial_ratio, limit=limit) # type: ignore
results = {}
for fuzzy_result in fuzzy_results:
name = fuzzy_result[0]
confidence = fuzzy_result[1]
song_id = name_song_id_map[name]
results[song_id] = max(confidence, results.get(song_id, 0))
return [
self.FuzzySearchSongIdResult(si, confi) for si, confi in results.items()
]
def get_scores(
self,
*,
score_id: Optional[int] = None,
song_id: Optional[List[str]] = None,
rating_class: Optional[List[int]] = None,
):
where_clauses = []
params = []
if score_id:
where_clauses.append("id = ?")
params.append(score_id)
if song_id:
where_clauses.append(f"song_id IN ({','.join('?'*len(song_id))})")
params.extend(song_id)
if rating_class:
where_clauses.append(f"rating_class IN ({','.join('?'*len(rating_class))})")
params.extend(rating_class)
return self.__get_table(
"scores", DbScoreRow, " AND ".join(where_clauses), params
)
def get_calculated(
self,
*,
song_id: Optional[List[str]] = None,
rating_class: Optional[List[int]] = None,
):
where_clauses = []
params = []
if song_id:
where_clauses.append(f"song_id IN ({','.join('?'*len(song_id))})")
params.extend(song_id)
if rating_class:
where_clauses.append(f"rating_class IN ({','.join('?'*len(rating_class))})")
params.extend(rating_class)
return self.__get_table(
"calculated", DbCalculatedRow, " AND ".join(where_clauses), params
)
@check_conn
def get_b30(self) -> float:
with self.conn as conn:
cursor = conn.cursor()
return cursor.execute("SELECT b30 FROM calculated_potential").fetchone()[0]
def insert_score(self, score: ScoreInsert):
columns = self.__get_columns_from_dataclass(ScoreInsert)
columns_clause = self.__get_columns_clause(columns)
params = [getattr(score, column) for column in columns]
with self.conn as conn:
cursor = conn.cursor()
cursor.execute(
f"INSERT INTO scores({columns_clause}) VALUES ({', '.join('?' * len(params))})",
params,
)
conn.commit()
self.__trigger_update_hooks()
@check_conn
def update_score(self, score_id: int, new_score: ScoreInsert):
# ensure we are only updating 1 row
scores = self.get_scores(score_id=score_id)
print(score_id)
assert len(scores) == 1, "Cannot update multiple or non-existing score(s)"
columns = self.__get_columns_from_dataclass(ScoreInsert)
params = [getattr(new_score, column) for column in columns] + [score_id]
update_columns_param_clause = ", ".join([f"{column} = ?" for column in columns])
with self.conn as conn:
cursor = conn.cursor()
cursor.execute(
f"UPDATE scores SET {update_columns_param_clause} WHERE id = ?",
params,
)
conn.commit()
self.__trigger_update_hooks()
@check_conn
def delete_score(self, score_id: int):
with self.conn as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM scores WHERE id = ?", (score_id,))
conn.commit()
self.__trigger_update_hooks()
SongsBase.metadata.create_all(engine, checkfirst=checkfirst)
ScoresBase.metadata.create_all(engine, checkfirst=checkfirst)
ScoresViewBase.metadata.create_all(engine)
CommonBase.metadata.create_all(engine, checkfirst=checkfirst)
with Session(engine) as session:
session.add(Property(id="version", value="2"))
session.commit()