forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjournal_writer.py
120 lines (99 loc) · 3.63 KB
/
journal_writer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import enum
from typing import Literal, Optional
from uuid import UUID
from oasst_backend.models import ApiClient, Journal, Task, User
from oasst_backend.models.payload_column_type import PayloadContainer, payload_type
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
from oasst_shared.utils import utcnow
from pydantic import BaseModel
from sqlmodel import Session
class JournalEventType(str, enum.Enum):
"""A label for a piece of text."""
user_created = "user_created"
text_reply_to_message = "text_reply_to_message"
message_rating = "message_rating"
message_ranking = "message_ranking"
@payload_type
class JournalEvent(BaseModel):
type: str
user_id: Optional[UUID]
message_id: Optional[UUID]
task_id: Optional[UUID]
task_type: Optional[str]
@payload_type
class TextReplyEvent(JournalEvent):
type: Literal[JournalEventType.text_reply_to_message] = JournalEventType.text_reply_to_message
length: int
role: str
@payload_type
class RatingEvent(JournalEvent):
type: Literal[JournalEventType.message_rating] = JournalEventType.message_rating
rating: int
@payload_type
class RankingEvent(JournalEvent):
type: Literal[JournalEventType.message_ranking] = JournalEventType.message_ranking
ranking: list[int]
class JournalWriter:
def __init__(self, db: Session, api_client: ApiClient, user: User):
self.db = db
self.api_client = api_client
self.user = user
self.user_id = self.user.id if self.user else None
def log_text_reply(self, task: Task, message_id: Optional[UUID], role: str, length: int) -> Journal:
return self.log(
task_type=task.payload_type,
event_type=JournalEventType.text_reply_to_message,
payload=TextReplyEvent(role=role, length=length),
task_id=task.id,
message_id=message_id,
)
def log_rating(self, task: Task, message_id: Optional[UUID], rating: int) -> Journal:
return self.log(
task_type=task.payload_type,
event_type=JournalEventType.message_rating,
payload=RatingEvent(rating=rating),
task_id=task.id,
message_id=message_id,
)
def log_ranking(self, task: Task, message_id: Optional[UUID], ranking: list[int]) -> Journal:
return self.log(
task_type=task.payload_type,
event_type=JournalEventType.message_ranking,
payload=RankingEvent(ranking=ranking),
task_id=task.id,
message_id=message_id,
)
@managed_tx_method(CommitMode.FLUSH)
def log(
self,
*,
payload: JournalEvent,
task_type: str,
event_type: str = None,
task_id: Optional[UUID] = None,
message_id: Optional[UUID] = None,
commit: bool = True,
) -> Journal:
if event_type is None:
if payload is None:
event_type = "null"
else:
event_type = type(payload).__name__
if payload.user_id is None:
payload.user_id = self.user_id
if payload.message_id is None:
payload.message_id = message_id
if payload.task_id is None:
payload.task_id = task_id
if payload.task_type is None:
payload.task_type = task_type
entry = Journal(
user_id=self.user_id,
api_client_id=self.api_client.id,
created_date=utcnow(),
event_type=event_type,
event_payload=PayloadContainer(payload=payload),
message_id=message_id,
)
self.db.add(entry)
return entry