Skip to content

Commit 45a62da

Browse files
committed
top_k
1 parent d63589d commit 45a62da

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

labml_nn/sampling/experiment.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from labml_nn.sampling import Sampler
88
from labml_nn.sampling.greedy import GreedySampler
99
from labml_nn.sampling.temperature import TemperatureSampler
10+
from labml_nn.sampling.top_k import TopKSampler
1011
from labml_nn.sampling.utils import get_model_dataset
1112

1213

@@ -39,13 +40,20 @@ def main():
3940
model, ds = get_model_dataset('074d4004cc6b11ecad7a0242ac1c0002')
4041
model.eval()
4142

42-
# main(GreedySampler(), 16, 16, 128, 'It is')
43-
with monit.section('temperature=1.'):
44-
sample(model, ds, TemperatureSampler(1.), 4, 32, 128, 'It is')
45-
with monit.section('temperature=.1'):
46-
sample(model, ds, TemperatureSampler(.1), 4, 32, 128, 'It is')
47-
with monit.section('temperature=10.'):
48-
sample(model, ds, TemperatureSampler(10.), 4, 32, 128, 'It is')
43+
with monit.section('greedy'):
44+
sample(model, ds, GreedySampler(), 4, 32, 128, 'It is')
45+
#
46+
# with monit.section('temperature=1.'):
47+
# sample(model, ds, TemperatureSampler(1.), 4, 32, 128, 'It is')
48+
# with monit.section('temperature=.1'):
49+
# sample(model, ds, TemperatureSampler(.1), 4, 32, 128, 'It is')
50+
# with monit.section('temperature=10.'):
51+
# sample(model, ds, TemperatureSampler(10.), 4, 32, 128, 'It is')
52+
53+
with monit.section('top_k=5'):
54+
sample(model, ds, TopKSampler(2, TemperatureSampler(1.)), 4, 32, 128, 'It is')
55+
56+
4957

5058

5159
if __name__ == '__main__':

labml_nn/sampling/top_k.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import torch
2+
from torch import nn
3+
from torch.distributions import Categorical
4+
5+
from labml_nn.sampling import Sampler
6+
7+
8+
class TopKSampler(Sampler):
9+
def __init__(self, k: int, sampler: Sampler):
10+
self.k = k
11+
self.sampler = sampler
12+
13+
def __call__(self, logits: torch.Tensor):
14+
zeros = logits.new_ones(logits.shape) * float('-inf')
15+
values, indices = torch.topk(logits, self.k, dim=-1)
16+
zeros.scatter_(-1, indices, values)
17+
18+
return self.sampler(zeros)

0 commit comments

Comments
 (0)