Skip to content

Commit e87505f

Browse files
[Flax] Add other BERT classes (#10977)
* add first code structures * add all bert models * add to init and docs * correct docs * make style
1 parent e031162 commit e87505f

File tree

7 files changed

+624
-21
lines changed

7 files changed

+624
-21
lines changed

docs/source/model_doc/bert.rst

+42
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,50 @@ FlaxBertModel
209209
:members: __call__
210210

211211

212+
FlaxBertForPreTraining
213+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
214+
215+
.. autoclass:: transformers.FlaxBertForPreTraining
216+
:members: __call__
217+
218+
212219
FlaxBertForMaskedLM
213220
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
214221

215222
.. autoclass:: transformers.FlaxBertForMaskedLM
216223
:members: __call__
224+
225+
226+
FlaxBertForNextSentencePrediction
227+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
228+
229+
.. autoclass:: transformers.FlaxBertForNextSentencePrediction
230+
:members: __call__
231+
232+
233+
FlaxBertForSequenceClassification
234+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
235+
236+
.. autoclass:: transformers.FlaxBertForSequenceClassification
237+
:members: __call__
238+
239+
240+
FlaxBertForMultipleChoice
241+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
242+
243+
.. autoclass:: transformers.FlaxBertForMultipleChoice
244+
:members: __call__
245+
246+
247+
FlaxBertForTokenClassification
248+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
249+
250+
.. autoclass:: transformers.FlaxBertForTokenClassification
251+
:members: __call__
252+
253+
254+
FlaxBertForQuestionAnswering
255+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
256+
257+
.. autoclass:: transformers.FlaxBertForQuestionAnswering
258+
:members: __call__

src/transformers/__init__.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -1290,7 +1290,19 @@
12901290
if is_flax_available():
12911291
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
12921292
_import_structure["models.auto"].extend(["FLAX_MODEL_MAPPING", "FlaxAutoModel"])
1293-
_import_structure["models.bert"].extend(["FlaxBertForMaskedLM", "FlaxBertModel"])
1293+
_import_structure["models.bert"].extend(
1294+
[
1295+
"FlaxBertForMaskedLM",
1296+
"FlaxBertForMultipleChoice",
1297+
"FlaxBertForNextSentencePrediction",
1298+
"FlaxBertForPreTraining",
1299+
"FlaxBertForQuestionAnswering",
1300+
"FlaxBertForSequenceClassification",
1301+
"FlaxBertForTokenClassification",
1302+
"FlaxBertModel",
1303+
"FlaxBertPreTrainedModel",
1304+
]
1305+
)
12941306
_import_structure["models.roberta"].append("FlaxRobertaModel")
12951307
else:
12961308
from .utils import dummy_flax_objects
@@ -2372,7 +2384,17 @@
23722384
if is_flax_available():
23732385
from .modeling_flax_utils import FlaxPreTrainedModel
23742386
from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel
2375-
from .models.bert import FlaxBertForMaskedLM, FlaxBertModel
2387+
from .models.bert import (
2388+
FlaxBertForMaskedLM,
2389+
FlaxBertForMultipleChoice,
2390+
FlaxBertForNextSentencePrediction,
2391+
FlaxBertForPreTraining,
2392+
FlaxBertForQuestionAnswering,
2393+
FlaxBertForSequenceClassification,
2394+
FlaxBertForTokenClassification,
2395+
FlaxBertModel,
2396+
FlaxBertPreTrainedModel,
2397+
)
23762398
from .models.roberta import FlaxRobertaModel
23772399
else:
23782400
# Import the same objects as dummies to get them in the namespace.

src/transformers/models/bert/__init__.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,17 @@
7070
]
7171

7272
if is_flax_available():
73-
_import_structure["modeling_flax_bert"] = ["FlaxBertForMaskedLM", "FlaxBertModel"]
74-
73+
_import_structure["modeling_flax_bert"] = [
74+
"FlaxBertForMaskedLM",
75+
"FlaxBertForMultipleChoice",
76+
"FlaxBertForNextSentencePrediction",
77+
"FlaxBertForPreTraining",
78+
"FlaxBertForQuestionAnswering",
79+
"FlaxBertForSequenceClassification",
80+
"FlaxBertForTokenClassification",
81+
"FlaxBertModel",
82+
"FlaxBertPreTrainedModel",
83+
]
7584

7685
if TYPE_CHECKING:
7786
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
@@ -115,7 +124,17 @@
115124
)
116125

117126
if is_flax_available():
118-
from .modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel
127+
from .modeling_flax_bert import (
128+
FlaxBertForMaskedLM,
129+
FlaxBertForMultipleChoice,
130+
FlaxBertForNextSentencePrediction,
131+
FlaxBertForPreTraining,
132+
FlaxBertForQuestionAnswering,
133+
FlaxBertForSequenceClassification,
134+
FlaxBertForTokenClassification,
135+
FlaxBertModel,
136+
FlaxBertPreTrainedModel,
137+
)
119138

120139
else:
121140
import importlib

0 commit comments

Comments
 (0)