Skip to content

Commit ee8e80a

Browse files
authored
fix FSDP version related issues (#22489)
fix fsdp
1 parent c7ec71b commit ee8e80a

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

src/transformers/trainer.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1481,16 +1481,19 @@ def _wrap_model(self, model, training=True, dataloader=None):
14811481
mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
14821482
if type(model) != FSDP:
14831483
# XXX: Breaking the self.model convention but I see no way around it for now.
1484+
signature = inspect.signature(FSDP.__init__).parameters.keys()
1485+
kwargs = {}
1486+
for arg in ["limit_all_gathers", "forward_prefetch", "backward_prefetch"]:
1487+
if arg in signature:
1488+
kwargs[arg] = getattr(self, arg)
14841489
self.model = model = FSDP(
14851490
model,
14861491
sharding_strategy=self.fsdp,
14871492
cpu_offload=cpu_offload,
14881493
auto_wrap_policy=auto_wrap_policy,
14891494
mixed_precision=mixed_precision_policy,
14901495
device_id=self.args.device,
1491-
backward_prefetch=self.backward_prefetch,
1492-
forward_prefetch=self.forword_prefetch,
1493-
limit_all_gathers=self.limit_all_gathers,
1496+
**kwargs,
14941497
)
14951498
else:
14961499
try:

0 commit comments

Comments
 (0)