Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/util/tf_export.py: 45%
148 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-05 06:32 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-05 06:32 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for exporting TensorFlow symbols to the API.
17Exporting a function or a class:
19To export a function or a class use tf_export decorator. For e.g.:
20```python
21@tf_export('foo', 'bar.foo')
22def foo(...):
23 ...
24```
26If a function is assigned to a variable, you can export it by calling
27tf_export explicitly. For e.g.:
28```python
29foo = get_foo(...)
30tf_export('foo', 'bar.foo')(foo)
31```
34Exporting a constant
35```python
36foo = 1
37tf_export('consts.foo').export_constant(__name__, 'foo')
38```
39"""
40import collections
41import functools
42import sys
44from tensorflow.python.util import tf_decorator
45from tensorflow.python.util import tf_inspect
47ESTIMATOR_API_NAME = 'estimator'
48KERAS_API_NAME = 'keras'
49TENSORFLOW_API_NAME = 'tensorflow'
51# List of subpackage names used by TensorFlow components. Have to check that
52# TensorFlow core repo does not export any symbols under these names.
53SUBPACKAGE_NAMESPACES = [ESTIMATOR_API_NAME]
55_Attributes = collections.namedtuple(
56 'ExportedApiAttributes', ['names', 'constants'])
58# Attribute values must be unique to each API.
59API_ATTRS = {
60 TENSORFLOW_API_NAME: _Attributes(
61 '_tf_api_names',
62 '_tf_api_constants'),
63 ESTIMATOR_API_NAME: _Attributes(
64 '_estimator_api_names',
65 '_estimator_api_constants'),
66 KERAS_API_NAME: _Attributes(
67 '_keras_api_names',
68 '_keras_api_constants')
69}
71API_ATTRS_V1 = {
72 TENSORFLOW_API_NAME: _Attributes(
73 '_tf_api_names_v1',
74 '_tf_api_constants_v1'),
75 ESTIMATOR_API_NAME: _Attributes(
76 '_estimator_api_names_v1',
77 '_estimator_api_constants_v1'),
78 KERAS_API_NAME: _Attributes(
79 '_keras_api_names_v1',
80 '_keras_api_constants_v1')
81}
84class SymbolAlreadyExposedError(Exception):
85 """Raised when adding API names to symbol that already has API names."""
86 pass
89class InvalidSymbolNameError(Exception):
90 """Raised when trying to export symbol as an invalid or unallowed name."""
91 pass
93_NAME_TO_SYMBOL_MAPPING = dict()
96def get_symbol_from_name(name):
97 return _NAME_TO_SYMBOL_MAPPING.get(name)
100def get_canonical_name_for_symbol(
101 symbol, api_name=TENSORFLOW_API_NAME,
102 add_prefix_to_v1_names=False):
103 """Get canonical name for the API symbol.
105 Example:
106 ```python
107 from tensorflow.python.util import tf_export
108 cls = tf_export.get_symbol_from_name('keras.optimizers.Adam')
110 # Gives `<class 'keras.optimizer_v2.adam.Adam'>`
111 print(cls)
113 # Gives `keras.optimizers.Adam`
114 print(tf_export.get_canonical_name_for_symbol(cls, api_name='keras'))
115 ```
117 Args:
118 symbol: API function or class.
119 api_name: API name (tensorflow or estimator).
120 add_prefix_to_v1_names: Specifies whether a name available only in V1
121 should be prefixed with compat.v1.
123 Returns:
124 Canonical name for the API symbol (for e.g. initializers.zeros) if
125 canonical name could be determined. Otherwise, returns None.
126 """
127 if not hasattr(symbol, '__dict__'):
128 return None
129 api_names_attr = API_ATTRS[api_name].names
130 _, undecorated_symbol = tf_decorator.unwrap(symbol)
131 if api_names_attr not in undecorated_symbol.__dict__:
132 return None
133 api_names = getattr(undecorated_symbol, api_names_attr)
134 deprecated_api_names = undecorated_symbol.__dict__.get(
135 '_tf_deprecated_api_names', [])
137 canonical_name = get_canonical_name(api_names, deprecated_api_names)
138 if canonical_name:
139 return canonical_name
141 # If there is no V2 canonical name, get V1 canonical name.
142 api_names_attr = API_ATTRS_V1[api_name].names
143 api_names = getattr(undecorated_symbol, api_names_attr)
144 v1_canonical_name = get_canonical_name(api_names, deprecated_api_names)
145 if add_prefix_to_v1_names:
146 return 'compat.v1.%s' % v1_canonical_name
147 return v1_canonical_name
150def get_canonical_name(api_names, deprecated_api_names):
151 """Get preferred endpoint name.
153 Args:
154 api_names: API names iterable.
155 deprecated_api_names: Deprecated API names iterable.
156 Returns:
157 Returns one of the following in decreasing preference:
158 - first non-deprecated endpoint
159 - first endpoint
160 - None
161 """
162 non_deprecated_name = next(
163 (name for name in api_names if name not in deprecated_api_names),
164 None)
165 if non_deprecated_name:
166 return non_deprecated_name
167 if api_names:
168 return api_names[0]
169 return None
172def get_v1_names(symbol):
173 """Get a list of TF 1.* names for this symbol.
175 Args:
176 symbol: symbol to get API names for.
178 Returns:
179 List of all API names for this symbol including TensorFlow and
180 Estimator names.
181 """
182 names_v1 = []
183 tensorflow_api_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].names
184 estimator_api_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].names
185 keras_api_attr_v1 = API_ATTRS_V1[KERAS_API_NAME].names
187 if not hasattr(symbol, '__dict__'):
188 return names_v1
189 if tensorflow_api_attr_v1 in symbol.__dict__:
190 names_v1.extend(getattr(symbol, tensorflow_api_attr_v1))
191 if estimator_api_attr_v1 in symbol.__dict__:
192 names_v1.extend(getattr(symbol, estimator_api_attr_v1))
193 if keras_api_attr_v1 in symbol.__dict__:
194 names_v1.extend(getattr(symbol, keras_api_attr_v1))
195 return names_v1
198def get_v2_names(symbol):
199 """Get a list of TF 2.0 names for this symbol.
201 Args:
202 symbol: symbol to get API names for.
204 Returns:
205 List of all API names for this symbol including TensorFlow and
206 Estimator names.
207 """
208 names_v2 = []
209 tensorflow_api_attr = API_ATTRS[TENSORFLOW_API_NAME].names
210 estimator_api_attr = API_ATTRS[ESTIMATOR_API_NAME].names
211 keras_api_attr = API_ATTRS[KERAS_API_NAME].names
213 if not hasattr(symbol, '__dict__'):
214 return names_v2
215 if tensorflow_api_attr in symbol.__dict__:
216 names_v2.extend(getattr(symbol, tensorflow_api_attr))
217 if estimator_api_attr in symbol.__dict__:
218 names_v2.extend(getattr(symbol, estimator_api_attr))
219 if keras_api_attr in symbol.__dict__:
220 names_v2.extend(getattr(symbol, keras_api_attr))
221 return names_v2
224def get_v1_constants(module):
225 """Get a list of TF 1.* constants in this module.
227 Args:
228 module: TensorFlow module.
230 Returns:
231 List of all API constants under the given module including TensorFlow and
232 Estimator constants.
233 """
234 constants_v1 = []
235 tensorflow_constants_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].constants
236 estimator_constants_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].constants
238 if hasattr(module, tensorflow_constants_attr_v1):
239 constants_v1.extend(getattr(module, tensorflow_constants_attr_v1))
240 if hasattr(module, estimator_constants_attr_v1):
241 constants_v1.extend(getattr(module, estimator_constants_attr_v1))
242 return constants_v1
245def get_v2_constants(module):
246 """Get a list of TF 2.0 constants in this module.
248 Args:
249 module: TensorFlow module.
251 Returns:
252 List of all API constants under the given module including TensorFlow and
253 Estimator constants.
254 """
255 constants_v2 = []
256 tensorflow_constants_attr = API_ATTRS[TENSORFLOW_API_NAME].constants
257 estimator_constants_attr = API_ATTRS[ESTIMATOR_API_NAME].constants
259 if hasattr(module, tensorflow_constants_attr):
260 constants_v2.extend(getattr(module, tensorflow_constants_attr))
261 if hasattr(module, estimator_constants_attr):
262 constants_v2.extend(getattr(module, estimator_constants_attr))
263 return constants_v2
266class api_export(object): # pylint: disable=invalid-name
267 """Provides ways to export symbols to the TensorFlow API."""
269 def __init__(self, *args, **kwargs): # pylint: disable=g-doc-args
270 """Export under the names *args (first one is considered canonical).
272 Args:
273 *args: API names in dot delimited format.
274 **kwargs: Optional keyed arguments.
275 v1: Names for the TensorFlow V1 API. If not set, we will use V2 API
276 names both for TensorFlow V1 and V2 APIs.
277 overrides: List of symbols that this is overriding
278 (those overrided api exports will be removed). Note: passing overrides
279 has no effect on exporting a constant.
280 api_name: Name of the API you want to generate (e.g. `tensorflow` or
281 `estimator`). Default is `tensorflow`.
282 allow_multiple_exports: Allow symbol to be exported multiple time under
283 different names.
284 """
285 self._names = args
286 self._names_v1 = kwargs.get('v1', args)
287 if 'v2' in kwargs:
288 raise ValueError('You passed a "v2" argument to tf_export. This is not '
289 'what you want. Pass v2 names directly as positional '
290 'arguments instead.')
291 self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME)
292 self._overrides = kwargs.get('overrides', [])
293 self._allow_multiple_exports = kwargs.get('allow_multiple_exports', False)
295 self._validate_symbol_names()
297 def _validate_symbol_names(self):
298 """Validate you are exporting symbols under an allowed package.
300 We need to ensure things exported by tf_export, estimator_export, etc.
301 export symbols under disjoint top-level package names.
303 For TensorFlow, we check that it does not export anything under subpackage
304 names used by components (estimator, keras, etc.).
306 For each component, we check that it exports everything under its own
307 subpackage.
309 Raises:
310 InvalidSymbolNameError: If you try to export symbol under disallowed name.
311 """
312 all_symbol_names = set(self._names) | set(self._names_v1)
313 if self._api_name == TENSORFLOW_API_NAME:
314 for subpackage in SUBPACKAGE_NAMESPACES:
315 if any(n.startswith(subpackage) for n in all_symbol_names):
316 raise InvalidSymbolNameError(
317 '@tf_export is not allowed to export symbols under %s.*' % (
318 subpackage))
319 else:
320 if not all(n.startswith(self._api_name) for n in all_symbol_names):
321 raise InvalidSymbolNameError(
322 'Can only export symbols under package name of component. '
323 'e.g. tensorflow_estimator must export all symbols under '
324 'tf.estimator')
326 def __call__(self, func):
327 """Calls this decorator.
329 Args:
330 func: decorated symbol (function or class).
332 Returns:
333 The input function with _tf_api_names attribute set.
335 Raises:
336 SymbolAlreadyExposedError: Raised when a symbol already has API names
337 and kwarg `allow_multiple_exports` not set.
338 """
339 api_names_attr = API_ATTRS[self._api_name].names
340 api_names_attr_v1 = API_ATTRS_V1[self._api_name].names
341 # Undecorate overridden names
342 for f in self._overrides:
343 _, undecorated_f = tf_decorator.unwrap(f)
344 delattr(undecorated_f, api_names_attr)
345 delattr(undecorated_f, api_names_attr_v1)
347 _, undecorated_func = tf_decorator.unwrap(func)
348 self.set_attr(undecorated_func, api_names_attr, self._names)
349 self.set_attr(undecorated_func, api_names_attr_v1, self._names_v1)
351 for name in self._names:
352 _NAME_TO_SYMBOL_MAPPING[name] = func
353 for name_v1 in self._names_v1:
354 _NAME_TO_SYMBOL_MAPPING['compat.v1.%s' % name_v1] = func
356 return func
358 def set_attr(self, func, api_names_attr, names):
359 # Check for an existing api. We check if attribute name is in
360 # __dict__ instead of using hasattr to verify that subclasses have
361 # their own _tf_api_names as opposed to just inheriting it.
362 if api_names_attr in func.__dict__:
363 if not self._allow_multiple_exports:
364 raise SymbolAlreadyExposedError(
365 'Symbol %s is already exposed as %s.' %
366 (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access
367 setattr(func, api_names_attr, names)
369 def export_constant(self, module_name, name):
370 """Store export information for constants/string literals.
372 Export information is stored in the module where constants/string literals
373 are defined.
375 e.g.
376 ```python
377 foo = 1
378 bar = 2
379 tf_export("consts.foo").export_constant(__name__, 'foo')
380 tf_export("consts.bar").export_constant(__name__, 'bar')
381 ```
383 Args:
384 module_name: (string) Name of the module to store constant at.
385 name: (string) Current constant name.
386 """
387 module = sys.modules[module_name]
388 api_constants_attr = API_ATTRS[self._api_name].constants
389 api_constants_attr_v1 = API_ATTRS_V1[self._api_name].constants
391 if not hasattr(module, api_constants_attr):
392 setattr(module, api_constants_attr, [])
393 # pylint: disable=protected-access
394 getattr(module, api_constants_attr).append(
395 (self._names, name))
397 if not hasattr(module, api_constants_attr_v1):
398 setattr(module, api_constants_attr_v1, [])
399 getattr(module, api_constants_attr_v1).append(
400 (self._names_v1, name))
403def kwarg_only(f):
404 """A wrapper that throws away all non-kwarg arguments."""
405 f_argspec = tf_inspect.getfullargspec(f)
407 def wrapper(*args, **kwargs):
408 if args:
409 raise TypeError(
410 '{f} only takes keyword args (possible keys: {kwargs}). '
411 'Please pass these args as kwargs instead.'
412 .format(f=f.__name__, kwargs=f_argspec.args))
413 return f(**kwargs)
415 return tf_decorator.make_decorator(
416 f, wrapper, decorator_argspec=f_argspec)
419tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME)
420keras_export = functools.partial(api_export, api_name=KERAS_API_NAME)