Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/impl/conversion.py: 18%
99 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Core conversion logic, serves as main point of access."""
17import functools
18import inspect
19import sys
20import unittest
22from tensorflow.python.autograph.core import config
23from tensorflow.python.autograph.pyct import cache
24from tensorflow.python.autograph.pyct import inspect_utils
25from tensorflow.python.autograph.utils import ag_logging as logging
26from tensorflow.python.eager.polymorphic_function import tf_method_target
27from tensorflow.python.util import tf_inspect
30_ALLOWLIST_CACHE = cache.UnboundInstanceCache()
33def _is_of_known_loaded_module(f, module_name):
34 mod = sys.modules.get(module_name, None)
35 if mod is None:
36 return False
37 if any(v is not None for v in mod.__dict__.values() if f is v):
38 return True
39 return False
42def _is_known_loaded_type(f, module_name, entity_name):
43 """Tests whether the function or method is an instance of a known type."""
44 if (module_name not in sys.modules or
45 not hasattr(sys.modules[module_name], entity_name)):
46 return False
47 type_entity = getattr(sys.modules[module_name], entity_name)
48 if isinstance(f, type_entity):
49 # The method if of this type. Example:
50 #
51 # o = ClassType()
52 # function(o.method)()
53 return True
54 # Note: inspect is required here, to avoid unpacking tf.function decorators.
55 if inspect.ismethod(f):
56 # The unbound method if of this type. Example:
57 #
58 # class ClassType:
59 # @function
60 # def method(self):
61 # ...
62 # o = ClassType()
63 # o.method()
64 if isinstance(f.__func__, type_entity):
65 return True
66 return False
69def is_unsupported(o):
70 """Checks whether an entity is supported by AutoGraph at all."""
72 # TODO(b/122265385): Remove this bypass.
73 if (_is_known_loaded_type(o, 'wrapt', 'FunctionWrapper') or
74 _is_known_loaded_type(o, 'wrapt', 'BoundFunctionWrapper')):
75 logging.warning(
76 '{} appears to be decorated by wrapt, which is not yet supported'
77 ' by AutoGraph. The function will run as-is.'
78 ' You may still apply AutoGraph before the wrapt decorator.'.format(o))
79 logging.log(2, 'Permanently allowed: %s: wrapt decorated', o)
80 return True
82 if _is_known_loaded_type(o, 'functools', '_lru_cache_wrapper'):
83 logging.log(2, 'Permanently allowed: %s: lru_cache', o)
84 return True
86 # Constructors are permanently allowed.
87 # TODO(mdan): Toggle as experimental feature instead.
88 # TODO(b/124016764): Remove this limitation.
89 if inspect_utils.isconstructor(o):
90 logging.log(2, 'Permanently allowed: %s: constructor', o)
91 return True
93 # Other built-in modules are permanently allowed.
94 # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
95 if any(
96 _is_of_known_loaded_module(o, m)
97 for m in ('collections', 'pdb', 'copy', 'inspect', 're')):
98 logging.log(2, 'Permanently allowed: %s: part of builtin module', o)
99 return True
101 # Custom ops and kernels are also permanently allowed.
102 # See tensorflow.framework.load_library.
103 if (hasattr(o, '__module__') and
104 hasattr(o.__module__, '_IS_TENSORFLOW_PLUGIN')):
105 logging.log(2, 'Permanently allowed: %s: TensorFlow plugin', o)
106 return True
108 return False
111# TODO(mdan): allow_namedtuple_subclass should be hardcoded to True.
112def is_allowlisted(
113 o, check_call_override=True, allow_namedtuple_subclass=False):
114 """Checks whether an entity is allowed for use in graph mode.
116 Examples of allowed entities include all members of the tensorflow
117 package.
119 Args:
120 o: A Python entity.
121 check_call_override: Reserved for internal use. When set to `False`, it
122 disables the rule according to which classes are allowed if their
123 __call__ method is allowed.
124 allow_namedtuple_subclass: Reserved for internal use. When `True`,
125 namedtuple subclasses are not allowed.
127 Returns:
128 Boolean
129 """
130 # TODO(b/120224672): Fix this.
131 if isinstance(o, functools.partial):
132 # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
133 # functools.partial objects do not have a __module__ attribute.
134 m = functools
135 else:
136 m = tf_inspect.getmodule(o)
138 # Examples of callables that lack a __module__ property include builtins.
139 if hasattr(m, '__name__'):
140 for rule in config.CONVERSION_RULES:
141 action = rule.get_action(m)
142 if action == config.Action.CONVERT:
143 logging.log(2, 'Not allowed: %s: %s', o, rule)
144 return False
145 elif action == config.Action.DO_NOT_CONVERT:
146 logging.log(2, 'Allowlisted: %s: %s', o, rule)
147 return True
149 # The check for __code__ below is because isgeneratorfunction crashes
150 # without one.
151 if hasattr(o, '__code__') and tf_inspect.isgeneratorfunction(o):
152 logging.log(2, 'Allowlisted: %s: generator functions are not converted', o)
153 return True
155 if (check_call_override and not tf_inspect.isclass(o) and
156 hasattr(o, '__call__')):
157 # Callable objects: allowed if their __call__ method is.
158 # The type check avoids infinite recursion around the __call__ method
159 # of function objects.
160 if (type(o) != type(o.__call__)) and is_allowlisted(o.__call__): # pylint: disable=unidiomatic-typecheck
161 logging.log(2, 'Allowlisted: %s: object __call__ allowed', o)
162 return True
164 owner_class = None
165 if tf_inspect.ismethod(o):
166 # Methods of allowed classes are also allowed, even if they are
167 # bound via user subclasses.
168 #
169 # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
170 # defined as below. `tf.Foo` is allowed. Then `baz.bar` is also
171 # allowed.
172 #
173 # class Custom(tf.Foo):
174 # pass
175 #
176 # baz = Custom()
177 #
178 # For the example above, if `Custom` did overload `bar`, then it would no
179 # longer be allowed.
181 owner_class = inspect_utils.getmethodclass(o)
182 if owner_class is tf_method_target.TfMethodTarget:
183 owner_class = o.__self__.target_class
184 if owner_class is not None:
185 if issubclass(owner_class, unittest.TestCase):
186 logging.log(2, 'Allowlisted: %s: method of TestCase subclass', o)
187 return True
189 owner_class = inspect_utils.getdefiningclass(o, owner_class)
190 if is_allowlisted(
191 owner_class,
192 check_call_override=False,
193 allow_namedtuple_subclass=True):
194 logging.log(2, 'Allowlisted: %s: owner is allowed %s', o,
195 owner_class)
196 return True
198 if inspect_utils.isnamedtuple(o):
199 # Due to the way they're constructed, namedtuple types cannot be converted
200 # because they don't expose source code. But we assume they are safe for
201 # graph mode since they are just containers.
202 if allow_namedtuple_subclass:
203 if not any(inspect_utils.isnamedtuple(base) for base in o.__bases__):
204 logging.log(2, 'Allowlisted: %s: named tuple', o)
205 return True
206 else:
207 logging.log(2, 'Allowlisted: %s: named tuple or subclass', o)
208 return True
210 logging.log(2, 'Not allowed: %s: default rule', o)
211 return False
214def is_in_allowlist_cache(entity, options):
215 try:
216 return _ALLOWLIST_CACHE.has(entity, options)
217 except TypeError:
218 # Catch-all for entities that are unhashable or don't allow weakrefs.
219 return False
222def cache_allowlisted(entity, options):
223 try:
224 _ALLOWLIST_CACHE[entity][options] = True
225 except TypeError:
226 # Catch-all for entities that are unhashable or don't allow weakrefs.
227 pass