Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/function_deserialization.py: 16%

301 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Tools for deserializing `Function`s.""" 

16 

17import collections 

18import pprint 

19import re 

20 

21from absl import logging 

22 

23from tensorflow.core.protobuf import saved_object_graph_pb2 

24from tensorflow.python.eager import def_function 

25from tensorflow.python.eager import function as function_lib 

26from tensorflow.python.eager.polymorphic_function import function_spec as function_spec_lib 

27from tensorflow.python.framework import func_graph as func_graph_lib 

28from tensorflow.python.framework import function_def_to_graph as function_def_lib 

29from tensorflow.python.framework import op_def_registry 

30from tensorflow.python.framework import ops 

31from tensorflow.python.framework import tensor_spec 

32from tensorflow.python.framework import type_spec 

33from tensorflow.python.ops import array_ops 

34from tensorflow.python.ops import custom_gradient 

35from tensorflow.python.ops import default_gradient 

36from tensorflow.python.ops import resource_variable_ops 

37from tensorflow.python.saved_model import nested_structure_coder 

38from tensorflow.python.util import compat 

39from tensorflow.python.util import nest 

40from tensorflow.python.util import tf_decorator 

41from tensorflow.python.util import tf_inspect 

42 

43 

44def _is_tensor(t): 

45 return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable)) 

46 

47 

48# TODO(b/205016027): Update this to just use ConcreteFunction.__call__ with the 

49# structured signature. 

50def _call_concrete_function(function, inputs): 

51 """Calls a restored Function with structured inputs. 

52 

53 This differs from `function.__call__` in that inputs and outputs are 

54 structured and that it casts inputs to tensors if needed. 

55 

56 Note: this does not checks that non-tensor inputs match. That should be 

57 done before via `_concrete_function_callable_with`. 

58 

59 Args: 

60 function: ConcreteFunction to call. 

61 inputs: Structured inputs compatible with 

62 `function.graph.structured_input_signature`. 

63 

64 Returns: 

65 The structured function output. 

