Skip to content

Commit c28bc80

Browse files
authored
Generalize problem_type to all sequence classification models (#14180)
* Generalize problem_type to all classification models * Missing import * Deberta BC and fix tests * Fix template * Missing imports * Revert change to reformer test * Fix style
1 parent 4ab6a4a commit c28bc80

38 files changed

+474
-191
lines changed

src/transformers/models/bart/modeling_bart.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch
2323
import torch.utils.checkpoint
2424
from torch import nn
25-
from torch.nn import CrossEntropyLoss, MSELoss
25+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2626

2727
from ...activations import ACT2FN
2828
from ...file_utils import (
@@ -1475,14 +1475,26 @@ def forward(
14751475

14761476
loss = None
14771477
if labels is not None:
1478-
if self.config.num_labels == 1:
1479-
# regression
1478+
if self.config.problem_type is None:
1479+
if self.config.num_labels == 1:
1480+
self.config.problem_type = "regression"
1481+
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1482+
self.config.problem_type = "single_label_classification"
1483+
else:
1484+
self.config.problem_type = "multi_label_classification"
1485+
1486+
if self.config.problem_type == "regression":
14801487
loss_fct = MSELoss()
1481-
loss = loss_fct(logits.view(-1), labels.view(-1))
1482-
else:
1488+
if self.config.num_labels == 1:
1489+
loss = loss_fct(logits.squeeze(), labels.squeeze())
1490+
else:
1491+
loss = loss_fct(logits, labels)
1492+
elif self.config.problem_type == "single_label_classification":
14831493
loss_fct = CrossEntropyLoss()
14841494
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1485-
1495+
elif self.config.problem_type == "multi_label_classification":
1496+
loss_fct = BCEWithLogitsLoss()
1497+
loss = loss_fct(logits, labels)
14861498
if not return_dict:
14871499
output = (logits,) + outputs[1:]
14881500
return ((loss,) + output) if loss is not None else output

src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import numpy as np
2424
import torch
2525
from torch import nn
26-
from torch.nn import CrossEntropyLoss, MSELoss
26+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2727

2828
from ...activations import ACT2FN
2929
from ...file_utils import (
@@ -2680,14 +2680,26 @@ def forward(
26802680

26812681
loss = None
26822682
if labels is not None:
2683-
if self.config.num_labels == 1:
2684-
# regression
2683+
if self.config.problem_type is None:
2684+
if self.config.num_labels == 1:
2685+
self.config.problem_type = "regression"
2686+
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
2687+
self.config.problem_type = "single_label_classification"
2688+
else:
2689+
self.config.problem_type = "multi_label_classification"
2690+
2691+
if self.config.problem_type == "regression":
26852692
loss_fct = MSELoss()
2686-
loss = loss_fct(logits.view(-1), labels.view(-1))
2687-
else:
2693+
if self.config.num_labels == 1:
2694+
loss = loss_fct(logits.squeeze(), labels.squeeze())
2695+
else:
2696+
loss = loss_fct(logits, labels)
2697+
elif self.config.problem_type == "single_label_classification":
26882698
loss_fct = CrossEntropyLoss()
26892699
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
2690-
2700+
elif self.config.problem_type == "multi_label_classification":
2701+
loss_fct = BCEWithLogitsLoss()
2702+
loss = loss_fct(logits, labels)
26912703
if not return_dict:
26922704
output = (logits,) + outputs[1:]
26932705
return ((loss,) + output) if loss is not None else output

src/transformers/models/ctrl/modeling_ctrl.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import numpy as np
2121
import torch
2222
from torch import nn
23-
from torch.nn import CrossEntropyLoss, MSELoss
23+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2424

2525
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
2626
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput
@@ -690,14 +690,26 @@ def forward(
690690

691691
loss = None
692692
if labels is not None:
693-
if self.num_labels == 1:
694-
# We are doing regression
693+
if self.config.problem_type is None:
694+
if self.num_labels == 1:
695+
self.config.problem_type = "regression"
696+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
697+
self.config.problem_type = "single_label_classification"
698+
else:
699+
self.config.problem_type = "multi_label_classification"
700+
701+
if self.config.problem_type == "regression":
695702
loss_fct = MSELoss()
696-
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
697-
else:
703+
if self.num_labels == 1:
704+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
705+
else:
706+
loss = loss_fct(pooled_logits, labels)
707+
elif self.config.problem_type == "single_label_classification":
698708
loss_fct = CrossEntropyLoss()
699709
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
700-
710+
elif self.config.problem_type == "multi_label_classification":
711+
loss_fct = BCEWithLogitsLoss()
712+
loss = loss_fct(pooled_logits, labels)
701713
if not return_dict:
702714
output = (pooled_logits,) + transformer_outputs[2:]
703715
return ((loss,) + output) if loss is not None else output

src/transformers/models/deberta/modeling_deberta.py

+37-22
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import torch
2121
from torch import _softmax_backward_data, nn
22-
from torch.nn import CrossEntropyLoss
22+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2323

2424
from ...activations import ACT2FN
2525
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
@@ -1194,31 +1194,46 @@ def forward(
11941194

11951195
loss = None
11961196
if labels is not None:
1197-
if self.num_labels == 1:
1198-
# regression task
1199-
loss_fn = nn.MSELoss()
1200-
logits = logits.view(-1).to(labels.dtype)
1201-
loss = loss_fn(logits, labels.view(-1))
1202-
elif labels.dim() == 1 or labels.size(-1) == 1:
1203-
label_index = (labels >= 0).nonzero()
1204-
labels = labels.long()
1205-
if label_index.size(0) > 0:
1206-
labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1)))
1207-
labels = torch.gather(labels, 0, label_index.view(-1))
1208-
loss_fct = CrossEntropyLoss()
1209-
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
1197+
if self.config.problem_type is None:
1198+
if self.num_labels == 1:
1199+
# regression task
1200+
loss_fn = nn.MSELoss()
1201+
logits = logits.view(-1).to(labels.dtype)
1202+
loss = loss_fn(logits, labels.view(-1))
1203+
elif labels.dim() == 1 or labels.size(-1) == 1:
1204+
label_index = (labels >= 0).nonzero()
1205+
labels = labels.long()
1206+
if label_index.size(0) > 0:
1207+
labeled_logits = torch.gather(
1208+
logits, 0, label_index.expand(label_index.size(0), logits.size(1))
1209+
)
1210+
labels = torch.gather(labels, 0, label_index.view(-1))
1211+
loss_fct = CrossEntropyLoss()
1212+
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
1213+
else:
1214+
loss = torch.tensor(0).to(logits)
12101215
else:
1211-
loss = torch.tensor(0).to(logits)
1212-
else:
1213-
log_softmax = nn.LogSoftmax(-1)
1214-
loss = -((log_softmax(logits) * labels).sum(-1)).mean()
1216+
log_softmax = nn.LogSoftmax(-1)
1217+
loss = -((log_softmax(logits) * labels).sum(-1)).mean()
1218+
elif self.config.problem_type == "regression":
1219+
loss_fct = MSELoss()
1220+
if self.num_labels == 1:
1221+
loss = loss_fct(logits.squeeze(), labels.squeeze())
1222+
else:
1223+
loss = loss_fct(logits, labels)
1224+
elif self.config.problem_type == "single_label_classification":
1225+
loss_fct = CrossEntropyLoss()
1226+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1227+
elif self.config.problem_type == "multi_label_classification":
1228+
loss_fct = BCEWithLogitsLoss()
1229+
loss = loss_fct(logits, labels)
12151230
if not return_dict:
12161231
output = (logits,) + outputs[1:]
12171232
return ((loss,) + output) if loss is not None else output
1218-
else:
1219-
return SequenceClassifierOutput(
1220-
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
1221-
)
1233+
1234+
return SequenceClassifierOutput(
1235+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
1236+
)
12221237

12231238

12241239
@add_start_docstrings(

src/transformers/models/deberta_v2/modeling_deberta_v2.py

+37-22
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import numpy as np
2121
import torch
2222
from torch import _softmax_backward_data, nn
23-
from torch.nn import CrossEntropyLoss, LayerNorm
23+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
2424

2525
from ...activations import ACT2FN
2626
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
@@ -1304,31 +1304,46 @@ def forward(
13041304

13051305
loss = None
13061306
if labels is not None:
1307-
if self.num_labels == 1:
1308-
# regression task
1309-
loss_fn = nn.MSELoss()
1310-
logits = logits.view(-1).to(labels.dtype)
1311-
loss = loss_fn(logits, labels.view(-1))
1312-
elif labels.dim() == 1 or labels.size(-1) == 1:
1313-
label_index = (labels >= 0).nonzero()
1314-
labels = labels.long()
1315-
if label_index.size(0) > 0:
1316-
labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1)))
1317-
labels = torch.gather(labels, 0, label_index.view(-1))
1318-
loss_fct = CrossEntropyLoss()
1319-
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
1307+
if self.config.problem_type is None:
1308+
if self.num_labels == 1:
1309+
# regression task
1310+
loss_fn = nn.MSELoss()
1311+
logits = logits.view(-1).to(labels.dtype)
1312+
loss = loss_fn(logits, labels.view(-1))
1313+
elif labels.dim() == 1 or labels.size(-1) == 1:
1314+
label_index = (labels >= 0).nonzero()
1315+
labels = labels.long()
1316+
if label_index.size(0) > 0:
1317+
labeled_logits = torch.gather(
1318+
logits, 0, label_index.expand(label_index.size(0), logits.size(1))
1319+
)
1320+
labels = torch.gather(labels, 0, label_index.view(-1))
1321+
loss_fct = CrossEntropyLoss()
1322+
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
1323+
else:
1324+
loss = torch.tensor(0).to(logits)
13201325
else:
1321-
loss = torch.tensor(0).to(logits)
1322-
else:
1323-
log_softmax = nn.LogSoftmax(-1)
1324-
loss = -((log_softmax(logits) * labels).sum(-1)).mean()
1326+
log_softmax = nn.LogSoftmax(-1)
1327+
loss = -((log_softmax(logits) * labels).sum(-1)).mean()
1328+
elif self.config.problem_type == "regression":
1329+
loss_fct = MSELoss()
1330+
if self.num_labels == 1:
1331+
loss = loss_fct(logits.squeeze(), labels.squeeze())
1332+
else:
1333+
loss = loss_fct(logits, labels)
1334+
elif self.config.problem_type == "single_label_classification":
1335+
loss_fct = CrossEntropyLoss()
1336+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1337+
elif self.config.problem_type == "multi_label_classification":
1338+
loss_fct = BCEWithLogitsLoss()
1339+
loss = loss_fct(logits, labels)
13251340
if not return_dict:
13261341
output = (logits,) + outputs[1:]
13271342
return ((loss,) + output) if loss is not None else output
1328-
else:
1329-
return SequenceClassifierOutput(
1330-
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
1331-
)
1343+
1344+
return SequenceClassifierOutput(
1345+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
1346+
)
13321347

13331348

13341349
@add_start_docstrings(

src/transformers/models/fnet/modeling_fnet.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch.utils.checkpoint
2424
from packaging import version
2525
from torch import nn
26-
from torch.nn import CrossEntropyLoss, MSELoss
26+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2727

2828
from ...file_utils import is_scipy_available
2929

@@ -927,14 +927,26 @@ def forward(
927927

928928
loss = None
929929
if labels is not None:
930-
if self.num_labels == 1:
931-
# We are doing regression
930+
if self.config.problem_type is None:
931+
if self.num_labels == 1:
932+
self.config.problem_type = "regression"
933+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
934+
self.config.problem_type = "single_label_classification"
935+
else:
936+
self.config.problem_type = "multi_label_classification"
937+
938+
if self.config.problem_type == "regression":
932939
loss_fct = MSELoss()
933-
loss = loss_fct(logits.view(-1), labels.view(-1))
934-
else:
940+
if self.num_labels == 1:
941+
loss = loss_fct(logits.squeeze(), labels.squeeze())
942+
else:
943+
loss = loss_fct(logits, labels)
944+
elif self.config.problem_type == "single_label_classification":
935945
loss_fct = CrossEntropyLoss()
936946
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
937-
947+
elif self.config.problem_type == "multi_label_classification":
948+
loss_fct = BCEWithLogitsLoss()
949+
loss = loss_fct(logits, labels)
938950
if not return_dict:
939951
output = (logits,) + outputs[2:]
940952
return ((loss,) + output) if loss is not None else output

src/transformers/models/gpt2/modeling_gpt2.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import torch.utils.checkpoint
2525
from packaging import version
2626
from torch import nn
27-
from torch.nn import CrossEntropyLoss, MSELoss
27+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2828

2929

3030
if version.parse(torch.__version__) >= version.parse("1.6"):
@@ -1406,14 +1406,26 @@ def forward(
14061406

14071407
loss = None
14081408
if labels is not None:
1409-
if self.num_labels == 1:
1410-
# We are doing regression
1409+
if self.config.problem_type is None:
1410+
if self.num_labels == 1:
1411+
self.config.problem_type = "regression"
1412+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1413+
self.config.problem_type = "single_label_classification"
1414+
else:
1415+
self.config.problem_type = "multi_label_classification"
1416+
1417+
if self.config.problem_type == "regression":
14111418
loss_fct = MSELoss()
1412-
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
1413-
else:
1419+
if self.num_labels == 1:
1420+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1421+
else:
1422+
loss = loss_fct(pooled_logits, labels)
1423+
elif self.config.problem_type == "single_label_classification":
14141424
loss_fct = CrossEntropyLoss()
14151425
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1416-
1426+
elif self.config.problem_type == "multi_label_classification":
1427+
loss_fct = BCEWithLogitsLoss()
1428+
loss = loss_fct(pooled_logits, labels)
14171429
if not return_dict:
14181430
output = (pooled_logits,) + transformer_outputs[1:]
14191431
return ((loss,) + output) if loss is not None else output

src/transformers/models/gpt_neo/modeling_gpt_neo.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import torch
2222
import torch.utils.checkpoint
2323
from torch import nn
24-
from torch.nn import CrossEntropyLoss, MSELoss
24+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2525

2626
from ...activations import ACT2FN
2727
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
@@ -895,14 +895,26 @@ def forward(
895895

896896
loss = None
897897
if labels is not None:
898-
if self.num_labels == 1:
899-
# We are doing regression
898+
if self.config.problem_type is None:
899+
if self.num_labels == 1:
900+
self.config.problem_type = "regression"
901+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
902+
self.config.problem_type = "single_label_classification"
903+
else:
904+
self.config.problem_type = "multi_label_classification"
905+
906+
if self.config.problem_type == "regression":
900907
loss_fct = MSELoss()
901-
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
902-
else:
908+
if self.num_labels == 1:
909+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
910+
else:
911+
loss = loss_fct(pooled_logits, labels)
912+
elif self.config.problem_type == "single_label_classification":
903913
loss_fct = CrossEntropyLoss()
904914
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
905-
915+
elif self.config.problem_type == "multi_label_classification":
916+
loss_fct = BCEWithLogitsLoss()
917+
loss = loss_fct(pooled_logits, labels)
906918
if not return_dict:
907919
output = (pooled_logits,) + transformer_outputs[1:]
908920
return ((loss,) + output) if loss is not None else output

0 commit comments

Comments
 (0)