Skip to content

Commit b5bb5bb

Browse files
authored
Add leaderboard stats, periodic updates via fastapi-utils (LAION-AI#724)
* add leaderboard stats, periodic update via fastapi-utils * count label tasks for assistant and prompter replies * Daily stats update every 15 mins, simplify leaderboard endpoint * add indices for some created_date columns * make user stats update intervals configurable * make sure intervals are positive
1 parent e01f2eb commit b5bb5bb

21 files changed

+555
-61
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""change user_stats ranking counts
2+
3+
Revision ID: 7c98102efbca
4+
Revises: 619255ae9076
5+
Create Date: 2023-01-15 00:02:45.622986
6+
7+
"""
8+
import sqlalchemy as sa
9+
import sqlmodel
10+
from alembic import op
11+
from sqlalchemy.dialects.postgresql import UUID
12+
13+
# revision identifiers, used by Alembic.
14+
revision = "7c98102efbca"
15+
down_revision = "619255ae9076"
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade() -> None:
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.drop_table("user_stats")
23+
op.create_table(
24+
"user_stats",
25+
sa.Column("user_id", UUID(as_uuid=True), nullable=False),
26+
sa.Column("modified_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
27+
sa.Column("base_date", sa.DateTime(), nullable=True),
28+
sa.Column("time_frame", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
29+
sa.Column("leader_score", sa.Integer(), nullable=False),
30+
sa.Column("prompts", sa.Integer(), nullable=False),
31+
sa.Column("replies_assistant", sa.Integer(), nullable=False),
32+
sa.Column("replies_prompter", sa.Integer(), nullable=False),
33+
sa.Column("labels_simple", sa.Integer(), nullable=False),
34+
sa.Column("labels_full", sa.Integer(), nullable=False),
35+
sa.Column("rankings_total", sa.Integer(), nullable=False),
36+
sa.Column("rankings_good", sa.Integer(), nullable=False),
37+
sa.Column("accepted_prompts", sa.Integer(), nullable=False),
38+
sa.Column("accepted_replies_assistant", sa.Integer(), nullable=False),
39+
sa.Column("accepted_replies_prompter", sa.Integer(), nullable=False),
40+
sa.Column("reply_ranked_1", sa.Integer(), nullable=False),
41+
sa.Column("reply_ranked_2", sa.Integer(), nullable=False),
42+
sa.Column("reply_ranked_3", sa.Integer(), nullable=False),
43+
sa.Column("streak_last_day_date", sa.DateTime(), nullable=True),
44+
sa.Column("streak_days", sa.Integer(), nullable=True),
45+
sa.ForeignKeyConstraint(
46+
["user_id"],
47+
["user.id"],
48+
),
49+
sa.PrimaryKeyConstraint("user_id", "time_frame"),
50+
)
51+
# ### end Alembic commands ###
52+
53+
54+
def downgrade() -> None:
55+
# ### commands auto generated by Alembic - please adjust! ###
56+
op.add_column(
57+
"user_stats",
58+
sa.Column("reply_prompter_ranked_3", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
59+
)
60+
op.add_column(
61+
"user_stats",
62+
sa.Column("reply_assistant_ranked_1", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
63+
)
64+
op.add_column(
65+
"user_stats",
66+
sa.Column("reply_assistant_ranked_2", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
67+
)
68+
op.add_column(
69+
"user_stats",
70+
sa.Column("reply_prompter_ranked_2", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
71+
)
72+
op.add_column(
73+
"user_stats",
74+
sa.Column("reply_prompter_ranked_1", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
75+
)
76+
op.add_column(
77+
"user_stats",
78+
sa.Column("reply_assistant_ranked_3", sa.INTEGER(), server_default="0", autoincrement=False, nullable=False),
79+
)
80+
op.drop_column("user_stats", "reply_ranked_3")
81+
op.drop_column("user_stats", "reply_ranked_2")
82+
op.drop_column("user_stats", "reply_ranked_1")
83+
# ### end Alembic commands ###
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""add indices for created_date
2+
3+
Revision ID: 423557e869e4
4+
Revises: 7c98102efbca
5+
Create Date: 2023-01-15 11:39:10.407859
6+
7+
"""
8+
from alembic import op
9+
10+
# revision identifiers, used by Alembic.
11+
revision = "423557e869e4"
12+
down_revision = "7c98102efbca"
13+
branch_labels = None
14+
depends_on = None
15+
16+
17+
def upgrade() -> None:
18+
# ### commands auto generated by Alembic - please adjust! ###
19+
op.create_index(op.f("ix_message_created_date"), "message", ["created_date"], unique=False)
20+
op.create_index(op.f("ix_message_reaction_created_date"), "message_reaction", ["created_date"], unique=False)
21+
op.create_index(op.f("ix_text_labels_created_date"), "text_labels", ["created_date"], unique=False)
22+
# ### end Alembic commands ###
23+
24+
25+
def downgrade() -> None:
26+
# ### commands auto generated by Alembic - please adjust! ###
27+
op.drop_index(op.f("ix_text_labels_created_date"), table_name="text_labels")
28+
op.drop_index(op.f("ix_message_reaction_created_date"), table_name="message_reaction")
29+
op.drop_index(op.f("ix_message_created_date"), table_name="message")
30+
# ### end Alembic commands ###

backend/main.py

+46
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import fastapi
1010
import redis.asyncio as redis
1111
from fastapi_limiter import FastAPILimiter
12+
from fastapi_utils.tasks import repeat_every
1213
from loguru import logger
1314
from oasst_backend.api.deps import get_dummy_api_client
1415
from oasst_backend.api.v1.api import api_router
@@ -18,6 +19,7 @@
1819
from oasst_backend.models import message_tree_state
1920
from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
2021
from oasst_backend.tree_manager import TreeManager
22+
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
2123
from oasst_shared.exceptions import OasstError, OasstErrorCode
2224
from oasst_shared.schemas import protocol as protocol_schema
2325
from pydantic import BaseModel
@@ -195,6 +197,50 @@ def ensure_tree_states():
195197
logger.exception("TreeManager.ensure_tree_states() failed.")
196198

197199

200+
@app.on_event("startup")
201+
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_DAY, wait_first=False)
202+
def update_leader_board_day() -> None:
203+
try:
204+
with Session(engine) as session:
205+
usr = UserStatsRepository(session)
206+
usr.update_stats(time_frame=UserStatsTimeFrame.day)
207+
except Exception:
208+
logger.exception("Error during leaderboard update (daily)")
209+
210+
211+
@app.on_event("startup")
212+
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_WEEK, wait_first=False)
213+
def update_leader_board_week() -> None:
214+
try:
215+
with Session(engine) as session:
216+
usr = UserStatsRepository(session)
217+
usr.update_stats(time_frame=UserStatsTimeFrame.week)
218+
except Exception:
219+
logger.exception("Error during user states update (weekly)")
220+
221+
222+
@app.on_event("startup")
223+
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_MONTH, wait_first=False)
224+
def update_leader_board_month() -> None:
225+
try:
226+
with Session(engine) as session:
227+
usr = UserStatsRepository(session)
228+
usr.update_stats(time_frame=UserStatsTimeFrame.month)
229+
except Exception:
230+
logger.exception("Error during user states update (monthly)")
231+
232+
233+
@app.on_event("startup")
234+
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_TOTAL, wait_first=False)
235+
def update_leader_board_total() -> None:
236+
try:
237+
with Session(engine) as session:
238+
usr = UserStatsRepository(session)
239+
usr.update_stats(time_frame=UserStatsTimeFrame.total)
240+
except Exception:
241+
logger.exception("Error during user states update (total)")
242+
243+
198244
app.include_router(api_router, prefix=settings.API_V1_STR)
199245

