Skip to content

Commit bbf0386

Browse files
authored
Add terms of service acceptance date to user table (LAION-AI#1046)
* add tos_acceptance_date column to user * send 451 UNAVAILABLE_FOR_LEGAL_REASONS status * add create user REST endpoint * adapt text-frontend to ToS requirements * set DEBUG_IGNORE_TOS_ACCEPTANCE default to True (temporary change) * update down revision to f60958968ff8
1 parent e0df9f0 commit bbf0386

File tree

12 files changed

+127
-8
lines changed

12 files changed

+127
-8
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""add tos_acceptance_date to user
2+
3+
Revision ID: 55361f323d12
4+
Revises: 7b8f0011e0b0
5+
Create Date: 2023-02-01 00:22:08.280251
6+
7+
"""
8+
import sqlalchemy as sa
9+
from alembic import op
10+
from sqlalchemy.dialects import postgresql
11+
12+
# revision identifiers, used by Alembic.
13+
revision = "55361f323d12"
14+
down_revision = "f60958968ff8"
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade() -> None:
20+
# ### commands auto generated by Alembic - please adjust! ###
21+
op.add_column("user", sa.Column("tos_acceptance_date", sa.DateTime(timezone=True), nullable=True))
22+
op.drop_column("user_stats", "streak_days")
23+
op.drop_column("user_stats", "streak_last_day_date")
24+
# ### end Alembic commands ###
25+
26+
27+
def downgrade() -> None:
28+
# ### commands auto generated by Alembic - please adjust! ###
29+
op.add_column(
30+
"user_stats", sa.Column("streak_last_day_date", postgresql.TIMESTAMP(), autoincrement=False, nullable=True)
31+
)
32+
op.add_column("user_stats", sa.Column("streak_days", sa.INTEGER(), autoincrement=False, nullable=True))
33+
op.drop_column("user", "tos_acceptance_date")
34+
# ### end Alembic commands ###

backend/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class DummyMessage(BaseModel):
147147

148148
ur = UserRepository(db=session, api_client=api_client)
149149
tr = TaskRepository(db=session, api_client=api_client, client_user=dummy_user, user_repository=ur)
150+
ur.update_user(tr.user_id, enabled=True, show_on_leaderboard=False, tos_acceptance=True)
150151
pr = PromptRepository(
151152
db=session, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
152153
)

backend/oasst_backend/api/v1/frontend_users.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,37 @@ def query_frontend_user(
5959
return user.to_protocol_frontend_user()
6060

6161

62+
@router.post("/", response_model=protocol.FrontEndUser)
63+
def create_frontend_user(
64+
*,
65+
create_user: protocol.CreateFrontendUserRequest,
66+
api_client: ApiClient = Depends(deps.get_api_client),
67+
db: Session = Depends(deps.get_db),
68+
):
69+
ur = UserRepository(db, api_client)
70+
user = ur.lookup_client_user(create_user, create_missing=True)
71+
72+
def changed(a, b) -> bool:
73+
return a is not None and a != b
74+
75+
# only call update_user if something changed
76+
if (
77+
changed(create_user.enabled, user.enabled)
78+
or changed(create_user.show_on_leaderboard, user.show_on_leaderboard)
79+
or changed(create_user.notes, user.notes)
80+
or (create_user.tos_acceptance and user.tos_acceptance_date is None)
81+
):
82+
user = ur.update_user(
83+
user.id,
84+
enabled=create_user.enabled,
85+
show_on_leaderboard=create_user.show_on_leaderboard,
86+
tos_acceptance=create_user.tos_acceptance,
87+
notes=create_user.notes,
88+
)
89+
90+
return user.to_protocol_frontend_user()
91+
92+
6293
@router.get("/{auth_method}/{username}/messages", response_model=list[protocol.Message])
6394
def query_frontend_user_messages(
6495
auth_method: str,

backend/oasst_backend/api/v1/users.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,15 @@ def update_user(
191191
enabled: Optional[bool] = None,
192192
notes: Optional[str] = None,
193193
show_on_leaderboard: Optional[bool] = None,
194+
tos_acceptance: Optional[bool] = None,
194195
db: Session = Depends(deps.get_db),
195196
api_client: ApiClient = Depends(deps.get_trusted_api_client),
196197
):
197198
"""
198199
Update a user by global user ID. Only trusted clients can update users.
199200
"""
200201
ur = UserRepository(db, api_client)
201-
ur.update_user(user_id, enabled, notes, show_on_leaderboard)
202+
ur.update_user(user_id, enabled, notes, show_on_leaderboard, tos_acceptance)
202203

203204

204205
@router.delete("/{user_id}", status_code=HTTP_204_NO_CONTENT)

backend/oasst_backend/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ class Settings(BaseSettings):
158158
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False
159159
DEBUG_SKIP_TOXICITY_CALCULATION: bool = False
160160
DEBUG_DATABASE_ECHO: bool = False
161+
DEBUG_IGNORE_TOS_ACCEPTANCE: bool = ( # ignore whether users accepted the ToS
162+
True # TODO: set False after ToS acceptance UI was added to web-frontend
163+
)
161164

162165
DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES: int = 120
163166

backend/oasst_backend/models/user.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ class User(SQLModel, table=True):
4141
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True, server_default=sa.func.current_timestamp())
4242
)
4343

44+
# terms of service acceptance date
45+
tos_acceptance_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
46+
4447
def to_protocol_frontend_user(self):
4548
return protocol.FrontEndUser(
4649
user_id=self.id,
@@ -55,4 +58,5 @@ def to_protocol_frontend_user(self):
5558
streak_days=self.streak_days,
5659
streak_last_day_date=self.streak_last_day_date,
5760
last_activity_date=self.last_activity_date,
61+
tos_acceptance_date=self.tos_acceptance_date,
5862
)

backend/oasst_backend/prompt_repository.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from sqlalchemy.orm import Query
3636
from sqlalchemy.orm.attributes import flag_modified
3737
from sqlmodel import JSON, Session, and_, func, literal_column, not_, or_, text, update
38-
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
3938

4039

4140
class PromptRepository:
@@ -77,7 +76,14 @@ def ensure_user_is_enabled(self):
7776
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
7877

7978
if self.user.deleted or not self.user.enabled:
80-
raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED)
79+
raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED, HTTPStatus.SERVICE_UNAVAILABLE)
80+
81+
if self.user.tos_acceptance_date is None and not settings.DEBUG_IGNORE_TOS_ACCEPTANCE:
82+
raise OasstError(
83+
"User has not accepted terms of service.",
84+
OasstErrorCode.USER_HAS_NOT_ACCEPTED_TOS,
85+
HTTPStatus.UNAVAILABLE_FOR_LEGAL_REASONS,
86+
)
8187

8288
def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
8389
validate_frontend_message_id(frontend_message_id)
@@ -90,7 +96,7 @@ def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if
9096
raise OasstError(
9197
f"Message with frontend_message_id {frontend_message_id} not found.",
9298
OasstErrorCode.MESSAGE_NOT_FOUND,
93-
HTTP_404_NOT_FOUND,
99+
HTTPStatus.NOT_FOUND,
94100
)
95101
return message
96102

@@ -675,7 +681,7 @@ def fetch_message(self, message_id: UUID, fail_if_missing: bool = True) -> Optio
675681

676682
message = self.db.query(Message).filter(Message.id == message_id).one_or_none()
677683
if fail_if_missing and not message:
678-
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
684+
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTPStatus.NOT_FOUND)
679685
return message
680686

681687
def fetch_non_task_text_labels(self, message_id: UUID, user_id: UUID) -> Optional[TextLabels]:
@@ -874,7 +880,7 @@ def query_messages_ordered_by_created_date(
874880

875881
if api_client_id != self.api_client.id:
876882
# Unprivileged api client asks for foreign messages
877-
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
883+
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTPStatus.FORBIDDEN)
878884

879885
qry = self.db.query(Message)
880886
if user_id:

backend/oasst_backend/user_repository.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def update_user(
7373
enabled: Optional[bool] = None,
7474
notes: Optional[str] = None,
7575
show_on_leaderboard: Optional[bool] = None,
76-
) -> None:
76+
tos_acceptance: Optional[bool] = None,
77+
) -> User:
7778
"""
7879
Update a user by global user ID to disable or set admin notes. Only trusted clients may update users.
7980
@@ -94,8 +95,11 @@ def update_user(
9495
user.notes = notes
9596
if show_on_leaderboard is not None:
9697
user.show_on_leaderboard = show_on_leaderboard
98+
if tos_acceptance:
99+
user.tos_acceptance_date = utcnow()
97100

98101
self.db.add(user)
102+
return user
99103

100104
@managed_tx_method(CommitMode.COMMIT)
101105
def mark_user_deleted(self, id: UUID) -> None:
@@ -143,8 +147,10 @@ def _lookup_user_tx(
143147
display_name=display_name,
144148
api_client_id=self.api_client.id,
145149
auth_method=auth_method,
146-
show_on_leaderboard=(auth_method != "system"), # don't show system users, e.g. import user
147150
)
151+
if auth_method == "system":
152+
user.show_on_leaderboard = False # don't show system users, e.g. import user
153+
user.tos_acceptance_date = utcnow()
148154
self.db.add(user)
149155
elif display_name and display_name != user.display_name:
150156
# we found the user but the display name changed
@@ -156,6 +162,10 @@ def _lookup_user_tx(
156162
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> User | None:
157163
if not client_user:
158164
return None
165+
166+
if not (client_user.auth_method and client_user.id):
167+
raise OasstError("Auth method or username missing.", OasstErrorCode.AUTH_AND_USERNAME_REQUIRED)
168+
159169
num_retries = settings.DATABASE_MAX_TX_RETRY_COUNT
160170
for i in range(num_retries):
161171
try:

oasst-shared/oasst_shared/exceptions/oasst_api_error.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ class OasstErrorCode(IntEnum):
8080
USER_NOT_SPECIFIED = 4000
8181
USER_DISABLED = 4001
8282
USER_NOT_FOUND = 4002
83+
USER_HAS_NOT_ACCEPTED_TOS = 4003
8384

8485
EMOJI_OP_UNSUPPORTED = 5000
8586

oasst-shared/oasst_shared/schemas/protocol.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class FrontEndUser(User):
3939
streak_days: Optional[int] = None
4040
streak_last_day_date: Optional[datetime] = None
4141
last_activity_date: Optional[datetime] = None
42+
tos_acceptance_date: Optional[datetime] = None
4243

4344

4445
class PageResult(BaseModel):
@@ -499,3 +500,10 @@ class MessageEmojiRequest(BaseModel):
499500
user: User
500501
op: EmojiOp = EmojiOp.togggle
501502
emoji: EmojiCode
503+
504+
505+
class CreateFrontendUserRequest(User):
506+
show_on_leaderboard: bool = True
507+
enabled: bool = True
508+
tos_acceptance: Optional[bool] = None
509+
notes: Optional[str] = None

0 commit comments

Comments
 (0)