Skip to content

Commit 472f4bc

Browse files
danielhanchenDatta0shimmyshimmerjeromekummathew23
authored
AMD fixes (#3467)
* Upcast layernorms * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update save.py * Update rl.py * Update pyproject.toml * Update rl.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl.py * Update _utils.py * Update __init__.py * Torch 2.8 * Update rl_replacements.py * Update loader.py * UNSLOTH_ENABLE_CCE * Fix * Update loader.py * Update loader.py * Update __init__.py * Update __init__.py * Update __init__.py * Update __init__.py * Import fixes * Update loader.py * Fix aimv2 issue * Update loader.py * Update import_fixes.py * Update import_fixes.py * Update loader.py * Update loader.py * Update loader.py * Upgrade * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update vision.py * Update vision.py * custom_datatype * recheck * Float16 * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Bug fix * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * torch_dtype * Update rl.py * Fix CE Loss * Versioning * Update loader.py * Update loader.py * extract_model_type_from_config * Model types * Update loader.py * get_transformers_model_type * Update loader.py * Update loader.py * Update loader.py * Update rl.py * Update pyproject.toml * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Versioning * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update vision.py * Update vision.py * Fix DataParallel * Update _utils.py * Update rl.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update mapper.py * Versioning * Update loader.py * Update loader.py * Update rl.py * Versioning * Update _utils.py * Fix auto_mapping * Update loader.py * Update loader.py * Update vision.py * Update vision.py * Update loader.py * Message * Update vision.py * Update loader.py * Update vision.py * cache_implementation * Update vision.py * Update loader.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * Update vision.py * Save max_seq_length * Update _utils.py * Update rl.py * Update vision.py * Update llama.py * Mistral3 vllm (#3349) * [WIP] use vLLM for vision language models * Update README.md Editing icon sizes * Update README.md Updating icon sizes * Update README.md (#2885) * MoE kernels AGPLv3 * versioning * Many bug fixes (#2908) * add deepseek v3 * add deepseek r1 base * add deepseek r1 zero * add deepseek distill llama * add deepseek distill models * remove redundant code when constructing model names * add mistral small to registry * rename model registration methods * rename deepseek registration methods * refactor naming for mistral and phi * add global register models * refactor model registration tests for new registry apis * add model search method * remove deprecated registration api * add quant type test * add registry readme * make llama registration more specific * clear registry when executing individual model registration file * more registry readme updates * Update _auto_install.py * Llama4 * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Synthetic data * Update mapper.py * Xet and Synthetic * Update synthetic.py * Update loader.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update pyproject.toml * Delete .gitignore * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update _utils.py * Update pyproject.toml * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update chat_templates.py * Seasame force float16 / float32 * Fix Seasame * Update loader.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * is_multimodal * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update vision.py * Update vision.py * Update vision.py * UNSLOTH_DISABLE_STATIC_GENERATION * Update vision.py * Auto vision detection * Sesame * Whisper * Update loader.py * Update loader.py * Update loader.py * Update mapper.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update _utils.py * Update rl.py * versioning * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * logging * Update pyproject.toml * Update rl.py * versioning * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * logits / temperature * Update rl_replacements.py * Update pyproject.toml * Update rl_replacements.py * Update rl_replacements.py * Debugging only * Update llama.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Generic efficient GRPO * Update rl_replacements.py * Update rl_replacements.py * Remove debugging * Update rl_replacements.py * Update rl_replacements.py * Update vision.py * Update llama.py * Update rl_replacements.py * versioning * Update _utils.py * Update vision.py * Update mapper.py * Update loader.py * Update mapper.py * Update vision.py * Update loader.py * Update vision.py * Update loader.py * Update _utils.py * Update vision.py * gradient checkpointing * Gemma 3N fixes * Update loader.py * Versioning * Gemma 3N fixes * Update vision.py * Update vision.py * Update loader.py * Update vision.py * Fix setup.py * setup.py * Prints * Update setup.py * Update setup.py * Update setup.py * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update vision.py * Update vision.py * Update pyproject.toml * Update vision.py * Update _utils.py * Update __init__.py * Update __init__.py --------- Co-authored-by: jeromeku <jerome.ku@gmail.com> Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com> * silienty skip falcon h1 import is transformers_version < 4.53.0 (#2912) * Dynamically adjust get_per_token_logps function and patch as well (#2911) * add intel gpu with vllm support (#2903) * [bugs] fix for casual mask (#2868) * fix for casual mask * use un_casual in sdpa * add missing mask * fix for type * Explicitly check if xformers exists for attention (#2889) * Update __init__.py * Update llama.py * if mlp doesn't exist in layer module check for feed_forward name for falcon h1 (#2913) * Move inputs to right devices. (#2919) * Move tensors to right devices * fix multi gpu for non mistral models * multi GPU RoPE for gemma2 * Finish up multi GPU inference * Make multiGPU rope a list * Remove unnecessary transfer to CPU * Remove unnecessary move to CPU * Donot move inputs to device yet will be handled separately in another PR * Move inputs to appropriate decoder device * Make device count global variable * Cleanup RoPE device code * Fixup num_gpu to device count * Cleanup device counts * Use device index for RoPE get_cache * Donot typecast * Use tuple instead of list for tensors. Use device index directly * fixup move to device logic * WIP VLM vLLM * Make vLLM patch a function * Add save and load lora functions * Make fast_inference setup depend on the flag * Improve fast inference patching mechanism * Make vision setting depend on checks in fastbasemodel * Check LoRA and vLLM intercompatibility for vision models * Comment pointing to vLLM LoRA check * Improve lora validation on vLLM * Error out on no vLLM and increase max lora rank * Bug fixes (#3017) * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update pyproject.toml * Delete .gitignore * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update _utils.py * Update pyproject.toml * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update chat_templates.py * Seasame force float16 / float32 * Fix Seasame * Update loader.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * is_multimodal * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update vision.py * Update vision.py * Update vision.py * UNSLOTH_DISABLE_STATIC_GENERATION * Update vision.py * Auto vision detection * Sesame * Whisper * Update loader.py * Update loader.py * Update loader.py * Update mapper.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update _utils.py * Update rl.py * versioning * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * logging * Update pyproject.toml * Update rl.py * versioning * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * logits / temperature * Update rl_replacements.py * Update pyproject.toml * Update rl_replacements.py * Update rl_replacements.py * Debugging only * Update llama.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Generic efficient GRPO * Update rl_replacements.py * Update rl_replacements.py * Remove debugging * Update rl_replacements.py * Update rl_replacements.py * Update vision.py * Update llama.py * Update rl_replacements.py * versioning * Update _utils.py * Update vision.py * Update mapper.py * Update loader.py * Update mapper.py * Update vision.py * Update loader.py * Update vision.py * Update loader.py * Update _utils.py * Update vision.py * gradient checkpointing * Gemma 3N fixes * Update loader.py * Versioning * Gemma 3N fixes * Update vision.py * Update vision.py * Update loader.py * Update vision.py * Fix setup.py * setup.py * Prints * Update setup.py * Update setup.py * Update setup.py * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update vision.py * Update vision.py * Update pyproject.toml * Update vision.py * Update _utils.py * Update __init__.py * Update __init__.py * Small fixes * Update vision.py * Update vision.py * versioning * Update __init__.py * Update llama.py * Update rl.py * Update rl.py * Update _utils.py * Update vision.py * Update vision.py * compiler stance * Update _utils.py * Update pyproject.toml * Update pyproject.toml * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Revert "Revert "Add Qwen2.5-VL-32B-Instruct mapping to fix quantized model me…" (#2990) This reverts commit 204fc46. * skip_guard_eval_unsafe fix * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update llama.py * Update llama.py * Fix `quantization_method` * versioning * fix for casual mask (#3011) * [intel] add for intel path for llama.py (#3012) * fix for intel path * remove unuse code * Update unsloth/models/llama.py --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Update llama.py * Fix Gemma 2 (#3024) * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update pyproject.toml * Delete .gitignore * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update _utils.py * Update pyproject.toml * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update chat_templates.py * Seasame force float16 / float32 * Fix Seasame * Update loader.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * is_multimodal * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update vision.py * Update vision.py * Update vision.py * UNSLOTH_DISABLE_STATIC_GENERATION * Update vision.py * Auto vision detection * Sesame * Whisper * Update loader.py * Update loader.py * Update loader.py * Update mapper.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update _utils.py * Update rl.py * versioning * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * logging * Update pyproject.toml * Update rl.py * versioning * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * logits / temperature * Update rl_replacements.py * Update pyproject.toml * Update rl_replacements.py * Update rl_replacements.py * Debugging only * Update llama.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Generic efficient GRPO * Update rl_replacements.py * Update rl_replacements.py * Remove debugging * Update rl_replacements.py * Update rl_replacements.py * Update vision.py * Update llama.py * Update rl_replacements.py * versioning * Update _utils.py * Update vision.py * Update mapper.py * Update loader.py * Update mapper.py * Update vision.py * Update loader.py * Update vision.py * Update loader.py * Update _utils.py * Update vision.py * gradient checkpointing * Gemma 3N fixes * Update loader.py * Versioning * Gemma 3N fixes * Update vision.py * Update vision.py * Update loader.py * Update vision.py * Fix setup.py * setup.py * Prints * Update setup.py * Update setup.py * Update setup.py * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update vision.py * Update vision.py * Update pyproject.toml * Update vision.py * Update _utils.py * Update __init__.py * Update __init__.py * Small fixes * Update vision.py * Update vision.py * versioning * Update __init__.py * Update llama.py * Update rl.py * Update rl.py * Update _utils.py * Update vision.py * Update vision.py * compiler stance * Update _utils.py * Update pyproject.toml * Update pyproject.toml * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Revert "Revert "Add Qwen2.5-VL-32B-Instruct mapping to fix quantized model me…" (#2990) This reverts commit 204fc46. * skip_guard_eval_unsafe fix * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update llama.py * Update llama.py * Fix `quantization_method` * versioning * Update _utils.py * Update _utils.py * Update _utils.py * falcon force float32 on sm<75 machines (#3026) * Fix torch compile issues (#3028) * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update pyproject.toml * Delete .gitignore * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update _utils.py * Update pyproject.toml * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update chat_templates.py * Seasame force float16 / float32 * Fix Seasame * Update loader.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * is_multimodal * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update vision.py * Update vision.py * Update vision.py * UNSLOTH_DISABLE_STATIC_GENERATION * Update vision.py * Auto vision detection * Sesame * Whisper * Update loader.py * Update loader.py * Update loader.py * Update mapper.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update _utils.py * Update rl.py * versioning * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * logging * Update pyproject.toml * Update rl.py * versioning * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * logits / temperature * Update rl_replacements.py * Update pyproject.toml * Update rl_replacements.py * Update rl_replacements.py * Debugging only * Update llama.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Generic efficient GRPO * Update rl_replacements.py * Update rl_replacements.py * Remove debugging * Update rl_replacements.py * Update rl_replacements.py * Update vision.py * Update llama.py * Update rl_replacements.py * versioning * Update _utils.py * Update vision.py * Update mapper.py * Update loader.py * Update mapper.py * Update vision.py * Update loader.py * Update vision.py * Update loader.py * Update _utils.py * Update vision.py * gradient checkpointing * Gemma 3N fixes * Update loader.py * Versioning * Gemma 3N fixes * Update vision.py * Update vision.py * Update loader.py * Update vision.py * Fix setup.py * setup.py * Prints * Update setup.py * Update setup.py * Update setup.py * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update vision.py * Update vision.py * Update pyproject.toml * Update vision.py * Update _utils.py * Update __init__.py * Update __init__.py * Small fixes * Update vision.py * Update vision.py * versioning * Update __init__.py * Update llama.py * Update rl.py * Update rl.py * Update _utils.py * Update vision.py * Update vision.py * compiler stance * Update _utils.py * Update pyproject.toml * Update pyproject.toml * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Revert "Revert "Add Qwen2.5-VL-32B-Instruct mapping to fix quantized model me…" (#2990) This reverts commit 204fc46. * skip_guard_eval_unsafe fix * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update llama.py * Update llama.py * Fix `quantization_method` * versioning * Update _utils.py * Update _utils.py * Update _utils.py * check stride * Cleanup * Update rope_embedding.py * Update gemma2.py * Fix `set_stance` * Update pyproject.toml * Update _utils.py * Fixup patch vllm * Disable mllama * Use variables to decide VLM support * Better attn_impl handling * Patch TF protobuf incompatability * Torch 2.8 (#3186) * Fix mamba * Update loader.py * Update vision.py * Update loader.py * Filter vLLM standby logs (#3131) * filter vLLM standby logs * safeguard standby logger patch * Update unsloth/models/_utils.py * Update unsloth/models/_utils.py * Update unsloth/models/_utils.py --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Update loader.py * Add scaler * Update llama.py * Update _utils.py * Versioning * GPT OSS fix * GPT OSS fix * Update loader.py * Update vision.py * Update vision.py * Update loader.py * Update vision.py * Update vision.py * Update llama.py * Update llama.py * Update llama.py * Versioning * Update mapper.py * Update vision.py * Update vision.py * Update vision.py * Upcast norms * Update loader.py * Update vision.py * Upcast layernorms * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update save.py * Update rl.py * Update pyproject.toml * Update rl.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl.py * Update _utils.py * Update __init__.py * Torch 2.8 * Update rl_replacements.py --------- Co-authored-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com> * Update _auto_install.py * Update pyproject.toml * Update rl.py * Protobuf issue * Update pyproject.toml * Fix extras transformers typo in pyproject.toml * Update _utils.py * Bug fixes (#3195) * Fix mamba * Update loader.py * Update vision.py * Update loader.py * Filter vLLM standby logs (#3131) * filter vLLM standby logs * safeguard standby logger patch * Update unsloth/models/_utils.py * Update unsloth/models/_utils.py * Update unsloth/models/_utils.py --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Update loader.py * Add scaler * Update llama.py * Update _utils.py * Versioning * GPT OSS fix * GPT OSS fix * Update loader.py * Update vision.py * Update vision.py * Update loader.py * Update vision.py * Update vision.py * Update llama.py * Update llama.py * Update llama.py * Versioning * Update mapper.py * Update vision.py * Update vision.py * Update vision.py * Upcast norms * Update loader.py * Update vision.py * Upcast layernorms * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update save.py * Update rl.py * Update pyproject.toml * Update rl.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl.py * Update _utils.py * Update __init__.py * Torch 2.8 * Update rl_replacements.py * Update loader.py * UNSLOTH_ENABLE_CCE * Fix * Update loader.py * Update loader.py * Update __init__.py * Update __init__.py * Update __init__.py * Update __init__.py * Import fixes * Update loader.py * Fix aimv2 issue * Update loader.py * Update import_fixes.py * Update import_fixes.py * Update loader.py * Update loader.py * Update loader.py * Upgrade * Update loader.py * Update loader.py * Update loader.py * Update loader.py --------- Co-authored-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com> * adallow float32 dtype in FastLanguageModel (#3204) * Update loader.py * Update vision.py * Suppress message and use unsloth sampling params * Use trl sampling params for now * Improve error message * fixup quantized fast inference model name * Add mistral 3 support --------- Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com> Co-authored-by: Daniel Han <danielhanchen@gmail.com> Co-authored-by: jeromeku <jerome.ku@gmail.com> Co-authored-by: DoubleMathew <mmathew23@gmail.com> Co-authored-by: Lei Zhenyuan <zhenyuan.lei@intel.com> Co-authored-by: parth2510 <parthguptapg7326@gmail.com> * Set padding to 0 * Fix patch * fixup patch (#3359) Co-authored-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com> * Update vision.py * Versioning * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * MXFP4 dequant * Update loader.py * Update vision.py * load_in_16bit * Update vision.py * Update vision.py * Update vision.py * Update rl.py * Update vision.py * offload_embedding * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update rl_replacements.py * Update loader.py * Fix padding issue * Update pyproject.toml * Update _utils.py * Update pyproject.toml * Update _utils.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * New models * Update llama.py * Versioning * Update _utils.py * Update llama.py * Update _utils.py * Update llama.py * Fix AMD * Update _utils.py * Update llama.py * Update vision.py * DEVICE_TYPE_TORCH * Update __init__.py * Update __init__.py --------- Co-authored-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com> Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com> Co-authored-by: jeromeku <jerome.ku@gmail.com> Co-authored-by: DoubleMathew <mmathew23@gmail.com> Co-authored-by: Lei Zhenyuan <zhenyuan.lei@intel.com> Co-authored-by: parth2510 <parthguptapg7326@gmail.com>
1 parent b20a56a commit 472f4bc

File tree

5 files changed

+38
-24
lines changed

5 files changed

+38
-24
lines changed

unsloth/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ def get_device_type():
9696
raise NotImplementedError("Unsloth currently only works on NVIDIA, AMD and Intel GPUs.")
9797
pass
9898
DEVICE_TYPE : str = get_device_type()
99+
# HIP fails for autocast and other torch functions. Use CUDA instead
100+
DEVICE_TYPE_TORCH = DEVICE_TYPE
101+
if DEVICE_TYPE_TORCH == "hip": DEVICE_TYPE_TORCH = "cuda"
99102

100103
@functools.cache
101104
def get_device_count():
@@ -146,7 +149,9 @@ def get_device_count():
146149
# OutOfResources: out of resource: shared memory, Required: 98304, Hardware limit: 65536. Reducing block sizes or `num_stages`
147150
if (major_torch >= 2 and minor_torch >= 8) or (major_torch > 2):
148151
os.environ["UNSLOTH_ENABLE_CCE"] = "0"
149-
pass
152+
elif DEVICE_TYPE == "hip":
153+
# CCE also fails in HIP / AMD
154+
os.environ["UNSLOTH_ENABLE_CCE"] = "0"
150155

151156
# Fix other issues
152157
import importlib.util

unsloth/models/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
import warnings, subprocess, re, inspect, psutil, os, math
8888
from unsloth_zoo.utils import Version
8989
from importlib.metadata import version as importlib_version
90-
from unsloth import DEVICE_TYPE, DEVICE_COUNT
90+
from unsloth import DEVICE_TYPE, DEVICE_COUNT, DEVICE_TYPE_TORCH
9191
from unsloth_zoo.log import logger
9292
from unsloth_zoo.tokenizer_utils import (
9393
patch_tokenizer as _patch_tokenizer,

unsloth/models/llama.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from unsloth_zoo.utils import Version, _get_dtype
2828
from unsloth_zoo.hf_utils import dtype_from_config, add_dtype_kwargs, fix_lora_auto_mapping
2929
from unsloth_zoo.peft_utils import SKIP_QUANTIZATION_MODULES
30-
from unsloth import DEVICE_TYPE, DEVICE_COUNT
30+
from unsloth import DEVICE_TYPE, DEVICE_COUNT, DEVICE_TYPE_TORCH
3131

3232
transformers_version = Version(transformers_version)
3333
# Transformers moved rotary embeddings out of all attention layers
@@ -732,7 +732,7 @@ def LlamaModel_fast_forward(
732732
position_ids = torch.arange(
733733
past_key_values_length, seq_length + past_key_values_length,
734734
dtype = torch.int32,
735-
device = f"{DEVICE_TYPE}:0",
735+
device = f"{DEVICE_TYPE_TORCH}:0",
736736
)
737737
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
738738
elif position_ids is not None:
@@ -905,13 +905,13 @@ def LlamaModel_fast_forward(
905905
is_causal = True,
906906
sliding_window = self.config.sliding_window,
907907
)\
908-
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = DEVICE_TYPE,)\
908+
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = DEVICE_TYPE_TORCH,)\
909909
.squeeze(0).squeeze(0)
910910

911911
self.GA_mask = AttentionMaskConverter(
912912
is_causal = True,
913913
)\
914-
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = DEVICE_TYPE,)\
914+
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = DEVICE_TYPE_TORCH,)\
915915
.squeeze(0).squeeze(0)
916916
pass
917917
pass
@@ -1028,11 +1028,11 @@ def LlamaModel_fast_forward_inference_custom(
10281028
bsz, q_len, hd = X.shape
10291029
assert(q_len == 1)
10301030
# Get saved buffers to reduce memory movement
1031-
residual = torch.empty((bsz, q_len, hd), dtype = torch.float32, device = f"{DEVICE_TYPE}:0")
1032-
_XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = f"{DEVICE_TYPE}:0")
1031+
residual = torch.empty((bsz, q_len, hd), dtype = torch.float32, device = f"{DEVICE_TYPE_TORCH}:0")
1032+
_XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = f"{DEVICE_TYPE_TORCH}:0")
10331033
XX, XX2 = _XX[0], _XX[1]
1034-
variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = f"{DEVICE_TYPE}:0")
1035-
temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = f"{DEVICE_TYPE}:0")
1034+
variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = f"{DEVICE_TYPE_TORCH}:0")
1035+
temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = f"{DEVICE_TYPE_TORCH}:0")
10361036
temp_gates, temp_ups = tuple(temp_mlp[0].to(torch.device(x)) for x in range(DEVICE_COUNT)), tuple(temp_mlp[1].to(torch.device(x)) for x in range(DEVICE_COUNT))
10371037

