mirror of
https://github.com/283375/arcaea-offline.git
synced 2025-04-21 15:00:18 +00:00
refactor: database.py
This commit is contained in:
parent
b23bd2652a
commit
4a3523d380
@ -1,345 +1,22 @@
|
|||||||
import atexit
|
from sqlalchemy import Engine
|
||||||
import os
|
from sqlalchemy.orm import Session
|
||||||
import sqlite3
|
|
||||||
from dataclasses import fields, is_dataclass
|
|
||||||
from functools import wraps
|
|
||||||
from typing import Callable, List, NamedTuple, Optional, TypeVar, Union
|
|
||||||
|
|
||||||
from thefuzz import fuzz
|
from .models.common import *
|
||||||
from thefuzz import process as fuzz_process
|
from .models.scores import *
|
||||||
|
from .models.songs import *
|
||||||
from .external import ExternalScoreItem
|
|
||||||
from .init_sqls import INIT_SQLS
|
|
||||||
from .models import (
|
|
||||||
DbAliasRow,
|
|
||||||
DbCalculatedRow,
|
|
||||||
DbChartRow,
|
|
||||||
DbPackageRow,
|
|
||||||
DbScoreRow,
|
|
||||||
ScoreInsert,
|
|
||||||
)
|
|
||||||
from .utils.singleton import Singleton
|
|
||||||
from .utils.types import TDataclass
|
|
||||||
|
|
||||||
TC = TypeVar("TC", bound=Callable)
|
|
||||||
TD = TypeVar("TD", bound=TDataclass)
|
|
||||||
|
|
||||||
|
|
||||||
class Database(metaclass=Singleton):
|
def init(engine: Engine, checkfirst: bool = True):
|
||||||
dbDir = os.getcwd()
|
# sqlalchemy-utils issue #396
|
||||||
dbFilename = "arcaea_offline.db"
|
# view.create_view() causes DuplicateTableError on Base.metadata.create_all(checkfirst=True)
|
||||||
|
# https://github.com/kvesteri/sqlalchemy-utils/issues/396
|
||||||
|
if checkfirst:
|
||||||
|
ScoresViewBase.metadata.drop_all(engine)
|
||||||
|
|
||||||
def __init__(self):
|
SongsBase.metadata.create_all(engine, checkfirst=checkfirst)
|
||||||
self.__conn: sqlite3.Connection = None # type: ignore
|
ScoresBase.metadata.create_all(engine, checkfirst=checkfirst)
|
||||||
self.renew_conn()
|
ScoresViewBase.metadata.create_all(engine)
|
||||||
|
CommonBase.metadata.create_all(engine, checkfirst=checkfirst)
|
||||||
self.__update_hooks = []
|
with Session(engine) as session:
|
||||||
|
session.add(Property(id="version", value="2"))
|
||||||
@property
|
session.commit()
|
||||||
def conn(self):
|
|
||||||
return self.__conn
|
|
||||||
|
|
||||||
def renew_conn(self):
|
|
||||||
if self.__conn:
|
|
||||||
atexit.unregister(self.__conn.close)
|
|
||||||
|
|
||||||
self.__conn = sqlite3.connect(os.path.join(self.dbDir, self.dbFilename))
|
|
||||||
self.__conn.execute("PRAGMA journal_mode = WAL;")
|
|
||||||
self.__conn.execute("PRAGMA foreign_keys = ON;")
|
|
||||||
atexit.register(self.__conn.close)
|
|
||||||
|
|
||||||
def _check_conn(self):
|
|
||||||
if not isinstance(self.__conn, sqlite3.Connection):
|
|
||||||
raise ValueError("Database not connected")
|
|
||||||
|
|
||||||
def check_conn(func: TC) -> TC: # type: ignore
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(self, *args, **kwargs):
|
|
||||||
self._check_conn()
|
|
||||||
return func(self, *args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper # type: ignore
|
|
||||||
|
|
||||||
def register_update_hook(self, hook: Callable) -> bool:
|
|
||||||
if callable(hook):
|
|
||||||
if hook not in self.__update_hooks:
|
|
||||||
self.__update_hooks.append(hook)
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def unregister_update_hook(self, hook: Callable) -> bool:
|
|
||||||
if hook in self.__update_hooks:
|
|
||||||
self.__update_hooks.remove(hook)
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def __trigger_update_hooks(self):
|
|
||||||
for hook in self.__update_hooks:
|
|
||||||
hook()
|
|
||||||
|
|
||||||
@check_conn
|
|
||||||
def update_arcsong_db(self, path: Union[str, bytes]):
|
|
||||||
with sqlite3.connect(path) as arcsong_conn:
|
|
||||||
arcsong_cursor = arcsong_conn.cursor()
|
|
||||||
data = {
|
|
||||||
"charts": arcsong_cursor.execute(
|
|
||||||
"SELECT song_id, rating_class, name_en, name_jp, artist, bpm, bpm_base, [set], time, side, world_unlock, remote_download, bg, date, version, difficulty, rating, note, chart_designer, jacket_designer, jacket_override, audio_override FROM charts"
|
|
||||||
).fetchall(),
|
|
||||||
"aliases": arcsong_cursor.execute(
|
|
||||||
"SELECT sid, alias FROM alias"
|
|
||||||
).fetchall(),
|
|
||||||
"packages": arcsong_cursor.execute(
|
|
||||||
"SELECT id, name FROM packages"
|
|
||||||
).fetchall(),
|
|
||||||
}
|
|
||||||
|
|
||||||
with self.conn as conn:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
for table, rows in data.items():
|
|
||||||
columns = [
|
|
||||||
row[0]
|
|
||||||
for row in cursor.execute(
|
|
||||||
f"SELECT * FROM {table} LIMIT 1"
|
|
||||||
).description
|
|
||||||
]
|
|
||||||
column_count = len(columns)
|
|
||||||
assert column_count == len(
|
|
||||||
rows[0]
|
|
||||||
), f"Incompatible column count for table '{table}'"
|
|
||||||
placeholders = ", ".join(["?" for _ in range(column_count)])
|
|
||||||
update_clauses = ", ".join(
|
|
||||||
[f"{column} = excluded.{column}" for column in columns]
|
|
||||||
)
|
|
||||||
cursor.executemany(
|
|
||||||
f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({placeholders}) ON CONFLICT DO UPDATE SET {update_clauses}",
|
|
||||||
rows,
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
self.__trigger_update_hooks()
|
|
||||||
|
|
||||||
@check_conn
|
|
||||||
def import_external(self, external_scores: List[ExternalScoreItem]):
|
|
||||||
for external in external_scores:
|
|
||||||
# 2017 Jan. 22, 00:00, UTC+8
|
|
||||||
time = external.time if external.time > -1 else 1485014400
|
|
||||||
pure = external.pure if external.pure > -1 else None
|
|
||||||
far = external.far if external.far > -1 else None
|
|
||||||
lost = external.lost if external.lost > -1 else None
|
|
||||||
max_recall = external.max_recall if external.max_recall > -1 else None
|
|
||||||
clear_type = external.clear_type if external.clear_type > -1 else None
|
|
||||||
|
|
||||||
score = ScoreInsert(
|
|
||||||
song_id=external.song_id,
|
|
||||||
rating_class=external.rating_class,
|
|
||||||
score=external.score,
|
|
||||||
time=time,
|
|
||||||
pure=pure,
|
|
||||||
far=far,
|
|
||||||
lost=lost,
|
|
||||||
max_recall=max_recall,
|
|
||||||
clear_type=clear_type,
|
|
||||||
)
|
|
||||||
self.insert_score(score)
|
|
||||||
|
|
||||||
@check_conn
|
|
||||||
def init(self):
|
|
||||||
create_sqls = INIT_SQLS[1]["init"]
|
|
||||||
|
|
||||||
with self.conn as conn:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
for sql in create_sqls:
|
|
||||||
cursor.execute(sql)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def __get_columns_from_dataclass(self, dataclass) -> List[str]:
|
|
||||||
if is_dataclass(dataclass):
|
|
||||||
dc_fields = fields(dataclass)
|
|
||||||
return [field.name for field in dc_fields]
|
|
||||||
return []
|
|
||||||
|
|
||||||
def __get_columns_clause(self, columns: List[str]):
|
|
||||||
return ", ".join([f'"{column}"' for column in columns])
|
|
||||||
|
|
||||||
@check_conn
|
|
||||||
def __get_table(
|
|
||||||
self, table_name: str, datacls: TD, where_clause: str = "", params=None
|
|
||||||
) -> List[TD]:
|
|
||||||
if params is None:
|
|
||||||
params = []
|
|
||||||
columns_clause = self.__get_columns_clause(
|
|
||||||
self.__get_columns_from_dataclass(datacls)
|
|
||||||
)
|
|
||||||
|
|
||||||
sql = f"SELECT {columns_clause} FROM {table_name}"
|
|
||||||
if where_clause:
|
|
||||||
sql += " WHERE "
|
|
||||||
sql += where_clause
|
|
||||||
with self.conn as conn:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute(sql, params)
|
|
||||||
return [datacls(*row) for row in cursor.fetchall()]
|
|
||||||
|
|
||||||
def get_packages(self):
|
|
||||||
return self.__get_table("packages", DbPackageRow)
|
|
||||||
|
|
||||||
def get_package_by_package_id(self, package_id: str):
|
|
||||||
result = self.__get_table(
|
|
||||||
"packages", DbPackageRow, "package_id = ?", (package_id,)
|
|
||||||
)
|
|
||||||
return result[0] if result else None
|
|
||||||
|
|
||||||
def get_aliases(self):
|
|
||||||
return self.__get_table("aliases", DbAliasRow)
|
|
||||||
|
|
||||||
def get_aliases_by_song_id(self, song_id: str):
|
|
||||||
return self.__get_table("aliases", DbAliasRow, "song_id = ?", (song_id,))
|
|
||||||
|
|
||||||
def get_charts(self):
|
|
||||||
return self.__get_table("charts", DbChartRow)
|
|
||||||
|
|
||||||
def get_charts_by_song_id(self, song_id: str):
|
|
||||||
return self.__get_table("charts", DbChartRow, "song_id = ?", (song_id,))
|
|
||||||
|
|
||||||
def get_charts_by_package_id(self, package_id: str):
|
|
||||||
return self.__get_table("charts", DbChartRow, "package_id = ?", (package_id,))
|
|
||||||
|
|
||||||
def get_chart(self, song_id: str, rating_class: int):
|
|
||||||
return self.__get_table(
|
|
||||||
"charts",
|
|
||||||
DbChartRow,
|
|
||||||
"song_id = ? AND rating_class = ?",
|
|
||||||
(song_id, rating_class),
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
@check_conn
|
|
||||||
def validate_song_id(self, song_id):
|
|
||||||
with self.conn as conn:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute("SELECT COUNT(*) FROM charts WHERE song_id = ?", (song_id,))
|
|
||||||
result = cursor.fetchone()
|
|
||||||
return result[0] > 0
|
|
||||||
|
|
||||||
@check_conn
|
|
||||||
def validate_chart(self, song_id: str, rating_class: int):
|
|
||||||
with self.conn as conn:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
"SELECT COUNT(*) FROM charts WHERE song_id = ? AND rating_class = ?",
|
|
||||||
(song_id, rating_class),
|
|
||||||
)
|
|
||||||
result = cursor.fetchone()
|
|
||||||
return result[0] > 0
|
|
||||||
|
|
||||||
class FuzzySearchSongIdResult(NamedTuple):
|
|
||||||
song_id: str
|
|
||||||
confidence: int
|
|
||||||
|
|
||||||
@check_conn
|
|
||||||
def fuzzy_search_song_id(
|
|
||||||
self, input_str: str, limit: int = 5
|
|
||||||
) -> List[FuzzySearchSongIdResult]:
|
|
||||||
with self.conn as conn:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
db_results = cursor.execute(
|
|
||||||
"SELECT song_id, name FROM song_id_names"
|
|
||||||
).fetchall()
|
|
||||||
name_song_id_map = {r[1]: r[0] for r in db_results}
|
|
||||||
names = name_song_id_map.keys()
|
|
||||||
fuzzy_results = fuzz_process.extractBests(input_str, names, scorer=fuzz.partial_ratio, limit=limit) # type: ignore
|
|
||||||
results = {}
|
|
||||||
for fuzzy_result in fuzzy_results:
|
|
||||||
name = fuzzy_result[0]
|
|
||||||
confidence = fuzzy_result[1]
|
|
||||||
song_id = name_song_id_map[name]
|
|
||||||
results[song_id] = max(confidence, results.get(song_id, 0))
|
|
||||||
|
|
||||||
return [
|
|
||||||
self.FuzzySearchSongIdResult(si, confi) for si, confi in results.items()
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_scores(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
score_id: Optional[int] = None,
|
|
||||||
song_id: Optional[List[str]] = None,
|
|
||||||
rating_class: Optional[List[int]] = None,
|
|
||||||
):
|
|
||||||
where_clauses = []
|
|
||||||
params = []
|
|
||||||
if score_id:
|
|
||||||
where_clauses.append("id = ?")
|
|
||||||
params.append(score_id)
|
|
||||||
if song_id:
|
|
||||||
where_clauses.append(f"song_id IN ({','.join('?'*len(song_id))})")
|
|
||||||
params.extend(song_id)
|
|
||||||
if rating_class:
|
|
||||||
where_clauses.append(f"rating_class IN ({','.join('?'*len(rating_class))})")
|
|
||||||
params.extend(rating_class)
|
|
||||||
|
|
||||||
return self.__get_table(
|
|
||||||
"scores", DbScoreRow, " AND ".join(where_clauses), params
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_calculated(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
song_id: Optional[List[str]] = None,
|
|
||||||
rating_class: Optional[List[int]] = None,
|
|
||||||
):
|
|
||||||
where_clauses = []
|
|
||||||
params = []
|
|
||||||
if song_id:
|
|
||||||
where_clauses.append(f"song_id IN ({','.join('?'*len(song_id))})")
|
|
||||||
params.extend(song_id)
|
|
||||||
if rating_class:
|
|
||||||
where_clauses.append(f"rating_class IN ({','.join('?'*len(rating_class))})")
|
|
||||||
params.extend(rating_class)
|
|
||||||
|
|
||||||
return self.__get_table(
|
|
||||||
"calculated", DbCalculatedRow, " AND ".join(where_clauses), params
|
|
||||||
)
|
|
||||||
|
|
||||||
@check_conn
|
|
||||||
def get_b30(self) -> float:
|
|
||||||
with self.conn as conn:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
return cursor.execute("SELECT b30 FROM calculated_potential").fetchone()[0]
|
|
||||||
|
|
||||||
def insert_score(self, score: ScoreInsert):
|
|
||||||
columns = self.__get_columns_from_dataclass(ScoreInsert)
|
|
||||||
columns_clause = self.__get_columns_clause(columns)
|
|
||||||
params = [getattr(score, column) for column in columns]
|
|
||||||
with self.conn as conn:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
f"INSERT INTO scores({columns_clause}) VALUES ({', '.join('?' * len(params))})",
|
|
||||||
params,
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
self.__trigger_update_hooks()
|
|
||||||
|
|
||||||
@check_conn
|
|
||||||
def update_score(self, score_id: int, new_score: ScoreInsert):
|
|
||||||
# ensure we are only updating 1 row
|
|
||||||
scores = self.get_scores(score_id=score_id)
|
|
||||||
print(score_id)
|
|
||||||
assert len(scores) == 1, "Cannot update multiple or non-existing score(s)"
|
|
||||||
columns = self.__get_columns_from_dataclass(ScoreInsert)
|
|
||||||
params = [getattr(new_score, column) for column in columns] + [score_id]
|
|
||||||
update_columns_param_clause = ", ".join([f"{column} = ?" for column in columns])
|
|
||||||
with self.conn as conn:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
f"UPDATE scores SET {update_columns_param_clause} WHERE id = ?",
|
|
||||||
params,
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
self.__trigger_update_hooks()
|
|
||||||
|
|
||||||
@check_conn
|
|
||||||
def delete_score(self, score_id: int):
|
|
||||||
with self.conn as conn:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute("DELETE FROM scores WHERE id = ?", (score_id,))
|
|
||||||
conn.commit()
|
|
||||||
self.__trigger_update_hooks()
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user