Skip to content

Commit 032a748

Browse files
677 - Add tree message export (LAION-AI#808)
* Added - Basic functions to export trees for users, export-ready trees and specific tree ids to files * Added print to logger by default for no file specified * linting to remove extra imports * Added cli for exporting trees which are ready to export Fixed some accidental removal Updated message lookup to use dict for better perf * removed unused imports * changed export flag for including deleted prompts back to include_deleted for better understandability * Use native collection types list, tuple, dict * pre-commit fix Co-authored-by: Andreas Köpf <andreas.koepf@provisio.com>
1 parent ffaf5c4 commit 032a748

File tree

4 files changed

+170
-13
lines changed

4 files changed

+170
-13
lines changed

backend/main.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,24 @@ def get_openapi_schema():
273273
return json.dumps(app.openapi())
274274

275275

276+
def export_ready_trees(file: Optional[str] = None, use_compression: bool = False):
277+
try:
278+
with Session(engine) as db:
279+
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
280+
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
281+
282+
ur = UserRepository(db=db, api_client=api_client)
283+
tr = TaskRepository(db=db, api_client=api_client, client_user=dummy_user, user_repository=ur)
284+
pr = PromptRepository(
285+
db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
286+
)
287+
tm = TreeManager(db, pr)
288+
289+
tm.export_all_ready_trees(file, use_compression=use_compression)
290+
except Exception:
291+
logger.exception("Error exporting trees.")
292+
293+
276294
def main():
277295
# Importing here so we don't import packages unnecessarily if we're
278296
# importing main as a module.
@@ -289,11 +307,21 @@ def main():
289307
)
290308
parser.add_argument("--host", help="The host to run the server", default="0.0.0.0")
291309
parser.add_argument("--port", help="The port to run the server", default=8080)
310+
parser.add_argument(
311+
"--export", help="Export all trees which are ready for exporting.", action=argparse.BooleanOptionalAction
312+
)
313+
parser.add_argument(
314+
"--export-file",
315+
help="Name of file to export trees to. If not provided when exporting, output will be send to STDOUT",
316+
)
292317

293318
args = parser.parse_args()
294319

295320
if args.print_openapi_schema:
296321
print(get_openapi_schema())
322+
elif args.export:
323+
use_compression: bool = ".gz" in args.export_file
324+
export_ready_trees(file=args.export_file, use_compression=use_compression)
297325
else:
298326
uvicorn.run(app, host=args.host, port=args.port)
299327

backend/oasst_backend/prompt_repository.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections import defaultdict
33
from datetime import datetime
44
from http import HTTPStatus
5-
from typing import List, Optional, Tuple
5+
from typing import Optional
66
from uuid import UUID, uuid4
77

88
import oasst_backend.models.db_payload as db_payload
@@ -255,7 +255,7 @@ def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction
255255
return reaction
256256