200246

backend/oasst_backend/api/v1/api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@
1919
api_router.include_router(users.router, prefix="/users", tags=["users"])
2020
api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"])
2121
api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
22-
api_router.include_router(leaderboards.router, prefix="/experimental/leaderboards", tags=["leaderboards"])
22+
api_router.include_router(leaderboards.router, prefix="/leaderboards", tags=["leaderboards"])
2323
api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"])
+11-16
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,21 @@
1-
from fastapi import APIRouter, Depends
1+
from typing import Optional
2+
3+
from fastapi import APIRouter, Depends, Query
24
from oasst_backend.api import deps
35
from oasst_backend.models import ApiClient
4-
from oasst_backend.user_repository import UserRepository
6+
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
57
from oasst_shared.schemas.protocol import LeaderboardStats
68
from sqlmodel import Session
79

810
router = APIRouter()
911

1012

11-
@router.get("/create/assistant")
12-
def get_assistant_leaderboard(
13-
db: Session = Depends(deps.get_db),
14-
api_client: ApiClient = Depends(deps.get_trusted_api_client),
15-
) -> LeaderboardStats:
16-
ur = UserRepository(db, api_client)
17-
return ur.get_user_leaderboard(role="assistant")
18-
19-
20-
@router.get("/create/prompter")
21-
def get_prompter_leaderboard(
13+
@router.get("/{time_frame}")
14+
def get_leaderboard_day(
15+
time_frame: UserStatsTimeFrame,
16+
max_count: Optional[int] = Query(100, gt=0, le=10000),
17+
api_client: ApiClient = Depends(deps.get_api_client),
2218
db: Session = Depends(deps.get_db),
23-
api_client: ApiClient = Depends(deps.get_trusted_api_client),
2419
) -> LeaderboardStats:
25-
ur = UserRepository(db, api_client)
26-
return ur.get_user_leaderboard(role="prompter")
20+
usr = UserStatsRepository(db)
21+
return usr.get_leader_board(time_frame, limit=max_count)

backend/oasst_backend/config.py

+16
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,22 @@ def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str
109109

110110
tree_manager: Optional[TreeManagerConfiguration] = TreeManagerConfiguration()
111111

112+
USER_STATS_INTERVAL_DAY: int = 15 # minutes
113+
USER_STATS_INTERVAL_WEEK: int = 60 # minutes
114+
USER_STATS_INTERVAL_MONTH: int = 120 # minutes
115+
USER_STATS_INTERVAL_TOTAL: int = 240 # minutes
116+
117+
@validator(
118+
"USER_STATS_INTERVAL_DAY",
119+
"USER_STATS_INTERVAL_WEEK",
120+
"USER_STATS_INTERVAL_MONTH",
121+
"USER_STATS_INTERVAL_TOTAL",
122+
)
123+
def validate_user_stats_intervals(cls, v: int):
124+
if v < 1:
125+
raise ValueError(v)
126+
return v
127+
112128
class Config:
113129
env_file = ".env"
114130
env_file_encoding = "utf-8"

backend/oasst_backend/models/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
from .task import Task
99
from .text_labels import TextLabels
1010
from .user import User
11-
from .user_stats import UserStats
11+
from .user_stats import UserStats, UserStatsTimeFrame
1212

1313
__all__ = [
1414
"ApiClient",
1515
"User",
1616
"UserStats",
17+
"UserStatsTimeFrame",
1718
"Message",
1819
"MessageEmbedding",
1920
"MessageReaction",

backend/oasst_backend/models/db_payload.py

+6
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,16 @@ class RankingReactionPayload(ReactionPayload):
6565
type: Literal["message_ranking"] = "message_ranking"
6666
ranking: list[int]
6767
ranked_message_ids: list[UUID]
68+
ranking_parent_id: Optional[UUID]
69+
message_tree_id: Optional[UUID]
6870

6971

7072
@payload_type
7173
class RankConversationRepliesPayload(TaskPayload):
7274
conversation: protocol_schema.Conversation # the conversation so far
7375
reply_messages: list[protocol_schema.ConversationMessage]
76+
ranking_parent_id: Optional[UUID]
77+
message_tree_id: Optional[UUID]
7478

7579

7680
@payload_type
@@ -104,6 +108,7 @@ class LabelInitialPromptPayload(TaskPayload):
104108
prompt: str
105109
valid_labels: list[str]
106110
mandatory_labels: Optional[list[str]]
111+
mode: Optional[protocol_schema.LabelTaskMode]
107112

108113

109114
@payload_type
@@ -115,6 +120,7 @@ class LabelConversationReplyPayload(TaskPayload):
115120
reply: str
116121
valid_labels: list[str]
117122
mandatory_labels: Optional[list[str]]
123+
mode: Optional[protocol_schema.LabelTaskMode]
118124

119125

120126
@payload_type

backend/oasst_backend/models/message.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class Message(SQLModel, table=True):
3030
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
3131
frontend_message_id: str = Field(max_length=200, nullable=False)
3232
created_date: Optional[datetime] = Field(
33-
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
33+
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True)
3434
)
3535
payload_type: str = Field(nullable=False, max_length=200)
3636
payload: Optional[PayloadContainer] = Field(

backend/oasst_backend/models/message_reaction.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class MessageReaction(SQLModel, table=True):
1919
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False, primary_key=True)
2020
)
2121
created_date: Optional[datetime] = Field(
22-
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
22+
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True)
2323
)
2424
payload_type: str = Field(nullable=False, max_length=200)
2525
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))

