forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchat_repository.py
91 lines (78 loc) · 3.83 KB
/
chat_repository.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import datetime
import fastapi
import pydantic
import sqlalchemy.orm
import sqlmodel
from loguru import logger
from oasst_inference_server import database, models
from oasst_inference_server.schemas import chat as chat_schema
from oasst_inference_server.settings import settings
from oasst_shared.schemas import inference
class ChatRepository(pydantic.BaseModel):
session: database.AsyncSession
class Config:
arbitrary_types_allowed = True
async def get_assistant_message_by_id(self, message_id: str) -> models.DbMessage:
query = (
sqlmodel.select(models.DbMessage)
.options(sqlalchemy.orm.selectinload(models.DbMessage.reports))
.where(models.DbMessage.id == message_id, models.DbMessage.role == "assistant")
)
message = (await self.session.exec(query)).one()
return message
async def start_work(
self, *, message_id: str, worker_id: str, worker_config: inference.WorkerConfig
) -> models.DbMessage:
logger.debug(f"Starting work on message {message_id}")
message = await self.get_assistant_message_by_id(message_id)
if settings.assistant_message_timeout > 0:
message_age_in_seconds = (datetime.datetime.utcnow() - message.created_at).total_seconds()
if message_age_in_seconds > settings.assistant_message_timeout:
message.state = inference.MessageState.timeout
await self.session.commit()
await self.session.refresh(message)
raise chat_schema.MessageTimeoutException(message=message.to_read())
if message.state == inference.MessageState.cancelled:
raise chat_schema.MessageCancelledException(message_id=message_id)
if message.state != inference.MessageState.pending:
raise fastapi.HTTPException(status_code=400, detail="Message is not pending")
message.state = inference.MessageState.in_progress
message.work_begin_at = datetime.datetime.utcnow()
message.worker_id = worker_id
message.worker_config = worker_config
await self.session.commit()
logger.debug(f"Started work on message {message_id}")
await self.session.refresh(message)
return message
async def reset_work(self, message_id: str) -> models.DbMessage:
logger.warning(f"Resetting work on message {message_id}")
message = await self.get_assistant_message_by_id(message_id)
message.state = inference.MessageState.pending
message.work_begin_at = None
message.worker_id = None
message.worker_compat_hash = None
message.worker_config = None
await self.session.commit()
logger.debug(f"Reset work on message {message_id}")
await self.session.refresh(message)
return message
async def abort_work(self, message_id: str, reason: str) -> models.DbMessage:
logger.warning(f"Aborting work on message {message_id}")
message = await self.get_assistant_message_by_id(message_id)
message.state = inference.MessageState.aborted_by_worker
message.work_end_at = datetime.datetime.utcnow()
message.error = reason
await self.session.commit()
logger.debug(f"Aborted work on message {message_id}")
await self.session.refresh(message)
return message
async def complete_work(self, message_id: str, content: str) -> models.DbMessage:
logger.debug(f"Completing work on message {message_id}")
message = await self.get_assistant_message_by_id(message_id)
message.state = inference.MessageState.complete
message.work_end_at = datetime.datetime.utcnow()
message.content = content
await self.session.commit()
logger.debug(f"Completed work on message {message_id}")
await self.session.refresh(message)
return message