Skip to content
This repository was archived by the owner on Jan 2, 2024. It is now read-only.

Commit 0e2ae31

Browse files
committed
fix: only create SQL engine when init the SQLRepository
1 parent 64df417 commit 0e2ae31

13 files changed

+148
-223
lines changed

.gitignore

-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ dataset
8282
# Filesystem default local storage
8383
.data/
8484
.my_data/
85-
None
8685

8786
# python notebook
8887
*.ipynb

src/taipy/core/_repository/_sql_repository.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@
1818
from ..common.typing import Converter, Entity, ModelType
1919
from ..exceptions import ModelNotFound
2020
from ._abstract_repository import _AbstractRepository
21-
from .db._init_db import init_db
22-
from .db._sql_session import SessionLocal
21+
from .db._sql_session import _SQLSession
2322

2423

2524
class _SQLRepository(_AbstractRepository[ModelType, Entity]):
26-
def __init__(self, model_type: Type[ModelType], converter: Type[Converter], session=SessionLocal()):
25+
def __init__(self, model_type: Type[ModelType], converter: Type[Converter], session=None):
2726
"""
2827
Holds common methods to be used and extended when the need for saving
2928
dataclasses in a SqlLite database.
@@ -36,10 +35,10 @@ def __init__(self, model_type: Type[ModelType], converter: Type[Converter], sess
3635
converter: A class that handles conversion to and from a database backend
3736
db: An SQLAlchemy session object
3837
"""
38+
SessionLocal = _SQLSession.init_db()
39+
self.db = session or SessionLocal()
3940
self.model_type = model_type
40-
self.db = session
4141
self.converter = converter
42-
init_db()
4342

4443
###############################
4544
# ## Inherited methods ## #

src/taipy/core/_repository/db/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,3 @@
88
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
99
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
1010
# specific language governing permissions and limitations under the License.
11-
from ._init_db import engine

src/taipy/core/_repository/db/_init_db.py

-39
This file was deleted.

src/taipy/core/_repository/db/_sql_session.py

+40-13
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from functools import lru_cache
1313

1414
from sqlalchemy import create_engine
15+
from sqlalchemy.engine import Engine
1516
from sqlalchemy.orm import sessionmaker
1617
from sqlalchemy.pool import StaticPool
1718

@@ -22,22 +23,48 @@
2223
from .._encoder import dumps
2324

2425

26+
class _SQLSession:
27+
_engine = None
28+
_SessionLocal = None
29+
30+
@classmethod
31+
def init_db(cls):
32+
if cls._SessionLocal:
33+
return cls._SessionLocal
34+
35+
cls._engine = _build_engine()
36+
cls._SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=cls._engine)
37+
38+
from ....core._version._version_model import _VersionModel
39+
from ....core.cycle._cycle_model import _CycleModel
40+
from ....core.data._data_model import _DataNodeModel
41+
from ....core.job._job_model import _JobModel
42+
from ....core.scenario._scenario_model import _ScenarioModel
43+
from ....core.task._task_model import _TaskModel
44+
45+
_CycleModel.__table__.create(bind=cls._engine, checkfirst=True)
46+
_DataNodeModel.__table__.create(bind=cls._engine, checkfirst=True)
47+
_JobModel.__table__.create(bind=cls._engine, checkfirst=True)
48+
_ScenarioModel.__table__.create(bind=cls._engine, checkfirst=True)
49+
_TaskModel.__table__.create(bind=cls._engine, checkfirst=True)
50+
_VersionModel.__table__.create(bind=cls._engine, checkfirst=True)
51+
52+
return cls._SessionLocal
53+
54+
2555
@lru_cache
26-
def _build_engine():
56+
def _build_engine() -> Engine:
2757
properties = Config.core.repository_properties
2858
try:
29-
# More sql databases can be easily added in the future
30-
engine = create_engine(
31-
f"sqlite:///{properties.get('db_location')}?check_same_thread=False",
32-
poolclass=StaticPool,
33-
json_serializer=dumps,
34-
json_deserializer=loads,
35-
)
36-
return engine
37-
59+
db_location = properties["db_location"]
3860
except KeyError:
3961
raise MissingRequiredProperty("Missing property db_location")
4062

