Skip to content

Commit bcb62cb

Browse files
goldsboroughfacebook-github-bot
authored andcommitted
Lazily create tensors in optim_baseline (pytorch#12301)
Summary: Tensors cannot be created globally because of static initialization order issues. So tensors for the optim_baseline test must be created lazily instead. This is fine because these functions will only be called once (in the respective test). ezyang Pull Request resolved: pytorch#12301 Differential Revision: D10201008 Pulled By: goldsborough fbshipit-source-id: 59a041f437354e7c6600e5655b3e2d0647dbde9e
1 parent 1962646 commit bcb62cb

File tree

3 files changed

+758
-727
lines changed

3 files changed

+758
-727
lines changed

test/cpp/api/optim.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -191,83 +191,84 @@ TEST(OptimTest, XORConvergence_AdamWithAmsgrad) {
191191
}
192192

193193
TEST(OptimTest, ProducesPyTorchValues_Adam) {
194-
check_exact_values<Adam>(AdamOptions(1.0), expected_parameters::Adam);
194+
check_exact_values<Adam>(AdamOptions(1.0), expected_parameters::Adam());
195195
}
196196

197197
TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecay) {
198198
check_exact_values<Adam>(
199199
AdamOptions(1.0).weight_decay(1e-2),
200-
expected_parameters::Adam_with_weight_decay);
200+
expected_parameters::Adam_with_weight_decay());
201201
}
202202

203203
TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecayAndAMSGrad) {
204204
check_exact_values<Adam>(
205205
AdamOptions(1.0).weight_decay(1e-6).amsgrad(true),
206-
expected_parameters::Adam_with_weight_decay_and_amsgrad);
206+
expected_parameters::Adam_with_weight_decay_and_amsgrad());
207207
}
208208

209209
TEST(OptimTest, ProducesPyTorchValues_Adagrad) {
210210
check_exact_values<Adagrad>(
211-
AdagradOptions(1.0), expected_parameters::Adagrad);
211+
AdagradOptions(1.0), expected_parameters::Adagrad());
212212
}
213213

214214
TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecay) {
215215
check_exact_values<Adagrad>(
216216
AdagradOptions(1.0).weight_decay(1e-2),
217-
expected_parameters::Adagrad_with_weight_decay);
217+
expected_parameters::Adagrad_with_weight_decay());
218218
}
219219

220220
TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecayAndLRDecay) {
221221
check_exact_values<Adagrad>(
222222
AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3),
223-
expected_parameters::Adagrad_with_weight_decay_and_lr_decay);
223+
expected_parameters::Adagrad_with_weight_decay_and_lr_decay());
224224
}
225225

226226
TEST(OptimTest, ProducesPyTorchValues_RMSprop) {
227227
check_exact_values<RMSprop>(
228-
RMSpropOptions(0.1), expected_parameters::RMSprop);
228+
RMSpropOptions(0.1), expected_parameters::RMSprop());
229229
}
230230

231231
TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecay) {
232232
check_exact_values<RMSprop>(
233233
RMSpropOptions(0.1).weight_decay(1e-2),
234-
expected_parameters::RMSprop_with_weight_decay);
234+
expected_parameters::RMSprop_with_weight_decay());
235235
}
236236

237237
TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecayAndCentered) {
238238
check_exact_values<RMSprop>(
239239
RMSpropOptions(0.1).weight_decay(1e-6).centered(true),
240-
expected_parameters::RMSprop_with_weight_decay_and_centered);
240+
expected_parameters::RMSprop_with_weight_decay_and_centered());
241241
}
242242

243243
TEST(
244244
OptimTest,
245245
ProducesPyTorchValues_RMSpropWithWeightDecayAndCenteredAndMomentum) {
246246
check_exact_values<RMSprop>(
247247
RMSpropOptions(0.1).weight_decay(1e-6).centered(true).momentum(0.9),
248-
expected_parameters::RMSprop_with_weight_decay_and_centered_and_momentum);
248+
expected_parameters::
249+
RMSprop_with_weight_decay_and_centered_and_momentum());
249250
}
250251

251252
TEST(OptimTest, ProducesPyTorchValues_SGD) {
252-
check_exact_values<SGD>(SGDOptions(0.1), expected_parameters::SGD);
253+
check_exact_values<SGD>(SGDOptions(0.1), expected_parameters::SGD());
253254
}
254255

255256
TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecay) {
256257
check_exact_values<SGD>(
257258
SGDOptions(0.1).weight_decay(1e-2),
258-
expected_parameters::SGD_with_weight_decay);
259+
expected_parameters::SGD_with_weight_decay());
259260
}
260261

261262
TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndMomentum) {
262263
check_exact_values<SGD>(
263264
SGDOptions(0.1).weight_decay(1e-2).momentum(0.9),
264-
expected_parameters::SGD_with_weight_decay_and_momentum);
265+
expected_parameters::SGD_with_weight_decay_and_momentum());
265266
}
266267

267268
TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndNesterovMomentum) {
268269
check_exact_values<SGD>(
269270
SGDOptions(0.1).weight_decay(1e-6).momentum(0.9).nesterov(true),
270-
expected_parameters::SGD_with_weight_decay_and_nesterov_momentum);
271+
expected_parameters::SGD_with_weight_decay_and_nesterov_momentum());
271272
}
272273

273274
TEST(OptimTest, ZeroGrad) {

0 commit comments

Comments
 (0)