257257
@managed_tx_method(CommitMode.COMMIT)
258-
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> Tuple[MessageReaction, Task]:
258+
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> tuple[MessageReaction, Task]:
259259
# fetch task
260260
task = self.task_repository.fetch_task_by_frontend_message_id(ranking.message_id)
261261
self._validate_task(task, frontend_message_id=ranking.message_id)
@@ -345,13 +345,13 @@ def insert_toxicity(self, message_id: UUID, model: str, score: float, label: str
345345
return message_toxicity
346346

347347
@managed_tx_method(CommitMode.FLUSH)
348-
def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding:
348+
def insert_message_embedding(self, message_id: UUID, model: str, embedding: list[float]) -> MessageEmbedding:
349349
"""Insert the embedding of a new message in the database.
350350
351351
Args:
352352
message_id (UUID): the identifier of the message we want to save its embedding
353353
model (str): the model used for creating the embedding
354-
embedding (List[float]): the values obtained from the message & model
354+
embedding (list[float]): the values obtained from the message & model
355355
356356
Raises:
357357
OasstError: if misses some of the before params
@@ -383,7 +383,7 @@ def insert_reaction(
383383
return reaction
384384

385385
@managed_tx_method(CommitMode.FLUSH)
386-
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> Tuple[TextLabels, Task, Message]:
386+
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> tuple[TextLabels, Task, Message]:
387387

388388
valid_labels: Optional[list[str]] = None
389389
mandatory_labels: Optional[list[str]] = None
@@ -529,6 +529,22 @@ def fetch_message_tree(
529529
qry = qry.filter(not_(Message.deleted))
530530
return qry.all()
531531

532+
def fetch_user_message_trees(
533+
self, user_id: Message.user_id, reviewed: bool = True, include_deleted: bool = False
534+
) -> list[Message]:
535+
qry = self.db.query(Message).filter(Message.user_id == user_id)
536+
if reviewed:
537+
qry = qry.filter(Message.review_result)
538+
if not include_deleted:
539+
qry = qry.filter(not_(Message.deleted))
540+
return qry.all()
541+
542+
def fetch_message_trees_ready_for_export(self) -> list[MessageTreeState]:
543+
qry = self.db.query(MessageTreeState).filter(
544+
MessageTreeState.state == message_tree_state.State.READY_FOR_EXPORT
545+
)
546+
return qry.all()
547+
532548
def fetch_multiple_random_replies(self, max_size: int = 5, message_role: str = None):
533549
"""
534550
Fetch a conversation with multiple possible replies to it.

backend/oasst_backend/tree_manager.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import json
12
import random
3+
import sys
24
from datetime import datetime
35
from enum import Enum
46
from http import HTTPStatus
@@ -7,11 +9,13 @@
79

810
import numpy as np
911
import pydantic
12+
from fastapi.encoders import jsonable_encoder
1013
from loguru import logger
1114
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
1215
from oasst_backend.config import TreeManagerConfiguration, settings
1316
from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, User, message_tree_state
1417
from oasst_backend.prompt_repository import PromptRepository
18+
from oasst_backend.utils import tree_export
1519
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
1620
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
1721
from oasst_backend.utils.ranking import ranked_pairs
@@ -1184,14 +1188,55 @@ def purge_user(self, user_id: UUID, ban: bool = True) -> None:
11841188
if ban:
11851189
self.db.execute(update(User).filter(User.id == user_id).values(deleted=True, enabled=False))
11861190

1191+
def export_trees_to_file(
1192+
self,
1193+
message_tree_ids: list[str],
1194+
file=None,
1195+
reviewed: bool = True,
1196+
include_deleted: bool = False,
1197+
use_compression: bool = False,
1198+
) -> None:
1199+
trees_to_export: List[tree_export.ExportMessageTree] = []
1200+
1201+
for message_tree_id in message_tree_ids:
1202+
messages: List[Message] = self.pr.fetch_message_tree(message_tree_id, reviewed, include_deleted)
1203+
trees_to_export.append(tree_export.build_export_tree(message_tree_id, messages))
1204+
1205+
if file:
1206+
tree_export.write_trees_to_file(file, trees_to_export, use_compression)
1207+
else:
1208+
sys.stdout.write(json.dumps(jsonable_encoder(trees_to_export), indent=2))
1209+
1210+
def export_all_ready_trees(
1211+
self, file: str, reviewed: bool = True, include_deleted: bool = False, use_compression: bool = False
1212+
) -> None:
1213+
message_tree_states: MessageTreeState = self.pr.fetch_message_trees_ready_for_export()
1214+
message_tree_ids = [ms.message_tree_id for ms in message_tree_states]
1215+
self.export_trees_to_file(message_tree_ids, file, reviewed, include_deleted, use_compression)
1216+
1217+
def export_all_user_trees(
1218+
self,
1219+
user_id: str,
1220+
file: str,
1221+
reviewed: bool = True,
1222+
include_deleted: bool = False,
1223+
use_compression: bool = False,
1224+
) -> None:
1225+
messages = self.pr.fetch_user_message_trees(UUID(user_id))
1226+
message_tree_ids = [ms.message_tree_id for ms in messages]
1227+
self.export_trees_to_file(message_tree_ids, file, reviewed, include_deleted, use_compression)
1228+
11871229

11881230
if __name__ == "__main__":
11891231
from oasst_backend.api.deps import api_auth
1232+
1233+
# from oasst_backend.api.deps import create_api_client
11901234
from oasst_backend.database import engine
11911235
from oasst_backend.prompt_repository import PromptRepository
11921236

11931237
with Session(engine) as db:
11941238
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
1239+
# api_client = create_api_client(session=db, description="test", frontend_type="bot")
11951240
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
11961241

11971242
pr = PromptRepository(db=db, api_client=api_client, client_user=dummy_user)
@@ -1200,25 +1245,22 @@ def purge_user(self, user_id: UUID, ban: bool = True) -> None:
12001245
tm = TreeManager(db, pr, cfg)
12011246
tm.ensure_tree_states()
12021247

1203-
tm.purge_user_messages(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"), purge_initial_prompts=False)
1248+
# tm.purge_user_messages(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"), purge_initial_prompts=False)
12041249
# tm.purge_user(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"))
12051250
# db.commit()
12061251

12071252
# print("query_num_active_trees", tm.query_num_active_trees())
12081253
# print("query_incomplete_rankings", tm.query_incomplete_rankings())
12091254
# print("query_replies_need_review", tm.query_replies_need_review())
1255+
# print("query_incomplete_reply_reviews", tm.query_replies_need_review())
12101256
# print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
12111257
# print("query_extendible_trees", tm.query_extendible_trees())
12121258
# print("query_extendible_parents", tm.query_extendible_parents())
1213-
# print("query_tree_size", tm.query_tree_size(message_tree_id=UUID("bdf434cf-4df5-4b74-949c-a5a157bc3292")))
1214-
1215-
# print(
1216-
# "query_reviews_for_message",
1217-
# tm.query_reviews_for_message(message_id=UUID("6a444493-0d48-4316-a9f1-7e263f5a2473")),
1218-
# )
12191259

12201260
# print("next_task:", tm.next_task())
12211261

12221262
# print(
1223-
# "query_tree_ranking_results", tm.query_tree_ranking_results(UUID("6036f58f-41b5-48c4-bdd9-b16f34ab1312"))
1263+
# ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921"))
12241264
# )
1265+
1266+
print(tm.export_trees_to_file(message_tree_ids=["7e75fb38-e664-4e2b-817c-b9a0b01b0074"], file="lol.jsonl"))
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from __future__ import annotations
2+
3+
import gzip
4+
import json
5+
from collections import defaultdict
6+
from typing import Optional, TextIO
7+
8+
from fastapi.encoders import jsonable_encoder
9+
from oasst_backend.models import Message
10+
from pydantic import BaseModel
11+
12+
13+
class ExportMessageNode(BaseModel):
14+
message_id: str
15+
parent_id: Optional[str]
16+
text: Optional[str]
17+
role: str
18+
review_count: Optional[int]
19+
rank: Optional[int]
20+
replies: Optional[list[ExportMessageNode]]
21+
22+
@classmethod
23+
def prep_message_export(cls, message: Message) -> ExportMessageNode:
24+
return cls(
25+
message_id=str(message.id),
26+
parent_id=str(message.parent_id) if message.parent_id else None,
27+
text=str(message.payload.payload.text),
28+
role=message.role,
29+
review_count=message.review_count,
30+
rank=message.rank,
31+
)
32+
33+
34+
class ExportMessageTree(BaseModel):
35+
message_tree_id: str
36+
replies: Optional[ExportMessageNode]
37+
38+
39+
def build_export_tree(message_tree_id: str, messages: list[Message]) -> ExportMessageTree:
40+
export_tree = ExportMessageTree(message_tree_id=str(message_tree_id))
41+
export_tree_data = [ExportMessageNode.prep_message_export(m) for m in messages]
42+
43+
message_parents = defaultdict(list)
44+
for message in export_tree_data:
45+
message_parents[message.parent_id].append(message)
46+
47+
def build_tree(tree: dict, parent: Optional[str], messages: list[Message]):
48+
children = message_parents[parent]
49+
tree.replies = children
50+
51+
for idx, child in enumerate(tree.replies):
52+
build_tree(tree.replies[idx], child.message_id, messages)
53+
54+
build_tree(export_tree, None, export_tree_data)
55+
56+
return export_tree
57+
58+
59+
def write_trees_to_file(file, trees: list[ExportMessageTree], use_compression: bool = True) -> None:
60+
61+
out_buff: TextIO
62+
if use_compression:
63+
out_buff = gzip.open(file, "wt", encoding="UTF-8")
64+
else:
65+
out_buff = open(file, "wt", encoding="UTF-8")
66+
67+
with out_buff as f:
68+
for tree in trees:
69+
file_data = jsonable_encoder(tree)
70+
json.dump(file_data, f)
71+
f.write("\n")

0 commit comments

Comments
 (0)