Skip to content

Commit 510cf72

Browse files
authored
Prepare oasst data export to HuggingFace (LAION-AI#2305)
Add conversion from trees to 'flat' message table jsonl file. - add oasst-data function to read messages & message tree files `read_message_trees`, `read_message_tree_list`, `read_messages`, `read_message_list` (+ lower level functions `open_jsonl_read`, `read_oasst_obj`, `read_oasst_jsonl`) - add oasst-data fucntios to write messages & message tree files `write_message_trees`, `write_messages` (+ lower level functions `open_jsonl_write`, `write_tree`, `write_message`) - add script used for data cleaning (`clean_dataset.py` in examples) - add examples: `filter_trees.py`, `filter_messages.py`, `split_dataset.py`, `tree_to_messages.py` - update pre-commit hook black-jupyter to rev: 23.3.0 and add args: ["--profile", "black", "--filter-files"] to isort config
1 parent b25f8f0 commit 510cf72

File tree

14 files changed

+586
-40
lines changed

14 files changed

+586
-40
lines changed

.pre-commit-config.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ repos:
5858
- id: end-of-file-fixer
5959

6060
- repo: https://github.com/psf/black
61-
rev: 23.1.0
61+
rev: 23.3.0
6262
hooks:
6363
- id: black-jupyter
6464

@@ -71,6 +71,7 @@ repos:
7171
rev: 5.12.0
7272
hooks:
7373
- id: isort
74+
args: ["--profile", "black", "--filter-files"]
7475

7576
- repo: https://github.com/pre-commit/mirrors-prettier
7677
rev: v2.7.1

backend/oasst_backend/utils/tree_export.py

+1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def prepare_export_message_node(
9292
message_id=message_id,
9393
parent_id=parent_id,
9494
user_id=user_id,
95+
created_date=message.created_date,
9596
text=str(message.payload.payload.text),
9697
role=message.role,
9798
lang=message.lang,

model/model_training/custom_datasets/oasst_dataset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pathlib import Path
22
from typing import Literal, Optional
33

4-
from oasst_data import ExportMessageNode, load_trees, visit_threads_depth_first
4+
from oasst_data import ExportMessageNode, read_message_trees, visit_threads_depth_first
55
from torch import Generator
66
from torch.utils.data import Dataset, random_split
77

@@ -43,7 +43,7 @@ def load_oasst_export(
4343
input_file_path = data_path / input_file_path
4444

4545
threads_per_tree = []
46-
for tree in load_trees(input_file_path):
46+
for tree in read_message_trees(input_file_path):
4747
if tree.tree_state != "ready_for_export" or not tree.prompt.review_result or tree.prompt.lang not in lang_codes:
4848
continue
4949

model/model_training/tools/check_oasst_export.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import argparse
22

3-
from oasst_data import ExportMessageTree, load_tree_list, visit_messages_depth_first
3+
from oasst_data import ExportMessageTree, read_message_tree_list, visit_messages_depth_first
44

55

66
def parse_args():
@@ -31,7 +31,7 @@ def tree_filter(tree: ExportMessageTree) -> bool:
3131
and (lang_codes is None or tree.prompt.lang in lang_codes)
3232
)
3333

34-
trees = load_tree_list(args.input_file_path, filter=tree_filter)
34+
trees = read_message_tree_list(args.input_file_path, filter=tree_filter)
3535
print(f"{len(trees)} trees")
3636

3737
all_messages = []

oasst-data/examples/clean_dataset.py

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import argparse
2+
from collections import OrderedDict
3+
4+
import pandas
5+
from oasst_data.reader import read_message_trees
6+
from oasst_data.schemas import ExportMessageNode, ExportMessageTree
7+
from oasst_data.traversal import visit_messages_depth_first
8+
from oasst_data.writer import write_message_trees
9+
10+
11+
def parse_args():
12+
parser = argparse.ArgumentParser(description="filter_dataset")
13+
parser.add_argument(
14+
"input_file_name",
15+
type=str,
16+
help="path to input .jsonl or .jsonl.gz input file",
17+
)
18+
parser.add_argument(
19+
"output_file_name",
20+
type=str,
21+
help="path to output .jsonl or .jsonl.gz file",
22+
)
23+
parser.add_argument("--instructions", type=str, help="xlsx file with instructions")
24+
parser.add_argument("--exclude-nulls", action="store_true", default=False)
25+
args = parser.parse_args()
26+
return args
27+
28+
29+
def main():
30+
args = parse_args()
31+
32+
instructions_df = pandas.read_excel(args.instructions, na_filter=False)
33+
34+
# load dataset and index messages by id
35+
tree_by_id: dict[str, ExportMessageTree] = OrderedDict()
36+
message_by_id: dict[str, ExportMessageNode] = {}
37+
38+
print(f"Reading: {args.input_file_name}")
39+
for message_tree in read_message_trees(args.input_file_name):
40+
tree_by_id[message_tree.message_tree_id] = message_tree
41+
42+
def index_message(msg: ExportMessageNode):
43+
message_by_id[msg.message_id] = msg
44+
45+
visit_messages_depth_first(message_tree.prompt, index_message)
46+
47+
print(f"Loaded {len(tree_by_id)} trees with {len(message_by_id)} messages.")
48+
49+
def count_descendants(msg: ExportMessageNode):
50+
i = 1
51+
if msg.replies:
52+
for r in msg.replies:
53+
i += count_descendants(r)
54+
return i
55+
56+
def delete_message(msg: ExportMessageNode):
57+
if msg.parent_id is None:
58+
tree_by_id.pop(msg.message_id)
59+
print(f"Tree deleted: {msg.message_id}")
60+
else:
61+
parent_msg = message_by_id[msg.parent_id]
62+
parent_msg.replies.remove(msg)
63+
print(f"Branch deleted: {msg.message_id} ({count_descendants(msg)} messages)")
64+
65+
# cleaning
66+
print("Cleaning...")
67+
for index, row in instructions_df.iterrows():
68+
id = row["UUID"]
69+
msg = message_by_id.get(id)
70+
if msg is None:
71+
print(f"Not found: {id}")
72+
73+
action = row["Action"]
74+
if action == "Delete":
75+
print(f"deleting: {id}")
76+
delete_message(msg)
77+
elif action == "Replace":
78+
print(f"replace: {id}")
79+
replace = row["Replace"]
80+
msg.text = replace
81+
elif action == "Edit":
82+
print(f"edit: {id}")
83+
if row["Category"] == "Copy Code":
84+
find = "\nCopy code\n"
85+
replace = "\n\n"
86+
else:
87+
find = row["Find"]
88+
replace = row["Replace"]
89+
msg.text.index(find) # make sure text is present
90+
msg.text = msg.text.replace(find, replace)
91+
else:
92+
print(f"Unsupported action {action}")
93+
94+
print("Done")
95+
96+
# write cleaned dataset to output file
97+
print(f"Writing: {args.output_file_name}")
98+
write_message_trees(
99+
args.output_file_name,
100+
tree_by_id.values(),
101+
exclude_none=args.exclude_nulls,
102+
)
103+
104+
105+
if __name__ == "__main__":
106+
main()
+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import argparse
2+
import json
3+
4+
from oasst_data import read_message_list, write_messages
5+
from oasst_data.schemas import ExportMessageNode
6+
from oasst_data.writer import open_jsonl_write
7+
8+
9+
def parse_args():
10+
parser = argparse.ArgumentParser(description="filter_messages")
11+
parser.add_argument(
12+
"input_file_name",
13+
type=str,
14+
help="path to input .jsonl or .jsonl.gz input file",
15+
)
16+
parser.add_argument(
17+
"output_file_name",
18+
type=str,
19+
help="path to output .jsonl or .jsonl.gz file",
20+
)
21+
parser.add_argument(
22+
"--include-deleted",
23+
action="store_true",
24+
help="Include deleted messages in export",
25+
)
26+
parser.add_argument(
27+
"--deleted-only",
28+
action="store_true",
29+
help="Export only deleted messages (implies --include-deleted)",
30+
)
31+
parser.add_argument(
32+
"--include-spam",
33+
action="store_true",
34+
help="Export including messages with no review or negative review result.",
35+
)
36+
parser.add_argument(
37+
"--spam-only",
38+
action="store_true",
39+
help="Export only messages with negative review result (implies --include-spam).",
40+
)
41+
parser.add_argument(
42+
"--exclude-normal",
43+
action="store_true",
44+
help="exclude non-deleted non-synthetic messages with positive review",
45+
default=False,
46+
)
47+
parser.add_argument(
48+
"--include-synthetic",
49+
action="store_true",
50+
help="Include synthetic messages in export",
51+
)
52+
parser.add_argument(
53+
"--synthetic-only",
54+
action="store_true",
55+
help="Export only synthetic messages (implies --include-synth)",
56+
)
57+
parser.add_argument(
58+
"--user",
59+
type=str,
60+
help="Only export trees involving the user with the specified ID. Incompatible with --state.",
61+
)
62+
parser.add_argument(
63+
"--state",
64+
type=str,
65+
help="all|prompt_lottery_waiting|growing|ready_for_export|aborted_low_grade|halted_by_moderator|backlog_ranking",
66+
)
67+
parser.add_argument(
68+
"--lang",
69+
type=str,
70+
help="Filter message trees by language code (BCP 47)",
71+
)
72+
parser.add_argument(
73+
"--prompts-only",
74+
action="store_true",
75+
help="Export a list of initial prompt messages",
76+
)
77+
parser.add_argument(
78+
"--export-text-only",
79+
action="store_true",
80+
help="Write jsonl file with message text strings only",
81+
)
82+
parser.add_argument("--exclude-nulls", action="store_true", default=False)
83+
args = parser.parse_args()
84+
return args
85+
86+
87+
def main():
88+
args = parse_args()
89+
90+
deleted: bool | None = False
91+
spam: bool | None = False
92+
synthetic: bool | None = False
93+
langs: list[str] | None = None
94+
states: list[str] | None = None
95+
prompts_only: bool = args.prompts_only
96+
exclude_normal: bool = args.exclude_normal
97+
98+
if args.include_deleted:
99+
deleted = None
100+
elif args.deleted_only:
101+
deleted = True
102+
103+
if args.include_spam:
104+
spam = None
105+
elif args.spam_only:
106+
spam = True
107+
108+
if args.include_synthetic:
109+
synthetic = None
110+
elif args.synthetic_only:
111+
synthetic = True
112+
113+
if args.lang:
114+
langs = args.lang.split(",")
115+
116+
if args.state:
117+
states = args.state.split(",")
118+
119+
def approve_message(msg: ExportMessageNode) -> bool:
120+
if (
121+
(deleted is not None and msg.deleted != deleted)
122+
or (synthetic is not None and msg.synthetic != synthetic)
123+
or (prompts_only and msg.parent_id)
124+
or (langs is not None and msg.lang not in langs)
125+
or (states is not None and msg.tree_state not in states)
126+
):
127+
return False
128+
129+
if exclude_normal is True and not msg.deleted and not msg.synthetic and msg.review_result:
130+
return False
131+
132+
if spam is not None and spam != (not msg.review_result):
133+
return False
134+
135+
return True
136+
137+
print(f"Reading: {args.input_file_name}")
138+
messages = read_message_list(args.input_file_name, approve_message)
139+
140+
print(f"Found {len(messages)} matching messages.")
141+
142+
print(f"Writing: {args.output_file_name}")
143+
if args.export_text_only:
144+
with open_jsonl_write(args.output_file_name) as file:
145+
for msg in messages:
146+
json.dump(msg.text, file)
147+
file.write("\n")
148+
else:
149+
write_messages(args.output_file_name, messages, args.exclude_nulls)
150+
151+
152+
if __name__ == "__main__":
153+
main()

oasst-data/examples/filter_trees.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import argparse
2+
3+
from oasst_data import read_message_trees, write_message_trees
4+
from oasst_data.schemas import ExportMessageTree
5+
from oasst_data.traversal import visit_messages_depth_first
6+
7+
8+
def parse_args():
9+
parser = argparse.ArgumentParser(description="filter_tres")
10+
parser.add_argument(
11+
"input_file_name",
12+
type=str,
13+
help="path to input .jsonl or .jsonl.gz input file",
14+
)
15+
parser.add_argument(
16+
"output_file_name",
17+
type=str,
18+
help="path to output .jsonl or .jsonl.gz file",
19+
)
20+
parser.add_argument(
21+
"--states",
22+
type=str,
23+
default="ready_for_export",
24+
help="all|prompt_lottery_waiting|growing|ready_for_export|aborted_low_grade|halted_by_moderator|backlog_ranking",
25+
)
26+
parser.add_argument("--exclude-nulls", action="store_true", default=False)
27+
parser.add_argument("--allow-synth", action="store_true", default=False)
28+
args = parser.parse_args()
29+
return args
30+
31+
32+
def main():
33+
args = parse_args()
34+
35+
# load dataset and index messages by id
36+
trees: list[ExportMessageTree] = []
37+
38+
states = args.states.split(",")
39+
allow_synth = args.allow_synth
40+
41+
print(f"Reading: {args.input_file_name}")
42+
for message_tree in read_message_trees(args.input_file_name):
43+
msgs = []
44+
visit_messages_depth_first(message_tree.prompt, msgs.append)
45+
if message_tree.tree_state in states:
46+
if allow_synth or not any(x.synthetic for x in msgs):
47+
trees.append(message_tree)
48+
49+
print(f"Found {len(trees)} matching trees.")
50+
51+
print(f"Writing: {args.output_file_name}")
52+
write_message_trees(args.output_file_name, trees, exclude_none=args.exclude_nulls)
53+
54+
55+
if __name__ == "__main__":
56+
main()

0 commit comments

Comments
 (0)