Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/legacy/saved_model/base_serialization.py: 65%

31 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Helper classes that list&validate all attributes to serialize to 

16SavedModel.""" 

17 

18from __future__ import absolute_import 

19from __future__ import division 

20from __future__ import print_function 

21 

22import abc 

23 

24from keras.src.saving.legacy.saved_model import json_utils 

25from keras.src.saving.legacy.saved_model import utils 

26 

27 

28class SavedModelSaver(object, metaclass=abc.ABCMeta): 

29 """Saver defining the methods and properties used to serialize Keras 

30 objects.""" 

31 

32 def __init__(self, obj): 

33 self.obj = obj 

34 

35 @abc.abstractproperty 

36 def object_identifier(self): 

37 """String stored in object identifier field in the SavedModel proto. 

38 

39 Returns: 

40 A string with the object identifier, which is used at load time. 

41 """ 

42 raise NotImplementedError 

43 

44 @property 

45 def tracking_metadata(self): 

46 """String stored in metadata field in the SavedModel proto. 

47 

48 Returns: 

49 A serialized JSON storing information necessary for recreating this 

50 layer. 

51 """ 

52 # TODO(kathywu): check that serialized JSON can be loaded (e.g., if an 

53 # object is in the python property) 

54 return json_utils.Encoder().encode(self.python_properties) 

55 

56 def trackable_children(self, serialization_cache): 

57 """Lists all Trackable children connected to this object.""" 

58 if not utils.should_save_traces(): 

59 return {} 

60 

61 children = self.objects_to_serialize(serialization_cache) 

62 children.update(self.functions_to_serialize(serialization_cache)) 

63 return children 

64 

65 @abc.abstractproperty 

66 def python_properties(self): 

67 """Returns dictionary of python properties to save in the metadata. 

68 

69 This dictionary must be serializable and deserializable to/from JSON. 

70 

71 When loading, the items in this dict are used to initialize the object 

72 and define attributes in the revived object. 

73 """ 

74 raise NotImplementedError 

75 

76 @abc.abstractmethod 

77 def objects_to_serialize(self, serialization_cache): 

78 """Returns dictionary of extra checkpointable objects to serialize. 

79 

80 See `functions_to_serialize` for an explanation of this function's 

81 effects. 

82 

83 Args: 

84 serialization_cache: Dictionary passed to all objects in the same 

85 object graph during serialization. 

86 

87 Returns: 

88 A dictionary mapping attribute names to checkpointable objects. 

89 """ 

90 raise NotImplementedError 

91 

92 @abc.abstractmethod 

93 def functions_to_serialize(self, serialization_cache): 

94 """Returns extra functions to include when serializing a Keras object. 

95 

96 Normally, when calling exporting an object to SavedModel, only the 

97 functions and objects defined by the user are saved. For example: 

98 

99 ``` 

100 obj = tf.Module() 

101 obj.v = tf.Variable(1.) 

102 

103 @tf.function 

104 def foo(...): ... 

105 

106 obj.foo = foo 

107 

108 w = tf.Variable(1.) 

109 

110 tf.saved_model.save(obj, 'path/to/saved/model') 

111 loaded = tf.saved_model.load('path/to/saved/model') 

112 

113 loaded.v # Variable with the same value as obj.v 

114 loaded.foo # Equivalent to obj.foo 

115 loaded.w # AttributeError 

116 ``` 

117 

118 Assigning trackable objects to attributes creates a graph, which is used 

119 for both checkpointing and SavedModel serialization. 

120 

121 When the graph generated from attribute tracking is insufficient, extra 

122 objects and functions may be added at serialization time. For example, 

123 most models do not have their call function wrapped with a @tf.function 

124 decorator. This results in `model.call` not being saved. Since Keras 

125 objects should be revivable from the SavedModel format, the call 

126 function is added as an extra function to serialize. 

127 

128 This function and `objects_to_serialize` is called multiple times when 

129 exporting to SavedModel. Please use the cache to avoid generating new 

130 functions and objects. A fresh cache is created for each SavedModel 

131 export. 

132 

133 Args: 

134 serialization_cache: Dictionary passed to all objects in the same 

135 object graph during serialization. 

136 

137 Returns: 

138 A dictionary mapping attribute names to `Function` or 

139 `ConcreteFunction`. 

140 """ 

141 raise NotImplementedError 

142