mirror of
https://github.com/283375/arcaea-offline.git
synced 2025-04-17 21:30:18 +00:00
wip: db
This commit is contained in:
parent
619d4029f8
commit
da109e7cb5
@ -36,6 +36,7 @@ def calculate_score(chart: Chart, score: Score) -> Calculated:
|
||||
assert isinstance(potential, Decimal)
|
||||
|
||||
return Calculated(
|
||||
id=score.id,
|
||||
song_id=chart.song_id,
|
||||
rating_class=chart.rating_class,
|
||||
score=score.score,
|
||||
|
@ -2,6 +2,7 @@ import atexit
|
||||
import os
|
||||
import sqlite3
|
||||
from dataclasses import fields, is_dataclass
|
||||
from functools import wraps
|
||||
from typing import Callable, List, NamedTuple, Optional, TypeVar, Union
|
||||
|
||||
from thefuzz import fuzz
|
||||
@ -19,7 +20,8 @@ from .models import (
|
||||
from .utils.singleton import Singleton
|
||||
from .utils.types import TDataclass
|
||||
|
||||
T = TypeVar("T", bound=TDataclass)
|
||||
TC = TypeVar("TC", bound=Callable)
|
||||
TD = TypeVar("TD", bound=TDataclass)
|
||||
|
||||
|
||||
class Database(metaclass=Singleton):
|
||||
@ -27,11 +29,8 @@ class Database(metaclass=Singleton):
|
||||
dbFilename = "arcaea_offline.db"
|
||||
|
||||
def __init__(self):
|
||||
self.__conn = sqlite3.connect(os.path.join(self.dbDir, self.dbFilename))
|
||||
self.__conn.execute("PRAGMA journal_mode = WAL;")
|
||||
self.__conn.execute("PRAGMA foreign_keys = ON;")
|
||||
|
||||
atexit.register(self.__conn.close)
|
||||
self.__conn: sqlite3.Connection = None # type: ignore
|
||||
self.renew_conn()
|
||||
|
||||
self.__update_hooks = []
|
||||
|
||||
@ -39,6 +38,27 @@ class Database(metaclass=Singleton):
|
||||
def conn(self):
|
||||
return self.__conn
|
||||
|
||||
def renew_conn(self):
|
||||
if self.__conn:
|
||||
atexit.unregister(self.__conn.close)
|
||||
|
||||
self.__conn = sqlite3.connect(os.path.join(self.dbDir, self.dbFilename))
|
||||
self.__conn.execute("PRAGMA journal_mode = WAL;")
|
||||
self.__conn.execute("PRAGMA foreign_keys = ON;")
|
||||
atexit.register(self.__conn.close)
|
||||
|
||||
def _check_conn(self):
|
||||
if not isinstance(self.__conn, sqlite3.Connection):
|
||||
raise ValueError("Database not connected")
|
||||
|
||||
def check_conn(func: TC) -> TC: # type: ignore
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
self._check_conn()
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore
|
||||
|
||||
def register_update_hook(self, hook: Callable) -> bool:
|
||||
if callable(hook):
|
||||
if hook not in self.__update_hooks:
|
||||
@ -56,6 +76,7 @@ class Database(metaclass=Singleton):
|
||||
for hook in self.__update_hooks:
|
||||
hook()
|
||||
|
||||
@check_conn
|
||||
def update_arcsong_db(self, path: Union[str, bytes]):
|
||||
with sqlite3.connect(path) as arcsong_conn:
|
||||
arcsong_cursor = arcsong_conn.cursor()
|
||||
@ -93,7 +114,9 @@ class Database(metaclass=Singleton):
|
||||
rows,
|
||||
)
|
||||
conn.commit()
|
||||
self.__trigger_update_hooks()
|
||||
|
||||
@check_conn
|
||||
def init(self):
|
||||
create_sqls = INIT_SQLS[1]["init"]
|
||||
|
||||
@ -112,9 +135,10 @@ class Database(metaclass=Singleton):
|
||||
def __get_columns_clause(self, columns: List[str]):
|
||||
return ", ".join([f'"{column}"' for column in columns])
|
||||
|
||||
@check_conn
|
||||
def __get_table(
|
||||
self, table_name: str, datacls: T, where_clause: str = "", params=None
|
||||
) -> List[T]:
|
||||
self, table_name: str, datacls: TD, where_clause: str = "", params=None
|
||||
) -> List[TD]:
|
||||
if params is None:
|
||||
params = []
|
||||
columns_clause = self.__get_columns_clause(
|
||||
@ -162,6 +186,7 @@ class Database(metaclass=Singleton):
|
||||
(song_id, rating_class),
|
||||
)[0]
|
||||
|
||||
@check_conn
|
||||
def validate_song_id(self, song_id):
|
||||
with self.conn as conn:
|
||||
cursor = conn.cursor()
|
||||
@ -169,6 +194,7 @@ class Database(metaclass=Singleton):
|
||||
result = cursor.fetchone()
|
||||
return result[0] > 0
|
||||
|
||||
@check_conn
|
||||
def validate_chart(self, song_id: str, rating_class: int):
|
||||
with self.conn as conn:
|
||||
cursor = conn.cursor()
|
||||
@ -183,6 +209,7 @@ class Database(metaclass=Singleton):
|
||||
song_id: str
|
||||
confidence: int
|
||||
|
||||
@check_conn
|
||||
def fuzzy_search_song_id(
|
||||
self, input_str: str, limit: int = 5
|
||||
) -> List[FuzzySearchSongIdResult]:
|
||||
@ -247,6 +274,7 @@ class Database(metaclass=Singleton):
|
||||
"calculated", DbCalculatedRow, " AND ".join(where_clauses), params
|
||||
)
|
||||
|
||||
@check_conn
|
||||
def get_b30(self) -> float:
|
||||
with self.conn as conn:
|
||||
cursor = conn.cursor()
|
||||
@ -265,6 +293,7 @@ class Database(metaclass=Singleton):
|
||||
conn.commit()
|
||||
self.__trigger_update_hooks()
|
||||
|
||||
@check_conn
|
||||
def update_score(self, score_id: int, new_score: ScoreInsert):
|
||||
# ensure we are only updating 1 row
|
||||
scores = self.get_scores(score_id=score_id)
|
||||
@ -282,6 +311,7 @@ class Database(metaclass=Singleton):
|
||||
conn.commit()
|
||||
self.__trigger_update_hooks()
|
||||
|
||||
@check_conn
|
||||
def delete_score(self, score_id: int):
|
||||
with self.conn as conn:
|
||||
cursor = conn.cursor()
|
||||
|
Loading…
x
Reference in New Issue
Block a user