10381038
seq_len = past_key_values[0][0].shape[-2]
@@ -1196,10 +1196,14 @@ def _CausalLM_fast_forward(
11961196
else:
11971197
RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1"
11981198
# < 1024 Normal Unsloth uses less VRAM!
1199-
if bsz*q_len <= 1024: RETURN_LOGITS = True
1199+
if DEVICE_TYPE == "hip":
1200+
# [TODO] AMD GPUs fail on chunked_cross_entropy loss!
1201+
# RuntimeError: Triton Error [HIP]: Code: 1, Messsage: invalid argument
1202+
RETURN_LOGITS = False
1203+
elif bsz*q_len <= 1024:
1204+
RETURN_LOGITS = True
12001205

12011206
if not RETURN_LOGITS and labels is not None:
1202-
12031207
n_items = kwargs.get("num_items_in_batch", None)
12041208
if n_items is None: n_items = kwargs.get("n_items", None)
12051209

@@ -1374,7 +1378,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=
13741378
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
13751379
dim = getattr(config, "head_dim", None)
13761380
if dim is None: dim = int((config.hidden_size // config.num_attention_heads))
1377-
device = DEVICE_TYPE
1381+
device = DEVICE_TYPE_TORCH
13781382
max_position_embeddings = config.max_position_embeddings
13791383
pass
13801384

@@ -1486,7 +1490,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=
14861490
base = config.rope_theta
14871491
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
14881492
dim = int((config.hidden_size // config.num_attention_heads))
1489-
device = DEVICE_TYPE
1493+
device = DEVICE_TYPE_TORCH
14901494
max_position_embeddings = config.max_position_embeddings
14911495
pass
14921496

@@ -1606,7 +1610,7 @@ def __init__(self,
16061610
base = config.rope_theta
16071611
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
16081612
dim = int((config.hidden_size // config.num_attention_heads))
1609-
device = DEVICE_TYPE
1613+
device = DEVICE_TYPE_TORCH
16101614
max_position_embeddings = config.max_position_embeddings
16111615
pass
16121616

@@ -1760,7 +1764,7 @@ def unsloth_fast_generate(
17601764
kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id)
17611765

17621766
# Mixed precision autocast
1763-
with torch.inference_mode(), torch.autocast(device_type = DEVICE_TYPE, dtype = dtype):
1767+
with torch.inference_mode(), torch.autocast(device_type = DEVICE_TYPE_TORCH, dtype = dtype):
17641768
output = self._old_generate(*args, **kwargs)
17651769
pass
17661770

@@ -2384,7 +2388,7 @@ def get_peft_model(
23842388
pass
23852389

23862390
model.get_input_embeddings().modules_to_save.default\
2387-
.to(device = DEVICE_TYPE, dtype = new_dtype, non_blocking = True)
2391+
.to(device = DEVICE_TYPE_TORCH, dtype = new_dtype, non_blocking = True)
23882392
model.get_input_embeddings().modules_to_save.default.requires_grad_(True)
23892393

23902394
# [TODO] Move old embed_tokens to CPU - should be disk!
@@ -2404,7 +2408,7 @@ def get_peft_model(
24042408
pass
24052409

24062410
model.get_output_embeddings().modules_to_save.default\
2407-
.to(device = DEVICE_TYPE, dtype = new_dtype, non_blocking = True)
2411+
.to(device = DEVICE_TYPE_TORCH, dtype = new_dtype, non_blocking = True)
24082412
model.get_output_embeddings().modules_to_save.default.requires_grad_(True)
24092413

24102414
# [TODO] Move old lm_head to CPU - should be disk!
@@ -2673,7 +2677,7 @@ def get_peft_model(
26732677
pass
26742678

26752679
model.get_input_embeddings().modules_to_save.default\
2676-
.to(device = DEVICE_TYPE, dtype = new_dtype, non_blocking = True)
2680+
.to(device = DEVICE_TYPE_TORCH, dtype = new_dtype, non_blocking = True)
26772681
model.get_input_embeddings().modules_to_save.default.requires_grad_(True)
26782682
pass
26792683

@@ -2689,7 +2693,7 @@ def get_peft_model(
26892693
pass
26902694

26912695
model.get_output_embeddings().modules_to_save.default\
2692-
.to(device = DEVICE_TYPE, dtype = new_dtype, non_blocking = True)
2696+
.to(device = DEVICE_TYPE_TORCH, dtype = new_dtype, non_blocking = True)
26932697
model.get_output_embeddings().modules_to_save.default.requires_grad_(True)
26942698
pass
26952699

unsloth/models/mistral.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,12 @@ def MistralForCausalLM_fast_forward(
298298
else:
299299
RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1"
300300
# < 1024 Normal Unsloth uses less VRAM!
301-
if bsz * q_len <= 1024: RETURN_LOGITS = True
301+
if DEVICE_TYPE == "hip":
302+
# [TODO] AMD GPUs fail on chunked_cross_entropy loss!
303+
# RuntimeError: Triton Error [HIP]: Code: 1, Messsage: invalid argument
304+
RETURN_LOGITS = False
305+
elif bsz*q_len <= 1024:
306+
RETURN_LOGITS = True
302307

303308
if not RETURN_LOGITS and labels is not None:
304309
n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None)

unsloth/models/vision.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
# Old HF Hub versions <= 0.0.25
7272
from huggingface_hub.utils._token import get_token
7373
pass
74-
from unsloth import DEVICE_TYPE, DEVICE_COUNT
74+
from unsloth import DEVICE_TYPE, DEVICE_COUNT, DEVICE_TYPE_TORCH
7575

7676
__all__ = [
7777
"FastBaseModel",
@@ -204,10 +204,10 @@ def unsloth_base_fast_generate(
204204

205205
# Mixed precision autocast
206206
if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
207-
autocaster = torch.autocast(device_type = "cuda", dtype = torch.float16)
207+
autocaster = torch.autocast(device_type = DEVICE_TYPE_TORCH, dtype = torch.float16)
208208
dtype = torch.float16
209209
else:
210-
autocaster = torch.autocast(device_type = "cuda", dtype = dtype)
210+
autocaster = torch.autocast(device_type = DEVICE_TYPE_TORCH, dtype = dtype)
211211

212212
# Prepare LoRA
213213
# state_dict = convert_lora_modules(self, dtype = dtype)

0 commit comments

Comments
 (0)