Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/saving/saveable_object_util.py: 18%
396 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for working with and creating SaveableObjects."""
16import functools
18from tensorflow.python.checkpoint import saveable_compat
19from tensorflow.python.client import session
20from tensorflow.python.eager import context
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import device as pydev
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import gen_control_flow_ops
30from tensorflow.python.ops import ref_variable
31from tensorflow.python.ops import resource_variable_ops
32from tensorflow.python.ops import state_ops
33from tensorflow.python.ops import variables
34from tensorflow.python.platform import tf_logging as logging
35from tensorflow.python.trackable import base as trackable
36from tensorflow.python.trackable import python_state
37from tensorflow.python.trackable import trackable_utils
38from tensorflow.python.training.saving import saveable_object
39from tensorflow.python.types import core
40from tensorflow.python.util import compat
41from tensorflow.python.util import nest
42from tensorflow.python.util import object_identity
43from tensorflow.python.util.tf_export import tf_export
45# Op names which identify variable reads which should be saved.
46_VARIABLE_OPS = set(["Variable",
47 "VariableV2",
48 "AutoReloadVariable",
49 "VarHandleOp",
50 "ReadVariableOp"])
52_REF_VARIABLE_OPS = frozenset(["Variable", "VariableV2", "AutoReloadVariable"])
55def set_cpu0(device_string):
56 """Creates a new device string based on `device_string` but using /CPU:0.
58 If the device is already on /CPU:0 or it is a custom device, this is a no-op.
60 Args:
61 device_string: A device string.
63 Returns:
64 A device string.
65 """
66 if context.is_custom_device(device_string):
67 return device_string
68 parsed_device = pydev.DeviceSpec.from_string(device_string)
69 parsed_device = parsed_device.replace(device_type="CPU", device_index=0)
70 return parsed_device.to_string()
73class ReferenceVariableSaveable(saveable_object.SaveableObject):
74 """SaveableObject implementation that handles reference variables."""
76 def __init__(self, var, slice_spec, name):
77 spec = saveable_object.SaveSpec(var, slice_spec, name, dtype=var.dtype)
78 super(ReferenceVariableSaveable, self).__init__(var, [spec], name)
80 def restore(self, restored_tensors, restored_shapes):
81 restored_tensor = restored_tensors[0]
82 if restored_shapes is not None:
83 restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
84 return state_ops.assign(
85 self.op,
86 restored_tensor,
87 validate_shape=restored_shapes is None and
88 self.op.get_shape().is_fully_defined())
91class ResourceVariableSaveable(saveable_object.SaveableObject):
92 """SaveableObject implementation that handles ResourceVariables."""
94 def __init__(self, var, slice_spec, name):
95 self._var_device = var.device
96 self._var_shape = var.shape
97 if isinstance(var, ops.Tensor):
98 self.handle_op = var.op.inputs[0]
99 tensor = var
100 elif resource_variable_ops.is_resource_variable(var):
102 def _read_variable_closure(v):
103 def f():
104 with ops.device(v.device):
105 if context.executing_eagerly() and not v.is_initialized():
106 # A SaveSpec tensor value of `None` indicates that the variable is
107 # uninitialized.
108 return None
109 # Read the variable without making a copy to limit memory usage.
110 x = v.read_value_no_copy()
111 # To allow variables placed on non-CPU devices to be checkpointed,
112 # we copy them to CPU on the same machine first.
113 with ops.device("/device:CPU:0"):
114 return array_ops.identity(x)
116 return f
118 self.handle_op = var.handle
119 tensor = _read_variable_closure(var)
120 else:
121 raise ValueError(
122 "Saveable is neither a resource variable nor a read operation."
123 f" Got: {repr(var)}")
124 spec = saveable_object.SaveSpec(tensor, slice_spec, name,
125 dtype=var.dtype, device=var.device)
126 super(ResourceVariableSaveable, self).__init__(var, [spec], name)
128 def restore(self, restored_tensors, restored_shapes):
129 """Restores tensors. Raises ValueError if incompatible shape found."""
130 restored_tensor = restored_tensors[0]
131 if restored_shapes is not None:
132 restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
133 # Copy the restored tensor to the variable's device.
134 with ops.device(self._var_device):
135 restored_tensor = array_ops.identity(restored_tensor)
136 try:
137 assigned_variable = resource_variable_ops.shape_safe_assign_variable_handle(
138 self.handle_op, self._var_shape, restored_tensor)
139 except ValueError as e:
140 raise ValueError(
141 f"Received incompatible tensor with shape {restored_tensor.shape} "
142 f"when attempting to restore variable with shape {self._var_shape} "
143 f"and name {self.name}.") from e
144 return assigned_variable
147def _tensor_comes_from_variable(v):
148 return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS
151def saveable_objects_for_op(op, name):
152 """Create `SaveableObject`s from an operation.
154 Args:
155 op: A variable, operation, or SaveableObject to coerce into a
156 SaveableObject.
157 name: A string name for the SaveableObject.
159 Yields:
160 `SaveableObject`s which together save/restore `op`.
162 Raises:
163 TypeError: If `name` is not a string.
164 ValueError: For operations with no known conversion to SaveableObject.
165 """
166 if not isinstance(name, str):
167 raise TypeError(
168 "names_to_saveables must be a dict mapping string names to "
169 f"trackable operations. Name is not a string: {name}")
170 if isinstance(op, saveable_object.SaveableObject):
171 yield op
172 elif isinstance(op, (list, tuple, variables.PartitionedVariable)):
173 if isinstance(op, variables.PartitionedVariable):
174 op = list(op)
175 # A set of slices.
176 slice_name = None
177 # pylint: disable=protected-access
178 for variable in op:
179 if isinstance(variable, saveable_object.SaveableObject):
180 yield variable
181 continue
182 if not isinstance(variable, variables.Variable):
183 raise ValueError(f"Slices must all be Variables: {variable}")
184 if not variable._save_slice_info:
185 raise ValueError(f"Slices must all be slices: {variable}")
186 if slice_name is None:
187 slice_name = variable._save_slice_info.full_name
188 elif slice_name != variable._save_slice_info.full_name:
189 raise ValueError(
190 f"Slices must all be from the same tensor: {slice_name} != "
191 f"{variable._save_slice_info.full_name}")
192 if variable.op.type in _REF_VARIABLE_OPS:
193 yield ReferenceVariableSaveable(
194 variable, variable._save_slice_info.spec, name)
195 else:
196 yield ResourceVariableSaveable(variable, variable._save_slice_info.spec,
197 name)
198 # pylint: enable=protected-access
199 elif isinstance(op, trackable.Trackable) and not isinstance(
200 op, variables.Variable):
201 # pylint: disable=protected-access
202 for attr, factory in saveable_objects_from_trackable(
203 op, tf1_saver=True).items():
204 if attr == trackable.VARIABLE_VALUE_KEY:
205 # Keep original name for classes masquerading as variables and
206 # Trackables that define _serialize_to_tensors.
207 full_name = name
208 elif attr == trackable_utils.SERIALIZE_TO_TENSORS_NAME:
209 full_name = name
210 else:
211 full_name = name + "_" + attr
212 op = (factory(full_name) if callable(factory) else factory)
213 for op in saveable_objects_for_op(op, op.name):
214 yield op
215 # pylint: enable=protected-access
216 else:
217 # A variable or tensor.
218 if isinstance(op, resource_variable_ops.BaseResourceVariable):
219 if op._in_graph_mode: # pylint: disable=protected-access
220 variable = op._graph_element # pylint: disable=protected-access
221 else:
222 variable = op
223 yield ResourceVariableSaveable(variable, "", name)
224 else:
225 if context.executing_eagerly():
226 raise ValueError("Can only save/restore ResourceVariables when "
227 f"executing eagerly, got type: {type(op)}.")
229 variable = ops.convert_to_tensor(op, as_ref=True)
230 if not _tensor_comes_from_variable(variable):
231 raise TypeError(
232 "names_to_saveables must be a dict mapping string "
233 f"names to Tensors/Variables. Not a variable: {variable}")
234 if variable.op.type in _REF_VARIABLE_OPS:
235 yield ReferenceVariableSaveable(variable, "", name)
236 else:
237 yield ResourceVariableSaveable(variable, "", name)
240def op_list_to_dict(op_list, convert_variable_to_tensor=True):
241 """Create a dictionary of names to operation lists.
243 This method is only used when the variable name matters (e.g. when saving
244 or restoring from a TF1 name-based checkpoint). In TF2, this can be called
245 from `tf.train.Checkpoint.restore` when loading from a name-based checkpoint.
247 Args:
248 op_list: A (nested) list, tuple, or set of Variables or SaveableObjects.
249 convert_variable_to_tensor: Whether or not to convert single Variables
250 with no slice info into Tensors.
252 Returns:
253 A dictionary of names to the operations that must be saved under
254 that name. Variables with save_slice_info are grouped together under the
255 same key in no particular order.
257 Raises:
258 TypeError: If the type of op_list or its elements is not supported.
259 ValueError: If at least two saveables share the same name.
260 """
261 if not isinstance(op_list, (list, tuple, set)):
262 raise TypeError("Variables to save should be passed in a dict or a "
263 f"list. Got {op_list}")
264 # List casting is necessary to support sets.
265 op_list = nest.flatten(list(op_list))
266 # When ResourceVariables are converted to Tensors, read ops are added to the
267 # graph. Sorting the op_list ensures that the resulting graph is always
268 # constructed in a deterministic way:
269 op_list = sorted(op_list, key=lambda x: x.name)
270 names_to_saveables = {}
271 # pylint: disable=protected-access
272 for var in op_list:
273 resource_or_ref_variable = (
274 isinstance(var, resource_variable_ops.BaseResourceVariable) or
275 isinstance(var, ref_variable.RefVariable))
277 if isinstance(var, saveable_object.SaveableObject):
278 names_to_saveables[var.name] = var
279 elif isinstance(var, variables.PartitionedVariable):
280 if var.name in names_to_saveables:
281 raise ValueError(
282 f"At least two variables have the same name: {var.name}")
283 names_to_saveables[var.name] = var
284 elif isinstance(var, variables.Variable) and var._save_slice_info:
285 name = var._save_slice_info.full_name
286 if name in names_to_saveables:
287 if not isinstance(names_to_saveables[name], list):
288 raise ValueError("Mixing slices and non-slices with the same name: "
289 f"{name}")
290 names_to_saveables[name].append(var)
291 else:
292 names_to_saveables[name] = [var]
293 elif isinstance(var, trackable.Trackable) and not resource_or_ref_variable:
294 trackable_saveables = [
295 (factory() if callable(factory) else factory)
296 for factory in (
297 saveable_objects_from_trackable(var, tf1_saver=True).values())]
298 names_to_saveables.update(
299 op_list_to_dict(trackable_saveables))
300 else:
301 # Variables (reference and resource) have an _in_graph_mode property
302 # indicating whether they were created in a graph building context. We
303 # also get Tensors when graph building, which do not have this property.
304 if not getattr(var, "_in_graph_mode", True):
305 if not isinstance(var, resource_variable_ops.BaseResourceVariable):
306 raise ValueError(
307 "Can only save/restore ResourceVariables when eager execution "
308 f"is enabled. Got type: {type(var)}.")
309 set_var = names_to_saveables.setdefault(var._shared_name, var)
310 if set_var is not var:
311 raise ValueError(
312 "Two different ResourceVariable objects with the same "
313 f"shared_name '{var._shared_name}' were passed to the Saver. This"
314 " likely means that they were created in different Graphs or "
315 "isolated contexts, and may not be checkpointed together.")
316 else:
317 if convert_variable_to_tensor:
318 if isinstance(var, resource_variable_ops.BaseResourceVariable):
319 var = var._graph_element # pylint: disable=protected-access
320 else:
321 var = ops.convert_to_tensor(var, as_ref=True)
322 if not _tensor_comes_from_variable(var):
323 raise TypeError(f"Variable to save is not a Variable: {var}")
324 if var.op.type == "ReadVariableOp":
325 name = var.op.inputs[0].op.name
326 else:
327 name = var.op.name
328 if name in names_to_saveables:
329 raise ValueError(f"At least two variables have the same name: {name}")
330 names_to_saveables[name] = var
332 # pylint: enable=protected-access
333 return names_to_saveables
336def _add_saveable(saveables, seen_ops, saveable):
337 """Adds the saveable to the saveables list.
339 Args:
340 saveables: List to append the SaveableObject to.
341 seen_ops: Set of the ops of the saveables already processed. Used to
342 check that each saveable is only saved once.
343 saveable: The saveable.
345 Raises:
346 ValueError: If the saveable has already been processed.
347 """
348 if saveable.op is not None and saveable.op in seen_ops:
349 raise ValueError("The same saveable will be restored with two names: "
350 f"{saveable.name}")
351 saveables.append(saveable)
352 seen_ops.add(saveable.op)
355def validate_and_slice_inputs(names_to_saveables):
356 """Returns the variables and names that will be used for a Saver.
358 Args:
359 names_to_saveables: A dict (k, v) where k is the name of an operation and
360 v is an operation to save or a BaseSaverBuilder.Saver.
362 Returns:
363 A list of SaveableObjects.
365 Raises:
366 TypeError: If any of the keys are not strings or any of the
367 values are not one of Tensor or Variable or a trackable operation.
368 ValueError: If the same operation is given in more than one value
369 (this also applies to slices of SlicedVariables).
370 """
371 saveables = []
372 seen_ops = object_identity.ObjectIdentitySet()
373 for name, op in sorted(names_to_saveables.items(),
374 # Avoid comparing ops, sort only by name.
375 key=lambda x: x[0]):
376 for converted_saveable_object in saveable_objects_for_op(op, name):
377 _add_saveable(saveables, seen_ops, converted_saveable_object)
378 return saveables
381def validate_saveables_for_saved_model(saveables, obj):
382 """Makes sure SaveableObjects are compatible with SavedModel."""
383 if isinstance(obj, python_state.PythonState):
384 logging.warn(
385 f"Note that object {obj} stores python values into the checkpoint. "
386 "These values will not be restored when loading the SavedModel "
387 "into python.")
388 return []
389 if any(isinstance(saveable, trackable.NoRestoreSaveable)
390 for saveable in saveables):
391 return []
392 return saveables
395class RestoredSaveableObject(saveable_object.SaveableObject):
396 """SaveableObject restored from SavedModel using the traced save/restore."""
398 def __init__(self, names_and_slices, save_function, restore_function, name):
399 self.save_function = save_function
400 self.restore_function = restore_function
402 if tensor_util.is_tf_type(name):
403 name_tensor = name
404 else:
405 with ops.init_scope():
406 name_tensor = constant_op.constant(name)
407 tensors = save_function(name_tensor)
408 specs = []
409 for (str_name, str_slice), tensor_info in zip(names_and_slices, tensors):
410 specs.append(saveable_object.SaveSpec(tensor_info["tensor"], str_slice,
411 name + str_name))
412 super(RestoredSaveableObject, self).__init__(None, specs, name)
414 def restore(self, restored_tensors, restored_shapes):
415 del restored_shapes # unused
416 return self.restore_function(
417 *[restored_tensors[i] for i in range(len(self.specs))])
420def recreate_saveable_objects(saveable_fn_by_name, temp_session):
421 """Returns a dict of SaveableObject factories generated from loaded fns."""
423 names_and_slices = []
425 with ops.init_scope():
427 for save_fn, _ in saveable_fn_by_name.values():
428 for tensor_info in save_fn(""):
429 name = tensor_info["name"]
430 slice_spec = tensor_info["slice_spec"]
431 if not context.executing_eagerly():
432 sess = ops.get_default_session()
433 if sess is None:
434 if temp_session[0] is not None:
435 sess = temp_session[0]
436 else:
437 sess = temp_session[0] = session.Session()
438 name, slice_spec = sess.run([name, slice_spec])
439 names_and_slices.append((
440 _convert_to_string(name),
441 _convert_to_string(slice_spec)))
443 saveable_factories = {}
444 for name, (save_fn, restore_fn) in saveable_fn_by_name.items():
445 saveable_factories[name] = functools.partial(
446 RestoredSaveableObject,
447 names_and_slices=names_and_slices,
448 save_function=save_fn,
449 restore_function=restore_fn)
450 return saveable_factories
453def create_saveable_object(name, key, factory, call_with_mapped_captures):
454 """Creates a SaveableObject while potentially in a different graph.
456 When creating the frozen saver for SavedModel, the save and restore ops are
457 placed in a separate graph. Since RestoredSaveableObject uses tf.functions to
458 save and restore, the function captures must be mapped to the new graph.
460 Args:
461 name: Name of SaveableObject factory.
462 key: Checkpoint key of this SaveableObject.
463 factory: Factory method for creating the SaveableObject.
464 call_with_mapped_captures: Helper that calls a tf.function while remapping
465 the captures.
467 Returns:
468 a SaveableObject.
469 """
470 if call_with_mapped_captures is None:
471 return factory(name=key)
472 if name == trackable_utils.SERIALIZE_TO_TENSORS_NAME:
473 return factory(name=key,
474 call_with_mapped_captures=call_with_mapped_captures)
475 elif is_factory_for_restored_saveable_object(factory):
476 concrete_save_fn = factory.keywords["save_function"]
478 def save_fn(name):
479 return call_with_mapped_captures(concrete_save_fn, [name])
481 concrete_restore_fn = factory.keywords["restore_function"]
483 def restore_fn(*restored_tensors):
484 return call_with_mapped_captures(concrete_restore_fn, restored_tensors)
486 return factory(save_function=save_fn, restore_function=restore_fn,
487 name=key)
488 else:
489 return factory(name=key)
492def is_factory_for_restored_saveable_object(factory):
493 return (isinstance(factory, functools.partial) and
494 factory.func is RestoredSaveableObject)
497@tf_export("__internal__.tracking.saveable_objects_from_trackable", v1=[])
498def saveable_objects_from_trackable(obj, tf1_saver=False):
499 """Returns SaveableObject factory dict from a Trackable.
501 Args:
502 obj: A `Trackable`
503 tf1_saver: Boolean, whether this is being called from a TF1 Saver (
504 `tf.compat.v1.train.Saver`). When this is True, the SaveableObject will
505 be generated from `obj`'s legacy `_gather_saveables_for_checkpoint` fn.
506 When saving with TF2, `Trackable._serialize_from_tensors` is preferred.
508 Returns:
509 A dict mapping attribute names to SaveableObject factories (callables that
510 produce a SaveableObject).
511 """
512 if isinstance(obj, python_state.PythonState):
513 return {
514 python_state.PYTHON_STATE:
515 functools.partial(
516 _PythonStringStateSaveable,
517 state_callback=obj.serialize,
518 restore_callback=obj.deserialize)
519 }
521 if tf1_saver:
522 saveable_factories = obj._gather_saveables_for_checkpoint() # pylint: disable=protected-access
523 if saveable_factories:
524 return saveable_factories
526 if trackable_has_serialize_to_tensor(obj):
528 def create_saveable(name="", call_with_mapped_captures=None):
529 save_fn = obj._serialize_to_tensors # pylint: disable=protected-access
530 if (call_with_mapped_captures and
531 isinstance(save_fn, core.ConcreteFunction)):
532 tensor_dict = call_with_mapped_captures(save_fn, [])
533 else:
534 tensor_dict = save_fn()
536 specs = []
537 local_names = []
538 for tensor_name, maybe_tensor in tensor_dict.items():
539 local_names.append(tensor_name)
541 if not isinstance(maybe_tensor, dict):
542 maybe_tensor = {"": maybe_tensor}
544 spec_name = name + trackable_utils.escape_local_name(tensor_name)
545 # Create separate specs for each slice spec.
546 for slice_spec, tensor in maybe_tensor.items():
547 if isinstance(tensor, saveable_object.SaveSpec):
548 spec = tensor
549 spec.name = spec_name
550 spec.slice_spec = slice_spec
551 else:
552 spec = saveable_object.SaveSpec(tensor, slice_spec, spec_name)
553 specs.append(spec)
555 return TrackableSaveable(
556 obj=obj,
557 specs=specs,
558 name=name,
559 local_names=local_names,
560 prefix=saveable_compat.get_saveable_name(obj) or "",
561 call_with_mapped_captures=call_with_mapped_captures)
563 return {trackable_utils.SERIALIZE_TO_TENSORS_NAME: create_saveable}
564 else:
565 return obj._gather_saveables_for_checkpoint() # pylint: disable=protected-access
568class TrackableSaveable(saveable_object.SaveableObject):
569 """A SaveableObject that defines `Trackable` checkpointing steps."""
571 def __init__(self, obj, specs, name, local_names, prefix,
572 call_with_mapped_captures=None):
573 self._prefix = prefix
574 self._local_names = local_names
575 self._trackable = obj
576 self._call_with_mapped_captures = call_with_mapped_captures
577 super(TrackableSaveable, self).__init__(obj, specs, name)
579 def restore(self, restored_tensors, restored_shapes):
580 del restored_shapes # Unused.
581 restored_tensor_dict = {}
582 for n, local_name in enumerate(self._local_names):
583 restored_tensor_dict[local_name] = restored_tensors[n]
585 restore_fn = self._trackable._restore_from_tensors # pylint: disable=protected-access
587 # When restoring a RefVariable, call the restore function directly.
588 # pylint: disable=protected-access
589 if not ops.executing_eagerly_outside_functions() and any([
590 spec._tensor.op.type in _REF_VARIABLE_OPS
591 for spec in self.specs
592 if isinstance(spec._tensor, ops.Tensor)]):
593 return restore_fn(restored_tensor_dict)
594 # pylint: enable=protected-access
596 if (self._call_with_mapped_captures and
597 isinstance(restore_fn, core.ConcreteFunction)):
598 ret = self._call_with_mapped_captures(restore_fn, [restored_tensor_dict])
599 else:
600 ret = restore_fn(restored_tensor_dict)
601 if ret is not None:
602 return ret
603 return gen_control_flow_ops.no_op()
605 def get_proto_names_and_checkpoint_keys(self):
606 return [(self._prefix + local_name, spec.name)
607 for local_name, spec in zip(self._local_names, self.specs)]
610class _PythonStringStateSaveable(saveable_object.SaveableObject):
611 """Saves Python state in a checkpoint."""
613 def __init__(self, name, state_callback, restore_callback):
614 """Configure saving.
616 Args:
617 name: The checkpoint key to write to.
618 state_callback: A function taking no arguments which returns a string.
619 This function is run every time a checkpoint is written.
620 restore_callback: A function taking a Python string, used to restore
621 state.
622 """
624 def _state_callback_wrapper():
625 with ops.init_scope():
626 return state_callback()
628 self._state_callback = _state_callback_wrapper
629 self._restore_callback = restore_callback
630 with ops.device("/cpu:0"):
631 self._save_string = constant_op.constant("", dtype=dtypes.string)
632 spec = saveable_object.SaveSpec(
633 self._save_string, "", name, dtype=dtypes.string)
634 super(_PythonStringStateSaveable, self).__init__(self._save_string, [spec],
635 name)
637 def feed_dict_additions(self):
638 """When running a graph, indicates fresh state to feed."""
639 return {self._save_string: self._state_callback()}
641 def freeze(self):
642 """Create a frozen `SaveableObject` which saves the current state."""
644 def _constant_state():
645 return constant_op.constant(self._state_callback(), dtype=dtypes.string)
647 return trackable.NoRestoreSaveable(
648 tensor=_constant_state,
649 dtype=dtypes.string,
650 name=self.name,
651 device="cpu:0")
654def trackable_has_serialize_to_tensor(obj):
655 """Returns whether obj's class has `_serialize_to_tensors` defined."""
656 try:
657 if "_serialize_to_tensors" in obj.__dict__:
658 # In some cases (e.g. restored objects), the object may have
659 # `_serialize_to_tensors` even if the class does not.
660 return True
661 except (AttributeError, TypeError):
662 # Data structure proxy wrappers don't have __dict__.
663 pass
665 # Use MRO so that if a parent class has `_serialize_to_tensors`, but the
666 # object class has not yet been migrated, we'll continue to use the obj
667 # class's `_gather_saveables_for_checkpoint` method.
668 for t in type(obj).mro():
669 if t is trackable.Trackable:
670 # Base case. Return False since _serialize_to_tensors will raise a
671 # NotImplemented Error.
672 return False
673 elif "_serialize_to_tensors" in t.__dict__:
674 return True
675 elif "_gather_saveables_for_checkpoint" in t.__dict__:
676 return False
677 return False
680def _convert_to_string(x):
681 return compat.as_str(tensor_util.constant_value(x))
684class SaveableCompatibilityConverter(trackable.Trackable):
685 """Converts object's `SaveableObjects` to functions used in TF2 checkpointing.
687 A class that converts a Trackable object's `SaveableObjects` to save and
688 restore functions with the same signatures as
689 `Trackable._serialize_to_tensors` and `Trackable._restore_from_tensors`.
690 This class also produces a method for filling the object proto.
691 """
693 __slots__ = ("_obj", "_saveables")
695 def __init__(self, obj, saveables):
696 """Constructor.
698 Args:
699 obj: A Trackable object.
700 saveables: A list of saveables for `obj`.
701 """
702 self._obj = obj
703 self._saveables = saveables
705 @property
706 def obj(self):
707 return self._obj
709 @property
710 def saveables(self):
711 """Returns a list of SaveableObjects generated from the Trackable object."""
712 return self._saveables
714 def _serialize_to_tensors(self):
715 """Returns a dict of tensors to serialize."""
716 return saveable_object_to_tensor_dict(self.saveables)
718 def _restore_from_tensors(self, restored_tensors):
719 """Returns the restore ops defined in the Saveables."""
720 # Map restored tensors to the corresponding SaveableObjects, then call
721 # restore. There must be an exact match between restored tensors and the
722 # expected attributes.
723 expected_keys = []
724 for saveable in self.saveables:
725 expected_keys.extend(
726 trackable_utils.extract_local_name(_convert_to_string(spec.name))
727 for spec in saveable.specs)
728 if set(expected_keys) != restored_tensors.keys():
729 raise ValueError(f"Could not restore object {self._obj} because not all "
730 "expected tensors were in the checkpoint."
731 f"\n\tExpected: {expected_keys}"
732 f"\n\tGot: {list(restored_tensors.keys())}")
734 return saveable_object_to_restore_fn(self.saveables)(restored_tensors)
737def saveable_object_to_tensor_dict(saveables):
738 """Converts a list of SaveableObjects to a tensor dictionary."""
739 tensor_dict = {}
740 for saveable in saveables:
741 for spec in saveable.specs:
742 name = _convert_to_string(spec.name)
743 slice_spec = _convert_to_string(spec.slice_spec)
744 # Currently, tensor dict cannot handle callable tensor values (which
745 # are needed for uninitialized variables), so keep using SaveSpec.
746 tensor = spec if callable(spec._tensor) else spec._tensor # pylint: disable=protected-access
747 if slice_spec:
748 tensor_dict.setdefault(name, {})[slice_spec] = tensor
749 else:
750 tensor_dict[name] = tensor
751 return tensor_dict
754def saveable_object_to_restore_fn(saveables):
755 """Generates `Trackable._restore_from_tensors` from SaveableObjects."""
757 def _restore_from_tensors(restored_tensors):
758 restore_ops = {}
760 for saveable in saveables:
761 saveable_restored_tensors = []
762 for spec in saveable.specs:
763 name = trackable_utils.extract_local_name(_convert_to_string(spec.name))
764 slice_spec = _convert_to_string(spec.slice_spec)
766 maybe_tensor = restored_tensors[name]
767 if not isinstance(maybe_tensor, dict):
768 maybe_tensor = {"": maybe_tensor}
770 saveable_restored_tensors.append(maybe_tensor[slice_spec])
771 restore_ops[saveable.name] = saveable.restore(
772 saveable_restored_tensors, restored_shapes=None)
773 return restore_ops
775 return _restore_from_tensors
778def serialized_tensors_to_saveable_cache(serialized_tensors):
779 """Converts a tensor dict to a SaveableObject cache.
781 Args:
782 serialized_tensors: Map from Trackable to a tensor dict. The tensor dict
783 maps checkpoint key (-> slice_spec) -> Tensor
785 Returns:
786 A dict mapping Trackable objects to a map from local savable name to
787 SaveableObject.
788 """
789 saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
791 for obj, tensor_dict in serialized_tensors.items():
792 if not tensor_dict: continue
793 if isinstance(obj, SaveableCompatibilityConverter):
794 trackable_obj = obj.obj
795 saveables_cache[trackable_obj] = {}
796 for saveable in obj.saveables:
797 local_name = trackable_utils.extract_local_name(saveable.name)
798 saveables_cache[trackable_obj][local_name] = [saveable]
799 continue
801 specs = []
802 # The local names and prefixes are computed to ensure that the generated
803 # SaveableObject can call `Trackable._restore_from_tensors()`
804 local_names = []
805 prefix = saveable_compat.get_saveable_name(obj) or ""
806 for checkpoint_key, maybe_tensor in tensor_dict.items():
807 # Make sure that `maybe_tensor` is a dict from `slice_spec` to `tensor`.
808 if not isinstance(maybe_tensor, dict):
809 maybe_tensor = {"": maybe_tensor}
811 for slice_spec, tensor in maybe_tensor.items():
812 if isinstance(tensor, saveable_object.SaveSpec):
813 specs.append(tensor)
814 else:
815 specs.append(saveable_object.SaveSpec(tensor,
816 slice_spec,
817 checkpoint_key))
818 local_names.append(trackable_utils.extract_local_name(checkpoint_key,
819 prefix))
821 object_name = trackable_utils.extract_object_name(
822 next(iter(tensor_dict.keys())))
823 saveables_cache[obj] = {
824 trackable_utils.SERIALIZE_TO_TENSORS_NAME: [TrackableSaveable(
825 obj, specs, object_name, local_names=local_names, prefix=prefix)]}
826 return saveables_cache