66 """ 

67 expected_structure = function.graph.structured_input_signature 

68 flatten_inputs = nest.flatten_up_to( 

69 expected_structure, inputs, expand_composites=True) 

70 flatten_expected = nest.flatten(expected_structure, expand_composites=True) 

71 tensor_inputs = [] 

72 for arg, expected in zip(flatten_inputs, flatten_expected): 

73 if isinstance(expected, tensor_spec.TensorSpec): 

74 tensor_inputs.append( 

75 ops.convert_to_tensor(arg, dtype_hint=expected.dtype)) 

76 elif isinstance(expected, resource_variable_ops.VariableSpec): 

77 tensor_inputs.append(arg) 

78 result = function._call_flat(tensor_inputs, function.captured_inputs) # pylint: disable=protected-access 

79 if isinstance(result, ops.Operation): 

80 return None 

81 return result 

82 

83 

84def _try_convert_to_tensor_spec(arg, dtype_hint): 

85 """Returns None or TensorSpec obtained if `arg` is converted to tensor.""" 

86 try: 

87 # Note: try conversion in a FuncGraph to avoid polluting current context. 

88 with func_graph_lib.FuncGraph(name="guess_conversion").as_default(): 

89 result = ops.convert_to_tensor(arg, dtype_hint=dtype_hint) 

90 return tensor_spec.TensorSpec(shape=result.shape, dtype=result.dtype) 

91 except (TypeError, ValueError): 

92 return None 

93 

94 

95def _concrete_function_callable_with(function, inputs, allow_conversion): 

96 """Returns whether concrete `function` can be called with `inputs`.""" 

97 expected_structure = function.graph.structured_input_signature 

98 try: 

99 flatten_inputs = nest.flatten_up_to(expected_structure, inputs) 

100 except (TypeError, ValueError): 

101 return False 

102 

103 for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): 

104 if isinstance(expected, tensor_spec.TensorSpec): 

105 if allow_conversion: 

106 arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype) 

107 if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec): 

108 return False 

109 if arg.dtype != expected.dtype: 

110 return False 

111 if not expected.shape.is_compatible_with(arg.shape): 

112 return False 

113 elif isinstance(expected, type_spec.TypeSpec): 

114 if not expected.is_compatible_with(arg): 

115 return False 

116 elif _is_tensor(arg): 

117 if id(arg) != id(expected): 

118 return False 

119 else: 

120 if arg != expected: 

121 return False 

122 return True 

123 

124 

125def _deserialize_function_spec_as_nonmethod(function_spec_proto): 

126 """Deserialize a FunctionSpec object from its proto representation.""" 

127 typeless_fullargspec = nested_structure_coder.decode_proto( 

128 function_spec_proto.fullargspec) 

129 

130 # Convert a method function into a non method. 

131 if function_spec_proto.is_method or ( 

132 typeless_fullargspec.args and typeless_fullargspec.args[0] == "self" 

133 ): 

134 if not typeless_fullargspec.args: 

135 raise NotImplementedError( 

136 "Cannot deserialize a method function without a named " 

137 "'self' argument.") 

138 args = typeless_fullargspec.args[1:] 

139 else: 

140 args = typeless_fullargspec.args 

141 

142 fullargspec = tf_inspect.FullArgSpec( 

143 args=args, 

144 varargs=typeless_fullargspec.varargs, 

145 varkw=typeless_fullargspec.varkw, 

146 defaults=typeless_fullargspec.defaults, 

147 kwonlyargs=typeless_fullargspec.kwonlyargs, 

148 kwonlydefaults=typeless_fullargspec.kwonlydefaults, 

149 annotations=typeless_fullargspec.annotations) 

150 input_signature = nested_structure_coder.decode_proto( 

151 function_spec_proto.input_signature) 

152 

153 # See `tf.function` and the JitCompile proto for details. 

154 jit_compile = { 

155 saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT: None, 

156 saved_object_graph_pb2.FunctionSpec.JitCompile.ON: True, 

157 saved_object_graph_pb2.FunctionSpec.JitCompile.OFF: False, 

158 }.get(function_spec_proto.jit_compile) 

159 

160 return function_spec_lib.FunctionSpec.from_fullargspec_and_signature( 

161 fullargspec=fullargspec, 

162 input_signature=input_signature, 

163 jit_compile=jit_compile) 

164 

165 

166# TODO(b/205016761): The fact that we can't derive ConcreteFunction calling 

167# conventions from the serialized input spec right now is unfortunate. Merging 

168# these would be good, maybe by adding TensorSpec names to cache keys so renamed 

169# keyword arguments would yield different ConcreteFunctions. 

170def setup_bare_concrete_function(saved_bare_concrete_function, 

171 concrete_functions): 

172 """Makes a restored bare concrete function callable.""" 

173 concrete_function = concrete_functions[ 

174 saved_bare_concrete_function.concrete_function_name] 

175 # pylint: disable=protected-access 

176 concrete_function._arg_keywords = ( 

177 saved_bare_concrete_function.argument_keywords) 

178 concrete_function._num_positional_args = ( 

179 saved_bare_concrete_function.allowed_positional_arguments) 

180 if saved_bare_concrete_function.HasField("function_spec"): 

181 function_spec = _deserialize_function_spec_as_nonmethod( 

182 saved_bare_concrete_function.function_spec) 

183 concrete_function._set_function_spec(function_spec) 

184 # pylint: enable=protected-access 

185 concrete_function.add_to_graph() 

186 return concrete_function 

187 

188 

189class RestoredFunction(def_function.Function): 

190 """Wrapper class for a function that has been restored from saved state. 

191 

192 See `def_function.Function`. 

