7
7
from uuid import UUID , uuid4
8
8
9
9
import oasst_backend .models .db_payload as db_payload
10
- import sqlalchemy as sa
11
10
from loguru import logger
12
11
from oasst_backend .api .deps import FrontendUserId
13
12
from oasst_backend .config import settings
@@ -62,13 +61,11 @@ def __init__(
62
61
63
62
if user_id :
64
63
self .user = self .user_repository .get_user (id = user_id )
65
- self .user_id = self .user .id
66
64
elif auth_method and username :
67
65
self .user = self .user_repository .query_frontend_user (auth_method = auth_method , username = username )
68
- self .user_id = self .user .id
69
66
else :
70
67
self .user = self .user_repository .lookup_client_user (client_user , create_missing = True )
71
- self .user_id = self .user .id if self .user else None
68
+ self .user_id = self .user .id if self .user else None
72
69
logger .debug (f"PromptRepository(api_client_id={ self .api_client .id } , { self .user_id = } )" )
73
70
self .task_repository = task_repository or TaskRepository (
74
71
db , api_client , client_user , user_repository = self .user_repository
@@ -215,6 +212,14 @@ def store_text_reply(
215
212
OasstErrorCode .TREE_NOT_IN_GROWING_STATE ,
216
213
)
217
214
215
+ if check_duplicate and not settings .DEBUG_ALLOW_DUPLICATE_TASKS :
216
+ siblings = self .fetch_message_children (task .parent_message_id , review_result = None , deleted = False )
217
+ if any (m .user_id == self .user_id for m in siblings ):
218
+ raise OasstError (
219
+ "User cannot reply twice to the same message." ,
220
+ OasstErrorCode .TASK_MESSAGE_DUPLICATE_REPLY ,
221
+ )
222
+
218
223
parent_message .message_tree_id
219
224
parent_message .children_count += 1
220
225
self .db .add (parent_message )
@@ -419,6 +424,7 @@ def insert_reaction(
419
424
420
425
@managed_tx_method (CommitMode .FLUSH )
421
426
def store_text_labels (self , text_labels : protocol_schema .TextLabels ) -> tuple [TextLabels , Task , Message ]:
427
+ self .ensure_user_is_enabled ()
422
428
423
429
valid_labels : Optional [list [str ]] = None
424
430
mandatory_labels : Optional [list [str ]] = None
@@ -484,6 +490,8 @@ def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> tuple[Te
484
490
message : Message = None
485
491
if message_id :
486
492
if not task :
493
+ # free labeling case
494
+
487
495
if text_labels .is_report is True :
488
496
message = self .handle_message_emoji (
489
497
message_id , protocol_schema .EmojiOp .add , protocol_schema .EmojiCode .red_flag
@@ -496,7 +504,21 @@ def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> tuple[Te
496
504
model = existing_text_label
497
505
498
506
else :
499
- message = self .fetch_message (message_id )
507
+ # task based labeling case
508
+
509
+ message = self .fetch_message (message_id , fail_if_missing = True )
510
+ if not settings .DEBUG_ALLOW_SELF_LABELING and message .user_id == self .user_id :
511
+ raise OasstError (
512
+ "Labeling own message is not allowed." , OasstErrorCode .TEXT_LABELS_NO_SELF_LABELING
513
+ )
514
+
515
+ existing_labels = self .fetch_message_text_labels (message_id , self .user_id )
516
+ if not settings .DEBUG_ALLOW_DUPLICATE_TASKS and any (l .task_id for l in existing_labels ):
517
+ raise OasstError (
518
+ "Message was already labeled by same user before." ,
519
+ OasstErrorCode .TEXT_LABELS_DUPLICATE_TASK_REPLY ,
520
+ )
521
+
500
522
message .review_count += 1
501
523
self .db .add (message )
502
524
@@ -666,6 +688,12 @@ def fetch_non_task_text_labels(self, message_id: UUID, user_id: UUID) -> Optiona
666
688
text_label = query .one_or_none ()
667
689
return text_label
668
690
691
+ def fetch_message_text_labels (self , message_id : UUID , user_id : Optional [UUID ] = None ) -> list [TextLabels ]:
692
+ query = self .db .query (TextLabels ).filter (TextLabels .message_id == message_id )
693
+ if user_id is not None :
694
+ query = query .filter (TextLabels .user_id == user_id )
695
+ return query .all ()
696
+
669
697
@staticmethod
670
698
def trace_conversation (messages : list [Message ] | dict [UUID , Message ], last_message : Message ) -> list [Message ]:
671
699
"""
@@ -712,7 +740,10 @@ def fetch_tree_from_message(self, message: Message | UUID) -> list[Message]:
712
740
return self .fetch_message_tree (message .message_tree_id )
713
741
714
742
def fetch_message_children (
715
- self , message : Message | UUID , reviewed : bool = True , exclude_deleted : bool = True
743
+ self ,
744
+ message : Message | UUID ,
745
+ review_result : Optional [bool ] = True ,
746
+ deleted : Optional [bool ] = False ,
716
747
) -> list [Message ]:
717
748
"""
718
749
Get all direct children of this message
@@ -721,26 +752,31 @@ def fetch_message_children(
721
752
message = message .id
722
753
723
754
qry = self .db .query (Message ).filter (Message .parent_id == message )
724
- if reviewed :
725
- qry = qry .filter (Message .review_result )
726
- if exclude_deleted :
727
- qry = qry .filter (Message .deleted == sa . false () )
755
+ if review_result is not None :
756
+ qry = qry .filter (Message .review_result == review_result )
757
+ if deleted is not None :
758
+ qry = qry .filter (Message .deleted == deleted )
728
759
children = self ._add_user_emojis_all (qry )
729
760
return children
730
761
731
762
def fetch_message_siblings (
732
- self , message : Message | UUID , reviewed : Optional [bool ] = True , deleted : Optional [bool ] = False
763
+ self ,
764
+ message : Message | UUID ,
765
+ review_result : Optional [bool ] = True ,
766
+ deleted : Optional [bool ] = False ,
733
767
) -> list [Message ]:
734
768
"""
735
769
Get siblings of a message (other messages with the same parent_id)
736
770
"""
771
+ qry = self .db .query (Message )
737
772
if isinstance (message , Message ):
738
- message = message .id
773
+ qry = qry .filter (Message .parent_id == message .parent_id )
774
+ else :
775
+ parent_qry = self .db .query (Message .parent_id ).filter (Message .id == message ).subquery ()
776
+ qry = qry .filter (Message .parent_id == parent_qry .c .parent_id )
739
777
740
- parent_qry = self .db .query (Message .parent_id ).filter (Message .id == message ).subquery ()
741
- qry = self .db .query (Message ).filter (Message .parent_id == parent_qry .c .parent_id )
742
- if reviewed is not None :
743
- qry = qry .filter (Message .review_result == reviewed )
778
+ if review_result is not None :
779
+ qry = qry .filter (Message .review_result == review_result )
744
780
if deleted is not None :
745
781
qry = qry .filter (Message .deleted == deleted )
746
782
siblings = self ._add_user_emojis_all (qry )
0 commit comments