Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save.py: 38%

45 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras SavedModel serialization.""" 

16 

17import os 

18 

19from tensorflow.python.keras import backend as K 

20from tensorflow.python.keras.protobuf import saved_metadata_pb2 

21from tensorflow.python.keras.protobuf import versions_pb2 

22from tensorflow.python.keras.saving import saving_utils 

23from tensorflow.python.keras.saving.saved_model import constants 

24from tensorflow.python.keras.saving.saved_model import save_impl 

25from tensorflow.python.keras.saving.saved_model import utils 

26from tensorflow.python.keras.utils.generic_utils import LazyLoader 

27from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite 

28from tensorflow.python.platform import gfile 

29from tensorflow.python.saved_model import save as save_lib 

30 

31# To avoid circular dependencies between keras/engine and keras/saving, 

32# code in keras/saving must delay imports. 

33 

34base_layer = LazyLoader( 

35 "base_layer", globals(), 

36 "tensorflow.python.keras.engine.base_layer") 

37training_lib = LazyLoader( 

38 "training_lib", globals(), 

39 "tensorflow.python.keras.engine.training") 

40 

41 

42def save(model, filepath, overwrite, include_optimizer, signatures=None, 

43 options=None, save_traces=True): 

44 """Saves a model as a SavedModel to the filepath. 

45 

46 Args: 

47 model: Keras model instance to be saved. 

48 filepath: String path to save the model. 

49 overwrite: whether to overwrite the existing filepath. 

50 include_optimizer: If True, save the model's optimizer state. 

51 signatures: Signatures to save with the SavedModel. Applicable to the 'tf' 

52 format only. Please see the `signatures` argument in `tf.saved_model.save` 

53 for details. 

54 options: (only applies to SavedModel format) `tf.saved_model.SaveOptions` 

55 object that specifies options for saving to SavedModel. 

56 save_traces: (only applies to SavedModel format) When enabled, the 

57 SavedModel will store the function traces for each layer. This 

58 can be disabled, so that only the configs of each layer are stored. 

59 Defaults to `True`. Disabling this will decrease serialization time 

60 and reduce file size, but it requires that all custom layers/models 

61 implement a `get_config()` method. 

62 

63 Raises: 

64 ValueError: if the model's inputs have not been defined. 

65 """ 

66 # If file exists and should not be overwritten. 

67 if not overwrite and os.path.exists(filepath): 

68 proceed = ask_to_proceed_with_overwrite(filepath) 

69 if not proceed: 

70 return 

71 

72 if save_traces: 

73 if save_impl.should_skip_serialization(model): 

74 saving_utils.raise_model_input_error(model) 

75 

76 if not include_optimizer: 

77 orig_optimizer = model.optimizer 

78 model.optimizer = None 

79 # TODO(b/180760306) Change to del model.optimizer if Layer's __delattr__ 

80 # calls AutoTrackable's __delattr__. 

81 model._delete_tracking("optimizer") # pylint: disable=protected-access 

82 

83 # Trace all functions and signatures with `training=0` instead of using an 

84 # already-set learning phase placeholder. 

85 # This is needed for compatibility reasons until learning phase setting 

86 # is removed from the public apis. 

87 with K.deprecated_internal_learning_phase_scope(0): 

88 with utils.keras_option_scope(save_traces): 

89 saved_nodes, node_paths = save_lib.save_and_return_nodes( 

90 model, filepath, signatures, options) 

91 

92 # Save all metadata to a separate file in the SavedModel directory. 

93 metadata = generate_keras_metadata(saved_nodes, node_paths) 

94 

95 with gfile.GFile( 

96 os.path.join(filepath, constants.SAVED_METADATA_PATH), "wb") as w: 

97 w.write(metadata.SerializeToString(deterministic=True)) 

98 

99 if not include_optimizer: 

100 model.optimizer = orig_optimizer 

101 

102 

103def generate_keras_metadata(saved_nodes, node_paths): 

104 """Constructs a KerasMetadata proto with the metadata of each keras object.""" 

105 metadata = saved_metadata_pb2.SavedMetadata() 

106 

107 for node_id, node in enumerate(saved_nodes): 

108 if isinstance(node, base_layer.Layer): 

109 path = node_paths[node] 

110 if not path: 

111 node_path = "root" 

112 else: 

113 node_path = "root.{}".format( 

114 ".".join([ref.name for ref in path])) 

115 

116 metadata.nodes.add( 

117 node_id=node_id, 

118 node_path=node_path, 

119 version=versions_pb2.VersionDef( 

120 producer=1, min_consumer=1, bad_consumers=[]), 

121 identifier=node._object_identifier, # pylint: disable=protected-access 

122 metadata=node._tracking_metadata) # pylint: disable=protected-access 

123 

124 return metadata