Skip to content

Commit 1cc453d

Browse files
LysandreJiksgugger
andauthored
Allow per-version configurations (#14344)
* Allow per-version configurations * Update tests/test_configuration_common.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update tests/test_configuration_common.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
1 parent 76d0d41 commit 1cc453d

File tree

2 files changed

+109
-6
lines changed

2 files changed

+109
-6
lines changed

src/transformers/configuration_utils.py

+70-6
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,19 @@
1919
import copy
2020
import json
2121
import os
22+
import re
2223
import warnings
23-
from typing import Any, Dict, Tuple, Union
24+
from typing import Any, Dict, Optional, Tuple, Union
25+
26+
from packaging import version
2427

2528
from . import __version__
2629
from .file_utils import (
2730
CONFIG_NAME,
2831
PushToHubMixin,
2932
cached_path,
3033
copy_func,
34+
get_list_of_files,
3135
hf_bucket_url,
3236
is_offline_mode,
3337
is_remote_url,
@@ -37,6 +41,8 @@
3741

3842

3943
logger = logging.get_logger(__name__)
44+
FULL_CONFIGURATION_FILE = "config.json"
45+
_re_configuration_file = re.compile(r"config\.(.*)\.json")
4046

4147

4248
class PretrainedConfig(PushToHubMixin):
@@ -536,15 +542,23 @@ def get_config_dict(
536542
local_files_only = True
537543

538544
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
539-
if os.path.isdir(pretrained_model_name_or_path):
540-
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
541-
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
545+
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
542546
config_file = pretrained_model_name_or_path
543547
else:
544-
config_file = hf_bucket_url(
545-
pretrained_model_name_or_path, filename=CONFIG_NAME, revision=revision, mirror=None
548+
configuration_file = get_configuration_file(
549+
pretrained_model_name_or_path,
550+
revision=revision,
551+
use_auth_token=use_auth_token,
552+
local_files_only=local_files_only,
546553
)
547554

555+
if os.path.isdir(pretrained_model_name_or_path):
556+
config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
557+
else:
558+
config_file = hf_bucket_url(
559+
pretrained_model_name_or_path, filename=configuration_file, revision=revision, mirror=None
560+
)
561+
548562
try:
549563
# Load from URL or cache if already cached
550564
resolved_config_file = cached_path(
@@ -796,6 +810,56 @@ def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
796810
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
797811

798812

813+
def get_configuration_file(
814+
path_or_repo: Union[str, os.PathLike],
815+
revision: Optional[str] = None,
816+
use_auth_token: Optional[Union[bool, str]] = None,
817+
local_files_only: bool = False,
818+
) -> str:
819+
"""
820+
Get the configuration file to use for this version of transformers.
821+
822+
Args:
823+
path_or_repo (:obj:`str` or :obj:`os.PathLike`):
824+
Can be either the id of a repo on huggingface.co or a path to a `directory`.
825+
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
826+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
827+
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
828+
identifier allowed by git.
829+
use_auth_token (:obj:`str` or `bool`, `optional`):
830+
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
831+
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
832+
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
833+
Whether or not to only rely on local files and not to attempt to download any files.
834+
835+
Returns:
836+
:obj:`str`: The configuration file to use.
837+
"""
838+
# Inspect all files from the repo/folder.
839+
all_files = get_list_of_files(
840+
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
841+
)
842+
configuration_files_map = {}
843+
for file_name in all_files:
844+
search = _re_configuration_file.search(file_name)
845+
if search is not None:
846+
v = search.groups()[0]
847+
configuration_files_map[v] = file_name
848+
available_versions = sorted(configuration_files_map.keys())
849+
850+
# Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
851+
configuration_file = FULL_CONFIGURATION_FILE
852+
transformers_version = version.parse(__version__)
853+
for v in available_versions:
854+
if version.parse(v) <= transformers_version:
855+
configuration_file = configuration_files_map[v]
856+
else:
857+
# No point going further since the versions are sorted.
858+
break
859+
860+
return configuration_file
861+
862+
799863
PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub)
800864
PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
801865
object="config", object_class="AutoConfig", object_files="configuration file"

tests/test_configuration_common.py

+39
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
import copy
1717
import json
1818
import os
19+
import shutil
1920
import tempfile
2021
import unittest
22+
import unittest.mock
2123

2224
from huggingface_hub import Repository, delete_repo, login
2325
from requests.exceptions import HTTPError
@@ -306,3 +308,40 @@ def test_config_common_kwargs_is_complete(self):
306308
"The following keys are set with the default values in `test_configuration_common.config_common_kwargs` "
307309
f"pick another value for them: {', '.join(keys_with_defaults)}."
308310
)
311+
312+
313+
class ConfigurationVersioningTest(unittest.TestCase):
314+
def test_local_versioning(self):
315+
configuration = AutoConfig.from_pretrained("bert-base-cased")
316+
317+
with tempfile.TemporaryDirectory() as tmp_dir:
318+
configuration.save_pretrained(tmp_dir)
319+
configuration.hidden_size = 2
320+
json.dump(configuration.to_dict(), open(os.path.join(tmp_dir, "config.4.0.0.json"), "w"))
321+
322+
# This should pick the new configuration file as the version of Transformers is > 4.0.0
323+
new_configuration = AutoConfig.from_pretrained(tmp_dir)
324+
self.assertEqual(new_configuration.hidden_size, 2)
325+
326+
# Will need to be adjusted if we reach v42 and this test is still here.
327+
# Should pick the old configuration file as the version of Transformers is < 4.42.0
328+
shutil.move(os.path.join(tmp_dir, "config.4.0.0.json"), os.path.join(tmp_dir, "config.42.0.0.json"))
329+
new_configuration = AutoConfig.from_pretrained(tmp_dir)
330+
self.assertEqual(new_configuration.hidden_size, 768)
331+
332+
def test_repo_versioning_before(self):
333+
# This repo has two configuration files, one for v5.0.0 and above with an added token, one for versions lower.
334+
repo = "microsoft/layoutxlm-base"
335+
336+
import transformers as new_transformers
337+
338+
new_transformers.configuration_utils.__version__ = "v5.0.0"
339+
new_configuration = new_transformers.models.auto.AutoConfig.from_pretrained(repo)
340+
self.assertEqual(new_configuration.tokenizer_class, None)
341+
342+
# Testing an older version by monkey-patching the version in the module it's used.
343+
import transformers as old_transformers
344+
345+
old_transformers.configuration_utils.__version__ = "v3.0.0"
346+
old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo)
347+
self.assertEqual(old_configuration.tokenizer_class, "XLMRobertaTokenizer")

0 commit comments

Comments
 (0)