1
- import time
2
- from pathlib import Path
1
+ import asyncio
2
+ import signal
3
+ import sys
3
4
4
5
import aiohttp
5
- import alembic .command
6
- import alembic .config
7
6
import fastapi
8
7
import sqlmodel
9
8
from fastapi import Depends , HTTPException
10
9
from fastapi .middleware .cors import CORSMiddleware
11
10
from loguru import logger
12
- from oasst_inference_server import auth , client_handler , deps , models , worker_handler
11
+ from oasst_inference_server import auth , client_handler , database , deps , models , worker_handler
13
12
from oasst_inference_server .schemas import chat as chat_schema
14
13
from oasst_inference_server .schemas import worker as worker_schema
15
14
from oasst_inference_server .settings import settings
@@ -67,19 +66,23 @@ def get_root_token(token: str = Depends(get_bearer_token)) -> str:
67
66
)
68
67
69
68
69
+ def terminate_server (signum , frame ):
70
+ logger .info (f"Signal { signum } . Terminating server..." )
71
+ sys .exit (0 )
72
+
73
+
70
74
@app .on_event ("startup" )
71
- def alembic_upgrade ():
75
+ async def alembic_upgrade ():
76
+ signal .signal (signal .SIGINT , terminate_server )
72
77
if not settings .update_alembic :
73
78
logger .info ("Skipping alembic upgrade on startup (update_alembic is False)" )
74
79
return
75
80
logger .info ("Attempting to upgrade alembic on startup" )
76
81
retry = 0
77
82
while True :
78
83
try :
79
- alembic_ini_path = Path (__file__ ).parent / "alembic.ini"
80
- alembic_cfg = alembic .config .Config (str (alembic_ini_path ))
81
- alembic_cfg .set_main_option ("sqlalchemy.url" , settings .database_uri )
82
- alembic .command .upgrade (alembic_cfg , "head" )
84
+ async with database .make_engine ().begin () as conn :
85
+ await conn .run_sync (database .alembic_upgrade )
83
86
logger .info ("Successfully upgraded alembic on startup" )
84
87
break
85
88
except Exception :
@@ -90,28 +93,26 @@ def alembic_upgrade():
90
93
91
94
timeout = settings .alembic_retry_timeout * 2 ** retry
92
95
logger .warning (f"Retrying alembic upgrade in { timeout } seconds" )
93
- time .sleep (timeout )
96
+ await asyncio .sleep (timeout )
97
+ signal .signal (signal .SIGINT , signal .SIG_DFL )
94
98
95
99
96
100
@app .on_event ("startup" )
97
- def maybe_add_debug_api_keys ():
101
+ async def maybe_add_debug_api_keys ():
98
102
if not settings .debug_api_keys :
99
103
logger .info ("No debug API keys configured, skipping" )
100
104
return
101
105
try :
102
106
logger .info ("Adding debug API keys" )
103
- with deps .manual_create_session () as session :
107
+ async with deps .manual_create_session () as session :
104
108
for api_key in settings .debug_api_keys :
105
109
logger .info (f"Checking if debug API key { api_key } exists" )
106
110
if (
107
- session .exec (
108
- sqlmodel .select (models .DbWorker ).where (models .DbWorker .api_key == api_key )
109
- ).one_or_none ()
110
- is None
111
- ):
111
+ await session .exec (sqlmodel .select (models .DbWorker ).where (models .DbWorker .api_key == api_key ))
112
+ ).one_or_none () is None :
112
113
logger .info (f"Adding debug API key { api_key } " )
113
114
session .add (models .DbWorker (api_key = api_key , name = "Debug API Key" ))
114
- session .commit ()
115
+ await session .commit ()
115
116
else :
116
117
logger .info (f"Debug API key { api_key } already exists" )
117
118
except Exception :
@@ -129,7 +130,7 @@ async def login_discord():
129
130
@app .get ("/auth/callback/discord" , response_model = protocol .Token )
130
131
async def callback_discord (
131
132
code : str ,
132
- db : sqlmodel . Session = Depends (deps .create_session ),
133
+ db : database . AsyncSession = Depends (deps .create_session ),
133
134
):
134
135
redirect_uri = f"{ settings .api_root } /auth/callback/discord"
135
136
@@ -166,15 +167,15 @@ async def callback_discord(
166
167
raise HTTPException (status_code = 400 , detail = "Invalid user info response from Discord" )
167
168
168
169
# Try to find a user in our DB linked to the Discord user
169
- user : models .DbUser = query_user_by_provider_id (db , discord_id = discord_id )
170
+ user : models .DbUser = await query_user_by_provider_id (db , discord_id = discord_id )
170
171
171
172
# Create if no user exists
172
173
if not user :
173
174
user = models .DbUser (provider = "discord" , provider_account_id = discord_id , display_name = discord_username )
174
175
175
176
db .add (user )
176
- db .commit ()
177
- db .refresh (user )
177
+ await db .commit ()
178
+ await db .refresh (user )
178
179
179
180
# Discord account is authenticated and linked to a user; create JWT
180
181
access_token = auth .create_access_token ({"user_id" : user .id })
@@ -188,7 +189,7 @@ async def list_chats(
188
189
) -> chat_schema .ListChatsResponse :
189
190
"""Lists all chats."""
190
191
logger .info ("Listing all chats." )
191
- chats = ucr .get_chats ()
192
+ chats = await ucr .get_chats ()
192
193
chats_list = [chat .to_list_read () for chat in chats ]
193
194
return chat_schema .ListChatsResponse (chats = chats_list )
194
195
@@ -200,7 +201,7 @@ async def create_chat(
200
201
) -> chat_schema .ChatListRead :
201
202
"""Allows a client to create a new chat."""
202
203
logger .info (f"Received { request = } " )
203
- chat = ucr .create_chat ()
204
+ chat = await ucr .create_chat ()
204
205
return chat .to_list_read ()
205
206
206
207
@@ -210,7 +211,7 @@ async def get_chat(
210
211
ucr : UserChatRepository = Depends (deps .create_user_chat_repository ),
211
212
) -> chat_schema .ChatRead :
212
213
"""Allows a client to get the current state of a chat."""
213
- chat = ucr .get_chat_by_id (id )
214
+ chat = await ucr .get_chat_by_id (id )
214
215
return chat .to_read ()
215
216
216
217
@@ -225,45 +226,45 @@ async def get_chat(
225
226
226
227
227
228
@app .put ("/worker" )
228
- def create_worker (
229
+ async def create_worker (
229
230
request : worker_schema .CreateWorkerRequest ,
230
231
root_token : str = Depends (get_root_token ),
231
- session : sqlmodel . Session = Depends (deps .create_session ),
232
- ):
232
+ session : database . AsyncSession = Depends (deps .create_session ),
233
+ ) -> worker_schema . WorkerRead :
233
234
"""Allows a client to register a worker."""
234
235
worker = models .DbWorker (name = request .name )
235
236
session .add (worker )
236
- session .commit ()
237
- session .refresh (worker )
238
- return worker
237
+ await session .commit ()
238
+ await session .refresh (worker )
239
+ return worker_schema . WorkerRead . from_orm ( worker )
239
240
240
241
241
242
@app .get ("/worker" )
242
- def list_workers (
243
+ async def list_workers (
243
244
root_token : str = Depends (get_root_token ),
244
- session : sqlmodel . Session = Depends (deps .create_session ),
245
- ):
245
+ session : database . AsyncSession = Depends (deps .create_session ),
246
+ ) -> list [ worker_schema . WorkerRead ] :
246
247
"""Lists all workers."""
247
- workers = session .exec (sqlmodel .select (models .DbWorker )).all ()
248
- return list ( workers )
248
+ workers = ( await session .exec (sqlmodel .select (models .DbWorker ) )).all ()
249
+ return [ worker_schema . WorkerRead . from_orm ( worker ) for worker in workers ]
249
250
250
251
251
252
@app .delete ("/worker/{worker_id}" )
252
- def delete_worker (
253
+ async def delete_worker (
253
254
worker_id : str ,
254
255
root_token : str = Depends (get_root_token ),
255
- session : sqlmodel . Session = Depends (deps .create_session ),
256
+ session : database . AsyncSession = Depends (deps .create_session ),
256
257
):
257
258
"""Deletes a worker."""
258
- worker = session .get (models .DbWorker , worker_id )
259
+ worker = await session .get (models .DbWorker , worker_id )
259
260
session .delete (worker )
260
- session .commit ()
261
+ await session .commit ()
261
262
return fastapi .Response (status_code = 200 )
262
263
263
264
264
- def query_user_by_provider_id (db : sqlmodel . Session , discord_id : str | None = None ) -> models .DbUser | None :
265
+ async def query_user_by_provider_id (db : database . AsyncSession , discord_id : str | None = None ) -> models .DbUser | None :
265
266
"""Returns the user associated with a given provider ID if any."""
266
- user_qry = db . query (models .DbUser )
267
+ user_qry = sqlmodel . select (models .DbUser )
267
268
268
269
if discord_id :
269
270
user_qry = user_qry .filter (models .DbUser .provider == "discord" ).filter (
@@ -273,12 +274,12 @@ def query_user_by_provider_id(db: sqlmodel.Session, discord_id: str | None = Non
273
274
else :
274
275
return None
275
276
276
- user : models .DbUser = user_qry .first ()
277
+ user : models .DbUser = ( await db . exec ( user_qry )) .first ()
277
278
return user
278
279
279
280
280
281
@app .get ("/auth/login/debug" )
281
- async def login_debug (username : str , db : sqlmodel . Session = Depends (deps .create_session )):
282
+ async def login_debug (username : str , db : database . AsyncSession = Depends (deps .create_session )):
282
283
"""Login using a debug username, which the system will accept unconditionally."""
283
284
284
285
if not settings .allow_debug_auth :
@@ -288,14 +289,16 @@ async def login_debug(username: str, db: sqlmodel.Session = Depends(deps.create_
288
289
raise HTTPException (status_code = 400 , detail = "Username is required" )
289
290
290
291
# Try to find the user
291
- user : models .DbUser = db .exec (sqlmodel .select (models .DbUser ).where (models .DbUser .id == username )).one_or_none ()
292
+ user : models .DbUser = (
293
+ await db .exec (sqlmodel .select (models .DbUser ).where (models .DbUser .id == username ))
294
+ ).one_or_none ()
292
295
293
296
if user is None :
294
297
logger .info (f"Creating new debug user { username = } " )
295
298
user = models .DbUser (id = username , display_name = username , provider = "debug" , provider_account_id = username )
296
299
db .add (user )
297
- db .commit ()
298
- db .refresh (user )
300
+ await db .commit ()
301
+ await db .refresh (user )
299
302
300
303
# Discord account is authenticated and linked to a user; create JWT
301
304
access_token = auth .create_access_token ({"user_id" : user .id })
0 commit comments