forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsplit_dataset.py
64 lines (49 loc) · 1.79 KB
/
split_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
import random
from oasst_data import read_message_list, write_messages
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--val_percent",
type=int,
default=5,
)
parser.add_argument(
"input_file_name",
type=str,
help="path to input .jsonl or .jsonl.gz input file",
)
parser.add_argument(
"--val_output",
type=str,
help="path to validation output .jsonl or .jsonl.gz file",
required=True,
)
parser.add_argument(
"--train_output",
type=str,
help="path to train output .jsonl or .jsonl.gz file",
required=True,
)
parser.add_argument("--exclude-nulls", action="store_true", default=False)
args = parser.parse_args()
return args
def main():
"""Split messages file into train and validation set based on message_tree_id."""
args = parse_args()
print(f"Reading: {args.input_file_name}")
messages = read_message_list(args.input_file_name)
print(f"Found {len(messages)} matching messages.")
tree_ids = list(set(m.message_tree_id for m in messages))
random.shuffle(tree_ids)
val_size = len(tree_ids) * args.val_percent // 100
train_set = set(tree_ids[val_size:])
val_set = set(tree_ids[:val_size])
train_messages = [m for m in messages if m.message_tree_id in train_set]
val_messages = [m for m in messages if m.message_tree_id in val_set]
print(f"Writing train {len(train_messages)} messages: {args.train_output}")
write_messages(args.train_output, train_messages, args.exclude_nulls)
print(f"Writing valid {len(val_messages)} messages: {args.val_output}")
write_messages(args.val_output, val_messages, args.exclude_nulls)
if __name__ == "__main__":
main()