This commit is contained in:
283375 2023-08-06 15:40:43 +08:00
parent 619d4029f8
commit da109e7cb5
2 changed files with 39 additions and 8 deletions

View File

@ -36,6 +36,7 @@ def calculate_score(chart: Chart, score: Score) -> Calculated:
assert isinstance(potential, Decimal) assert isinstance(potential, Decimal)
return Calculated( return Calculated(
id=score.id,
song_id=chart.song_id, song_id=chart.song_id,
rating_class=chart.rating_class, rating_class=chart.rating_class,
score=score.score, score=score.score,

View File

@ -2,6 +2,7 @@ import atexit
import os import os
import sqlite3 import sqlite3
from dataclasses import fields, is_dataclass from dataclasses import fields, is_dataclass
from functools import wraps
from typing import Callable, List, NamedTuple, Optional, TypeVar, Union from typing import Callable, List, NamedTuple, Optional, TypeVar, Union
from thefuzz import fuzz from thefuzz import fuzz
@ -19,7 +20,8 @@ from .models import (
from .utils.singleton import Singleton from .utils.singleton import Singleton
from .utils.types import TDataclass from .utils.types import TDataclass
T = TypeVar("T", bound=TDataclass) TC = TypeVar("TC", bound=Callable)
TD = TypeVar("TD", bound=TDataclass)
class Database(metaclass=Singleton): class Database(metaclass=Singleton):
@ -27,11 +29,8 @@ class Database(metaclass=Singleton):
dbFilename = "arcaea_offline.db" dbFilename = "arcaea_offline.db"
def __init__(self): def __init__(self):
self.__conn = sqlite3.connect(os.path.join(self.dbDir, self.dbFilename)) self.__conn: sqlite3.Connection = None # type: ignore
self.__conn.execute("PRAGMA journal_mode = WAL;") self.renew_conn()
self.__conn.execute("PRAGMA foreign_keys = ON;")
atexit.register(self.__conn.close)
self.__update_hooks = [] self.__update_hooks = []
@ -39,6 +38,27 @@ class Database(metaclass=Singleton):
def conn(self): def conn(self):
return self.__conn return self.__conn
def renew_conn(self):
if self.__conn:
atexit.unregister(self.__conn.close)
self.__conn = sqlite3.connect(os.path.join(self.dbDir, self.dbFilename))
self.__conn.execute("PRAGMA journal_mode = WAL;")
self.__conn.execute("PRAGMA foreign_keys = ON;")
atexit.register(self.__conn.close)
def _check_conn(self):
if not isinstance(self.__conn, sqlite3.Connection):
raise ValueError("Database not connected")
def check_conn(func: TC) -> TC: # type: ignore
@wraps(func)
def wrapper(self, *args, **kwargs):
self._check_conn()
return func(self, *args, **kwargs)
return wrapper # type: ignore
def register_update_hook(self, hook: Callable) -> bool: def register_update_hook(self, hook: Callable) -> bool:
if callable(hook): if callable(hook):
if hook not in self.__update_hooks: if hook not in self.__update_hooks:
@ -56,6 +76,7 @@ class Database(metaclass=Singleton):
for hook in self.__update_hooks: for hook in self.__update_hooks:
hook() hook()
@check_conn
def update_arcsong_db(self, path: Union[str, bytes]): def update_arcsong_db(self, path: Union[str, bytes]):
with sqlite3.connect(path) as arcsong_conn: with sqlite3.connect(path) as arcsong_conn:
arcsong_cursor = arcsong_conn.cursor() arcsong_cursor = arcsong_conn.cursor()
@ -93,7 +114,9 @@ class Database(metaclass=Singleton):
rows, rows,
) )
conn.commit() conn.commit()
self.__trigger_update_hooks()
@check_conn
def init(self): def init(self):
create_sqls = INIT_SQLS[1]["init"] create_sqls = INIT_SQLS[1]["init"]
@ -112,9 +135,10 @@ class Database(metaclass=Singleton):
def __get_columns_clause(self, columns: List[str]): def __get_columns_clause(self, columns: List[str]):
return ", ".join([f'"{column}"' for column in columns]) return ", ".join([f'"{column}"' for column in columns])
@check_conn
def __get_table( def __get_table(
self, table_name: str, datacls: T, where_clause: str = "", params=None self, table_name: str, datacls: TD, where_clause: str = "", params=None
) -> List[T]: ) -> List[TD]:
if params is None: if params is None:
params = [] params = []
columns_clause = self.__get_columns_clause( columns_clause = self.__get_columns_clause(
@ -162,6 +186,7 @@ class Database(metaclass=Singleton):
(song_id, rating_class), (song_id, rating_class),
)[0] )[0]
@check_conn
def validate_song_id(self, song_id): def validate_song_id(self, song_id):
with self.conn as conn: with self.conn as conn:
cursor = conn.cursor() cursor = conn.cursor()
@ -169,6 +194,7 @@ class Database(metaclass=Singleton):
result = cursor.fetchone() result = cursor.fetchone()
return result[0] > 0 return result[0] > 0
@check_conn
def validate_chart(self, song_id: str, rating_class: int): def validate_chart(self, song_id: str, rating_class: int):
with self.conn as conn: with self.conn as conn:
cursor = conn.cursor() cursor = conn.cursor()
@ -183,6 +209,7 @@ class Database(metaclass=Singleton):
song_id: str song_id: str
confidence: int confidence: int
@check_conn
def fuzzy_search_song_id( def fuzzy_search_song_id(
self, input_str: str, limit: int = 5 self, input_str: str, limit: int = 5
) -> List[FuzzySearchSongIdResult]: ) -> List[FuzzySearchSongIdResult]:
@ -247,6 +274,7 @@ class Database(metaclass=Singleton):
"calculated", DbCalculatedRow, " AND ".join(where_clauses), params "calculated", DbCalculatedRow, " AND ".join(where_clauses), params
) )
@check_conn
def get_b30(self) -> float: def get_b30(self) -> float:
with self.conn as conn: with self.conn as conn:
cursor = conn.cursor() cursor = conn.cursor()
@ -265,6 +293,7 @@ class Database(metaclass=Singleton):
conn.commit() conn.commit()
self.__trigger_update_hooks() self.__trigger_update_hooks()
@check_conn
def update_score(self, score_id: int, new_score: ScoreInsert): def update_score(self, score_id: int, new_score: ScoreInsert):
# ensure we are only updating 1 row # ensure we are only updating 1 row
scores = self.get_scores(score_id=score_id) scores = self.get_scores(score_id=score_id)
@ -282,6 +311,7 @@ class Database(metaclass=Singleton):
conn.commit() conn.commit()
self.__trigger_update_hooks() self.__trigger_update_hooks()
@check_conn
def delete_score(self, score_id: int): def delete_score(self, score_id: int):
with self.conn as conn: with self.conn as conn:
cursor = conn.cursor() cursor = conn.cursor()