Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/function_deserialization.py: 16%
301 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tools for deserializing `Function`s."""
17import collections
18import pprint
19import re
21from absl import logging
23from tensorflow.core.protobuf import saved_object_graph_pb2
24from tensorflow.python.eager import def_function
25from tensorflow.python.eager import function as function_lib
26from tensorflow.python.eager.polymorphic_function import function_spec as function_spec_lib
27from tensorflow.python.framework import func_graph as func_graph_lib
28from tensorflow.python.framework import function_def_to_graph as function_def_lib
29from tensorflow.python.framework import op_def_registry
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_spec
32from tensorflow.python.framework import type_spec
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import custom_gradient
35from tensorflow.python.ops import default_gradient
36from tensorflow.python.ops import resource_variable_ops
37from tensorflow.python.saved_model import nested_structure_coder
38from tensorflow.python.util import compat
39from tensorflow.python.util import nest
40from tensorflow.python.util import tf_decorator
41from tensorflow.python.util import tf_inspect
44def _is_tensor(t):
45 return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable))
48# TODO(b/205016027): Update this to just use ConcreteFunction.__call__ with the
49# structured signature.
50def _call_concrete_function(function, inputs):
51 """Calls a restored Function with structured inputs.
53 This differs from `function.__call__` in that inputs and outputs are
54 structured and that it casts inputs to tensors if needed.
56 Note: this does not checks that non-tensor inputs match. That should be
57 done before via `_concrete_function_callable_with`.
59 Args:
60 function: ConcreteFunction to call.
61 inputs: Structured inputs compatible with
62 `function.graph.structured_input_signature`.
64 Returns:
65 The structured function output.
66 """
67 expected_structure = function.graph.structured_input_signature
68 flatten_inputs = nest.flatten_up_to(
69 expected_structure, inputs, expand_composites=True)
70 flatten_expected = nest.flatten(expected_structure, expand_composites=True)
71 tensor_inputs = []
72 for arg, expected in zip(flatten_inputs, flatten_expected):
73 if isinstance(expected, tensor_spec.TensorSpec):
74 tensor_inputs.append(
75 ops.convert_to_tensor(arg, dtype_hint=expected.dtype))
76 elif isinstance(expected, resource_variable_ops.VariableSpec):
77 tensor_inputs.append(arg)
78 result = function._call_flat(tensor_inputs, function.captured_inputs) # pylint: disable=protected-access
79 if isinstance(result, ops.Operation):
80 return None
81 return result
84def _try_convert_to_tensor_spec(arg, dtype_hint):
85 """Returns None or TensorSpec obtained if `arg` is converted to tensor."""
86 try:
87 # Note: try conversion in a FuncGraph to avoid polluting current context.
88 with func_graph_lib.FuncGraph(name="guess_conversion").as_default():
89 result = ops.convert_to_tensor(arg, dtype_hint=dtype_hint)
90 return tensor_spec.TensorSpec(shape=result.shape, dtype=result.dtype)
91 except (TypeError, ValueError):
92 return None
95def _concrete_function_callable_with(function, inputs, allow_conversion):
96 """Returns whether concrete `function` can be called with `inputs`."""
97 expected_structure = function.graph.structured_input_signature
98 try:
99 flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
100 except (TypeError, ValueError):
101 return False
103 for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
104 if isinstance(expected, tensor_spec.TensorSpec):
105 if allow_conversion:
106 arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype)
107 if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec):
108 return False
109 if arg.dtype != expected.dtype:
110 return False
111 if not expected.shape.is_compatible_with(arg.shape):
112 return False
113 elif isinstance(expected, type_spec.TypeSpec):
114 if not expected.is_compatible_with(arg):
115 return False
116 elif _is_tensor(arg):
117 if id(arg) != id(expected):
118 return False
119 else:
120 if arg != expected:
121 return False
122 return True
125def _deserialize_function_spec_as_nonmethod(function_spec_proto):
126 """Deserialize a FunctionSpec object from its proto representation."""
127 typeless_fullargspec = nested_structure_coder.decode_proto(
128 function_spec_proto.fullargspec)
130 # Convert a method function into a non method.
131 if function_spec_proto.is_method or (
132 typeless_fullargspec.args and typeless_fullargspec.args[0] == "self"
133 ):
134 if not typeless_fullargspec.args:
135 raise NotImplementedError(
136 "Cannot deserialize a method function without a named "
137 "'self' argument.")
138 args = typeless_fullargspec.args[1:]
139 else:
140 args = typeless_fullargspec.args
142 fullargspec = tf_inspect.FullArgSpec(
143 args=args,
144 varargs=typeless_fullargspec.varargs,
145 varkw=typeless_fullargspec.varkw,
146 defaults=typeless_fullargspec.defaults,
147 kwonlyargs=typeless_fullargspec.kwonlyargs,
148 kwonlydefaults=typeless_fullargspec.kwonlydefaults,
149 annotations=typeless_fullargspec.annotations)
150 input_signature = nested_structure_coder.decode_proto(
151 function_spec_proto.input_signature)
153 # See `tf.function` and the JitCompile proto for details.
154 jit_compile = {
155 saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT: None,
156 saved_object_graph_pb2.FunctionSpec.JitCompile.ON: True,
157 saved_object_graph_pb2.FunctionSpec.JitCompile.OFF: False,
158 }.get(function_spec_proto.jit_compile)
160 return function_spec_lib.FunctionSpec.from_fullargspec_and_signature(
161 fullargspec=fullargspec,
162 input_signature=input_signature,
163 jit_compile=jit_compile)
166# TODO(b/205016761): The fact that we can't derive ConcreteFunction calling
167# conventions from the serialized input spec right now is unfortunate. Merging
168# these would be good, maybe by adding TensorSpec names to cache keys so renamed
169# keyword arguments would yield different ConcreteFunctions.
170def setup_bare_concrete_function(saved_bare_concrete_function,
171 concrete_functions):
172 """Makes a restored bare concrete function callable."""
173 concrete_function = concrete_functions[
174 saved_bare_concrete_function.concrete_function_name]
175 # pylint: disable=protected-access
176 concrete_function._arg_keywords = (
177 saved_bare_concrete_function.argument_keywords)
178 concrete_function._num_positional_args = (
179 saved_bare_concrete_function.allowed_positional_arguments)
180 if saved_bare_concrete_function.HasField("function_spec"):
181 function_spec = _deserialize_function_spec_as_nonmethod(
182 saved_bare_concrete_function.function_spec)
183 concrete_function._set_function_spec(function_spec)
184 # pylint: enable=protected-access
185 concrete_function.add_to_graph()
186 return concrete_function
189class RestoredFunction(def_function.Function):
190 """Wrapper class for a function that has been restored from saved state.
192 See `def_function.Function`.
193 """
195 def __init__(self, python_function, name, function_spec, concrete_functions):
196 # TODO(b/205016819): We may enable autograph once exceptions are supported.
197 super(RestoredFunction, self).__init__(
198 python_function,
199 name,
200 autograph=False,
201 jit_compile=function_spec.jit_compile)
202 self.concrete_functions = concrete_functions
203 self._function_spec = function_spec
205 # Prevent RestoredFunction from spamming users with frequent tracing
206 # warnings.
207 self._omit_frequent_tracing_warning = True
209 @property
210 def _run_functions_eagerly(self):
211 # We do not have access to the original python function, and thus, we
212 # cannot meaningfully do anything but call our concrete function graphs
213 # under the hood.
214 #
215 # Attempting to call our bespoke python function (i.e.
216 # `restored_function_body`) will work so long as the user passes in all
217 # required and optional arguments. If an optional argument is missing,
218 # however, the call will break. For this reason, we instead skip the
219 # eager call path altogether if a user has enabled eager function execution
220 # via `tf.config.run_functions_eagerly`.
221 return False
223 def _list_all_concrete_functions(self):
224 return self.concrete_functions
226 def _list_all_concrete_functions_for_serialization(self):
227 return self.concrete_functions
229 def _compiler_with_scope(self, scope):
230 func = super(RestoredFunction, self)._compiler_with_scope(scope)
231 func._function_spec = self._function_spec # pylint: disable=protected-access
232 return func
235def recreate_function(saved_function, concrete_functions):
236 """Creates a `Function` from a `SavedFunction`.
238 Args:
239 saved_function: `SavedFunction` proto.
240 concrete_functions: map from function name to `ConcreteFunction`. As a side
241 effect of this function, the `FunctionSpec` from `saved_function` is added
242 to each `ConcreteFunction` in this map.
244 Returns:
245 A `Function`.
246 """
247 # TODO(b/205017389): Construct a `Function` with the cache populated
248 # instead of creating a new `Function` backed by a Python layer to
249 # glue things together. Current approach is nesting functions deeper for each
250 # serialization cycle.
252 # Note: handling method functions is tricky since make_decorator does not
253 # allows control of "ismethod". Additionally since restored functions do
254 # not behave as methods i.e. they always use the same captured tensors
255 # independent of the object they are bound to, there is little value on
256 # propagating that correctly.
257 #
258 # Ideally this conversion should happen at serialization time. But since
259 # there are SavedModels which have "ismethod" populated and have an extra
260 # argument that they expect to be ignored, we do it at deserialization.
261 function_spec = _deserialize_function_spec_as_nonmethod(
262 saved_function.function_spec)
264 def restored_function_body(*args, **kwargs):
265 """Calls a restored function or raises an error if no matching function."""
266 if not saved_function.concrete_functions:
267 raise ValueError("Found zero restored functions for caller function.")
268 # This is the format of function.graph.structured_input_signature. At this
269 # point, the args and kwargs have already been canonicalized.
270 inputs = (args, kwargs)
272 # First try to find a concrete function that can be called without input
273 # conversions. This allows one to pick a more specific trace in case there
274 # was also a more expensive one that supported tensors.
275 for allow_conversion in [False, True]:
276 for function_name in saved_function.concrete_functions:
277 function = concrete_functions[function_name]
278 if any([inp is None for inp in function.captured_inputs]):
279 raise ValueError("Looks like you are trying to run a loaded "
280 "non-Keras model that was trained using "
281 "tf.distribute.experimental.ParameterServerStrategy "
282 "with variable partitioning, which is not currently "
283 "supported. Try using Keras to define your model "
284 "if possible.")
285 if _concrete_function_callable_with(function, inputs, allow_conversion):
286 return _call_concrete_function(function, inputs)
288 signature_descriptions = []
290 def _pretty_format_positional(positional):
291 return "Positional arguments ({} total):\n * {}".format(
292 len(positional),
293 "\n * ".join(pprint.pformat(a) for a in positional))
295 for index, function_name in enumerate(saved_function.concrete_functions):
296 concrete_function = concrete_functions[function_name]
297 positional, keyword = concrete_function.structured_input_signature
298 signature_descriptions.append(
299 "Option {}:\n {}\n Keyword arguments: {}".format(
300 index + 1, _pretty_format_positional(positional), keyword))
301 raise ValueError(
302 "Could not find matching concrete function to call loaded from the "
303 f"SavedModel. Got:\n {_pretty_format_positional(args)}\n Keyword "
304 f"arguments: {kwargs}\n\n Expected these arguments to match one of the "
305 f"following {len(saved_function.concrete_functions)} option(s):\n\n"
306 f"{(chr(10)+chr(10)).join(signature_descriptions)}")
308 concrete_function_objects = []
309 for concrete_function_name in saved_function.concrete_functions:
310 concrete_function_objects.append(concrete_functions[concrete_function_name])
312 for cf in concrete_function_objects:
313 cf._set_function_spec(function_spec) # pylint: disable=protected-access
315 restored_function = RestoredFunction(restored_function_body,
316 restored_function_body.__name__,
317 function_spec, concrete_function_objects)
319 return tf_decorator.make_decorator(
320 restored_function_body,
321 restored_function,
322 decorator_argspec=function_spec.fullargspec)
325def load_function_def_library(library,
326 saved_object_graph=None,
327 load_shared_name_suffix=None,
328 wrapper_function=None):
329 """Load a set of functions as concrete functions without captured inputs.
331 Functions names are manipulated during load such that they do not overlap
332 with previously created ones.
334 Gradients are re-registered under new names. Ops that reference the gradients
335 are updated to reflect the new registered names.
337 Args:
338 library: FunctionDefLibrary proto message.
339 saved_object_graph: SavedObjectGraph proto message. If not passed in,
340 concrete function structured signatures and outputs will not be set.
341 load_shared_name_suffix: If specified, used to uniquify shared names.
342 Otherwise, a unique name is generated.
343 wrapper_function: An object that will be wrapped on newly created functions.
345 Returns:
346 Map of original function names in the library to instances of
347 `ConcreteFunction` without captured inputs.
349 Raises:
350 ValueError: if functions dependencies have a cycle.
351 """
352 library_function_names = set(fdef.signature.name for fdef in library.function)
353 functions = {}
354 renamed_functions = {}
356 # Our graph building code currently requires functions to be registered with
357 # some tf.Graph in order to import functions using the
358 # op-name-is-function-name calling convention. To avoid leaking memory into
359 # the global default graph when executing eagerly, we create a temporary
360 # Graph.
361 #
362 # TODO(b/205023033): Make this Graph creation unnecessary when executing
363 # eagerly by fixing function_def_to_graph_def.
364 if ops.executing_eagerly_outside_functions():
365 graph = ops.Graph()
366 else:
367 graph = ops.get_default_graph()
369 if load_shared_name_suffix is None:
370 load_shared_name_suffix = "_load_{}".format(ops.uid())
372 # Custom gradient functions must be re-registered under new UIDs.
373 library_gradient_names = {} # Maps old op type to old function name
374 new_gradient_op_types = {} # Maps old gradient op type to new op type.
375 gradients_to_register = {} # Maps old function name to new op type
376 for gdef in library.registered_gradients:
377 if gdef.registered_op_type:
378 new_op_type = custom_gradient.generate_name()
379 old_op_type = compat.as_bytes(gdef.registered_op_type)
381 library_gradient_names[old_op_type] = gdef.gradient_func
382 new_gradient_op_types[old_op_type] = new_op_type
383 gradients_to_register[gdef.gradient_func] = new_op_type
385 function_deps = {}
386 for fdef in library.function:
387 function_deps[fdef.signature.name] = _list_function_deps(
388 fdef, library_function_names, library_gradient_names)
390 loaded_gradients = {}
391 for fdef in _sort_function_defs(library, function_deps):
392 orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix,
393 new_gradient_op_types)
395 # Setup function signatures and outputs
396 #
397 # When concrete functions are created normally (i.e. when they're originally
398 # created and not loaded via saved model), the inputs and outputs are
399 # calculated based on the values passed in by the user and returned from the
400 # original function, respectively. We don't have access to those anymore at
401 # restore time, so we must instead pass them to the FuncGraph explicitly.
402 structured_input_signature = None
403 structured_outputs = None
404 if (saved_object_graph is not None and
405 orig_name in saved_object_graph.concrete_functions):
406 # TODO(b/204324043): Offload the deserialization of the protos to the
407 # first class objects by passing the actual protos. This is blocked on
408 # importing `nested_structure_coder` in function.py causing a circular
409 # dependency.
410 proto = saved_object_graph.concrete_functions[orig_name]
411 structured_input_signature = nested_structure_coder.decode_proto(
412 proto.canonicalized_input_signature)
413 structured_outputs = nested_structure_coder.decode_proto(
414 proto.output_signature)
416 # There is no need to copy all functions into the function def graph. It
417 # leads to a O(n^2) increase of memory when importing functions and the
418 # extra function definitions are a no-op since they already imported as a
419 # function before and passed in explicitly (due to the topologic sort
420 # import).
421 with graph.as_default():
422 func_graph = function_def_lib.function_def_to_graph(
423 fdef,
424 structured_input_signature=structured_input_signature,
425 structured_outputs=structured_outputs)
426 # Restores gradients for function-call ops (not the same as ops that use
427 # custom gradients)
428 _restore_gradient_functions(func_graph, renamed_functions, loaded_gradients)
430 for dep in function_deps[orig_name]:
431 functions[dep].add_to_graph(func_graph)
433 # We do not initialize the new ConcreteFunction's function_spec and/or
434 # arg_keywords here (which are used to parse the structured and flat
435 # signatures, respectively). ConcreteFunction that are part of a saved
436 # function is set up later by recreate_function(); and bare ConcreteFunction
437 # is set up by by setup_bare_concrete_function().
438 # However, we copy the FunctionDef attributes to the new ConcreteFunction,
439 # excluding the "_input_shapes", which may cause an error during input shape
440 # initialization at a later stage.
441 if "_input_shapes" in fdef.attr:
442 del fdef.attr["_input_shapes"]
443 func = function_lib.ConcreteFunction(func_graph, attrs=fdef.attr)
444 if wrapper_function:
445 func = wrapper_function(func)
446 func.add_to_graph(graph)
448 functions[orig_name] = func
449 renamed_functions[func.name] = func
450 if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()):
451 # TODO(b/150708051): Remove this hack once TensorRT SavedModel integration
452 # is fixed. Currently it's leaking memory to maintain bug compatibility
453 # with previous behavior.
454 func.add_to_graph(ops.get_default_graph())
456 if orig_name in gradients_to_register:
457 gradient_op_type = gradients_to_register[orig_name]
458 loaded_gradients[compat.as_bytes(gradient_op_type)] = func
459 ops.RegisterGradient(gradient_op_type)(_gen_gradient_func(func))
461 return functions
464def _gen_gradient_func(func):
465 """Wraps a deserialized function."""
467 def gradient_func(unused_op, *result_grads):
468 # Replace all `None` arguments, because the traced custom gradient function
469 # expects tensors. Replacing with zeros is correct since the `None` values
470 # occur when the gradient is unconnected, and thus the gradient is
471 # "statically proven to be zero." See `tf.UnconnectedGradients` for details.
473 def none_to_zero(x, t):
474 if x is not None:
475 return x
477 shape, dtype = default_gradient.shape_and_dtype(t)
479 if shape.is_fully_defined():
480 return default_gradient.zeros_like(t)
482 dims = []
483 if shape.rank is not None:
484 dims = [1 if d is None else d for d in shape.as_list()]
486 return array_ops.zeros(dims, dtype)
488 result_grads = [
489 none_to_zero(x, t) for (x, t) in zip(result_grads, func.graph.inputs)
490 ]
492 return func(*result_grads)
494 return gradient_func
497def _restore_gradient_functions(func_graph, renamed_functions,
498 loaded_gradients):
499 """Populate function op's _gradient_function with default gradient."""
500 for op in func_graph.get_operations():
501 # TODO(b/205024208): This code assumes that the gradient registered for this
502 # function call is the default gradient for the function and not a custom
503 # one.
504 if op.type in ["StatefulPartitionedCall", "PartitionedCall"]:
505 function = renamed_functions[compat.as_bytes(
506 op.node_def.attr["f"].func.name)]
507 op._gradient_function = function._get_gradient_function() # pylint: disable=protected-access
508 try:
509 gradient_op_type = op.get_attr("_gradient_op_type")
510 except ValueError:
511 pass
512 else:
513 if gradient_op_type in loaded_gradients:
514 grad_fn = loaded_gradients[gradient_op_type]
515 grad_fn._num_positional_args = len(op.inputs) # pylint: disable=protected-access
516 grad_fn._arg_keywords = [inp.name for inp in op.inputs] # pylint: disable=protected-access
519def _sort_function_defs(library, function_deps):
520 """Return a topologic sort of FunctionDefs in a library."""
521 edges = collections.defaultdict(list)
522 in_count = collections.defaultdict(lambda: 0)
524 for fname, deps in function_deps.items():
525 for dep in deps:
526 edges[dep].append(fname)
527 in_count[fname] += 1
528 ready = [
529 fdef.signature.name
530 for fdef in library.function
531 if in_count[fdef.signature.name] == 0
532 ]
533 output = []
534 while ready:
535 node = ready.pop()
536 output.append(node)
537 for dest in edges[node]:
538 in_count[dest] -= 1
539 if not in_count[dest]:
540 ready.append(dest)
542 if len(output) != len(library.function):
543 failed_to_resolve = sorted(set(in_count.keys()) - set(output))
544 raise ValueError("There is a cyclic dependency between functions. ",
545 f"Could not resolve {failed_to_resolve}.")
547 reverse = {fdef.signature.name: fdef for fdef in library.function}
548 return [reverse[x] for x in output]
551def _get_gradient_op_type(node_def):
552 """Returns the custom gradient op type."""
553 if ("_gradient_op_type" in node_def.attr and
554 node_def.op not in ["StatefulPartitionedCall", "PartitionedCall"]):
555 return node_def.attr["_gradient_op_type"].s
556 return None
559def fix_node_def(node_def, functions, shared_name_suffix):
560 """Replace functions calls and shared names in `node_def`."""
561 if node_def.op in functions:
562 node_def.op = functions[node_def.op].name
563 for _, attr_value in node_def.attr.items():
564 if attr_value.WhichOneof("value") == "func":
565 attr_value.func.name = functions[attr_value.func.name].name
566 elif attr_value.WhichOneof("value") == "list":
567 for fn in attr_value.list.func:
568 fn.name = functions[fn.name].name
570 # Fix old table creation bug.
571 if node_def.op == "HashTableV2":
572 if ("use_node_name_sharing" not in node_def.attr or
573 not node_def.attr["use_node_name_sharing"].b):
574 node_def.attr["use_node_name_sharing"].b = True
575 # We are turning on node mame sharing, so have to make sure we don't
576 # accidentally share a table resource.
577 shared_name_suffix += "_{}".format(ops.uid())
579 # TODO(b/124205571): Avoid accidental sharing and destruction of restored
580 # resources. For now uniquify "shared_name" when loading functions to avoid
581 # sharing.
582 # TODO: Add regression test for b/150826922.
583 op_def = op_def_registry.get(node_def.op)
584 if op_def:
585 attr = next((a for a in op_def.attr if a.name == "shared_name"), None)
586 if attr:
587 shared_name = None
588 if "shared_name" in node_def.attr and node_def.attr["shared_name"].s:
589 shared_name = node_def.attr["shared_name"].s
590 elif attr.default_value.s:
591 shared_name = compat.as_bytes(attr.default_value.s)
592 if not shared_name:
593 shared_name = compat.as_bytes(node_def.name)
595 node_def.attr["shared_name"].s = (
596 shared_name + compat.as_bytes(shared_name_suffix))
599def _fix_fdef_in_place(fdef, functions, shared_name_suffix,
600 new_gradient_op_types):
601 """Fixes a FunctionDef proto to be loaded in current context.
603 In particular, when loading a function library into an eager context, one
604 must rename the functions to avoid conflicts with existent functions.
606 Args:
607 fdef: FunctionDef proto to fix. It is mutated in-place.
608 functions: map from function name to a ConcreteFunction instance.
609 shared_name_suffix: A unique string for this load which helps to avoid
610 `shared_name` collisions across loads. Two functions from the same load
611 using the same `shared_name` still need to share, but functions from
612 different loads with the same `shared_name` should not.
613 new_gradient_op_types: map from old gradient op type to newly generated op
614 type.
616 Returns:
617 orig_name: original value of fdef.signature.name
618 """
619 orig_name = fdef.signature.name
620 contains_unsaved_custom_gradients = False
622 for node_def in fdef.node_def:
623 fix_node_def(node_def, functions, shared_name_suffix)
624 op_type = _get_gradient_op_type(node_def)
625 if op_type is not None:
626 if op_type in new_gradient_op_types:
627 node_def.attr["_gradient_op_type"].s = compat.as_bytes(
628 new_gradient_op_types[op_type])
629 else:
630 contains_unsaved_custom_gradients = True
631 if contains_unsaved_custom_gradients:
632 logging.warning(
633 "Importing a function (%s) with ops with unsaved custom gradients. Will"
634 " likely fail if a gradient is requested.", fdef.signature.name)
636 fdef.signature.name = _clean_function_name(fdef.signature.name)
637 return orig_name
640def _list_function_deps(fdef, library_function_names, library_gradient_names):
641 """Find functions referenced in `fdef`."""
642 # TODO(b/205023953): Recurse into list attributes and into NameAttrList attrs
643 # both when listing deps and when fixing them. `function_def_to_graph` also
644 # requires fixes.
645 deps = set()
646 for node_def in fdef.node_def:
647 grad_op_type = _get_gradient_op_type(node_def)
648 if node_def.op in library_function_names:
649 deps.add(node_def.op)
650 elif grad_op_type and grad_op_type in library_gradient_names:
651 deps.add(library_gradient_names[grad_op_type])
652 else:
653 for _, attr_value in node_def.attr.items():
654 if attr_value.WhichOneof("value") == "func":
655 deps.add(attr_value.func.name)
656 elif attr_value.WhichOneof("value") == "list":
657 for fn in attr_value.list.func:
658 deps.add(fn.name)
660 return deps
663_FUNCTION_WRAPPER_NAME_REGEX = r"^%s(.*)_\d+$" % (function_lib._INFERENCE_PREFIX
664 ) # pylint:disable=protected-access
667def _clean_function_name(name):
668 """Vanity function to keep the function names comprehensible."""
669 # Note: each time a function is wrapped into `function_lib.ConcreteFunction`
670 # its name becomes "__inference_<orig>_xyz".
671 match = re.search(_FUNCTION_WRAPPER_NAME_REGEX, name)
672 if match:
673 return match.group(1)
674 else:
675 return name