Skip to content

Commit 410036d

Browse files
authored
Cancelling pending chat messages of user on submission of a new message (LAION-AI#2141)
1 parent 5f492e4 commit 410036d

File tree

5 files changed

+38
-0
lines changed

5 files changed

+38
-0
lines changed

inference/server/oasst_inference_server/chat_repository.py

+4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import sqlmodel
77
from loguru import logger
88
from oasst_inference_server import database, models
9+
from oasst_inference_server.schemas import chat as chat_schema
910
from oasst_shared.schemas import inference
1011

1112

@@ -30,6 +31,9 @@ async def start_work(
3031
logger.debug(f"Starting work on message {message_id}")
3132
message = await self.get_assistant_message_by_id(message_id)
3233

34+
if message.state == inference.MessageState.cancelled:
35+
raise chat_schema.MessageCancelledException(message_id=message_id)
36+
3337
if message.state != inference.MessageState.pending:
3438
raise fastapi.HTTPException(status_code=400, detail="Message is not pending")
3539

inference/server/oasst_inference_server/routes/workers.py

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from fastapi import Depends
99
from loguru import logger
1010
from oasst_inference_server import chat_repository, database, deps, models, queueing, worker_utils
11+
from oasst_inference_server.schemas import chat as chat_schema
1112
from oasst_inference_server.settings import settings
1213
from oasst_shared.schemas import inference
1314

@@ -148,6 +149,8 @@ def _add_receive(ftrs: set):
148149
work_request_map[work_request.id] = WorkRequestContainer(
149150
work_request=work_request, message_id=message_id
150151
)
152+
except chat_schema.MessageCancelledException as e:
153+
logger.warning(f"Message was cancelled before work could be initiated: {e.message_id=}")
151154
finally:
152155
_add_dequeue(pending_futures)
153156
else:

inference/server/oasst_inference_server/schemas/chat.py

+6
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,9 @@ class ChatRead(ChatListRead):
7070

7171
class ListChatsResponse(pydantic.BaseModel):
7272
chats: list[ChatListRead]
73+
74+
75+
class MessageCancelledException(Exception):
76+
def __init__(self, message_id: str):
77+
super().__init__(f"Message {message_id} was cancelled")
78+
self.message_id = message_id

inference/server/oasst_inference_server/user_chat_repository.py

+24
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ async def create_chat(self) -> models.DbChat:
5959

6060
async def add_prompter_message(self, chat_id: str, parent_id: str | None, content: str) -> models.DbMessage:
6161
logger.info(f"Adding prompter message {len(content)=} to chat {chat_id}")
62+
6263
chat: models.DbChat = (
6364
await self.session.exec(
6465
sqlmodel.select(models.DbChat)
@@ -110,6 +111,29 @@ async def initiate_assistant_message(
110111
self, parent_id: str, work_parameters: inference.WorkParameters
111112
) -> models.DbMessage:
112113
logger.info(f"Adding stub assistant message to {parent_id=}")
114+
115+
# find and cancel all pending messages by this user
116+
pending_msg_query = (
117+
sqlmodel.select(models.DbMessage)
118+
.where(
119+
models.DbMessage.role == "assistant",
120+
models.DbMessage.state == inference.MessageState.pending,
121+
)
122+
.join(models.DbChat)
123+
.where(
124+
models.DbChat.user_id == self.user_id,
125+
)
126+
)
127+
128+
pending_msgs: list[models.DbMessage] = (await self.session.exec(pending_msg_query)).all()
129+
for pending_msg in pending_msgs:
130+
logger.warning(
131+
f"User {self.user_id} has a pending message {pending_msg.id} in chat {pending_msg.chat_id}. Cancelling..."
132+
)
133+
pending_msg.state = inference.MessageState.cancelled
134+
await self.session.commit()
135+
logger.debug(f"Cancelled message {pending_msg.id} in chat {pending_msg.chat_id}.")
136+
113137
query = (
114138
sqlmodel.select(models.DbMessage)
115139
.options(sqlalchemy.orm.selectinload(models.DbMessage.chat))

oasst-shared/oasst_shared/schemas/inference.py

+1
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ class MessageState(str, enum.Enum):
144144
in_progress = "in_progress"
145145
complete = "complete"
146146
aborted_by_worker = "aborted_by_worker"
147+
cancelled = "cancelled"
147148

148149

149150
class MessageRead(pydantic.BaseModel):

0 commit comments

Comments
 (0)