Skip to content

Commit a28c246

Browse files
committed
update
1 parent 3671aff commit a28c246

4 files changed

+56
-29
lines changed

config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def make_T_sampler(self):
269269
raise NotImplementedError()
270270

271271
def make_diffusion_conf(self):
272-
return self._make_diffusion_conf(T=self.T)
272+
return self._make_diffusion_conf(self.T)
273273

274274
def make_eval_diffusion_conf(self):
275275
return self._make_diffusion_conf(T=self.T_eval)

experiment.py

+26-9
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,14 @@ def denormalize(self, cond):
9292
self.device)
9393
return cond
9494

95-
def sample(self, N, device):
95+
def sample(self, N, device, T=None, T_latent=None):
96+
if T is None:
97+
sampler = self.eval_sampler
98+
latent_sampler = self.latent_sampler
99+
else:
100+
sampler = self.conf._make_diffusion_conf(T).make_sampler()
101+
latent_sampler = self.conf._make_latent_diffusion_conf(T_latent).make_sampler()
102+
96103
noise = torch.randn(N,
97104
3,
98105
self.conf.img_size,
@@ -102,26 +109,31 @@ def sample(self, N, device):
102109
self.conf,
103110
self.ema_model,
104111
noise,
105-
sampler=self.eval_sampler,
106-
latent_sampler=self.eval_latent_sampler,
112+
sampler=sampler,
113+
latent_sampler=latent_sampler,
107114
conds_mean=self.conds_mean,
108115
conds_std=self.conds_std,
109116
)
110117
pred_img = (pred_img + 1) / 2
111118
return pred_img
112119

113-
def render(self, noise, cond=None):
120+
def render(self, noise, cond=None, T=None):
121+
if T is None:
122+
sampler = self.eval_sampler
123+
else:
124+
sampler = self.conf._make_diffusion_conf(T).make_sampler()
125+
114126
if cond is not None:
115127
pred_img = render_condition(self.conf,
116128
self.ema_model,
117129
noise,
118-
sampler=self.eval_sampler,
130+
sampler=sampler,
119131
cond=cond)
120132
else:
121133
pred_img = render_uncondition(self.conf,
122134
self.ema_model,
123135
noise,
124-
sampler=self.eval_sampler,
136+
sampler=sampler,
125137
latent_sampler=None)
126138
pred_img = (pred_img + 1) / 2
127139
return pred_img
@@ -132,9 +144,14 @@ def encode(self, x):
132144
cond = self.ema_model.encoder.forward(x)
133145
return cond
134146

135-
def encode_stochastic(self, x, cond):
136-
out = self.eval_sampler.ddim_reverse_sample_loop(
137-
self.ema_model, x, model_kwargs={'cond': cond})
147+
def encode_stochastic(self, x, cond, T=None):
148+
if T is None:
149+
sampler = self.eval_sampler
150+
else:
151+
sampler = self.conf._make_diffusion_conf(T).make_sampler()
152+
out = sampler.ddim_reverse_sample_loop(self.ema_model,
153+
x,
154+
model_kwargs={'cond': cond})
138155
return out['sample']
139156

140157
def forward(self, noise=None, x_start=None, ema_model: bool = False):

install_requirements_for_colab.sh

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
!pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 pytorch-lightning==1.2.2 torchtext==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
2+
!pip install scipy==1.5.4
3+
!pip install numpy==1.19.5
4+
!pip install tqdm
5+
!pip install pytorch-fid==0.2.0
6+
!pip install pandas==1.1.5
7+
!pip install lpips==0.1.4
8+
!pip install lmdb==1.2.1
9+
!pip install ftfy
10+
!pip install regex
11+
!pip install dlib requests

manipulate.ipynb

+18-19
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)