forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_eval.py
122 lines (96 loc) · 4.22 KB
/
run_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright 2018 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Script that iteratively applies the unsupervised update rule and evaluates the
meta-objective performance.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
from absl import app
from learning_unsupervised_learning import evaluation
from learning_unsupervised_learning import datasets
from learning_unsupervised_learning import architectures
from learning_unsupervised_learning import summary_utils
from learning_unsupervised_learning import meta_objective
import tensorflow as tf
import sonnet as snt
from tensorflow.contrib.framework.python.framework import checkpoint_utils
flags.DEFINE_string("checkpoint_dir", None, "Dir to load pretrained update rule from")
flags.DEFINE_string("train_log_dir", None, "Training log directory")
FLAGS = flags.FLAGS
def train(train_log_dir, checkpoint_dir, eval_every_n_steps=10, num_steps=3000):
dataset_fn = datasets.mnist.TinyMnist
w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner
theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess
meta_objectives = []
meta_objectives.append(
meta_objective.linear_regression.LinearRegressionMetaObjective)
meta_objectives.append(meta_objective.sklearn.LogisticRegression)
checkpoint_vars, train_one_step_op, (
base_model, dataset) = evaluation.construct_evaluation_graph(
theta_process_fn=theta_process_fn,
w_learner_fn=w_learner_fn,
dataset_fn=dataset_fn,
meta_objectives=meta_objectives)
batch = dataset()
pre_logit, outputs = base_model(batch)
global_step = tf.train.get_or_create_global_step()
var_list = list(
snt.get_variables_in_module(base_model, tf.GraphKeys.TRAINABLE_VARIABLES))
tf.logging.info("all vars")
for v in tf.all_variables():
tf.logging.info(" %s" % str(v))
global_step = tf.train.get_global_step()
accumulate_global_step = global_step.assign_add(1)
reset_global_step = global_step.assign(0)
train_op = tf.group(
train_one_step_op, accumulate_global_step, name="train_op")
summary_op = tf.summary.merge_all()
file_writer = summary_utils.LoggingFileWriter(train_log_dir, regexes=[".*"])
if checkpoint_dir:
str_var_list = checkpoint_utils.list_variables(checkpoint_dir)
name_to_v_map = {v.op.name: v for v in tf.all_variables()}
var_list = [
name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map
]
saver = tf.train.Saver(var_list)
missed_variables = [
v.op.name for v in set(
snt.get_variables_in_scope("LocalWeightUpdateProcess",
tf.GraphKeys.GLOBAL_VARIABLES)) -
set(var_list)
]
assert len(missed_variables) == 0, "Missed a theta variable."
hooks = []
with tf.train.SingularMonitoredSession(master="", hooks=hooks) as sess:
# global step should be restored from the evals job checkpoint or zero for fresh.
step = sess.run(global_step)
if step == 0 and checkpoint_dir:
tf.logging.info("force restore")
saver.restore(sess, checkpoint_dir)
tf.logging.info("force restore done")
sess.run(reset_global_step)
step = sess.run(global_step)
while step < num_steps:
if step % eval_every_n_steps == 0:
s, _, step = sess.run([summary_op, train_op, global_step])
file_writer.add_summary(s, step)
else:
_, step = sess.run([train_op, global_step])
def main(argv):
train(FLAGS.train_log_dir, FLAGS.checkpoint_dir)
if __name__ == "__main__":
app.run(main)