Skip to content

Commit 07384ba

Browse files
authored
AutoModelForTableQuestionAnswering (#9154)
* AutoModelForTableQuestionAnswering * Update src/transformers/models/auto/modeling_auto.py * Style
1 parent 3433466 commit 07384ba

File tree

8 files changed

+161
-5
lines changed

8 files changed

+161
-5
lines changed

docs/source/model_doc/auto.rst

+7
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,13 @@ AutoModelForQuestionAnswering
114114
:members:
115115

116116

117+
AutoModelForTableQuestionAnswering
118+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
119+
120+
.. autoclass:: transformers.AutoModelForTableQuestionAnswering
121+
:members:
122+
123+
117124
TFAutoModel
118125
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
119126

src/transformers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@
358358
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
359359
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
360360
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
361+
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
361362
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
362363
MODEL_MAPPING,
363364
MODEL_WITH_LM_HEAD_MAPPING,
@@ -370,6 +371,7 @@
370371
AutoModelForQuestionAnswering,
371372
AutoModelForSeq2SeqLM,
372373
AutoModelForSequenceClassification,
374+
AutoModelForTableQuestionAnswering,
373375
AutoModelForTokenClassification,
374376
AutoModelWithLMHead,
375377
)

src/transformers/dependency_versions_table.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
"sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3",
4141
"sphinx": "sphinx==3.2.1",
4242
"starlette": "starlette",
43-
"tensorflow-cpu": "tensorflow-cpu>=2.0,<2.4",
44-
"tensorflow": "tensorflow>=2.0,<2.4",
43+
"tensorflow-cpu": "tensorflow-cpu>=2.0",
44+
"tensorflow": "tensorflow>=2.0",
4545
"timeout-decorator": "timeout-decorator",
4646
"tokenizers": "tokenizers==0.9.4",
4747
"torch": "torch>=1.0",

src/transformers/models/auto/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
3232
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
3333
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
34+
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
3435
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
3536
MODEL_MAPPING,
3637
MODEL_WITH_LM_HEAD_MAPPING,
@@ -43,6 +44,7 @@
4344
AutoModelForQuestionAnswering,
4445
AutoModelForSeq2SeqLM,
4546
AutoModelForSequenceClassification,
47+
AutoModelForTableQuestionAnswering,
4648
AutoModelForTokenClassification,
4749
AutoModelWithLMHead,
4850
)

src/transformers/models/auto/modeling_auto.py

+106
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,12 @@
467467
(FunnelConfig, FunnelForQuestionAnswering),
468468
(LxmertConfig, LxmertForQuestionAnswering),
469469
(MPNetConfig, MPNetForQuestionAnswering),
470+
]
471+
)
472+
473+
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = OrderedDict(
474+
[
475+
# Model for Table Question Answering mapping
470476
(TapasConfig, TapasForQuestionAnswering),
471477
]
472478
)
@@ -1384,6 +1390,106 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
13841390
)
13851391

13861392

