Skip to content

Commit 69e16ab

Browse files
nbroad1881sgugger
andauthored
Switch from using sum for flattening lists of lists in group_texts (#14472)
* remove sum for list flattening * change to chain(*) * make chain object a list * delete empty lines per sgugger's suggestions Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Nicholas Broad <nicholas@nmbroad.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
1 parent 0b7d053 commit 69e16ab

File tree

15 files changed

+35
-20
lines changed

15 files changed

+35
-20
lines changed

examples/flax/language-modeling/run_clm_flax.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import sys
2828
import time
2929
from dataclasses import dataclass, field
30+
from itertools import chain
3031
from pathlib import Path
3132
from typing import Callable, Optional
3233

@@ -430,7 +431,7 @@ def tokenize_function(examples):
430431
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
431432
def group_texts(examples):
432433
# Concatenate all texts.
433-
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
434+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
434435
total_length = len(concatenated_examples[list(examples.keys())[0]])
435436
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
436437
# customize this part to your needs.

examples/flax/language-modeling/run_mlm_flax.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import sys
2626
import time
2727
from dataclasses import dataclass, field
28+
from itertools import chain
2829

2930
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
3031
from pathlib import Path
@@ -453,7 +454,7 @@ def tokenize_function(examples):
453454
# max_seq_length.
454455
def group_texts(examples):
455456
# Concatenate all texts.
456-
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
457+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
457458
total_length = len(concatenated_examples[list(examples.keys())[0]])
458459
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
459460
# customize this part to your needs.

examples/flax/language-modeling/run_t5_mlm_flax.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import sys
2626
import time
2727
from dataclasses import dataclass, field
28+
from itertools import chain
2829
from pathlib import Path
2930
from typing import Dict, List, Optional
3031

@@ -563,7 +564,7 @@ def tokenize_function(examples):
563564
# Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
564565
def group_texts(examples):
565566
# Concatenate all texts.
566-
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
567+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
567568
total_length = len(concatenated_examples[list(examples.keys())[0]])
568569
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
569570
# customize this part to your needs.

examples/pytorch/language-modeling/run_clm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import os
2727
import sys
2828
from dataclasses import dataclass, field
29+
from itertools import chain
2930
from typing import Optional
3031

3132
import datasets
@@ -408,7 +409,7 @@ def tokenize_function(examples):
408409
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
409410
def group_texts(examples):
410411
# Concatenate all texts.
411-
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
412+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
412413
total_length = len(concatenated_examples[list(examples.keys())[0]])
413414
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
414415
# customize this part to your needs.

examples/pytorch/language-modeling/run_clm_no_trainer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import math
2828
import os
2929
import random
30+
from itertools import chain
3031
from pathlib import Path
3132

3233
import datasets
@@ -366,7 +367,7 @@ def tokenize_function(examples):
366367
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
367368
def group_texts(examples):
368369
# Concatenate all texts.
369-
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
370+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
370371
total_length = len(concatenated_examples[list(examples.keys())[0]])
371372
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
372373
# customize this part to your needs.

examples/pytorch/language-modeling/run_mlm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import os
2727
import sys
2828
from dataclasses import dataclass, field
29+
from itertools import chain
2930
from typing import Optional
3031

3132
import datasets
@@ -432,7 +433,7 @@ def tokenize_function(examples):
432433
# max_seq_length.
433434
def group_texts(examples):
434435
# Concatenate all texts.
435-
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
436+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
436437
total_length = len(concatenated_examples[list(examples.keys())[0]])
437438
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
438439
# customize this part to your needs.

examples/pytorch/language-modeling/run_mlm_no_trainer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import math
2828
import os
2929
import random
30+
from itertools import chain
3031
from pathlib import Path
3132

3233
import datasets
@@ -406,7 +407,7 @@ def tokenize_function(examples):
406407
# max_seq_length.
407408
def group_texts(examples):
408409
# Concatenate all texts.
409-
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
410+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
410411
total_length = len(concatenated_examples[list(examples.keys())[0]])
411412
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
412413
# customize this part to your needs.

examples/pytorch/language-modeling/run_plm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import os
2424
import sys
2525
from dataclasses import dataclass, field
26+
from itertools import chain
2627
from typing import Optional
2728

2829
import datasets
@@ -403,7 +404,7 @@ def tokenize_function(examples):
403404
# max_seq_length.
404405
def group_texts(examples):
405406
# Concatenate all texts.
406-
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
407+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
407408
total_length = len(concatenated_examples[list(examples.keys())[0]])
408409
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
409410
# customize this part to your needs.

examples/pytorch/multiple-choice/run_swag.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import os
2323
import sys
2424
from dataclasses import dataclass, field
25+
from itertools import chain
2526
from typing import Optional, Union
2627

2728
import datasets
@@ -185,7 +186,7 @@ def __call__(self, features):
185186
flattened_features = [
186187
[{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
187188
]
188-
flattened_features = sum(flattened_features, [])
189+
flattened_features = list(chain(*flattened_features))
189190

190191
batch = self.tokenizer.pad(
191192
flattened_features,
@@ -333,8 +334,8 @@ def preprocess_function(examples):
333334
]
334335

335336
# Flatten out
336-
first_sentences = sum(first_sentences, [])
337-
second_sentences = sum(second_sentences, [])
337+
first_sentences = list(chain(*first_sentences))
338+
second_sentences = list(chain(*second_sentences))
338339

339340
# Tokenize
340341
tokenized_examples = tokenizer(

examples/pytorch/multiple-choice/run_swag_no_trainer.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import os
2525
import random
2626
from dataclasses import dataclass
27+
from itertools import chain
2728
from pathlib import Path
2829
from typing import Optional, Union
2930

@@ -224,7 +225,7 @@ def __call__(self, features):
224225
flattened_features = [
225226
[{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
226227
]
227-
flattened_features = sum(flattened_features, [])
228+
flattened_features = list(chain(*flattened_features))
228229

229230
batch = self.tokenizer.pad(
230231
flattened_features,
@@ -365,8 +366,8 @@ def preprocess_function(examples):
365366
labels = examples[label_column_name]
366367

367368
# Flatten out
368-
first_sentences = sum(first_sentences, [])
369-
second_sentences = sum(second_sentences, [])
369+
first_sentences = list(chain(*first_sentences))
370+
second_sentences = list(chain(*second_sentences))
370371

371372
# Tokenize
372373
tokenized_examples = tokenizer(

examples/research_projects/jax-projects/model_parallel/run_clm_mp.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import sys
2424
import time
2525
from dataclasses import dataclass, field
26+
from itertools import chain
2627
from pathlib import Path
2728
from typing import Callable, Optional
2829

@@ -364,7 +365,7 @@ def tokenize_function(examples):
364365
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
365366
def group_texts(examples):
366367
# Concatenate all texts.
367-
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
368+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
368369
total_length = len(concatenated_examples[list(examples.keys())[0]])
369370
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
370371
# customize this part to your needs.

examples/tensorflow/language-modeling/run_clm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import sys
3131
from dataclasses import dataclass, field
3232
from functools import partial
33+
from itertools import chain
3334
from pathlib import Path
3435
from typing import Optional
3536

@@ -406,7 +407,7 @@ def tokenize_function(examples):
406407
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
407408
def group_texts(examples):
408409
# Concatenate all texts.
409-
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
410+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
410411
total_length = len(concatenated_examples[list(examples.keys())[0]])
411412
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
412413
# customize this part to your needs.

examples/tensorflow/language-modeling/run_mlm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import sys
3333
from dataclasses import dataclass, field
3434
from functools import partial
35+
from itertools import chain
3536
from pathlib import Path
3637
from typing import Optional
3738

@@ -462,7 +463,7 @@ def tokenize_function(examples):
462463
# max_seq_length.
463464
def group_texts(examples):
464465
# Concatenate all texts.
465-
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
466+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
466467
total_length = len(concatenated_examples[list(examples.keys())[0]])
467468
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
468469
# customize this part to your needs.

examples/tensorflow/multiple-choice/run_swag.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import os
2323
import sys
2424
from dataclasses import dataclass, field
25+
from itertools import chain
2526
from pathlib import Path
2627
from typing import Optional
2728

@@ -342,8 +343,8 @@ def preprocess_function(examples):
342343
]
343344

344345
# Flatten out
345-
first_sentences = sum(first_sentences, [])
346-
second_sentences = sum(second_sentences, [])
346+
first_sentences = list(chain(*first_sentences))
347+
second_sentences = list(chain(*second_sentences))
347348

348349
# Tokenize
349350
tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True, max_length=max_seq_length)

src/transformers/file_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from enum import Enum
3636
from functools import partial, wraps
3737
from hashlib import sha256
38+
from itertools import chain
3839
from pathlib import Path
3940
from types import ModuleType
4041
from typing import Any, BinaryIO, ContextManager, Dict, List, Optional, Tuple, Union
@@ -2129,7 +2130,7 @@ def __init__(self, name, module_file, import_structure, module_spec=None, extra_
21292130
for value in values:
21302131
self._class_to_module[value] = key
21312132
# Needed for autocompletion in an IDE
2132-
self.__all__ = list(import_structure.keys()) + sum(import_structure.values(), [])
2133+
self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
21332134
self.__file__ = module_file
21342135
self.__spec__ = module_spec
21352136
self.__path__ = [os.path.dirname(module_file)]

0 commit comments

Comments
 (0)