|
| 1 | +"""Test for cli.py""" |
| 2 | + |
| 3 | +from math import cos |
| 4 | +import os |
| 5 | +import subprocess |
| 6 | +from unittest.mock import Mock |
| 7 | +from unittest.mock import patch |
| 8 | + |
| 9 | +import numpy as np |
| 10 | +import pytest |
| 11 | +from sklearn.pipeline import FeatureUnion |
| 12 | +from sklearn.pipeline import Pipeline |
| 13 | +from sklearn.preprocessing import MinMaxScaler |
| 14 | +from torch import nn |
| 15 | +from torch.nn import RReLU |
| 16 | + |
| 17 | + |
| 18 | +fire_installed = True |
| 19 | +try: |
| 20 | + import fire |
| 21 | +except ImportError: |
| 22 | + fire_installed = False |
| 23 | + |
| 24 | + |
| 25 | +@pytest.mark.skipif(not fire_installed, reason='fire libarary not installed') |
| 26 | +class TestCli: |
| 27 | + @pytest.fixture |
| 28 | + def resolve_dotted_name(self): |
| 29 | + from skorch.cli import _resolve_dotted_name |
| 30 | + return _resolve_dotted_name |
| 31 | + |
| 32 | + @pytest.mark.parametrize('name, expected', [ |
| 33 | + (0, 0), |
| 34 | + (1.23, 1.23), |
| 35 | + ('foo', 'foo'), |
| 36 | + ('math.cos', cos), |
| 37 | + ('torch.nn', nn), |
| 38 | + ('torch.nn.ReLU', nn.ReLU), |
| 39 | + ]) |
| 40 | + def test_resolve_dotted_name(self, resolve_dotted_name, name, expected): |
| 41 | + result = resolve_dotted_name(name) |
| 42 | + assert result == expected |
| 43 | + |
| 44 | + def test_resolve_dotted_name_instantiated(self, resolve_dotted_name): |
| 45 | + result = resolve_dotted_name('torch.nn.RReLU(0.123, upper=0.456)') |
| 46 | + assert isinstance(result, RReLU) |
| 47 | + assert np.isclose(result.lower, 0.123) |
| 48 | + assert np.isclose(result.upper, 0.456) |
| 49 | + |
| 50 | + @pytest.fixture |
| 51 | + def parse_net_kwargs(self): |
| 52 | + from skorch.cli import parse_net_kwargs |
| 53 | + return parse_net_kwargs |
| 54 | + |
| 55 | + def test_parse_net_kwargs(self, parse_net_kwargs): |
| 56 | + kwargs = { |
| 57 | + 'lr': 0.05, |
| 58 | + 'max_epochs': 5, |
| 59 | + 'module__num_units': 10, |
| 60 | + 'module__nonlin': 'torch.nn.RReLU(0.123, upper=0.456)', |
| 61 | + } |
| 62 | + parsed_kwargs = parse_net_kwargs(kwargs) |
| 63 | + |
| 64 | + assert len(parsed_kwargs) == 4 |
| 65 | + assert np.isclose(parsed_kwargs['lr'], 0.05) |
| 66 | + assert parsed_kwargs['max_epochs'] == 5 |
| 67 | + assert parsed_kwargs['module__num_units'] == 10 |
| 68 | + assert isinstance(parsed_kwargs['module__nonlin'], RReLU) |
| 69 | + assert np.isclose(parsed_kwargs['module__nonlin'].lower, 0.123) |
| 70 | + assert np.isclose(parsed_kwargs['module__nonlin'].upper, 0.456) |
| 71 | + |
| 72 | + @pytest.fixture |
| 73 | + def net_cls(self): |
| 74 | + from skorch import NeuralNetClassifier |
| 75 | + return NeuralNetClassifier |
| 76 | + |
| 77 | + @pytest.fixture |
| 78 | + def net(self, net_cls, classifier_module): |
| 79 | + return net_cls(classifier_module) |
| 80 | + |
| 81 | + @pytest.fixture |
| 82 | + def pipe(self, net): |
| 83 | + return Pipeline([ |
| 84 | + ('features', FeatureUnion([ |
| 85 | + ('scale', MinMaxScaler()), |
| 86 | + ])), |
| 87 | + ('net', net), |
| 88 | + ]) |
| 89 | + |
| 90 | + @pytest.fixture |
| 91 | + def yield_estimators(self): |
| 92 | + from skorch.cli import _yield_estimators |
| 93 | + return _yield_estimators |
| 94 | + |
| 95 | + def test_yield_estimators_net(self, yield_estimators, net): |
| 96 | + result = list(yield_estimators(net)) |
| 97 | + |
| 98 | + assert result[0][0] == '' |
| 99 | + assert result[0][1] is net |
| 100 | + assert result[1][0] == 'module' |
| 101 | + assert result[1][1] is net.module |
| 102 | + |
| 103 | + def test_yield_estimators_pipe(self, yield_estimators, pipe): |
| 104 | + result = list(yield_estimators(pipe)) |
| 105 | + scaler = pipe.named_steps['features'].transformer_list[0][1] |
| 106 | + net = pipe.named_steps['net'] |
| 107 | + module = net.module |
| 108 | + |
| 109 | + assert result[0][0] == 'features__scale' |
| 110 | + assert result[0][1] is scaler |
| 111 | + assert result[1][0] == 'net' |
| 112 | + assert result[1][1] is net |
| 113 | + assert result[2][0] == 'net__module' |
| 114 | + assert result[2][1] is module |
| 115 | + |
| 116 | + @pytest.fixture |
| 117 | + def substitute_default(self): |
| 118 | + from skorch.cli import _substitute_default |
| 119 | + return _substitute_default |
| 120 | + |
| 121 | + @pytest.mark.parametrize('s, new_value, expected', [ |
| 122 | + ('', '', ''), |
| 123 | + ('', 'foo', ''), |
| 124 | + ('bar', 'foo', 'bar'), |
| 125 | + ('int (default=128)', '', 'int (default=)'), |
| 126 | + ('int (default=128)', None, 'int (default=128)'), |
| 127 | + ('int (default=128)', '""', 'int (default="")'), |
| 128 | + ('int (default=128)', '128', 'int (default=128)'), |
| 129 | + ('int (default=128)', '256', 'int (default=256)'), |
| 130 | + ('int (default=128)', 256, 'int (default=256)'), |
| 131 | + ('with_parens (default=(1, 2))', (3, 4), 'with_parens (default=(3, 4))'), |
| 132 | + ('int (default =128)', '256', 'int (default =256)'), |
| 133 | + ('int (default= 128)', '256', 'int (default= 256)'), |
| 134 | + ('int (default = 128)', '256', 'int (default = 256)'), |
| 135 | + ( |
| 136 | + 'nonlin (default = ReLU())', |
| 137 | + nn.Hardtanh(1, 2), |
| 138 | + 'nonlin (default = {})'.format(nn.Hardtanh(1, 2)) |
| 139 | + ), |
| 140 | + ( |
| 141 | + # from sklearn MinMaxScaler |
| 142 | + 'tuple (min, max), default=(0, 1)', |
| 143 | + (-1, 1), |
| 144 | + 'tuple (min, max), default=(-1, 1)' |
| 145 | + ), |
| 146 | + ( |
| 147 | + # from sklearn MinMaxScaler |
| 148 | + 'boolean, optional, default True', |
| 149 | + False, |
| 150 | + 'boolean, optional, default False' |
| 151 | + ), |
| 152 | + ( |
| 153 | + # from sklearn Normalizer |
| 154 | + "'l1', 'l2', or 'max', optional ('l2' by default)", |
| 155 | + 'l1', |
| 156 | + "'l1', 'l2', or 'max', optional ('l1' by default)" |
| 157 | + ), |
| 158 | + ( |
| 159 | + # same but double ticks |
| 160 | + '"l1", "l2", or "max", optional ("l2" by default)', |
| 161 | + 'l1', |
| 162 | + '"l1", "l2", or "max", optional ("l1" by default)' |
| 163 | + ), |
| 164 | + ( |
| 165 | + # same but no ticks |
| 166 | + "l1, l2, or max, optional (l2 by default)", |
| 167 | + 'l1', |
| 168 | + "l1, l2, or max, optional (l1 by default)" |
| 169 | + ), |
| 170 | + ( |
| 171 | + "tuple, optional ((1, 1) by default)", |
| 172 | + (2, 2), |
| 173 | + "tuple, optional ((2, 2) by default)" |
| 174 | + ), |
| 175 | + ( |
| 176 | + "nonlin (ReLU() by default)", |
| 177 | + nn.Tanh(), |
| 178 | + "nonlin (Tanh() by default)" |
| 179 | + ), |
| 180 | + ]) |
| 181 | + def test_replace_default(self, substitute_default, s, new_value, expected): |
| 182 | + result = substitute_default(s, new_value) |
| 183 | + assert result == expected |
| 184 | + |
| 185 | + @pytest.fixture |
| 186 | + def print_help(self): |
| 187 | + from skorch.cli import print_help |
| 188 | + return print_help |
| 189 | + |
| 190 | + def test_print_help_net(self, print_help, net, capsys): |
| 191 | + print_help(net) |
| 192 | + out = capsys.readouterr()[0] |
| 193 | + |
| 194 | + expected_snippets = [ |
| 195 | + '-- --help', |
| 196 | + '<NeuralNetClassifier> options', |
| 197 | + '--module : torch module (class or instance)', |
| 198 | + '--batch_size : int (default=128)', |
| 199 | + '<MLPModule> options', |
| 200 | + '--module__hidden_units : int (default=10)' |
| 201 | + ] |
| 202 | + for snippet in expected_snippets: |
| 203 | + assert snippet in out |
| 204 | + |
| 205 | + def test_print_help_net_custom_defaults(self, print_help, net, capsys): |
| 206 | + defaults = {'batch_size': 256, 'module__hidden_units': 55} |
| 207 | + print_help(net, defaults) |
| 208 | + out = capsys.readouterr()[0] |
| 209 | + |
| 210 | + expected_snippets = [ |
| 211 | + '-- --help', |
| 212 | + '<NeuralNetClassifier> options', |
| 213 | + '--module : torch module (class or instance)', |
| 214 | + '--batch_size : int (default=256)', |
| 215 | + '<MLPModule> options', |
| 216 | + '--module__hidden_units : int (default=55)' |
| 217 | + ] |
| 218 | + for snippet in expected_snippets: |
| 219 | + assert snippet in out |
| 220 | + |
| 221 | + def test_print_help_pipeline(self, print_help, pipe, capsys): |
| 222 | + print_help(pipe) |
| 223 | + out = capsys.readouterr()[0] |
| 224 | + |
| 225 | + expected_snippets = [ |
| 226 | + '-- --help', |
| 227 | + '<MinMaxScaler> options', |
| 228 | + '--features__scale__feature_range', |
| 229 | + '<NeuralNetClassifier> options', |
| 230 | + '--net__module : torch module (class or instance)', |
| 231 | + '--net__batch_size : int (default=128)', |
| 232 | + '<MLPModule> options', |
| 233 | + '--net__module__hidden_units : int (default=10)' |
| 234 | + ] |
| 235 | + for snippet in expected_snippets: |
| 236 | + assert snippet in out |
| 237 | + |
| 238 | + def test_print_help_pipeline_custom_defaults( |
| 239 | + self, print_help, pipe, capsys): |
| 240 | + defaults = {'net__batch_size': 256, 'net__module__hidden_units': 55} |
| 241 | + print_help(pipe, defaults=defaults) |
| 242 | + out = capsys.readouterr()[0] |
| 243 | + |
| 244 | + expected_snippets = [ |
| 245 | + '-- --help', |
| 246 | + '<MinMaxScaler> options', |
| 247 | + '--features__scale__feature_range', |
| 248 | + '<NeuralNetClassifier> options', |
| 249 | + '--net__module : torch module (class or instance)', |
| 250 | + '--net__batch_size : int (default=256)', |
| 251 | + '<MLPModule> options', |
| 252 | + '--net__module__hidden_units : int (default=55)' |
| 253 | + ] |
| 254 | + for snippet in expected_snippets: |
| 255 | + assert snippet in out |
| 256 | + |
| 257 | + @pytest.fixture |
| 258 | + def parse_args(self): |
| 259 | + from skorch.cli import parse_args |
| 260 | + return parse_args |
| 261 | + |
| 262 | + @pytest.fixture |
| 263 | + def estimator(self, net_cls): |
| 264 | + mock = Mock(net_cls) |
| 265 | + return mock |
| 266 | + |
| 267 | + def test_parse_args_help(self, parse_args, estimator): |
| 268 | + with patch('skorch.cli.sys.exit') as exit: |
| 269 | + with patch('skorch.cli.print_help') as help: |
| 270 | + parsed = parse_args({'help': True, 'foo': 'bar'}) |
| 271 | + parsed(estimator) |
| 272 | + |
| 273 | + assert estimator.set_params.call_count == 0 # kwargs and defaults |
| 274 | + assert help.call_count == 1 |
| 275 | + assert exit.call_count == 1 |
| 276 | + |
| 277 | + def test_parse_args_run(self, parse_args, estimator): |
| 278 | + kwargs = {'foo': 'bar', 'baz': 'math.cos'} |
| 279 | + with patch('skorch.cli.sys.exit') as exit: |
| 280 | + with patch('skorch.cli.print_help') as help: |
| 281 | + parsed = parse_args(kwargs) |
| 282 | + parsed(estimator) |
| 283 | + |
| 284 | + assert estimator.set_params.call_count == 2 # defaults and kwargs |
| 285 | + |
| 286 | + defaults_set_params = estimator.set_params.call_args_list[0][1] |
| 287 | + assert not defaults_set_params # no defaults specified |
| 288 | + |
| 289 | + kwargs_set_params = estimator.set_params.call_args_list[1][1] |
| 290 | + assert kwargs_set_params['foo'] == 'bar' |
| 291 | + assert kwargs_set_params['baz'] == cos |
| 292 | + |
| 293 | + assert help.call_count == 0 |
| 294 | + assert exit.call_count == 0 |
| 295 | + |
| 296 | + def test_parse_args_net_custom_defaults(self, parse_args, net): |
| 297 | + defaults = {'batch_size': 256, 'module__hidden_units': 55} |
| 298 | + kwargs = {'batch_size': 123, 'module__nonlin': nn.Hardtanh(1, 2)} |
| 299 | + parsed = parse_args(kwargs, defaults) |
| 300 | + net = parsed(net) |
| 301 | + |
| 302 | + # cmd line args have precedence over defaults |
| 303 | + assert net.batch_size == 123 |
| 304 | + assert net.module_.hidden_units == 55 |
| 305 | + assert isinstance(net.module_.nonlin, nn.Hardtanh) |
| 306 | + assert net.module_.nonlin.min_val == 1 |
| 307 | + assert net.module_.nonlin.max_val == 2 |
| 308 | + |
| 309 | + def test_parse_args_pipe_custom_defaults(self, parse_args, pipe): |
| 310 | + defaults = {'net__batch_size': 256, 'net__module__hidden_units': 55} |
| 311 | + kwargs = {'net__batch_size': 123, 'net__module__nonlin': nn.Hardtanh(1, 2)} |
| 312 | + parsed = parse_args(kwargs, defaults) |
| 313 | + pipe = parsed(pipe) |
| 314 | + net = pipe.steps[-1][1] |
| 315 | + |
| 316 | + # cmd line args have precedence over defaults |
| 317 | + assert net.batch_size == 123 |
| 318 | + assert net.module_.hidden_units == 55 |
| 319 | + assert isinstance(net.module_.nonlin, nn.Hardtanh) |
| 320 | + assert net.module_.nonlin.min_val == 1 |
| 321 | + assert net.module_.nonlin.max_val == 2 |
0 commit comments