1393+
class AutoModelForTableQuestionAnswering:
1394+
r"""
1395+
This is a generic model class that will be instantiated as one of the model classes of the library---with a table
1396+
question answering head---when created with the when created with the
1397+
:meth:`~transformers.AutoModeForTableQuestionAnswering.from_pretrained` class method or the
1398+
:meth:`~transformers.AutoModelForTableQuestionAnswering.from_config` class method.
1399+
1400+
This class cannot be instantiated directly using ``__init__()`` (throws an error).
1401+
"""
1402+
1403+
def __init__(self):
1404+
raise EnvironmentError(
1405+
"AutoModelForQuestionAnswering is designed to be instantiated "
1406+
"using the `AutoModelForTableQuestionAnswering.from_pretrained(pretrained_model_name_or_path)` or "
1407+
"`AutoModelForTableQuestionAnswering.from_config(config)` methods."
1408+
)
1409+
1410+
@classmethod
1411+
@replace_list_option_in_docstrings(MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, use_model_types=False)
1412+
def from_config(cls, config):
1413+
r"""
1414+
Instantiates one of the model classes of the library---with a table question answering head---from a
1415+
configuration.
1416+
1417+
Note:
1418+
Loading a model from its configuration file does **not** load the model weights. It only affects the
1419+
model's configuration. Use :meth:`~transformers.AutoModelForTableQuestionAnswering.from_pretrained` to load
1420+
the model weights.
1421+
1422+
Args:
1423+
config (:class:`~transformers.PretrainedConfig`):
1424+
The model class to instantiate is selected based on the configuration class:
1425+
1426+
List options
1427+
1428+
Examples::
1429+
1430+
>>> from transformers import AutoConfig, AutoModelForTableQuestionAnswering
1431+
>>> # Download configuration from huggingface.co and cache.
1432+
>>> config = AutoConfig.from_pretrained('google/tapas-base-finetuned-wtq')
1433+
>>> model = AutoModelForTableQuestionAnswering.from_config(config)
1434+
"""
1435+
if type(config) in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.keys():
1436+
return MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING[type(config)](config)
1437+
1438+
raise ValueError(
1439+
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
1440+
"Model type should be one of {}.".format(
1441+
config.__class__,
1442+
cls.__name__,
1443+
", ".join(c.__name__ for c in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.keys()),
1444+
)
1445+
)
1446+
1447+
@classmethod
1448+
@replace_list_option_in_docstrings(MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING)
1449+
@add_start_docstrings(
1450+
"Instantiate one of the model classes of the library---with a table question answering head---from a "
1451+
"pretrained model.",
1452+
AUTO_MODEL_PRETRAINED_DOCSTRING,
1453+
)
1454+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1455+
r"""
1456+
Examples::
1457+
1458+
>>> from transformers import AutoConfig, AutoModelForTableQuestionAnswering
1459+
1460+
>>> # Download model and configuration from huggingface.co and cache.
1461+
>>> model = AutoModelForTableQuestionAnswering.from_pretrained('google/tapas-base-finetuned-wtq')
1462+
1463+
>>> # Update configuration during loading
1464+
>>> model = AutoModelForTableQuestionAnswering.from_pretrained('google/tapas-base-finetuned-wtq', output_attentions=True)
1465+
>>> model.config.output_attentions
1466+
True
1467+
1468+
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
1469+
>>> config = AutoConfig.from_json_file('./tf_model/tapas_tf_checkpoint.json')
1470+
>>> model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/tapas_tf_checkpoint.ckpt.index', from_tf=True, config=config)
1471+
"""
1472+
config = kwargs.pop("config", None)
1473+
if not isinstance(config, PretrainedConfig):
1474+
config, kwargs = AutoConfig.from_pretrained(
1475+
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
1476+
)
1477+
1478+
if type(config) in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.keys():
1479+
return MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING[type(config)].from_pretrained(
1480+
pretrained_model_name_or_path, *model_args, config=config, **kwargs
1481+
)
1482+
1483+
raise ValueError(
1484+
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
1485+
"Model type should be one of {}.".format(
1486+
config.__class__,
1487+
cls.__name__,
1488+
", ".join(c.__name__ for c in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.keys()),
1489+
)
1490+
)
1491+
1492+
13871493
class AutoModelForTokenClassification:
13881494
r"""
13891495
This is a generic model class that will be instantiated as one of the model classes of the library---with a token

src/transformers/utils/dummy_pt_objects.py

+12
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,9 @@ def load_tf_weights_in_albert(*args, **kwargs):
303303
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None
304304

305305

306+
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = None
307+
308+
306309
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
307310

308311

@@ -393,6 +396,15 @@ def from_pretrained(self, *args, **kwargs):
393396
requires_pytorch(self)
394397

395398

399+
class AutoModelForTableQuestionAnswering:
400+
def __init__(self, *args, **kwargs):
401+
requires_pytorch(self)
402+
403+
@classmethod
404+
def from_pretrained(self, *args, **kwargs):
405+
requires_pytorch(self)
406+
407+
396408
class AutoModelForTokenClassification:
397409
def __init__(self, *args, **kwargs):
398410
requires_pytorch(self)

tests/test_modeling_auto.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@
1717
import unittest
1818

1919
from transformers import is_torch_available
20-
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_torch, slow
20+
from transformers.testing_utils import (
21+
DUMMY_UNKWOWN_IDENTIFIER,
22+
SMALL_MODEL_IDENTIFIER,
23+
require_scatter,
24+
require_torch,
25+
slow,
26+
)
2127

2228

2329
if is_torch_available():
@@ -30,6 +36,7 @@
3036
AutoModelForQuestionAnswering,
3137
AutoModelForSeq2SeqLM,
3238
AutoModelForSequenceClassification,
39+
AutoModelForTableQuestionAnswering,
3340
AutoModelForTokenClassification,
3441
AutoModelWithLMHead,
3542
BertConfig,
@@ -44,6 +51,8 @@
4451
RobertaForMaskedLM,
4552
T5Config,
4653
T5ForConditionalGeneration,
54+
TapasConfig,
55+
TapasForQuestionAnswering,
4756
)
4857
from transformers.models.auto.modeling_auto import (
4958
MODEL_FOR_CAUSAL_LM_MAPPING,
@@ -52,13 +61,15 @@
5261
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
5362
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
5463
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
64+
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
5565
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
5666
MODEL_MAPPING,
5767
MODEL_WITH_LM_HEAD_MAPPING,
5868
)
5969
from transformers.models.bert.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
6070
from transformers.models.gpt2.modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_LIST
6171
from transformers.models.t5.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_LIST
72+
from transformers.models.tapas.modeling_tapas import TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST
6273

6374

6475
@require_torch
@@ -168,6 +179,21 @@ def test_question_answering_model_from_pretrained(self):
168179
self.assertIsNotNone(model)
169180
self.assertIsInstance(model, BertForQuestionAnswering)
170181

182+
@slow
183+
@require_scatter
184+
def test_table_question_answering_model_from_pretrained(self):
185+
for model_name in TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST[5:6]:
186+
config = AutoConfig.from_pretrained(model_name)
187+
self.assertIsNotNone(config)
188+
self.assertIsInstance(config, TapasConfig)
189+
190+
model = AutoModelForTableQuestionAnswering.from_pretrained(model_name)
191+
model, loading_info = AutoModelForTableQuestionAnswering.from_pretrained(
192+
model_name, output_loading_info=True
193+
)
194+
self.assertIsNotNone(model)
195+
self.assertIsInstance(model, TapasForQuestionAnswering)
196+
171197
@slow
172198
def test_token_classification_model_from_pretrained(self):
173199
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
@@ -200,6 +226,7 @@ def test_parents_and_children_in_mappings(self):
200226
MODEL_MAPPING,
201227
MODEL_FOR_PRETRAINING_MAPPING,
202228
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
229+
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
203230
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
204231
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
205232
MODEL_WITH_LM_HEAD_MAPPING,

tests/test_modeling_tapas.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
MODEL_FOR_MASKED_LM_MAPPING,
2626
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
2727
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
28-
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
2928
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
3029
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
30+
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
3131
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
3232
is_torch_available,
3333
)
@@ -436,7 +436,7 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
436436
if return_labels:
437437
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
438438
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
439-
elif model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
439+
elif model_class in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.values():
440440
inputs_dict["labels"] = torch.zeros(
441441
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
442442
)

0 commit comments

Comments
 (0)