41-
42-
engine = _build_engine()
43-
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
63+
# More sql databases can be easily added in the future
64+
engine = create_engine(
65+
f"sqlite:///{db_location}?check_same_thread=False",
66+
poolclass=StaticPool,
67+
json_serializer=dumps,
68+
json_deserializer=loads,
69+
)
70+
return engine

tests/conftest.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from sqlalchemy import create_engine, text
2222

2323
from src.taipy.core._orchestrator._orchestrator_factory import _OrchestratorFactory
24-
from src.taipy.core._repository.db import engine
24+
from src.taipy.core._repository.db._sql_session import _build_engine
2525
from src.taipy.core._version._version import _Version
2626
from src.taipy.core._version._version_manager_factory import _VersionManagerFactory
2727
from src.taipy.core._version._version_model import _VersionModel
@@ -188,9 +188,6 @@ def default_multi_sheet_data_frame():
188188

189189
@pytest.fixture(scope="session", autouse=True)
190190
def cleanup_files():
191-
if os.path.exists("None"):
192-
os.remove("None")
193-
194191
yield
195192

196193
if os.path.exists(".data"):
@@ -439,28 +436,29 @@ def sql_engine():
439436
return create_engine("sqlite:///:memory:")
440437

441438

442-
@pytest.fixture()
439+
@pytest.fixture
443440
def init_sql_repo(tmp_sqlite):
444441
Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})
445-
return tmp_sqlite
446442

443+
# Clean SQLite database
444+
engine = _build_engine()
447445

448-
@pytest.fixture(autouse=True)
449-
def clean_sql_db():
450-
_CycleModel.__table__.drop(engine, checkfirst=True)
446+
_CycleModel.__table__.drop(bind=engine, checkfirst=True)
451447
_DataNodeModel.__table__.drop(bind=engine, checkfirst=True)
452448
_JobModel.__table__.drop(bind=engine, checkfirst=True)
453449
_ScenarioModel.__table__.drop(bind=engine, checkfirst=True)
454450
_TaskModel.__table__.drop(bind=engine, checkfirst=True)
455451
_VersionModel.__table__.drop(bind=engine, checkfirst=True)
456452

457-
_CycleModel.__table__.create(engine, checkfirst=True)
453+
_CycleModel.__table__.create(bind=engine, checkfirst=True)
458454
_DataNodeModel.__table__.create(bind=engine, checkfirst=True)
459455
_JobModel.__table__.create(bind=engine, checkfirst=True)
460456
_ScenarioModel.__table__.create(bind=engine, checkfirst=True)
461457
_TaskModel.__table__.create(bind=engine, checkfirst=True)
462458
_VersionModel.__table__.create(bind=engine, checkfirst=True)
463459

460+
return tmp_sqlite
461+
464462

465463
@pytest.fixture
466464
def entities_for_migration():

tests/core/cycle/test_cycle_repositories.py

+9-18
Original file line numberDiff line numberDiff line change
@@ -21,27 +21,24 @@
2121

2222
class TestCycleRepositories:
2323
@pytest.mark.parametrize("repo", [_CycleFSRepository, _CycleSQLRepository])
24-
def test_save_and_load(self, tmpdir, cycle, repo):
24+
def test_save_and_load(self, cycle, repo, init_sql_repo):
2525
repository = repo()
26-
repository.base_path = tmpdir
2726
repository._save(cycle)
2827

2928
obj = repository._load(cycle.id)
3029
assert isinstance(obj, Cycle)
3130

3231
@pytest.mark.parametrize("repo", [_CycleFSRepository, _CycleSQLRepository])
33-
def test_exists(self, tmpdir, cycle, repo):
32+
def test_exists(self, cycle, repo, init_sql_repo):
3433
repository = repo()
35-
repository.base_path = tmpdir
3634
repository._save(cycle)
3735

