diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 0ad258889..0230f8456 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -201,6 +201,35 @@ def filter(self, x): return not (self.text in x.getMessage()) # Patch get_model_param_count to record correct 4bit / 8bit from transformers.trainer_pt_utils import is_deepspeed_zero3_enabled + +def extract_approx_params_from_config(config): + """ + Extract approximate parameter count from model config's name_or_path + Returns int (param count) or None if not found. + """ + lowercase_b_families = ["gemma"] # gemma uses small 'b' : google/gemma-3-1b-it + model_name = getattr(config, "name_or_path", "") + import re + cleaned = re.sub(r"[-_]?bnb[-_]?4bit|[-_]?4bit|[-_]?8bit|[-_]?bnb", "", model_name, flags=re.IGNORECASE) # replace bnb and xbit + match_B = re.search(r"([0-9]+(?:\.[0-9]+)?)\s*B", cleaned) # first prefer searching 'B' + if match_B: + # most model names would come in this flow + billions = float(match_B.group(1)) + return int(1_000_000_000 * billions) + else: + if any(fam in cleaned.lower() for fam in lowercase_b_families): + match_b = re.search(r"([0-9]+(?:\.[0-9]+)?)\s*b", cleaned) + if match_b: + billions = float(match_b.group(1)) + return int(1_000_000_000 * billions) + else: + match_any = re.search(r"([0-9]+(?:\.[0-9]+)?)\s*[bB]", cleaned) + if match_any: + billions = float(match_any.group(1)) + return int(1_000_000_000 * billions) + return None + + def get_model_param_count(model, trainable_only = False): """ Calculate model's total param count. If trainable_only is True then count only those requiring grads @@ -215,12 +244,9 @@ def numel(p): if (not trainable_only) and \ hasattr(model, "config") and \ hasattr(model.config, "quantization_config"): - - billions = re.findall(r"([0-9]{1,})(?:b|B)", model.config.name_or_path) - if len(billions) != 0: - billions = int(billions[0]) - s = 1_000_000_000 * billions - pass + approx = extract_approx_params_from_config(model.config) + if approx is not None: + s = approx return s pass import transformers.trainer_pt_utils