-
Notifications
You must be signed in to change notification settings - Fork 28.4k
/
Copy pathmodeling_highway_roberta.py
156 lines (132 loc) · 6.63 KB
/
modeling_highway_roberta.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from __future__ import absolute_import, division, print_function, unicode_literals
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss
from transformers import RobertaConfig
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.models.roberta.modeling_roberta import (
ROBERTA_INPUTS_DOCSTRING,
ROBERTA_START_DOCSTRING,
RobertaEmbeddings,
)
from .modeling_highway_bert import BertPreTrainedModel, DeeBertModel, HighwayException, entropy
@add_start_docstrings(
"The RoBERTa Model transformer with early exiting (DeeRoBERTa). ",
ROBERTA_START_DOCSTRING,
)
class DeeRobertaModel(DeeBertModel):
config_class = RobertaConfig
base_model_prefix = "roberta"
def __init__(self, config):
super().__init__(config)
self.embeddings = RobertaEmbeddings(config)
self.init_weights()
@add_start_docstrings(
"""RoBERTa Model (with early exiting - DeeRoBERTa) with a classifier on top,
also takes care of multi-layer training. """,
ROBERTA_START_DOCSTRING,
)
class DeeRobertaForSequenceClassification(BertPreTrainedModel):
config_class = RobertaConfig
base_model_prefix = "roberta"
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.num_layers = config.num_hidden_layers
self.roberta = DeeRobertaModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_layer=-1,
train_highway=False,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
highway_exits (:obj:`tuple(tuple(torch.Tensor))`:
Tuple of each early exit's results (total length: number of layers)
Each tuple is again, a tuple of length 2 - the first entry is logits and the second entry is hidden states.
"""
exit_layer = self.num_layers
try:
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
except HighwayException as e:
outputs = e.message
exit_layer = e.exit_layer
logits = outputs[0]
if not self.training:
original_entropy = entropy(logits)
highway_entropy = []
highway_logits_all = []
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
# work with highway exits
highway_losses = []
for highway_exit in outputs[-1]:
highway_logits = highway_exit[0]
if not self.training:
highway_logits_all.append(highway_logits)
highway_entropy.append(highway_exit[2])
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
highway_loss = loss_fct(highway_logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
highway_loss = loss_fct(highway_logits.view(-1, self.num_labels), labels.view(-1))
highway_losses.append(highway_loss)
if train_highway:
outputs = (sum(highway_losses[:-1]),) + outputs
# exclude the final highway, of course
else:
outputs = (loss,) + outputs
if not self.training:
outputs = outputs + ((original_entropy, highway_entropy), exit_layer)
if output_layer >= 0:
outputs = (
(outputs[0],) + (highway_logits_all[output_layer],) + outputs[2:]
) # use the highway of the last layer
return outputs # (loss), logits, (hidden_states), (attentions), entropy