Skip to content

Commit eda275b

Browse files
authored
fix export & add emojis (LAION-AI#1004)
1 parent 29b540a commit eda275b

File tree

2 files changed

+31
-23
lines changed

2 files changed

+31
-23
lines changed

backend/main.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -318,34 +318,43 @@ def main():
318318

319319
parser.add_argument(
320320
"--print-openapi-schema",
321+
default=False,
321322
help="Dumps the openapi schema to stdout",
322-
action=argparse.BooleanOptionalAction,
323+
action="store_true",
323324
)
324325
parser.add_argument("--host", help="The host to run the server", default="0.0.0.0")
325326
parser.add_argument("--port", help="The port to run the server", default=8080)
326327
parser.add_argument(
327-
"--export", help="Export all trees which are ready for exporting.", action=argparse.BooleanOptionalAction
328+
"--export",
329+
default=False,
330+
help="Export all trees which are ready for exporting.",
331+
action="store_true",
328332
)
329333
parser.add_argument(
330334
"--export-file",
335+
type=str,
331336
help="Name of file to export trees to. If not provided when exporting, output will be send to STDOUT",
332337
)
333338
parser.add_argument(
334339
"--retry-scoring",
340+
default=False,
335341
help="Retry scoring failed message trees",
336-
action=argparse.BooleanOptionalAction,
342+
action="store_true",
337343
)
338344

339345
args = parser.parse_args()
340346

341347
if args.print_openapi_schema:
342348
print(get_openapi_schema())
343-
elif args.export:
349+
350+
if args.export:
344351
use_compression: bool = ".gz" in args.export_file
345352
export_ready_trees(file=args.export_file, use_compression=use_compression)
346-
elif args.retry_scoring:
353+
354+
if args.retry_scoring:
347355
retry_scoring_failed_message_trees()
348-
else:
356+
357+
if not (args.export or args.print_openapi_schema or args.retry_scoring):
349358
uvicorn.run(app, host=args.host, port=args.port)
350359

351360

backend/oasst_backend/utils/tree_export.py

+16-17
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ class ExportMessageNode(BaseModel):
2020
rank: int | None
2121
synthetic: bool | None
2222
model_name: str | None
23+
emojis: dict[str, int] | None
2324
replies: list[ExportMessageNode] | None
2425

25-
@classmethod
26-
def prep_message_export(cls, message: Message) -> ExportMessageNode:
27-
return cls(
26+
@staticmethod
27+
def prep_message_export(message: Message) -> ExportMessageNode:
28+
return ExportMessageNode(
2829
message_id=str(message.id),
2930
parent_id=str(message.parent_id) if message.parent_id else None,
3031
text=str(message.payload.payload.text),
@@ -33,6 +34,7 @@ def prep_message_export(cls, message: Message) -> ExportMessageNode:
3334
review_count=message.review_count,
3435
synthetic=message.synthetic,
3536
model_name=message.model_name,
37+
emojis=message.emojis,
3638
rank=message.rank,
3739
)
3840

@@ -43,23 +45,20 @@ class ExportMessageTree(BaseModel):
4345

4446

4547
def build_export_tree(message_tree_id: str, messages: list[Message]) -> ExportMessageTree:
46-
export_tree = ExportMessageTree(message_tree_id=str(message_tree_id))
47-
export_tree_data = [ExportMessageNode.prep_message_export(m) for m in messages]
48+
export_messages = [ExportMessageNode.prep_message_export(m) for m in messages]
4849

49-
message_parents = defaultdict(list)
50-
for message in export_tree_data:
51-
message_parents[message.parent_id].append(message)
50+
messages_by_parent = defaultdict(list)
51+
for message in export_messages:
52+
messages_by_parent[message.parent_id].append(message)
5253

53-
def build_tree(tree: dict, parent: Optional[str], messages: list[Message]):
54-
children = message_parents[parent]
55-
tree.replies = children
54+
def assign_replies(node: ExportMessageNode) -> ExportMessageNode:
55+
node.replies = messages_by_parent[node.message_id]
56+
for child in node.replies:
57+
assign_replies(child)
58+
return node
5659

57-
for idx, child in enumerate(tree.replies):
58-
build_tree(tree.replies[idx], child.message_id, messages)
59-
60-
build_tree(export_tree, None, export_tree_data)
61-
62-
return export_tree
60+
prompt = assign_replies(messages_by_parent[None][0])
61+
return ExportMessageTree(message_tree_id=str(message_tree_id), prompt=prompt)
6362

6463

6564
def write_trees_to_file(file, trees: list[ExportMessageTree], use_compression: bool = True) -> None:

0 commit comments

Comments
 (0)