Skip to content

Commit 002a078

Browse files
sguggerLysandreJik
andauthored
Dynamically load model code from the Hub (#13467)
* Dynamic model * Use defensive flag * Style * Doc and arg rename * Arg rename * Add tests * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Address review comments * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
1 parent aeb2dac commit 002a078

File tree

5 files changed

+389
-3
lines changed

5 files changed

+389
-3
lines changed

src/transformers/file_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,8 @@
248248
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
249249
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
250250
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
251+
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
252+
TRANSFORMERS_DYNAMIC_MODULE_NAME = "transformers_modules"
251253
SESSION_ID = uuid4().hex
252254
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) in ENV_VARS_TRUE_VALUES
253255

src/transformers/models/auto/auto_factory.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ...file_utils import copy_func
2121
from ...utils import logging
2222
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
23+
from .dynamic import get_class_from_dynamic_module
2324

2425

2526
logger = logging.get_logger(__name__)
@@ -122,6 +123,10 @@
122123
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
123124
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
124125
identifier allowed by git.
126+
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
127+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
128+
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
129+
will execute code present on the Hub on your local machine.
125130
kwargs (additional keyword arguments, `optional`):
126131
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
127132
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
@@ -211,6 +216,10 @@
211216
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
212217
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
213218
identifier allowed by git.
219+
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
220+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
221+
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
222+
will execute code present on the Hub on your local machine.
214223
kwargs (additional keyword arguments, `optional`):
215224
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
216225
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
@@ -300,6 +309,10 @@
300309
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
301310
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
302311
identifier allowed by git.
312+
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
313+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
314+
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
315+
will execute code present on the Hub on your local machine.
303316
kwargs (additional keyword arguments, `optional`):
304317
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
305318
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
@@ -377,13 +390,31 @@ def from_config(cls, config, **kwargs):
377390
@classmethod
378391
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
379392
config = kwargs.pop("config", None)
393+
trust_remote_code = kwargs.pop("trust_remote_code", False)
380394
kwargs["_from_auto"] = True
381395
if not isinstance(config, PretrainedConfig):
382396
config, kwargs = AutoConfig.from_pretrained(
383397
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
384398
)
385-
386-
if type(config) in cls._model_mapping.keys():
399+
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
400+
if not trust_remote_code:
401+
raise ValueError(
402+
f"Loading {pretrained_model_name_or_path} requires you to execute the modeling file in that repo "
403+
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
404+
"the option `trust_remote_code=True` to remove this error."
405+
)
406+
if kwargs.get("revision", None) is None:
407+
logger.warn(
408+
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
409+
"no malicious code has been contributed in a newer revision."
410+
)
411+
class_ref = config.auto_map[cls.__name__]
412+
module_file, class_name = class_ref.split(".")
413+
model_class = get_class_from_dynamic_module(
414+
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
415+
)
416+
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
417+
elif type(config) in cls._model_mapping.keys():
387418
model_class = _get_model_class(config, cls._model_mapping)
388419
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
389420
raise ValueError(
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# coding=utf-8
2+
# Copyright 2021 The HuggingFace Inc. team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Utilities to dynamically load model and tokenizer from the Hub."""
16+
17+
import importlib
18+
import os
19+
import re
20+
import shutil
21+
import sys
22+
from pathlib import Path
23+
from typing import Dict, Optional, Union
24+
25+
from ...file_utils import (
26+
HF_MODULES_CACHE,
27+
TRANSFORMERS_DYNAMIC_MODULE_NAME,
28+
cached_path,
29+
hf_bucket_url,
30+
is_offline_mode,
31+
)
32+
from ...utils import logging
33+
34+
35+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36+
37+
38+
def init_hf_modules():
39+
"""
40+
Creates the cache directory for modules with an init, and adds it to the Python path.
41+
"""
42+
# This function has already been executed if HF_MODULES_CACHE already is in the Python path.
43+
if HF_MODULES_CACHE in sys.path:
44+
return
45+
46+
sys.path.append(HF_MODULES_CACHE)
47+
os.makedirs(HF_MODULES_CACHE, exist_ok=True)
48+
init_path = Path(HF_MODULES_CACHE) / "__init__.py"
49+
if not init_path.exists():
50+
init_path.touch()
51+
52+
53+
def create_dynamic_module(name: Union[str, os.PathLike]):
54+
"""
55+
Creates a dynamic module in the cache directory for modules.
56+
"""
57+
init_hf_modules()
58+
dynamic_module_path = Path(HF_MODULES_CACHE) / name
59+
# If the parent module does not exist yet, recursively create it.
60+
if not dynamic_module_path.parent.exists():
61+
create_dynamic_module(dynamic_module_path.parent)
62+
os.makedirs(dynamic_module_path, exist_ok=True)
63+
init_path = dynamic_module_path / "__init__.py"
64+
if not init_path.exists():
65+
init_path.touch()
66+
67+
68+
def check_imports(filename):
69+
"""
70+
Check if the current Python environment contains all the libraries that are imported in a file.
71+
"""
72+
with open(filename, "r", encoding="utf-8") as f:
73+
content = f.read()
74+
75+
# Imports of the form `import xxx`
76+
imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
77+
# Imports of the form `from xxx import yyy`
78+
imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
79+
# Only keep the top-level module
80+
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
81+
82+
# Unique-ify and test we got them all
83+
imports = list(set(imports))
84+
missing_packages = []
85+
for imp in imports:
86+
try:
87+
importlib.import_module(imp)
88+
except ImportError:
89+
missing_packages.append(imp)
90+
91+
if len(missing_packages) > 0:
92+
raise ImportError(
93+
"This modeling file requires the following packages that were not found in your environment: "
94+
f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
95+
)
96+
97+
98+
def get_class_in_module(class_name, module_path):
99+
"""
100+
Import a module on the cache directory for modules and extract a class from it.
101+
"""
102+
module_path = module_path.replace(os.path.sep, ".")
103+
module = importlib.import_module(module_path)
104+
return getattr(module, class_name)
105+
106+
107+
def get_class_from_dynamic_module(
108+
pretrained_model_name_or_path: Union[str, os.PathLike],
109+
module_file: str,
110+
class_name: str,
111+
cache_dir: Optional[Union[str, os.PathLike]] = None,
112+
force_download: bool = False,
113+
resume_download: bool = False,
114+
proxies: Optional[Dict[str, str]] = None,
115+
use_auth_token: Optional[Union[bool, str]] = None,
116+
revision: Optional[str] = None,
117+
local_files_only: bool = False,
118+
**kwargs,
119+
):
120+
"""
121+
Extracts a class from a module file, present in the local folder or repository of a model.
122+
123+
.. warning::
124+
125+
Calling this function will execute the code in the module file found locally or downloaded from the Hub. It
126+
should therefore only be called on trusted repos.
127+
128+
Args:
129+
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
130+
This can be either:
131+
132+
- a string, the `model id` of a pretrained model configuration hosted inside a model repo on
133+
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
134+
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
135+
- a path to a `directory` containing a configuration file saved using the
136+
:func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g., ``./my_model_directory/``.
137+
138+
module_file (:obj:`str`):
139+
The name of the module file containing the class to look for.
140+
class_name (:obj:`str`):
141+
The name of the class to import in the module.
142+
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
143+
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
144+
cache should not be used.
145+
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
146+
Whether or not to force to (re-)download the configuration files and override the cached versions if they
147+
exist.
148+
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
149+
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
150+
proxies (:obj:`Dict[str, str]`, `optional`):
151+
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
152+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
153+
use_auth_token (:obj:`str` or `bool`, `optional`):
154+
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
155+
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
156+
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
157+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
158+
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
159+
identifier allowed by git.
160+
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
161+
If :obj:`True`, will only try to load the tokenizer configuration from local files.
162+
163+
.. note::
164+
165+
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
166+
167+
168+
Returns:
169+
:obj:`type`: The class, dynamically imported from the module.
170+
171+
Examples::
172+
173+
# Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
174+
# module.
175+
cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
176+
"""
177+
if is_offline_mode() and not local_files_only:
178+
logger.info("Offline mode: forcing local_files_only=True")
179+
local_files_only = True
180+
181+
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
182+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
183+
if os.path.isdir(pretrained_model_name_or_path):
184+
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
185+
submodule = "local"
186+
else:
187+
module_file_or_url = hf_bucket_url(
188+
pretrained_model_name_or_path, filename=module_file, revision=revision, mirror=None
189+
)
190+
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
191+
192+
try:
193+
# Load from URL or cache if already cached
194+
resolved_module_file = cached_path(
195+
module_file_or_url,
196+
cache_dir=cache_dir,
197+
force_download=force_download,
198+
proxies=proxies,
199+
resume_download=resume_download,
200+
local_files_only=local_files_only,
201+
use_auth_token=use_auth_token,
202+
)
203+
204+
except EnvironmentError:
205+
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
206+
raise
207+
208+
# Check we have all the requirements in our environment
209+
check_imports(resolved_module_file)
210+
211+
# Now we move the module inside our cached dynamic modules.
212+
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
213+
create_dynamic_module(full_submodule)
214+
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
215+
if submodule == "local":
216+
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
217+
# that hash, to only copy when there is a modification but it seems overkill for now).
218+
# The only reason we do the copy is to avoid putting too many folders in sys.path.
219+
module_name = module_file
220+
shutil.copy(resolved_module_file, submodule_path / module_file)
221+
else:
222+
# The module file will end up being named module_file + the etag. This way we get the benefit of versioning.
223+
resolved_module_file_name = Path(resolved_module_file).name
224+
module_name_parts = [module_file.replace(".py", "")] + resolved_module_file_name.split(".")
225+
module_name = "_".join(module_name_parts) + ".py"
226+
if not (submodule_path / module_name).exists():
227+
shutil.copy(resolved_module_file, submodule_path / module_name)
228+
229+
# And lastly we get the class inside our newly created module
230+
final_module = os.path.join(full_submodule, module_name.replace(".py", ""))
231+
return get_class_in_module(class_name, final_module)

0 commit comments

Comments
 (0)