Skip to content

Commit fc178b5

Browse files
authored
FP8 training enhancements (#3496)
* Fix FP8 for models with non 8 multiple weights * patch fp8 forward methods for compiled models * patch hf quantizer for fp8 * Failsafe import of fbgemmfp8linear and fp8linear * Beautify
1 parent fe8d426 commit fc178b5

File tree

3 files changed

+73
-16
lines changed

3 files changed

+73
-16
lines changed

unsloth/kernels/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
apply_lora_o,
4545
fast_lora_forward,
4646
)
47+
from .fp8 import * # This step is to ensure that we patch the FbgmemFP8Linear and FP8Linear's forward functions before the execution of model creation so that this applies to compiled non fast inference models as well
4748
from .utils import fast_dequantize, fast_gemv, QUANT_STATE, fast_linear_forward, matmul_lora
4849

4950
from .flex_attention import (

unsloth/kernels/fp8.py

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,20 @@
1717
import triton.language as tl
1818
from torch.nn import functional as F
1919
import math
20+
from unsloth_zoo.log import logger
21+
22+
try:
23+
from transformers.integrations.finegrained_fp8 import FP8Linear
24+
except ImportError:
25+
FP8Linear = None
26+
logger.log("Unsloth: FP8 models need importing FP8Linear from `transformers.integrations.finegrained_fp8` but we don't see it.")
27+
28+
try:
29+
from transformers.integrations.fbgemm_fp8 import FbgemmFp8Linear
30+
except ImportError:
31+
FbgemmFp8Linear = None
32+
logger.log("Unsloth: FP8 models need importing FbgemmFP8Linear from `transformers.integrations.fbgemm_fp8` but we don't see it.")
33+
2034
try:
2135
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import triton_quantize_fp8_block
2236
except ImportError:
@@ -329,7 +343,7 @@ def forward(ctx, X, weight, weight_scale):
329343
@staticmethod
330344
def backward(ctx, grad_output):
331345
W_deq = weight_dequant(ctx.weight, ctx.weight_scale)
332-
grad_X = torch_matmul(grad_output, W_deq.t())
346+
grad_X = torch_matmul(grad_output, W_deq)
333347
del W_deq
334348
return grad_X, None, None
335349

@@ -338,20 +352,17 @@ def fp8_block_quant_forward(X, weight, weight_scale):
338352
return FP8BlockQuantLinear.apply(X, weight, weight_scale)
339353

340354

341-
class FbgemmFp8Linear(torch.autograd.Function):
355+
class FbgemmFp8Linear_matmul(torch.autograd.Function):
342356

343357
@staticmethod
344358
def forward(ctx, x, weight, weight_scale, bias=None):
345-
if weight.shape[0] != weight_scale.shape[0]:
346-
if weight.shape[1] == weight_scale.shape[0]:
347-
# This is generally the case when we do backward pass. The only way is to dequantize as there is no column wise fp8 matmul
348-
W_deq = weight_dequant(weight, weight_scale).T
349-
x = torch_matmul(x, W_deq)
350-
del W_deq
351-
return x
352-
else:
353-
raise ValueError(f"Shapes are incompatible {weight.shape=}, {weight_scale.shape=}, {x.shape=}")
354-
else:
359+
360+
if weight.shape[0] == weight_scale.shape[0] and (weight.shape[0] % 8 == 0 and weight.shape[1] % 8 == 0):
361+
# Edit: The kernel seems to expect that the weight has dimensions divisible by 8. Otherwise it throws `RuntimeError: cutlass cannot implement`
362+
# One thing we can do is to pad the weight and weight scale to multiple of 8 and perform a F8F8BF16 operation.
363+
# I tried benchmarking that for speed but observed that dequantize+bf16 matmul is significantly faster than padding+f8f8bf16 matmul. So we'll go that route.
364+
# So essentially, f8f8bf16_rowise only happens when shapes are proper (no transposes) and divisible by 8.
365+
355366
# quantize_fp8_per_row will squash the leading dimensions, so save the desired shape here
356367
output_shape = (*x.shape[:-1], -1)
357368
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
@@ -378,6 +389,16 @@ def forward(ctx, x, weight, weight_scale, bias=None):
378389
output = output.to(x.device, x.dtype)
379390
output = output.reshape(output_shape)
380391
del x_quantized, x_scale
392+
elif (weight.shape[0] != weight_scale.shape[0] and weight.shape[1] == weight_scale.shape[0]) or (weight.shape[0] // 8 != 0 or weight.shape[1] // 8 != 0):
393+
# Either the weight/scale is transposed or its shape is not divisible by 8. Both cases, dequantizing is the preferred way.
394+
# The transpose case is generally noticed in backward pass when we do dY@W instead of @W.T as we do for forward.
395+
# The shape case, I noticed to happen in MLP of Qwen 2.5 VL 7B where the gate proj is of shape (3420, 1280) and 3420/8=427.5
396+
397+
W_deq = weight_dequant(weight, weight_scale).T
398+
output = torch_matmul(x, W_deq)
399+
del W_deq
400+
else:
401+
raise ValueError(f"Shapes are incompatible {weight.shape=}, {weight_scale.shape=}, {x.shape=}")
381402

382403
ctx.weight = weight
383404
ctx.weight_scale = weight_scale
@@ -386,13 +407,13 @@ def forward(ctx, x, weight, weight_scale, bias=None):
386407
@staticmethod
387408
def backward(ctx, grad_output):
388409
W_deq = weight_dequant(ctx.weight, ctx.weight_scale)
389-
grad_X = torch_matmul(grad_output, W_deq.t())
410+
grad_X = torch_matmul(grad_output, W_deq)
390411
del W_deq
391412
return grad_X, None, None, None, None
392413

393414
@torch_compile
394-
def fbgemm_fp8_linear(X, weight, weight_scale, bias=None, ):
395-
return FbgemmFp8Linear.apply(X, weight, weight_scale, bias)
415+
def fbgemm_fp8_linear(X, weight, weight_scale, bias=None):
416+
return FbgemmFp8Linear_matmul.apply(X, weight, weight_scale, bias)
396417

397418

398419
class FP8_torch_linear(torch.autograd.Function):
@@ -437,7 +458,7 @@ def forward(ctx, X, weight, weight_scale, bias=None):
437458
@staticmethod
438459
def backward(ctx, grad_output):
439460
W_deq = weight_dequant(ctx.weight, ctx.weight_scale)
440-
grad_X = torch_matmul(grad_output, W_deq.t())
461+
grad_X = torch_matmul(grad_output, W_deq)
441462
del W_deq
442463
return grad_X, None, None, None, None
443464

@@ -459,3 +480,16 @@ def fp8_linear(X, weight, weight_scale, bias=None):
459480
# Row quantized FP8
460481
out = fbgemm_fp8_linear(X, weight, weight_scale, bias)
461482
return out
483+
484+
485+
def module_forward_patch(forward_function, scale_attr='weight_scale'):
486+
def patched_forward(self, X):
487+
return forward_function(X, self.weight, getattr(self, scale_attr))
488+
return patched_forward
489+
490+
491+
# Patch the forward functions of the layers (for compiled models)
492+
if FbgemmFp8Linear is not None:
493+
FbgemmFp8Linear.forward = module_forward_patch(fbgemm_fp8_linear, 'weight_scale')
494+
if FP8Linear is not None:
495+
FP8Linear.forward = module_forward_patch(fp8_block_quant_forward, 'weight_scale_inv')

unsloth/models/_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
"patch_peft_fast_inference",
7474
"error_out_no_vllm",
7575
"dequantize_module_weight",
76+
"patch_hf_quantizer",
7677
]
7778

7879
import torch
@@ -1814,3 +1815,24 @@ def _prepare_model_for_qat(model: torch.nn.Module, qat_scheme: Union[str, TorchA
18141815
quantize_(model, QATConfig(base_config, step = "prepare"), filter_fn = filter_fn)
18151816
return model
18161817
pass
1818+
1819+
def patch_hf_quantizer():
1820+
# To tell hf trainer that the quantized model is trainable
1821+
def make_trainable(self):
1822+
return True
1823+
try:
1824+
from transformers.quantizers.quantizer_finegrained_fp8 import FineGrainedFP8HfQuantizer
1825+
FineGrainedFP8HfQuantizer.is_trainable = property(make_trainable)
1826+
FineGrainedFP8HfQuantizer.is_qat_trainable = property(make_trainable)
1827+
except Exception as e:
1828+
logger.warning(f"Failed to patch FineGrainedFP8HfQuantizer. Error {e}")
1829+
1830+
try:
1831+
from transformers.quantizers.quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
1832+
FbgemmFp8HfQuantizer.is_trainable = property(make_trainable)
1833+
FbgemmFp8HfQuantizer.is_qat_trainable = property(make_trainable)
1834+
except Exception as e:
1835+
logger.warning(f"Failed to patch FbgemmFp8HfQuantizer. Error {e}")
1836+
pass
1837+
1838+
patch_hf_quantizer()

0 commit comments

Comments
 (0)