forked from phizaz/diffae
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathunet_autoenc.py
283 lines (238 loc) · 8.96 KB
/
unet_autoenc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
from enum import Enum
import torch
from torch import Tensor
from torch.nn.functional import silu
from .latentnet import *
from .unet import *
from choices import *
@dataclass
class BeatGANsAutoencConfig(BeatGANsUNetConfig):
# number of style channels
enc_out_channels: int = 512
enc_attn_resolutions: Tuple[int] = None
enc_pool: str = 'depthconv'
enc_num_res_block: int = 2
enc_channel_mult: Tuple[int] = None
enc_grad_checkpoint: bool = False
latent_net_conf: MLPSkipNetConfig = None
def make_model(self):
return BeatGANsAutoencModel(self)
class BeatGANsAutoencModel(BeatGANsUNetModel):
def __init__(self, conf: BeatGANsAutoencConfig):
super().__init__(conf)
self.conf = conf
# having only time, cond
self.time_embed = TimeStyleSeperateEmbed(
time_channels=conf.model_channels,
time_out_channels=conf.embed_channels,
)
self.encoder = BeatGANsEncoderConfig(
image_size=conf.image_size,
in_channels=conf.in_channels,
model_channels=conf.model_channels,
out_hid_channels=conf.enc_out_channels,
out_channels=conf.enc_out_channels,
num_res_blocks=conf.enc_num_res_block,
attention_resolutions=(conf.enc_attn_resolutions
or conf.attention_resolutions),
dropout=conf.dropout,
channel_mult=conf.enc_channel_mult or conf.channel_mult,
use_time_condition=False,
conv_resample=conf.conv_resample,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint or conf.enc_grad_checkpoint,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
resblock_updown=conf.resblock_updown,
use_new_attention_order=conf.use_new_attention_order,
pool=conf.enc_pool,
).make_model()
if conf.latent_net_conf is not None:
self.latent_net = conf.latent_net_conf.make_model()
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
assert self.conf.is_stochastic
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def sample_z(self, n: int, device):
assert self.conf.is_stochastic
return torch.randn(n, self.conf.enc_out_channels, device=device)
def noise_to_cond(self, noise: Tensor):
raise NotImplementedError()
assert self.conf.noise_net_conf is not None
return self.noise_net.forward(noise)
def encode(self, x):
cond = self.encoder.forward(x)
return {'cond': cond}
@property
def stylespace_sizes(self):
modules = list(self.input_blocks.modules()) + list(
self.middle_block.modules()) + list(self.output_blocks.modules())
sizes = []
for module in modules:
if isinstance(module, ResBlock):
linear = module.cond_emb_layers[-1]
sizes.append(linear.weight.shape[0])
return sizes
def encode_stylespace(self, x, return_vector: bool = True):
"""
encode to style space
"""
modules = list(self.input_blocks.modules()) + list(
self.middle_block.modules()) + list(self.output_blocks.modules())
# (n, c)
cond = self.encoder.forward(x)
S = []
for module in modules:
if isinstance(module, ResBlock):
# (n, c')
s = module.cond_emb_layers.forward(cond)
S.append(s)
if return_vector:
# (n, sum_c)
return torch.cat(S, dim=1)
else:
return S
def forward(self,
x,
t,
y=None,
x_start=None,
cond=None,
style=None,
noise=None,
t_cond=None,
**kwargs):
"""
Apply the model to an input batch.
Args:
x_start: the original image to encode
cond: output of the encoder
noise: random noise (to predict the cond)
"""
if t_cond is None:
t_cond = t
if noise is not None:
# if the noise is given, we predict the cond from noise
cond = self.noise_to_cond(noise)
if cond is None:
if x is not None:
assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}'
tmp = self.encode(x_start)
cond = tmp['cond']
if t is not None:
_t_emb = timestep_embedding(t, self.conf.model_channels)
_t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels)
else:
# this happens when training only autoenc
_t_emb = None
_t_cond_emb = None
if self.conf.resnet_two_cond:
res = self.time_embed.forward(
time_emb=_t_emb,
cond=cond,
time_cond_emb=_t_cond_emb,
)
else:
raise NotImplementedError()
if self.conf.resnet_two_cond:
# two cond: first = time emb, second = cond_emb
emb = res.time_emb
cond_emb = res.emb
else:
# one cond = combined of both time and cond
emb = res.emb
cond_emb = None
# override the style if given
style = style or res.style
assert (y is not None) == (
self.conf.num_classes is not None
), "must specify y if and only if the model is class-conditional"
if self.conf.num_classes is not None:
raise NotImplementedError()
# assert y.shape == (x.shape[0], )
# emb = emb + self.label_emb(y)
# where in the model to supply time conditions
enc_time_emb = emb
mid_time_emb = emb
dec_time_emb = emb
# where in the model to supply style conditions
enc_cond_emb = cond_emb
mid_cond_emb = cond_emb
dec_cond_emb = cond_emb
# hs = []
hs = [[] for _ in range(len(self.conf.channel_mult))]
if x is not None:
h = x.type(self.dtype)
# input blocks
k = 0
for i in range(len(self.input_num_blocks)):
for j in range(self.input_num_blocks[i]):
h = self.input_blocks[k](h,
emb=enc_time_emb,
cond=enc_cond_emb)
# print(i, j, h.shape)
hs[i].append(h)
k += 1
assert k == len(self.input_blocks)
# middle blocks
h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb)
else:
# no lateral connections
# happens when training only the autonecoder
h = None
hs = [[] for _ in range(len(self.conf.channel_mult))]
# output blocks
k = 0
for i in range(len(self.output_num_blocks)):
for j in range(self.output_num_blocks[i]):
# take the lateral connection from the same layer (in reserve)
# until there is no more, use None
try:
lateral = hs[-i - 1].pop()
# print(i, j, lateral.shape)
except IndexError:
lateral = None
# print(i, j, lateral)
h = self.output_blocks[k](h,
emb=dec_time_emb,
cond=dec_cond_emb,
lateral=lateral)
k += 1
pred = self.out(h)
return AutoencReturn(pred=pred, cond=cond)
class AutoencReturn(NamedTuple):
pred: Tensor
cond: Tensor = None
class EmbedReturn(NamedTuple):
# style and time
emb: Tensor = None
# time only
time_emb: Tensor = None
# style only (but could depend on time)
style: Tensor = None
class TimeStyleSeperateEmbed(nn.Module):
# embed only style
def __init__(self, time_channels, time_out_channels):
super().__init__()
self.time_embed = nn.Sequential(
linear(time_channels, time_out_channels),
nn.SiLU(),
linear(time_out_channels, time_out_channels),
)
self.style = nn.Identity()
def forward(self, time_emb=None, cond=None, **kwargs):
if time_emb is None:
# happens with autoenc training mode
time_emb = None
else:
time_emb = self.time_embed(time_emb)
style = self.style(cond)
return EmbedReturn(emb=style, time_emb=time_emb, style=style)