diff --git a/src/arcaea_offline/database.py b/src/arcaea_offline/database.py index 0eb0358..81efec0 100644 --- a/src/arcaea_offline/database.py +++ b/src/arcaea_offline/database.py @@ -1,13 +1,20 @@ import os import sqlite3 from dataclasses import fields, is_dataclass -from typing import List, NamedTuple, Optional, TypeVar, Union +from typing import Callable, List, NamedTuple, Optional, TypeVar, Union from thefuzz import fuzz from thefuzz import process as fuzz_process from .init_sqls import INIT_SQLS -from .models import DbAliasRow, DbCalculatedRow, DbChartRow, DbPackageRow, DbScoreRow +from .models import ( + DbAliasRow, + DbCalculatedRow, + DbChartRow, + DbPackageRow, + DbScoreRow, + ScoreInsert, +) from .utils.singleton import Singleton from .utils.types import TDataclass @@ -23,16 +30,35 @@ class Database(metaclass=Singleton): self.__conn.execute("PRAGMA journal_mode = WAL;") self.__conn.execute("PRAGMA foreign_keys = ON;") + self.__update_hooks = [] + @property def conn(self): return self.__conn + def register_update_hook(self, hook: Callable) -> bool: + if callable(hook): + if hook not in self.__update_hooks: + self.__update_hooks.append(hook) + return True + return False + + def unregister_update_hook(self, hook: Callable) -> bool: + if hook in self.__update_hooks: + self.__update_hooks.remove(hook) + return True + return False + + def __trigger_update_hooks(self): + for hook in self.__update_hooks: + hook() + def update_arcsong_db(self, path: Union[str, bytes]): with sqlite3.connect(path) as arcsong_conn: arcsong_cursor = arcsong_conn.cursor() data = { "charts": arcsong_cursor.execute( - "SELECT song_id, rating_class, name_en, name_jp, artist, bpm, bpm_base, 'set', time, side, world_unlock, remote_download, bg, date, version, difficulty, rating, note, chart_designer, jacket_designer, jacket_override, audio_override FROM charts" + "SELECT song_id, rating_class, name_en, name_jp, artist, bpm, bpm_base, [set], time, side, world_unlock, remote_download, bg, date, version, difficulty, rating, note, chart_designer, jacket_designer, jacket_override, audio_override FROM charts" ).fetchall(), "aliases": arcsong_cursor.execute( "SELECT sid, alias FROM alias" @@ -211,8 +237,8 @@ class Database(metaclass=Singleton): cursor = conn.cursor() return cursor.execute("SELECT b30 FROM calculated_potential").fetchone()[0] - def insert_score(self, score: DbScoreRow): - columns = self.__get_columns_from_dataclass(DbScoreRow) + def insert_score(self, score: ScoreInsert): + columns = self.__get_columns_from_dataclass(ScoreInsert) columns_clause = self.__get_columns_clause(columns) params = [getattr(score, column) for column in columns] with self.conn as conn: @@ -222,3 +248,4 @@ class Database(metaclass=Singleton): params, ) conn.commit() + self.__trigger_update_hooks() diff --git a/src/arcaea_offline/models.py b/src/arcaea_offline/models.py index 6ac438e..c709468 100644 --- a/src/arcaea_offline/models.py +++ b/src/arcaea_offline/models.py @@ -127,6 +127,19 @@ class Score: return DbScoreRow(*values) +@dataclass(kw_only=True) +class ScoreInsert: + song_id: str + rating_class: int + score: int + time: int + pure: Optional[int] = None + far: Optional[int] = None + lost: Optional[int] = None + max_recall: Optional[int] = None + clear_type: Optional[int] = None + + @dataclass class DbCalculatedRow: id: int