Skip to content

Commit e787220

Browse files
committed
Allow both with tensorflow and without tensorflow to remove the redundant code (train_tb.py).
1 parent e4d4811 commit e787220

File tree

3 files changed

+24
-196
lines changed

3 files changed

+24
-196
lines changed

README.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@ There's something difference compared to neuraltalk2.
55
- Instead of including the convnet in the model, we use preprocessed features. (finetuneable cnn version is in the branch **with_finetune**)
66
- Use resnet101; the same way as in self-critical (the preprocessing code may have bug, haven't tested yet)
77

8-
# TODO:
9-
- Other models
10-
118
# Requirements
129
Python 2.7 (no [coco-caption](https://github.com/tylin/coco-caption) version for python 3), pytorch
1310

@@ -49,7 +46,7 @@ $ python train.py --input_json coco/cocotalk.json --input_json --input_fc_dir da
4946

5047
The train script will take over, and start dumping checkpoints into the folder specified by `checkpoint_path` (default = current folder). For more options, see `opts.py`.
5148

52-
If you have tensorflow, you can run train.py instead of `train_tb.py`. `train_tb.py` saves learning curves by summary writer, and can be visualized using tensorboard.
49+
If you have tensorflow, the loss histories are automatically dumped into checkpoint_path, and can be visualized using tensorboard.
5350

5451
The current command use scheduled sampling, you can also set scheduled_sampling_start to -1 to turn off scheduled sampling.
5552

train.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,23 @@
1919
import eval_utils
2020
import misc.utils as utils
2121

22-
import os
22+
try:
23+
import tensorflow as tf
24+
except ImportError:
25+
print("Tensorflow not installed; No tensorboard logging.")
26+
tf = None
27+
28+
def add_summary_value(writer, key, value, iteration):
29+
summary = tf.Summary(value=[tf.Summary.Value(tag=key, simple_value=value)])
30+
writer.add_summary(summary, iteration)
2331

2432
def train(opt):
2533
loader = DataLoader(opt)
2634
opt.vocab_size = loader.vocab_size
2735
opt.seq_length = loader.seq_length
2836

37+
tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)
38+
2939
infos = {}
3040
if opt.start_from is not None:
3141
# open old infos and check if models are compatible
@@ -111,6 +121,12 @@ def train(opt):
111121

112122
# Write the training loss summary
113123
if (iteration % opt.losses_log_every == 0):
124+
if tf is not None:
125+
add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration)
126+
add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration)
127+
add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
128+
tf_summary_writer.flush()
129+
114130
loss_history[iteration] = train_loss
115131
lr_history[iteration] = opt.current_lr
116132
ss_prob_history[iteration] = model.ss_prob
@@ -123,6 +139,12 @@ def train(opt):
123139
eval_kwargs.update(vars(opt))
124140
val_loss, predictions, lang_stats = eval_utils.eval_split(model, crit, loader, eval_kwargs)
125141

142+
# Write validation result into summary
143+
if tf is not None:
144+
add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration)
145+
for k,v in lang_stats.items():
146+
add_summary_value(tf_summary_writer, k, v, iteration)
147+
tf_summary_writer.flush()
126148
val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}
127149

128150
# Save model if is improving on validation result

train_tb.py

Lines changed: 0 additions & 191 deletions
This file was deleted.

0 commit comments

Comments
 (0)