-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgumbel_rao.py
83 lines (65 loc) · 3.24 KB
/
gumbel_rao.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""Implementation of the straight-through gumbel-rao estimator.
Paper: "Rao-Blackwellizing the Straight-Through Gumbel-Softmax
Gradient Estimator" <https://arxiv.org/abs/2010.04838>.
Note: the implementation here differs from the paper in that we DO NOT
propagate gradients through the conditional G_k | D
reparameterization. The paper states that:
Note that the total derivative d(softmax_τ(θ + Gk))/dθ is taken
through both θ and Gk. For the case K = 1, our estimator reduces to
the standard ST-GS estimator.
I believe that this is a mistake - the expectation of this estimator
is only equal to that of ST-GS if the derivative is *not* taken
through G_k, as the ST-GS estimator ∂f(D)/dD d(softmax_τ(θ + G))/dθ
does not.
With the derivative ignored through G_k, the value of this estimator
with k=1 is numerically equal to that of ST-GS, and as k->∞ the
estimator for any given outcome D converges to the expectation of
ST-GS over G conditional on D.
"""
import torch
@torch.no_grad()
def conditional_gumbel(logits, D, k=1):
"""Outputs k samples of Q = StandardGumbel(), such that argmax(logits
+ Q) is given by D (one hot vector)."""
# iid. exponential
E = torch.distributions.exponential.Exponential(rate=torch.ones_like(logits)).sample([k])
# E of the chosen class
Ei = (D * E).sum(dim=-1, keepdim=True)
# partition function (normalization constant)
Z = logits.exp().sum(dim=-1, keepdim=True)
# Sampled gumbel-adjusted logits
adjusted = (D * (-torch.log(Ei) + torch.log(Z)) +
(1 - D) * -torch.log(E/torch.exp(logits) + Ei / Z))
return adjusted - logits
def exact_conditional_gumbel(logits, D, k=1):
"""Same as conditional_gumbel but uses rejection sampling."""
# Rejection sampling.
idx = D.argmax(dim=-1)
gumbels = []
while len(gumbels) < k:
gumbel = torch.rand_like(logits).log().neg().log().neg()
if logits.add(gumbel).argmax() == idx:
gumbels.append(gumbel)
return torch.stack(gumbels)
def replace_gradient(value, surrogate):
"""Returns `value` but backpropagates gradients through `surrogate`."""
return surrogate + (value - surrogate).detach()
def gumbel_rao(logits, k, temp=1.0, I=None):
"""Returns a categorical sample from logits (over axis=-1) as a
one-hot vector, with gumbel-rao gradient.
k: integer number of samples to use in the rao-blackwellization.
1 sample reduces to straight-through gumbel-softmax.
I: optional, categorical sample to use instead of drawing a new
sample. Should be a tensor(shape=logits.shape[:-1], dtype=int64).
"""
num_classes = logits.shape[-1]
if I is None:
I = torch.distributions.categorical.Categorical(logits=logits).sample()
D = torch.nn.functional.one_hot(I, num_classes).float()
adjusted = logits + conditional_gumbel(logits, D, k=k)
surrogate = torch.nn.functional.softmax(adjusted/temp, dim=-1).mean(dim=0)
return replace_gradient(D, surrogate)
# >>> exact_conditional_gumbel(torch.tensor([[1.0,2.0, 3.0]]), torch.tensor([[0.0, 1.0, 0.0]]), k=10000).std(dim=0)
# tensor([[0.9952, 1.2695, 0.8132]])
# >>> conditional_gumbel(torch.tensor([[1.0,2.0, 3.0]]), torch.tensor([[0.0, 1.0, 0.0]]), k=10000).std(dim=0)
# tensor([[0.9905, 1.2951, 0.8148]])