Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
f6dfa80
add model registry
jeromeku Mar 28, 2025
a5e7b3a
move hf hub utils to unsloth/utils
jeromeku Mar 28, 2025
dc8f34e
refactor global model info dicts to dataclasses
jeromeku Mar 30, 2025
7cd2763
fix dataclass init
jeromeku Mar 30, 2025
9899a72
fix llama registration
jeromeku Mar 30, 2025
310c598
remove deprecated key function
jeromeku Mar 30, 2025
e70d035
start registry reog
jeromeku Mar 30, 2025
de1fe25
add llama vision
jeromeku Mar 30, 2025
7e2207c
quant types -> Enum
jeromeku Mar 30, 2025
c3a1aff
remap literal quant types to QuantType Enum
jeromeku Mar 30, 2025
03de6df
add llama model registration
jeromeku Mar 30, 2025
fa95aa0
fix quant tag mapping
jeromeku Mar 30, 2025
fdafa78
add qwen2.5 models to registry
jeromeku Mar 31, 2025
6049310
add option to include original model in registry
jeromeku Mar 31, 2025
8dc3d66
handle quant types per model size
jeromeku Mar 31, 2025
1237075
separate registration of base and instruct llama3.2
jeromeku Mar 31, 2025
baab018
add QwenQVQ to registry
jeromeku Mar 31, 2025
6b08fc3
add gemma3 to registry
jeromeku Mar 31, 2025
44e227b
add phi
jeromeku Mar 31, 2025
d633179
add deepseek v3
jeromeku Mar 31, 2025
0755b45
add deepseek r1 base
jeromeku Mar 31, 2025
17358e6
add deepseek r1 zero
jeromeku Mar 31, 2025
975d263
add deepseek distill llama
jeromeku Mar 31, 2025
229ae10
add deepseek distill models
jeromeku Mar 31, 2025
6439e88
remove redundant code when constructing model names
jeromeku Mar 31, 2025
4e1df71
add mistral small to registry
jeromeku Mar 31, 2025
6d4ede4
rename model registration methods
jeromeku Apr 1, 2025
a774726
rename deepseek registration methods
jeromeku Apr 1, 2025
a2a4366
refactor naming for mistral and phi
jeromeku Apr 1, 2025
02fbb87
add global register models
jeromeku Apr 1, 2025
7fbde42
refactor model registration tests for new registry apis
jeromeku Apr 1, 2025
a2d3ad9
add model search method
jeromeku Apr 1, 2025
13a1126
remove deprecated registration api
jeromeku Apr 1, 2025
4840a32
add quant type test
jeromeku Apr 1, 2025
7d64639
add registry readme
jeromeku Apr 1, 2025
12b0d32
make llama registration more specific
jeromeku Apr 1, 2025
ea75001
clear registry when executing individual model registration file
jeromeku Apr 1, 2025
d854070
more registry readme updates
jeromeku Apr 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# Ruff stuff:
.ruff_cache/

# PyPI configuration file
.pypirc

# unsloth compiled cache
unsloth_compiled_cache
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ include-package-data = false
exclude = ["images*", "tests*"]

[project.optional-dependencies]
dev = [
"pytest",
]

triton = [
"triton-windows ; platform_system == 'Windows'",
]
Expand Down
Empty file added tests/__init__.py
Empty file.
91 changes: 91 additions & 0 deletions tests/test_model_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""

Test model registration methods
Checks that model registration methods work for respective models as well as all models
The check is performed
- by registering the models
- checking that the instantiated models can be found on huggingface hub by querying for the model id

