|
7 | 7 | from labml_nn.sampling import Sampler |
8 | 8 | from labml_nn.sampling.greedy import GreedySampler |
9 | 9 | from labml_nn.sampling.temperature import TemperatureSampler |
| 10 | +from labml_nn.sampling.top_k import TopKSampler |
10 | 11 | from labml_nn.sampling.utils import get_model_dataset |
11 | 12 |
|
12 | 13 |
|
@@ -39,13 +40,20 @@ def main(): |
39 | 40 | model, ds = get_model_dataset('074d4004cc6b11ecad7a0242ac1c0002') |
40 | 41 | model.eval() |
41 | 42 |
|
42 | | - # main(GreedySampler(), 16, 16, 128, 'It is') |
43 | | - with monit.section('temperature=1.'): |
44 | | - sample(model, ds, TemperatureSampler(1.), 4, 32, 128, 'It is') |
45 | | - with monit.section('temperature=.1'): |
46 | | - sample(model, ds, TemperatureSampler(.1), 4, 32, 128, 'It is') |
47 | | - with monit.section('temperature=10.'): |
48 | | - sample(model, ds, TemperatureSampler(10.), 4, 32, 128, 'It is') |
| 43 | + with monit.section('greedy'): |
| 44 | + sample(model, ds, GreedySampler(), 4, 32, 128, 'It is') |
| 45 | + # |
| 46 | + # with monit.section('temperature=1.'): |
| 47 | + # sample(model, ds, TemperatureSampler(1.), 4, 32, 128, 'It is') |
| 48 | + # with monit.section('temperature=.1'): |
| 49 | + # sample(model, ds, TemperatureSampler(.1), 4, 32, 128, 'It is') |
| 50 | + # with monit.section('temperature=10.'): |
| 51 | + # sample(model, ds, TemperatureSampler(10.), 4, 32, 128, 'It is') |
| 52 | + |
| 53 | + with monit.section('top_k=5'): |
| 54 | + sample(model, ds, TopKSampler(2, TemperatureSampler(1.)), 4, 32, 128, 'It is') |
| 55 | + |
| 56 | + |
49 | 57 |
|
50 | 58 |
|
51 | 59 | if __name__ == '__main__': |
|
0 commit comments