@@ -111,7 +111,8 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
111
111
112
112
# No ALBERT model currently handles the next sentence prediction task
113
113
if "seq_relationship" in name :
114
- continue
114
+ name = name .replace ("seq_relationship/output_" , "sop_classifier/classifier/" )
115
+ name = name .replace ("weights" , "weight" )
115
116
116
117
name = name .split ("/" )
117
118
@@ -568,6 +569,115 @@ def forward(
568
569
return outputs
569
570
570
571
572
+ @add_start_docstrings (
573
+ """Albert Model with two heads on top as done during the pre-training: a `masked language modeling` head and
574
+ a `sentence order prediction (classification)` head. """ ,
575
+ ALBERT_START_DOCSTRING ,
576
+ )
577
+ class AlbertForPreTraining (AlbertPreTrainedModel ):
578
+ def __init__ (self , config ):
579
+ super ().__init__ (config )
580
+
581
+ self .albert = AlbertModel (config )
582
+ self .predictions = AlbertMLMHead (config )
583
+ self .sop_classifier = AlbertSOPHead (config )
584
+
585
+ self .init_weights ()
586
+ self .tie_weights ()
587
+
588
+ def tie_weights (self ):
589
+ self ._tie_or_clone_weights (self .predictions .decoder , self .albert .embeddings .word_embeddings )
590
+
591
+ def get_output_embeddings (self ):
592
+ return self .predictions .decoder
593
+
594
+ @add_start_docstrings_to_callable (ALBERT_INPUTS_DOCSTRING )
595
+ def forward (
596
+ self ,
597
+ input_ids = None ,
598
+ attention_mask = None ,
599
+ token_type_ids = None ,
600
+ position_ids = None ,
601
+ head_mask = None ,
602
+ inputs_embeds = None ,
603
+ masked_lm_labels = None ,
604
+ sentence_order_label = None ,
605
+ ):
606
+ r"""
607
+ masked_lm_labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
608
+ Labels for computing the masked language modeling loss.
609
+ Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
610
+ Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
611
+ in ``[0, ..., config.vocab_size]``
612
+ sentence_order_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
613
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
614
+ Indices should be in ``[0, 1]``.
615
+ ``0`` indicates original order (sequence A, then sequence B),
616
+ ``1`` indicates switched order (sequence B, then sequence A).
617
+
618
+ Returns:
619
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
620
+ loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
621
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
622
+ prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
623
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
624
+ sop_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
625
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False
626
+ continuation before SoftMax).
627
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
628
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
629
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
630
+
631
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
632
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
633
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
634
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
635
+
636
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
637
+ heads.
638
+
639
+
640
+ Examples::
641
+
642
+ from transformers import AlbertTokenizer, AlbertForPreTraining
643
+ import torch
644
+
645
+ tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
646
+ model = AlbertForPreTraining.from_pretrained('albert-base-v2')
647
+
648
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
649
+ outputs = model(input_ids)
650
+
651
+ prediction_scores, sop_scores = outputs[:2]
652
+
653
+ """
654
+
655
+ outputs = self .albert (
656
+ input_ids ,
657
+ attention_mask = attention_mask ,
658
+ token_type_ids = token_type_ids ,
659
+ position_ids = position_ids ,
660
+ head_mask = head_mask ,
661
+ inputs_embeds = inputs_embeds ,
662
+ )
663
+
664
+ sequence_output , pooled_output = outputs [:2 ]
665
+
666
+ prediction_scores = self .predictions (sequence_output )
667
+ sop_scores = self .sop_classifier (pooled_output )
668
+
669
+ outputs = (prediction_scores , sop_scores ,) + outputs [2 :] # add hidden states and attention if they are here
670
+
671
+ if masked_lm_labels is not None and sentence_order_label is not None :
672
+ loss_fct = CrossEntropyLoss ()
673
+ masked_lm_loss = loss_fct (prediction_scores .view (- 1 , self .config .vocab_size ), masked_lm_labels .view (- 1 ))
674
+ sentence_order_loss = loss_fct (sop_scores .view (- 1 , 2 ), sentence_order_label .view (- 1 ))
675
+ total_loss = masked_lm_loss + sentence_order_loss
676
+ outputs = (total_loss ,) + outputs
677
+
678
+ return outputs # (loss), prediction_scores, sop_scores, (hidden_states), (attentions)
679
+
680
+
571
681
class AlbertMLMHead (nn .Module ):
572
682
def __init__ (self , config ):
573
683
super ().__init__ ()
@@ -592,6 +702,19 @@ def forward(self, hidden_states):
592
702
return prediction_scores
593
703
594
704
705
+ class AlbertSOPHead (nn .Module ):
706
+ def __init__ (self , config ):
707
+ super ().__init__ ()
708
+
709
+ self .dropout = nn .Dropout (config .classifier_dropout_prob )
710
+ self .classifier = nn .Linear (config .hidden_size , config .num_labels )
711
+
712
+ def forward (self , pooled_output ):
713
+ dropout_pooled_output = self .dropout (pooled_output )
714
+ logits = self .classifier (dropout_pooled_output )
715
+ return logits
716
+
717
+
595
718
@add_start_docstrings (
596
719
"Albert Model with a `language modeling` head on top." , ALBERT_START_DOCSTRING ,
597
720
)
0 commit comments