Skip to content

Commit 0f3f045

Browse files
Add GPTJForQuestionAnswering (#14503)
* Add GPTJForQuestionAnswering * Reformat for GPTJForQuestionAnswering * Fix isort error * make style for GPTJForQA * Add _keys_to_ignore_on_load_missing * Change the sequence of qa and classification Co-authored-by: Suraj Patil <surajp815@gmail.com>
1 parent 1ccc033 commit 0f3f045

File tree

7 files changed

+141
-2
lines changed

7 files changed

+141
-2
lines changed

docs/source/model_doc/gptj.rst

+7
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,13 @@ GPTJForSequenceClassification
121121
:members: forward
122122

123123

124+
GPTJForQuestionAnswering
125+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
126+
127+
.. autoclass:: transformers.GPTJForQuestionAnswering
128+
:members: forward
129+
130+
124131
FlaxGPTJModel
125132
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
126133

src/transformers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,7 @@
951951
[
952952
"GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST",
953953
"GPTJForCausalLM",
954+
"GPTJForQuestionAnswering",
954955
"GPTJForSequenceClassification",
955956
"GPTJModel",
956957
"GPTJPreTrainedModel",
@@ -2833,6 +2834,7 @@
28332834
from .models.gptj import (
28342835
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST,
28352836
GPTJForCausalLM,
2837+
GPTJForQuestionAnswering,
28362838
GPTJForSequenceClassification,
28372839
GPTJModel,
28382840
GPTJPreTrainedModel,

src/transformers/models/auto/modeling_auto.py

+1
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@
385385
# Model for Question Answering mapping
386386
("qdqbert", "QDQBertForQuestionAnswering"),
387387
("fnet", "FNetForQuestionAnswering"),
388+
("gptj", "GPTJForQuestionAnswering"),
388389
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
389390
("rembert", "RemBertForQuestionAnswering"),
390391
("canine", "CanineForQuestionAnswering"),

src/transformers/models/gptj/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
_import_structure["modeling_gptj"] = [
2929
"GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST",
3030
"GPTJForCausalLM",
31+
"GPTJForQuestionAnswering",
3132
"GPTJForSequenceClassification",
3233
"GPTJModel",
3334
"GPTJPreTrainedModel",
@@ -48,6 +49,7 @@
4849
from .modeling_gptj import (
4950
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST,
5051
GPTJForCausalLM,
52+
GPTJForQuestionAnswering,
5153
GPTJForSequenceClassification,
5254
GPTJModel,
5355
GPTJPreTrainedModel,

src/transformers/models/gptj/modeling_gptj.py

+111-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323

2424
from ...activations import ACT2FN
2525
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
26-
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
26+
from ...modeling_outputs import (
27+
BaseModelOutputWithPast,
28+
CausalLMOutputWithPast,
29+
QuestionAnsweringModelOutput,
30+
SequenceClassifierOutputWithPast,
31+
)
2732
from ...modeling_utils import PreTrainedModel
2833
from ...utils import logging
2934
from ...utils.model_parallel_utils import assert_device_map, get_device_map
@@ -967,3 +972,108 @@ def forward(
967972
hidden_states=transformer_outputs.hidden_states,
968973
attentions=transformer_outputs.attentions,
969974
)
975+
976+
977+
@add_start_docstrings(
978+
"""
979+
The GPT-J Model transformer with a span classification head on top for extractive question-answering tasks like
980+
SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
981+
""",
982+
GPTJ_START_DOCSTRING,
983+
)
984+
class GPTJForQuestionAnswering(GPTJPreTrainedModel):
985+
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"]
986+
987+
def __init__(self, config):
988+
super().__init__(config)
989+
self.num_labels = config.num_labels
990+
self.transformer = GPTJModel(config)
991+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
992+
993+
# Model parallel
994+
self.model_parallel = False
995+
self.device_map = None
996+
997+
# Initialize weights and apply final processing
998+
self.post_init()
999+
1000+
@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1001+
@add_code_sample_docstrings(
1002+
processor_class=_TOKENIZER_FOR_DOC,
1003+
checkpoint=_CHECKPOINT_FOR_DOC,
1004+
output_type=QuestionAnsweringModelOutput,
1005+
config_class=_CONFIG_FOR_DOC,
1006+
)
1007+
def forward(
1008+
self,
1009+
input_ids=None,
1010+
attention_mask=None,
1011+
token_type_ids=None,
1012+
position_ids=None,
1013+
head_mask=None,
1014+
inputs_embeds=None,
1015+
start_positions=None,
1016+
end_positions=None,
1017+
output_attentions=None,
1018+
output_hidden_states=None,
1019+
return_dict=None,
1020+
):
1021+
r"""
1022+
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1023+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
1024+
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
1025+
sequence are not taken into account for computing the loss.
1026+
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1027+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
1028+
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
1029+
sequence are not taken into account for computing the loss.
1030+
"""
1031+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1032+
1033+
outputs = self.transformer(
1034+
input_ids,
1035+
attention_mask=attention_mask,
1036+
token_type_ids=token_type_ids,
1037+
position_ids=position_ids,
1038+
head_mask=head_mask,
1039+
inputs_embeds=inputs_embeds,
1040+
output_attentions=output_attentions,
1041+
output_hidden_states=output_hidden_states,
1042+
return_dict=return_dict,
1043+
)
1044+
1045+
sequence_output = outputs[0]
1046+
1047+
logits = self.qa_outputs(sequence_output)
1048+
start_logits, end_logits = logits.split(1, dim=-1)
1049+
start_logits = start_logits.squeeze(-1).contiguous()
1050+
end_logits = end_logits.squeeze(-1).contiguous()
1051+
1052+
total_loss = None
1053+
if start_positions is not None and end_positions is not None:
1054+
# If we are on multi-GPU, split add a dimension
1055+
if len(start_positions.size()) > 1:
1056+
start_positions = start_positions.squeeze(-1)
1057+
if len(end_positions.size()) > 1:
1058+
end_positions = end_positions.squeeze(-1)
1059+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
1060+
ignored_index = start_logits.size(1)
1061+
start_positions = start_positions.clamp(0, ignored_index)
1062+
end_positions = end_positions.clamp(0, ignored_index)
1063+
1064+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1065+
start_loss = loss_fct(start_logits, start_positions)
1066+
end_loss = loss_fct(end_logits, end_positions)
1067+
total_loss = (start_loss + end_loss) / 2
1068+
1069+
if not return_dict:
1070+
output = (start_logits, end_logits) + outputs[2:]
1071+
return ((total_loss,) + output) if total_loss is not None else output
1072+
1073+
return QuestionAnsweringModelOutput(
1074+
loss=total_loss,
1075+
start_logits=start_logits,
1076+
end_logits=end_logits,
1077+
hidden_states=outputs.hidden_states,
1078+
attentions=outputs.attentions,
1079+
)

src/transformers/utils/dummy_pt_objects.py

+12
Original file line numberDiff line numberDiff line change
@@ -2494,6 +2494,18 @@ def forward(self, *args, **kwargs):
24942494
requires_backends(self, ["torch"])
24952495

24962496

2497+
class GPTJForQuestionAnswering:
2498+
def __init__(self, *args, **kwargs):
2499+
requires_backends(self, ["torch"])
2500+
2501+
@classmethod
2502+
def from_pretrained(cls, *args, **kwargs):
2503+
requires_backends(cls, ["torch"])
2504+
2505+
def forward(self, *args, **kwargs):
2506+
requires_backends(self, ["torch"])
2507+
2508+
24972509
class GPTJForSequenceClassification:
24982510
def __init__(self, *args, **kwargs):
24992511
requires_backends(self, ["torch"])

tests/test_modeling_gptj.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST,
3333
AutoTokenizer,
3434
GPTJForCausalLM,
35+
GPTJForQuestionAnswering,
3536
GPTJForSequenceClassification,
3637
GPTJModel,
3738
)
@@ -356,7 +357,11 @@ def prepare_config_and_inputs_for_common(self):
356357
@require_torch
357358
class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
358359

359-
all_model_classes = (GPTJModel, GPTJForCausalLM, GPTJForSequenceClassification) if is_torch_available() else ()
360+
all_model_classes = (
361+
(GPTJModel, GPTJForCausalLM, GPTJForSequenceClassification, GPTJForQuestionAnswering)
362+
if is_torch_available()
363+
else ()
364+
)
360365
all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else ()
361366
fx_ready_model_classes = all_model_classes
362367
test_pruning = False

0 commit comments

Comments
 (0)