From da109e7cb5a0e15e6d7fc71b8911faf7f8fc098f Mon Sep 17 00:00:00 2001 From: 283375 Date: Sun, 6 Aug 2023 15:40:43 +0800 Subject: [PATCH] wip: db --- src/arcaea_offline/calculate.py | 1 + src/arcaea_offline/database.py | 46 +++++++++++++++++++++++++++------ 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/src/arcaea_offline/calculate.py b/src/arcaea_offline/calculate.py index eaed7ab..d9f72ac 100644 --- a/src/arcaea_offline/calculate.py +++ b/src/arcaea_offline/calculate.py @@ -36,6 +36,7 @@ def calculate_score(chart: Chart, score: Score) -> Calculated: assert isinstance(potential, Decimal) return Calculated( + id=score.id, song_id=chart.song_id, rating_class=chart.rating_class, score=score.score, diff --git a/src/arcaea_offline/database.py b/src/arcaea_offline/database.py index 5373676..95dd992 100644 --- a/src/arcaea_offline/database.py +++ b/src/arcaea_offline/database.py @@ -2,6 +2,7 @@ import atexit import os import sqlite3 from dataclasses import fields, is_dataclass +from functools import wraps from typing import Callable, List, NamedTuple, Optional, TypeVar, Union from thefuzz import fuzz @@ -19,7 +20,8 @@ from .models import ( from .utils.singleton import Singleton from .utils.types import TDataclass -T = TypeVar("T", bound=TDataclass) +TC = TypeVar("TC", bound=Callable) +TD = TypeVar("TD", bound=TDataclass) class Database(metaclass=Singleton): @@ -27,11 +29,8 @@ class Database(metaclass=Singleton): dbFilename = "arcaea_offline.db" def __init__(self): - self.__conn = sqlite3.connect(os.path.join(self.dbDir, self.dbFilename)) - self.__conn.execute("PRAGMA journal_mode = WAL;") - self.__conn.execute("PRAGMA foreign_keys = ON;") - - atexit.register(self.__conn.close) + self.__conn: sqlite3.Connection = None # type: ignore + self.renew_conn() self.__update_hooks = [] @@ -39,6 +38,27 @@ class Database(metaclass=Singleton): def conn(self): return self.__conn + def renew_conn(self): + if self.__conn: + atexit.unregister(self.__conn.close) + + self.__conn = sqlite3.connect(os.path.join(self.dbDir, self.dbFilename)) + self.__conn.execute("PRAGMA journal_mode = WAL;") + self.__conn.execute("PRAGMA foreign_keys = ON;") + atexit.register(self.__conn.close) + + def _check_conn(self): + if not isinstance(self.__conn, sqlite3.Connection): + raise ValueError("Database not connected") + + def check_conn(func: TC) -> TC: # type: ignore + @wraps(func) + def wrapper(self, *args, **kwargs): + self._check_conn() + return func(self, *args, **kwargs) + + return wrapper # type: ignore + def register_update_hook(self, hook: Callable) -> bool: if callable(hook): if hook not in self.__update_hooks: @@ -56,6 +76,7 @@ class Database(metaclass=Singleton): for hook in self.__update_hooks: hook() + @check_conn def update_arcsong_db(self, path: Union[str, bytes]): with sqlite3.connect(path) as arcsong_conn: arcsong_cursor = arcsong_conn.cursor() @@ -93,7 +114,9 @@ class Database(metaclass=Singleton): rows, ) conn.commit() + self.__trigger_update_hooks() + @check_conn def init(self): create_sqls = INIT_SQLS[1]["init"] @@ -112,9 +135,10 @@ class Database(metaclass=Singleton): def __get_columns_clause(self, columns: List[str]): return ", ".join([f'"{column}"' for column in columns]) + @check_conn def __get_table( - self, table_name: str, datacls: T, where_clause: str = "", params=None - ) -> List[T]: + self, table_name: str, datacls: TD, where_clause: str = "", params=None + ) -> List[TD]: if params is None: params = [] columns_clause = self.__get_columns_clause( @@ -162,6 +186,7 @@ class Database(metaclass=Singleton): (song_id, rating_class), )[0] + @check_conn def validate_song_id(self, song_id): with self.conn as conn: cursor = conn.cursor() @@ -169,6 +194,7 @@ class Database(metaclass=Singleton): result = cursor.fetchone() return result[0] > 0 + @check_conn def validate_chart(self, song_id: str, rating_class: int): with self.conn as conn: cursor = conn.cursor() @@ -183,6 +209,7 @@ class Database(metaclass=Singleton): song_id: str confidence: int + @check_conn def fuzzy_search_song_id( self, input_str: str, limit: int = 5 ) -> List[FuzzySearchSongIdResult]: @@ -247,6 +274,7 @@ class Database(metaclass=Singleton): "calculated", DbCalculatedRow, " AND ".join(where_clauses), params ) + @check_conn def get_b30(self) -> float: with self.conn as conn: cursor = conn.cursor() @@ -265,6 +293,7 @@ class Database(metaclass=Singleton): conn.commit() self.__trigger_update_hooks() + @check_conn def update_score(self, score_id: int, new_score: ScoreInsert): # ensure we are only updating 1 row scores = self.get_scores(score_id=score_id) @@ -282,6 +311,7 @@ class Database(metaclass=Singleton): conn.commit() self.__trigger_update_hooks() + @check_conn def delete_score(self, score_id: int): with self.conn as conn: cursor = conn.cursor()