3836
assert repository._exists(cycle.id)
3937
assert not repository._exists("not-existed-cycle")
4038

4139
@pytest.mark.parametrize("repo", [_CycleFSRepository, _CycleSQLRepository])
42-
def test_load_all(self, tmpdir, cycle, repo):
40+
def test_load_all(self, cycle, repo, init_sql_repo):
4341
repository = repo()
44-
repository.base_path = tmpdir
4542
for i in range(10):
4643
cycle.id = CycleId(f"cycle-{i}")
4744
repository._save(cycle)
@@ -50,9 +47,8 @@ def test_load_all(self, tmpdir, cycle, repo):
5047
assert len(data_nodes) == 10
5148

5249
@pytest.mark.parametrize("repo", [_CycleFSRepository, _CycleSQLRepository])
53-
def test_load_all_with_filters(self, tmpdir, cycle, repo):
50+
def test_load_all_with_filters(self, cycle, repo, init_sql_repo):
5451
repository = repo()
55-
repository.base_path = tmpdir
5652

5753
for i in range(10):
5854
cycle.id = CycleId(f"cycle-{i}")
@@ -63,9 +59,8 @@ def test_load_all_with_filters(self, tmpdir, cycle, repo):
6359
assert len(objs) == 1
6460

6561
@pytest.mark.parametrize("repo", [_CycleSQLRepository])
66-
def test_delete(self, tmpdir, cycle, repo):
62+
def test_delete(self, cycle, repo, init_sql_repo):
6763
repository = repo()
68-
repository.base_path = tmpdir
6964
repository._save(cycle)
7065

7166
repository._delete(cycle.id)
@@ -74,9 +69,8 @@ def test_delete(self, tmpdir, cycle, repo):
7469
repository._load(cycle.id)
7570

7671
@pytest.mark.parametrize("repo", [_CycleFSRepository, _CycleSQLRepository])
77-
def test_delete_all(self, tmpdir, cycle, repo):
72+
def test_delete_all(self, cycle, repo, init_sql_repo):
7873
repository = repo()
79-
repository.base_path = tmpdir
8074

8175
for i in range(10):
8276
cycle.id = CycleId(f"cycle-{i}")
@@ -89,9 +83,8 @@ def test_delete_all(self, tmpdir, cycle, repo):
8983
assert len(repository._load_all()) == 0
9084

9185
@pytest.mark.parametrize("repo", [_CycleFSRepository, _CycleSQLRepository])
92-
def test_delete_many(self, tmpdir, cycle, repo):
86+
def test_delete_many(self, cycle, repo, init_sql_repo):
9387
repository = repo()
94-
repository.base_path = tmpdir
9588

9689
for i in range(10):
9790
cycle.id = CycleId(f"cycle-{i}")
@@ -105,9 +98,8 @@ def test_delete_many(self, tmpdir, cycle, repo):
10598
assert len(repository._load_all()) == 7
10699

107100
@pytest.mark.parametrize("repo", [_CycleFSRepository, _CycleSQLRepository])
108-
def test_search(self, tmpdir, cycle, repo):
101+
def test_search(self, cycle, repo, init_sql_repo):
109102
repository = repo()
110-
repository.base_path = tmpdir
111103

112104
for i in range(10):
113105
cycle.id = CycleId(f"cycle-{i}")
@@ -121,9 +113,8 @@ def test_search(self, tmpdir, cycle, repo):
121113
assert isinstance(objs[0], Cycle)
122114

123115
@pytest.mark.parametrize("repo", [_CycleFSRepository, _CycleSQLRepository])
124-
def test_export(self, tmpdir, cycle, repo):
116+
def test_export(self, tmpdir, cycle, repo, init_sql_repo):
125117
repository = repo()
126-
repository.base_path = tmpdir
127118
repository._save(cycle)
128119

129120
repository._export(cycle.id, tmpdir.strpath)

0 commit comments

Comments
 (0)