Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/method_name_updater.py: 32%

38 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""SignatureDef method name utility functions. 

16 

17Utility functions for manipulating signature_def.method_names. 

18""" 

19 

20from tensorflow.python.lib.io import file_io 

21from tensorflow.python.platform import tf_logging 

22from tensorflow.python.saved_model import constants 

23from tensorflow.python.saved_model import loader_impl as loader 

24from tensorflow.python.util import compat 

25from tensorflow.python.util.tf_export import tf_export 

26 

27 

28# TODO(jdchung): Consider integrated this into the saved_model_cli so that users 

29# could do this from the command line directly. 

30@tf_export(v1=["saved_model.signature_def_utils.MethodNameUpdater"]) 

31class MethodNameUpdater(object): 

32 """Updates the method name(s) of the SavedModel stored in the given path. 

33 

34 The `MethodNameUpdater` class provides the functionality to update the method 

35 name field in the signature_defs of the given SavedModel. For example, it 

36 can be used to replace the `predict` `method_name` to `regress`. 

37 

38 Typical usages of the `MethodNameUpdater` 

39 ```python 

40 ... 

41 updater = tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater( 

42 export_dir) 

43 # Update all signature_defs with key "foo" in all meta graph defs. 

44 updater.replace_method_name(signature_key="foo", method_name="regress") 

45 # Update a single signature_def with key "bar" in the meta graph def with 

46 # tags ["serve"] 

47 updater.replace_method_name(signature_key="bar", method_name="classify", 

48 tags="serve") 

49 updater.save(new_export_dir) 

50 ``` 

51 

52 Note: This function will only be available through the v1 compatibility 

53 library as tf.compat.v1.saved_model.builder.MethodNameUpdater. 

54 """ 

55 

56 def __init__(self, export_dir): 

57 """Creates an MethodNameUpdater object. 

58 

59 Args: 

60 export_dir: Directory containing the SavedModel files. 

61 

62 Raises: 

63 IOError: If the saved model file does not exist, or cannot be successfully 

64 parsed. 

65 """ 

66 self._export_dir = export_dir 

67 self._saved_model = loader.parse_saved_model(export_dir) 

68 

69 def replace_method_name(self, signature_key, method_name, tags=None): 

70 """Replaces the method_name in the specified signature_def. 

71 

72 This will match and replace multiple sig defs iff tags is None (i.e when 

73 multiple `MetaGraph`s have a signature_def with the same key). 

74 If tags is not None, this will only replace a single signature_def in the 

75 `MetaGraph` with matching tags. 

76 

77 Args: 

78 signature_key: Key of the signature_def to be updated. 

79 method_name: new method_name to replace the existing one. 

80 tags: A tag or sequence of tags identifying the `MetaGraph` to update. If 

81 None, all meta graphs will be updated. 

82 Raises: 

83 ValueError: if signature_key or method_name are not defined or 

84 if no metagraphs were found with the associated tags or 

85 if no meta graph has a signature_def that matches signature_key. 

86 """ 

87 if not signature_key: 

88 raise ValueError("`signature_key` must be defined.") 

89 if not method_name: 

90 raise ValueError("`method_name` must be defined.") 

91 

92 if (tags is not None and not isinstance(tags, list)): 

93 tags = [tags] 

94 found_match = False 

95 for meta_graph_def in self._saved_model.meta_graphs: 

96 if tags is None or set(tags) == set(meta_graph_def.meta_info_def.tags): 

97 if signature_key not in meta_graph_def.signature_def: 

98 raise ValueError( 

99 f"MetaGraphDef associated with tags {tags} " 

100 f"does not have a signature_def with key: '{signature_key}'. " 

101 "This means either you specified the wrong signature key or " 

102 "forgot to put the signature_def with the corresponding key in " 

103 "your SavedModel.") 

104 meta_graph_def.signature_def[signature_key].method_name = method_name 

105 found_match = True 

106 

107 if not found_match: 

108 raise ValueError( 

109 f"MetaGraphDef associated with tags {tags} could not be found in " 

110 "SavedModel. This means either you specified invalid tags or your " 

111 "SavedModel does not have a MetaGraphDef with the specified tags.") 

112 

113 def save(self, new_export_dir=None): 

114 """Saves the updated `SavedModel`. 

115 

116 Args: 

117 new_export_dir: Path where the updated `SavedModel` will be saved. If 

118 None, the input `SavedModel` will be overriden with the updates. 

119 

120 Raises: 

121 errors.OpError: If there are errors during the file save operation. 

122 """ 

123 

124 is_input_text_proto = file_io.file_exists( 

125 file_io.join( 

126 compat.as_bytes(self._export_dir), 

127 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))) 

128 if not new_export_dir: 

129 new_export_dir = self._export_dir 

130 

131 if is_input_text_proto: 

132 # TODO(jdchung): Add a util for the path creation below. 

133 path = file_io.join( 

134 compat.as_bytes(new_export_dir), 

135 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)) 

136 file_io.write_string_to_file(path, str(self._saved_model)) 

137 else: 

138 path = file_io.join( 

139 compat.as_bytes(new_export_dir), 

140 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) 

141 file_io.write_string_to_file( 

142 path, self._saved_model.SerializeToString(deterministic=True)) 

143 tf_logging.info("SavedModel written to: %s", compat.as_text(path))