Skip to content

Commit 70708cc

Browse files
fix t5 token type ids (#8437)
1 parent 9fd1f56 commit 70708cc

File tree

3 files changed

+58
-0
lines changed

3 files changed

+58
-0
lines changed

src/transformers/tokenization_t5.py

+22
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,28 @@ def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
187187
else:
188188
return token_ids + [self.eos_token_id]
189189

190+
def create_token_type_ids_from_sequences(
191+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
192+
) -> List[int]:
193+
"""
194+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
195+
use of token type ids, therefore a list of zeros is returned.
196+
197+
Args:
198+
token_ids_0 (:obj:`List[int]`):
199+
List of IDs.
200+
token_ids_1 (:obj:`List[int]`, `optional`):
201+
Optional second list of IDs for sequence pairs.
202+
203+
Returns:
204+
:obj:`List[int]`: List of zeros.
205+
"""
206+
eos = [self.eos_token_id]
207+
208+
if token_ids_1 is None:
209+
return len(token_ids_0 + eos) * [0]
210+
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
211+
190212
def build_inputs_with_special_tokens(
191213
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
192214
) -> List[int]:

src/transformers/tokenization_t5_fast.py

+22
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,28 @@ def build_inputs_with_special_tokens(
191191
token_ids_1 = token_ids_1 + [self.eos_token_id]
192192
return self.prefix_tokens + token_ids_0 + token_ids_1
193193

194+
def create_token_type_ids_from_sequences(
195+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
196+
) -> List[int]:
197+
"""
198+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
199+
use of token type ids, therefore a list of zeros is returned.
200+
201+
Args:
202+
token_ids_0 (:obj:`List[int]`):
203+
List of IDs.
204+
token_ids_1 (:obj:`List[int]`, `optional`):
205+
Optional second list of IDs for sequence pairs.
206+
207+
Returns:
208+
:obj:`List[int]`: List of zeros.
209+
"""
210+
eos = [self.eos_token_id]
211+
212+
if token_ids_1 is None:
213+
return len(token_ids_0 + eos) * [0]
214+
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
215+
194216
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
195217
def prepare_seq2seq_batch(
196218
self,

tests/test_tokenization_t5.py

+14
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,20 @@ def test_eos_in_input(self):
223223
self.assertEqual(expected_src_tokens, src_ids)
224224
self.assertEqual(expected_tgt_tokens, tgt_ids)
225225

226+
def test_token_type_ids(self):
227+
src_text_1 = ["A first paragraph for summarization."]
228+
src_text_2 = ["A second paragraph for summarization."]
229+
230+
fast_token_type_ids = self.t5_base_tokenizer_fast(
231+
src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True
232+
).token_type_ids
233+
slow_token_type_ids = self.t5_base_tokenizer(
234+
src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True
235+
).token_type_ids
236+
237+
self.assertEqual(slow_token_type_ids, fast_token_type_ids)
238+
self.assertEqual(len(slow_token_type_ids[0]), 18)
239+
226240
def test_fast_and_slow_same_result(self):
227241
src_text = "<pad> Today is <unk> nice day </s>"
228242
tgt_ids = [0, 1960, 19, 2, 1245, 239, 1]

0 commit comments

Comments
 (0)