You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* Fix FP8 for models with non 8 multiple weights
* patch fp8 forward methods for compiled models
* patch hf quantizer for fp8
* Failsafe import of fbgemmfp8linear and fp8linear
* Beautify
Copy file name to clipboardExpand all lines: unsloth/kernels/__init__.py
+1Lines changed: 1 addition & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -44,6 +44,7 @@
44
44
apply_lora_o,
45
45
fast_lora_forward,
46
46
)
47
+
from .fp8import*# This step is to ensure that we patch the FbgmemFP8Linear and FP8Linear's forward functions before the execution of model creation so that this applies to compiled non fast inference models as well
47
48
from .utilsimportfast_dequantize, fast_gemv, QUANT_STATE, fast_linear_forward, matmul_lora
# This is generally the case when we do backward pass. The only way is to dequantize as there is no column wise fp8 matmul
348
-
W_deq=weight_dequant(weight, weight_scale).T
349
-
x=torch_matmul(x, W_deq)
350
-
delW_deq
351
-
returnx
352
-
else:
353
-
raiseValueError(f"Shapes are incompatible {weight.shape=}, {weight_scale.shape=}, {x.shape=}")
354
-
else:
359
+
360
+
ifweight.shape[0] ==weight_scale.shape[0] and (weight.shape[0] %8==0andweight.shape[1] %8==0):
361
+
# Edit: The kernel seems to expect that the weight has dimensions divisible by 8. Otherwise it throws `RuntimeError: cutlass cannot implement`
362
+
# One thing we can do is to pad the weight and weight scale to multiple of 8 and perform a F8F8BF16 operation.
363
+
# I tried benchmarking that for speed but observed that dequantize+bf16 matmul is significantly faster than padding+f8f8bf16 matmul. So we'll go that route.
364
+
# So essentially, f8f8bf16_rowise only happens when shapes are proper (no transposes) and divisible by 8.
365
+
355
366
# quantize_fp8_per_row will squash the leading dimensions, so save the desired shape here
356
367
output_shape= (*x.shape[:-1], -1)
357
368
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
0 commit comments