Skip to content

Commit c8d1628

Browse files
authored
Import message trees from jsonl file (LAION-AI#964)
* add new backlog_ranking tree state * add first version of import script * allow activation of trees during import * add min_active_rankings_per_lang config param * add settings docstring
1 parent b2eb949 commit c8d1628

12 files changed

+377
-49
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""add origin column to message_tree_state
2+
3+
Revision ID: 49d8445b4c90
4+
Revises: f856bf19d32b
5+
Create Date: 2023-01-28 11:57:45.580027
6+
7+
"""
8+
import sqlalchemy as sa
9+
from alembic import op
10+
11+
# revision identifiers, used by Alembic.
12+
revision = "49d8445b4c90"
13+
down_revision = "f856bf19d32b"
14+
branch_labels = None
15+
depends_on = None
16+
17+
18+
def upgrade() -> None:
19+
# ### commands auto generated by Alembic - please adjust! ###
20+
op.add_column("message", sa.Column("synthetic", sa.Boolean(), server_default=sa.text("false"), nullable=False))
21+
op.add_column("message", sa.Column("model_name", sa.String(length=1024), nullable=True))
22+
op.add_column("message_tree_state", sa.Column("origin", sa.String(length=1024), nullable=True))
23+
# ### end Alembic commands ###
24+
25+
26+
def downgrade() -> None:
27+
# ### commands auto generated by Alembic - please adjust! ###
28+
op.drop_column("message_tree_state", "origin")
29+
op.drop_column("message", "model_name")
30+
op.drop_column("message", "synthetic")
31+
# ### end Alembic commands ###

backend/import.py

+187
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import argparse
2+
import json
3+
from pathlib import Path
4+
from typing import Optional
5+
from uuid import UUID
6+
7+
import oasst_backend.models.db_payload as db_payload
8+
import oasst_backend.utils.database_utils as db_utils
9+
import pydantic
10+
from loguru import logger
11+
from oasst_backend.api.deps import create_api_client
12+
from oasst_backend.models import ApiClient, Message
13+
from oasst_backend.models.message_tree_state import MessageTreeState
14+
from oasst_backend.models.message_tree_state import State as TreeState
15+
from oasst_backend.models.payload_column_type import PayloadContainer
16+
from oasst_backend.prompt_repository import PromptRepository
17+
from oasst_backend.user_repository import UserRepository
18+
from oasst_backend.utils.tree_export import ExportMessageNode, ExportMessageTree
19+
from sqlmodel import Session
20+
21+
# well known id
22+
IMPORT_API_CLIENT_ID = UUID("bd8fde8b-1d8e-4e9a-9966-e96d000f8363")
23+
24+
25+
class Importer:
26+
def __init__(self, db: Session, origin: str, model_name: Optional[str] = None):
27+
self.db = db
28+
self.origin = origin
29+
self.model_name = model_name
30+
31+
# get import api client
32+
api_client = db.query(ApiClient).filter(ApiClient.id == IMPORT_API_CLIENT_ID).first()
33+
if not api_client:
34+
api_client = create_api_client(
35+
session=db,
36+
description="API client used for importing data",
37+
frontend_type="import",
38+
force_id=IMPORT_API_CLIENT_ID,
39+
)
40+
41+
ur = UserRepository(db, api_client)
42+
self.import_user = ur.lookup_system_user(username="import")
43+
self.pr = PromptRepository(db=db, api_client=api_client, user_repository=ur)
44+
self.api_client = api_client
45+
46+
def fetch_message_tree_state(self, message_tree_id: UUID) -> MessageTreeState:
47+
return self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one_or_none()
48+
49+
def import_message(
50+
self, message: ExportMessageNode, message_tree_id: UUID, parent_id: Optional[UUID] = None
51+
) -> Message:
52+
payload = db_payload.MessagePayload(text=message.text)
53+
msg = Message(
54+
id=message.message_id,
55+
message_tree_id=message_tree_id,
56+
frontend_message_id=message.message_id,
57+
parent_id=parent_id,
58+
review_count=message.review_count or 0,
59+
lang=message.lang or "en",
60+
review_result=True,
61+
synthetic=message.synthetic if message.synthetic is not None else True,
62+
model_name=message.model_name or self.model_name,
63+
role=message.role,
64+
api_client_id=self.api_client.id,
65+
payload_type=type(payload).__name__,
66+
payload=PayloadContainer(payload=payload),
67+
user_id=self.import_user.id,
68+
)
69+
self.db.add(msg)
70+
if message.replies:
71+
for r in message.replies:
72+
self.import_message(r, message_tree_id=message_tree_id, parent_id=msg.id)
73+
self.db.flush()
74+
if parent_id is None:
75+
self.pr.update_children_counts(msg.id)
76+
self.db.refresh(msg)
77+
return msg
78+
79+
def import_tree(
80+
self, tree: ExportMessageTree, state: TreeState = TreeState.BACKLOG_RANKING
81+
) -> tuple[MessageTreeState, Message]:
82+
assert tree.message_tree_id is not None and tree.message_tree_id == tree.prompt.message_id
83+
root_msg = self.import_message(tree.prompt, message_tree_id=tree.prompt.message_id)
84+
assert state == TreeState.BACKLOG_RANKING or state == TreeState.RANKING, f"{state} not supported for import"
85+
active = state == TreeState.RANKING
86+
mts = MessageTreeState(
87+
message_tree_id=root_msg.id,
88+
goal_tree_size=0,
89+
max_depth=0,
90+
max_children_count=0,
91+
state=state,
92+
origin=self.origin,
93+
active=active,
94+
)
95+
self.db.add(mts)
96+
return mts, root_msg
97+
98+
99+
def import_file(
100+
input_file_path: Path,
101+
origin: str,
102+
*,
103+
model_name: Optional[str] = None,
104+
num_activate: int = 0,
105+
max_count: Optional[int] = None,
106+
dry_run: bool = False,
107+
) -> int:
108+
@db_utils.managed_tx_function(auto_commit=db_utils.CommitMode.ROLLBACK if dry_run else db_utils.CommitMode.COMMIT)
109+
def import_tx(db: Session) -> int:
110+
importer = Importer(db, origin=origin, model_name=model_name)
111+
i = 0
112+
with input_file_path.open() as file_in:
113+
# read line tree object
114+
for line in file_in:
115+
dict_tree = json.loads(line)
116+
117+
# validate data
118+
tree: ExportMessageTree = pydantic.parse_obj_as(ExportMessageTree, dict_tree)
119+
existing_mts = importer.fetch_message_tree_state(tree.message_tree_id)
120+
if existing_mts:
121+
logger.info(f"Skipping existing message tree: {tree.message_tree_id}")
122+
else:
123+
state = TreeState.BACKLOG_RANKING if i >= num_activate else TreeState.RANKING
124+
mts, root_msg = importer.import_tree(tree, state=state)
125+
i += 1
126+
logger.info(
127+
f"imported tree: {mts.message_tree_id}, {mts.state=}, {mts.active=}, {root_msg.children_count=}"
128+
)
129+
130+
if max_count and i >= max_count:
131+
logger.info(f"Reached max count {max_count} of trees to import.")
132+
break
133+
return i
134+
135+
if dry_run:
136+
logger.info("DRY RUN with rollback")
137+
return import_tx()
138+
139+
140+
def parse_args():
141+
def str2bool(v):
142+
if isinstance(v, bool):
143+
return v
144+
if v.lower() in ("yes", "true", "t", "y", "1"):
145+
return True
146+
elif v.lower() in ("no", "false", "f", "n", "0"):
147+
return False
148+
else:
149+
raise argparse.ArgumentTypeError("Boolean value expected.")
150+
151+
parser = argparse.ArgumentParser()
152+
parser.add_argument(
153+
"input_file_path",
154+
help="Input file path",
155+
)
156+
parser.add_argument("--origin", type=str, default=None, help="Value for origin of message trees")
157+
parser.add_argument("--model_name", type=str, default=None, help="Default name of model (if missing in messages)")
158+
parser.add_argument("--num_activate", type=int, default=0, help="Number of trees to add in ranking state")
159+
parser.add_argument("--max_count", type=int, default=None, help="Maximum number of message trees to import")
160+
parser.add_argument("--dry_run", type=str2bool, default=False)
161+
args = parser.parse_args()
162+
return args
163+
164+
165+
def main():
166+
args = parse_args()
167+
168+
input_file_path = Path(args.input_file_path)
169+
if not input_file_path.exists() or not input_file_path.is_file():
170+
print("Invalid input file:", args.input_file_path)
171+
exit(1)
172+
173+
dry_run = args.dry_run
174+
num_imported = import_file(
175+
input_file_path,
176+
origin=args.origin or input_file_path.name,
177+
model_name=args.model_name,
178+
num_activate=args.num_activate,
179+
max_count=args.max_count,
180+
dry_run=dry_run,
181+
)
182+
183+
logger.info(f"Done ({num_imported=}, {dry_run=})")
184+
185+
186+
if __name__ == "__main__":
187+
main()

backend/main.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ class DummyMessage(BaseModel):
191191
review_count=5,
192192
review_result=True,
193193
check_tree_state=False,
194+
check_duplicate=False,
194195
)
195196
if message.parent_id is None:
196197
tm._insert_default_state(
@@ -215,7 +216,8 @@ def ensure_tree_states():
215216
try:
216217
logger.info("Startup: TreeManager.ensure_tree_states()")
217218
with Session(engine) as db:
218-
tm = TreeManager(db, None)
219+
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
220+
tm = TreeManager(db, PromptRepository(db, api_client=api_client))
219221
tm.ensure_tree_states()
220222

221223
except Exception:

backend/oasst_backend/api/deps.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from http import HTTPStatus
22
from secrets import token_hex
3-
from typing import Generator, NamedTuple
3+
from typing import Generator, NamedTuple, Optional
4+
from uuid import UUID
45

56
from fastapi import Depends, Request, Response, Security
67
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
@@ -67,6 +68,7 @@ def create_api_client(
6768
trusted: bool | None = False,
6869
admin_email: str | None = None,
6970
api_key: str | None = None,
71+
force_id: Optional[UUID] = None,
7072
) -> ApiClient:
7173
if api_key is None:
7274
api_key = token_hex(32)
@@ -79,6 +81,8 @@ def create_api_client(
7981
trusted=trusted,
8082
admin_email=admin_email,
8183
)
84+
if force_id:
85+
api_client.id = force_id
8286
session.add(api_client)
8387
session.commit()
8488
session.refresh(api_client)

backend/oasst_backend/config.py

+9
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ class TreeManagerConfiguration(BaseModel):
4646
num_required_rankings: int = 3
4747
"""Number of rankings in which the message participated."""
4848

49+
p_activate_backlog_tree: float = 0.8
50+
"""Probability to activate a message tree in BACKLOG_RANKING state when another tree enters
51+
a terminal state. Use this settting to control ratio of initial prompts and backlog tree
52+
activations."""
53+
54+
min_active_rankings_per_lang: int = 2
55+
"""When the number of active ranking tasks is below this value when a tree enters a terminal
56+
state an available trees in BACKLOG_RANKING will be actived (i.e. enters the RANKING state)."""
57+
4958
labels_initial_prompt: list[TextLabel] = [
5059
TextLabel.spam,
5160
TextLabel.quality,

backend/oasst_backend/models/message.py

+5
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ def __new__(cls, *args: Any, **kwargs: Any):
5757

5858
rank: Optional[int] = Field(nullable=True)
5959

60+
synthetic: Optional[bool] = Field(
61+
sa_column=sa.Column(sa.Boolean, default=False, server_default=false(), nullable=False)
62+
)
63+
model_name: Optional[str] = Field(sa_column=sa.Column(sa.String(1024), nullable=True))
64+
6065
emojis: Optional[dict[str, int]] = Field(default=None, sa_column=sa.Column(pg.JSONB), nullable=False)
6166
_user_emojis: Optional[list[str]] = PrivateAttr(default=None)
6267

backend/oasst_backend/models/message_tree_state.py

+5
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ class State(str, Enum):
4343
HALTED_BY_MODERATOR = "halted_by_moderator"
4444
"""A moderator decided to manually halt the message tree construction process."""
4545

46+
BACKLOG_RANKING = "backlog_ranking"
47+
"""Imported tree ready to be activated and ranked by users (currently inactive)."""
48+
4649

4750
VALID_STATES = (
4851
State.INITIAL_PROMPT_REVIEW,
@@ -51,6 +54,7 @@ class State(str, Enum):
5154
State.READY_FOR_SCORING,
5255
State.READY_FOR_EXPORT,
5356
State.ABORTED_LOW_GRADE,
57+
State.BACKLOG_RANKING,
5458
)
5559

5660
TERMINAL_STATES = (State.READY_FOR_EXPORT, State.ABORTED_LOW_GRADE, State.SCORING_FAILED, State.HALTED_BY_MODERATOR)
@@ -67,3 +71,4 @@ class MessageTreeState(SQLModel, table=True):
6771
max_children_count: int = Field(nullable=False)
6872
state: str = Field(nullable=False, max_length=128, index=True)
6973
active: bool = Field(nullable=False, index=True)
74+
origin: str = Field(sa_column=sa.Column(sa.String(1024), nullable=True))

backend/oasst_backend/prompt_repository.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def store_text_reply(
177177
review_count: int = 0,
178178
review_result: bool = False,
179179
check_tree_state: bool = True,
180+
check_duplicate: bool = True,
180181
) -> Message:
181182
self.ensure_user_is_enabled()
182183

@@ -199,7 +200,7 @@ def store_text_reply(
199200
logger.error(f"Message size {len(text)=} exceeds size limit of {settings.MESSAGE_SIZE_LIMIT=}.")
200201
raise OasstError("Message size too long.", OasstErrorCode.TASK_MESSAGE_TOO_LONG)
201202

202-
if self.check_users_recent_replies_for_duplicates(text):
203+
if check_duplicate and self.check_users_recent_replies_for_duplicates(text):
203204
raise OasstError("User recent messages have duplicates", OasstErrorCode.TASK_MESSAGE_DUPLICATED)
204205

205206
if task.parent_message_id:
@@ -909,8 +910,7 @@ def update_children_counts(self, message_tree_id: UUID):
909910
) AS cc
910911
WHERE message.id = cc.id;
911912
"""
912-
r = self.db.execute(text(sql_update_children_count), {"message_tree_id": message_tree_id})
913-
logger.debug(f"update_children_count({message_tree_id=}): {r.rowcount} rows.")
913+
self.db.execute(text(sql_update_children_count), {"message_tree_id": message_tree_id})
914914

915915
@managed_tx_method(CommitMode.COMMIT)
916916
def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True):

0 commit comments

Comments
 (0)