Skip to content

Commit a5bc9bf

Browse files
authored
Add dupe checks to store_text_reply() & store_text_labels() in PromptRepository (LAION-AI#1018)
* add dupe checks to store_text_reply() & store_text_labels * remove test export file * add user_id to protocol.ConversationMessage * add show_on_leaderboard ot protocol.FrontEndUser
1 parent b8a62e5 commit a5bc9bf

File tree

8 files changed

+70
-22
lines changed

8 files changed

+70
-22
lines changed

backend/oasst_backend/api/v1/frontend_messages.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def get_children_by_frontend_id(
5858
"""
5959
pr = PromptRepository(db, api_client)
6060
message = pr.fetch_message_by_frontend_message_id(message_id)
61-
messages = pr.fetch_message_children(message.id)
61+
messages = pr.fetch_message_children(message.id, review_result=None)
6262
return utils.prepare_message_list(messages)
6363

6464

backend/oasst_backend/api/v1/messages.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def get_children(
201201
Get all messages belonging to the same message tree.
202202
"""
203203
pr = PromptRepository(db, api_client, frontend_user=frontend_user)
204-
messages = pr.fetch_message_children(message_id)
204+
messages = pr.fetch_message_children(message_id, review_result=None)
205205
return utils.prepare_message_list(messages)
206206

207207

backend/oasst_backend/api/v1/utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,15 @@ def prepare_message(m: Message) -> protocol.Message:
1010
id=m.id,
1111
frontend_message_id=m.frontend_message_id,
1212
parent_id=m.parent_id,
13+
user_id=m.user_id,
1314
text=m.text,
1415
lang=m.lang,
1516
is_assistant=(m.role == "assistant"),
1617
created_date=m.created_date,
1718
emojis=m.emojis or {},
1819
user_emojis=m.user_emojis or [],
20+
review_result=m.review_result,
21+
review_count=m.review_count,
1922
)
2023

2124

@@ -26,6 +29,7 @@ def prepare_message_list(messages: list[Message]) -> list[protocol.Message]:
2629
def prepare_conversation_message(message: Message) -> protocol.ConversationMessage:
2730
return protocol.ConversationMessage(
2831
id=message.id,
32+
user_id=message.user_id,
2933
frontend_message_id=message.frontend_message_id,
3034
text=message.text,
3135
lang=message.lang,

backend/oasst_backend/models/user.py

+1
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,5 @@ def to_protocol_frontend_user(self):
4242
deleted=self.deleted,
4343
notes=self.notes,
4444
created_date=self.created_date,
45+
show_on_leaderboard=self.show_on_leaderboard,
4546
)

backend/oasst_backend/prompt_repository.py

+52-16
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from uuid import UUID, uuid4
88

99
import oasst_backend.models.db_payload as db_payload
10-
import sqlalchemy as sa
1110
from loguru import logger
1211
from oasst_backend.api.deps import FrontendUserId
1312
from oasst_backend.config import settings
@@ -62,13 +61,11 @@ def __init__(
6261

6362
if user_id:
6463
self.user = self.user_repository.get_user(id=user_id)
65-
self.user_id = self.user.id
6664
elif auth_method and username:
6765
self.user = self.user_repository.query_frontend_user(auth_method=auth_method, username=username)
68-
self.user_id = self.user.id
6966
else:
7067
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
71-
self.user_id = self.user.id if self.user else None
68+
self.user_id = self.user.id if self.user else None
7269
logger.debug(f"PromptRepository(api_client_id={self.api_client.id}, {self.user_id=})")
7370
self.task_repository = task_repository or TaskRepository(
7471
db, api_client, client_user, user_repository=self.user_repository
@@ -215,6 +212,14 @@ def store_text_reply(
215212
OasstErrorCode.TREE_NOT_IN_GROWING_STATE,
216213
)
217214

215+
if check_duplicate and not settings.DEBUG_ALLOW_DUPLICATE_TASKS:
216+
siblings = self.fetch_message_children(task.parent_message_id, review_result=None, deleted=False)
217+
if any(m.user_id == self.user_id for m in siblings):
218+
raise OasstError(
219+
"User cannot reply twice to the same message.",
220+
OasstErrorCode.TASK_MESSAGE_DUPLICATE_REPLY,
221+
)
222+
218223
parent_message.message_tree_id
219224
parent_message.children_count += 1
220225
self.db.add(parent_message)
@@ -419,6 +424,7 @@ def insert_reaction(
419424

420425
@managed_tx_method(CommitMode.FLUSH)
421426
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> tuple[TextLabels, Task, Message]:
427+
self.ensure_user_is_enabled()
422428

423429
valid_labels: Optional[list[str]] = None
424430
mandatory_labels: Optional[list[str]] = None
@@ -484,6 +490,8 @@ def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> tuple[Te
484490
message: Message = None
485491
if message_id:
486492
if not task:
493+
# free labeling case
494+
487495
if text_labels.is_report is True:
488496
message = self.handle_message_emoji(
489497
message_id, protocol_schema.EmojiOp.add, protocol_schema.EmojiCode.red_flag
@@ -496,7 +504,21 @@ def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> tuple[Te
496504
model = existing_text_label
497505

498506
else:
499-
message = self.fetch_message(message_id)
507+
# task based labeling case
508+
509+
message = self.fetch_message(message_id, fail_if_missing=True)
510+
if not settings.DEBUG_ALLOW_SELF_LABELING and message.user_id == self.user_id:
511+
raise OasstError(
512+
"Labeling own message is not allowed.", OasstErrorCode.TEXT_LABELS_NO_SELF_LABELING
513+
)
514+
515+
existing_labels = self.fetch_message_text_labels(message_id, self.user_id)
516+
if not settings.DEBUG_ALLOW_DUPLICATE_TASKS and any(l.task_id for l in existing_labels):
517+
raise OasstError(
518+
"Message was already labeled by same user before.",
519+
OasstErrorCode.TEXT_LABELS_DUPLICATE_TASK_REPLY,
520+
)
521+
500522
message.review_count += 1
501523
self.db.add(message)
502524

@@ -666,6 +688,12 @@ def fetch_non_task_text_labels(self, message_id: UUID, user_id: UUID) -> Optiona
666688
text_label = query.one_or_none()
667689
return text_label
668690

691+
def fetch_message_text_labels(self, message_id: UUID, user_id: Optional[UUID] = None) -> list[TextLabels]:
692+
query = self.db.query(TextLabels).filter(TextLabels.message_id == message_id)
693+
if user_id is not None:
694+
query = query.filter(TextLabels.user_id == user_id)
695+
return query.all()
696+
669697
@staticmethod
670698
def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]:
671699
"""
@@ -712,7 +740,10 @@ def fetch_tree_from_message(self, message: Message | UUID) -> list[Message]:
712740
return self.fetch_message_tree(message.message_tree_id)
713741

714742
def fetch_message_children(
715-
self, message: Message | UUID, reviewed: bool = True, exclude_deleted: bool = True
743+
self,
744+
message: Message | UUID,
745+
review_result: Optional[bool] = True,
746+
deleted: Optional[bool] = False,
716747
) -> list[Message]:
717748
"""
718749
Get all direct children of this message
@@ -721,26 +752,31 @@ def fetch_message_children(
721752
message = message.id
722753

723754
qry = self.db.query(Message).filter(Message.parent_id == message)
724-
if reviewed:
725-
qry = qry.filter(Message.review_result)
726-
if exclude_deleted:
727-
qry = qry.filter(Message.deleted == sa.false())
755+
if review_result is not None:
756+
qry = qry.filter(Message.review_result == review_result)
757+
if deleted is not None:
758+
qry = qry.filter(Message.deleted == deleted)
728759
children = self._add_user_emojis_all(qry)
729760
return children
730761

731762
def fetch_message_siblings(
732-
self, message: Message | UUID, reviewed: Optional[bool] = True, deleted: Optional[bool] = False
763+
self,
764+
message: Message | UUID,
765+
review_result: Optional[bool] = True,
766+
deleted: Optional[bool] = False,
733767
) -> list[Message]:
734768
"""
735769
Get siblings of a message (other messages with the same parent_id)
736770
"""
771+
qry = self.db.query(Message)
737772
if isinstance(message, Message):
738-
message = message.id
773+
qry = qry.filter(Message.parent_id == message.parent_id)
774+
else:
775+
parent_qry = self.db.query(Message.parent_id).filter(Message.id == message).subquery()
776+
qry = qry.filter(Message.parent_id == parent_qry.c.parent_id)
739777

740-
parent_qry = self.db.query(Message.parent_id).filter(Message.id == message).subquery()
741-
qry = self.db.query(Message).filter(Message.parent_id == parent_qry.c.parent_id)
742-
if reviewed is not None:
743-
qry = qry.filter(Message.review_result == reviewed)
778+
if review_result is not None:
779+
qry = qry.filter(Message.review_result == review_result)
744780
if deleted is not None:
745781
qry = qry.filter(Message.deleted == deleted)
746782
siblings = self._add_user_emojis_all(qry)

backend/oasst_backend/tree_manager.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def next_task(
319319
ranking_parent = messages[-1]
320320
assert not ranking_parent.deleted and ranking_parent.review_result
321321
conversation = prepare_conversation(messages)
322-
replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True)
322+
replies = self.pr.fetch_message_children(ranking_parent_id, review_result=True, deleted=False)
323323

324324
assert len(replies) > 1
325325
random.shuffle(replies) # hand out replies in random order
@@ -756,7 +756,7 @@ def update_message_ranks(
756756
logger.debug(f"CONSENSUS: {consensus}\n\n")
757757

758758
# fetch all siblings and clear ranks
759-
siblings = self.pr.fetch_message_siblings(consensus[0], reviewed=None, deleted=None)
759+
siblings = self.pr.fetch_message_siblings(consensus[0], review_result=None, deleted=None)
760760
for m in siblings:
761761
m.rank = None
762762
self.db.add(m)

oasst-shared/oasst_shared/exceptions/oasst_api_error.py

+3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class OasstErrorCode(IntEnum):
4040
TASK_MESSAGE_TOO_LONG = 1008
4141
TASK_MESSAGE_DUPLICATED = 1009
4242
TASK_MESSAGE_TEXT_EMPTY = 1010
43+
TASK_MESSAGE_DUPLICATE_REPLY = 1011
4344

4445
# 2000-3000: prompt_repository
4546
INVALID_FRONTEND_MESSAGE_ID = 2000
@@ -59,6 +60,8 @@ class OasstErrorCode(IntEnum):
5960
TEXT_LABELS_WRONG_MESSAGE_ID = 2050
6061
TEXT_LABELS_INVALID_LABEL = 2051
6162
TEXT_LABELS_MANDATORY_LABEL_MISSING = 2052
63+
TEXT_LABELS_NO_SELF_LABELING = 2053
64+
TEXT_LABELS_DUPLICATE_TASK_REPLY = 2053
6265

6366
TASK_NOT_FOUND = 2100
6467
TASK_EXPIRED = 2101

oasst-shared/oasst_shared/schemas/protocol.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class FrontEndUser(User):
3535
deleted: bool
3636
notes: str
3737
created_date: Optional[datetime] = None
38+
show_on_leaderboard: bool
3839

3940

4041
class PageResult(BaseModel):
@@ -53,6 +54,7 @@ class ConversationMessage(BaseModel):
5354
"""Represents a message in a conversation between the user and the assistant."""
5455

5556
id: Optional[UUID] = None
57+
user_id: Optional[UUID]
5658
frontend_message_id: Optional[str] = None
5759
text: str
5860
lang: Optional[str] # BCP 47
@@ -80,8 +82,10 @@ def is_prompter_turn(self) -> bool:
8082

8183

8284
class Message(ConversationMessage):
83-
parent_id: Optional[UUID] = None
84-
created_date: Optional[datetime] = None
85+
parent_id: Optional[UUID]
86+
created_date: Optional[datetime]
87+
review_result: Optional[bool]
88+
review_count: Optional[int]
8589

8690

8791
class MessagePage(PageResult):

0 commit comments

Comments
 (0)