Skip to content

Commit 14fa08e

Browse files
Message tree state machine (LAION-AI#555)
* add query_incomplete_rankings() * Add SQL queries for TreeManager task selection * first working version of TreeManager.next_task() * remove old generate_task(), add mandatory_labels to text_labels task * Add ConversationMessage list to Ranking tasks * add more sophisticated sql queries to find extendible trees * add TreeManager.query_extendible_parents() * fix task validation, seed data insertion (reviewed) * provide user for task selection in text-frontend * enter 'growing' state * enter 'aborted_low_grade' state * enter 'ranking' state * check tree 'growing' state upon relpy insertion * exclude user from labeling their own messages (added DEBUG_ALLOW_SELF_LABELING setting) * add DEBUG_ALLOW_SELF_LABELING to docker-compose.yaml * fix ranking submission * add query_tree_ranking_results() * add ranked_message_ids to RankingReactionPayload * fix reply_messages instead of prompt_messages * incorment 'ranking_count' of ranked replies * added logic to check_condition_for_scoring_state * changes to msg_tree_state_machine * pre-commit changes * enter 'ready_for_scoring' state * re-add HF embedding call (lost during merge) * use prepare_conversation() helper for seed-data creation * Partially add user specified task selection Co-authored-by: Daniel Hug <danielpatrickhug@gmail.com>
1 parent 23ff01c commit 14fa08e

19 files changed

+1209
-320
lines changed

ansible/dev.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
REDIS_HOST: oasst-redis
8080
DEBUG_ALLOW_ANY_API_KEY: "true"
8181
DEBUG_USE_SEED_DATA: "true"
82+
DEBUG_ALLOW_SELF_LABELING: "true"
8283
MAX_WORKERS: "1"
8384
RATE_LIMIT: "false"
8485
DEBUG_SKIP_EMBEDDING_COMPUTATION: "true"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""restructure message_tree_state table
2+
3+
Revision ID: 92a367bb9f40
4+
Revises: ba61fe17fb6e
5+
Create Date: 2023-01-08 22:08:46.458195
6+
7+
"""
8+
import sqlalchemy as sa
9+
import sqlmodel
10+
from alembic import op
11+
from sqlalchemy.dialects import postgresql
12+
13+
# revision identifiers, used by Alembic.
14+
revision = "92a367bb9f40"
15+
down_revision = "aac6b2f66006"
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade() -> None:
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.drop_table("message_tree_state")
23+
op.create_table(
24+
"message_tree_state",
25+
sa.Column("message_tree_id", postgresql.UUID(as_uuid=True), nullable=False),
26+
sa.Column("goal_tree_size", sa.Integer(), nullable=False),
27+
sa.Column("max_depth", sa.Integer(), nullable=False),
28+
sa.Column("max_children_count", sa.Integer(), nullable=False),
29+
sa.Column("state", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
30+
sa.Column("active", sa.Boolean(), nullable=False),
31+
sa.Column("accepted_messages", sa.Integer(), nullable=False),
32+
sa.ForeignKeyConstraint(
33+
["message_tree_id"],
34+
["message.id"],
35+
),
36+
sa.PrimaryKeyConstraint("message_tree_id"),
37+
)
38+
op.create_index(op.f("ix_message_tree_state_active"), "message_tree_state", ["active"], unique=False)
39+
op.create_index(op.f("ix_message_tree_state_state"), "message_tree_state", ["state"], unique=False)
40+
41+
# ### end Alembic commands ###
42+
43+
44+
def downgrade() -> None:
45+
# ### commands auto generated by Alembic - please adjust! ###
46+
op.drop_index(op.f("ix_message_tree_state_state"), table_name="message_tree_state")
47+
op.drop_index(op.f("ix_message_tree_state_active"), table_name="message_tree_state")
48+
op.drop_table("message_tree_state")
49+
op.create_table(
50+
"message_tree_state",
51+
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
52+
sa.Column("message_tree_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
53+
sa.Column("state", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
54+
sa.Column("goal_tree_size", sa.Integer(), nullable=False),
55+
sa.Column("current_num_non_filtered_messages", sa.Integer(), nullable=False),
56+
sa.Column("max_depth", sa.Integer(), nullable=False),
57+
sa.PrimaryKeyConstraint("id"),
58+
)
59+
op.create_index(
60+
op.f("ix_message_tree_state_message_tree_id"), "message_tree_state", ["message_tree_id"], unique=False
61+
)
62+
op.create_index("ix_message_tree_state_tree_id", "message_tree_state", ["message_tree_id"], unique=True)
63+
# ### end Alembic commands ###
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""add review_count & ranking_count to message
2+
3+
Revision ID: 05975b274a81
4+
Revises: 92a367bb9f40
5+
Create Date: 2023-01-09 00:47:25.496036
6+
7+
"""
8+
import sqlalchemy as sa
9+
from alembic import op
10+
11+
# revision identifiers, used by Alembic.
12+
revision = "05975b274a81"
13+
down_revision = "92a367bb9f40"
14+
branch_labels = None
15+
depends_on = None
16+
17+
18+
def upgrade() -> None:
19+
# ### commands auto generated by Alembic - please adjust! ###
20+
op.add_column("message", sa.Column("review_count", sa.Integer(), server_default=sa.text("0"), nullable=False))
21+
op.add_column("message", sa.Column("review_result", sa.Boolean(), server_default=sa.text("false"), nullable=False))
22+
op.add_column("message", sa.Column("ranking_count", sa.Integer(), server_default=sa.text("0"), nullable=False))
23+
# ### end Alembic commands ###
24+
25+
26+
def downgrade() -> None:
27+
# ### commands auto generated by Alembic - please adjust! ###
28+
op.drop_column("message", "ranking_count")
29+
op.drop_column("message", "review_result")
30+
op.drop_column("message", "review_count")
31+
# ### end Alembic commands ###

backend/main.py

+29-13
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
from loguru import logger
1313
from oasst_backend.api.deps import get_dummy_api_client
1414
from oasst_backend.api.v1.api import api_router
15+
from oasst_backend.api.v1.utils import prepare_conversation
1516
from oasst_backend.config import settings
1617
from oasst_backend.database import engine
18+
from oasst_backend.models import message_tree_state
1719
from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
20+
from oasst_backend.tree_manager import TreeManager, TreeManagerConfiguration
1821
from oasst_shared.exceptions import OasstError, OasstErrorCode
1922
from oasst_shared.schemas import protocol as protocol_schema
2023
from pydantic import BaseModel
@@ -116,6 +119,7 @@ class DummyMessage(BaseModel):
116119
pr = PromptRepository(
117120
db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
118121
)
122+
tm = TreeManager(db, pr, TreeManagerConfiguration())
119123

120124
with open(settings.DEBUG_USE_SEED_DATA_PATH) as f:
121125
dummy_messages_raw = json.load(f)
@@ -138,24 +142,19 @@ class DummyMessage(BaseModel):
138142
msg.parent_message_id, fail_if_missing=True
139143
)
140144
conversation_messages = pr.fetch_message_conversation(parent_message)
141-
conversation = protocol_schema.Conversation(
142-
messages=[
143-
protocol_schema.ConversationMessage(
144-
text=cmsg.text,
145-
is_assistant=cmsg.role == "assistant",
146-
message_id=cmsg.id,
147-
fronend_message_id=cmsg.frontend_message_id,
148-
)
149-
for cmsg in conversation_messages
150-
]
151-
)
145+
conversation = prepare_conversation(conversation_messages)
152146
task = tr.store_task(
153147
protocol_schema.AssistantReplyTask(conversation=conversation),
154148
message_tree_id=parent_message.message_tree_id,
155149
parent_message_id=parent_message.id,
156150
)
157151
tr.bind_frontend_message_id(task.id, msg.task_message_id)
158-
message = pr.store_text_reply(msg.text, msg.task_message_id, msg.user_message_id)
152+
message = pr.store_text_reply(
153+
msg.text, msg.task_message_id, msg.user_message_id, review_count=5, review_result=True
154+
)
155+
if message.parent_id is None:
156+
tm._insert_default_state(root_message_id=message.id, state=message_tree_state.State.GROWING)
157+
db.commit()
159158

