forked from openai/improved-diffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfp16_util.py
76 lines (60 loc) · 2.23 KB
/
fp16_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""
Helpers to train with 16-bit precision.
"""
import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
def convert_module_to_f16(l):
"""
Convert primitive modules to float16.
"""
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
l.weight.data = l.weight.data.half()
l.bias.data = l.bias.data.half()
def convert_module_to_f32(l):
"""
Convert primitive modules to float32, undoing convert_module_to_f16().
"""
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
l.weight.data = l.weight.data.float()
l.bias.data = l.bias.data.float()
def make_master_params(model_params):
"""
Copy model parameters into a (differently-shaped) list of full-precision
parameters.
"""
master_params = _flatten_dense_tensors(
[param.detach().float() for param in model_params]
)
master_params = nn.Parameter(master_params)
master_params.requires_grad = True
return [master_params]
def model_grads_to_master_grads(model_params, master_params):
"""
Copy the gradients from the model parameters into the master parameters
from make_master_params().
"""
master_params[0].grad = _flatten_dense_tensors(
[param.grad.data.detach().float() for param in model_params]
)
def master_params_to_model_params(model_params, master_params):
"""
Copy the master parameter data back into the model parameters.
"""
# Without copying to a list, if a generator is passed, this will
# silently not copy any parameters.
model_params = list(model_params)
for param, master_param in zip(
model_params, unflatten_master_params(model_params, master_params)
):
param.detach().copy_(master_param)
def unflatten_master_params(model_params, master_params):
"""
Unflatten the master parameters to look like model_params.
"""
return _unflatten_dense_tensors(master_params[0].detach(), model_params)
def zero_grad(model_params):
for param in model_params:
# Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
if param.grad is not None:
param.grad.detach_()
param.grad.zero_()