Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/functional_saver.py: 20%
210 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Saves and restore variables inside traced @tf.functions."""
17from tensorflow.core.protobuf import saver_pb2
18from tensorflow.python.checkpoint import checkpoint_options
19from tensorflow.python.eager import context
20from tensorflow.python.eager import def_function
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_spec
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import gen_io_ops
28from tensorflow.python.ops import io_ops
29from tensorflow.python.ops import string_ops
30from tensorflow.python.saved_model import registration
31from tensorflow.python.trackable import trackable_utils
32from tensorflow.python.training.saving import saveable_object
33from tensorflow.python.training.saving import saveable_object_util
34from tensorflow.python.util import nest
35from tensorflow.python.util import object_identity
38class _SingleDeviceSaver(object):
39 """Saves and restores checkpoints from the current device."""
41 __slots__ = ["_tensor_slice_dict"]
43 def __init__(self, tensor_slice_dict):
44 """Specify a list of `SaveableObject`s to save and restore.
46 Args:
47 tensor_slice_dict: A dict mapping checkpoint key -> slice_spec -> tensor.
48 """
49 self._tensor_slice_dict = tensor_slice_dict
51 def save(self, file_prefix, options=None):
52 """Save the saveable objects to a checkpoint with `file_prefix`.
54 Args:
55 file_prefix: A string or scalar string Tensor containing the prefix to
56 save under.
57 options: Optional `CheckpointOptions` object.
58 Returns:
59 An `Operation`, or None when executing eagerly.
60 """
61 options = options or checkpoint_options.CheckpointOptions()
62 tensor_names = []
63 tensors = []
64 slice_specs = []
65 for checkpoint_key, tensor_slices in self._tensor_slice_dict.items():
66 for slice_spec, tensor in tensor_slices.items():
67 if isinstance(tensor, saveable_object.SaveSpec):
68 tensor_value = tensor.tensor
69 # A tensor value of `None` indicates that this SaveableObject gets
70 # recorded in the object graph, but that no value is saved in the
71 # checkpoint.
72 if tensor_value is not None:
73 tensor_names.append(tensor.name)
74 tensors.append(tensor_value)
75 slice_specs.append(tensor.slice_spec)
76 else:
77 tensor_names.append(checkpoint_key)
78 tensors.append(tensor)
79 slice_specs.append(slice_spec)
80 save_device = options.experimental_io_device or (
81 len(tensors) and saveable_object_util.set_cpu0(tensors[0].device))
82 save_device = save_device or "cpu:0"
83 with ops.device(save_device):
84 return io_ops.save_v2(file_prefix, tensor_names, slice_specs, tensors)
86 def restore(self, file_prefix, options=None):
87 """Restore the saveable objects from a checkpoint with `file_prefix`.
89 Args:
90 file_prefix: A string or scalar string Tensor containing the prefix for
91 files to read from.
92 options: Optional `CheckpointOptions` object.
94 Returns:
95 A restored tensor dict (maps checkpoint_key -> slice_spec -> tensor).
96 """
97 options = options or checkpoint_options.CheckpointOptions()
98 tensor_names = []
99 tensor_dtypes = []
100 slice_specs = []
102 for checkpoint_key, tensor_slices in self._tensor_slice_dict.items():
103 for slice_spec, tensor in tensor_slices.items():
104 tensor_dtypes.append(tensor.dtype)
105 if isinstance(tensor, saveable_object.SaveSpec):
106 slice_specs.append(tensor.slice_spec)
107 tensor_names.append(tensor.name)
108 else:
109 slice_specs.append(slice_spec)
110 tensor_names.append(checkpoint_key)
112 restore_device = options.experimental_io_device or "cpu:0"
113 with ops.device(restore_device):
114 restored_tensors = io_ops.restore_v2(
115 file_prefix, tensor_names, slice_specs, tensor_dtypes)
117 restored_tensor_dict = {}
118 for checkpoint_key, tensor_slices in self._tensor_slice_dict.items():
119 for slice_spec in tensor_slices:
120 restored_tensor = restored_tensors.pop(0)
121 restored_tensor_dict.setdefault(checkpoint_key, {})[slice_spec] = (
122 restored_tensor)
123 return restored_tensor_dict
126def sharded_filename(filename_tensor, shard, num_shards):
127 """Append sharding information to a filename.
129 Args:
130 filename_tensor: A string tensor.
131 shard: Integer. The shard for the filename.
132 num_shards: An int Tensor for the number of shards.
134 Returns:
135 A string tensor.
136 """
137 return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards)
140def registered_saver_filename(filename_tensor, saver_name):
141 return string_ops.string_join(
142 [filename_tensor, constant_op.constant(f"-{saver_name}")])
145def _get_mapped_registered_save_fn(fn, trackables, call_with_mapped_captures):
146 """Converts the function to a python or tf.function with a single file arg."""
148 def save_fn(file_prefix):
149 return fn(trackables=trackables, file_prefix=file_prefix)
150 if call_with_mapped_captures is None:
151 return save_fn
152 else:
153 tf_fn = def_function.function(save_fn, autograph=False)
154 concrete = tf_fn.get_concrete_function(
155 file_prefix=tensor_spec.TensorSpec(shape=(), dtype=dtypes.string))
157 def save_fn_with_replaced_captures(file_prefix):
158 return call_with_mapped_captures(concrete, [file_prefix])
160 return save_fn_with_replaced_captures
163def _get_mapped_registered_restore_fn(fn, trackables,
164 call_with_mapped_captures):
165 """Converts the function to a python or tf.function with a single file arg."""
167 def restore_fn(merged_prefix):
168 return fn(trackables=trackables, merged_prefix=merged_prefix)
169 if call_with_mapped_captures is None:
170 return restore_fn
171 else:
172 tf_fn = def_function.function(restore_fn, autograph=False)
173 concrete = tf_fn.get_concrete_function(
174 merged_prefix=tensor_spec.TensorSpec(shape=(), dtype=dtypes.string))
176 def restore_fn_with_replaced_captures(merged_prefix):
177 return call_with_mapped_captures(concrete, [merged_prefix])
179 return restore_fn_with_replaced_captures
182_restore_noop = lambda *args, **kwargs: None
185class MultiDeviceSaver(object):
186 """Saves checkpoints directly from multiple devices.
188 Note that this is a low-level utility which stores Tensors in the keys
189 specified by `SaveableObject`s. Higher-level utilities for object-based
190 checkpointing are built on top of it.
191 """
193 def __init__(self,
194 serialized_tensors,
195 registered_savers=None,
196 call_with_mapped_captures=None):
197 """Specify a list of `SaveableObject`s to save and restore.
199 Args:
200 serialized_tensors: A dictionary mapping `Trackable` to a tensor dict,
201 which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. The
202 `Trackable` key is used to get the `restore_from_tensors` function,
203 and may be `None` if the tensor is not meant to be restored.
204 registered_savers: A dictionary mapping `registration.RegisteredSaver`
205 namedtuples to a dictionary of named Trackables. The keys of the
206 Trackable dictionary are string names that uniquely identify the
207 Trackable in the checkpoint.
208 call_with_mapped_captures: TODO
209 """
210 # Keep these two data structures so that we can map restored tensors to
211 # the Trackable restore functions.
212 self._keys_to_restore_fn = {}
213 self._restore_fn_to_keys = {}
215 # Extract serialized tensors and separate by device.
216 tensors_by_device = {} # device -> checkpoint key -> (slice_spec ->) tensor
218 for obj, tensor_dict in serialized_tensors.items():
219 restore_fn = _restore_noop if obj is None else obj._restore_from_tensors
221 # Divide tensor_dict by device.
222 for checkpoint_key, maybe_tensor in tensor_dict.items():
223 if not isinstance(maybe_tensor, dict):
224 # Make sure that maybe_tensor is structured as {slice_spec -> tensor}.
225 maybe_tensor = {"": maybe_tensor}
227 for slice_spec, tensor in maybe_tensor.items():
228 if (checkpoint_key, slice_spec) in self._keys_to_restore_fn:
229 raise ValueError(
230 "Recieved multiple tensors with the same checkpoint key and "
231 "slice spec. This is invalid because one will overwrite the "
232 "other in the checkpoint. This indicates a bug in the "
233 "Checkpoint key-generation.")
234 self._keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn
235 self._restore_fn_to_keys.setdefault(restore_fn, []).append(
236 (checkpoint_key, slice_spec))
238 host_device = saveable_object_util.set_cpu0(tensor.device)
239 (tensors_by_device
240 .setdefault(host_device, {})
241 .setdefault(checkpoint_key, {})[slice_spec]) = tensor
242 self._single_device_savers = {
243 device: _SingleDeviceSaver(tensor_slice_dict)
244 for device, tensor_slice_dict in tensors_by_device.items()}
246 self._registered_savers = {}
247 if registered_savers:
248 for registered_name, trackables in registered_savers.items():
249 save_fn = _get_mapped_registered_save_fn(
250 registration.get_save_function(registered_name),
251 trackables, call_with_mapped_captures)
252 restore_fn = _get_mapped_registered_restore_fn(
253 registration.get_restore_function(registered_name),
254 trackables, call_with_mapped_captures)
255 self._registered_savers[registered_name] = (save_fn, restore_fn)
257 @classmethod
258 def from_saveables(cls, saveables, registered_savers=None,
259 call_with_mapped_captures=None):
260 serialized_tensors = object_identity.ObjectIdentityDictionary()
261 for saveable in saveables:
262 trackable = saveable_object_util.SaveableCompatibilityConverter(
263 saveable, saveables=[saveable])
264 serialized_tensors[trackable] = trackable._serialize_to_tensors() # pylint: disable=protected-access
265 return cls(serialized_tensors, registered_savers, call_with_mapped_captures)
267 def to_proto(self):
268 """Serializes to a SaverDef referencing the current graph."""
269 filename_tensor = array_ops.placeholder(
270 shape=[], dtype=dtypes.string, name="saver_filename")
271 save_tensor = self._traced_save(filename_tensor)
272 restore_op = self._traced_restore(filename_tensor).op
273 return saver_pb2.SaverDef(
274 filename_tensor_name=filename_tensor.name,
275 save_tensor_name=save_tensor.name,
276 restore_op_name=restore_op.name,
277 version=saver_pb2.SaverDef.V2)
279 @def_function.function(
280 input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),),
281 autograph=False)
282 def _traced_save(self, file_prefix):
283 save_op = self.save(file_prefix)
284 with ops.device("cpu:0"):
285 with ops.control_dependencies([save_op]):
286 return array_ops.identity(file_prefix)
288 @def_function.function(
289 input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),),
290 autograph=False)
291 def _traced_restore(self, file_prefix):
292 restore_ops = self.restore(file_prefix)
293 with ops.device("cpu:0"):
294 with ops.control_dependencies(restore_ops.values()):
295 return array_ops.identity(file_prefix)
297 def save(self, file_prefix, options=None):
298 """Save the saveable objects to a checkpoint with `file_prefix`.
300 Args:
301 file_prefix: A string or scalar string Tensor containing the prefix to
302 save under.
303 options: Optional `CheckpointOptions` object.
304 Returns:
305 An `Operation`, or None when executing eagerly.
306 """
307 options = options or checkpoint_options.CheckpointOptions()
309 # IMPLEMENTATION DETAILS: most clients should skip.
310 #
311 # Suffix for any well-formed "checkpoint_prefix", when sharded.
312 # Transformations:
313 # * Users pass in "save_path" in save() and restore(). Say "myckpt".
314 # * checkpoint_prefix gets fed <save_path><sharded_suffix>.
315 #
316 # Example:
317 # During runtime, a temporary directory is first created, which contains
318 # files
319 #
320 # <train dir>/myckpt_temp/
321 # part-?????-of-?????{.index, .data-00000-of-00001}
322 #
323 # Before .save() finishes, they will be (hopefully, atomically) renamed to
324 #
325 # <train dir>/
326 # myckpt{.index, .data-?????-of-?????}
327 #
328 # Filesystems with eventual consistency (such as S3), don't need a
329 # temporary location. Using a temporary directory in those cases might
330 # cause situations where files are not available during copy.
331 #
332 # Users only need to interact with the user-specified prefix, which is
333 # "<train dir>/myckpt" in this case. Save() and Restore() work with the
334 # prefix directly, instead of any physical pathname. (On failure and
335 # subsequent restore, an outdated and orphaned temporary directory can be
336 # safely removed.)
337 with ops.device("CPU"):
338 sharded_suffix = array_ops.where(
339 string_ops.regex_full_match(file_prefix, "^s3://.*"),
340 constant_op.constant(".part"),
341 constant_op.constant("_temp/part"))
342 tmp_checkpoint_prefix = string_ops.string_join(
343 [file_prefix, sharded_suffix])
344 registered_paths = {
345 saver_name: registered_saver_filename(file_prefix, saver_name)
346 for saver_name in self._registered_savers
347 }
349 def save_fn():
350 saved_prefixes = []
351 # Save with the registered savers. These run before default savers due to
352 # the API contract.
353 for saver_name, (save_fn, _) in self._registered_savers.items():
354 maybe_saved_prefixes = save_fn(registered_paths[saver_name])
355 if maybe_saved_prefixes is not None:
356 flattened_saved_prefixes = nest.flatten(maybe_saved_prefixes)
357 if not all(
358 tensor_util.is_tf_type(x) and x.dtype == dtypes.string
359 for x in flattened_saved_prefixes):
360 raise ValueError(
361 "Registered saver must return a (maybe empty) list of "
362 f"string type tensors. Got {maybe_saved_prefixes}.")
363 saved_prefixes.extend(flattened_saved_prefixes)
365 # (Default saver) Save with single device savers.
366 num_shards = len(self._single_device_savers)
367 sharded_saves = []
368 num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
369 last_device = None
370 for shard, (device, saver) in enumerate(
371 sorted(self._single_device_savers.items())):
372 last_device = device
373 with ops.device(saveable_object_util.set_cpu0(device)):
374 shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard,
375 num_shards_tensor)
376 saved_prefixes.append(shard_prefix)
377 with ops.device(device):
378 # _SingleDeviceSaver will use the CPU device when necessary, but
379 # initial read operations should be placed on the SaveableObject's
380 # device.
381 sharded_saves.append(saver.save(shard_prefix, options))
383 with ops.control_dependencies(sharded_saves):
384 # Merge on the io_device if specified, otherwise co-locates the merge op
385 # with the last device used.
386 merge_device = (
387 options.experimental_io_device or
388 saveable_object_util.set_cpu0(last_device))
389 with ops.device(merge_device):
390 # V2 format write path consists of a metadata merge step. Once
391 # merged, attempts to delete the temporary directory,
392 # "<user-fed prefix>_temp".
393 return gen_io_ops.merge_v2_checkpoints(
394 saved_prefixes, file_prefix, delete_old_dirs=True)
396 # Since this will causes a function re-trace on each save, limit this to the
397 # cases where it is needed: eager and when there are multiple tasks/single
398 # device savers. Note that the retrace is needed to ensure we pickup the
399 # latest values of options like experimental_io_device.
400 if context.executing_eagerly() and len(self._single_device_savers) > 1:
401 # Explicitly place the identity op on the first device.
402 @def_function.function(jit_compile=False)
403 def tf_function_save():
404 save_fn()
405 tf_function_save()
406 else:
407 return save_fn()
409 def restore(self, file_prefix, options=None):
410 """Restore the saveable objects from a checkpoint with `file_prefix`.
412 Args:
413 file_prefix: A string or scalar string Tensor containing the prefix for
414 files to read from.
415 options: Optional `CheckpointOptions` object.
417 Returns:
418 When not run eagerly or when saving on a single device, returns a
419 dictionary mapping from SaveableObject names to restore operations;
420 otherwise, returns an empty dict.
421 """
422 options = options or checkpoint_options.CheckpointOptions()
424 def restore_fn():
425 restore_fn_inputs = {}
426 restore_fn_input_count = {
427 fn: len(keys) for fn, keys in self._restore_fn_to_keys.items()}
429 restore_ops = {}
430 # Sort by device name to avoid propagating non-deterministic dictionary
431 # ordering in some Python versions.
432 for device, saver in sorted(self._single_device_savers.items()):
433 with ops.device(device):
434 # Load values from checkpoint
435 restored_tensor_dict = saver.restore(file_prefix, options)
437 # Map restored tensors to the corresponding restore_fn, and see if all
438 # inputs have all been loaded. Call `restore_fn` if that is the case.
439 for checkpoint_key, slice_and_tensor in restored_tensor_dict.items():
440 for slice_spec, tensor in slice_and_tensor.items():
441 restore_fn = self._keys_to_restore_fn[(checkpoint_key,
442 slice_spec)]
444 # Processing the returned restored_tensor_dict to prepare for the
445 # Trackable `restore` function. The `restore` function expects a
446 # map of `string name (checkpoint_key) -> Tensor`. Unless there is
447 # a slice_spec, in which case the map will be of
448 # `string name (checkpoint_key)-> slice_spec -> Tensor`.
449 if slice_spec:
450 (restore_fn_inputs.setdefault(restore_fn, {}).setdefault(
451 checkpoint_key, {})[slice_spec]) = tensor
452 else:
453 restore_fn_inputs.setdefault(restore_fn,
454 {})[checkpoint_key] = tensor
455 restore_fn_input_count[restore_fn] -= 1
457 if restore_fn_input_count[restore_fn] == 0:
458 restored_tensors = {}
459 # Extracts the substring after the "/.ATTRIBUTES/" in the
460 # ckpt_key from restore_fn_inputs[restore_fn] to
461 # restored_tensors. For example, if restore_fn_input[restore_fn]
462 # is dict { "/.ATTIBUTES/a": Tensor}, restored_tensors will be
463 # changed to dict {"a": Tensor}
464 for ckpt_key, tensor in restore_fn_inputs[restore_fn].items():
465 restored_tensors[trackable_utils.extract_local_name(
466 ckpt_key)] = tensor
467 ret = restore_fn(restored_tensors)
468 if isinstance(ret, dict):
469 restore_ops.update(ret)
470 # Run registered restore methods after the default restore ops.
471 for _, (_, restore_fn) in self._registered_savers.items():
472 restore_fn(file_prefix)
473 return restore_ops
475 has_custom_device_saver = any([
476 context.is_custom_device(d) for d in self._single_device_savers.keys()
477 ])
478 # Since this will cause a function re-trace on each restore, limit this to
479 # cases where it is needed: eager and when there are multiple tasks/single
480 # device savers or any single device saver is a custom device. Note that the
481 # retrace is needed to ensure we pickup the latest values of options like
482 # experimental_io_device.
483 #
484 # We run in a function when there is a custom device saver because custom
485 # devices, such as DTensor, usually do a sharded save and restore.
486 # Doing a sharded save and restore requires knowledge about what shards
487 # of variables we are restoring to. In practice, this means that custom
488 # devices need the AssignVariableOps along with the Restore op within the
489 # same graph to infer shapes and shard specs for Restore op.
490 if context.executing_eagerly() and (len(self._single_device_savers) > 1 or
491 has_custom_device_saver):
492 @def_function.function(jit_compile=False, autograph=False)
493 def tf_function_restore():
494 restore_fn()
495 return {}
497 restore_ops = tf_function_restore()
498 else:
499 restore_ops = restore_fn()
501 return restore_ops