Skip to content

Commit 6d71c53

Browse files
authored
Fix sgd (#2777)
* Only update momentum buffers for SGD if momentum is enabled * update stock PT version
1 parent c1dc7ae commit 6d71c53

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

dependency_version.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ oneCCL:
2828
protobuf:
2929
version: 3.20.3
3030
pytorch:
31-
version: 2.4.0.dev20240401+cpu
31+
version: 2.4.0.dev20240417+cpu
3232
torch-ccl:
3333
commit: ccl_torch_dev_0131
3434
repo: https://github.com/intel/torch-ccl.git
3535
version: 2.3.0+cpu
3636
torchaudio:
37-
version: 2.2.0.dev20240401+cpu
37+
version: 2.2.0.dev20240417+cpu
3838
torchvision:
39-
version: 0.19.0.dev20240401+cpu
39+
version: 0.19.0.dev20240417+cpu
4040
transformers:
4141
version: 4.38.1

intel_extension_for_pytorch/optim/_functional.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -526,10 +526,11 @@ def sgd_step(self, closure=None):
526526
fused=self.fused,
527527
)
528528

529-
# update momentum_buffers in state
530-
for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
531-
state = self.state[p]
532-
state["momentum_buffer"] = momentum_buffer
529+
if group["momentum"] != 0:
530+
# update momentum_buffers in state
531+
for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
532+
state = self.state[p]
533+
state["momentum_buffer"] = momentum_buffer
533534

534535
return loss
535536

0 commit comments

Comments
 (0)