backend/oasst_backend/models/task.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import sqlalchemy as sa
66
import sqlalchemy.dialects.postgresql as pg
7+
from oasst_shared.utils import utcnow
78
from sqlalchemy import false
89
from sqlmodel import Field, SQLModel
910

@@ -35,4 +36,4 @@ class Task(SQLModel, table=True):
3536

3637
@property
3738
def expired(self) -> bool:
38-
return self.expiry_date is not None and datetime.utcnow() > self.expiry_date
39+
return self.expiry_date is not None and utcnow() > self.expiry_date

backend/oasst_backend/models/text_labels.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class TextLabels(SQLModel, table=True):
1717
)
1818
user_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False))
1919
created_date: Optional[datetime] = Field(
20-
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
20+
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True),
2121
)
2222
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
2323
text: str = Field(nullable=False, max_length=2**16)

backend/oasst_backend/models/user_stats.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class UserStats(SQLModel, table=True):
2222
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), primary_key=True)
2323
)
2424
time_frame: Optional[str] = Field(nullable=False, primary_key=True)
25+
base_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
2526

2627
leader_score: int = 0
2728
modified_date: Optional[datetime] = Field(
@@ -40,14 +41,27 @@ class UserStats(SQLModel, table=True):
4041
accepted_replies_assistant: int = 0
4142
accepted_replies_prompter: int = 0
4243

43-
reply_assistant_ranked_1: int = 0
44-
reply_assistant_ranked_2: int = 0
45-
reply_assistant_ranked_3: int = 0
46-
47-
reply_prompter_ranked_1: int = 0
48-
reply_prompter_ranked_2: int = 0
49-
reply_prompter_ranked_3: int = 0
44+
reply_ranked_1: int = 0
45+
reply_ranked_2: int = 0
46+
reply_ranked_3: int = 0
5047

5148
# only used for time span "total"
5249
streak_last_day_date: Optional[datetime] = Field(nullable=True)
5350
streak_days: Optional[int] = Field(nullable=True)
51+
52+
def compute_leader_score(self) -> int:
53+
return (
54+
self.prompts
55+
+ self.replies_assistant * 4
56+
+ self.replies_prompter
57+
+ self.labels_simple
58+
+ self.labels_full * 2
59+
+ self.rankings_total
60+
+ self.rankings_good
61+
+ self.accepted_prompts
62+
+ self.accepted_replies_assistant * 4
63+
+ self.accepted_replies_prompter
64+
+ self.reply_ranked_1 * 9
65+
+ self.reply_ranked_2 * 3
66+
+ self.reply_ranked_3
67+
)

backend/oasst_backend/prompt_repository.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,10 @@ def store_ranking(self, ranking: protocol_schema.MessageRanking) -> Tuple[Messag
260260
self.db.add(message)
261261

262262
reaction_payload = db_payload.RankingReactionPayload(
263-
ranking=ranking.ranking, ranked_message_ids=ranked_message_ids
263+
ranking=ranking.ranking,
264+
ranked_message_ids=ranked_message_ids,
265+
ranking_parent_id=task_payload.ranking_parent_id,
266+
message_tree_id=task_payload.message_tree_id,
264267
)
265268
reaction = self.insert_reaction(task.id, reaction_payload)
266269
self.journal.log_ranking(task, message_id=parent_msg.id, ranking=ranking.ranking)

0 commit comments

Comments
 (0)