"""

from dataclasses import dataclass

import pytest
from huggingface_hub import ModelInfo as HfModelInfo

from unsloth.registry import register_models, search_models
from unsloth.registry._deepseek import register_deepseek_models
from unsloth.registry._gemma import register_gemma_models
from unsloth.registry._llama import register_llama_models
from unsloth.registry._mistral import register_mistral_models
from unsloth.registry._phi import register_phi_models
from unsloth.registry._qwen import register_qwen_models
from unsloth.registry.registry import MODEL_REGISTRY, QUANT_TAG_MAP, QuantType
from unsloth.utils.hf_hub import get_model_info

MODEL_NAMES = [
"llama",
"qwen",
"mistral",
"phi",
"gemma",
"deepseek",
]
MODEL_REGISTRATION_METHODS = [
register_llama_models,
register_qwen_models,
register_mistral_models,
register_phi_models,
register_gemma_models,
register_deepseek_models,
]


@dataclass
class ModelTestParam:
name: str
register_models: callable


def _test_model_uploaded(model_ids: list[str]):
missing_models = []
for _id in model_ids:
model_info: HfModelInfo = get_model_info(_id)
if not model_info:
missing_models.append(_id)

return missing_models


TestParams = [
ModelTestParam(name, models)
for name, models in zip(MODEL_NAMES, MODEL_REGISTRATION_METHODS)
]


# Test that model registration methods register respective models
@pytest.mark.parametrize("model_test_param", TestParams, ids=lambda param: param.name)
def test_model_registration(model_test_param: ModelTestParam):
MODEL_REGISTRY.clear()
registration_method = model_test_param.register_models
registration_method()
registered_models = MODEL_REGISTRY.keys()
missing_models = _test_model_uploaded(registered_models)
assert not missing_models, (
f"{model_test_param.name} missing following models: {missing_models}"
)


def test_all_model_registration():
register_models()
registered_models = MODEL_REGISTRY.keys()
missing_models = _test_model_uploaded(registered_models)
assert not missing_models, f"Missing following models: {missing_models}"

def test_quant_type():
# Test that the quant_type is correctly set for model paths
# NOTE: for models registered under org="unsloth" with QuantType.NONE aliases QuantType.UNSLOTH
dynamic_quant_models = search_models(quant_types=[QuantType.UNSLOTH])
assert all(m.quant_type == QuantType.UNSLOTH for m in dynamic_quant_models)
quant_tag = QUANT_TAG_MAP[QuantType.UNSLOTH]
assert all(quant_tag in m.model_path for m in dynamic_quant_models)
110 changes: 110 additions & 0 deletions unsloth/registry/REGISTRY.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
## Model Registry

### Structure
```
unsloth
-registry
__init__.py
registry.py
_llama.py
_mistral.py
_phi.py
...
```

Each model is registered in a separate file within the `registry` module (e.g. `registry/_llama.py`).

Within each model registration file, a high-level `ModelMeta` is created for each model version, with the following structure:
```python
@dataclass
class ModelMeta:
org: str
base_name: str
model_version: str
model_info_cls: type[ModelInfo]
model_sizes: list[str] = field(default_factory=list)
instruct_tags: list[str] = field(default_factory=list)
quant_types: list[QuantType] | dict[str, list[QuantType]] = field(default_factory=list)
is_multimodal: bool = False
```

Each model then instantiates a global `ModelMeta` for its specific model version, defining how the model path (e.g. `unsloth/Llama-3.1-8B-Instruct`) is constructed since each model type has a different naming convention.
```python
LlamaMeta_3_1 = ModelMeta(
org="meta-llama",
base_name="Llama",
instruct_tags=[None, "Instruct"],
model_version="3.1",
model_sizes=["8"],
model_info_cls=LlamaModelInfo,
is_multimodal=False,
quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
)
```

`LlamaModelInfo` is a subclass of `ModelInfo` that defines the model path for each model size and quant type.
```python
class LlamaModelInfo(ModelInfo):
@classmethod
def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
key = f"{base_name}-{version}-{size}B"
return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key)
```

Once these constructs are defined, the model is registered by writing a register_xx_models function.
```python
def register_llama_3_1_models(include_original_model: bool = False):
global _IS_LLAMA_3_1_REGISTERED
if _IS_LLAMA_3_1_REGISTERED:
return
_register_models(LlamaMeta_3_1, include_original_model=include_original_model)
_IS_LLAMA_3_1_REGISTERED = True
```

`_register_models` is a helper function that registers the model with the registry. The global `_IS_XX_REGISTERED` is used to prevent duplicate registration.

Once a model is registered, registry.registry.MODEL_REGISTRY is updated with the model info and can be searched with `registry.search_models`.

### Tests

The `tests/test_model_registry.py` file contains tests for the model registry.

Also, each model registration file is an executable module that checks that all registered models are available on `huggingface_hub`.
```python
python unsloth.registry._llama.py
```

Prints the following (abridged) output:
```bash
✓ unsloth/Llama-3.1-8B
✓ unsloth/Llama-3.1-8B-bnb-4bit
✓ unsloth/Llama-3.1-8B-unsloth-bnb-4bit
✓ meta-llama/Llama-3.1-8B
✓ unsloth/Llama-3.1-8B-Instruct
✓ unsloth/Llama-3.1-8B-Instruct-bnb-4bit
✓ unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit
✓ meta-llama/Llama-3.1-8B-Instruct
✓ unsloth/Llama-3.2-1B
✓ unsloth/Llama-3.2-1B-bnb-4bit
✓ unsloth/Llama-3.2-1B-unsloth-bnb-4bit
✓ meta-llama/Llama-3.2-1B
...
```

### TODO
- Model Collections
- [x] Gemma3
- [ ] Llama3.1
- [x] Llama3.2
- [x] MistralSmall
- [x] Qwen2.5
- [x] Qwen2.5-VL
- [ ] Qwen2.5 Coder
- [x] QwenQwQ-32B
- [x] Deepseek v3
- [x] Deepseek R1
- [x] Phi-4
- [ ] Unsloth 4-bit Dynamic Quants
- [ ] Vision/multimodal models
- Sync model uploads with registry
- Add utility methods for tracking model stats
Loading