forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathuser_repository.py
88 lines (75 loc) · 3.33 KB
/
user_repository.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from typing import Optional
from uuid import UUID
from oasst_backend.models import ApiClient, Message, User
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import LeaderboardStats
from sqlmodel import Session, func
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
class UserRepository:
def __init__(self, db: Session, api_client: ApiClient):
self.db = db
self.api_client = api_client
def query_frontend_user(
self, auth_method: str, username: str, api_client_id: Optional[UUID] = None
) -> Optional[User]:
if not api_client_id:
api_client_id = self.api_client.id
if not self.api_client.trusted and api_client_id != self.api_client.id:
# Unprivileged API client asks for foreign user
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
user: User = (
self.db.query(User)
.filter(User.auth_method == auth_method, User.username == username, User.api_client_id == api_client_id)
.first()
)
if user is None:
raise OasstError("User not found", OasstErrorCode.USER_NOT_FOUND, HTTP_404_NOT_FOUND)
return user
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
if not client_user:
return None
user: User = (
self.db.query(User)
.filter(
User.api_client_id == self.api_client.id,
User.username == client_user.id,
User.auth_method == client_user.auth_method,
)
.first()
)
if user is None:
if create_missing:
# user is unknown, create new record
user = User(
username=client_user.id,
display_name=client_user.display_name,
api_client_id=self.api_client.id,
auth_method=client_user.auth_method,
)
self.db.add(user)
self.db.commit()
self.db.refresh(user)
elif client_user.display_name and client_user.display_name != user.display_name:
# we found the user but the display name changed
user.display_name = client_user.display_name
self.db.add(user)
self.db.commit()
return user
def get_user_leaderboard(self, role: str) -> LeaderboardStats:
"""
Get leaderboard stats for Messages created,
separate leaderboard for prompts & assistants
"""
query = (
self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
.join(User, User.id == Message.user_id, isouter=True)
.filter(Message.deleted is not True, Message.role == role)
.group_by(Message.user_id, User.username, User.display_name)
.order_by(func.count(Message.user_id).desc())
)
result = [
{"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
for i, j in enumerate(query.all(), start=1)
]
return LeaderboardStats(leaderboard=result)