193 """ 

194 

195 def __init__(self, python_function, name, function_spec, concrete_functions): 

196 # TODO(b/205016819): We may enable autograph once exceptions are supported. 

197 super(RestoredFunction, self).__init__( 

198 python_function, 

199 name, 

200 autograph=False, 

201 jit_compile=function_spec.jit_compile) 

202 self.concrete_functions = concrete_functions 

203 self._function_spec = function_spec 

204 

205 # Prevent RestoredFunction from spamming users with frequent tracing 

206 # warnings. 

207 self._omit_frequent_tracing_warning = True 

208 

209 @property 

210 def _run_functions_eagerly(self): 

211 # We do not have access to the original python function, and thus, we 

212 # cannot meaningfully do anything but call our concrete function graphs 

213 # under the hood. 

214 # 

215 # Attempting to call our bespoke python function (i.e. 

216 # `restored_function_body`) will work so long as the user passes in all 

217 # required and optional arguments. If an optional argument is missing, 

218 # however, the call will break. For this reason, we instead skip the 

219 # eager call path altogether if a user has enabled eager function execution 

220 # via `tf.config.run_functions_eagerly`. 

221 return False 

222 

223 def _list_all_concrete_functions(self): 

224 return self.concrete_functions 

225 

226 def _list_all_concrete_functions_for_serialization(self): 

227 return self.concrete_functions 

228 

229 def _compiler_with_scope(self, scope): 

230 func = super(RestoredFunction, self)._compiler_with_scope(scope) 

231 func._function_spec = self._function_spec # pylint: disable=protected-access 

232 return func 

233 

234 

235def recreate_function(saved_function, concrete_functions): 

236 """Creates a `Function` from a `SavedFunction`. 

237 

238 Args: 

239 saved_function: `SavedFunction` proto. 

240 concrete_functions: map from function name to `ConcreteFunction`. As a side 

241 effect of this function, the `FunctionSpec` from `saved_function` is added 

242 to each `ConcreteFunction` in this map. 

243 

244 Returns: 

245 A `Function`. 

246 """ 

247 # TODO(b/205017389): Construct a `Function` with the cache populated 

248 # instead of creating a new `Function` backed by a Python layer to 

249 # glue things together. Current approach is nesting functions deeper for each 

250 # serialization cycle. 

251 

252 # Note: handling method functions is tricky since make_decorator does not 

253 # allows control of "ismethod". Additionally since restored functions do 

254 # not behave as methods i.e. they always use the same captured tensors 

255 # independent of the object they are bound to, there is little value on 

256 # propagating that correctly. 

257 # 

258 # Ideally this conversion should happen at serialization time. But since 

259 # there are SavedModels which have "ismethod" populated and have an extra 

260 # argument that they expect to be ignored, we do it at deserialization. 

261 function_spec = _deserialize_function_spec_as_nonmethod( 

262 saved_function.function_spec) 

263 

264 def restored_function_body(*args, **kwargs): 

265 """Calls a restored function or raises an error if no matching function.""" 

266 if not saved_function.concrete_functions: 

267 raise ValueError("Found zero restored functions for caller function.") 

268 # This is the format of function.graph.structured_input_signature. At this 

269 # point, the args and kwargs have already been canonicalized. 

270 inputs = (args, kwargs) 

271 

272 # First try to find a concrete function that can be called without input 

273 # conversions. This allows one to pick a more specific trace in case there 

274 # was also a more expensive one that supported tensors. 

275 for allow_conversion in [False, True]: 

276 for function_name in saved_function.concrete_functions: 

277 function = concrete_functions[function_name] 

278 if any([inp is None for inp in function.captured_inputs]): 

279 raise ValueError("Looks like you are trying to run a loaded " 

280 "non-Keras model that was trained using " 

281 "tf.distribute.experimental.ParameterServerStrategy " 

282 "with variable partitioning, which is not currently " 

283 "supported. Try using Keras to define your model " 

284 "if possible.") 

285 if _concrete_function_callable_with(function, inputs, allow_conversion): 

286 return _call_concrete_function(function, inputs) 

287 

288 signature_descriptions = [] 

289 

290 def _pretty_format_positional(positional): 

291 return "Positional arguments ({} total):\n * {}".format( 

292 len(positional), 

293 "\n * ".join(pprint.pformat(a) for a in positional)) 

294 

295 for index, function_name in enumerate(saved_function.concrete_functions): 

296 concrete_function = concrete_functions[function_name] 

297 positional, keyword = concrete_function.structured_input_signature 

298 signature_descriptions.append( 

299 "Option {}:\n {}\n Keyword arguments: {}".format( 

300 index + 1, _pretty_format_positional(positional), keyword)) 

301 raise ValueError( 

302 "Could not find matching concrete function to call loaded from the " 

303 f"SavedModel. Got:\n {_pretty_format_positional(args)}\n Keyword " 

304 f"arguments: {kwargs}\n\n Expected these arguments to match one of the " 

305 f"following {len(saved_function.concrete_functions)} option(s):\n\n" 

306 f"{(chr(10)+chr(10)).join(signature_descriptions)}") 

307 

308 concrete_function_objects = [] 

309 for concrete_function_name in saved_function.concrete_functions: 

310 concrete_function_objects.append(concrete_functions[concrete_function_name]) 

311 

312 for cf in concrete_function_objects: 

313 cf._set_function_spec(function_spec) # pylint: disable=protected-access 

314 

315 restored_function = RestoredFunction(restored_function_body, 

316 restored_function_body.__name__, 

317 function_spec, concrete_function_objects) 

318 

319 return tf_decorator.make_decorator( 

320 restored_function_body, 

321 restored_function, 

322 decorator_argspec=function_spec.fullargspec) 

323 

324 

325def load_function_def_library(library, 

326 saved_object_graph=None, 

327 load_shared_name_suffix=None, 

328 wrapper_function=None): 

329 """Load a set of functions as concrete functions without captured inputs. 

