Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/saving/saveable_object.py: 36%

28 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Types for specifying saving and loading behavior.""" 

16 

17 

18class SaveSpec: 

19 """Class used to describe tensor slices that need to be saved.""" 

20 

21 def __init__(self, tensor, slice_spec, name, dtype=None, device=None): 

22 """Creates a `SaveSpec` object. 

23 

24 Args: 

25 tensor: the tensor to save or callable that produces a tensor to save. 

26 If the value is `None`, the `SaveSpec` is ignored. 

27 slice_spec: the slice to be saved. See `Variable.SaveSliceInfo`. 

28 name: the name to save the tensor under. 

29 dtype: The data type of the Tensor. Required if `tensor` is callable. 

30 Used for error checking in the restore op. 

31 device: The device generating and consuming this tensor. Required if 

32 `tensor` is callable. Used to group objects to save by device. 

33 """ 

34 self._tensor = tensor 

35 self.slice_spec = slice_spec 

36 self.name = name 

37 if callable(self._tensor): 

38 if dtype is None or device is None: 

39 raise AssertionError( 

40 "When passing a callable `tensor` to a SaveSpec, an explicit " 

41 "dtype and device must be provided.") 

42 self.dtype = dtype 

43 self.device = device 

44 else: 

45 self.dtype = tensor.dtype 

46 if device is not None: 

47 self.device = device 

48 else: 

49 self.device = tensor.device 

50 

51 @property 

52 def tensor(self): 

53 return self._tensor() if callable(self._tensor) else self._tensor 

54 

55 

56class SaveableObject: 

57 """Base class for saving and restoring saveable objects.""" 

58 

59 def __init__(self, op, specs, name): 

60 """Creates a `SaveableObject` object. 

61 

62 Args: 

63 op: the "producer" object that this class wraps; it produces a list of 

64 tensors to save. E.g., a "Variable" object saving its backing tensor. 

65 specs: a list of SaveSpec, each element of which describes one tensor to 

66 save under this object. All Tensors must be on the same device. 

67 name: the name to save the object under. 

68 """ 

69 self.op = op 

70 self.specs = specs 

71 self.name = name 

72 

73 @property 

74 def device(self): 

75 """The device for SaveSpec Tensors.""" 

76 return self.specs[0].device 

77 

78 def restore(self, restored_tensors, restored_shapes): 

79 """Restores this object from 'restored_tensors'. 

80 

81 Args: 

82 restored_tensors: the tensors that were loaded from a checkpoint 

83 restored_shapes: the shapes this object should conform to after 

84 restore, or None. 

85 

86 Returns: 

87 An operation that restores the state of the object. 

88 

89 Raises: 

90 ValueError: If the object cannot be restored using the provided 

91 parameters. 

92 """ 

93 # pylint: disable=unused-argument 

94 raise ValueError("Calling an abstract method.")