Skip to content

Commit 558b207

Browse files
authored
Add /messages/{message_id}/emoji endpoint to toggle, add, remove message emojis (LAION-AI#925)
* add endpoint to set message emojis * make refresh result optional in db utils
1 parent 4146930 commit 558b207

File tree

10 files changed

+187
-0
lines changed

10 files changed

+187
-0
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""add message_emoji
2+
3+
Revision ID: 40ed93df0ed5
4+
Revises: 8ba17b5f467a
5+
Create Date: 2023-01-24 22:56:28.229408
6+
7+
"""
8+
import sqlalchemy as sa
9+
import sqlmodel
10+
from alembic import op
11+
from sqlalchemy.dialects import postgresql
12+
13+
# revision identifiers, used by Alembic.
14+
revision = "40ed93df0ed5"
15+
down_revision = "8ba17b5f467a"
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade() -> None:
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.create_table(
23+
"message_emoji",
24+
sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False),
25+
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
26+
sa.Column(
27+
"created_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False
28+
),
29+
sa.Column("emoji", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
30+
sa.ForeignKeyConstraint(["message_id"], ["message.id"], ondelete="CASCADE"),
31+
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
32+
sa.PrimaryKeyConstraint("message_id", "user_id", "emoji"),
33+
)
34+
op.create_index("ix_message_emoji__user_id__message_id", "message_emoji", ["user_id", "message_id"], unique=False)
35+
op.add_column("message", sa.Column("emojis", postgresql.JSONB(astext_type=sa.Text()), nullable=True))
36+
# ### end Alembic commands ###
37+
38+
39+
def downgrade() -> None:
40+
# ### commands auto generated by Alembic - please adjust! ###
41+
op.drop_column("message", "emojis")
42+
op.drop_index("ix_message_emoji__user_id__message_id", table_name="message_emoji")
43+
op.drop_table("message_emoji")
44+
# ### end Alembic commands ###

backend/oasst_backend/api/v1/messages.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from oasst_backend.api.v1 import utils
88
from oasst_backend.models import ApiClient
99
from oasst_backend.prompt_repository import PromptRepository
10+
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
1011
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
1112
from oasst_shared.schemas import protocol
1213
from sqlmodel import Session
@@ -229,3 +230,22 @@ def mark_message_deleted(
229230
):
230231
pr = PromptRepository(db, api_client)
231232
pr.mark_messages_deleted(message_id)
233+
234+
235+
@router.post("/{message_id}/emoji", response_model=protocol.Message)
236+
def post_message_emoji(
237+
*,
238+
message_id: UUID,
239+
request: protocol.MessageEmojiRequest,
240+
api_client: ApiClient = Depends(deps.get_api_client),
241+
) -> protocol.Message:
242+
"""
243+
Toggle, add or remove message emoji.
244+
"""
245+
246+
@managed_tx_function(CommitMode.COMMIT)
247+
def emoji_tx(session: deps.Session):
248+
pr = PromptRepository(session, api_client, client_user=request.user)
249+
return pr.handle_message_emoji(message_id, request.op, request.emoji)
250+
251+
return utils.prepare_message(emoji_tx())

backend/oasst_backend/api/v1/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def prepare_message(m: Message) -> protocol.Message:
1414
lang=m.lang,
1515
is_assistant=(m.role == "assistant"),
1616
created_date=m.created_date,
17+
emojis=m.emojis,
1718
)
1819

1920

backend/oasst_backend/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .journal import Journal, JournalIntegration
33
from .message import Message
44
from .message_embedding import MessageEmbedding
5+
from .message_emoji import MessageEmoji
56
from .message_reaction import MessageReaction
67
from .message_toxicity import MessageToxicity
78
from .message_tree_state import MessageTreeState
@@ -24,4 +25,5 @@
2425
"TextLabels",
2526
"Journal",
2627
"JournalIntegration",
28+
"MessageEmoji",
2729
]

backend/oasst_backend/models/message.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ class Message(SQLModel, table=True):
4949

5050
rank: Optional[int] = Field(nullable=True)
5151

52+
emojis: dict[str, int] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False)
53+
5254
def ensure_is_message(self) -> None:
5355
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
5456
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from datetime import datetime
2+
from typing import Optional
3+
from uuid import UUID
4+
5+
import sqlalchemy as sa
6+
import sqlalchemy.dialects.postgresql as pg
7+
from sqlmodel import Field, Index, SQLModel
8+
9+
10+
class MessageEmoji(SQLModel, table=True):
11+
__tablename__ = "message_emoji"
12+
__table_args__ = (Index("ix_message_emoji__user_id__message_id", "user_id", "message_id", unique=False),)
13+
14+
message_id: Optional[UUID] = Field(
15+
sa_column=sa.Column(
16+
pg.UUID(as_uuid=True), sa.ForeignKey("message.id", ondelete="CASCADE"), nullable=False, primary_key=True
17+
)
18+
)
19+
user_id: UUID = Field(
20+
sa_column=sa.Column(
21+
pg.UUID(as_uuid=True), sa.ForeignKey("user.id", ondelete="CASCADE"), nullable=False, primary_key=True
22+
)
23+
)
24+
emoji: str = Field(nullable=False, max_length=128, primary_key=True)
25+
created_date: Optional[datetime] = Field(
26+
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
27+
)

backend/oasst_backend/prompt_repository.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
ApiClient,
1414
Message,
1515
MessageEmbedding,
16+
MessageEmoji,
1617
MessageReaction,
1718
MessageToxicity,
1819
MessageTreeState,
@@ -29,6 +30,7 @@
2930
from oasst_shared.schemas import protocol as protocol_schema
3031
from oasst_shared.schemas.protocol import SystemStats
3132
from oasst_shared.utils import unaware_to_utc
33+
from sqlalchemy.orm.attributes import flag_modified
3234
from sqlmodel import Session, and_, func, not_, or_, text, update
3335
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
3436

@@ -843,3 +845,62 @@ def get_stats(self) -> SystemStats:
843845
deleted=result.get(True, 0),
844846
message_trees=result.get(None, 0),
845847
)
848+
849+
def handle_message_emoji(self, message_id: UUID, op: protocol_schema.EmojiOp, emoji: protocol_schema) -> Message:
850+
self.ensure_user_is_enabled()
851+
852+
message = self.fetch_message(message_id)
853+
854+
# check if emoji exists
855+
existing_emoji = (
856+
self.db.query(MessageEmoji)
857+
.filter(
858+
MessageEmoji.message_id == message_id, MessageEmoji.user_id == self.user_id, MessageEmoji.emoji == emoji
859+
)
860+
.one_or_none()
861+
)
862+
863+
if existing_emoji:
864+
if op == protocol_schema.EmojiOp.add:
865+
logger.info(f"Emoji record already exists {message_id=}, {emoji=}, {self.user_id=}")
866+
return message
867+
elif op == protocol_schema.EmojiOp.togggle:
868+
op = protocol_schema.EmojiOp.remove
869+
870+
if existing_emoji is None:
871+
if op == protocol_schema.EmojiOp.remove:
872+
logger.info(f"Emoji record not found {message_id=}, {emoji=}, {self.user_id=}")
873+
return message
874+
elif op == protocol_schema.EmojiOp.togggle:
875+
op = protocol_schema.EmojiOp.add
876+
877+
if op == protocol_schema.EmojiOp.add:
878+
# insert emoji record & increment count
879+
message_emoji = MessageEmoji(message_id=message.id, user_id=self.user_id, emoji=emoji)
880+
self.db.add(message_emoji)
881+
emoji_counts = message.emojis
882+
if not emoji_counts:
883+
message.emojis = {emoji.value: 1}
884+
else:
885+
count = emoji_counts.get(emoji.value) or 0
886+
emoji_counts[emoji.value] = count + 1
887+
elif op == protocol_schema.EmojiOp.remove:
888+
# remove emoji record and & decrement count
889+
message = self.fetch_message(message_id)
890+
self.db.delete(existing_emoji)
891+
emoji_counts = message.emojis
892+
count = emoji_counts.get(emoji.value)
893+
if count is not None:
894+
if count == 1:
895+
del emoji_counts[emoji.value]
896+
else:
897+
emoji_counts[emoji.value] = count - 1
898+
flag_modified(message, "emojis")
899+
self.db.add(message)
900+
else:
901+
raise OasstError("Emoji op not supported", OasstErrorCode.EMOJI_OP_UNSUPPORTED)
902+
903+
flag_modified(message, "emojis")
904+
self.db.add(message)
905+
self.db.flush()
906+
return message

backend/oasst_backend/utils/database_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def managed_tx_function(
107107
auto_commit: CommitMode = CommitMode.COMMIT,
108108
num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT,
109109
session_factory: Callable[..., Session] = default_session_factor,
110+
refresh_result: bool = True,
110111
):
111112
"""Passes Session object as first argument to wrapped function."""
112113

@@ -124,6 +125,8 @@ def wrapped_f(*args, **kwargs):
124125
session.flush()
125126
elif auto_commit == CommitMode.ROLLBACK:
126127
session.rollback()
128+
if refresh_result and isinstance(result, SQLModel):
129+
session.refresh(result)
127130
return result
128131
except OperationalError:
129132
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")

oasst-shared/oasst_shared/exceptions/oasst_api_error.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ class OasstErrorCode(IntEnum):
7676
USER_DISABLED = 4001
7777
USER_NOT_FOUND = 4002
7878

79+
EMOJI_OP_UNSUPPORTED = 5000
80+
7981

8082
class OasstError(Exception):
8183
"""Base class for Open-Assistant exceptions."""

oasst-shared/oasst_shared/schemas/protocol.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def is_prompter_turn(self) -> bool:
8080
class Message(ConversationMessage):
8181
parent_id: Optional[UUID] = None
8282
created_date: Optional[datetime] = None
83+
emojis: Optional[dict] = None
8384

8485

8586
class MessagePage(PageResult):
@@ -432,3 +433,27 @@ class OasstErrorResponse(BaseModel):
432433

433434
error_code: OasstErrorCode
434435
message: str
436+
437+
438+
class EmojiCode(str, enum.Enum):
439+
thumbs_up = "+1" # 👍
440+
thumbs_down = "-1" # 👎
441+
red_flag = "red_flag" # 🚩
442+
hundred = "100" # 💯
443+
rofl = "rofl" # 🤣"
444+
heart_eyes = "heart_eyes" # 😍
445+
disappointed = "disappointed" # 😞
446+
poop = "poop" # 💩
447+
skull = "skull" # 💀
448+
449+
450+
class EmojiOp(str, enum.Enum):
451+
togggle = "toggle"
452+
add = "add"
453+
remove = "remove"
454+
455+
456+
class MessageEmojiRequest(BaseModel):
457+
user: User
458+
op: EmojiOp = EmojiOp.togggle
459+
emoji: EmojiCode

0 commit comments

Comments
 (0)