330 

331 Functions names are manipulated during load such that they do not overlap 

332 with previously created ones. 

333 

334 Gradients are re-registered under new names. Ops that reference the gradients 

335 are updated to reflect the new registered names. 

336 

337 Args: 

338 library: FunctionDefLibrary proto message. 

339 saved_object_graph: SavedObjectGraph proto message. If not passed in, 

340 concrete function structured signatures and outputs will not be set. 

341 load_shared_name_suffix: If specified, used to uniquify shared names. 

342 Otherwise, a unique name is generated. 

343 wrapper_function: An object that will be wrapped on newly created functions. 

344 

345 Returns: 

346 Map of original function names in the library to instances of 

347 `ConcreteFunction` without captured inputs. 

348 

349 Raises: 

350 ValueError: if functions dependencies have a cycle. 

351 """ 

352 library_function_names = set(fdef.signature.name for fdef in library.function) 

353 functions = {} 

354 renamed_functions = {} 

355 

356 # Our graph building code currently requires functions to be registered with 

357 # some tf.Graph in order to import functions using the 

358 # op-name-is-function-name calling convention. To avoid leaking memory into 

359 # the global default graph when executing eagerly, we create a temporary 

360 # Graph. 

361 # 

362 # TODO(b/205023033): Make this Graph creation unnecessary when executing 

363 # eagerly by fixing function_def_to_graph_def. 

364 if ops.executing_eagerly_outside_functions(): 

365 graph = ops.Graph() 

366 else: 

367 graph = ops.get_default_graph() 

368 

369 if load_shared_name_suffix is None: 

370 load_shared_name_suffix = "_load_{}".format(ops.uid()) 

371 

372 # Custom gradient functions must be re-registered under new UIDs. 

373 library_gradient_names = {} # Maps old op type to old function name 

374 new_gradient_op_types = {} # Maps old gradient op type to new op type. 

375 gradients_to_register = {} # Maps old function name to new op type 

376 for gdef in library.registered_gradients: 

377 if gdef.registered_op_type: 

378 new_op_type = custom_gradient.generate_name() 

379 old_op_type = compat.as_bytes(gdef.registered_op_type) 

380 

381 library_gradient_names[old_op_type] = gdef.gradient_func 

382 new_gradient_op_types[old_op_type] = new_op_type 

383 gradients_to_register[gdef.gradient_func] = new_op_type 

384 

385 function_deps = {} 

386 for fdef in library.function: 

387 function_deps[fdef.signature.name] = _list_function_deps( 

388 fdef, library_function_names, library_gradient_names) 

389 

390 loaded_gradients = {} 

391 for fdef in _sort_function_defs(library, function_deps): 

392 orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, 

393 new_gradient_op_types) 

394 

395 # Setup function signatures and outputs 

396 # 

397 # When concrete functions are created normally (i.e. when they're originally 

398 # created and not loaded via saved model), the inputs and outputs are 

399 # calculated based on the values passed in by the user and returned from the 

400 # original function, respectively. We don't have access to those anymore at 

401 # restore time, so we must instead pass them to the FuncGraph explicitly. 

402 structured_input_signature = None 

403 structured_outputs = None 

404 if (saved_object_graph is not None and 

405 orig_name in saved_object_graph.concrete_functions): 

406 # TODO(b/204324043): Offload the deserialization of the protos to the 

407 # first class objects by passing the actual protos. This is blocked on 

408 # importing `nested_structure_coder` in function.py causing a circular 

409 # dependency. 

410 proto = saved_object_graph.concrete_functions[orig_name] 

411 structured_input_signature = nested_structure_coder.decode_proto( 

412 proto.canonicalized_input_signature) 

413 structured_outputs = nested_structure_coder.decode_proto( 

414 proto.output_signature) 

415 

416 # There is no need to copy all functions into the function def graph. It 

417 # leads to a O(n^2) increase of memory when importing functions and the 

418 # extra function definitions are a no-op since they already imported as a 

419 # function before and passed in explicitly (due to the topologic sort 

420 # import). 

421 with graph.as_default(): 

422 func_graph = function_def_lib.function_def_to_graph( 

423 fdef, 

424 structured_input_signature=structured_input_signature, 

425 structured_outputs=structured_outputs) 

426 # Restores gradients for function-call ops (not the same as ops that use 

427 # custom gradients) 

428 _restore_gradient_functions(func_graph, renamed_functions, loaded_gradients) 

429 

430 for dep in function_deps[orig_name]: 

431 functions[dep].add_to_graph(func_graph) 

432 

433 # We do not initialize the new ConcreteFunction's function_spec and/or 

434 # arg_keywords here (which are used to parse the structured and flat 

435 # signatures, respectively). ConcreteFunction that are part of a saved 

436 # function is set up later by recreate_function(); and bare ConcreteFunction 

437 # is set up by by setup_bare_concrete_function(). 

438 # However, we copy the FunctionDef attributes to the new ConcreteFunction, 

439 # excluding the "_input_shapes", which may cause an error during input shape 

440 # initialization at a later stage. 

441 if "_input_shapes" in fdef.attr: 

442 del fdef.attr["_input_shapes"] 

443 func = function_lib.ConcreteFunction(func_graph, attrs=fdef.attr) 

444 if wrapper_function: 

445 func = wrapper_function(func) 

446 func.add_to_graph(graph) 

447 

448 functions[orig_name] = func 

449 renamed_functions[func.name] = func 

450 if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()): 

451 # TODO(b/150708051): Remove this hack once TensorRT SavedModel integration 

452 # is fixed. Currently it's leaking memory to maintain bug compatibility 

453 # with previous behavior. 

454 func.add_to_graph(ops.get_default_graph()) 

455 

456 if orig_name in gradients_to_register: 

457 gradient_op_type = gradients_to_register[orig_name] 

458 loaded_gradients[compat.as_bytes(gradient_op_type)] = func 

459 ops.RegisterGradient(gradient_op_type)(_gen_gradient_func(func)) 

460 

461 return functions 

462 

463 

464def _gen_gradient_func(func): 

465 """Wraps a deserialized function.""" 

466 

467 def gradient_func(unused_op, *result_grads): 

468 # Replace all `None` arguments, because the traced custom gradient function 

469 # expects tensors. Replacing with zeros is correct since the `None` values 

470 # occur when the gradient is unconnected, and thus the gradient is 

471 # "statically proven to be zero." See `tf.UnconnectedGradients` for details. 

472 

473 def none_to_zero(x, t): 

474 if x is not None: 

475 return x 

476 

477 shape, dtype = default_gradient.shape_and_dtype(t) 

478 

479 if shape.is_fully_defined(): 

480 return default_gradient.zeros_like(t) 

481 

482 dims = [] 

483 if shape.rank is not None: 

484 dims = [1 if d is None else d for d in shape.as_list()] 

485 

486 return array_ops.zeros(dims, dtype) 

487 

488 result_grads = [ 

489 none_to_zero(x, t) for (x, t) in zip(result_grads, func.graph.inputs) 

490 ] 

491 

492 return func(*result_grads) 

493 

494 return gradient_func 

495 

496 

497def _restore_gradient_functions(func_graph, renamed_functions, 

498 loaded_gradients): 

499 """Populate function op's _gradient_function with default gradient.""" 

