This repository was archived by the owner on Feb 9, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 95
/
Copy pathvae_gaussian.py
56 lines (46 loc) · 1.88 KB
/
vae_gaussian.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from torch.nn import Module
from .common import *
from .encoders import *
from .diffusion import *
class GaussianVAE(Module):
def __init__(self, args):
super().__init__()
self.args = args
self.encoder = PointNetEncoder(args.latent_dim)
self.diffusion = DiffusionPoint(
net = PointwiseNet(point_dim=3, context_dim=args.latent_dim, residual=args.residual),
var_sched = VarianceSchedule(
num_steps=args.num_steps,
beta_1=args.beta_1,
beta_T=args.beta_T,
mode=args.sched_mode
)
)
def get_loss(self, x, writer=None, it=None, kl_weight=1.0):
"""
Args:
x: Input point clouds, (B, N, d).
"""
batch_size, _, _ = x.size()
z_mu, z_sigma = self.encoder(x)
z = reparameterize_gaussian(mean=z_mu, logvar=z_sigma) # (B, F)
log_pz = standard_normal_logprob(z).sum(dim=1) # (B, ), Independence assumption
entropy = gaussian_entropy(logvar=z_sigma) # (B, )
loss_prior = (- log_pz - entropy).mean()
loss_recons = self.diffusion.get_loss(x, z)
loss = kl_weight * loss_prior + loss_recons
if writer is not None:
writer.add_scalar('train/loss_entropy', -entropy.mean(), it)
writer.add_scalar('train/loss_prior', -log_pz.mean(), it)
writer.add_scalar('train/loss_recons', loss_recons, it)
return loss
def sample(self, z, num_points, flexibility, truncate_std=None):
"""
Args:
z: Input latent, normal random samples with mean=0 std=1, (B, F)
"""
if truncate_std is not None:
z = truncated_normal_(z, mean=0, std=1, trunc_std=truncate_std)
samples = self.diffusion.sample(num_points, context=z, flexibility=flexibility)
return samples