Skip to content

Commit 7d47557

Browse files
fix: config.torch_dtype in LlamaModel_fast_forward_inference (#2091)
* fix: config.torch_dtype in LlamaModel_fast_forward_inference * Update llama.py * update for consistency --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com>
1 parent cca0d38 commit 7d47557

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

unsloth/models/llama.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -652,13 +652,7 @@ def LlamaModel_fast_forward(
652652
if inputs_embeds is None:
653653
inputs_embeds = self.embed_tokens(input_ids)
654654

655-
# inputs_embeds = inputs_embeds.to(self.config.torch_dtype)
656-
torch_dtype = __DTYPE_MAP.get(self.config.torch_dtype, None)
657-
if torch_dtype is not None:
658-
inputs_embeds = inputs_embeds.to(torch_dtype)
659-
else:
660-
raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!")
661-
pass
655+
inputs_embeds = inputs_embeds.to(_get_dtype(self.config.torch_dtype))
662656

663657
# Normalized from Gemma
664658
IS_GEMMA = self.config.model_type.startswith("gemma")
@@ -924,7 +918,7 @@ def LlamaModel_fast_forward_inference(
924918
mlp_size = self.config.intermediate_size
925919

926920
X = self.model.embed_tokens(input_ids)
927-
X = X.to(self.config.torch_dtype)
921+
X = X.to(_get_dtype(self.config.torch_dtype))
928922
bsz, q_len, hd = X.shape
929923
assert(q_len == 1)
930924
# Get saved buffers to reduce memory movement

0 commit comments

Comments
 (0)