500 for op in func_graph.get_operations(): 

501 # TODO(b/205024208): This code assumes that the gradient registered for this 

502 # function call is the default gradient for the function and not a custom 

503 # one. 

504 if op.type in ["StatefulPartitionedCall", "PartitionedCall"]: 

505 function = renamed_functions[compat.as_bytes( 

506 op.node_def.attr["f"].func.name)] 

507 op._gradient_function = function._get_gradient_function() # pylint: disable=protected-access 

508 try: 

509 gradient_op_type = op.get_attr("_gradient_op_type") 

510 except ValueError: 

511 pass 

512 else: 

513 if gradient_op_type in loaded_gradients: 

514 grad_fn = loaded_gradients[gradient_op_type] 

515 grad_fn._num_positional_args = len(op.inputs) # pylint: disable=protected-access 

516 grad_fn._arg_keywords = [inp.name for inp in op.inputs] # pylint: disable=protected-access 

517 

518 

519def _sort_function_defs(library, function_deps): 

520 """Return a topologic sort of FunctionDefs in a library.""" 

521 edges = collections.defaultdict(list) 

522 in_count = collections.defaultdict(lambda: 0) 

523 

524 for fname, deps in function_deps.items(): 

525 for dep in deps: 

526 edges[dep].append(fname) 

