-
Notifications
You must be signed in to change notification settings - Fork 9.6k
/
Copy pathsummarization_dataset.py
82 lines (62 loc) · 2.77 KB
/
summarization_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import argparse
import glob
import os
import json
import time
import logging
import random
import re
from itertools import chain
from string import punctuation
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from nlp import load_dataset
from transformers import (
AdamW,
T5ForConditionalGeneration,
T5Tokenizer,
get_linear_schedule_with_warmup
)
class wikihow(Dataset):
def __init__(self, tokenizer, type_path, num_samples, input_length, output_length, print_text=False):
self.dataset = load_dataset('wikihow', 'all', data_dir='data/', split=type_path)
if num_samples:
self.dataset = self.dataset.select(list(range(0, num_samples)))
self.input_length = input_length
self.tokenizer = tokenizer
self.output_length = output_length
self.print_text = print_text
def __len__(self):
return self.dataset.shape[0]
def clean_text(self, text):
text = text.replace('Example of text:', '')
text = text.replace('Example of Summary:', '')
text = text.replace('\n','')
text = text.replace('``', '')
text = text.replace('"', '')
return text
def convert_to_features(self, example_batch):
# Tokenize contexts and questions (as pairs of inputs)
if self.print_text:
print("Input Text: ", self.clean_text(example_batch['text']))
# input_ = self.clean_text(example_batch['text']) + " </s>"
# target_ = self.clean_text(example_batch['headline']) + " </s>"
input_ = self.clean_text(example_batch['text'])
target_ = self.clean_text(example_batch['headline'])
source = self.tokenizer.batch_encode_plus([input_], max_length=self.input_length,
padding='max_length', truncation=True, return_tensors="pt")
targets = self.tokenizer.batch_encode_plus([target_], max_length=self.output_length,
padding='max_length', truncation=True, return_tensors="pt")
return source, targets
def __getitem__(self, index):
source, targets = self.convert_to_features(self.dataset[index])
source_ids = source["input_ids"].squeeze()
target_ids = targets["input_ids"].squeeze()
src_mask = source["attention_mask"].squeeze()
target_mask = targets["attention_mask"].squeeze()
return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "target_mask": target_mask}
def get_dataset(tokenizer, type_path, num_samples, args):
return wikihow(tokenizer=tokenizer, type_path=type_path, num_samples=num_samples, input_length=max_input_length,
output_length=max_output_length)