Skip to content

Commit 7d1ae7c

Browse files
authored
add include_user params in /messages/cursor` endpoint (LAION-AI#2021)
Needed to display user information in message table (admin interface).
1 parent de2a50f commit 7d1ae7c

File tree

5 files changed

+35
-5
lines changed

5 files changed

+35
-5
lines changed

backend/oasst_backend/api/v1/messages.py

+2
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def get_messages_cursor(
7070
max_count: Optional[int] = Query(10, gt=0, le=1000),
7171
desc: Optional[bool] = False,
7272
lang: Optional[str] = None,
73+
include_user: Optional[bool] = None,
7374
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
7475
api_client: ApiClient = Depends(deps.get_api_client),
7576
db: Session = Depends(deps.get_db),
@@ -115,6 +116,7 @@ def split_cursor(x: str | None) -> tuple[datetime, UUID]:
115116
desc=query_desc,
116117
limit=qry_max_count,
117118
lang=lang,
119+
include_user=include_user,
118120
)
119121

120122
num_rows = len(items)

backend/oasst_backend/api/v1/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def prepare_message(m: Message) -> protocol.Message:
2626
model_name=m.model_name,
2727
message_tree_id=m.message_tree_id,
2828
rank=m.rank,
29+
user=m.user.to_protocol_frontend_user() if m.user else None,
2930
)
3031

3132

backend/oasst_backend/models/message.py

+6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import sqlalchemy as sa
77
import sqlalchemy.dialects.postgresql as pg
88
from oasst_backend.models.db_payload import MessagePayload
9+
from oasst_backend.models.user import User
910
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
1011
from pydantic import PrivateAttr
1112
from sqlalchemy import false
@@ -65,6 +66,7 @@ def __new__(cls, *args: Any, **kwargs: Any):
6566
emojis: Optional[dict[str, int]] = Field(default=None, sa_column=sa.Column(pg.JSONB), nullable=False)
6667
_user_emojis: Optional[list[str]] = PrivateAttr(default=None)
6768
_user_is_author: Optional[bool] = PrivateAttr(default=None)
69+
_user: Optional[bool] = PrivateAttr(default=None)
6870

6971
def ensure_is_message(self) -> None:
7072
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
@@ -88,3 +90,7 @@ def user_emojis(self) -> str:
8890
@property
8991
def user_is_author(self) -> str:
9092
return self._user_is_author
93+
94+
@property
95+
def user(self) -> User:
96+
return self._user

backend/oasst_backend/prompt_repository.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -888,14 +888,27 @@ def fetch_message_with_max_children(self, message: Message | UUID) -> tuple[Mess
888888
max_message = max(tree, key=lambda m: m.children_count)
889889
return max_message, [m for m in tree if m.parent_id == max_message.id]
890890

891-
def _add_user_emojis_all(self, qry: Query) -> list[Message]:
891+
def _add_user_emojis_all(self, qry: Query, include_user: bool = False) -> list[Message]:
892892
if self.user_id is None:
893-
return qry.all()
893+
if not include_user:
894+
return qry.all()
895+
896+
messages: list[Message] = []
897+
898+
for element in qry:
899+
message = element["Message"]
900+
user = element["User"]
901+
message._user = user
902+
messages.append(message)
903+
return messages
894904

895905
order_by_clauses = qry._order_by_clauses
896906
sq = qry.subquery("m")
907+
select_entities = [Message, func.string_agg(MessageEmoji.emoji, literal_column("','")).label("user_emojis")]
908+
if include_user:
909+
select_entities.append(User)
897910
qry = (
898-
self.db.query(Message, func.string_agg(MessageEmoji.emoji, literal_column("','")).label("user_emojis"))
911+
self.db.query(*select_entities)
899912
.select_entity_from(sq)
900913
.outerjoin(
901914
MessageEmoji,
@@ -915,7 +928,10 @@ def _add_user_emojis_all(self, qry: Query) -> list[Message]:
915928
if user_emojis:
916929
m._user_emojis = user_emojis.split(",")
917930
m._user_is_author = self.user_id and self.user_id == m.user_id
931+
if include_user:
932+
m._user = x["User"]
918933
messages.append(m)
934+
919935
return messages
920936

921937
def query_messages_ordered_by_created_date(
@@ -934,6 +950,7 @@ def query_messages_ordered_by_created_date(
934950
desc: bool = False,
935951
limit: Optional[int] = 100,
936952
lang: Optional[str] = None,
953+
include_user: Optional[bool] = None,
937954
) -> list[Message]:
938955
if not self.api_client.trusted:
939956
if not api_client_id:
@@ -945,12 +962,15 @@ def query_messages_ordered_by_created_date(
945962
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTPStatus.FORBIDDEN)
946963

947964
qry = self.db.query(Message)
965+
if include_user:
966+
qry = self.db.query(Message, User)
948967
if user_id:
949968
qry = qry.filter(Message.user_id == user_id)
969+
if username or auth_method or include_user:
970+
qry = qry.join(User)
950971
if username or auth_method:
951972
if not (username and auth_method):
952973
raise OasstError("Auth method or username missing.", OasstErrorCode.AUTH_AND_USERNAME_REQUIRED)
953-
qry = qry.join(User)
954974
qry = qry.filter(User.username == username, User.auth_method == auth_method)
955975
if api_client_id:
956976
qry = qry.filter(Message.api_client_id == api_client_id)
@@ -1004,7 +1024,7 @@ def query_messages_ordered_by_created_date(
10041024
if limit is not None:
10051025
qry = qry.limit(limit)
10061026

1007-
return self._add_user_emojis_all(qry)
1027+
return self._add_user_emojis_all(qry, include_user=include_user)
10081028

10091029
def update_children_counts(self, message_tree_id: UUID):
10101030
sql_update_children_count = """

oasst-shared/oasst_shared/schemas/protocol.py

+1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ class Message(ConversationMessage):
108108
message_tree_id: Optional[UUID]
109109
ranking_count: Optional[int]
110110
rank: Optional[int]
111+
user: Optional[FrontEndUser]
111112

112113

113114
class MessagePage(PageResult):

0 commit comments

Comments
 (0)