Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/checkpoint_utils.py: 21%
164 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tools to work with name-based checkpoints.
17While some of these symbols also work with the TF2 object-based checkpoints,
18they are not recommended for TF2. Please check `tensorflow/python/checkpoint`
19for newer utilities built to work with TF2 checkpoints.
20"""
22from collections import abc
23import os
24import time
26from tensorflow.python.checkpoint import checkpoint_management
27from tensorflow.python.distribute import distribute_lib
28from tensorflow.python.framework import ops
29from tensorflow.python.ops import io_ops
30from tensorflow.python.ops import resource_variable_ops
31from tensorflow.python.ops import variable_scope as vs
32from tensorflow.python.ops import variables
33from tensorflow.python.platform import gfile
34from tensorflow.python.platform import tf_logging as logging
35from tensorflow.python.training import py_checkpoint_reader
36from tensorflow.python.training.saving import saveable_object_util
37from tensorflow.python.util.tf_export import tf_export
40__all__ = [
41 "load_checkpoint", "load_variable", "list_variables",
42 "checkpoints_iterator", "init_from_checkpoint"
43]
46@tf_export("train.load_checkpoint")
47def load_checkpoint(ckpt_dir_or_file):
48 """Returns `CheckpointReader` for checkpoint found in `ckpt_dir_or_file`.
50 If `ckpt_dir_or_file` resolves to a directory with multiple checkpoints,
51 reader for the latest checkpoint is returned.
53 Example usage:
55 ```python
56 import tensorflow as tf
57 a = tf.Variable(1.0)
58 b = tf.Variable(2.0)
59 ckpt = tf.train.Checkpoint(var_list={'a': a, 'b': b})
60 ckpt_path = ckpt.save('tmp-ckpt')
61 reader= tf.train.load_checkpoint(ckpt_path)
62 print(reader.get_tensor('var_list/a/.ATTRIBUTES/VARIABLE_VALUE')) # 1.0
63 ```
65 Args:
66 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint
67 file.
69 Returns:
70 `CheckpointReader` object.
72 Raises:
73 ValueError: If `ckpt_dir_or_file` resolves to a directory with no
74 checkpoints.
75 """
76 filename = _get_checkpoint_filename(ckpt_dir_or_file)
77 if filename is None:
78 raise ValueError("Couldn't find 'checkpoint' file or checkpoints in "
79 "given directory %s" % ckpt_dir_or_file)
80 return py_checkpoint_reader.NewCheckpointReader(filename)
83@tf_export("train.load_variable")
84def load_variable(ckpt_dir_or_file, name):
85 """Returns the tensor value of the given variable in the checkpoint.
87 When the variable name is unknown, you can use `tf.train.list_variables` to
88 inspect all the variable names.
90 Example usage:
92 ```python
93 import tensorflow as tf
94 a = tf.Variable(1.0)
95 b = tf.Variable(2.0)
96 ckpt = tf.train.Checkpoint(var_list={'a': a, 'b': b})
97 ckpt_path = ckpt.save('tmp-ckpt')
98 var= tf.train.load_variable(
99 ckpt_path, 'var_list/a/.ATTRIBUTES/VARIABLE_VALUE')
100 print(var) # 1.0
101 ```
103 Args:
104 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
105 name: Name of the variable to return.
107 Returns:
108 A numpy `ndarray` with a copy of the value of this variable.
109 """
110 # TODO(b/29227106): Fix this in the right place and remove this.
111 if name.endswith(":0"):
112 name = name[:-2]
113 reader = load_checkpoint(ckpt_dir_or_file)
114 return reader.get_tensor(name)
117@tf_export("train.list_variables")
118def list_variables(ckpt_dir_or_file):
119 """Lists the checkpoint keys and shapes of variables in a checkpoint.
121 Checkpoint keys are paths in a checkpoint graph.
123 Example usage:
125 ```python
126 import tensorflow as tf
127 import os
128 ckpt_directory = "/tmp/training_checkpoints/ckpt"
129 ckpt = tf.train.Checkpoint(optimizer=optimizer, model=model)
130 manager = tf.train.CheckpointManager(ckpt, ckpt_directory, max_to_keep=3)
131 train_and_checkpoint(model, manager)
132 tf.train.list_variables(manager.latest_checkpoint)
133 ```
135 Args:
136 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
138 Returns:
139 List of tuples `(key, shape)`.
140 """
141 reader = load_checkpoint(ckpt_dir_or_file)
142 variable_map = reader.get_variable_to_shape_map()
143 names = sorted(variable_map.keys())
144 result = []
145 for name in names:
146 result.append((name, variable_map[name]))
147 return result
150def wait_for_new_checkpoint(checkpoint_dir,
151 last_checkpoint=None,
152 seconds_to_sleep=1,
153 timeout=None):
154 """Waits until a new checkpoint file is found.
156 Args:
157 checkpoint_dir: The directory in which checkpoints are saved.
158 last_checkpoint: The last checkpoint path used or `None` if we're expecting
159 a checkpoint for the first time.
160 seconds_to_sleep: The number of seconds to sleep for before looking for a
161 new checkpoint.
162 timeout: The maximum number of seconds to wait. If left as `None`, then the
163 process will wait indefinitely.
165 Returns:
166 a new checkpoint path, or None if the timeout was reached.
167 """
168 logging.info("Waiting for new checkpoint at %s", checkpoint_dir)
169 stop_time = time.time() + timeout if timeout is not None else None
170 while True:
171 checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir)
172 if checkpoint_path is None or checkpoint_path == last_checkpoint:
173 if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
174 return None
175 time.sleep(seconds_to_sleep)
176 else:
177 logging.info("Found new checkpoint at %s", checkpoint_path)
178 return checkpoint_path
181@tf_export("train.checkpoints_iterator")
182def checkpoints_iterator(checkpoint_dir,
183 min_interval_secs=0,
184 timeout=None,
185 timeout_fn=None):
186 """Continuously yield new checkpoint files as they appear.
188 The iterator only checks for new checkpoints when control flow has been
189 reverted to it. This means it can miss checkpoints if your code takes longer
190 to run between iterations than `min_interval_secs` or the interval at which
191 new checkpoints are written.
193 The `timeout` argument is the maximum number of seconds to block waiting for
194 a new checkpoint. It is used in combination with the `timeout_fn` as
195 follows:
197 * If the timeout expires and no `timeout_fn` was specified, the iterator
198 stops yielding.
199 * If a `timeout_fn` was specified, that function is called and if it returns
200 a true boolean value the iterator stops yielding.
201 * If the function returns a false boolean value then the iterator resumes the
202 wait for new checkpoints. At this point the timeout logic applies again.
204 This behavior gives control to callers on what to do if checkpoints do not
205 come fast enough or stop being generated. For example, if callers have a way
206 to detect that the training has stopped and know that no new checkpoints
207 will be generated, they can provide a `timeout_fn` that returns `True` when
208 the training has stopped. If they know that the training is still going on
209 they return `False` instead.
211 Args:
212 checkpoint_dir: The directory in which checkpoints are saved.
213 min_interval_secs: The minimum number of seconds between yielding
214 checkpoints.
215 timeout: The maximum number of seconds to wait between checkpoints. If left
216 as `None`, then the process will wait indefinitely.
217 timeout_fn: Optional function to call after a timeout. If the function
218 returns True, then it means that no new checkpoints will be generated and
219 the iterator will exit. The function is called with no arguments.
221 Yields:
222 String paths to latest checkpoint files as they arrive.
223 """
224 checkpoint_path = None
225 while True:
226 new_checkpoint_path = wait_for_new_checkpoint(
227 checkpoint_dir, checkpoint_path, timeout=timeout)
228 if new_checkpoint_path is None:
229 if not timeout_fn:
230 # timed out
231 logging.info("Timed-out waiting for a checkpoint.")
232 return
233 if timeout_fn():
234 # The timeout_fn indicated that we are truly done.
235 return
236 else:
237 # The timeout_fn indicated that more checkpoints may come.
238 continue
239 start = time.time()
240 checkpoint_path = new_checkpoint_path
241 yield checkpoint_path
242 time_to_next_eval = start + min_interval_secs - time.time()
243 if time_to_next_eval > 0:
244 time.sleep(time_to_next_eval)
247@tf_export(v1=["train.init_from_checkpoint"])
248def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
249 """Replaces `tf.Variable` initializers so they load from a checkpoint file.
251 @compatibility(TF2)
252 `tf.compat.v1.train.init_from_checkpoint` is not recommended for restoring
253 variable values in TF2.
255 To restore checkpoints in TF2, please use
256 `tf.keras.Model.load_weights` or `tf.train.Checkpoint.restore`. These APIs use
257 use an [object-based method of checkpointing]
258 (https://www.tensorflow.org/guide/checkpoint#loading_mechanics), while
259 `tf.compat.v1.init_from_checkpoint` relies on a more-fragile variable-name
260 based method of checkpointing. There is no object-based equivalent of
261 `init_from_checkpoint` in TF2.
263 Please re-write your checkpoints immediately using the object-based APIs,
264 see [migration guide]
265 (https://www.tensorflow.org/guide/migrate#checkpoint_compatibility) for more
266 details.
268 You can load a name-based checkpoint written by `tf.compat.v1.train.Saver`
269 using `tf.train.Checkpoint.restore` or `tf.keras.Model.load_weights`. However,
270 you may have to change the names of the variables in your model to match the
271 variable names in the name-based checkpoint, which can be viewed with
272 `tf.train.list_variables(path)`.
274 Another option is to create an `assignment_map` that maps the name of the
275 variables in the name-based checkpoint to the variables in your model, eg:
276 ```
277 {
278 'sequential/dense/bias': model.variables[0],
279 'sequential/dense/kernel': model.variables[1]
280 }
281 ```
282 and use `tf.compat.v1.train.init_from_checkpoint(path, assignment_map)` to
283 restore the name-based checkpoint.
285 After restoring, re-encode your checkpoint using `tf.train.Checkpoint.save`
286 or `tf.keras.Model.save_weights`.
288 @end_compatibility
290 Values are not loaded immediately, but when the initializer is run
291 (typically by running a `tf.compat.v1.global_variables_initializer` op).
293 Note: This overrides default initialization ops of specified variables and
294 redefines dtype.
296 Assignment map supports following syntax:
298 * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in
299 current `scope_name` from `checkpoint_scope_name` with matching tensor
300 names.
301 * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` -
302 will initialize `scope_name/variable_name` variable
303 from `checkpoint_scope_name/some_other_variable`.
304 * `'scope_variable_name': variable` - will initialize given `tf.Variable`
305 object with tensor 'scope_variable_name' from the checkpoint.
306 * `'scope_variable_name': list(variable)` - will initialize list of
307 partitioned variables with tensor 'scope_variable_name' from the checkpoint.
308 * `'/': 'scope_name/'` - will load all variables in current `scope_name` from
309 checkpoint's root (e.g. no scope).
311 Supports loading into partitioned variables, which are represented as
312 `'<variable>/part_<part #>'`.
314 Assignment map can be a dict, or a list of pairs. The latter is
315 necessary to initialize multiple variables in the current graph from
316 the same variable in the checkpoint.
318 Example:
320 ```python
322 # Say, '/tmp/model.ckpt' has the following tensors:
323 # -- name='old_scope_1/var1', shape=[20, 2]
324 # -- name='old_scope_1/var2', shape=[50, 4]
325 # -- name='old_scope_2/var3', shape=[100, 100]
327 # Create new model's variables
328 with tf.compat.v1.variable_scope('new_scope_1'):
329 var1 = tf.compat.v1.get_variable('var1', shape=[20, 2],
330 initializer=tf.compat.v1.zeros_initializer())
331 with tf.compat.v1.variable_scope('new_scope_2'):
332 var2 = tf.compat.v1.get_variable('var2', shape=[50, 4],
333 initializer=tf.compat.v1.zeros_initializer())
334 # Partition into 5 variables along the first axis.
335 var3 = tf.compat.v1.get_variable(name='var3', shape=[100, 100],
336 initializer=tf.compat.v1.zeros_initializer(),
337 partitioner=lambda shape, dtype: [5, 1])
339 # Initialize all variables in `new_scope_1` from `old_scope_1`.
340 init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1/'})
342 # Use names to specify which variables to initialize from checkpoint.
343 init_from_checkpoint('/tmp/model.ckpt',
344 {'old_scope_1/var1': 'new_scope_1/var1',
345 'old_scope_1/var2': 'new_scope_2/var2'})
347 # Or use tf.Variable objects to identify what to initialize.
348 init_from_checkpoint('/tmp/model.ckpt',
349 {'old_scope_1/var1': var1,
350 'old_scope_1/var2': var2})
352 # Initialize partitioned variables using variable's name
353 init_from_checkpoint('/tmp/model.ckpt',
354 {'old_scope_2/var3': 'new_scope_2/var3'})
356 # Or specify the list of tf.Variable objects.
357 init_from_checkpoint('/tmp/model.ckpt',
358 {'old_scope_2/var3': var3._get_variable_list()})
360 ```
362 Args:
363 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
364 assignment_map: Dict, or a list of key-value pairs, where keys are names
365 of the variables in the checkpoint and values are current variables or
366 names of current variables (in default graph).
368 Raises:
369 ValueError: If missing variables in current graph, or if missing
370 checkpoints or tensors in checkpoints.
372 """
373 init_from_checkpoint_fn = lambda _: _init_from_checkpoint(
374 ckpt_dir_or_file, assignment_map)
375 if distribute_lib.get_cross_replica_context():
376 init_from_checkpoint_fn(None)
377 else:
378 distribute_lib.get_replica_context().merge_call(
379 init_from_checkpoint_fn)
382def _init_from_checkpoint(ckpt_dir_or_file, assignment_map):
383 """See `init_from_checkpoint` for documentation."""
384 ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file)
385 reader = load_checkpoint(ckpt_dir_or_file)
386 variable_map = reader.get_variable_to_shape_map()
387 if isinstance(assignment_map, abc.Mapping):
388 assignment_map = assignment_map.items()
390 # We only want to sort by tensor names.
391 sort_key = lambda pair: pair[0]
393 for tensor_name_in_ckpt, current_var_or_name in sorted(
394 assignment_map, key=sort_key):
395 var = None
396 # Check if this is Variable object or list of Variable objects (in case of
397 # partitioned variables).
398 if _is_variable(current_var_or_name) or (
399 isinstance(current_var_or_name, list)
400 and all(_is_variable(v) for v in current_var_or_name)):
401 var = current_var_or_name
402 else:
403 store_vars = vs._get_default_variable_store()._vars # pylint:disable=protected-access
404 # Check if this variable is in var_store.
405 var = store_vars.get(current_var_or_name, None)
406 # Also check if variable is partitioned as list.
407 if var is None:
408 var = _collect_partitioned_variable(current_var_or_name, store_vars)
409 if var is not None:
410 # If 1 to 1 mapping was provided, find variable in the checkpoint.
411 if tensor_name_in_ckpt not in variable_map:
412 raise ValueError("Tensor %s is not found in %s checkpoint %s" % (
413 tensor_name_in_ckpt, ckpt_dir_or_file, variable_map
414 ))
415 if _is_variable(var):
416 # Additional at-call-time checks.
417 if not var.get_shape().is_compatible_with(
418 variable_map[tensor_name_in_ckpt]):
419 raise ValueError(
420 "Shape of variable %s (%s) doesn't match with shape of "
421 "tensor %s (%s) from checkpoint reader." % (
422 var.name, str(var.get_shape()),
423 tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt])
424 ))
425 var_name = var.name
426 else:
427 var_name = ",".join(v.name for v in var)
428 _set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt)
429 logging.debug("Initialize variable %s from checkpoint %s with %s",
430 var_name, ckpt_dir_or_file, tensor_name_in_ckpt)
431 else:
432 scopes = ""
433 # TODO(vihanjain): Support list of 'current_var_or_name' here.
434 if "/" in current_var_or_name:
435 scopes = current_var_or_name[:current_var_or_name.rindex("/")]
436 if not tensor_name_in_ckpt.endswith("/"):
437 raise ValueError(
438 "Assignment map with scope only name {} should map to scope only "
439 "{}. Should be 'scope/': 'other_scope/'.".format(
440 scopes, tensor_name_in_ckpt))
441 # If scope to scope mapping was provided, find all variables in the scope
442 # and create variable to variable mapping.
443 scope_variables = set()
444 for var_name in store_vars:
445 if not scopes or var_name.startswith(scopes + "/"):
446 # Consume /part_ if partitioned variable.
447 if "/part_" in var_name:
448 var_name = var_name[:var_name.index("/part_")]
449 scope_variables.add(var_name)
450 for var_name in sorted(scope_variables):
451 # Lookup name with specified prefix and suffix from current variable.
452 # If tensor_name given is '/' (root), don't use it for full name.
453 full_tensor_name = var_name[len(scopes):]
454 if current_var_or_name != "/":
455 full_tensor_name = full_tensor_name[1:]
456 if tensor_name_in_ckpt != "/":
457 full_tensor_name = tensor_name_in_ckpt + full_tensor_name
458 # Remove trailing '/', if any, in the full_tensor_name
459 if full_tensor_name.endswith("/"):
460 full_tensor_name = full_tensor_name[:-1]
461 if full_tensor_name not in variable_map:
462 raise ValueError(
463 "Tensor %s (%s in %s) is not found in %s checkpoint" % (
464 full_tensor_name, var_name[len(scopes) + 1:],
465 tensor_name_in_ckpt, ckpt_dir_or_file
466 ))
467 var = store_vars.get(var_name, None)
468 if var is None:
469 var = _collect_partitioned_variable(var_name, store_vars)
470 _set_variable_or_list_initializer(var, ckpt_file, full_tensor_name)
471 logging.debug("Initialize variable %s from checkpoint %s with %s",
472 var_name, ckpt_dir_or_file, full_tensor_name)
475def _get_checkpoint_filename(ckpt_dir_or_file):
476 """Returns checkpoint filename given directory or specific checkpoint file."""
477 if isinstance(ckpt_dir_or_file, os.PathLike):
478 ckpt_dir_or_file = os.fspath(ckpt_dir_or_file)
479 if gfile.IsDirectory(ckpt_dir_or_file):
480 return checkpoint_management.latest_checkpoint(ckpt_dir_or_file)
481 return ckpt_dir_or_file
484def _set_checkpoint_initializer(variable,
485 ckpt_file,
486 tensor_name,
487 slice_spec,
488 name="checkpoint_initializer"):
489 """Overrides given variable's initialization op.
491 Sets variable initializer to assign op that initializes variable from tensor's
492 value in the checkpoint.
494 Args:
495 variable: `tf.Variable` object.
496 ckpt_file: string, full path of the checkpoint.
497 tensor_name: Name of the tensor to load from the checkpoint.
498 slice_spec: Slice specification for loading partitioned tensors.
499 name: Name of the operation.
500 """
501 base_type = variable.dtype.base_dtype
502 # Do not colocate with variable since RestoreV2 op only runs on CPU and
503 # colocation will force variable (and other ops that colocate with variable)
504 # to be on CPU as well. It is okay to place the variable's initializer op on
505 # CPU since it will only be run once at the start.
506 with ops.device(variable.device), ops.device("/cpu:0"):
507 restore_op = io_ops.restore_v2(
508 ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
510 names_to_saveables = saveable_object_util.op_list_to_dict([variable])
511 saveable_objects = []
512 for name, op in names_to_saveables.items():
513 for s in saveable_object_util.saveable_objects_for_op(op, name):
514 saveable_objects.append(s)
516 assert len(saveable_objects) == 1 # Should be only one variable.
517 init_op = saveable_objects[0].restore([restore_op], restored_shapes=None)
519 # pylint:disable=protected-access
520 variable._initializer_op = init_op
521 restore_op.set_shape(variable.shape)
522 variable._initial_value = restore_op
523 # pylint:enable=protected-access
526def _set_variable_or_list_initializer(variable_or_list, ckpt_file,
527 tensor_name):
528 """Overrides initialization op of given variable or list of variables.
530 Calls `_set_checkpoint_initializer` for each variable in the given list of
531 variables.
533 Args:
534 variable_or_list: `tf.Variable` object or a list of `tf.Variable` objects.
535 ckpt_file: string, full path of the checkpoint.
536 tensor_name: Name of the tensor to load from the checkpoint.
538 Raises:
539 ValueError: if all objects in `variable_or_list` are not partitions of the
540 same large variable.
541 """
542 if isinstance(variable_or_list, (list, tuple)):
543 # A set of slices.
544 slice_name = None
545 for v in variable_or_list:
546 slice_info = v._save_slice_info # pylint:disable=protected-access
547 if slice_name is None:
548 slice_name = slice_info.full_name
549 elif slice_name != slice_info.full_name:
550 raise ValueError("Slices must all be from the same tensor: %s != %s" %
551 (slice_name, slice_info.full_name))
552 _set_checkpoint_initializer(v, ckpt_file, tensor_name, slice_info.spec)
553 else:
554 _set_checkpoint_initializer(variable_or_list, ckpt_file, tensor_name, "")
557def _is_variable(x):
558 return (isinstance(x, variables.Variable) or
559 resource_variable_ops.is_resource_variable(x))
562def _collect_partitioned_variable(name, all_vars):
563 """Returns list of `tf.Variable` that comprise the partitioned variable."""
564 if name + "/part_0" in all_vars:
565 var = []
566 i = 0
567 while name + "/part_%d" % i in all_vars:
568 var.append(all_vars[name + "/part_%d" % i])
569 i += 1
570 return var
571 return None