@@ -888,14 +888,27 @@ def fetch_message_with_max_children(self, message: Message | UUID) -> tuple[Mess
888
888
max_message = max (tree , key = lambda m : m .children_count )
889
889
return max_message , [m for m in tree if m .parent_id == max_message .id ]
890
890
891
- def _add_user_emojis_all (self , qry : Query ) -> list [Message ]:
891
+ def _add_user_emojis_all (self , qry : Query , include_user : bool = False ) -> list [Message ]:
892
892
if self .user_id is None :
893
- return qry .all ()
893
+ if not include_user :
894
+ return qry .all ()
895
+
896
+ messages : list [Message ] = []
897
+
898
+ for element in qry :
899
+ message = element ["Message" ]
900
+ user = element ["User" ]
901
+ message ._user = user
902
+ messages .append (message )
903
+ return messages
894
904
895
905
order_by_clauses = qry ._order_by_clauses
896
906
sq = qry .subquery ("m" )
907
+ select_entities = [Message , func .string_agg (MessageEmoji .emoji , literal_column ("','" )).label ("user_emojis" )]
908
+ if include_user :
909
+ select_entities .append (User )
897
910
qry = (
898
- self .db .query (Message , func . string_agg ( MessageEmoji . emoji , literal_column ( "','" )). label ( "user_emojis" ) )
911
+ self .db .query (* select_entities )
899
912
.select_entity_from (sq )
900
913
.outerjoin (
901
914
MessageEmoji ,
@@ -915,7 +928,10 @@ def _add_user_emojis_all(self, qry: Query) -> list[Message]:
915
928
if user_emojis :
916
929
m ._user_emojis = user_emojis .split ("," )
917
930
m ._user_is_author = self .user_id and self .user_id == m .user_id
931
+ if include_user :
932
+ m ._user = x ["User" ]
918
933
messages .append (m )
934
+
919
935
return messages
920
936
921
937
def query_messages_ordered_by_created_date (
@@ -934,6 +950,7 @@ def query_messages_ordered_by_created_date(
934
950
desc : bool = False ,
935
951
limit : Optional [int ] = 100 ,
936
952
lang : Optional [str ] = None ,
953
+ include_user : Optional [bool ] = None ,
937
954
) -> list [Message ]:
938
955
if not self .api_client .trusted :
939
956
if not api_client_id :
@@ -945,12 +962,15 @@ def query_messages_ordered_by_created_date(
945
962
raise OasstError ("Forbidden" , OasstErrorCode .API_CLIENT_NOT_AUTHORIZED , HTTPStatus .FORBIDDEN )
946
963
947
964
qry = self .db .query (Message )
965
+ if include_user :
966
+ qry = self .db .query (Message , User )
948
967
if user_id :
949
968
qry = qry .filter (Message .user_id == user_id )
969
+ if username or auth_method or include_user :
970
+ qry = qry .join (User )
950
971
if username or auth_method :
951
972
if not (username and auth_method ):
952
973
raise OasstError ("Auth method or username missing." , OasstErrorCode .AUTH_AND_USERNAME_REQUIRED )
953
- qry = qry .join (User )
954
974
qry = qry .filter (User .username == username , User .auth_method == auth_method )
955
975
if api_client_id :
956
976
qry = qry .filter (Message .api_client_id == api_client_id )
@@ -1004,7 +1024,7 @@ def query_messages_ordered_by_created_date(
1004
1024
if limit is not None :
1005
1025
qry = qry .limit (limit )
1006
1026
1007
- return self ._add_user_emojis_all (qry )
1027
+ return self ._add_user_emojis_all (qry , include_user = include_user )
1008
1028
1009
1029
def update_children_counts (self , message_tree_id : UUID ):
1010
1030
sql_update_children_count = """
0 commit comments