Skip to content

Commit d407a1f

Browse files
authored
Store null in review_result column during review phase (LAION-AI#1599)
1 parent 7958422 commit d407a1f

File tree

4 files changed

+49
-5
lines changed

4 files changed

+49
-5
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""message review_result nullable
2+
3+
Revision ID: 8cd0c34d0c3c
4+
Revises: 165b55de5a94
5+
Create Date: 2023-02-15 17:54:58.029278
6+
7+
"""
8+
import sqlalchemy as sa
9+
from alembic import op
10+
11+
# revision identifiers, used by Alembic.
12+
revision = "8cd0c34d0c3c"
13+
down_revision = "165b55de5a94"
14+
branch_labels = None
15+
depends_on = None
16+
17+
18+
def upgrade() -> None:
19+
# ### commands auto generated by Alembic - please adjust! ###
20+
op.alter_column(
21+
"message",
22+
"review_result",
23+
existing_type=sa.BOOLEAN(),
24+
nullable=True,
25+
server_default=None,
26+
existing_server_default=sa.text("false"),
27+
)
28+
# ### end Alembic commands ###
29+
30+
31+
def downgrade() -> None:
32+
# ### commands auto generated by Alembic - please adjust! ###
33+
op.alter_column(
34+
"message", "review_result", existing_type=sa.BOOLEAN(), nullable=False, server_default=sa.text("false")
35+
)
36+
# ### end Alembic commands ###

backend/oasst_backend/models/message.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __new__(cls, *args: Any, **kwargs: Any):
5252
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
5353

5454
review_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
55-
review_result: bool = Field(sa_column=sa.Column(sa.Boolean, default=False, server_default=false(), nullable=False))
55+
review_result: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=True))
5656
ranking_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
5757

5858
rank: Optional[int] = Field(nullable=True)

backend/oasst_backend/prompt_repository.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def insert_message(
132132
payload_type: str = None,
133133
depth: int = 0,
134134
review_count: int = 0,
135-
review_result: bool = False,
135+
review_result: bool = None,
136136
) -> Message:
137137
if payload_type is None:
138138
if payload is None:
@@ -198,7 +198,7 @@ def store_text_reply(
198198
frontend_message_id: str,
199199
user_frontend_message_id: str,
200200
review_count: int = 0,
201-
review_result: bool = False,
201+
review_result: bool = None,
202202
check_tree_state: bool = True,
203203
check_duplicate: bool = True,
204204
) -> Message:

backend/oasst_backend/tree_manager.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,8 @@ async def handle_interaction(self, interaction: protocol_schema.AnyInteraction)
742742
f"Initial prompt message was accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}"
743743
)
744744
else:
745+
msg.review_result = False
746+
self.db.add(msg)
745747
self.enter_low_grade_state(msg.message_tree_id)
746748
self.check_condition_for_prompt_lottery(msg.message_tree_id)
747749
elif msg.review_count >= self.cfg.num_reviews_reply:
@@ -751,6 +753,9 @@ async def handle_interaction(self, interaction: protocol_schema.AnyInteraction)
751753
logger.info(
752754
f"Reply message message accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}"
753755
)
756+
else:
757+
msg.review_result = False
758+
self.db.add(msg)
754759

755760
self.check_condition_for_ranking_state(msg.message_tree_id)
756761

@@ -965,7 +970,7 @@ def _query_need_review(
965970
.filter(
966971
MessageTreeState.active,
967972
MessageTreeState.state == state,
968-
not_(Message.review_result),
973+
or_(Message.review_result.is_(None), not_(Message.review_result)),
969974
not_(Message.deleted),
970975
Message.review_count < required_reviews,
971976
Message.lang == lang,
@@ -1183,7 +1188,10 @@ def query_tree_size(self, message_tree_id: UUID) -> ActiveTreeSizeRow:
11831188
MessageTreeState.goal_tree_size.label("goal_tree_size"),
11841189
func.count(Message.id).filter(Message.review_result).label("tree_size"),
11851190
func.count(Message.id)
1186-
.filter(not_(Message.review_result), Message.review_count < required_reviews)
1191+
.filter(
1192+
or_(Message.review_result.is_(None), not_(Message.review_result)),
1193+
Message.review_count < required_reviews,
1194+
)
11871195
.label("awaiting_review"),
11881196
)
11891197
.select_from(MessageTreeState)

0 commit comments

Comments
 (0)