forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfogain_selector.py
98 lines (79 loc) · 2.84 KB
/
infogain_selector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import numpy as np
from scipy.special import gammaln, psi
from scipy.stats import dirichlet
'''
Legacy numerical solution.
Should not be used as it is probably broken
def make_range(*x):
"""
constructs leftover values for the simplex given the first k entries
(0,x_k) = 1-(x_1+...+x_(k-1))
"""
return (0, max(0, 1 - sum(x)))
def relative_entropy(p, q):
"""
relative entropy of the two given dirichlet distributions
"""
def tmp(*x):
"""
First adds the last always forced entry to the input (the last x_last = 1-(x_1+...+x_(N)) )
Then computes the relative entropy of posterior and prior for that datapoint
"""
x_new = np.append(x, 1 - sum(x))
return p(x_new) * log2(p(x_new) / q(x_new))
return tmp
def naive_monte_carlo_integral(fun, dim, samples=10_000_000):
s = np.random.rand(dim - 1, samples)
s = np.sort(np.concatenate((np.zeros((1, samples)), s, np.ones((1, samples)))), 0)
# print(s)
pos = np.diff(s, axis=0)
# print(pos)
res = fun(pos)
return np.mean(res)
def infogain(a_post, a_prior):
raise (
"""For the love of good don't use this:
it's insanely poorly conditioned, the worst numerical code I have ever written
and it's slow as molasses. Use the analytic solution instead.
Maybe remove
"""
)
args = len(a_prior)
p = dirichlet(a_post).pdf
q = dirichlet(a_prior).pdf
(info, _) = nquad(relative_entropy(p, q), [make_range for _ in range(args - 1)], opts={"epsabs": 1e-8})
# info = naive_monte_carlo_integral(relative_entropy(p,q), len(a_post))
return info
'''
def analytic_solution(a_post, a_prior):
"""
Analytic solution to the KL-divergence between two dirichlet distributions.
Proof is in the Notion design doc.
"""
post_sum = np.sum(a_post)
prior_sum = np.sum(a_prior)
info = (
gammaln(post_sum)
- gammaln(prior_sum)
- np.sum(gammaln(a_post))
+ np.sum(gammaln(a_prior))
- np.sum((a_post - a_prior) * (psi(a_post) - psi(post_sum)))
)
return info
def uniform_expected_infogain(a_prior):
mean_weight = dirichlet.mean(a_prior)
results = []
for i, w in enumerate(mean_weight):
a_post = a_prior.copy()
a_post[i] = a_post[i] + 1
results.append(w * analytic_solution(a_post, a_prior))
return np.sum(results)
if __name__ == "__main__":
a_prior = np.array([1, 1, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
a_post = np.array([1, 1, 20, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
print("algebraic", analytic_solution(a_post, a_prior))
# print("raw",infogain(a_post, a_prior))
print("large infogain", uniform_expected_infogain(a_prior))
print("post infogain", uniform_expected_infogain(a_post))
# a_prior = np.array([1,1,1000])
# print("small infogain",uniform_expected_infogain(a_prior))