Skip to content

Commit 3cbea43

Browse files
authored
add periodically updated cached message stats (LAION-AI#1464)
* add periodically updated cached message stats
1 parent 66f1aa9 commit 3cbea43

File tree

9 files changed

+231
-0
lines changed

9 files changed

+231
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""add cached_stats
2+
3+
Revision ID: ba40d055714a
4+
Revises: caee1e8ee0bc
5+
Create Date: 2023-02-11 10:30:21.996198
6+
7+
"""
8+
import sqlalchemy as sa
9+
import sqlmodel
10+
from alembic import op
11+
from sqlalchemy.dialects import postgresql
12+
13+
# revision identifiers, used by Alembic.
14+
revision = "ba40d055714a"
15+
down_revision = "caee1e8ee0bc"
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade() -> None:
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.create_table(
23+
"cached_stats",
24+
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
25+
sa.Column(
26+
"modified_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False
27+
),
28+
sa.Column("stats", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
29+
sa.PrimaryKeyConstraint("name"),
30+
)
31+
# ### end Alembic commands ###
32+
33+
34+
def downgrade() -> None:
35+
# ### commands auto generated by Alembic - please adjust! ###
36+
op.drop_table("cached_stats")
37+
# ### end Alembic commands ###

backend/main.py

+12
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from oasst_backend.api.deps import api_auth, create_api_client
1616
from oasst_backend.api.v1.api import api_router
1717
from oasst_backend.api.v1.utils import prepare_conversation
18+
from oasst_backend.cached_stats_repository import CachedStatsRepository
1819
from oasst_backend.config import settings
1920
from oasst_backend.database import engine
2021
from oasst_backend.models import message_tree_state
@@ -334,6 +335,17 @@ def cronjob_delete_expired_tasks(session: Session) -> None:
334335
delete_expired_tasks(session)
335336

336337

338+
@app.on_event("startup")
339+
@repeat_every(seconds=60 * settings.CACHED_STATS_UPDATE_INTERVAL, wait_first=True)
340+
@managed_tx_function(auto_commit=CommitMode.COMMIT)
341+
def update_cached_stats(session: Session) -> None:
342+
try:
343+
csr = CachedStatsRepository(session)
344+
csr.update_all_cached_stats()
345+
except Exception:
346+
logger.exception("Error during cached stats update")
347+
348+
337349
app.include_router(api_router, prefix=settings.API_V1_STR)
338350

339351

backend/oasst_backend/api/v1/stats.py

+38
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from fastapi import APIRouter, Depends
22
from oasst_backend.api import deps
3+
from oasst_backend.cached_stats_repository import CachedStatsRepository
34
from oasst_backend.models import ApiClient
45
from oasst_backend.prompt_repository import PromptRepository
56
from oasst_backend.tree_manager import TreeManager, TreeManagerStats, TreeMessageCountStats
7+
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
68
from oasst_shared.schemas import protocol
79
from sqlmodel import Session
10+
from starlette.status import HTTP_204_NO_CONTENT
811

912
router = APIRouter()
1013

@@ -47,3 +50,38 @@ def get_tree_manager__stats(
4750
pr = PromptRepository(db, api_client)
4851
tm = TreeManager(db, pr)
4952
return tm.stats()
53+
54+
55+
@router.get("/cached/{name}", response_model=protocol.CachedStatsResponse)
56+
def get_cached_stats(
57+
*,
58+
name: protocol.CachedStatsName,
59+
db: Session = Depends(deps.get_db),
60+
api_client: ApiClient = Depends(deps.get_api_client),
61+
):
62+
csr = CachedStatsRepository(db)
63+
return csr.get_stats(name)
64+
65+
66+
@router.get("/cached", response_model=protocol.AllCachedStatsResponse)
67+
def get_cached_stats_all(
68+
*,
69+
db: Session = Depends(deps.get_db),
70+
api_client: ApiClient = Depends(deps.get_api_client),
71+
):
72+
csr = CachedStatsRepository(db)
73+
return csr.get_stats_all()
74+
75+
76+
@router.post("/cached/update", response_model=None, status_code=HTTP_204_NO_CONTENT)
77+
def update_cached_stats(
78+
*,
79+
db: Session = Depends(deps.get_db),
80+
api_client: ApiClient = Depends(deps.get_trusted_api_client),
81+
):
82+
@managed_tx_function(CommitMode.COMMIT)
83+
def update_tx(db: deps.Session) -> None:
84+
csr = CachedStatsRepository(db)
85+
csr.update_all_cached_stats()
86+
87+
update_tx()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from oasst_backend.models import CachedStats, Message, MessageTreeState, User
2+
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
3+
from oasst_shared.schemas.protocol import AllCachedStatsResponse, CachedStatsName, CachedStatsResponse
4+
from oasst_shared.utils import log_timing, utcnow
5+
from sqlalchemy.orm.attributes import flag_modified
6+
from sqlmodel import Session, func, not_
7+
8+
9+
def row_to_dict(r) -> dict:
10+
return {k: r[k] for k in r.keys()}
11+
12+
13+
class CachedStatsRepository:
14+
def __init__(self, db: Session):
15+
self.db = db
16+
17+
def qry_human_messages_by_lang(self) -> dict[str, int]:
18+
qry = (
19+
self.db.query(Message.lang, func.count(Message.id).label("count"))
20+
.filter(not_(Message.deleted), Message.review_result, not_(Message.synthetic))
21+
.group_by(Message.lang)
22+
)
23+
return {r["lang"]: r["count"] for r in qry}
24+
25+
def qry_human_messages_by_role(self) -> dict[str, int]:
26+
qry = (
27+
self.db.query(Message.role, func.count(Message.id).label("count"))
28+
.filter(not_(Message.deleted), Message.review_result, not_(Message.synthetic))
29+
.group_by(Message.role)
30+
)
31+
return {r["role"]: r["count"] for r in qry}
32+
33+
def qry_message_trees_by_state(self) -> dict[str, int]:
34+
qry = self.db.query(
35+
MessageTreeState.state, func.count(MessageTreeState.message_tree_id).label("count")
36+
).group_by(MessageTreeState.state)
37+
return {r["state"]: r["count"] for r in qry}
38+
39+
def qry_message_trees_states_by_lang(self) -> list:
40+
qry = (
41+
self.db.query(
42+
Message.lang, MessageTreeState.state, func.count(MessageTreeState.message_tree_id).label("count")
43+
)
44+
.select_from(MessageTreeState)
45+
.join(Message, MessageTreeState.message_tree_id == Message.id)
46+
.group_by(MessageTreeState.state, Message.lang)
47+
.order_by(Message.lang, MessageTreeState.state)
48+
)
49+
return [row_to_dict(r) for r in qry]
50+
51+
def qry_users_accepted_tos(self) -> dict[str, int]:
52+
qry = self.db.query(func.count(User.id)).filter(User.enabled, User.tos_acceptance_date.is_not(None))
53+
return {"count": qry.scalar()}
54+
55+
@log_timing(level="INFO")
56+
def update_all_cached_stats(self):
57+
v = self.qry_human_messages_by_lang()
58+
self._insert_cached_stats(CachedStatsName.human_messages_by_lang, v)
59+
60+
v = self.qry_human_messages_by_role()
61+
self._insert_cached_stats(CachedStatsName.human_messages_by_role, v)
62+
63+
v = self.qry_message_trees_by_state()
64+
self._insert_cached_stats(CachedStatsName.message_trees_by_state, v)
65+
66+
v = self.qry_message_trees_states_by_lang()
67+
self._insert_cached_stats(CachedStatsName.message_trees_states_by_lang, v)
68+
69+
v = self.qry_users_accepted_tos()
70+
self._insert_cached_stats(CachedStatsName.users_accepted_tos, v)
71+
72+
def _insert_cached_stats(self, name: CachedStatsName, stats: dict | list):
73+
row: CachedStats | None = self.db.query(CachedStats).filter(CachedStats.name == name).one_or_none()
74+
if row:
75+
row.modified_date = utcnow()
76+
row.stats = stats
77+
flag_modified(row, "stats")
78+
else:
79+
row = CachedStats(name=name, modified_date=utcnow(), stats=stats)
80+
self.db.add(row)
81+
82+
def get_stats(self, name: CachedStatsName) -> CachedStatsResponse:
83+
row: CachedStats | None = self.db.query(CachedStats).filter(CachedStats.name == name).one_or_none()
84+
if not row:
85+
raise OasstError(f"Cached stats '{name.value}' not found.", OasstErrorCode.CACHED_STATS_NOT_AVAILABLE)
86+
return CachedStatsResponse(name=row.name, last_updated=row.modified_date, stats=row.stats)
87+
88+
def get_stats_all(self) -> AllCachedStatsResponse:
89+
by_name: dict[CachedStatsName, CachedStatsResponse] = {}
90+
qry = self.db.query(CachedStats)
91+
for row in qry:
92+
by_name[row.name] = CachedStatsResponse(name=row.name, last_updated=row.modified_date, stats=row.stats)
93+
return AllCachedStatsResponse(stats_by_name=by_name)
94+
95+
96+
if __name__ == "__main__":
97+
# from oasst_backend.api.deps import create_api_client
98+
from oasst_backend.database import engine
99+
100+
with Session(engine) as db:
101+
csr = CachedStatsRepository(db)
102+
csr.update_all_cached_stats()()
103+
db.commit()

backend/oasst_backend/config.py

+2
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,8 @@ def validate_user_stats_intervals(cls, v: int):
243243
raise ValueError(v)
244244
return v
245245

246+
CACHED_STATS_UPDATE_INTERVAL: int = 60 # minutes
247+
246248
RATE_LIMIT_TASK_USER_TIMES: int = 60
247249
RATE_LIMIT_TASK_USER_MINUTES: int = 5
248250
RATE_LIMIT_TASK_API_TIMES: int = 10_000

backend/oasst_backend/models/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .api_client import ApiClient
2+
from .cached_stats import CachedStats
23
from .flagged_message import FlaggedMessage
34
from .journal import Journal, JournalIntegration
45
from .message import Message
@@ -30,4 +31,5 @@
3031
"MessageEmoji",
3132
"TrollStats",
3233
"FlaggedMessage",
34+
"CachedStats",
3335
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from datetime import datetime
2+
3+
import sqlalchemy as sa
4+
import sqlalchemy.dialects.postgresql as pg
5+
from sqlmodel import AutoString, Field, SQLModel
6+
7+
8+
class CachedStats(SQLModel, table=True):
9+
__tablename__ = "cached_stats"
10+
11+
name: str = Field(sa_column=sa.Column(AutoString(length=128), primary_key=True))
12+
13+
modified_date: datetime | None = Field(
14+
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
15+
)
16+
17+
stats: dict | list | None = Field(None, sa_column=sa.Column(pg.JSONB, nullable=False))

oasst-shared/oasst_shared/exceptions/oasst_api_error.py

+2
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ class OasstErrorCode(IntEnum):
8686

8787
EMOJI_OP_UNSUPPORTED = 5000
8888

89+
CACHED_STATS_NOT_AVAILABLE = 6000
90+
8991

9092
class OasstError(Exception):
9193
"""Base class for Open-Assistant exceptions."""

oasst-shared/oasst_shared/schemas/protocol.py

+18
Original file line numberDiff line numberDiff line change
@@ -571,3 +571,21 @@ class CreateFrontendUserRequest(User):
571571
enabled: bool = True
572572
tos_acceptance: Optional[bool] = None
573573
notes: Optional[str] = None
574+
575+
576+
class CachedStatsName(str, enum.Enum):
577+
human_messages_by_lang = "human_messages_by_lang"
578+
human_messages_by_role = "human_messages_by_role"
579+
message_trees_by_state = "message_trees_by_state"
580+
message_trees_states_by_lang = "message_trees_states_by_lang"
581+
users_accepted_tos = "users_accepted_tos"
582+
583+
584+
class CachedStatsResponse(BaseModel):
585+
name: CachedStatsName | str
586+
last_updated: datetime
587+
stats: dict | list
588+
589+
590+
class AllCachedStatsResponse(BaseModel):
591+
stats_by_name: dict[CachedStatsName | str, CachedStatsResponse]

0 commit comments

Comments
 (0)