From ab88b6903c9bdddd7e873e39ba1e8decac012082 Mon Sep 17 00:00:00 2001 From: 283375 Date: Tue, 21 May 2024 21:01:35 +0800 Subject: [PATCH] test: conftest database clean-up --- tests/conftest.py | 44 ++++++++++++++++++++++------ tests/db/models/test_custom_types.py | 4 +-- 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index f723cf1..56fc95c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,27 +1,53 @@ import pytest -from sqlalchemy import create_engine +from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker # region sqlalchemy fixtures -# from https://medium.com/@vittorio.camisa/agile-database-integration-tests-with-python-sqlalchemy-and-factory-boy-6824e8fe33a1 engine = create_engine("sqlite:///:memory:") Session = sessionmaker() -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def db_conn(): - connection = engine.connect() - yield connection - connection.close() + conn = engine.connect() + yield conn + conn.close() -@pytest.fixture(scope="function") +@pytest.fixture() def db_session(db_conn): - transaction = db_conn.begin() session = Session(bind=db_conn) yield session session.close() - transaction.rollback() + + # drop everything + query_tables = db_conn.execute( + text("SELECT name FROM sqlite_master WHERE type='table'") + ).fetchall() + for row in query_tables: + table_name = row[0] + db_conn.execute(text(f"DROP TABLE {table_name}")) + + query_views = db_conn.execute( + text("SELECT name FROM sqlite_master WHERE type='view'") + ).fetchall() + for row in query_views: + view_name = row[0] + db_conn.execute(text(f"DROP VIEW {view_name}")) + + query_indexes = db_conn.execute( + text("SELECT name FROM sqlite_master WHERE type='index'") + ).fetchall() + for row in query_indexes: + index_name = row[0] + db_conn.execute(text(f"DROP INDEX {index_name}")) + + query_triggers = db_conn.execute( + text("SELECT name FROM sqlite_master WHERE type='trigger'") + ).fetchall() + for row in query_triggers: + trigger_name = row[0] + db_conn.execute(text(f"DROP TRIGGER {trigger_name}")) # endregion diff --git a/tests/db/models/test_custom_types.py b/tests/db/models/test_custom_types.py index 8b77b0f..077fdd2 100644 --- a/tests/db/models/test_custom_types.py +++ b/tests/db/models/test_custom_types.py @@ -41,7 +41,7 @@ class TestCustomTypes: ) ).one()[0] - TestBase.metadata.create_all(db_session.bind) + TestBase.metadata.create_all(db_session.bind, checkfirst=False) basic_obj = IntEnumTestModel(id=1, value=TestIntEnum.TWO) null_obj = IntEnumTestModel(id=2, value=None) @@ -53,7 +53,7 @@ class TestCustomTypes: assert _query_value(2) is None def test_tz_datetime(self, db_session): - TestBase.metadata.create_all(db_session.bind) + TestBase.metadata.create_all(db_session.bind, checkfirst=False) dt1 = datetime.now(tz=timezone(timedelta(hours=8)))