|
19 | 19 | import copy
|
20 | 20 | import json
|
21 | 21 | import os
|
| 22 | +import re |
22 | 23 | import warnings
|
23 |
| -from typing import Any, Dict, Tuple, Union |
| 24 | +from typing import Any, Dict, Optional, Tuple, Union |
| 25 | + |
| 26 | +from packaging import version |
24 | 27 |
|
25 | 28 | from . import __version__
|
26 | 29 | from .file_utils import (
|
27 | 30 | CONFIG_NAME,
|
28 | 31 | PushToHubMixin,
|
29 | 32 | cached_path,
|
30 | 33 | copy_func,
|
| 34 | + get_list_of_files, |
31 | 35 | hf_bucket_url,
|
32 | 36 | is_offline_mode,
|
33 | 37 | is_remote_url,
|
|
37 | 41 |
|
38 | 42 |
|
39 | 43 | logger = logging.get_logger(__name__)
|
| 44 | +FULL_CONFIGURATION_FILE = "config.json" |
| 45 | +_re_configuration_file = re.compile(r"config\.(.*)\.json") |
40 | 46 |
|
41 | 47 |
|
42 | 48 | class PretrainedConfig(PushToHubMixin):
|
@@ -536,15 +542,23 @@ def get_config_dict(
|
536 | 542 | local_files_only = True
|
537 | 543 |
|
538 | 544 | pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
539 |
| - if os.path.isdir(pretrained_model_name_or_path): |
540 |
| - config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) |
541 |
| - elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): |
| 545 | + if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): |
542 | 546 | config_file = pretrained_model_name_or_path
|
543 | 547 | else:
|
544 |
| - config_file = hf_bucket_url( |
545 |
| - pretrained_model_name_or_path, filename=CONFIG_NAME, revision=revision, mirror=None |
| 548 | + configuration_file = get_configuration_file( |
| 549 | + pretrained_model_name_or_path, |
| 550 | + revision=revision, |
| 551 | + use_auth_token=use_auth_token, |
| 552 | + local_files_only=local_files_only, |
546 | 553 | )
|
547 | 554 |
|
| 555 | + if os.path.isdir(pretrained_model_name_or_path): |
| 556 | + config_file = os.path.join(pretrained_model_name_or_path, configuration_file) |
| 557 | + else: |
| 558 | + config_file = hf_bucket_url( |
| 559 | + pretrained_model_name_or_path, filename=configuration_file, revision=revision, mirror=None |
| 560 | + ) |
| 561 | + |
548 | 562 | try:
|
549 | 563 | # Load from URL or cache if already cached
|
550 | 564 | resolved_config_file = cached_path(
|
@@ -796,6 +810,56 @@ def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
|
796 | 810 | d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
|
797 | 811 |
|
798 | 812 |
|
| 813 | +def get_configuration_file( |
| 814 | + path_or_repo: Union[str, os.PathLike], |
| 815 | + revision: Optional[str] = None, |
| 816 | + use_auth_token: Optional[Union[bool, str]] = None, |
| 817 | + local_files_only: bool = False, |
| 818 | +) -> str: |
| 819 | + """ |
| 820 | + Get the configuration file to use for this version of transformers. |
| 821 | +
|
| 822 | + Args: |
| 823 | + path_or_repo (:obj:`str` or :obj:`os.PathLike`): |
| 824 | + Can be either the id of a repo on huggingface.co or a path to a `directory`. |
| 825 | + revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): |
| 826 | + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a |
| 827 | + git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any |
| 828 | + identifier allowed by git. |
| 829 | + use_auth_token (:obj:`str` or `bool`, `optional`): |
| 830 | + The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token |
| 831 | + generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). |
| 832 | + local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`): |
| 833 | + Whether or not to only rely on local files and not to attempt to download any files. |
| 834 | +
|
| 835 | + Returns: |
| 836 | + :obj:`str`: The configuration file to use. |
| 837 | + """ |
| 838 | + # Inspect all files from the repo/folder. |
| 839 | + all_files = get_list_of_files( |
| 840 | + path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only |
| 841 | + ) |
| 842 | + configuration_files_map = {} |
| 843 | + for file_name in all_files: |
| 844 | + search = _re_configuration_file.search(file_name) |
| 845 | + if search is not None: |
| 846 | + v = search.groups()[0] |
| 847 | + configuration_files_map[v] = file_name |
| 848 | + available_versions = sorted(configuration_files_map.keys()) |
| 849 | + |
| 850 | + # Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions. |
| 851 | + configuration_file = FULL_CONFIGURATION_FILE |
| 852 | + transformers_version = version.parse(__version__) |
| 853 | + for v in available_versions: |
| 854 | + if version.parse(v) <= transformers_version: |
| 855 | + configuration_file = configuration_files_map[v] |
| 856 | + else: |
| 857 | + # No point going further since the versions are sorted. |
| 858 | + break |
| 859 | + |
| 860 | + return configuration_file |
| 861 | + |
| 862 | + |
799 | 863 | PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub)
|
800 | 864 | PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
|
801 | 865 | object="config", object_class="AutoConfig", object_files="configuration file"
|
|
0 commit comments