Skip to content

Commit 0d0aada

Browse files
sguggerjulien-c
andauthored
Use commit hash to look in cache instead of calling head (#18534)
* Use commit hash to look in cache instead of calling head * Add tests * Add attr for local configs too * Stupid typos * Fix tests * Update src/transformers/utils/hub.py Co-authored-by: Julien Chaumond <julien@huggingface.co> * Address Julien's comments Co-authored-by: Julien Chaumond <julien@huggingface.co>
1 parent 6eb5145 commit 0d0aada

15 files changed

+221
-23
lines changed

src/transformers/configuration_utils.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,15 @@
2727

2828
from . import __version__
2929
from .dynamic_module_utils import custom_object_save
30-
from .utils import CONFIG_NAME, PushToHubMixin, cached_file, copy_func, is_torch_available, logging
30+
from .utils import (
31+
CONFIG_NAME,
32+
PushToHubMixin,
33+
cached_file,
34+
copy_func,
35+
extract_commit_hash,
36+
is_torch_available,
37+
logging,
38+
)
3139

3240

3341
logger = logging.get_logger(__name__)
@@ -343,6 +351,8 @@ def __init__(self, **kwargs):
343351

344352
# Name or path to the pretrained checkpoint
345353
self._name_or_path = str(kwargs.pop("name_or_path", ""))
354+
# Config hash
355+
self._commit_hash = kwargs.pop("_commit_hash", None)
346356

347357
# Drop the transformers version info
348358
self.transformers_version = kwargs.pop("transformers_version", None)
@@ -539,6 +549,8 @@ def get_config_dict(
539549
original_kwargs = copy.deepcopy(kwargs)
540550
# Get config dict associated with the base config file
541551
config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
552+
if "_commit_hash" in config_dict:
553+
original_kwargs["_commit_hash"] = config_dict["_commit_hash"]
542554

543555
# That config file may point us toward another config file to use.
544556
if "configuration_files" in config_dict:
@@ -564,6 +576,7 @@ def _get_config_dict(
564576
subfolder = kwargs.pop("subfolder", "")
565577
from_pipeline = kwargs.pop("_from_pipeline", None)
566578
from_auto_class = kwargs.pop("_from_auto", False)
579+
commit_hash = kwargs.pop("_commit_hash", None)
567580

568581
if trust_remote_code is True:
569582
logger.warning(
@@ -599,7 +612,9 @@ def _get_config_dict(
599612
user_agent=user_agent,
600613
revision=revision,
601614
subfolder=subfolder,
615+
_commit_hash=commit_hash,
602616
)
617+
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
603618
except EnvironmentError:
604619
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
605620
# the original exception.
@@ -616,6 +631,7 @@ def _get_config_dict(
616631
try:
617632
# Load config dict
618633
config_dict = cls._dict_from_json_file(resolved_config_file)
634+
config_dict["_commit_hash"] = commit_hash
619635
except (json.JSONDecodeError, UnicodeDecodeError):
620636
raise EnvironmentError(
621637
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
@@ -648,6 +664,9 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
648664
# We remove them so they don't appear in `return_unused_kwargs`.
649665
kwargs.pop("_from_auto", None)
650666
kwargs.pop("_from_pipeline", None)
667+
# The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
668+
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
669+
kwargs["_commit_hash"] = config_dict["_commit_hash"]
651670

652671
config = cls(**config_dict)
653672

@@ -751,6 +770,8 @@ def to_dict(self) -> Dict[str, Any]:
751770
output["model_type"] = self.__class__.model_type
752771
if "_auto_class" in output:
753772
del output["_auto_class"]
773+
if "_commit_hash" in output:
774+
del output["_commit_hash"]
754775

755776
# Transformers version when serializing the model
756777
output["transformers_version"] = __version__

src/transformers/modeling_flax_utils.py

+7
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,7 @@ def from_pretrained(
595595
from_auto_class = kwargs.pop("_from_auto", False)
596596
_do_init = kwargs.pop("_do_init", True)
597597
subfolder = kwargs.pop("subfolder", "")
598+
commit_hash = kwargs.pop("_commit_hash", None)
598599

599600
if trust_remote_code is True:
600601
logger.warning(
@@ -625,11 +626,15 @@ def from_pretrained(
625626
revision=revision,
626627
_from_auto=from_auto_class,
627628
_from_pipeline=from_pipeline,
629+
_commit_hash=commit_hash,
628630
**kwargs,
629631
)
630632
else:
631633
model_kwargs = kwargs
632634

635+
if commit_hash is None:
636+
commit_hash = getattr(config, "_commit_hash", None)
637+
633638
# Add the dtype to model_kwargs
634639
model_kwargs["dtype"] = dtype
635640

@@ -682,6 +687,7 @@ def from_pretrained(
682687
revision=revision,
683688
subfolder=subfolder,
684689
_raise_exceptions_for_missing_entries=False,
690+
_commit_hash=commit_hash,
685691
)
686692
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
687693

@@ -748,6 +754,7 @@ def from_pretrained(
748754
use_auth_token=use_auth_token,
749755
user_agent=user_agent,
750756
revision=revision,
757+
_commit_hash=commit_hash,
751758
)
752759

753760
# init random models

src/transformers/modeling_tf_utils.py

+7
Original file line numberDiff line numberDiff line change
@@ -2161,6 +2161,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
21612161
from_pipeline = kwargs.pop("_from_pipeline", None)
21622162
from_auto_class = kwargs.pop("_from_auto", False)
21632163
subfolder = kwargs.pop("subfolder", "")
2164+
commit_hash = kwargs.pop("_commit_hash", None)
21642165

21652166
if trust_remote_code is True:
21662167
logger.warning(
@@ -2191,11 +2192,15 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
21912192
revision=revision,
21922193
_from_auto=from_auto_class,
21932194
_from_pipeline=from_pipeline,
2195+
_commit_hash=commit_hash,
21942196
**kwargs,
21952197
)
21962198
else:
21972199
model_kwargs = kwargs
21982200

2201+
if commit_hash is None:
2202+
commit_hash = getattr(config, "_commit_hash", None)
2203+
21992204
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
22002205
# index of the files.
22012206
is_sharded = False
@@ -2253,6 +2258,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
22532258
revision=revision,
22542259
subfolder=subfolder,
22552260
_raise_exceptions_for_missing_entries=False,
2261+
_commit_hash=commit_hash,
22562262
)
22572263
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
22582264

@@ -2320,6 +2326,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
23202326
use_auth_token=use_auth_token,
23212327
user_agent=user_agent,
23222328
revision=revision,
2329+
_commit_hash=commit_hash,
23232330
)
23242331

23252332
config.name_or_path = pretrained_model_name_or_path

src/transformers/modeling_utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -1840,6 +1840,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
18401840
load_in_8bit = kwargs.pop("load_in_8bit", False)
18411841
int8_threshold = kwargs.pop("int8_threshold", 6.0)
18421842
subfolder = kwargs.pop("subfolder", "")
1843+
commit_hash = kwargs.pop("_commit_hash", None)
18431844

18441845
if trust_remote_code is True:
18451846
logger.warning(
@@ -1918,6 +1919,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
19181919
else:
19191920
model_kwargs = kwargs
19201921

1922+
if commit_hash is None:
1923+
commit_hash = getattr(config, "_commit_hash", None)
1924+
19211925
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
19221926
# index of the files.
19231927
is_sharded = False
@@ -2004,6 +2008,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
20042008
revision=revision,
20052009
subfolder=subfolder,
20062010
_raise_exceptions_for_missing_entries=False,
2011+
_commit_hash=commit_hash,
20072012
)
20082013
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
20092014

@@ -2078,6 +2083,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
20782083
user_agent=user_agent,
20792084
revision=revision,
20802085
subfolder=subfolder,
2086+
_commit_hash=commit_hash,
20812087
)
20822088

20832089
# load pt weights early so that we know which dtype to init the model under

src/transformers/models/auto/tokenization_auto.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ...tokenization_utils import PreTrainedTokenizer
2626
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
2727
from ...tokenization_utils_fast import PreTrainedTokenizerFast
28-
from ...utils import get_file_from_repo, is_sentencepiece_available, is_tokenizers_available, logging
28+
from ...utils import cached_file, extract_commit_hash, is_sentencepiece_available, is_tokenizers_available, logging
2929
from ..encoder_decoder import EncoderDecoderConfig
3030
from .auto_factory import _LazyAutoMapping
3131
from .configuration_auto import (
@@ -389,7 +389,8 @@ def get_tokenizer_config(
389389
tokenizer.save_pretrained("tokenizer-test")
390390
tokenizer_config = get_tokenizer_config("tokenizer-test")
391391
```"""
392-
resolved_config_file = get_file_from_repo(
392+
commit_hash = kwargs.get("_commit_hash", None)
393+
resolved_config_file = cached_file(
393394
pretrained_model_name_or_path,
394395
TOKENIZER_CONFIG_FILE,
395396
cache_dir=cache_dir,
@@ -399,13 +400,19 @@ def get_tokenizer_config(
399400
use_auth_token=use_auth_token,
400401
revision=revision,
401402
local_files_only=local_files_only,
403+
_raise_exceptions_for_missing_entries=False,
404+
_raise_exceptions_for_connection_errors=False,
405+
_commit_hash=commit_hash,
402406
)
403407
if resolved_config_file is None:
404408
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
405409
return {}
410+
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
406411

407412
with open(resolved_config_file, encoding="utf-8") as reader:
408-
return json.load(reader)
413+
result = json.load(reader)
414+
result["_commit_hash"] = commit_hash
415+
return result
409416

410417

411418
class AutoTokenizer:
@@ -532,6 +539,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
532539

533540
# Next, let's try to use the tokenizer_config file to get the tokenizer class.
534541
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
542+
if "_commit_hash" in tokenizer_config:
543+
kwargs["_commit_hash"] = tokenizer_config["_commit_hash"]
535544
config_tokenizer_class = tokenizer_config.get("tokenizer_class")
536545
tokenizer_auto_map = None
537546
if "auto_map" in tokenizer_config:

src/transformers/pipelines/__init__.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,12 @@ def pipeline(
557557
# Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs,
558558
# this is to keep BC).
559559
use_auth_token = model_kwargs.pop("use_auth_token", use_auth_token)
560-
hub_kwargs = {"revision": revision, "use_auth_token": use_auth_token, "trust_remote_code": trust_remote_code}
560+
hub_kwargs = {
561+
"revision": revision,
562+
"use_auth_token": use_auth_token,
563+
"trust_remote_code": trust_remote_code,
564+
"_commit_hash": None,
565+
}
561566

562567
if task is None and model is None:
563568
raise RuntimeError(
@@ -583,8 +588,10 @@ def pipeline(
583588
# Instantiate config if needed
584589
if isinstance(config, str):
585590
config = AutoConfig.from_pretrained(config, _from_pipeline=task, **hub_kwargs, **model_kwargs)
591+
hub_kwargs["_commit_hash"] = config._commit_hash
586592
elif config is None and isinstance(model, str):
587593
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
594+
hub_kwargs["_commit_hash"] = config._commit_hash
588595

589596
custom_tasks = {}
590597
if config is not None and len(getattr(config, "custom_pipelines", {})) > 0:
@@ -639,6 +646,7 @@ def pipeline(
639646
)
640647
if config is None and isinstance(model, str):
641648
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
649+
hub_kwargs["_commit_hash"] = config._commit_hash
642650

643651
if device_map is not None:
644652
if "device_map" in model_kwargs:
@@ -672,6 +680,7 @@ def pipeline(
672680
)
673681

674682
model_config = model.config
683+
hub_kwargs["_commit_hash"] = model.config._commit_hash
675684

676685
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
677686
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None

src/transformers/testing_utils.py

+28
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from typing import Iterator, List, Union
3232
from unittest import mock
3333

34+
import huggingface_hub
3435
from transformers import logging as transformers_logging
3536

3637
from .deepspeed import is_deepspeed_available
@@ -1588,3 +1589,30 @@ def run_command(command: List[str], return_stdout=False):
15881589
raise SubprocessCallException(
15891590
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
15901591
) from e
1592+
1593+
1594+
class RequestCounter:
1595+
"""
1596+
Helper class that will count all requests made online.
1597+
"""
1598+
1599+
def __enter__(self):
1600+
self.head_request_count = 0
1601+
self.get_request_count = 0
1602+
self.other_request_count = 0
1603+
self.old_request = huggingface_hub.file_download.requests.request
1604+
huggingface_hub.file_download.requests.request = self.new_request
1605+
return self
1606+
1607+
def __exit__(self, *args, **kwargs):
1608+
huggingface_hub.file_download.requests.request = self.old_request
1609+
1610+
def new_request(self, method, **kwargs):
1611+
if method == "GET":
1612+
self.get_request_count += 1
1613+
elif method == "HEAD":
1614+
self.head_request_count += 1
1615+
else:
1616+
self.other_request_count += 1
1617+
1618+
return self.old_request(method=method, **kwargs)

0 commit comments

Comments
 (0)