Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/training_util.py: 30%
96 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions for training."""
16from tensorflow.python.eager import context
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import graph_io
19from tensorflow.python.framework import ops
20from tensorflow.python.ops import control_flow_ops
21from tensorflow.python.ops import init_ops
22from tensorflow.python.ops import resource_variable_ops
23from tensorflow.python.ops import state_ops
24from tensorflow.python.ops import variable_scope
25from tensorflow.python.ops import variable_v1
26from tensorflow.python.ops import variables
27from tensorflow.python.platform import tf_logging as logging
28from tensorflow.python.util.tf_export import tf_export
30# Picked a long key value to minimize the chance of collision with user defined
31# collection keys.
32GLOBAL_STEP_READ_KEY = 'global_step_read_op_cache'
34# TODO(drpng): remove this after legacy uses are resolved.
35write_graph = graph_io.write_graph
38@tf_export(v1=['train.global_step'])
39def global_step(sess, global_step_tensor):
40 """Small helper to get the global step.
42 ```python
43 # Create a variable to hold the global_step.
44 global_step_tensor = tf.Variable(10, trainable=False, name='global_step')
45 # Create a session.
46 sess = tf.compat.v1.Session()
47 # Initialize the variable
48 sess.run(global_step_tensor.initializer)
49 # Get the variable value.
50 print('global_step: %s' % tf.compat.v1.train.global_step(sess,
51 global_step_tensor))
53 global_step: 10
54 ```
56 Args:
57 sess: A TensorFlow `Session` object.
58 global_step_tensor: `Tensor` or the `name` of the operation that contains
59 the global step.
61 Returns:
62 The global step value.
63 """
64 if context.executing_eagerly():
65 return int(global_step_tensor.numpy())
66 return int(sess.run(global_step_tensor))
69@tf_export(v1=['train.get_global_step'])
70def get_global_step(graph=None):
71 """Get the global step tensor.
73 The global step tensor must be an integer variable. We first try to find it
74 in the collection `GLOBAL_STEP`, or by name `global_step:0`.
76 Args:
77 graph: The graph to find the global step in. If missing, use default graph.
79 Returns:
80 The global step variable, or `None` if none was found.
82 Raises:
83 TypeError: If the global step tensor has a non-integer type, or if it is not
84 a `Variable`.
86 @compatibility(TF2)
87 With the deprecation of global graphs, TF no longer tracks variables in
88 collections. In other words, there are no global variables in TF2. Thus, the
89 global step functions have been removed (`get_or_create_global_step`,
90 `create_global_step`, `get_global_step`) . You have two options for migrating:
92 1. Create a Keras optimizer, which generates an `iterations` variable. This
93 variable is automatically incremented when calling `apply_gradients`.
94 2. Manually create and increment a `tf.Variable`.
96 Below is an example of migrating away from using a global step to using a
97 Keras optimizer:
99 Define a dummy model and loss:
101 >>> def compute_loss(x):
102 ... v = tf.Variable(3.0)
103 ... y = x * v
104 ... loss = x * 5 - x * v
105 ... return loss, [v]
107 Before migrating:
109 >>> g = tf.Graph()
110 >>> with g.as_default():
111 ... x = tf.compat.v1.placeholder(tf.float32, [])
112 ... loss, var_list = compute_loss(x)
113 ... global_step = tf.compat.v1.train.get_or_create_global_step()
114 ... global_init = tf.compat.v1.global_variables_initializer()
115 ... optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
116 ... train_op = optimizer.minimize(loss, global_step, var_list)
117 >>> sess = tf.compat.v1.Session(graph=g)
118 >>> sess.run(global_init)
119 >>> print("before training:", sess.run(global_step))
120 before training: 0
121 >>> sess.run(train_op, feed_dict={x: 3})
122 >>> print("after training:", sess.run(global_step))
123 after training: 1
125 Using `get_global_step`:
127 >>> with g.as_default():
128 ... print(sess.run(tf.compat.v1.train.get_global_step()))
129 1
131 Migrating to a Keras optimizer:
133 >>> optimizer = tf.keras.optimizers.SGD(.01)
134 >>> print("before training:", optimizer.iterations.numpy())
135 before training: 0
136 >>> with tf.GradientTape() as tape:
137 ... loss, var_list = compute_loss(3)
138 ... grads = tape.gradient(loss, var_list)
139 ... optimizer.apply_gradients(zip(grads, var_list))
140 >>> print("after training:", optimizer.iterations.numpy())
141 after training: 1
143 @end_compatibility
144 """
145 graph = graph or ops.get_default_graph()
146 global_step_tensor = None
147 global_step_tensors = graph.get_collection(ops.GraphKeys.GLOBAL_STEP)
148 if len(global_step_tensors) == 1:
149 global_step_tensor = global_step_tensors[0]
150 elif not global_step_tensors:
151 try:
152 global_step_tensor = graph.get_tensor_by_name('global_step:0')
153 except KeyError:
154 return None
155 else:
156 logging.error('Multiple tensors in global_step collection.')
157 return None
159 assert_global_step(global_step_tensor)
160 return global_step_tensor
163@tf_export(v1=['train.create_global_step'])
164def create_global_step(graph=None):
165 """Create global step tensor in graph.
167 Args:
168 graph: The graph in which to create the global step tensor. If missing, use
169 default graph.
171 Returns:
172 Global step tensor.
174 Raises:
175 ValueError: if global step tensor is already defined.
177 @compatibility(TF2)
178 With the deprecation of global graphs, TF no longer tracks variables in
179 collections. In other words, there are no global variables in TF2. Thus, the
180 global step functions have been removed (`get_or_create_global_step`,
181 `create_global_step`, `get_global_step`) . You have two options for migrating:
183 1. Create a Keras optimizer, which generates an `iterations` variable. This
184 variable is automatically incremented when calling `apply_gradients`.
185 2. Manually create and increment a `tf.Variable`.
187 Below is an example of migrating away from using a global step to using a
188 Keras optimizer:
190 Define a dummy model and loss:
192 >>> def compute_loss(x):
193 ... v = tf.Variable(3.0)
194 ... y = x * v
195 ... loss = x * 5 - x * v
196 ... return loss, [v]
198 Before migrating:
200 >>> g = tf.Graph()
201 >>> with g.as_default():
202 ... x = tf.compat.v1.placeholder(tf.float32, [])
203 ... loss, var_list = compute_loss(x)
204 ... global_step = tf.compat.v1.train.create_global_step()
205 ... global_init = tf.compat.v1.global_variables_initializer()
206 ... optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
207 ... train_op = optimizer.minimize(loss, global_step, var_list)
208 >>> sess = tf.compat.v1.Session(graph=g)
209 >>> sess.run(global_init)
210 >>> print("before training:", sess.run(global_step))
211 before training: 0
212 >>> sess.run(train_op, feed_dict={x: 3})
213 >>> print("after training:", sess.run(global_step))
214 after training: 1
216 Migrating to a Keras optimizer:
218 >>> optimizer = tf.keras.optimizers.SGD(.01)
219 >>> print("before training:", optimizer.iterations.numpy())
220 before training: 0
221 >>> with tf.GradientTape() as tape:
222 ... loss, var_list = compute_loss(3)
223 ... grads = tape.gradient(loss, var_list)
224 ... optimizer.apply_gradients(zip(grads, var_list))
225 >>> print("after training:", optimizer.iterations.numpy())
226 after training: 1
228 @end_compatibility
229 """
230 graph = graph or ops.get_default_graph()
231 if get_global_step(graph) is not None:
232 raise ValueError('"global_step" already exists.')
233 if context.executing_eagerly():
234 with ops.device('cpu:0'):
235 return variable_scope.get_variable(
236 ops.GraphKeys.GLOBAL_STEP,
237 shape=[],
238 dtype=dtypes.int64,
239 initializer=init_ops.zeros_initializer(),
240 trainable=False,
241 aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA,
242 collections=[
243 ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP
244 ])
245 # Create in proper graph and base name_scope.
246 with graph.as_default() as g, g.name_scope(None):
247 return variable_scope.get_variable(
248 ops.GraphKeys.GLOBAL_STEP,
249 shape=[],
250 dtype=dtypes.int64,
251 initializer=init_ops.zeros_initializer(),
252 trainable=False,
253 aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA,
254 collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP])
257@tf_export(v1=['train.get_or_create_global_step'])
258def get_or_create_global_step(graph=None):
259 """Returns and create (if necessary) the global step tensor.
261 Args:
262 graph: The graph in which to create the global step tensor. If missing, use
263 default graph.
265 Returns:
266 The global step tensor.
268 @compatibility(TF2)
269 With the deprecation of global graphs, TF no longer tracks variables in
270 collections. In other words, there are no global variables in TF2. Thus, the
271 global step functions have been removed (`get_or_create_global_step`,
272 `create_global_step`, `get_global_step`) . You have two options for migrating:
274 1. Create a Keras optimizer, which generates an `iterations` variable. This
275 variable is automatically incremented when calling `apply_gradients`.
276 2. Manually create and increment a `tf.Variable`.
278 Below is an example of migrating away from using a global step to using a
279 Keras optimizer:
281 Define a dummy model and loss:
283 >>> def compute_loss(x):
284 ... v = tf.Variable(3.0)
285 ... y = x * v
286 ... loss = x * 5 - x * v
287 ... return loss, [v]
289 Before migrating:
291 >>> g = tf.Graph()
292 >>> with g.as_default():
293 ... x = tf.compat.v1.placeholder(tf.float32, [])
294 ... loss, var_list = compute_loss(x)
295 ... global_step = tf.compat.v1.train.get_or_create_global_step()
296 ... global_init = tf.compat.v1.global_variables_initializer()
297 ... optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
298 ... train_op = optimizer.minimize(loss, global_step, var_list)
299 >>> sess = tf.compat.v1.Session(graph=g)
300 >>> sess.run(global_init)
301 >>> print("before training:", sess.run(global_step))
302 before training: 0
303 >>> sess.run(train_op, feed_dict={x: 3})
304 >>> print("after training:", sess.run(global_step))
305 after training: 1
307 Migrating to a Keras optimizer:
309 >>> optimizer = tf.keras.optimizers.SGD(.01)
310 >>> print("before training:", optimizer.iterations.numpy())
311 before training: 0
312 >>> with tf.GradientTape() as tape:
313 ... loss, var_list = compute_loss(3)
314 ... grads = tape.gradient(loss, var_list)
315 ... optimizer.apply_gradients(zip(grads, var_list))
316 >>> print("after training:", optimizer.iterations.numpy())
317 after training: 1
319 @end_compatibility
320 """
321 graph = graph or ops.get_default_graph()
322 global_step_tensor = get_global_step(graph)
323 if global_step_tensor is None:
324 global_step_tensor = create_global_step(graph)
325 return global_step_tensor
328@tf_export(v1=['train.assert_global_step'])
329def assert_global_step(global_step_tensor):
330 """Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`.
332 Args:
333 global_step_tensor: `Tensor` to test.
334 """
335 if not (isinstance(global_step_tensor, variables.Variable) or
336 isinstance(global_step_tensor, ops.Tensor) or
337 resource_variable_ops.is_resource_variable(global_step_tensor)):
338 raise TypeError('Existing "global_step" must be a Variable or Tensor: %s.' %
339 global_step_tensor)
341 if not global_step_tensor.dtype.base_dtype.is_integer:
342 raise TypeError('Existing "global_step" does not have integer type: %s' %
343 global_step_tensor.dtype)
345 if (global_step_tensor.get_shape().ndims != 0 and
346 global_step_tensor.get_shape().is_fully_defined()):
347 raise TypeError('Existing "global_step" is not scalar: %s' %
348 global_step_tensor.get_shape())
351def _get_global_step_read(graph=None):
352 """Gets global step read tensor in graph.
354 Args:
355 graph: The graph in which to create the global step read tensor. If missing,
356 use default graph.
358 Returns:
359 Global step read tensor.
361 Raises:
362 RuntimeError: if multiple items found in collection GLOBAL_STEP_READ_KEY.
363 """
364 graph = graph or ops.get_default_graph()
365 global_step_read_tensors = graph.get_collection(GLOBAL_STEP_READ_KEY)
366 if len(global_step_read_tensors) > 1:
367 raise RuntimeError('There are multiple items in collection {}. '
368 'There should be only one.'.format(GLOBAL_STEP_READ_KEY))
370 if len(global_step_read_tensors) == 1:
371 return global_step_read_tensors[0]
372 return None
375def _get_or_create_global_step_read(graph=None):
376 """Gets or creates global step read tensor in graph.
378 Args:
379 graph: The graph in which to create the global step read tensor. If missing,
380 use default graph.
382 Returns:
383 Global step read tensor if there is global_step_tensor else return None.
384 """
385 graph = graph or ops.get_default_graph()
386 global_step_read_tensor = _get_global_step_read(graph)
387 if global_step_read_tensor is not None:
388 return global_step_read_tensor
389 global_step_tensor = get_global_step(graph)
390 if global_step_tensor is None:
391 return None
392 # add 'zero' so that it will create a copy of variable as Tensor.
393 with graph.as_default() as g, g.name_scope(None):
394 with g.name_scope(global_step_tensor.op.name + '/'):
395 # must ensure that global_step is initialized before
396 # this run. This is needed for example Estimator makes all model_fn build
397 # under global_step_read_tensor dependency.
398 if isinstance(global_step_tensor, variables.Variable):
399 global_step_value = control_flow_ops.cond(
400 variable_v1.is_variable_initialized(global_step_tensor),
401 global_step_tensor.read_value,
402 lambda: global_step_tensor.initial_value)
403 else:
404 global_step_value = global_step_tensor
406 global_step_read_tensor = global_step_value + 0
407 ops.add_to_collection(GLOBAL_STEP_READ_KEY, global_step_read_tensor)
408 return _get_global_step_read(graph)
411def _increment_global_step(increment, graph=None):
412 graph = graph or ops.get_default_graph()
413 global_step_tensor = get_global_step(graph)
414 if global_step_tensor is None:
415 raise ValueError(
416 'Global step tensor should be created by '
417 'tf.train.get_or_create_global_step before calling increment.')
418 global_step_read_tensor = _get_or_create_global_step_read(graph)
419 with graph.as_default() as g, g.name_scope(None):
420 with g.name_scope(global_step_tensor.op.name + '/'):
421 with ops.control_dependencies([global_step_read_tensor]):
422 return state_ops.assign_add(global_step_tensor, increment)