Skip to content

Commit a37bf6b

Browse files
committed
added text labels to the API
1 parent db10c52 commit a37bf6b

File tree

8 files changed

+181
-2
lines changed

8 files changed

+181
-2
lines changed

backend/alembic/script.py.mako

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Create Date: ${create_date}
77
"""
88
from alembic import op
99
import sqlalchemy as sa
10+
import sqlmodel
1011
${imports if imports else ""}
1112

1213
# revision identifiers, used by Alembic.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# -*- coding: utf-8 -*-
2+
"""empty message
3+
4+
Revision ID: 067c4002f2d9
5+
Revises: 0daec5f8135f
6+
Create Date: 2022-12-25 17:05:21.208843
7+
8+
"""
9+
import sqlalchemy as sa
10+
import sqlmodel
11+
from alembic import op
12+
from sqlalchemy.dialects import postgresql
13+
14+
# revision identifiers, used by Alembic.
15+
revision = "067c4002f2d9"
16+
down_revision = "0daec5f8135f"
17+
branch_labels = None
18+
depends_on = None
19+
20+
21+
def upgrade() -> None:
22+
# ### commands auto generated by Alembic - please adjust! ###
23+
op.create_table(
24+
"text_labels",
25+
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
26+
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
27+
sa.Column("post_id", postgresql.UUID(as_uuid=True), nullable=True),
28+
sa.Column("labels", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
29+
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
30+
sa.Column("text", sqlmodel.sql.sqltypes.AutoString(length=65536), nullable=False),
31+
sa.ForeignKeyConstraint(
32+
["api_client_id"],
33+
["api_client.id"],
34+
),
35+
sa.ForeignKeyConstraint(
36+
["post_id"],
37+
["post.id"],
38+
),
39+
sa.PrimaryKeyConstraint("id"),
40+
)
41+
# ### end Alembic commands ###
42+
43+
44+
def downgrade() -> None:
45+
# ### commands auto generated by Alembic - please adjust! ###
46+
op.drop_table("text_labels")
47+
# ### end Alembic commands ###

backend/oasst_backend/api/v1/api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22
from fastapi import APIRouter
3-
from oasst_backend.api.v1 import tasks
3+
from oasst_backend.api.v1 import tasks, text_labels
44

55
api_router = APIRouter()
66
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
7+
api_router.include_router(text_labels.router, prefix="/text_labels", tags=["text_labels"])
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# -*- coding: utf-8 -*-
2+
import pydantic
3+
from fastapi import APIRouter, Depends, HTTPException
4+
from fastapi.security.api_key import APIKey
5+
from loguru import logger
6+
from oasst_backend.api import deps
7+
from oasst_backend.prompt_repository import PromptRepository
8+
from oasst_shared.schemas import protocol as protocol_schema
9+
from sqlmodel import Session
10+
from starlette.status import HTTP_400_BAD_REQUEST
11+
12+
router = APIRouter()
13+
14+
15+
class LabelTextRequest(pydantic.BaseModel):
16+
text_labels: protocol_schema.TextLabels
17+
user: protocol_schema.User
18+
19+
20+
@router.post("/") # work with Union once more types are added
21+
def label_text(
22+
*,
23+
db: Session = Depends(deps.get_db),
24+
api_key: APIKey = Depends(deps.get_api_key),
25+
request: LabelTextRequest,
26+
) -> None:
27+
"""
28+
Label a piece of text.
29+
"""
30+
api_client = deps.api_auth(api_key, db)
31+
32+
try:
33+
logger.info(f"Labeling text {request=}.")
34+
pr = PromptRepository(db, api_client, user=request.user)
35+
pr.store_text_labels(request.text_labels)
36+
37+
except Exception:
38+
logger.exception("Failed to store label.")
39+
raise HTTPException(
40+
status_code=HTTP_400_BAD_REQUEST,
41+
)

backend/oasst_backend/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .person_stats import PersonStats
55
from .post import Post
66
from .post_reaction import PostReaction
7+
from .text_labels import TextLabels
78
from .work_package import WorkPackage
89

910
__all__ = [
@@ -13,4 +14,5 @@
1314
"Post",
1415
"PostReaction",
1516
"WorkPackage",
17+
"TextLabels",
1618
]
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# -*- coding: utf-8 -*-
2+
from datetime import datetime
3+
from typing import Optional
4+
from uuid import UUID, uuid4
5+
6+
import sqlalchemy as sa
7+
import sqlalchemy.dialects.postgresql as pg
8+
from sqlmodel import Field, SQLModel
9+
10+
11+
class TextLabels(SQLModel, table=True):
12+
__tablename__ = "text_labels"
13+
14+
id: Optional[UUID] = Field(
15+
sa_column=sa.Column(
16+
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
17+
),
18+
)
19+
created_date: Optional[datetime] = Field(
20+
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
21+
)
22+
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
23+
text: str = Field(nullable=False, max_length=2**16)
24+
post_id: Optional[UUID] = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("post.id"), nullable=True))
25+
labels: dict[str, float] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False)

backend/oasst_backend/prompt_repository.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import oasst_backend.models.db_payload as db_payload
77
from loguru import logger
8-
from oasst_backend.models import ApiClient, Person, Post, PostReaction, WorkPackage
8+
from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage
99
from oasst_backend.models.payload_column_type import PayloadContainer
1010
from oasst_shared.schemas import protocol as protocol_schema
1111
from sqlmodel import Session
@@ -314,3 +314,17 @@ def insert_reaction(self, post_id: UUID, payload: db_payload.ReactionPayload) ->
314314
self.db.commit()
315315
self.db.refresh(reaction)
316316
return reaction
317+
318+
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> TextLabels:
319+
model = TextLabels(
320+
api_client_id=self.api_client.id,
321+
text=text_labels.text,
322+
labels=text_labels.labels,
323+
)
324+
if text_labels.has_post_id:
325+
self.fetch_post_by_frontend_post_id(text_labels.post_id, fail_if_missing=True)
326+
model.post_id = text_labels.post_id
327+
self.db.add(model)
328+
self.db.commit()
329+
self.db.refresh(model)
330+
return model

oasst-shared/oasst_shared/schemas/protocol.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,51 @@ class PostRanking(Interaction):
204204
PostRating,
205205
PostRanking,
206206
]
207+
208+
209+
class TextLabel(str, enum.Enum):
210+
"""A label for a piece of text."""
211+
212+
spam = "spam"
213+
violence = "violence"
214+
sexual_content = "sexual_content"
215+
toxicity = "toxicity"
216+
political_content = "political_content"
217+
humor = "humor"
218+
sarcasm = "sarcasm"
219+
hate_speech = "hate_speech"
220+
profanity = "profanity"
221+
ad_hominem = "ad_hominem"
222+
insult = "insult"
223+
threat = "threat"
224+
aggressive = "aggressive"
225+
misleading = "misleading"
226+
helpful = "helpful"
227+
formal = "formal"
228+
cringe = "cringe"
229+
creative = "creative"
230+
beautiful = "beautiful"
231+
informative = "informative"
232+
based = "based"
233+
slang = "slang"
234+
235+
236+
class TextLabels(BaseModel):
237+
"""A set of labels for a piece of text."""
238+
239+
text: str
240+
labels: dict[TextLabel, float]
241+
post_id: str | None = None
242+
243+
@property
244+
def has_post_id(self) -> bool:
245+
"""Whether this TextLabels has a post_id."""
246+
return bool(self.post_id)
247+
248+
# check that each label value is between 0 and 1
249+
@pydantic.validator("labels")
250+
def check_label_values(cls, v):
251+
for key, value in v.items():
252+
if not 0 <= value <= 1:
253+
raise ValueError(f"Label values must be between 0 and 1, got {value} for {key}.")
254+
return v

0 commit comments

Comments
 (0)