forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatabase_utils.py
142 lines (121 loc) · 5.13 KB
/
database_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from enum import IntEnum
from functools import wraps
from http import HTTPStatus
from typing import Callable
from loguru import logger
from oasst_backend.config import settings
from oasst_backend.database import engine
from oasst_shared.exceptions import OasstError, OasstErrorCode
from sqlalchemy.exc import OperationalError
from sqlmodel import Session, SQLModel
class CommitMode(IntEnum):
"""
Commit modes for the managed tx methods
"""
NONE = 0
FLUSH = 1
COMMIT = 2
ROLLBACK = 3
"""
* managed_tx_method and async_managed_tx_method methods are decorators functions
* to be used on class functions. It expects the Class to have a 'db' Session object
* initialised
"""
def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT):
def decorator(f):
@wraps(f)
def wrapped_f(self, *args, **kwargs):
try:
for i in range(num_retries):
try:
result = f(self, *args, **kwargs)
if auto_commit == CommitMode.COMMIT:
self.db.commit()
elif auto_commit == CommitMode.FLUSH:
self.db.flush()
elif auto_commit == CommitMode.ROLLBACK:
self.db.rollback()
if isinstance(result, SQLModel):
self.db.refresh(result)
return result
except OperationalError:
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
self.db.rollback()
raise OasstError(
"DATABASE_MAX_RETIRES_EXHAUSTED",
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
except Exception as e:
logger.error("DB Rollback Failure")
raise e
return wrapped_f
return decorator
def async_managed_tx_method(
auto_commit: CommitMode = CommitMode.COMMIT, num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT
):
def decorator(f):
@wraps(f)
async def wrapped_f(self, *args, **kwargs):
try:
for i in range(num_retries):
try:
result = await f(self, *args, **kwargs)
if auto_commit == CommitMode.COMMIT:
self.db.commit()
elif auto_commit == CommitMode.FLUSH:
self.db.flush()
elif auto_commit == CommitMode.ROLLBACK:
self.db.rollback()
if isinstance(result, SQLModel):
self.db.refresh(result)
return result
except OperationalError:
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
self.db.rollback()
raise OasstError(
"DATABASE_MAX_RETIRES_EXHAUSTED",
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
except Exception as e:
logger.exception("DB Rollback Failure")
raise e
return wrapped_f
return decorator
def default_session_factor() -> Session:
return Session(engine)
def managed_tx_function(
auto_commit: CommitMode = CommitMode.COMMIT,
num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT,
session_factory: Callable[..., Session] = default_session_factor,
):
"""Passes Session object as first argument to wrapped function."""
def decorator(f):
@wraps(f)
def wrapped_f(*args, **kwargs):
try:
for i in range(num_retries):
with session_factory() as session:
try:
result = f(session, *args, **kwargs)
if auto_commit == CommitMode.COMMIT:
session.commit()
elif auto_commit == CommitMode.FLUSH:
session.flush()
elif auto_commit == CommitMode.ROLLBACK:
session.rollback()
return result
except OperationalError:
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
session.rollback()
raise OasstError(
"DATABASE_MAX_RETIRES_EXHAUSTED",
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
except Exception as e:
logger.error("DB Rollback Failure")
raise e
return wrapped_f
return decorator