forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
120 lines (99 loc) · 3.87 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import asyncio
import signal
import sys
import fastapi
import sqlmodel
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from oasst_inference_server import database, deps, models
from oasst_inference_server.routes import account, admin, auth, chats, configs, workers
from oasst_inference_server.settings import settings
from oasst_shared.schemas import inference
from prometheus_fastapi_instrumentator import Instrumentator
from starlette.middleware.sessions import SessionMiddleware
app = fastapi.FastAPI()
# Allow CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.inference_cors_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Session middleware for authlib
app.add_middleware(SessionMiddleware, secret_key=settings.session_middleware_secret_key)
@app.middleware("http")
async def log_exceptions(request: fastapi.Request, call_next):
try:
response = await call_next(request)
except Exception:
logger.exception("Exception in request")
raise
return response
# add prometheus metrics at /metrics
@app.on_event("startup")
async def enable_prom_metrics():
Instrumentator().instrument(app).expose(app)
@app.on_event("startup")
async def log_inference_protocol_version():
logger.warning(f"Inference protocol version: {inference.INFERENCE_PROTOCOL_VERSION}")
def terminate_server(signum, frame):
logger.warning(f"Signal {signum}. Terminating server...")
sys.exit(0)
@app.on_event("startup")
async def alembic_upgrade():
signal.signal(signal.SIGINT, terminate_server)
if not settings.update_alembic:
logger.warning("Skipping alembic upgrade on startup (update_alembic is False)")
return
logger.warning("Attempting to upgrade alembic on startup")
retry = 0
while True:
try:
async with database.make_engine().begin() as conn:
await conn.run_sync(database.alembic_upgrade)
logger.warning("Successfully upgraded alembic on startup")
break
except Exception:
logger.exception("Alembic upgrade failed on startup")
retry += 1
if retry >= settings.alembic_retries:
raise
timeout = settings.alembic_retry_timeout * 2**retry
logger.warning(f"Retrying alembic upgrade in {timeout} seconds")
await asyncio.sleep(timeout)
signal.signal(signal.SIGINT, signal.SIG_DFL)
@app.on_event("startup")
async def maybe_add_debug_api_keys():
debug_api_keys = settings.debug_api_keys_list
if not debug_api_keys:
logger.warning("No debug API keys configured, skipping")
return
try:
logger.warning("Adding debug API keys")
async with deps.manual_create_session() as session:
for api_key in debug_api_keys:
logger.info(f"Checking if debug API key {api_key} exists")
if (
await session.exec(sqlmodel.select(models.DbWorker).where(models.DbWorker.api_key == api_key))
).one_or_none() is None:
logger.info(f"Adding debug API key {api_key}")
session.add(models.DbWorker(api_key=api_key, name="Debug API Key"))
await session.commit()
else:
logger.info(f"Debug API key {api_key} already exists")
logger.warning("Finished adding debug API keys")
except Exception:
logger.exception("Failed to add debug API keys")
raise
# add routes
app.include_router(account.router)
app.include_router(auth.router)
app.include_router(admin.router)
app.include_router(chats.router)
app.include_router(workers.router)
app.include_router(configs.router)
@app.on_event("startup")
async def welcome_message():
logger.warning("Inference server started")
logger.warning("To stop the server, press Ctrl+C")