Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,35 @@ def filter(self, x): return not (self.text in x.getMessage())

# Patch get_model_param_count to record correct 4bit / 8bit
from transformers.trainer_pt_utils import is_deepspeed_zero3_enabled

def extract_approx_params_from_config(config):
"""
Extract approximate parameter count from model config's name_or_path
Returns int (param count) or None if not found.
"""
lowercase_b_families = ["gemma"] # gemma uses small 'b' : google/gemma-3-1b-it
model_name = getattr(config, "name_or_path", "")
import re
cleaned = re.sub(r"[-_]?bnb[-_]?4bit|[-_]?4bit|[-_]?8bit|[-_]?bnb", "", model_name, flags=re.IGNORECASE) # replace bnb and xbit
Copy link

Copilot AI Jun 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The regex substitution comment mentions 'xbit', but the pattern does not match any 'xbit' strings. Please update either the comment or the regex to correctly handle 'xbit' if intended.

Suggested change
cleaned = re.sub(r"[-_]?bnb[-_]?4bit|[-_]?4bit|[-_]?8bit|[-_]?bnb", "", model_name, flags=re.IGNORECASE) # replace bnb and xbit
cleaned = re.sub(r"[-_]?bnb[-_]?4bit|[-_]?4bit|[-_]?8bit|[-_]?bnb|[-_]?xbit", "", model_name, flags=re.IGNORECASE) # replace bnb and xbit

Copilot uses AI. Check for mistakes.
match_B = re.search(r"([0-9]+(?:\.[0-9]+)?)\s*B", cleaned) # first prefer searching 'B'
if match_B:
# most model names would come in this flow
billions = float(match_B.group(1))
return int(1_000_000_000 * billions)
else:
if any(fam in cleaned.lower() for fam in lowercase_b_families):
match_b = re.search(r"([0-9]+(?:\.[0-9]+)?)\s*b", cleaned)
if match_b:
billions = float(match_b.group(1))
return int(1_000_000_000 * billions)
else:
match_any = re.search(r"([0-9]+(?:\.[0-9]+)?)\s*[bB]", cleaned)
if match_any:
billions = float(match_any.group(1))
return int(1_000_000_000 * billions)
return None


def get_model_param_count(model, trainable_only = False):
"""
Calculate model's total param count. If trainable_only is True then count only those requiring grads
Expand All @@ -215,12 +244,9 @@ def numel(p):
if (not trainable_only) and \
hasattr(model, "config") and \
hasattr(model.config, "quantization_config"):

billions = re.findall(r"([0-9]{1,})(?:b|B)", model.config.name_or_path)
if len(billions) != 0:
billions = int(billions[0])
s = 1_000_000_000 * billions
pass
approx = extract_approx_params_from_config(model.config)
if approx is not None:
s = approx
return s
pass
import transformers.trainer_pt_utils
Expand Down