@@ -92,7 +92,14 @@ def denormalize(self, cond):
92
92
self .device )
93
93
return cond
94
94
95
- def sample (self , N , device ):
95
+ def sample (self , N , device , T = None , T_latent = None ):
96
+ if T is None :
97
+ sampler = self .eval_sampler
98
+ latent_sampler = self .latent_sampler
99
+ else :
100
+ sampler = self .conf ._make_diffusion_conf (T ).make_sampler ()
101
+ latent_sampler = self .conf ._make_latent_diffusion_conf (T_latent ).make_sampler ()
102
+
96
103
noise = torch .randn (N ,
97
104
3 ,
98
105
self .conf .img_size ,
@@ -102,26 +109,31 @@ def sample(self, N, device):
102
109
self .conf ,
103
110
self .ema_model ,
104
111
noise ,
105
- sampler = self . eval_sampler ,
106
- latent_sampler = self . eval_latent_sampler ,
112
+ sampler = sampler ,
113
+ latent_sampler = latent_sampler ,
107
114
conds_mean = self .conds_mean ,
108
115
conds_std = self .conds_std ,
109
116
)
110
117
pred_img = (pred_img + 1 ) / 2
111
118
return pred_img
112
119
113
- def render (self , noise , cond = None ):
120
+ def render (self , noise , cond = None , T = None ):
121
+ if T is None :
122
+ sampler = self .eval_sampler
123
+ else :
124
+ sampler = self .conf ._make_diffusion_conf (T ).make_sampler ()
125
+
114
126
if cond is not None :
115
127
pred_img = render_condition (self .conf ,
116
128
self .ema_model ,
117
129
noise ,
118
- sampler = self . eval_sampler ,
130
+ sampler = sampler ,
119
131
cond = cond )
120
132
else :
121
133
pred_img = render_uncondition (self .conf ,
122
134
self .ema_model ,
123
135
noise ,
124
- sampler = self . eval_sampler ,
136
+ sampler = sampler ,
125
137
latent_sampler = None )
126
138
pred_img = (pred_img + 1 ) / 2
127
139
return pred_img
@@ -132,9 +144,14 @@ def encode(self, x):
132
144
cond = self .ema_model .encoder .forward (x )
133
145
return cond
134
146
135
- def encode_stochastic (self , x , cond ):
136
- out = self .eval_sampler .ddim_reverse_sample_loop (
137
- self .ema_model , x , model_kwargs = {'cond' : cond })
147
+ def encode_stochastic (self , x , cond , T = None ):
148
+ if T is None :
149
+ sampler = self .eval_sampler
150
+ else :
151
+ sampler = self .conf ._make_diffusion_conf (T ).make_sampler ()
152
+ out = sampler .ddim_reverse_sample_loop (self .ema_model ,
153
+ x ,
154
+ model_kwargs = {'cond' : cond })
138
155
return out ['sample' ]
139
156
140
157
def forward (self , noise = None , x_start = None , ema_model : bool = False ):
0 commit comments