forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhyperparams_builder.py
169 lines (137 loc) · 5.93 KB
/
hyperparams_builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builder function to construct tf-slim arg_scope for convolution, fc ops."""
import tensorflow as tf
from object_detection.protos import hyperparams_pb2
slim = tf.contrib.slim
def build(hyperparams_config, is_training):
"""Builds tf-slim arg_scope for convolution ops based on the config.
Returns an arg_scope to use for convolution ops containing weights
initializer, weights regularizer, activation function, batch norm function
and batch norm parameters based on the configuration.
Note that if the batch_norm parameteres are not specified in the config
(i.e. left to default) then batch norm is excluded from the arg_scope.
The batch norm parameters are set for updates based on `is_training` argument
and conv_hyperparams_config.batch_norm.train parameter. During training, they
are updated only if batch_norm.train parameter is true. However, during eval,
no updates are made to the batch norm variables. In both cases, their current
values are used during forward pass.
Args:
hyperparams_config: hyperparams.proto object containing
hyperparameters.
is_training: Whether the network is in training mode.
Returns:
arg_scope: tf-slim arg_scope containing hyperparameters for ops.
Raises:
ValueError: if hyperparams_config is not of type hyperparams.Hyperparams.
"""
if not isinstance(hyperparams_config,
hyperparams_pb2.Hyperparams):
raise ValueError('hyperparams_config not of type '
'hyperparams_pb.Hyperparams.')
batch_norm = None
batch_norm_params = None
if hyperparams_config.HasField('batch_norm'):
batch_norm = slim.batch_norm
batch_norm_params = _build_batch_norm_params(
hyperparams_config.batch_norm, is_training)
affected_ops = [slim.conv2d, slim.separable_conv2d, slim.conv2d_transpose]
if hyperparams_config.HasField('op') and (
hyperparams_config.op == hyperparams_pb2.Hyperparams.FC):
affected_ops = [slim.fully_connected]
with slim.arg_scope(
affected_ops,
weights_regularizer=_build_regularizer(
hyperparams_config.regularizer),
weights_initializer=_build_initializer(
hyperparams_config.initializer),
activation_fn=_build_activation_fn(hyperparams_config.activation),
normalizer_fn=batch_norm,
normalizer_params=batch_norm_params) as sc:
return sc
def _build_activation_fn(activation_fn):
"""Builds a callable activation from config.
Args:
activation_fn: hyperparams_pb2.Hyperparams.activation
Returns:
Callable activation function.
Raises:
ValueError: On unknown activation function.
"""
if activation_fn == hyperparams_pb2.Hyperparams.NONE:
return None
if activation_fn == hyperparams_pb2.Hyperparams.RELU:
return tf.nn.relu
if activation_fn == hyperparams_pb2.Hyperparams.RELU_6:
return tf.nn.relu6
raise ValueError('Unknown activation function: {}'.format(activation_fn))
def _build_regularizer(regularizer):
"""Builds a tf-slim regularizer from config.
Args:
regularizer: hyperparams_pb2.Hyperparams.regularizer proto.
Returns:
tf-slim regularizer.
Raises:
ValueError: On unknown regularizer.
"""
regularizer_oneof = regularizer.WhichOneof('regularizer_oneof')
if regularizer_oneof == 'l1_regularizer':
return slim.l1_regularizer(scale=float(regularizer.l1_regularizer.weight))
if regularizer_oneof == 'l2_regularizer':
return slim.l2_regularizer(scale=float(regularizer.l2_regularizer.weight))
raise ValueError('Unknown regularizer function: {}'.format(regularizer_oneof))
def _build_initializer(initializer):
"""Build a tf initializer from config.
Args:
initializer: hyperparams_pb2.Hyperparams.regularizer proto.
Returns:
tf initializer.
Raises:
ValueError: On unknown initializer.
"""
initializer_oneof = initializer.WhichOneof('initializer_oneof')
if initializer_oneof == 'truncated_normal_initializer':
return tf.truncated_normal_initializer(
mean=initializer.truncated_normal_initializer.mean,
stddev=initializer.truncated_normal_initializer.stddev)
if initializer_oneof == 'variance_scaling_initializer':
enum_descriptor = (hyperparams_pb2.VarianceScalingInitializer.
DESCRIPTOR.enum_types_by_name['Mode'])
mode = enum_descriptor.values_by_number[initializer.
variance_scaling_initializer.
mode].name
return slim.variance_scaling_initializer(
factor=initializer.variance_scaling_initializer.factor,
mode=mode,
uniform=initializer.variance_scaling_initializer.uniform)
raise ValueError('Unknown initializer function: {}'.format(
initializer_oneof))
def _build_batch_norm_params(batch_norm, is_training):
"""Build a dictionary of batch_norm params from config.
Args:
batch_norm: hyperparams_pb2.ConvHyperparams.batch_norm proto.
is_training: Whether the models is in training mode.
Returns:
A dictionary containing batch_norm parameters.
"""
batch_norm_params = {
'decay': batch_norm.decay,
'center': batch_norm.center,
'scale': batch_norm.scale,
'epsilon': batch_norm.epsilon,
'fused': True,
'is_training': is_training and batch_norm.train,
}
return batch_norm_params