Skip to content

Commit 3dbe0ae

Browse files
MattAlexMiracleAlexander Mattick
andauthored
Implement task selection (LAION-AI#383)
* commented out legacy numerical solver * added comments and task_scheduling for selecting which task to serve to users * removed standalone task weighting * pre-commit hook rerun Co-authored-by: Alexander Mattick <alex.mattick@fau.de>
1 parent 8942194 commit 3dbe0ae

File tree

3 files changed

+102
-24
lines changed

3 files changed

+102
-24
lines changed

scripts/postprocessing/infogain_selector.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import numpy as np
2-
from scipy import log2
3-
from scipy.integrate import nquad
42
from scipy.special import gammaln, psi
53
from scipy.stats import dirichlet
64

5+
'''
6+
Legacy numerical solution.
7+
Should not be used as it is probably broken
8+
79
810
def make_range(*x):
911
"""
@@ -38,6 +40,23 @@ def naive_monte_carlo_integral(fun, dim, samples=10_000_000):
3840
res = fun(pos)
3941
return np.mean(res)
4042
43+
def infogain(a_post, a_prior):
44+
raise (
45+
"""For the love of good don't use this:
46+
it's insanely poorly conditioned, the worst numerical code I have ever written
47+
and it's slow as molasses. Use the analytic solution instead.
48+
49+
Maybe remove
50+
"""
51+
)
52+
args = len(a_prior)
53+
p = dirichlet(a_post).pdf
54+
q = dirichlet(a_prior).pdf
55+
(info, _) = nquad(relative_entropy(p, q), [make_range for _ in range(args - 1)], opts={"epsabs": 1e-8})
56+
# info = naive_monte_carlo_integral(relative_entropy(p,q), len(a_post))
57+
return info
58+
'''
59+
4160

4261
def analytic_solution(a_post, a_prior):
4362
"""
@@ -57,26 +76,8 @@ def analytic_solution(a_post, a_prior):
5776
return info
5877

5978

60-
def infogain(a_post, a_prior):
61-
raise (
62-
"""For the love of good don't use this:
63-
it's insanely poorly conditioned, the worst numerical code I have ever written
64-
and it's slow as molasses. Use the analytic solution instead.
65-
66-
Maybe remove
67-
"""
68-
)
69-
args = len(a_prior)
70-
p = dirichlet(a_post).pdf
71-
q = dirichlet(a_prior).pdf
72-
(info, _) = nquad(relative_entropy(p, q), [make_range for _ in range(args - 1)], opts={"epsabs": 1e-8})
73-
# info = naive_monte_carlo_integral(relative_entropy(p,q), len(a_post))
74-
return info
75-
76-
7779
def uniform_expected_infogain(a_prior):
7880
mean_weight = dirichlet.mean(a_prior)
79-
print("weight", mean_weight)
8081
results = []
8182
for i, w in enumerate(mean_weight):
8283
a_post = a_prior.copy()

scripts/postprocessing/scoring.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,9 @@ def score_update_prompts(consensus: npt.ArrayLike, voter_data: Voter) -> Voter:
8787
"""
8888
This function returns the gain of points for a given prompt's votes
8989
90-
This function is only to be run when archiving a question
91-
i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information
90+
In contrast to the other score updating functions, we can run this online as new votes come in.
91+
i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information.
92+
9293
9394
Parameters:
9495
consensus (ArrayLike): all votes cast for this question
@@ -100,7 +101,8 @@ def score_update_prompts(consensus: npt.ArrayLike, voter_data: Voter) -> Voter:
100101
# produces the ranking of votes, e.g. for [100,300,200] it returns [0, 2, 1],
101102
# since 100 is the lowest, 300 the highest and 200 the middle value
102103
consensus_ranking = np.arange(len(consensus)) - len(consensus) // 2 + 1
103-
delta_votes = np.sum(consensus_ranking * consensus)
104+
# expected consenus ranking (i.e. normalize the votes and multiply-sum with weightings)
105+
delta_votes = np.sum(consensus_ranking * consensus / sum(consensus))
104106
new_points = delta_votes + voter_data.prompt_points
105107

106108
# we need to correct for 0 indexing, if you are closer to "right" than "wrong" of the conensus,
@@ -133,7 +135,7 @@ def score_update_ranking(user_ranking: npt.ArrayLike, consensus_ranking: npt.Arr
133135
"research design and statistical analyses, second edition, 2003"
134136
the authors note that at least from an significance test POV they will yield the same p-values
135137
136-
Parameters:
138+
Parameters:
137139
user_ranking (ArrayLike): ranking produced by the user
138140
consensus (ArrayLike): ranking produced after running the voting algorithm to merge into the consensus ranking
139141
voter_data (Voter): a "Voter" object that represents the person that wrote the prompt
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from enum import Enum
2+
3+
import numpy as np
4+
from scipy import optimize
5+
6+
7+
class Task(Enum):
8+
RANKING = 0
9+
ANSWER = 1
10+
PROMPT = 2
11+
VOTE = 3
12+
13+
14+
def task_selection(
15+
num_ranking_tasks: int, current_prompts: int, target_num_prompts: int, p: float, answers_per_prompt: int
16+
) -> Task:
17+
"""
18+
This computes which task to serve to the user.
19+
In general, this method aims to get rankable tasks out of the active pool ASAP.
20+
Before checking anything else, we first have a p% probability of running a ranking task.
21+
After that, we can dynamically determine which task to serve by balancing the number of active tasks.
22+
23+
Parameters:
24+
num_ranking_tasks (int): number of prompts that are ready to do ranking (i.e. have "answers_per_prompt" many answers)
25+
current_prompts (int): how many prompts are currently in the active pool
26+
target_num_prompts (int): how many prompts _should_ be in the active pool
27+
p (float): probability to serve a ranking task, if one is available
28+
answers_per_prompt (int): number of answers we want to have per prompt
29+
Returns:
30+
task (Task): the task Enum that corresponds to one of the four tasks
31+
"""
32+
if num_ranking_tasks > 0 and np.random.rand() < p:
33+
return Task.RANKING
34+
rate = 50 / (current_prompts * 2)
35+
prob_prompt_task = 0.5 + (target_num_prompts - current_prompts) * rate
36+
# Yes, I'm too lazy to solve this analytically...
37+
prob_unfinished_prompt = optimize.linprog(
38+
np.array([1, 1]), A_eq=np.array([[1, 1], [1, -answers_per_prompt]]), b_eq=np.array([1, 0]), bounds=(0, None)
39+
).x[0]
40+
if np.random.rand() < prob_prompt_task:
41+
if np.random.rand() < prob_unfinished_prompt:
42+
return Task.ANSWER
43+
else:
44+
return Task.PROMPT
45+
else:
46+
return Task.VOTE
47+
48+
49+
def next_answer_task(possible_prompts, answers_per_prompt):
50+
"""
51+
If the `task_selection`method returns "answer", you can use this method to decide which
52+
prompt should get an answer next.
53+
The goal of this is to finish off the prompts that have almost enough answers collected already:
54+
I.e. if we want 5 answers, this is going to give preferential sampling to those prompts that already
55+
have 4/5 answers.
56+
This helps to not have too much close-to-finished prompts in the active set.
57+
58+
Parameters:
59+
possible_prompts (dict[prompt_id, num_answers]): a dictonary containing all open prompts and the number of answers these prompts currently have.
60+
answers_per_prompt (int): number of answers we per prompt to target
61+
Returns:
62+
prompt_id (int): the prompt_id corresponding to the next prompt that should get a new answer
63+
"""
64+
nums = list(set(possible_prompts.values()))
65+
p = np.array([max(x / answers_per_prompt, 1 / answers_per_prompt) for x in nums])
66+
idx = np.random.choice(nums, p=p / p.sum())
67+
sample = np.random.choice([k for k, v in possible_prompts.items() if v == idx])
68+
return sample
69+
70+
71+
if __name__ == "__main__":
72+
x = task_selection(1, 500, 1000, 0.1, 5)
73+
print(x)
74+
y = next_answer_task({"this": 2, "is": 4, "a": 1, "test": 4}, 5)
75+
print(y)

0 commit comments

Comments
 (0)