forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtext_labels.py
97 lines (85 loc) · 3.12 KB
/
text_labels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.config import settings
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.schemas.text_labels import LabelDescription, ValidLabelsResponse
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
from oasst_shared.exceptions import OasstError
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import TextLabel
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
router = APIRouter()
@router.post("/", status_code=HTTP_204_NO_CONTENT)
def label_text(
*,
api_key: APIKey = Depends(deps.get_api_key),
text_labels: protocol_schema.TextLabels,
) -> None:
"""
Label a piece of text.
"""
@managed_tx_function(CommitMode.COMMIT)
def store_text_labels(session: deps.Session):
api_client = deps.api_auth(api_key, session)
pr = PromptRepository(session, api_client, client_user=text_labels.user)
pr.store_text_labels(text_labels)
try:
logger.info(f"Labeling text {text_labels=}.")
store_text_labels()
except OasstError:
raise
except Exception:
logger.exception("Failed to store label.")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
)
@router.get("/valid_labels")
def get_valid_lables(
*,
message_id: Optional[UUID] = None,
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_api_client),
) -> ValidLabelsResponse:
if message_id:
pr = PromptRepository(db, api_client=api_client)
message = pr.fetch_message(message_id=message_id)
if message.parent_id is None:
valid_labels = settings.tree_manager.labels_initial_prompt
elif message.role == "assistant":
valid_labels = settings.tree_manager.labels_assistant_reply
else:
valid_labels = settings.tree_manager.labels_prompter_reply
else:
valid_labels = [l for l in TextLabel if l != TextLabel.fails_task]
return ValidLabelsResponse(
valid_labels=[
LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text)
for l in valid_labels
]
)
@router.get("/report_labels")
def get_report_lables() -> ValidLabelsResponse:
report_labels = [
TextLabel.spam,
TextLabel.not_appropriate,
TextLabel.pii,
TextLabel.hate_speech,
TextLabel.sexual_content,
TextLabel.moral_judgement,
TextLabel.political_content,
TextLabel.toxicity,
TextLabel.violence,
TextLabel.quality,
]
return ValidLabelsResponse(
valid_labels=[
LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text)
for l in report_labels
]
)