test: conftest database clean-up

This commit is contained in:
283375 2024-05-21 21:01:35 +08:00
parent ce715bfccc
commit a27afca8a7
Signed by: 283375
SSH Key Fingerprint: SHA256:UcX0qg6ZOSDOeieKPGokA5h7soykG61nz2uxuQgVLSk
2 changed files with 37 additions and 11 deletions

View File

@ -1,27 +1,53 @@
import pytest import pytest
from sqlalchemy import create_engine from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
# region sqlalchemy fixtures # region sqlalchemy fixtures
# from https://medium.com/@vittorio.camisa/agile-database-integration-tests-with-python-sqlalchemy-and-factory-boy-6824e8fe33a1
engine = create_engine("sqlite:///:memory:") engine = create_engine("sqlite:///:memory:")
Session = sessionmaker() Session = sessionmaker()
@pytest.fixture(scope="module") @pytest.fixture(scope="session")
def db_conn(): def db_conn():
connection = engine.connect() conn = engine.connect()
yield connection yield conn
connection.close() conn.close()
@pytest.fixture(scope="function") @pytest.fixture()
def db_session(db_conn): def db_session(db_conn):
transaction = db_conn.begin()
session = Session(bind=db_conn) session = Session(bind=db_conn)
yield session yield session
session.close() session.close()
transaction.rollback()
# drop everything
query_tables = db_conn.execute(
text("SELECT name FROM sqlite_master WHERE type='table'")
).fetchall()
for row in query_tables:
table_name = row[0]
db_conn.execute(text(f"DROP TABLE {table_name}"))
query_views = db_conn.execute(
text("SELECT name FROM sqlite_master WHERE type='view'")
).fetchall()
for row in query_views:
view_name = row[0]
db_conn.execute(text(f"DROP VIEW {view_name}"))
query_indexes = db_conn.execute(
text("SELECT name FROM sqlite_master WHERE type='index'")
).fetchall()
for row in query_indexes:
index_name = row[0]
db_conn.execute(text(f"DROP INDEX {index_name}"))
query_triggers = db_conn.execute(
text("SELECT name FROM sqlite_master WHERE type='trigger'")
).fetchall()
for row in query_triggers:
trigger_name = row[0]
db_conn.execute(text(f"DROP TRIGGER {trigger_name}"))
# endregion # endregion

View File

@ -41,7 +41,7 @@ class TestCustomTypes:
) )
).one()[0] ).one()[0]
TestBase.metadata.create_all(db_session.bind) TestBase.metadata.create_all(db_session.bind, checkfirst=False)
basic_obj = IntEnumTestModel(id=1, value=TestIntEnum.TWO) basic_obj = IntEnumTestModel(id=1, value=TestIntEnum.TWO)
null_obj = IntEnumTestModel(id=2, value=None) null_obj = IntEnumTestModel(id=2, value=None)
@ -53,7 +53,7 @@ class TestCustomTypes:
assert _query_value(2) is None assert _query_value(2) is None
def test_tz_datetime(self, db_session): def test_tz_datetime(self, db_session):
TestBase.metadata.create_all(db_session.bind) TestBase.metadata.create_all(db_session.bind, checkfirst=False)
dt1 = datetime.now(tz=timezone(timedelta(hours=8))) dt1 = datetime.now(tz=timezone(timedelta(hours=8)))