|
20 | 20 | import numpy as np
|
21 | 21 | import torch
|
22 | 22 | from torch import _softmax_backward_data, nn
|
23 |
| -from torch.nn import CrossEntropyLoss, LayerNorm |
| 23 | +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss |
24 | 24 |
|
25 | 25 | from ...activations import ACT2FN
|
26 | 26 | from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
@@ -1304,31 +1304,46 @@ def forward(
|
1304 | 1304 |
|
1305 | 1305 | loss = None
|
1306 | 1306 | if labels is not None:
|
1307 |
| - if self.num_labels == 1: |
1308 |
| - # regression task |
1309 |
| - loss_fn = nn.MSELoss() |
1310 |
| - logits = logits.view(-1).to(labels.dtype) |
1311 |
| - loss = loss_fn(logits, labels.view(-1)) |
1312 |
| - elif labels.dim() == 1 or labels.size(-1) == 1: |
1313 |
| - label_index = (labels >= 0).nonzero() |
1314 |
| - labels = labels.long() |
1315 |
| - if label_index.size(0) > 0: |
1316 |
| - labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1))) |
1317 |
| - labels = torch.gather(labels, 0, label_index.view(-1)) |
1318 |
| - loss_fct = CrossEntropyLoss() |
1319 |
| - loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) |
| 1307 | + if self.config.problem_type is None: |
| 1308 | + if self.num_labels == 1: |
| 1309 | + # regression task |
| 1310 | + loss_fn = nn.MSELoss() |
| 1311 | + logits = logits.view(-1).to(labels.dtype) |
| 1312 | + loss = loss_fn(logits, labels.view(-1)) |
| 1313 | + elif labels.dim() == 1 or labels.size(-1) == 1: |
| 1314 | + label_index = (labels >= 0).nonzero() |
| 1315 | + labels = labels.long() |
| 1316 | + if label_index.size(0) > 0: |
| 1317 | + labeled_logits = torch.gather( |
| 1318 | + logits, 0, label_index.expand(label_index.size(0), logits.size(1)) |
| 1319 | + ) |
| 1320 | + labels = torch.gather(labels, 0, label_index.view(-1)) |
| 1321 | + loss_fct = CrossEntropyLoss() |
| 1322 | + loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) |
| 1323 | + else: |
| 1324 | + loss = torch.tensor(0).to(logits) |
1320 | 1325 | else:
|
1321 |
| - loss = torch.tensor(0).to(logits) |
1322 |
| - else: |
1323 |
| - log_softmax = nn.LogSoftmax(-1) |
1324 |
| - loss = -((log_softmax(logits) * labels).sum(-1)).mean() |
| 1326 | + log_softmax = nn.LogSoftmax(-1) |
| 1327 | + loss = -((log_softmax(logits) * labels).sum(-1)).mean() |
| 1328 | + elif self.config.problem_type == "regression": |
| 1329 | + loss_fct = MSELoss() |
| 1330 | + if self.num_labels == 1: |
| 1331 | + loss = loss_fct(logits.squeeze(), labels.squeeze()) |
| 1332 | + else: |
| 1333 | + loss = loss_fct(logits, labels) |
| 1334 | + elif self.config.problem_type == "single_label_classification": |
| 1335 | + loss_fct = CrossEntropyLoss() |
| 1336 | + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| 1337 | + elif self.config.problem_type == "multi_label_classification": |
| 1338 | + loss_fct = BCEWithLogitsLoss() |
| 1339 | + loss = loss_fct(logits, labels) |
1325 | 1340 | if not return_dict:
|
1326 | 1341 | output = (logits,) + outputs[1:]
|
1327 | 1342 | return ((loss,) + output) if loss is not None else output
|
1328 |
| - else: |
1329 |
| - return SequenceClassifierOutput( |
1330 |
| - loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions |
1331 |
| - ) |
| 1343 | + |
| 1344 | + return SequenceClassifierOutput( |
| 1345 | + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions |
| 1346 | + ) |
1332 | 1347 |
|
1333 | 1348 |
|
1334 | 1349 | @add_start_docstrings(
|
|
0 commit comments