Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/legacy/saved_model/save.py: 35%

54 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras legacy SavedModel saving.""" 

16 

17import os 

18 

19import tensorflow.compat.v2 as tf 

20from absl import logging 

21 

22from keras.src import backend 

23from keras.protobuf import saved_metadata_pb2 

24from keras.protobuf import versions_pb2 

25from keras.src.saving.legacy import saving_utils 

26from keras.src.saving.legacy import serialization 

27from keras.src.saving.legacy.saved_model import constants 

28from keras.src.saving.legacy.saved_model import save_impl 

29from keras.src.saving.legacy.saved_model import utils 

30from keras.src.utils.generic_utils import LazyLoader 

31from keras.src.utils.io_utils import ask_to_proceed_with_overwrite 

32 

33# isort: off 

34from tensorflow.python.saved_model import save as save_lib 

35 

36# To avoid circular dependencies between keras/engine and keras/saving, 

37# code in keras/saving must delay imports. 

38 

39base_layer = LazyLoader("base_layer", globals(), "keras.src.engine.base_layer") 

40training_lib = LazyLoader("training_lib", globals(), "keras.src.engine.training") 

41 

42 

43def save( 

44 model, 

45 filepath, 

46 overwrite, 

47 include_optimizer, 

48 signatures=None, 

49 options=None, 

50 save_traces=True, 

51): 

52 """Saves a model as a SavedModel to the filepath. 

53 

54 Args: 

55 model: Keras model instance to be saved. 

56 filepath: String path to save the model. 

57 overwrite: whether to overwrite the existing filepath. 

58 include_optimizer: If True, save the model's optimizer state. 

59 signatures: Signatures to save with the SavedModel. Applicable to the 'tf' 

60 format only. Please see the `signatures` argument in 

61 `tf.saved_model.save` for details. 

62 options: (only applies to SavedModel format) `tf.saved_model.SaveOptions` 

63 object that specifies options for saving to SavedModel. 

64 save_traces: (only applies to SavedModel format) When enabled, the 

65 SavedModel will store the function traces for each layer. This 

66 can be disabled, so that only the configs of each layer are stored. 

67 Defaults to `True`. Disabling this will decrease serialization time 

68 and reduce file size, but it requires that all custom layers/models 

69 implement a `get_config()` method. 

70 

71 Raises: 

72 ValueError: if the model's inputs have not been defined. 

73 """ 

74 # If file exists and should not be overwritten. 

75 if not overwrite and os.path.exists(filepath): 

76 proceed = ask_to_proceed_with_overwrite(filepath) 

77 if not proceed: 

78 return 

79 

80 if save_traces: 

81 if save_impl.should_skip_serialization(model): 

82 saving_utils.raise_model_input_error(model) 

83 

84 if not include_optimizer: 

85 orig_optimizer = model.optimizer 

86 model.optimizer = None 

87 # TODO(b/180760306) Change to del model.optimizer if Layer's __delattr__ 

88 # calls AutoTrackable's __delattr__. 

89 model._delete_tracking("optimizer") 

90 

91 # Trace all functions and signatures with `training=0` instead of using an 

92 # already-set learning phase placeholder. 

93 # This is needed for compatibility reasons until learning phase setting 

94 # is removed from the public apis. 

95 with serialization.SharedObjectSavingScope(): 

96 with backend.deprecated_internal_learning_phase_scope(0): 

97 with utils.keras_option_scope(save_traces): 

98 saved_nodes, node_paths = save_lib.save_and_return_nodes( 

99 model, filepath, signatures, options 

100 ) 

101 

102 # Save all metadata to a separate file in the SavedModel directory. 

103 metadata = generate_keras_metadata(saved_nodes, node_paths) 

104 

105 with tf.io.gfile.GFile( 

106 tf.io.gfile.join(filepath, constants.SAVED_METADATA_PATH), "wb" 

107 ) as w: 

108 w.write(metadata.SerializeToString(deterministic=True)) 

109 

110 if not include_optimizer: 

111 model.optimizer = orig_optimizer 

112 

113 

114def generate_keras_metadata(saved_nodes, node_paths): 

115 """Constructs a KerasMetadata proto with the metadata of each object.""" 

116 metadata = saved_metadata_pb2.SavedMetadata() 

117 for node_id, node in enumerate(saved_nodes): 

118 if isinstance(node, base_layer.Layer): 

119 path = node_paths[node] 

120 if not path: 

121 node_path = "root" 

122 else: 

123 node_path = f"root.{'.'.join([ref.name for ref in path])}" 

124 

125 metadata.nodes.add( 

126 node_id=node_id, 

127 node_path=node_path, 

128 version=versions_pb2.VersionDef( 

129 producer=2, min_consumer=1, bad_consumers=[] 

130 ), 

131 identifier=node._object_identifier, 

132 metadata=node._tracking_metadata, 

133 ) 

134 

135 # Log warning if the node's class name conflicts with a Keras 

136 # built-in object. 

137 class_name = node.__class__.__name__ 

138 from keras.src.layers import serialization as layers_serialization 

139 

140 builtin_layer = layers_serialization.get_builtin_layer(class_name) 

141 if builtin_layer: 

142 if not isinstance(node, builtin_layer): 

143 logging.warning( 

144 "%s has the same name '%s' as a built-in Keras " 

145 "object. Consider renaming %s to avoid naming " 

146 "conflicts when loading with " 

147 "`tf.keras.models.load_model`. " 

148 "If renaming is not possible, pass " 

149 "the object in the `custom_objects` " 

150 "parameter of the load " 

151 "function.", 

152 node, 

153 class_name, 

154 node.__class__, 

155 ) 

156 

157 return metadata 

158