Skip to content

Commit cb2eb46

Browse files
committed
add wandb logging
1 parent f57f46e commit cb2eb46

File tree

7 files changed

+37
-20
lines changed

7 files changed

+37
-20
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,4 +93,4 @@ external/
9393

9494
draws/
9595
plot/
96-
96+
wandb/

experiments/infinity_coco/hrnet/w48_384x288_adam_lr1e-3.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ OUTPUT_DIR: "outputs/output_new_infinity"
99
LOG_DIR: "logs/output_new_infinity"
1010
WORKERS: 12
1111
PRINT_FREQ: 10
12+
LOG_WANDB: True
1213

1314
DATASET:
1415
COLOR_RGB: true

experiments/infinity_coco/hrnet/w48_384x288_adam_lr1e-3_local.yaml

+20-19
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ OUTPUT_DIR: "outputs/output_new_infinity"
99
LOG_DIR: "logs/output_new_infinity"
1010
WORKERS: 12
1111
PRINT_FREQ: 10
12+
LOG_WANDB: True
1213

1314
DATASET:
1415
COLOR_RGB: true
@@ -57,13 +58,13 @@ MODEL:
5758
- "conv2"
5859
- "bn2"
5960
- "layer1"
60-
# - "transition1"
61-
# - "stage2"
62-
# - "transition2"
63-
# - "stage3"
64-
# - "transition3"
65-
# - "stage4"
66-
PRETRAIN_FINAL_LAYER: false
61+
- "transition1"
62+
- "stage2"
63+
- "transition2"
64+
- "stage3"
65+
- "transition3"
66+
- "stage4"
67+
PRETRAIN_FINAL_LAYER: true
6768
FINAL_CONV_KERNEL: 1
6869
STAGE2:
6970
NUM_MODULES: 1
@@ -73,8 +74,8 @@ MODEL:
7374
- 4
7475
- 4
7576
NUM_CHANNELS:
76-
- 64
77-
- 128
77+
- 48
78+
- 96
7879
FUSE_METHOD: SUM
7980
STAGE3:
8081
NUM_MODULES: 4
@@ -85,9 +86,9 @@ MODEL:
8586
- 4
8687
- 4
8788
NUM_CHANNELS:
88-
- 64
89-
- 128
90-
- 256
89+
- 48
90+
- 96
91+
- 192
9192
FUSE_METHOD: SUM
9293
STAGE4:
9394
NUM_MODULES: 3
@@ -99,20 +100,20 @@ MODEL:
99100
- 4
100101
- 4
101102
NUM_CHANNELS:
102-
- 64
103-
- 128
104-
- 256
105-
- 512
103+
- 48
104+
- 96
105+
- 192
106+
- 384
106107
FUSE_METHOD: SUM
107108
LOSS:
108109
USE_TARGET_WEIGHT: true
109110
TRAIN:
110-
BATCH_SIZE_PER_GPU: 1
111+
BATCH_SIZE_PER_GPU: 2
111112
SHUFFLE: true
112113
BEGIN_EPOCH: 0
113114
END_EPOCH: 200
114115
OPTIMIZER: adam
115-
LR: 0.00001
116+
LR: 0.001
116117
LR_FACTOR: 0.1
117118
LR_STEP:
118119
- 170
@@ -123,7 +124,7 @@ TRAIN:
123124
MOMENTUM: 0.9
124125
NESTEROV: false
125126
TEST:
126-
BATCH_SIZE_PER_GPU: 1
127+
BATCH_SIZE_PER_GPU: 2
127128
COCO_BBOX_FILE: "data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json"
128129
BBOX_THRE: 1.0
129130
IMAGE_THRE: 0.0

experiments/infinity_coco/hrnet/w64_384x288_adam_lr1e-3.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ OUTPUT_DIR: "outputs/output_new_infinity_64"
99
LOG_DIR: "logs/output_new_infinity_64"
1010
WORKERS: 12
1111
PRINT_FREQ: 10
12+
LOG_WANDB: True
1213

1314
DATASET:
1415
COLOR_RGB: true

lib/config/default.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
_C.AUTO_RESUME = False
2222
_C.PIN_MEMORY = True
2323
_C.RANK = 0
24+
_C.LOG_WANDB = False
2425

2526
# Cudnn related params
2627
_C.CUDNN = CN()

lib/core/function.py

+10
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import numpy as np
1414
import torch
15+
import wandb
1516
from core.evaluate import accuracy
1617
from core.inference import get_final_preds
1718
from utils.transforms import flip_back
@@ -102,6 +103,15 @@ def train(
102103
)
103104
)
104105
logger.info(msg)
106+
if config.LOG_WANDB:
107+
wandb.log(
108+
{
109+
"epoch": epoch,
110+
"loss_avg": losses.avg,
111+
"accuracy_avg": acc.avg,
112+
"speed": input.size(0) / batch_time.val,
113+
}
114+
)
105115

106116
writer = writer_dict["writer"]
107117
global_steps = writer_dict["train_global_steps"]

tools/train.py

+3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch.utils.data
2121
import torch.utils.data.distributed
2222
import torchvision.transforms as transforms
23+
import wandb
2324
from config import cfg, update_config
2425
from core.function import train, validate
2526
from core.loss import JointsMSELoss
@@ -66,6 +67,8 @@ def main():
6667

6768
logger.info(pprint.pformat(args))
6869
logger.info(cfg)
70+
if cfg.LOG_WANDB:
71+
wandb.init(project="synthetic_finetuning", entity="yonigoz", config=cfg)
6972

7073
# cudnn related setting
7174
cudnn.benchmark = cfg.CUDNN.BENCHMARK

0 commit comments

Comments
 (0)