27
27
28
28
from . import __version__
29
29
from .dynamic_module_utils import custom_object_save
30
- from .utils import CONFIG_NAME , PushToHubMixin , cached_file , copy_func , is_torch_available , logging
30
+ from .utils import (
31
+ CONFIG_NAME ,
32
+ PushToHubMixin ,
33
+ cached_file ,
34
+ copy_func ,
35
+ extract_commit_hash ,
36
+ is_torch_available ,
37
+ logging ,
38
+ )
31
39
32
40
33
41
logger = logging .get_logger (__name__ )
@@ -343,6 +351,8 @@ def __init__(self, **kwargs):
343
351
344
352
# Name or path to the pretrained checkpoint
345
353
self ._name_or_path = str (kwargs .pop ("name_or_path" , "" ))
354
+ # Config hash
355
+ self ._commit_hash = kwargs .pop ("_commit_hash" , None )
346
356
347
357
# Drop the transformers version info
348
358
self .transformers_version = kwargs .pop ("transformers_version" , None )
@@ -539,6 +549,8 @@ def get_config_dict(
539
549
original_kwargs = copy .deepcopy (kwargs )
540
550
# Get config dict associated with the base config file
541
551
config_dict , kwargs = cls ._get_config_dict (pretrained_model_name_or_path , ** kwargs )
552
+ if "_commit_hash" in config_dict :
553
+ original_kwargs ["_commit_hash" ] = config_dict ["_commit_hash" ]
542
554
543
555
# That config file may point us toward another config file to use.
544
556
if "configuration_files" in config_dict :
@@ -564,6 +576,7 @@ def _get_config_dict(
564
576
subfolder = kwargs .pop ("subfolder" , "" )
565
577
from_pipeline = kwargs .pop ("_from_pipeline" , None )
566
578
from_auto_class = kwargs .pop ("_from_auto" , False )
579
+ commit_hash = kwargs .pop ("_commit_hash" , None )
567
580
568
581
if trust_remote_code is True :
569
582
logger .warning (
@@ -599,7 +612,9 @@ def _get_config_dict(
599
612
user_agent = user_agent ,
600
613
revision = revision ,
601
614
subfolder = subfolder ,
615
+ _commit_hash = commit_hash ,
602
616
)
617
+ commit_hash = extract_commit_hash (resolved_config_file , commit_hash )
603
618
except EnvironmentError :
604
619
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
605
620
# the original exception.
@@ -616,6 +631,7 @@ def _get_config_dict(
616
631
try :
617
632
# Load config dict
618
633
config_dict = cls ._dict_from_json_file (resolved_config_file )
634
+ config_dict ["_commit_hash" ] = commit_hash
619
635
except (json .JSONDecodeError , UnicodeDecodeError ):
620
636
raise EnvironmentError (
621
637
f"It looks like the config file at '{ resolved_config_file } ' is not a valid JSON file."
@@ -648,6 +664,9 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
648
664
# We remove them so they don't appear in `return_unused_kwargs`.
649
665
kwargs .pop ("_from_auto" , None )
650
666
kwargs .pop ("_from_pipeline" , None )
667
+ # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
668
+ if "_commit_hash" in kwargs and "_commit_hash" in config_dict :
669
+ kwargs ["_commit_hash" ] = config_dict ["_commit_hash" ]
651
670
652
671
config = cls (** config_dict )
653
672
@@ -751,6 +770,8 @@ def to_dict(self) -> Dict[str, Any]:
751
770
output ["model_type" ] = self .__class__ .model_type
752
771
if "_auto_class" in output :
753
772
del output ["_auto_class" ]
773
+ if "_commit_hash" in output :
774
+ del output ["_commit_hash" ]
754
775
755
776
# Transformers version when serializing the model
756
777
output ["transformers_version" ] = __version__
0 commit comments