forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcheck_dataset_counts.py
164 lines (144 loc) · 6.56 KB
/
check_dataset_counts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import argparse
from collections import Counter
from enum import Enum
from pathlib import Path
from typing import Any
import pandas as pd
import yaml
from langdetect import DetectorFactory, detect
from model_training.custom_datasets.formatting import DatasetEntrySft
from model_training.utils.utils import _strtobool, get_dataset
class Mode(str, Enum):
sft = "sft"
rm = "rm"
rl = "rl"
def config_name(self) -> str:
match self:
case Mode.sft:
return "config.yaml"
case Mode.rm:
return "config_rm.yaml"
case Mode.rl:
return "config_rl.yaml"
def default_config(self) -> str:
match self:
case Mode.sft:
return "defaults"
case Mode.rm:
return "defaults_rm"
case Mode.rl:
return "defaults_rlhf"
def read_yaml(dir: str | Path, config_file: str) -> dict[str, Any]:
with open(Path(dir) / config_file, "r") as f:
return yaml.safe_load(f)
def argument_parsing(notebook=False, notebook_args=None):
parser = argparse.ArgumentParser()
parser.add_argument(
"--datasets",
nargs="+",
required=True,
help="""
Multiple datasets can be passed to set different options.
For example, run as:
./check_dataset_counts.py --datasets math oasst_export_eu
to check the counts of the math and the oasst_export_eu dataset.
""",
)
parser.add_argument("--mode", dest="mode", type=Mode, choices=list(Mode))
parser.add_argument("--output_path", dest="output_path", default="dataset_counts.csv")
parser.add_argument("--detect_language", default=False, action="store_true")
if notebook:
args, remaining = parser.parse_known_args(notebook_args)
else:
args, remaining = parser.parse_known_args()
# Config from YAML
mode: Mode = args.mode
configs = read_yaml("./configs", config_file=mode.config_name())
conf = configs[mode.default_config()]
if "all" in args.datasets:
conf["datasets"] = configs[mode.default_config()]["datasets"] + configs[mode.default_config()]["datasets_extra"]
else:
# reset datasets, so that we only get the datasets defined in configs and remove the ones in the default
datasets_list = list()
for name in args.datasets:
# check and process multiple datasets
if "," in name:
for n in name.split(","):
datasets_value = configs[n].get("datasets") or configs[n]["datasets_extra"]
# check if dataset is extra key in config
elif name in configs:
datasets_value = configs[name].get("datasets") or configs[name]["datasets_extra"]
# check in default config
elif name in configs[mode.default_config()]["datasets"]:
datasets_value = [name]
else:
raise ValueError(
f'Error: Could not find the dataset "{name}" in {mode.config_name()}. ',
f"Tried to look for this dataset within th key {mode.default_config()} ",
"and as separate key.",
)
datasets_list.extend(datasets_value)
conf["mode"] = mode
conf["output_path"] = args.output_path
conf["datasets_extra"] = []
conf["datasets"] = datasets_list
conf["detect_language"] = args.detect_language
# Override config from command-line
parser = argparse.ArgumentParser()
for key, value in conf.items():
type_ = type(value) if value is not None else str
if type_ == bool:
type_ = _strtobool
parser.add_argument(f"--{key}", type=type_, default=value)
# Allow --no-{key} to remove it completely
parser.add_argument(f"--no-{key}", dest=key, action="store_const", const=None)
args = parser.parse_args(remaining)
print(args)
return args
if __name__ == "__main__":
args = argument_parsing()
train, evals = get_dataset(args, mode=args.mode.value)
overview_df = pd.DataFrame(columns=["dataset_name", "train_counts", "eval_counts", "total_counts"])
language_df = pd.DataFrame()
if args.detect_language:
DetectorFactory.seed = 0
for idx, (dataset_name, dataset) in enumerate(evals.items()):
train_lang = Counter()
if args.detect_language:
length = len(dataset)
for idx1, row in enumerate(dataset):
if idx1 % 1000 == 0:
print(f"{idx1} of {length} of ds {dataset_name}.")
try:
if isinstance(row, (list, tuple)):
train_lang += Counter([detect(k) for k in row])
elif isinstance(row, DatasetEntrySft):
train_lang += Counter([detect(k) for k in row.questions if k])
if isinstance(row.answers[0], list):
for answers in row.answers:
train_lang += Counter([detect(k) for k in answers if k])
else:
train_lang += Counter([detect(k) for k in row.answers if k])
else:
raise ValueError(
f"Did not expect the type {type(row)}. Should be either list, tuple or DatasetEntry."
)
except Exception as e:
print(e)
train_lang = dict(train_lang)
train_lang["dataset_name"] = dataset_name
language_df = pd.concat([language_df, pd.DataFrame([train_lang])])
eval_count = len(evals.get(dataset_name, []))
overview_df.loc[idx] = [
dataset_name,
len(train.datasets[idx]),
eval_count,
len(train.datasets[idx]) + eval_count,
]
print(overview_df)
print(language_df)
overview_df.to_csv(args.output_path, index=False)
language_df.to_csv("language_counts.csv", index=False)
# python check_dataset_counts.py --datasets joke webgpt gpt4all alpaca code_alpaca vicuna minimath humaneval_mbpp_codegen_qa humaneval_mbpp_testgen_qa grade_school_math_instructions recipes cmu_wiki_qa oa_wiki_qa_bart_10000row prosocial_dialogue explain_prosocial soda oa_leet10k --mode sft
# python check_dataset_counts.py --datasets joke webgpt alpaca code_alpaca vicuna minimath humaneval_mbpp_codegen_qa humaneval_mbpp_testgen_qa grade_school_math_instructions recipes cmu_wiki_qa oa_wiki_qa_bart_10000row prosocial_dialogue oa_leet10k --mode sft
# python check_dataset_counts.py --datasets joke webgpt --mode sft