Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/util/function_utils.py: 25%
56 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility to retrieve function args."""
17import functools
19import six
21from tensorflow.core.protobuf import config_pb2
22from tensorflow.python.util import tf_decorator
23from tensorflow.python.util import tf_inspect
26def _is_bound_method(fn):
27 _, fn = tf_decorator.unwrap(fn)
28 return tf_inspect.ismethod(fn) and (fn.__self__ is not None)
31def _is_callable_object(obj):
32 return hasattr(obj, '__call__') and tf_inspect.ismethod(obj.__call__)
35def fn_args(fn):
36 """Get argument names for function-like object.
38 Args:
39 fn: Function, or function-like object (e.g., result of `functools.partial`).
41 Returns:
42 `tuple` of string argument names.
44 Raises:
45 ValueError: if partial function has positionally bound arguments
46 """
47 if isinstance(fn, functools.partial):
48 args = fn_args(fn.func)
49 args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
50 else:
51 if _is_callable_object(fn):
52 fn = fn.__call__
53 args = tf_inspect.getfullargspec(fn).args
54 if _is_bound_method(fn) and args:
55 # If it's a bound method, it may or may not have a self/cls first
56 # argument; for example, self could be captured in *args.
57 # If it does have a positional argument, it is self/cls.
58 args.pop(0)
59 return tuple(args)
62def has_kwargs(fn):
63 """Returns whether the passed callable has **kwargs in its signature.
65 Args:
66 fn: Function, or function-like object (e.g., result of `functools.partial`).
68 Returns:
69 `bool`: if `fn` has **kwargs in its signature.
71 Raises:
72 `TypeError`: If fn is not a Function, or function-like object.
73 """
74 if isinstance(fn, functools.partial):
75 fn = fn.func
76 elif _is_callable_object(fn):
77 fn = fn.__call__
78 elif not callable(fn):
79 raise TypeError(
80 'Argument `fn` should be a callable. '
81 f'Received: fn={fn} (of type {type(fn)})')
82 return tf_inspect.getfullargspec(fn).varkw is not None
85def get_func_name(func):
86 """Returns name of passed callable."""
87 _, func = tf_decorator.unwrap(func)
88 if callable(func):
89 if tf_inspect.isfunction(func):
90 return func.__name__
91 elif tf_inspect.ismethod(func):
92 return '%s.%s' % (six.get_method_self(func).__class__.__name__,
93 six.get_method_function(func).__name__)
94 else: # Probably a class instance with __call__
95 return str(type(func))
96 else:
97 raise ValueError(
98 'Argument `func` must be a callable. '
99 f'Received func={func} (of type {type(func)})')
102def get_func_code(func):
103 """Returns func_code of passed callable, or None if not available."""
104 _, func = tf_decorator.unwrap(func)
105 if callable(func):
106 if tf_inspect.isfunction(func) or tf_inspect.ismethod(func):
107 return six.get_function_code(func)
108 # Since the object is not a function or method, but is a callable, we will
109 # try to access the __call__method as a function. This works with callable
110 # classes but fails with functool.partial objects despite their __call__
111 # attribute.
112 try:
113 return six.get_function_code(func.__call__)
114 except AttributeError:
115 return None
116 else:
117 raise ValueError(
118 'Argument `func` must be a callable. '
119 f'Received func={func} (of type {type(func)})')
122_rewriter_config_optimizer_disabled = None
125def get_disabled_rewriter_config():
126 global _rewriter_config_optimizer_disabled
127 if _rewriter_config_optimizer_disabled is None:
128 config = config_pb2.ConfigProto()
129 rewriter_config = config.graph_options.rewrite_options
130 rewriter_config.disable_meta_optimizer = True
131 _rewriter_config_optimizer_disabled = config.SerializeToString()
132 return _rewriter_config_optimizer_disabled