forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmessage_tree_state.py
89 lines (70 loc) · 3.47 KB
/
message_tree_state.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from datetime import datetime
from enum import Enum
from typing import Optional
from uuid import UUID
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, Index, SQLModel
class State(str, Enum):
"""States of the Open-Assistant message tree state machine."""
INITIAL_PROMPT_REVIEW = "initial_prompt_review"
"""In this state the message tree consists only of a single initial prompt root node.
Initial prompt labeling tasks will determine if the tree goes into `growing` or
`aborted_low_grade` state."""
GROWING = "growing"
"""Assistant & prompter human demonstrations are collected. Concurrently labeling tasks
are handed out to check if the quality of the replies surpasses the minimum acceptable
quality.
When the required number of messages passing the initial labelling-quality check has been
collected the tree will enter `ranking`. If too many poor-quality labelling responses
are received the tree can also enter the `aborted_low_grade` state."""
RANKING = "ranking"
"""The tree has been successfully populated with the desired number of messages. Ranking
tasks are now handed out for all nodes with more than one child."""
READY_FOR_SCORING = "ready_for_scoring"
"""Required ranking responses have been collected and the scoring algorithm can now
compute the aggregated ranking scores that will appear in the dataset."""
READY_FOR_EXPORT = "ready_for_export"
"""The Scoring algorithm computed rankings scores for all children. The message tree can be
exported as part of an Open-Assistant message tree dataset."""
SCORING_FAILED = "scoring_failed"
"""An exception occurred in the scoring algorithm."""
ABORTED_LOW_GRADE = "aborted_low_grade"
"""The system received too many bad reviews and stopped handing out tasks for this message tree."""
HALTED_BY_MODERATOR = "halted_by_moderator"
"""A moderator decided to manually halt the message tree construction process."""
BACKLOG_RANKING = "backlog_ranking"
"""Imported tree ready to be activated and ranked by users (currently inactive)."""
PROMPT_LOTTERY_WAITING = "prompt_lottery_waiting"
"""Initial prompt has passed spam check, waiting to be drawn to grow."""
VALID_STATES = (
State.INITIAL_PROMPT_REVIEW,
State.GROWING,
State.RANKING,
State.READY_FOR_SCORING,
State.READY_FOR_EXPORT,
State.ABORTED_LOW_GRADE,
State.BACKLOG_RANKING,
)
TERMINAL_STATES = (
State.READY_FOR_EXPORT,
State.ABORTED_LOW_GRADE,
State.SCORING_FAILED,
State.HALTED_BY_MODERATOR,
State.BACKLOG_RANKING,
State.PROMPT_LOTTERY_WAITING,
)
class MessageTreeState(SQLModel, table=True):
__tablename__ = "message_tree_state"
__table_args__ = (Index("ix_message_tree_state__lang__state", "state", "lang", unique=False),)
message_tree_id: UUID = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), primary_key=True)
)
goal_tree_size: int = Field(nullable=False)
max_depth: int = Field(nullable=False)
max_children_count: int = Field(nullable=False)
state: str = Field(nullable=False, max_length=128)
active: bool = Field(nullable=False, index=True)
origin: str = Field(sa_column=sa.Column(sa.String(1024), nullable=True))
won_prompt_lottery_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
lang: str = Field(sa_column=sa.Column(sa.String(32), nullable=False))