forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtext_labels.py
45 lines (38 loc) · 1.38 KB
/
text_labels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.schemas.text_labels import LabelOption, ValidLabelsResponse
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
router = APIRouter()
@router.post("/", status_code=HTTP_204_NO_CONTENT)
def label_text(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
text_labels: protocol_schema.TextLabels,
) -> None:
"""
Label a piece of text.
"""
api_client = deps.api_auth(api_key, db)
try:
logger.info(f"Labeling text {text_labels=}.")
pr = PromptRepository(db, api_client, client_user=text_labels.user)
pr.store_text_labels(text_labels)
except Exception:
logger.exception("Failed to store label.")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
)
@router.get("/valid_labels")
def get_valid_lables() -> ValidLabelsResponse:
return ValidLabelsResponse(
valid_labels=[
LabelOption(name=l.value, display_text=l.display_text, help_text=l.help_text)
for l in protocol_schema.TextLabel
]
)