Skip to content

Commit cd16d9c

Browse files
authored
added async postgres to inference (LAION-AI#1961)
1 parent feae209 commit cd16d9c

File tree

13 files changed

+335
-260
lines changed

13 files changed

+335
-260
lines changed

inference/server/alembic.ini

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
5656
# output_encoding = utf-8
5757

5858
# sqlalchemy.url = postgresql://<username>:<password>@<host>/<database_name>
59-
sqlalchemy.url = postgresql://postgres:postgres@localhost:5432/postgres
59+
sqlalchemy.url = postgresql+asyncpg://postgres:postgres@localhost:5432/postgres
6060

6161
[post_write_hooks]
6262
# post_write_hooks defines scripts or Python functions that are run

inference/server/alembic/env.py

+30-9
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import asyncio
12
from logging.config import fileConfig
23

34
import sqlmodel
45
from alembic import context
6+
from loguru import logger
57
from oasst_inference_server import models # noqa: F401
6-
from sqlalchemy import engine_from_config, pool
8+
from sqlalchemy import engine_from_config, pool, text
9+
from sqlalchemy.ext.asyncio import AsyncEngine
710

811
# this is the Alembic Config object, which provides
912
# access to the values within the .ini file in use.
@@ -50,7 +53,16 @@ def run_migrations_offline() -> None:
5053
context.run_migrations()
5154

5255

53-
def run_migrations_online() -> None:
56+
def do_run_migrations(connection):
57+
context.configure(connection=connection, target_metadata=target_metadata)
58+
59+
with context.begin_transaction():
60+
context.get_context()._ensure_version_table()
61+
connection.execute(text("LOCK TABLE alembic_version IN ACCESS EXCLUSIVE MODE"))
62+
context.run_migrations()
63+
64+
65+
async def run_async_migrations() -> None:
5466
"""Run migrations in 'online' mode.
5567
5668
In this scenario we need to create an Engine
@@ -61,18 +73,27 @@ def run_migrations_online() -> None:
6173
config.get_section(config.config_ini_section),
6274
prefix="sqlalchemy.",
6375
poolclass=pool.NullPool,
76+
future=True,
6477
)
6578

66-
with connectable.connect() as connection:
67-
context.configure(connection=connection, target_metadata=target_metadata)
79+
connectable = AsyncEngine(connectable)
80+
81+
logger.info(f"Running migrations on {connectable.url}")
6882

69-
with context.begin_transaction():
70-
context.get_context()._ensure_version_table()
71-
connection.execute("LOCK TABLE alembic_version IN ACCESS EXCLUSIVE MODE")
72-
context.run_migrations()
83+
async with connectable.connect() as connection:
84+
logger.info("Connected to database")
85+
await connection.run_sync(do_run_migrations)
86+
logger.info("Migrations complete")
87+
logger.info("Disconnecting from database")
88+
await connectable.dispose()
89+
logger.info("Disconnected from database")
7390

7491

7592
if context.is_offline_mode():
7693
run_migrations_offline()
7794
else:
78-
run_migrations_online()
95+
connection = config.attributes.get("connection", None)
96+
if connection is None:
97+
asyncio.run(run_async_migrations())
98+
else:
99+
do_run_migrations(connection)

inference/server/main.py

+51-48
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1-
import time
2-
from pathlib import Path
1+
import asyncio
2+
import signal
3+
import sys
34

45
import aiohttp
5-
import alembic.command
6-
import alembic.config
76
import fastapi
87
import sqlmodel
98
from fastapi import Depends, HTTPException
109
from fastapi.middleware.cors import CORSMiddleware
1110
from loguru import logger
12-
from oasst_inference_server import auth, client_handler, deps, models, worker_handler
11+
from oasst_inference_server import auth, client_handler, database, deps, models, worker_handler
1312
from oasst_inference_server.schemas import chat as chat_schema
1413
from oasst_inference_server.schemas import worker as worker_schema
1514
from oasst_inference_server.settings import settings
@@ -67,19 +66,23 @@ def get_root_token(token: str = Depends(get_bearer_token)) -> str:
6766
)
6867

6968

69+
def terminate_server(signum, frame):
70+
logger.info(f"Signal {signum}. Terminating server...")
71+
sys.exit(0)
72+
73+
7074
@app.on_event("startup")
71-
def alembic_upgrade():
75+
async def alembic_upgrade():
76+
signal.signal(signal.SIGINT, terminate_server)
7277
if not settings.update_alembic:
7378
logger.info("Skipping alembic upgrade on startup (update_alembic is False)")
7479
return
7580
logger.info("Attempting to upgrade alembic on startup")
7681
retry = 0
7782
while True:
7883
try:
79-
alembic_ini_path = Path(__file__).parent / "alembic.ini"
80-
alembic_cfg = alembic.config.Config(str(alembic_ini_path))
81-
alembic_cfg.set_main_option("sqlalchemy.url", settings.database_uri)
82-
alembic.command.upgrade(alembic_cfg, "head")
84+
async with database.make_engine().begin() as conn:
85+
await conn.run_sync(database.alembic_upgrade)
8386
logger.info("Successfully upgraded alembic on startup")
8487
break
8588
except Exception:
@@ -90,28 +93,26 @@ def alembic_upgrade():
9093

9194
timeout = settings.alembic_retry_timeout * 2**retry
9295
logger.warning(f"Retrying alembic upgrade in {timeout} seconds")
93-
time.sleep(timeout)
96+
await asyncio.sleep(timeout)
97+
signal.signal(signal.SIGINT, signal.SIG_DFL)
9498

9599

96100
@app.on_event("startup")
97-
def maybe_add_debug_api_keys():
101+
async def maybe_add_debug_api_keys():
98102
if not settings.debug_api_keys:
99103
logger.info("No debug API keys configured, skipping")
100104
return
101105
try:
102106
logger.info("Adding debug API keys")
103-
with deps.manual_create_session() as session:
107+
async with deps.manual_create_session() as session:
104108
for api_key in settings.debug_api_keys:
105109
logger.info(f"Checking if debug API key {api_key} exists")
106110
if (
107-
session.exec(
108-
sqlmodel.select(models.DbWorker).where(models.DbWorker.api_key == api_key)
109-
).one_or_none()
110-
is None
111-
):
111+
await session.exec(sqlmodel.select(models.DbWorker).where(models.DbWorker.api_key == api_key))
112+
).one_or_none() is None:
112113
logger.info(f"Adding debug API key {api_key}")
113114
session.add(models.DbWorker(api_key=api_key, name="Debug API Key"))
114-
session.commit()
115+
await session.commit()
115116
else:
116117
logger.info(f"Debug API key {api_key} already exists")
117118
except Exception:
@@ -129,7 +130,7 @@ async def login_discord():
129130
@app.get("/auth/callback/discord", response_model=protocol.Token)
130131
async def callback_discord(
131132
code: str,
132-
db: sqlmodel.Session = Depends(deps.create_session),
133+
db: database.AsyncSession = Depends(deps.create_session),
133134
):
134135
redirect_uri = f"{settings.api_root}/auth/callback/discord"
135136

@@ -166,15 +167,15 @@ async def callback_discord(
166167
raise HTTPException(status_code=400, detail="Invalid user info response from Discord")
167168

168169
# Try to find a user in our DB linked to the Discord user
169-
user: models.DbUser = query_user_by_provider_id(db, discord_id=discord_id)
170+
user: models.DbUser = await query_user_by_provider_id(db, discord_id=discord_id)
170171

171172
# Create if no user exists
172173
if not user:
173174
user = models.DbUser(provider="discord", provider_account_id=discord_id, display_name=discord_username)
174175

175176
db.add(user)
176-
db.commit()
177-
db.refresh(user)
177+
await db.commit()
178+
await db.refresh(user)
178179

179180
# Discord account is authenticated and linked to a user; create JWT
180181
access_token = auth.create_access_token({"user_id": user.id})
@@ -188,7 +189,7 @@ async def list_chats(
188189
) -> chat_schema.ListChatsResponse:
189190
"""Lists all chats."""
190191
logger.info("Listing all chats.")
191-
chats = ucr.get_chats()
192+
chats = await ucr.get_chats()
192193
chats_list = [chat.to_list_read() for chat in chats]
193194
return chat_schema.ListChatsResponse(chats=chats_list)
194195

@@ -200,7 +201,7 @@ async def create_chat(
200201
) -> chat_schema.ChatListRead:
201202
"""Allows a client to create a new chat."""
202203
logger.info(f"Received {request=}")
203-
chat = ucr.create_chat()
204+
chat = await ucr.create_chat()
204205
return chat.to_list_read()
205206

206207

@@ -210,7 +211,7 @@ async def get_chat(
210211
ucr: UserChatRepository = Depends(deps.create_user_chat_repository),
211212
) -> chat_schema.ChatRead:
212213
"""Allows a client to get the current state of a chat."""
213-
chat = ucr.get_chat_by_id(id)
214+
chat = await ucr.get_chat_by_id(id)
214215
return chat.to_read()
215216

216217

@@ -225,45 +226,45 @@ async def get_chat(
225226

226227

227228
@app.put("/worker")
228-
def create_worker(
229+
async def create_worker(
229230
request: worker_schema.CreateWorkerRequest,
230231
root_token: str = Depends(get_root_token),
231-
session: sqlmodel.Session = Depends(deps.create_session),
232-
):
232+
session: database.AsyncSession = Depends(deps.create_session),
233+
) -> worker_schema.WorkerRead:
233234
"""Allows a client to register a worker."""
234235
worker = models.DbWorker(name=request.name)
235236
session.add(worker)
236-
session.commit()
237-
session.refresh(worker)
238-
return worker
237+
await session.commit()
238+
await session.refresh(worker)
239+
return worker_schema.WorkerRead.from_orm(worker)
239240

240241

241242
@app.get("/worker")
242-
def list_workers(
243+
async def list_workers(
243244
root_token: str = Depends(get_root_token),
244-
session: sqlmodel.Session = Depends(deps.create_session),
245-
):
245+
session: database.AsyncSession = Depends(deps.create_session),
246+
) -> list[worker_schema.WorkerRead]:
246247
"""Lists all workers."""
247-
workers = session.exec(sqlmodel.select(models.DbWorker)).all()
248-
return list(workers)
248+
workers = (await session.exec(sqlmodel.select(models.DbWorker))).all()
249+
return [worker_schema.WorkerRead.from_orm(worker) for worker in workers]
249250

250251

251252
@app.delete("/worker/{worker_id}")
252-
def delete_worker(
253+
async def delete_worker(
253254
worker_id: str,
254255
root_token: str = Depends(get_root_token),
255-
session: sqlmodel.Session = Depends(deps.create_session),
256+
session: database.AsyncSession = Depends(deps.create_session),
256257
):
257258
"""Deletes a worker."""
258-
worker = session.get(models.DbWorker, worker_id)
259+
worker = await session.get(models.DbWorker, worker_id)
259260
session.delete(worker)
260-
session.commit()
261+
await session.commit()
261262
return fastapi.Response(status_code=200)
262263

263264

264-
def query_user_by_provider_id(db: sqlmodel.Session, discord_id: str | None = None) -> models.DbUser | None:
265+
async def query_user_by_provider_id(db: database.AsyncSession, discord_id: str | None = None) -> models.DbUser | None:
265266
"""Returns the user associated with a given provider ID if any."""
266-
user_qry = db.query(models.DbUser)
267+
user_qry = sqlmodel.select(models.DbUser)
267268

268269
if discord_id:
269270
user_qry = user_qry.filter(models.DbUser.provider == "discord").filter(
@@ -273,12 +274,12 @@ def query_user_by_provider_id(db: sqlmodel.Session, discord_id: str | None = Non
273274
else:
274275
return None
275276

276-
user: models.DbUser = user_qry.first()
277+
user: models.DbUser = (await db.exec(user_qry)).first()
277278
return user
278279

279280

280281
@app.get("/auth/login/debug")
281-
async def login_debug(username: str, db: sqlmodel.Session = Depends(deps.create_session)):
282+
async def login_debug(username: str, db: database.AsyncSession = Depends(deps.create_session)):
282283
"""Login using a debug username, which the system will accept unconditionally."""
283284

284285
if not settings.allow_debug_auth:
@@ -288,14 +289,16 @@ async def login_debug(username: str, db: sqlmodel.Session = Depends(deps.create_
288289
raise HTTPException(status_code=400, detail="Username is required")
289290

290291
# Try to find the user
291-
user: models.DbUser = db.exec(sqlmodel.select(models.DbUser).where(models.DbUser.id == username)).one_or_none()
292+
user: models.DbUser = (
293+
await db.exec(sqlmodel.select(models.DbUser).where(models.DbUser.id == username))
294+
).one_or_none()
292295

293296
if user is None:
294297
logger.info(f"Creating new debug user {username=}")
295298
user = models.DbUser(id=username, display_name=username, provider="debug", provider_account_id=username)
296299
db.add(user)
297-
db.commit()
298-
db.refresh(user)
300+
await db.commit()
301+
await db.refresh(user)
299302

300303
# Discord account is authenticated and linked to a user; create JWT
301304
access_token = auth.create_access_token({"user_id": user.id})

0 commit comments

Comments
 (0)