Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/base_serialization.py: 61%

28 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Helper classes that list&validate all attributes to serialize to SavedModel.""" 

16 

17import abc 

18 

19from tensorflow.python.keras.saving.saved_model import json_utils 

20from tensorflow.python.keras.saving.saved_model import utils 

21 

22 

23class SavedModelSaver(object, metaclass=abc.ABCMeta): 

24 """Saver defining the methods and properties used to serialize Keras objects. 

25 """ 

26 

27 def __init__(self, obj): 

28 self.obj = obj 

29 

30 @abc.abstractproperty 

31 def object_identifier(self): 

32 """String stored in object identifier field in the SavedModel proto. 

33 

34 Returns: 

35 A string with the object identifier, which is used at load time. 

36 """ 

37 raise NotImplementedError 

38 

39 @property 

40 def tracking_metadata(self): 

41 """String stored in metadata field in the SavedModel proto. 

42 

43 Returns: 

44 A serialized JSON storing information necessary for recreating this layer. 

45 """ 

46 # TODO(kathywu): check that serialized JSON can be loaded (e.g., if an 

47 # object is in the python property) 

48 return json_utils.Encoder().encode(self.python_properties) 

49 

50 def trackable_children(self, serialization_cache): 

51 """Lists all Trackable children connected to this object.""" 

52 if not utils.should_save_traces(): 

53 return {} 

54 

55 children = self.objects_to_serialize(serialization_cache) 

56 children.update(self.functions_to_serialize(serialization_cache)) 

57 return children 

58 

59 @abc.abstractproperty 

60 def python_properties(self): 

61 """Returns dictionary of python properties to save in the metadata. 

62 

63 This dictionary must be serializable and deserializable to/from JSON. 

64 

65 When loading, the items in this dict are used to initialize the object and 

66 define attributes in the revived object. 

67 """ 

68 raise NotImplementedError 

69 

70 @abc.abstractmethod 

71 def objects_to_serialize(self, serialization_cache): 

72 """Returns dictionary of extra checkpointable objects to serialize. 

73 

74 See `functions_to_serialize` for an explanation of this function's 

75 effects. 

76 

77 Args: 

78 serialization_cache: Dictionary passed to all objects in the same object 

79 graph during serialization. 

80 

81 Returns: 

82 A dictionary mapping attribute names to checkpointable objects. 

83 """ 

84 raise NotImplementedError 

85 

86 @abc.abstractmethod 

87 def functions_to_serialize(self, serialization_cache): 

88 """Returns extra functions to include when serializing a Keras object. 

89 

90 Normally, when calling exporting an object to SavedModel, only the 

91 functions and objects defined by the user are saved. For example: 

92 

93 ``` 

94 obj = tf.Module() 

95 obj.v = tf.Variable(1.) 

96 

97 @tf.function 

98 def foo(...): ... 

99 

100 obj.foo = foo 

101 

102 w = tf.Variable(1.) 

103 

104 tf.saved_model.save(obj, 'path/to/saved/model') 

105 loaded = tf.saved_model.load('path/to/saved/model') 

106 

107 loaded.v # Variable with the same value as obj.v 

108 loaded.foo # Equivalent to obj.foo 

109 loaded.w # AttributeError 

110 ``` 

111 

112 Assigning trackable objects to attributes creates a graph, which is used for 

113 both checkpointing and SavedModel serialization. 

114 

115 When the graph generated from attribute tracking is insufficient, extra 

116 objects and functions may be added at serialization time. For example, 

117 most models do not have their call function wrapped with a @tf.function 

118 decorator. This results in `model.call` not being saved. Since Keras objects 

119 should be revivable from the SavedModel format, the call function is added 

120 as an extra function to serialize. 

121 

122 This function and `objects_to_serialize` is called multiple times when 

123 exporting to SavedModel. Please use the cache to avoid generating new 

124 functions and objects. A fresh cache is created for each SavedModel export. 

125 

126 Args: 

127 serialization_cache: Dictionary passed to all objects in the same object 

128 graph during serialization. 

129 

130 Returns: 

131 A dictionary mapping attribute names to `Function` or 

132 `ConcreteFunction`. 

133 """ 

134 raise NotImplementedError