Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/saving/saveable_object.py: 36%
28 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Types for specifying saving and loading behavior."""
18class SaveSpec:
19 """Class used to describe tensor slices that need to be saved."""
21 def __init__(self, tensor, slice_spec, name, dtype=None, device=None):
22 """Creates a `SaveSpec` object.
24 Args:
25 tensor: the tensor to save or callable that produces a tensor to save.
26 If the value is `None`, the `SaveSpec` is ignored.
27 slice_spec: the slice to be saved. See `Variable.SaveSliceInfo`.
28 name: the name to save the tensor under.
29 dtype: The data type of the Tensor. Required if `tensor` is callable.
30 Used for error checking in the restore op.
31 device: The device generating and consuming this tensor. Required if
32 `tensor` is callable. Used to group objects to save by device.
33 """
34 self._tensor = tensor
35 self.slice_spec = slice_spec
36 self.name = name
37 if callable(self._tensor):
38 if dtype is None or device is None:
39 raise AssertionError(
40 "When passing a callable `tensor` to a SaveSpec, an explicit "
41 "dtype and device must be provided.")
42 self.dtype = dtype
43 self.device = device
44 else:
45 self.dtype = tensor.dtype
46 if device is not None:
47 self.device = device
48 else:
49 self.device = tensor.device
51 @property
52 def tensor(self):
53 return self._tensor() if callable(self._tensor) else self._tensor
56class SaveableObject:
57 """Base class for saving and restoring saveable objects."""
59 def __init__(self, op, specs, name):
60 """Creates a `SaveableObject` object.
62 Args:
63 op: the "producer" object that this class wraps; it produces a list of
64 tensors to save. E.g., a "Variable" object saving its backing tensor.
65 specs: a list of SaveSpec, each element of which describes one tensor to
66 save under this object. All Tensors must be on the same device.
67 name: the name to save the object under.
68 """
69 self.op = op
70 self.specs = specs
71 self.name = name
73 @property
74 def device(self):
75 """The device for SaveSpec Tensors."""
76 return self.specs[0].device
78 def restore(self, restored_tensors, restored_shapes):
79 """Restores this object from 'restored_tensors'.
81 Args:
82 restored_tensors: the tensors that were loaded from a checkpoint
83 restored_shapes: the shapes this object should conform to after
84 restore, or None.
86 Returns:
87 An operation that restores the state of the object.
89 Raises:
90 ValueError: If the object cannot be restored using the provided
91 parameters.
92 """
93 # pylint: disable=unused-argument
94 raise ValueError("Calling an abstract method.")