1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- __version__ = "2024.11.1 "
15+ __version__ = "2024.11.3 "
1616
1717__all__ = [
1818 "prepare_model_for_kbit_training" ,
6969 patch_compiling_bitsandbytes ,
7070 patch_layernorm ,
7171 patch_torch_compile ,
72- patch_regional_compilation ,
7372 patch_model_and_tokenizer ,
7473)
7574from unsloth_zoo .gradient_checkpointing import (
8887# Disable some warnings which can get annoying
8988warnings .filterwarnings (action = "ignore" , category = UserWarning , module = "torch" )
9089warnings .filterwarnings (action = "ignore" , category = UserWarning , module = "huggingface_hub" )
91- warnings .filterwarnings (action = "ignore" , category = UserWarning , module = "trl" )
9290warnings .filterwarnings (action = "ignore" , category = FutureWarning , module = "huggingface_hub" )
91+ warnings .filterwarnings (action = "ignore" , category = UserWarning , module = "trl" )
92+ warnings .filterwarnings (action = "ignore" , category = FutureWarning , module = "trl" )
9393warnings .filterwarnings (action = "ignore" , category = FutureWarning , module = "xformers" )
9494warnings .filterwarnings (action = "ignore" , category = RuntimeWarning , module = "subprocess" )
9595warnings .filterwarnings (action = "ignore" , category = UserWarning , module = "transformers" )
@@ -374,8 +374,9 @@ def _is_openai_available(): return False
374374
375375# =============================================
376376# Torch compile settings
377- UNSLOTH_COMPILE_DEBUG = "UNSLOTH_COMPILE_DEBUG" in os .environ
378- UNSLOTH_COMPILE_MAXIMUM = "UNSLOTH_COMPILE_MAXIMUM" in os .environ
377+ UNSLOTH_COMPILE_DEBUG = os .environ .get ("UNSLOTH_COMPILE_DEBUG" , "0" ) == "1"
378+ UNSLOTH_COMPILE_MAXIMUM = os .environ .get ("UNSLOTH_COMPILE_MAXIMUM" , "0" ) == "1"
379+ UNSLOTH_COMPILE_IGNORE_ERRORS = os .environ .get ("UNSLOTH_COMPILE_IGNORE_ERRORS" , "1" ) == "1"
379380# Just remove max_autotune_gemm warning
380381import functools
381382@functools .lru_cache (None )
@@ -387,7 +388,11 @@ def is_big_gpu(index):
387388 return True
388389import torch ._inductor .utils
389390torch ._inductor .utils .is_big_gpu = is_big_gpu
390- patch_torch_compile (debug = UNSLOTH_COMPILE_DEBUG , O3 = UNSLOTH_COMPILE_MAXIMUM )
391+ patch_torch_compile (
392+ debug = UNSLOTH_COMPILE_DEBUG ,
393+ O3 = UNSLOTH_COMPILE_MAXIMUM ,
394+ ignore_errors = UNSLOTH_COMPILE_IGNORE_ERRORS ,
395+ )
391396
392397torch_compile_options = {
393398 "epilogue_fusion" : True ,
@@ -408,6 +413,26 @@ def torch_compile_kwargs(*args, **kwargs):
408413accelerate .accelerator .TorchDynamoPlugin .to_kwargs = torch_compile_kwargs
409414del accelerate
410415
416+ def patch_regional_compilation ():
417+ # Regional torch 2.5 Recompilation - weirdly very slow??
418+ if torch .nn .ModuleList .__name__ == "UnslothModuleList" : return
419+ # Only works for torch 2.5
420+ if Version (torch .__version__ ) < Version ("2.5.0" ): return
421+
422+ old_module_list = torch .nn .ModuleList
423+ os .environ ["UNSLOTH_PATCHED" ] = "1"
424+
425+ def UnslothModuleList (* args , ** kwargs ):
426+ if len (args ) == 1 and len (kwargs ) == 0 and type (args [0 ]) is list :
427+ args = [old_module_list ([torch .compile (x , dynamic = True , options = torch_compile_options , fullgraph = False ) for x in args [0 ]])]
428+ return old_module_list (* args , ** kwargs )
429+ pass
430+ UnslothModuleList .__doc__ = old_module_list .__doc__
431+
432+ torch .nn .ModuleList = UnslothModuleList
433+ return
434+ pass
435+
411436# =============================================
412437
413438def prepare_model_for_kbit_training (
0 commit comments