|
23 | 23 |
|
24 | 24 | from ...activations import ACT2FN
|
25 | 25 | from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
26 |
| -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast |
| 26 | +from ...modeling_outputs import ( |
| 27 | + BaseModelOutputWithPast, |
| 28 | + CausalLMOutputWithPast, |
| 29 | + QuestionAnsweringModelOutput, |
| 30 | + SequenceClassifierOutputWithPast, |
| 31 | +) |
27 | 32 | from ...modeling_utils import PreTrainedModel
|
28 | 33 | from ...utils import logging
|
29 | 34 | from ...utils.model_parallel_utils import assert_device_map, get_device_map
|
@@ -967,3 +972,108 @@ def forward(
|
967 | 972 | hidden_states=transformer_outputs.hidden_states,
|
968 | 973 | attentions=transformer_outputs.attentions,
|
969 | 974 | )
|
| 975 | + |
| 976 | + |
| 977 | +@add_start_docstrings( |
| 978 | + """ |
| 979 | + The GPT-J Model transformer with a span classification head on top for extractive question-answering tasks like |
| 980 | + SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). |
| 981 | + """, |
| 982 | + GPTJ_START_DOCSTRING, |
| 983 | +) |
| 984 | +class GPTJForQuestionAnswering(GPTJPreTrainedModel): |
| 985 | + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"] |
| 986 | + |
| 987 | + def __init__(self, config): |
| 988 | + super().__init__(config) |
| 989 | + self.num_labels = config.num_labels |
| 990 | + self.transformer = GPTJModel(config) |
| 991 | + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) |
| 992 | + |
| 993 | + # Model parallel |
| 994 | + self.model_parallel = False |
| 995 | + self.device_map = None |
| 996 | + |
| 997 | + # Initialize weights and apply final processing |
| 998 | + self.post_init() |
| 999 | + |
| 1000 | + @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length")) |
| 1001 | + @add_code_sample_docstrings( |
| 1002 | + processor_class=_TOKENIZER_FOR_DOC, |
| 1003 | + checkpoint=_CHECKPOINT_FOR_DOC, |
| 1004 | + output_type=QuestionAnsweringModelOutput, |
| 1005 | + config_class=_CONFIG_FOR_DOC, |
| 1006 | + ) |
| 1007 | + def forward( |
| 1008 | + self, |
| 1009 | + input_ids=None, |
| 1010 | + attention_mask=None, |
| 1011 | + token_type_ids=None, |
| 1012 | + position_ids=None, |
| 1013 | + head_mask=None, |
| 1014 | + inputs_embeds=None, |
| 1015 | + start_positions=None, |
| 1016 | + end_positions=None, |
| 1017 | + output_attentions=None, |
| 1018 | + output_hidden_states=None, |
| 1019 | + return_dict=None, |
| 1020 | + ): |
| 1021 | + r""" |
| 1022 | + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): |
| 1023 | + Labels for position (index) of the start of the labelled span for computing the token classification loss. |
| 1024 | + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the |
| 1025 | + sequence are not taken into account for computing the loss. |
| 1026 | + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): |
| 1027 | + Labels for position (index) of the end of the labelled span for computing the token classification loss. |
| 1028 | + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the |
| 1029 | + sequence are not taken into account for computing the loss. |
| 1030 | + """ |
| 1031 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 1032 | + |
| 1033 | + outputs = self.transformer( |
| 1034 | + input_ids, |
| 1035 | + attention_mask=attention_mask, |
| 1036 | + token_type_ids=token_type_ids, |
| 1037 | + position_ids=position_ids, |
| 1038 | + head_mask=head_mask, |
| 1039 | + inputs_embeds=inputs_embeds, |
| 1040 | + output_attentions=output_attentions, |
| 1041 | + output_hidden_states=output_hidden_states, |
| 1042 | + return_dict=return_dict, |
| 1043 | + ) |
| 1044 | + |
| 1045 | + sequence_output = outputs[0] |
| 1046 | + |
| 1047 | + logits = self.qa_outputs(sequence_output) |
| 1048 | + start_logits, end_logits = logits.split(1, dim=-1) |
| 1049 | + start_logits = start_logits.squeeze(-1).contiguous() |
| 1050 | + end_logits = end_logits.squeeze(-1).contiguous() |
| 1051 | + |
| 1052 | + total_loss = None |
| 1053 | + if start_positions is not None and end_positions is not None: |
| 1054 | + # If we are on multi-GPU, split add a dimension |
| 1055 | + if len(start_positions.size()) > 1: |
| 1056 | + start_positions = start_positions.squeeze(-1) |
| 1057 | + if len(end_positions.size()) > 1: |
| 1058 | + end_positions = end_positions.squeeze(-1) |
| 1059 | + # sometimes the start/end positions are outside our model inputs, we ignore these terms |
| 1060 | + ignored_index = start_logits.size(1) |
| 1061 | + start_positions = start_positions.clamp(0, ignored_index) |
| 1062 | + end_positions = end_positions.clamp(0, ignored_index) |
| 1063 | + |
| 1064 | + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) |
| 1065 | + start_loss = loss_fct(start_logits, start_positions) |
| 1066 | + end_loss = loss_fct(end_logits, end_positions) |
| 1067 | + total_loss = (start_loss + end_loss) / 2 |
| 1068 | + |
| 1069 | + if not return_dict: |
| 1070 | + output = (start_logits, end_logits) + outputs[2:] |
| 1071 | + return ((total_loss,) + output) if total_loss is not None else output |
| 1072 | + |
| 1073 | + return QuestionAnsweringModelOutput( |
| 1074 | + loss=total_loss, |
| 1075 | + start_logits=start_logits, |
| 1076 | + end_logits=end_logits, |
| 1077 | + hidden_states=outputs.hidden_states, |
| 1078 | + attentions=outputs.attentions, |
| 1079 | + ) |
0 commit comments