527 in_count[fname] += 1 

528 ready = [ 

529 fdef.signature.name 

530 for fdef in library.function 

531 if in_count[fdef.signature.name] == 0 

532 ] 

533 output = [] 

534 while ready: 

535 node = ready.pop() 

536 output.append(node) 

537 for dest in edges[node]: 

538 in_count[dest] -= 1 

539 if not in_count[dest]: 

540 ready.append(dest) 

541 

542 if len(output) != len(library.function): 

543 failed_to_resolve = sorted(set(in_count.keys()) - set(output)) 

544 raise ValueError("There is a cyclic dependency between functions. ", 

545 f"Could not resolve {failed_to_resolve}.") 

546 

547 reverse = {fdef.signature.name: fdef for fdef in library.function} 

548 return [reverse[x] for x in output] 

549 

550 

551def _get_gradient_op_type(node_def): 

552 """Returns the custom gradient op type.""" 

553 if ("_gradient_op_type" in node_def.attr and 

554 node_def.op not in ["StatefulPartitionedCall", "PartitionedCall"]): 

555 return node_def.attr["_gradient_op_type"].s 

556 return None 

557 

558 

559def fix_node_def(node_def, functions, shared_name_suffix): 

560 """Replace functions calls and shared names in `node_def`.""" 

561 if node_def.op in functions: 

562 node_def.op = functions[node_def.op].name 

563 for _, attr_value in node_def.attr.items(): 

564 if attr_value.WhichOneof("value") == "func": 

565 attr_value.func.name = functions[attr_value.func.name].name 

566 elif attr_value.WhichOneof("value") == "list": 

567 for fn in attr_value.list.func: 

568 fn.name = functions[fn.name].name 

569 

570 # Fix old table creation bug. 

571 if node_def.op == "HashTableV2": 

572 if ("use_node_name_sharing" not in node_def.attr or 

573 not node_def.attr["use_node_name_sharing"].b): 

574 node_def.attr["use_node_name_sharing"].b = True 

575 # We are turning on node mame sharing, so have to make sure we don't 

576 # accidentally share a table resource. 

577 shared_name_suffix += "_{}".format(ops.uid()) 

578 

579 # TODO(b/124205571): Avoid accidental sharing and destruction of restored 

580 # resources. For now uniquify "shared_name" when loading functions to avoid 

581 # sharing. 

582 # TODO: Add regression test for b/150826922. 

