Skip to content

Commit b693417

Browse files
authored
Inference: Associate chats with user IDs (LAION-AI#1826)
Closes LAION-AI#1788.
1 parent 8a6637f commit b693417

File tree

7 files changed

+68
-28
lines changed

7 files changed

+68
-28
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""Add user id to chats table
2+
3+
Revision ID: b74c66553643
4+
Revises: b365a18db6fd
5+
Create Date: 2023-02-23 22:45:19.946188
6+
"""
7+
import sqlalchemy as sa
8+
from alembic import op
9+
10+
# revision identifiers, used by Alembic.
11+
revision = "b74c66553643"
12+
down_revision = "b365a18db6fd"
13+
branch_labels = None
14+
depends_on = None
15+
16+
17+
def upgrade() -> None:
18+
op.add_column("chat", sa.Column("user_id", sa.String(), sa.ForeignKey("user.id"), nullable=True))
19+
op.create_index(op.f("ix_chat_user_id"), "chat", ["user_id"], unique=False)
20+
21+
22+
def downgrade() -> None:
23+
op.drop_index(op.f("ix_chat_user_id"), table_name="chat")
24+
op.drop_column("chat", "user_id")

inference/server/main.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,9 @@ async def list_chats(
182182
) -> interface.ListChatsResponse:
183183
"""Lists all chats."""
184184
logger.info("Listing all chats.")
185-
chats = cr.get_chat_list()
186-
return interface.ListChatsResponse(chats=chats)
185+
chats = cr.get_chats(user_id)
186+
chats_list = [chat.to_list_read() for chat in chats]
187+
return interface.ListChatsResponse(chats=chats_list)
187188

188189

189190
@app.post("/chat")
@@ -194,7 +195,7 @@ async def create_chat(
194195
) -> interface.ChatListRead:
195196
"""Allows a client to create a new chat."""
196197
logger.info(f"Received {request=}")
197-
chat = cr.create_chat()
198+
chat = cr.create_chat(user_id)
198199
return chat.to_list_read()
199200

200201

@@ -206,6 +207,9 @@ async def get_chat(
206207
) -> interface.ChatRead:
207208
"""Allows a client to get the current state of a chat."""
208209
chat = cr.get_chat_by_id(id)
210+
# currently, user_id will be None if server auth is disabled
211+
if user_id and chat.user_id != user_id:
212+
raise HTTPException(status_code=404, detail="Chat not found")
209213
return chat.to_read()
210214

211215

inference/server/oasst_inference_server/__init__.py

Whitespace-only changes.

inference/server/oasst_inference_server/auth.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,18 @@
1111
oauth2_scheme = APIKeyCookie(name=settings.auth_cookie_name)
1212

1313

14+
def derive_key() -> bytes:
15+
"""Derive a key from the auth secret."""
16+
hkdf = HKDF(
17+
algorithm=hashes.SHA256(),
18+
length=settings.auth_length,
19+
salt=settings.auth_salt,
20+
info=settings.auth_info,
21+
)
22+
key = hkdf.derive(settings.auth_secret)
23+
return key
24+
25+
1426
def create_access_token(data: dict) -> str:
1527
"""Create encoded JSON Web Token (JWT) using the given data."""
1628
expires_delta = timedelta(minutes=settings.auth_access_token_expire_minutes)
@@ -19,32 +31,20 @@ def create_access_token(data: dict) -> str:
1931
to_encode.update({"exp": expire})
2032

2133
# Generate a key from the auth secret
22-
hkdf = HKDF(
23-
algorithm=hashes.SHA256(),
24-
length=settings.auth_length,
25-
salt=settings.auth_salt,
26-
info=settings.auth_info,
27-
)
28-
key = hkdf.derive(settings.auth_secret)
34+
key = derive_key()
2935

3036
# Encrypt the payload using JWE
3137
token: bytes = jwe.encrypt(to_encode, key)
3238
return token.decode()
3339

3440

3541
def get_current_user_id(token: str = Security(oauth2_scheme)) -> str | None:
42+
"""Decode the current user JWT token and return the payload."""
3643
if not settings.use_auth:
3744
return None
3845

39-
"""Decode the current user JWT token and return the payload."""
4046
# Generate a key from the auth secret
41-
hkdf = HKDF(
42-
algorithm=hashes.SHA256(),
43-
length=settings.auth_length,
44-
salt=settings.auth_salt,
45-
info=settings.auth_info,
46-
)
47-
key = hkdf.derive(settings.auth_secret)
47+
key = derive_key()
4848

4949
# Decrypt the JWE token
5050
try:

inference/server/oasst_inference_server/chat_repository.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,11 @@ def maybe_commit(self) -> None:
2020
if self.do_commit:
2121
self.session.commit()
2222

23-
def get_chats(self) -> list[models.DbChat]:
24-
return self.session.exec(sqlmodel.select(models.DbChat)).all()
23+
def get_chats(self, user_id: str | None = None) -> list[models.DbChat]:
24+
query = sqlmodel.select(models.DbChat)
25+
if user_id:
26+
query = query.filter(models.DbChat.user_id == user_id)
27+
return self.session.exec(query).all()
2528

2629
def get_pending_chats(self) -> list[models.DbChat]:
2730
return self.session.exec(
@@ -31,10 +34,6 @@ def get_pending_chats(self) -> list[models.DbChat]:
3134
)
3235
).all()
3336

34-
def get_chat_list(self) -> list[interface.ChatListRead]:
35-
chats = self.get_chats()
36-
return [chat.to_read() for chat in chats]
37-
3837
def get_prompter_message_by_id(self, message_id: str, for_update=False) -> models.DbMessage:
3938
query = sqlmodel.select(models.DbMessage).where(
4039
models.DbMessage.id == message_id, models.DbMessage.role == "prompter"
@@ -53,15 +52,17 @@ def get_assistant_message_by_id(self, message_id: str, for_update=False) -> mode
5352
message = self.session.exec(query).one()
5453
return message
5554

56-
def get_chat_by_id(self, chat_id: str, for_update=False) -> models.DbChat:
55+
def get_chat_by_id(self, chat_id: str, user_id: str | None = None, for_update=False) -> models.DbChat:
5756
query = sqlmodel.select(models.DbChat).where(models.DbChat.id == chat_id)
57+
if user_id:
58+
query = query.where(models.DbChat.user_id == user_id)
5859
if for_update:
5960
query = query.with_for_update()
6061
chat = self.session.exec(query).one()
6162
return chat
6263

63-
def create_chat(self) -> models.DbChat:
64-
chat = models.DbChat()
64+
def create_chat(self, user_id: str) -> models.DbChat:
65+
chat = models.DbChat(user_id=user_id)
6566
self.session.add(chat)
6667
self.maybe_commit()
6768
return chat

inference/server/oasst_inference_server/client_handler.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
11
import fastapi
2+
from fastapi import Depends
23
from loguru import logger
3-
from oasst_inference_server import deps, interface, queueing
4+
from oasst_inference_server import auth, deps, interface, queueing
45
from oasst_shared.schemas import inference
6+
from sqlalchemy.exc import NoResultFound
57
from sse_starlette.sse import EventSourceResponse
68

79

810
async def handle_create_message(
911
chat_id: str,
1012
message_request: interface.MessageRequest,
1113
fastapi_request: fastapi.Request,
14+
user_id: str = Depends(auth.get_current_user_id),
1215
) -> EventSourceResponse:
1316
"""Allows the client to stream the results of a request."""
1417

1518
with deps.manual_chat_repository() as cr:
19+
try:
20+
# Ensure the user can access the chat
21+
cr.get_chat_by_id(chat_id, user_id=user_id)
22+
except NoResultFound:
23+
return fastapi.Response(status_code=404)
24+
1625
try:
1726
prompter_message = cr.add_prompter_message(
1827
chat_id=chat_id, parent_id=message_request.parent_id, content=message_request.content

inference/server/oasst_inference_server/models.py

+2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class DbChat(SQLModel, table=True):
4545

4646
id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)
4747

48+
user_id: str = Field(foreign_key="user.id", index=True)
49+
4850
messages: list[DbMessage] = Relationship(back_populates="chat")
4951

5052
def to_list_read(self) -> interface.ChatListRead:

0 commit comments

Comments
 (0)