160159
logger.info(
161160
f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}"
@@ -168,14 +167,27 @@ class DummyMessage(BaseModel):
168167
logger.exception("Seed data insertion failed")
169168

170169

170+
@app.on_event("startup")
171+
def ensure_tree_states():
172+
try:
173+
logger.info("Startup: TreeManager.ensure_tree_states()")
174+
cfg = TreeManagerConfiguration() # TODO: decide where config is stored, e.g. load form json/yaml file
175+
with Session(engine) as db:
176+
tm = TreeManager(db, None, configuration=cfg)
177+
tm.ensure_tree_states()
178+
179+
except Exception:
180+
logger.exception("TreeManager.ensure_tree_states() failed.")
181+
182+
171183
app.include_router(api_router, prefix=settings.API_V1_STR)
172184

173185

174186
def get_openapi_schema():
175187
return json.dumps(app.openapi())
176188

177189

178-
if __name__ == "__main__":
190+
def main():
179191
# Importing here so we don't import packages unnecessarily if we're
180192
# importing main as a module.
181193
import argparse
@@ -198,3 +210,7 @@ def get_openapi_schema():
198210
print(get_openapi_schema())
199211
else:
200212
uvicorn.run(app, host=args.host, port=args.port)
213+
214+
215+
if __name__ == "__main__":
216+
main()

0 commit comments

Comments
 (0)