-
Notifications
You must be signed in to change notification settings - Fork 28.5k
/
Copy pathdownload_cnn_daily_mail.py
31 lines (24 loc) · 1.13 KB
/
download_cnn_daily_mail.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import argparse
from pathlib import Path
import tensorflow_datasets as tfds
def main(input_path, reference_path, data_dir):
cnn_ds = tfds.load("cnn_dailymail", split="test", shuffle_files=False, data_dir=data_dir)
cnn_ds_iter = tfds.as_numpy(cnn_ds)
test_articles_file = Path(input_path).open("w")
test_summaries_file = Path(reference_path).open("w")
for example in cnn_ds_iter:
test_articles_file.write(example["article"].decode("utf-8") + "\n")
test_articles_file.flush()
test_summaries_file.write(example["highlights"].decode("utf-8").replace("\n", " ") + "\n")
test_summaries_file.flush()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input_path", type=str, help="where to save the articles input data")
parser.add_argument(
"reference_path", type=str, help="where to save the reference summaries",
)
parser.add_argument(
"--data_dir", type=str, default="~/tensorflow_datasets", help="where to save the tensorflow datasets.",
)
args = parser.parse_args()
main(args.input_path, args.reference_path, args.data_dir)