Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/signature_serialization.py: 23%

160 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Helpers for working with signatures in tf.saved_model.save.""" 

16 

17from absl import logging 

18 

19from tensorflow.python.eager import def_function 

20from tensorflow.python.eager import function as defun 

21from tensorflow.python.framework import composite_tensor 

22from tensorflow.python.framework import ops 

23from tensorflow.python.framework import tensor_spec 

24from tensorflow.python.ops import resource_variable_ops 

25from tensorflow.python.saved_model import function_serialization 

26from tensorflow.python.saved_model import revived_types 

27from tensorflow.python.saved_model import signature_constants 

28from tensorflow.python.trackable import base 

29from tensorflow.python.types import core 

30from tensorflow.python.util import compat 

31from tensorflow.python.util import nest 

32from tensorflow.python.util.compat import collections_abc 

33 

34 

35DEFAULT_SIGNATURE_ATTR = "_default_save_signature" 

36SIGNATURE_ATTRIBUTE_NAME = "signatures" 

37# Max number of warnings to show if signature contains normalized input names. 

38_NUM_DISPLAY_NORMALIZED_SIGNATURES = 5 

39 

40 

41def _get_signature(function): 

42 if (isinstance(function, (defun.Function, def_function.Function)) and 

43 function.input_signature is not None): 

44 function = function._get_concrete_function_garbage_collected() # pylint: disable=protected-access 

45 if not isinstance(function, defun.ConcreteFunction): 

46 return None 

47 return function 

48 

49 

50def _valid_signature(concrete_function): 

51 """Returns whether concrete function can be converted to a signature.""" 

52 if not concrete_function.outputs: 

53 # Functions without outputs don't make sense as signatures. We just don't 

54 # have any way to run an Operation with no outputs as a SignatureDef in the 

55 # 1.x style. 

56 return False 

57 try: 

58 _validate_inputs(concrete_function) 

59 _normalize_outputs(concrete_function.structured_outputs, "unused", "unused") 

60 except ValueError: 

61 return False 

62 return True 

63 

64 

65def _validate_inputs(concrete_function): 

66 """Raises error if input type is tf.Variable.""" 

67 if any(isinstance(inp, resource_variable_ops.VariableSpec) 

68 for inp in nest.flatten( 

69 concrete_function.structured_input_signature)): 

70 raise ValueError( 

71 f"Unable to serialize concrete_function '{concrete_function.name}'" 

72 f"with tf.Variable input. Functions that expect tf.Variable " 

73 "inputs cannot be exported as signatures.") 

74 

75 

76def _get_signature_name_changes(concrete_function): 

77 """Checks for user-specified signature input names that are normalized.""" 

78 # Map of {user-given name: normalized name} if the names are un-identical. 

79 name_changes = {} 

80 for signature_input_name, graph_input in zip( 

81 concrete_function.function_def.signature.input_arg, 

82 concrete_function.graph.inputs): 

83 try: 

84 user_specified_name = compat.as_str( 

85 graph_input.op.get_attr("_user_specified_name")) 

86 if signature_input_name.name != user_specified_name: 

87 name_changes[user_specified_name] = signature_input_name.name 

88 except ValueError: 

89 # Signature input does not have a user-specified name. 

90 pass 

91 return name_changes 

92 

93 

94def find_function_to_export(saveable_view): 

95 """Function to export, None if no suitable function was found.""" 

96 # If the user did not specify signatures, check the root object for a function 

97 # that can be made into a signature. 

98 children = saveable_view.list_children(saveable_view.root) 

99 

100 # TODO(b/205014194): Discuss removing this behaviour. It can lead to WTFs when 

101 # a user decides to annotate more functions with tf.function and suddenly 

102 # serving that model way later in the process stops working. 

103 possible_signatures = [] 

104 for name, child in children: 

105 if not isinstance(child, (def_function.Function, defun.ConcreteFunction)): 

106 continue 

107 if name == DEFAULT_SIGNATURE_ATTR: 

108 return child 

109 concrete = _get_signature(child) 

110 if concrete is not None and _valid_signature(concrete): 

111 possible_signatures.append(concrete) 

112 

113 if len(possible_signatures) == 1: 

114 single_function = possible_signatures[0] 

115 signature = _get_signature(single_function) 

116 if signature and _valid_signature(signature): 

117 return signature 

118 return None 

119 

120 

121def canonicalize_signatures(signatures): 

122 """Converts `signatures` into a dictionary of concrete functions.""" 

123 if signatures is None: 

124 return {}, {}, {} 

125 if not isinstance(signatures, collections_abc.Mapping): 

126 signatures = { 

127 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures} 

128 num_normalized_signatures_counter = 0 

129 concrete_signatures = {} 

130 wrapped_functions = {} 

131 defaults = {} 

132 for signature_key, function in signatures.items(): 

133 original_function = signature_function = _get_signature(function) 

134 if signature_function is None: 

135 raise ValueError( 

136 "Expected a TensorFlow function for which to generate a signature, " 

137 f"but got {function}. Only `tf.functions` with an input signature or " 

138 "concrete functions can be used as a signature.") 

139 

140 wrapped_functions[original_function] = signature_function = ( 

141 wrapped_functions.get(original_function) or 

142 function_serialization.wrap_cached_variables(original_function)) 

143 _validate_inputs(signature_function) 

144 if num_normalized_signatures_counter < _NUM_DISPLAY_NORMALIZED_SIGNATURES: 

145 signature_name_changes = _get_signature_name_changes(signature_function) 

146 if signature_name_changes: 

147 num_normalized_signatures_counter += 1 

148 logging.info( 

149 "Function `%s` contains input name(s) %s with unsupported " 

150 "characters which will be renamed to %s in the SavedModel.", 

151 compat.as_str(signature_function.graph.name), 

152 ", ".join(signature_name_changes.keys()), 

153 ", ".join(signature_name_changes.values())) 

154 # Re-wrap the function so that it returns a dictionary of Tensors. This 

155 # matches the format of 1.x-style signatures. 

156 # pylint: disable=cell-var-from-loop 

157 def signature_wrapper(**kwargs): 

158 structured_outputs = signature_function(**kwargs) 

159 return _normalize_outputs( 

160 structured_outputs, signature_function.name, signature_key) 

161 if hasattr(function, "__name__"): 

162 signature_wrapper.__name__ = "signature_wrapper_" + function.__name__ 

163 wrapped_function = def_function.function(signature_wrapper) 

164 tensor_spec_signature = {} 

165 if signature_function.structured_input_signature is not None: 

166 # The structured input signature may contain other non-tensor arguments. 

167 inputs = filter( 

168 lambda x: isinstance(x, tensor_spec.TensorSpec), 

169 nest.flatten(signature_function.structured_input_signature, 

170 expand_composites=True)) 

171 else: 

172 # Structured input signature isn't always defined for some functions. 

173 inputs = signature_function.inputs 

174 

175 for keyword, inp in zip( 

176 signature_function._arg_keywords, # pylint: disable=protected-access 

177 inputs): 

178 keyword = compat.as_str(keyword) 

179 if isinstance(inp, tensor_spec.TensorSpec): 

180 spec = tensor_spec.TensorSpec(inp.shape, inp.dtype, name=keyword) 

181 else: 

182 spec = tensor_spec.TensorSpec.from_tensor(inp, name=keyword) 

183 tensor_spec_signature[keyword] = spec 

184 final_concrete = wrapped_function._get_concrete_function_garbage_collected( # pylint: disable=protected-access 

185 **tensor_spec_signature) 

186 # pylint: disable=protected-access 

187 if len(final_concrete._arg_keywords) == 1: 

188 # If there is only one input to the signature, a very common case, then 

189 # ordering is unambiguous and we can let people pass a positional 

190 # argument. Since SignatureDefs are unordered (protobuf "map") multiple 

191 # arguments means we need to be keyword-only. 

192 final_concrete._num_positional_args = 1 

193 else: 

194 final_concrete._num_positional_args = 0 

195 # pylint: enable=protected-access 

196 concrete_signatures[signature_key] = final_concrete 

197 # pylint: enable=cell-var-from-loop 

198 if isinstance(function, core.GenericFunction): 

199 full_arg_spec = function._function_spec.fullargspec # pylint: disable=protected-access 

200 len_defaults = len(full_arg_spec.defaults or []) 

201 for arg, default in zip( 

202 full_arg_spec.args[-len_defaults:], full_arg_spec.defaults or [] 

203 ): 

204 if not default: 

205 continue 

206 defaults[(signature_key, arg)] = default 

207 return concrete_signatures, wrapped_functions, defaults 

208 

209 

210def _normalize_outputs(outputs, function_name, signature_key): 

211 """Normalize outputs if necessary and check that they are tensors.""" 

212 # Convert `outputs` to a dictionary (if it's not one already). 

213 if not isinstance(outputs, collections_abc.Mapping): 

214 # Check if `outputs` is a namedtuple. 

215 if hasattr(outputs, "_asdict"): 

216 outputs = outputs._asdict() 

217 else: 

218 if not isinstance(outputs, collections_abc.Sequence): 

219 outputs = [outputs] 

220 outputs = {("output_{}".format(output_index)): output 

221 for output_index, output in enumerate(outputs)} 

222 

223 # Check that the keys of `outputs` are strings and the values are Tensors. 

224 for key, value in outputs.items(): 

225 if not isinstance(key, compat.bytes_or_text_types): 

226 raise ValueError( 

227 f"Got a dictionary with a non-string key {key!r} in the output of " 

228 f"the function {compat.as_str_any(function_name)} used to generate " 

229 f"the SavedModel signature {signature_key!r}.") 

230 if not isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)): 

231 raise ValueError( 

232 f"Got a non-Tensor value {value!r} for key {key!r} in the output of " 

233 f"the function {compat.as_str_any(function_name)} used to generate " 

234 f"the SavedModel signature {signature_key!r}. " 

235 "Outputs for functions used as signatures must be a single Tensor, " 

236 "a sequence of Tensors, or a dictionary from string to Tensor.") 

237 return outputs 

238 

239 

240# _SignatureMap is immutable to ensure that users do not expect changes to be 

241# reflected in the SavedModel. Using public APIs, tf.saved_model.load() is the 

242# only way to create a _SignatureMap and there is no way to modify it. So we can 

243# safely ignore/overwrite ".signatures" attributes attached to objects being 

244# saved if they contain a _SignatureMap. A ".signatures" attribute containing 

245# any other type (e.g. a regular dict) will raise an exception asking the user 

246# to first "del obj.signatures" if they want it overwritten. 

247class _SignatureMap(collections_abc.Mapping, base.Trackable): 

248 """A collection of SavedModel signatures.""" 

249 

250 def __init__(self): 

251 self._signatures = {} 

252 

253 def _add_signature(self, name, concrete_function): 

254 """Adds a signature to the _SignatureMap.""" 

255 # Ideally this object would be immutable, but restore is streaming so we do 

256 # need a private API for adding new signatures to an existing object. 

257 self._signatures[name] = concrete_function 

258 

259 def __getitem__(self, key): 

260 return self._signatures[key] 

261 

262 def __iter__(self): 

263 return iter(self._signatures) 

264 

265 def __len__(self): 

266 return len(self._signatures) 

267 

268 def __repr__(self): 

269 return "_SignatureMap({})".format(self._signatures) 

270 

271 def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs): 

272 if save_type != base.SaveType.SAVEDMODEL: 

273 return {} 

274 

275 return { 

276 key: value for key, value in self.items() 

277 if isinstance(value, (def_function.Function, defun.ConcreteFunction)) 

278 } 

279 

280 

281revived_types.register_revived_type( 

282 "signature_map", 

283 lambda obj: isinstance(obj, _SignatureMap), 

284 versions=[revived_types.VersionedTypeRegistration( 

285 # Standard dependencies are enough to reconstruct the trackable 

286 # items in dictionaries, so we don't need to save any extra information. 

287 object_factory=lambda proto: _SignatureMap(), 

288 version=1, 

289 min_producer_version=1, 

290 min_consumer_version=1, 

291 setter=_SignatureMap._add_signature # pylint: disable=protected-access 

292 )]) 

293 

294 

295def create_signature_map(signatures): 

296 """Creates an object containing `signatures`.""" 

297 signature_map = _SignatureMap() 

298 for name, func in signatures.items(): 

299 # This true of any signature that came from canonicalize_signatures. Here as 

300 # a sanity check on saving; crashing on load (e.g. in _add_signature) would 

301 # be more problematic in case future export changes violated these 

302 # assertions. 

303 assert isinstance(func, defun.ConcreteFunction) 

304 assert isinstance(func.structured_outputs, collections_abc.Mapping) 

305 # pylint: disable=protected-access 

306 if len(func._arg_keywords) == 1: 

307 assert 1 == func._num_positional_args 

308 else: 

309 assert 0 == func._num_positional_args 

310 signature_map._add_signature(name, func) 

311 # pylint: enable=protected-access 

312 return signature_map 

313 

314 

315def validate_augmented_graph_view(augmented_graph_view): 

316 """Performs signature-related sanity checks on `augmented_graph_view`.""" 

317 for name, dep in augmented_graph_view.list_children( 

318 augmented_graph_view.root): 

319 if name == SIGNATURE_ATTRIBUTE_NAME: 

320 if not isinstance(dep, _SignatureMap): 

321 raise ValueError( 

322 f"Exporting an object {augmented_graph_view.root} which has an attribute " 

323 f"named '{SIGNATURE_ATTRIBUTE_NAME}'. This is a reserved attribute " 

324 "used to store SavedModel signatures in objects which come from " 

325 "`tf.saved_model.load`. Delete this attribute " 

326 f"(e.g. `del obj.{SIGNATURE_ATTRIBUTE_NAME}`) before saving if " 

327 "this shadowing is acceptable.") 

328 break