This commit is contained in:
283375 2023-08-06 15:40:43 +08:00
parent 619d4029f8
commit da109e7cb5
2 changed files with 39 additions and 8 deletions

View File

@ -36,6 +36,7 @@ def calculate_score(chart: Chart, score: Score) -> Calculated:
assert isinstance(potential, Decimal)
return Calculated(
id=score.id,
song_id=chart.song_id,
rating_class=chart.rating_class,
score=score.score,

View File

@ -2,6 +2,7 @@ import atexit
import os
import sqlite3
from dataclasses import fields, is_dataclass
from functools import wraps
from typing import Callable, List, NamedTuple, Optional, TypeVar, Union
from thefuzz import fuzz
@ -19,7 +20,8 @@ from .models import (
from .utils.singleton import Singleton
from .utils.types import TDataclass
T = TypeVar("T", bound=TDataclass)
TC = TypeVar("TC", bound=Callable)
TD = TypeVar("TD", bound=TDataclass)
class Database(metaclass=Singleton):
@ -27,11 +29,8 @@ class Database(metaclass=Singleton):
dbFilename = "arcaea_offline.db"
def __init__(self):
self.__conn = sqlite3.connect(os.path.join(self.dbDir, self.dbFilename))
self.__conn.execute("PRAGMA journal_mode = WAL;")
self.__conn.execute("PRAGMA foreign_keys = ON;")
atexit.register(self.__conn.close)
self.__conn: sqlite3.Connection = None # type: ignore
self.renew_conn()
self.__update_hooks = []
@ -39,6 +38,27 @@ class Database(metaclass=Singleton):
def conn(self):
return self.__conn
def renew_conn(self):
if self.__conn:
atexit.unregister(self.__conn.close)
self.__conn = sqlite3.connect(os.path.join(self.dbDir, self.dbFilename))
self.__conn.execute("PRAGMA journal_mode = WAL;")
self.__conn.execute("PRAGMA foreign_keys = ON;")
atexit.register(self.__conn.close)
def _check_conn(self):
if not isinstance(self.__conn, sqlite3.Connection):
raise ValueError("Database not connected")
def check_conn(func: TC) -> TC: # type: ignore
@wraps(func)
def wrapper(self, *args, **kwargs):
self._check_conn()
return func(self, *args, **kwargs)
return wrapper # type: ignore
def register_update_hook(self, hook: Callable) -> bool:
if callable(hook):
if hook not in self.__update_hooks:
@ -56,6 +76,7 @@ class Database(metaclass=Singleton):
for hook in self.__update_hooks:
hook()
@check_conn
def update_arcsong_db(self, path: Union[str, bytes]):
with sqlite3.connect(path) as arcsong_conn:
arcsong_cursor = arcsong_conn.cursor()
@ -93,7 +114,9 @@ class Database(metaclass=Singleton):
rows,
)
conn.commit()
self.__trigger_update_hooks()
@check_conn
def init(self):
create_sqls = INIT_SQLS[1]["init"]
@ -112,9 +135,10 @@ class Database(metaclass=Singleton):
def __get_columns_clause(self, columns: List[str]):
return ", ".join([f'"{column}"' for column in columns])
@check_conn
def __get_table(
self, table_name: str, datacls: T, where_clause: str = "", params=None
) -> List[T]:
self, table_name: str, datacls: TD, where_clause: str = "", params=None
) -> List[TD]:
if params is None:
params = []
columns_clause = self.__get_columns_clause(
@ -162,6 +186,7 @@ class Database(metaclass=Singleton):
(song_id, rating_class),
)[0]
@check_conn
def validate_song_id(self, song_id):
with self.conn as conn:
cursor = conn.cursor()
@ -169,6 +194,7 @@ class Database(metaclass=Singleton):
result = cursor.fetchone()
return result[0] > 0
@check_conn
def validate_chart(self, song_id: str, rating_class: int):
with self.conn as conn:
cursor = conn.cursor()
@ -183,6 +209,7 @@ class Database(metaclass=Singleton):
song_id: str
confidence: int
@check_conn
def fuzzy_search_song_id(
self, input_str: str, limit: int = 5
) -> List[FuzzySearchSongIdResult]:
@ -247,6 +274,7 @@ class Database(metaclass=Singleton):
"calculated", DbCalculatedRow, " AND ".join(where_clauses), params
)
@check_conn
def get_b30(self) -> float:
with self.conn as conn:
cursor = conn.cursor()
@ -265,6 +293,7 @@ class Database(metaclass=Singleton):
conn.commit()
self.__trigger_update_hooks()
@check_conn
def update_score(self, score_id: int, new_score: ScoreInsert):
# ensure we are only updating 1 row
scores = self.get_scores(score_id=score_id)
@ -282,6 +311,7 @@ class Database(metaclass=Singleton):
conn.commit()
self.__trigger_update_hooks()
@check_conn
def delete_score(self, score_id: int):
with self.conn as conn:
cursor = conn.cursor()