forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfrontend_users.py
172 lines (155 loc) · 5.54 KB
/
frontend_users.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import datetime
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.api.v1.messages import get_messages_cursor
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.user_repository import UserRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT
router = APIRouter()
@router.get("/", response_model=list[protocol.FrontEndUser], deprecated=True)
def get_users_ordered_by_username(
api_client_id: Optional[UUID] = None,
gte_username: Optional[str] = None,
gt_id: Optional[UUID] = None,
lte_username: Optional[str] = None,
lt_id: Optional[UUID] = None,
search_text: Optional[str] = None,
auth_method: Optional[str] = None,
max_count: Optional[int] = Query(100, gt=0, le=10000),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
ur = UserRepository(db, api_client)
users = ur.query_users_ordered_by_username(
api_client_id=api_client_id,
gte_username=gte_username,
gt_id=gt_id,
lte_username=lte_username,
lt_id=lt_id,
auth_method=auth_method,
search_text=search_text,
limit=max_count,
)
return [u.to_protocol_frontend_user() for u in users]
@router.get("/{auth_method}/{username}", response_model=protocol.FrontEndUser)
def query_frontend_user(
auth_method: str,
username: str,
api_client_id: Optional[UUID] = None,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Query frontend user.
"""
ur = UserRepository(db, api_client)
user = ur.query_frontend_user(auth_method, username, api_client_id)
return user.to_protocol_frontend_user()
@router.post("/", response_model=protocol.FrontEndUser)
def create_frontend_user(
*,
create_user: protocol.CreateFrontendUserRequest,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
ur = UserRepository(db, api_client)
user = ur.lookup_client_user(create_user, create_missing=True)
def changed(a, b) -> bool:
return a is not None and a != b
# only call update_user if something changed
if (
changed(create_user.enabled, user.enabled)
or changed(create_user.show_on_leaderboard, user.show_on_leaderboard)
or changed(create_user.notes, user.notes)
or (create_user.tos_acceptance and user.tos_acceptance_date is None)
):
user = ur.update_user(
user.id,
enabled=create_user.enabled,
show_on_leaderboard=create_user.show_on_leaderboard,
tos_acceptance=create_user.tos_acceptance,
notes=create_user.notes,
)
return user.to_protocol_frontend_user()
@router.get("/{auth_method}/{username}/messages", response_model=list[protocol.Message])
def query_frontend_user_messages(
auth_method: str,
username: str,
api_client_id: UUID = None,
max_count: int = Query(10, gt=0, le=1000),
start_date: datetime.datetime = None,
end_date: datetime.datetime = None,
only_roots: bool = False,
desc: bool = True,
include_deleted: bool = False,
lang: Optional[str] = None,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Query frontend user messages.
"""
pr = PromptRepository(db, api_client, auth_method=auth_method, username=username)
messages = pr.query_messages_ordered_by_created_date(
auth_method=auth_method,
username=username,
api_client_id=api_client_id,
desc=desc,
limit=max_count,
gte_created_date=start_date,
lte_created_date=end_date,
only_roots=only_roots,
deleted=None if include_deleted else False,
lang=lang,
)
return utils.prepare_message_list(messages)
@router.get("/{auth_method}/{username}/messages/cursor", response_model=protocol.MessagePage)
def query_frontend_user_messages_cursor(
auth_method: str,
username: str,
before: Optional[str] = None,
after: Optional[str] = None,
only_roots: Optional[bool] = False,
include_deleted: Optional[bool] = False,
max_count: Optional[int] = Query(10, gt=0, le=1000),
desc: Optional[bool] = False,
lang: Optional[str] = None,
frontend_user: deps.FrontendUserId = Depends(deps.get_frontend_user_id),
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
return get_messages_cursor(
before=before,
after=after,
auth_method=auth_method,
username=username,
only_roots=only_roots,
include_deleted=include_deleted,
max_count=max_count,
desc=desc,
lang=lang,
frontend_user=frontend_user,
api_client=api_client,
db=db,
)
@router.delete("/{auth_method}/{username}/messages", status_code=HTTP_204_NO_CONTENT)
def mark_frontend_user_messages_deleted(
auth_method: str,
username: str,
api_client: ApiClient = Depends(deps.get_trusted_api_client),
db: Session = Depends(deps.get_db),
):
pr = PromptRepository(db, api_client)
messages = pr.query_messages_ordered_by_created_date(
auth_method=auth_method,
username=username,
api_client_id=api_client.id,
limit=None,
)
pr.mark_messages_deleted(messages)