Skip to content

Commit dfb00bf

Browse files
authored
Expand dynamic supported objects to configs and tokenizers (#14296)
* Dynamic configs * Add config test * Better tests * Add tokenizer and test * Add to from_config * With save
1 parent de635af commit dfb00bf

7 files changed

+272
-10
lines changed

src/transformers/models/auto/auto_factory.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,24 @@ def __init__(self, *args, **kwargs):
378378

379379
@classmethod
380380
def from_config(cls, config, **kwargs):
381-
if type(config) in cls._model_mapping.keys():
381+
trust_remote_code = kwargs.pop("trust_remote_code", False)
382+
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
383+
if not trust_remote_code:
384+
raise ValueError(
385+
"Loading this model requires you to execute the modeling file in that repo "
386+
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
387+
"the option `trust_remote_code=True` to remove this error."
388+
)
389+
if kwargs.get("revision", None) is None:
390+
logger.warn(
391+
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
392+
"no malicious code has been contributed in a newer revision."
393+
)
394+
class_ref = config.auto_map[cls.__name__]
395+
module_file, class_name = class_ref.split(".")
396+
model_class = get_class_from_dynamic_module(config.name_or_path, module_file + ".py", class_name, **kwargs)
397+
return model_class._from_config(config, **kwargs)
398+
elif type(config) in cls._model_mapping.keys():
382399
model_class = _get_model_class(config, cls._model_mapping)
383400
return model_class._from_config(config, **kwargs)
384401

@@ -394,7 +411,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
394411
kwargs["_from_auto"] = True
395412
if not isinstance(config, PretrainedConfig):
396413
config, kwargs = AutoConfig.from_pretrained(
397-
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
414+
pretrained_model_name_or_path, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **kwargs
398415
)
399416
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
400417
if not trust_remote_code:

src/transformers/models/auto/configuration_auto.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,12 @@
2121

2222
from ...configuration_utils import PretrainedConfig
2323
from ...file_utils import CONFIG_NAME
24+
from ...utils import logging
25+
from .dynamic import get_class_from_dynamic_module
2426

2527

28+
logger = logging.get_logger(__name__)
29+
2630
CONFIG_MAPPING_NAMES = OrderedDict(
2731
[
2832
# Add configs here
@@ -523,6 +527,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
523527
If :obj:`True`, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs`
524528
is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e.,
525529
the part of ``kwargs`` which has not been used to update ``config`` and is otherwise ignored.
530+
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
531+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
532+
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
533+
will execute code present on the Hub on your local machine.
526534
kwargs(additional keyword arguments, `optional`):
527535
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
528536
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
@@ -555,8 +563,28 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
555563
{'foo': False}
556564
"""
557565
kwargs["_from_auto"] = True
566+
kwargs["name_or_path"] = pretrained_model_name_or_path
567+
trust_remote_code = kwargs.pop("trust_remote_code", False)
558568
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
559-
if "model_type" in config_dict:
569+
if "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]:
570+
if not trust_remote_code:
571+
raise ValueError(
572+
f"Loading {pretrained_model_name_or_path} requires you to execute the configuration file in that repo "
573+
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
574+
"the option `trust_remote_code=True` to remove this error."
575+
)
576+
if kwargs.get("revision", None) is None:
577+
logger.warn(
578+
"Explicitly passing a `revision` is encouraged when loading a configuration with custom code to "
579+
"ensure no malicious code has been contributed in a newer revision."
580+
)
581+
class_ref = config_dict["auto_map"]["AutoConfig"]
582+
module_file, class_name = class_ref.split(".")
583+
config_class = get_class_from_dynamic_module(
584+
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
585+
)
586+
return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
587+
elif "model_type" in config_dict:
560588
config_class = CONFIG_MAPPING[config_dict["model_type"]]
561589
return config_class.from_dict(config_dict, **kwargs)
562590
else:

src/transformers/models/auto/tokenization_auto.py

+36-2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
model_type_to_module_name,
4242
replace_list_option_in_docstrings,
4343
)
44+
from .dynamic import get_class_from_dynamic_module
4445

4546

4647
logger = logging.get_logger(__name__)
@@ -412,6 +413,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
412413
Whether or not to try to load the fast version of the tokenizer.
413414
tokenizer_type (:obj:`str`, `optional`):
414415
Tokenizer type to be loaded.
416+
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
417+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
418+
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
419+
will execute code present on the Hub on your local machine.
415420
kwargs (additional keyword arguments, `optional`):
416421
Will be passed to the Tokenizer ``__init__()`` method. Can be used to set special tokens like
417422
``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``,
@@ -436,6 +441,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
436441

437442
use_fast = kwargs.pop("use_fast", True)
438443
tokenizer_type = kwargs.pop("tokenizer_type", None)
444+
trust_remote_code = kwargs.pop("trust_remote_code", False)
439445

440446
# First, let's see whether the tokenizer_type is passed so that we can leverage it
441447
if tokenizer_type is not None:
@@ -464,17 +470,45 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
464470
# Next, let's try to use the tokenizer_config file to get the tokenizer class.
465471
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
466472
config_tokenizer_class = tokenizer_config.get("tokenizer_class")
473+
tokenizer_auto_map = tokenizer_config.get("auto_map")
467474

468475
# If that did not work, let's try to use the config.
469476
if config_tokenizer_class is None:
470477
if not isinstance(config, PretrainedConfig):
471-
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
478+
config = AutoConfig.from_pretrained(
479+
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
480+
)
472481
config_tokenizer_class = config.tokenizer_class
482+
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
483+
tokenizer_auto_map = config.auto_map["AutoTokenizer"]
473484

474485
# If we have the tokenizer class from the tokenizer config or the model config we're good!
475486
if config_tokenizer_class is not None:
476487
tokenizer_class = None
477-
if use_fast and not config_tokenizer_class.endswith("Fast"):
488+
if tokenizer_auto_map is not None:
489+
if not trust_remote_code:
490+
raise ValueError(
491+
f"Loading {pretrained_model_name_or_path} requires you to execute the tokenizer file in that repo "
492+
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
493+
"the option `trust_remote_code=True` to remove this error."
494+
)
495+
if kwargs.get("revision", None) is None:
496+
logger.warn(
497+
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
498+
"no malicious code has been contributed in a newer revision."
499+
)
500+
501+
if use_fast and tokenizer_auto_map[1] is not None:
502+
class_ref = tokenizer_auto_map[1]
503+
else:
504+
class_ref = tokenizer_auto_map[0]
505+
506+
module_file, class_name = class_ref.split(".")
507+
tokenizer_class = get_class_from_dynamic_module(
508+
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
509+
)
510+
511+
elif use_fast and not config_tokenizer_class.endswith("Fast"):
478512
tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
479513
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
480514
if tokenizer_class is None:

src/transformers/tokenization_utils_base.py

+3
Original file line numberDiff line numberDiff line change
@@ -1784,6 +1784,7 @@ def _from_pretrained(
17841784
# First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers.
17851785
config_tokenizer_class = init_kwargs.get("tokenizer_class")
17861786
init_kwargs.pop("tokenizer_class", None)
1787+
init_kwargs.pop("auto_map", None)
17871788
saved_init_inputs = init_kwargs.pop("init_inputs", ())
17881789
if not init_inputs:
17891790
init_inputs = saved_init_inputs
@@ -2028,6 +2029,8 @@ def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True):
20282029
if tokenizer_class.endswith("Fast") and tokenizer_class != "PreTrainedTokenizerFast":
20292030
tokenizer_class = tokenizer_class[:-4]
20302031
tokenizer_config["tokenizer_class"] = tokenizer_class
2032+
if getattr(self, "_auto_map", None) is not None:
2033+
tokenizer_config["auto_map"] = self._auto_map
20312034

20322035
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
20332036
f.write(json.dumps(tokenizer_config, ensure_ascii=False))

tests/test_configuration_common.py

+41-2
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
import tempfile
2020
import unittest
2121

22-
from huggingface_hub import delete_repo, login
22+
from huggingface_hub import Repository, delete_repo, login
2323
from requests.exceptions import HTTPError
24-
from transformers import BertConfig, GPT2Config, is_torch_available
24+
from transformers import AutoConfig, BertConfig, GPT2Config, is_torch_available
2525
from transformers.configuration_utils import PretrainedConfig
2626
from transformers.testing_utils import PASS, USER, is_staging_test
2727

@@ -190,6 +190,23 @@ def run_common_tests(self):
190190
self.check_config_arguments_init()
191191

192192

193+
class FakeConfig(PretrainedConfig):
194+
def __init__(self, attribute=1, **kwargs):
195+
self.attribute = attribute
196+
super().__init__(**kwargs)
197+
198+
199+
# Make sure this is synchronized with the config above.
200+
FAKE_CONFIG_CODE = """
201+
from transformers import PretrainedConfig
202+
203+
class FakeConfig(PretrainedConfig):
204+
def __init__(self, attribute=1, **kwargs):
205+
self.attribute = attribute
206+
super().__init__(**kwargs)
207+
"""
208+
209+
193210
@is_staging_test
194211
class ConfigPushToHubTester(unittest.TestCase):
195212
@classmethod
@@ -208,6 +225,11 @@ def tearDownClass(cls):
208225
except HTTPError:
209226
pass
210227

228+
try:
229+
delete_repo(token=cls._token, name="test-dynamic-config")
230+
except HTTPError:
231+
pass
232+
211233
def test_push_to_hub(self):
212234
config = BertConfig(
213235
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
@@ -238,6 +260,23 @@ def test_push_to_hub_in_organization(self):
238260
if k != "transformers_version":
239261
self.assertEqual(v, getattr(new_config, k))
240262

263+
def test_push_to_hub_dynamic_config(self):
264+
config = FakeConfig(attribute=42)
265+
config.auto_map = {"AutoConfig": "configuration.FakeConfig"}
266+
267+
with tempfile.TemporaryDirectory() as tmp_dir:
268+
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-config", use_auth_token=self._token)
269+
config.save_pretrained(tmp_dir)
270+
with open(os.path.join(tmp_dir, "configuration.py"), "w") as f:
271+
f.write(FAKE_CONFIG_CODE)
272+
273+
repo.push_to_hub()
274+
275+
new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config", trust_remote_code=True)
276+
# Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
277+
self.assertEqual(new_config.__class__.__name__, "FakeConfig")
278+
self.assertEqual(new_config.attribute, 42)
279+
241280

242281
class ConfigTestUtils(unittest.TestCase):
243282
def test_config_from_string(self):

tests/test_modeling_common.py

+72-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,14 @@
3030
import transformers
3131
from huggingface_hub import Repository, delete_repo, login
3232
from requests.exceptions import HTTPError
33-
from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging
33+
from transformers import (
34+
AutoConfig,
35+
AutoModel,
36+
AutoModelForSequenceClassification,
37+
PretrainedConfig,
38+
is_torch_available,
39+
logging,
40+
)
3441
from transformers.file_utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available
3542
from transformers.models.auto import get_values
3643
from transformers.testing_utils import (
@@ -67,7 +74,6 @@
6774
AdaptiveEmbedding,
6875
BertConfig,
6976
BertModel,
70-
PretrainedConfig,
7177
PreTrainedModel,
7278
T5Config,
7379
T5ForConditionalGeneration,
@@ -2078,6 +2084,23 @@ def test_model_from_pretrained_torch_dtype(self):
20782084
self.assertEqual(model.dtype, torch.float16)
20792085

20802086

2087+
class FakeConfig(PretrainedConfig):
2088+
def __init__(self, attribute=1, **kwargs):
2089+
self.attribute = attribute
2090+
super().__init__(**kwargs)
2091+
2092+
2093+
# Make sure this is synchronized with the config above.
2094+
FAKE_CONFIG_CODE = """
2095+
from transformers import PretrainedConfig
2096+
2097+
class FakeConfig(PretrainedConfig):
2098+
def __init__(self, attribute=1, **kwargs):
2099+
self.attribute = attribute
2100+
super().__init__(**kwargs)
2101+
"""
2102+
2103+
20812104
if is_torch_available():
20822105

20832106
class FakeModel(PreTrainedModel):
@@ -2140,6 +2163,11 @@ def tearDownClass(cls):
21402163
except HTTPError:
21412164
pass
21422165

2166+
try:
2167+
delete_repo(token=cls._token, name="test-dynamic-model-config")
2168+
except HTTPError:
2169+
pass
2170+
21432171
def test_push_to_hub(self):
21442172
config = BertConfig(
21452173
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
@@ -2185,5 +2213,47 @@ def test_push_to_hub_dynamic_model(self):
21852213
repo.push_to_hub()
21862214

21872215
new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
2216+
# Can't make an isinstance check because the new_model is from the FakeModel class of a dynamic module
2217+
self.assertEqual(new_model.__class__.__name__, "FakeModel")
2218+
for p1, p2 in zip(model.parameters(), new_model.parameters()):
2219+
self.assertTrue(torch.equal(p1, p2))
2220+
2221+
config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model")
2222+
new_model = AutoModel.from_config(config, trust_remote_code=True)
2223+
self.assertEqual(new_model.__class__.__name__, "FakeModel")
2224+
2225+
def test_push_to_hub_dynamic_model_and_config(self):
2226+
config = FakeConfig(
2227+
attribute=42,
2228+
vocab_size=99,
2229+
hidden_size=32,
2230+
num_hidden_layers=5,
2231+
num_attention_heads=4,
2232+
intermediate_size=37,
2233+
)
2234+
config.auto_map = {"AutoConfig": "configuration.FakeConfig", "AutoModel": "modeling.FakeModel"}
2235+
model = FakeModel(config)
2236+
2237+
with tempfile.TemporaryDirectory() as tmp_dir:
2238+
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-model-config", use_auth_token=self._token)
2239+
model.save_pretrained(tmp_dir)
2240+
with open(os.path.join(tmp_dir, "configuration.py"), "w") as f:
2241+
f.write(FAKE_CONFIG_CODE)
2242+
with open(os.path.join(tmp_dir, "modeling.py"), "w") as f:
2243+
f.write(FAKE_MODEL_CODE)
2244+
2245+
repo.push_to_hub()
2246+
2247+
new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model-config", trust_remote_code=True)
2248+
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
2249+
self.assertEqual(new_model.config.__class__.__name__, "FakeConfig")
2250+
self.assertEqual(new_model.config.attribute, 42)
2251+
2252+
# Can't make an isinstance check because the new_model is from the FakeModel class of a dynamic module
2253+
self.assertEqual(new_model.__class__.__name__, "FakeModel")
21882254
for p1, p2 in zip(model.parameters(), new_model.parameters()):
21892255
self.assertTrue(torch.equal(p1, p2))
2256+
2257+
config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model")
2258+
new_model = AutoModel.from_config(config, trust_remote_code=True)
2259+
self.assertEqual(new_model.__class__.__name__, "FakeModel")

0 commit comments

Comments
 (0)