583 op_def = op_def_registry.get(node_def.op) 

584 if op_def: 

585 attr = next((a for a in op_def.attr if a.name == "shared_name"), None) 

586 if attr: 

587 shared_name = None 

588 if "shared_name" in node_def.attr and node_def.attr["shared_name"].s: 

589 shared_name = node_def.attr["shared_name"].s 

590 elif attr.default_value.s: 

591 shared_name = compat.as_bytes(attr.default_value.s) 

592 if not shared_name: 

593 shared_name = compat.as_bytes(node_def.name) 

594 

595 node_def.attr["shared_name"].s = ( 

596 shared_name + compat.as_bytes(shared_name_suffix)) 

597 

598 

599def _fix_fdef_in_place(fdef, functions, shared_name_suffix, 

600 new_gradient_op_types): 

601 """Fixes a FunctionDef proto to be loaded in current context. 

602 

603 In particular, when loading a function library into an eager context, one 

604 must rename the functions to avoid conflicts with existent functions. 

605 

606 Args: 

607 fdef: FunctionDef proto to fix. It is mutated in-place. 

608 functions: map from function name to a ConcreteFunction instance. 

609 shared_name_suffix: A unique string for this load which helps to avoid 

610 `shared_name` collisions across loads. Two functions from the same load 

611 using the same `shared_name` still need to share, but functions from 

612 different loads with the same `shared_name` should not. 

613 new_gradient_op_types: map from old gradient op type to newly generated op 

614 type. 

615 

616 Returns: 

617 orig_name: original value of fdef.signature.name 

618 """ 

619 orig_name = fdef.signature.name 

620 contains_unsaved_custom_gradients = False 

621 

622 for node_def in fdef.node_def: 

623 fix_node_def(node_def, functions, shared_name_suffix) 

624 op_type = _get_gradient_op_type(node_def) 

625 if op_type is not None: 

626 if op_type in new_gradient_op_types: 

627 node_def.attr["_gradient_op_type"].s = compat.as_bytes( 

628 new_gradient_op_types[op_type]) 

629 else: 

630 contains_unsaved_custom_gradients = True 

631 if contains_unsaved_custom_gradients: 

632 logging.warning( 

633 "Importing a function (%s) with ops with unsaved custom gradients. Will" 

634 " likely fail if a gradient is requested.", fdef.signature.name) 

635 

636 fdef.signature.name = _clean_function_name(fdef.signature.name) 

637 return orig_name 

638 

639 

640def _list_function_deps(fdef, library_function_names, library_gradient_names): 

641 """Find functions referenced in `fdef`.""" 

642 # TODO(b/205023953): Recurse into list attributes and into NameAttrList attrs 

643 # both when listing deps and when fixing them. `function_def_to_graph` also 

644 # requires fixes. 

645 deps = set() 

646 for node_def in fdef.node_def: 

647 grad_op_type = _get_gradient_op_type(node_def) 

648 if node_def.op in library_function_names: 

649 deps.add(node_def.op) 

650 elif grad_op_type and grad_op_type in library_gradient_names: 

651 deps.add(library_gradient_names[grad_op_type]) 

652 else: 

653 for _, attr_value in node_def.attr.items(): 

654 if attr_value.WhichOneof("value") == "func": 

655 deps.add(attr_value.func.name) 

656 elif attr_value.WhichOneof("value") == "list": 

657 for fn in attr_value.list.func: 

658 deps.add(fn.name) 

659 

660 return deps 

661 

662 

663_FUNCTION_WRAPPER_NAME_REGEX = r"^%s(.*)_\d+$" % (function_lib._INFERENCE_PREFIX 

664 ) # pylint:disable=protected-access 

665 

666 

667def _clean_function_name(name): 

668 """Vanity function to keep the function names comprehensible.""" 

669 # Note: each time a function is wrapped into `function_lib.ConcreteFunction` 

670 # its name becomes "__inference_<orig>_xyz". 

671 match = re.search(_FUNCTION_WRAPPER_NAME_REGEX, name) 

672 if match: 

673 return match.group(1) 

674 else: 

675 return name