Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/operators/py_builtins.py: 27%
284 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operators corresponding to Python builtin functions.
17List of built-in functions: https://docs.python.org/3/library/functions.html
18"""
20import inspect
22from tensorflow.python.autograph.utils import tensors
23from tensorflow.python.autograph.utils import type_registry
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import cond
30from tensorflow.python.ops import control_flow_assert
31from tensorflow.python.ops import gen_parsing_ops
32from tensorflow.python.ops import gen_string_ops
33from tensorflow.python.ops import list_ops
34from tensorflow.python.ops import math_ops
37UNSPECIFIED = object()
39abs_registry = type_registry.TypeRegistry()
40len_registry = type_registry.TypeRegistry()
41print_registry = type_registry.TypeRegistry()
42enumerate_registry = type_registry.TypeRegistry()
43zip_registry = type_registry.TypeRegistry()
44map_registry = type_registry.TypeRegistry()
45filter_registry = type_registry.TypeRegistry()
46any_registry = type_registry.TypeRegistry()
47all_registry = type_registry.TypeRegistry()
48sorted_registry = type_registry.TypeRegistry()
49next_registry = type_registry.TypeRegistry()
52def registry_lookup(reg, obj):
53 try:
54 return reg.lookup(obj)
55 except LookupError:
56 pass
57 return None
60def overload_of(f):
61 if f in SUPPORTED_BUILTINS:
62 return BUILTIN_FUNCTIONS_MAP[f.__name__]
63 return f
66def _find_originating_frame(caller_fn_scope, innermost=True):
67 """Locates the frame in which `caller_fn_scope` was defined."""
68 ctx_frame = inspect.currentframe()
69 result = None
70 while ctx_frame is not None:
71 # Note it should not be normally possible to get false positives this way
72 # because the function scope object is not accessible to user code (barring
73 # call stack introspection).
74 if ctx_frame.f_locals.get(caller_fn_scope.name, None) is caller_fn_scope:
75 result = ctx_frame
76 if innermost:
77 break
78 ctx_frame = ctx_frame.f_back
80 assert result is not None, (
81 'the conversion process should ensure the caller_fn_scope is always'
82 ' found somewhere on the call stack')
84 return result
87def locals_in_original_context(caller_fn_scope):
88 """Executes the locals function in the context of a specified function."""
89 return _find_originating_frame(caller_fn_scope, innermost=True).f_locals
92def globals_in_original_context(caller_fn_scope):
93 """Executes the locals function in the context of a specified function."""
94 return _find_originating_frame(caller_fn_scope, innermost=True).f_globals
97def eval_in_original_context(f, args, caller_fn_scope):
98 """Executes the eval function in the context of a specified function."""
99 # When control flow is rewritten using functions, eval should use the
100 # variables found in the same block where it was called. That is equivalent
101 # to the innermost function call.
102 ctx_frame = _find_originating_frame(caller_fn_scope, innermost=True)
104 args = (
105 args[0],
106 ctx_frame.f_globals if len(args) < 2 else args[1],
107 ctx_frame.f_locals if len(args) < 3 else args[2],
108 )
109 return f(*args)
112def super_in_original_context(f, args, caller_fn_scope):
113 """Executes the super function in the context of a specified function.
115 See https://docs.python.org/3/library/functions.html#super for the exact
116 details
118 Args:
119 f: Callable, typically the super builtin
120 args: List[Any], the original call arguments
121 caller_fn_scope: Optional[function_wrappers.FunctionScope], the function
122 scope of the converted function in which this call was originally made
124 Returns:
125 The result of calling `f` as if it was called in the frame indicated by
126 `caller_fn_scope`.
127 """
129 # Only the no-arg call is desugared.
130 if args:
131 return f(*args)
133 # Inner functions seem to include their closure in f_locals, so we need
134 # to find the outermost frame.
135 ctx_frame = _find_originating_frame(caller_fn_scope, innermost=False)
137 # When super(..) is called without arguments, it looks for __class__ cell
138 # variable and the first argument passed in the enclosing function according
139 # to the spec https://www.python.org/dev/peps/pep-3135/ .
140 #
141 # We couldn't verify if `inspect.currentframe().f_code.co_varnames[0]` is
142 # guaranteed to be the first argument from an official doc or PEP, however,
143 # it's fairly stable and well established:
144 # - An unofficial community doc mentions it.
145 # https://python-reference.readthedocs.io/en/latest/docs/code/varnames.html
146 # - CPython has tests checking that order, which was merged in 2008, and
147 # unchanged since then.
148 # https://github.com/python/cpython/blame/2f224a077a83ac9de8a12bb7dcc516642b8176d8/Lib/lib2to3/tests/data/py2_test_grammar.py#L157
149 # https://github.com/python/cpython/blame/2f224a077a83ac9de8a12bb7dcc516642b8176d8/Lib/lib2to3/tests/data/py3_test_grammar.py#L192
150 #
151 # Note: the name can be more reliably obtained by inspecting the calling
152 # function's argspec.
153 #
154 # Even though methods can be declared using *args (def method(*args)),
155 # that pattern is disallowed by super() -- it raises super() no arguments.
156 # Method definitions using **kwargs are not allowed at all.
157 # In other words, we can always assume that self is on the first positional
158 # argument (for correct code).
159 #
160 # TODO(mdan): Consider additional checks in case the input code is incorrect.
161 # For example, the error might be cryptic compared to what super() regularly
162 # raises.
164 type_arg = ctx_frame.f_locals['__class__']
165 self_arg_name = ctx_frame.f_code.co_varnames[0]
166 self_arg = ctx_frame.f_locals[self_arg_name]
167 return f(type_arg, self_arg)
170def abs_(x):
171 abs_override = registry_lookup(abs_registry, x)
172 if abs_override is not None:
173 return abs_override(x)
174 if tensor_util.is_tf_type(x):
175 return _tf_abs(x)
176 return _py_abs(x)
179def _tf_abs(x):
180 return math_ops.abs(x)
183def _py_abs(x):
184 return abs(x)
187def float_(x=0):
188 if tensor_util.is_tf_type(x):
189 return _tf_float(x)
190 return _py_float(x)
193def _tf_float(x):
194 # TODO(mdan): We shouldn't assume float32.
195 if x.dtype == dtypes.string:
196 return gen_parsing_ops.string_to_number(x, out_type=dtypes.float32)
197 return math_ops.cast(x, dtype=dtypes.float32)
200def _py_float(x):
201 return float(x)
204def int_(x=0, base=UNSPECIFIED):
205 if tensor_util.is_tf_type(x):
206 return _tf_int(x, base)
207 return _py_int(x, base)
210def _tf_int(x, base):
211 if base not in (10, UNSPECIFIED):
212 raise NotImplementedError('base {} not supported for int'.format(base))
214 # TODO(mdan): We shouldn't assume int32.
215 if x.dtype == dtypes.string:
216 return gen_parsing_ops.string_to_number(x, out_type=dtypes.int32)
217 return math_ops.cast(x, dtype=dtypes.int32)
220def _py_int(x, base):
221 if base is UNSPECIFIED:
222 return int(x)
223 return int(x, base)
226def len_(s):
227 len_override = registry_lookup(len_registry, s)
228 if len_override is not None:
229 return len_override(s)
230 if tensors.is_tensor_array(s):
231 return _tf_tensor_array_len(s)
232 elif tensors.is_tensor_list(s):
233 return _tf_tensor_list_len(s)
234 elif tensor_util.is_tf_type(s):
235 return _tf_tensor_len(s)
236 return _py_len(s)
239def _tf_tensor_array_len(s):
240 return s.size()
243def _tf_tensor_list_len(s):
244 return list_ops.tensor_list_length(s)
247def _tf_tensor_len(s):
248 """Overload of len_ for Tensor arguments."""
249 # Statically shaped tensors: length is known ahead of time.
250 if s.shape.ndims and s.shape.dims[0].value is not None:
251 return s.shape.dims[0].value
253 # Static shape of unknown dimensions: use dynamic shape but statically
254 # check that it's a scalar.
255 shape = array_ops.shape(s)
257 assert shape.shape, 'shape tensor of zero size? {}'.format(shape)
259 if shape.shape[0] == 0:
260 raise ValueError(
261 'len requires a non-scalar tensor, got one of shape {}'.format(shape))
263 if shape.shape.dims[0].value is not None:
264 return array_ops.shape(s)[0]
266 # Fully dynamic shape: use ops.
267 rank = array_ops.rank(s)
269 def raise_zero_rank_error():
270 msg = gen_string_ops.string_join(
271 ['len requires non-zero rank, got ',
272 gen_string_ops.as_string(rank)])
273 with ops.control_dependencies([control_flow_assert.Assert(False, [msg])]):
274 return constant_op.constant(0, dtype=dtypes.int32)
276 return cond.cond(rank > 0, lambda: array_ops.shape(s)[0],
277 raise_zero_rank_error)
280def _py_len(s):
281 return len(s)
284def print_(*objects, **kwargs):
285 """Overload of the print builtin."""
286 # Note: Python 2.6 doesn't support explicit keywords after starargs.
287 unknown_kwargs = tuple(
288 set(kwargs.keys()) - set(('sep', 'end', 'file', 'flush')))
289 if unknown_kwargs:
290 raise ValueError('invalid keyword arguments: {}'.format(unknown_kwargs))
292 print_fn = _py_print
293 for x in objects:
294 print_override = registry_lookup(print_registry, x)
295 if print_override is not None: # pylint: disable=comparison-with-callable
296 print_fn = print_override
297 break
299 if print_fn is _py_print:
300 # If this fails, ops/autograph_ops.py hasn't been imported.
301 assert not any(tensor_util.is_tf_type(s) for s in objects)
303 return print_fn(*objects, **kwargs)
306def _py_print(*objects, **kwargs):
307 print(*objects, **kwargs)
310def min_(*args, **kwargs):
311 if any(tensor_util.is_tf_type(s) for s in args):
312 return _tf_min(*args, **kwargs)
313 return _py_min(*args, **kwargs)
316def _tf_min(*args, **kwargs):
317 if len(kwargs):
318 kwargs_tuple = tuple(set(kwargs.keys()))
319 raise ValueError('These keyword arguments are '
320 'currently not supported: {}'.format(kwargs_tuple))
321 if len(args) == 1:
322 rank = args[0].shape.rank
323 if rank == 0:
324 return args[0]
325 if rank == 1:
326 return math_ops.reduce_min(*args, axis=0)
327 raise ValueError('min(arg) currently support only tensor with rank 1, '
328 'but got a tensor with rank {}'.format(rank))
329 for arg in args:
330 rank = arg.shape.rank
331 if rank != 0:
332 raise ValueError('min(arg1, arg2, *args) currently support '
333 'only scalar tensor, but got a tensor '
334 'with shape {}'.format(rank))
335 return math_ops.reduce_min(args, axis=0)
338def _py_min(*args, **kwargs):
339 return min(*args, **kwargs)
342def max_(*args, **kwargs):
343 if any(tensor_util.is_tf_type(s) for s in args):
344 return _tf_max(*args, **kwargs)
345 return _py_max(*args, **kwargs)
348def _tf_max(*args, **kwargs):
349 if len(kwargs):
350 kwargs_tuple = tuple(set(kwargs.keys()))
351 raise ValueError('These keyword arguments are '
352 'currently not supported: {}'.format(kwargs_tuple))
353 if len(args) == 1:
354 rank = args[0].shape.rank
355 if rank == 0:
356 return args[0]
357 if rank == 1:
358 return math_ops.reduce_max(*args, axis=0)
359 raise ValueError('max(arg) currently support only tensor with rank 1, '
360 'but got a tensor with rank {}'.format(rank))
361 for arg in args:
362 rank = arg.shape.rank
363 if rank != 0:
364 raise ValueError('max(arg1, arg2, *args) currently support '
365 'only scalar tensor, but got a tensor '
366 'with shape {}'.format(rank))
367 return math_ops.reduce_max(args, axis=0)
370def _py_max(*args, **kwargs):
371 return max(*args, **kwargs)
374def range_(start_or_stop, stop=UNSPECIFIED, step=UNSPECIFIED):
375 if any(tensor_util.is_tf_type(s) for s in (start_or_stop, stop, step)):
376 return _tf_range(start_or_stop, stop, step)
377 return _py_range(start_or_stop, stop, step)
380def _tf_range(start_or_stop, stop, step):
381 """Overload of range_ that generates a TF range tensor."""
382 # Note: for static inputs (e.g. constants), tf.range errors out at graph
383 # construction time, instead of returning an empty tensor. Preventing the
384 # graph construction error aligns the semantics with Python.
386 # TODO(mdan): We should optimize this when a full tensor is not required.
387 if step is not UNSPECIFIED:
388 # TODO(mdan): Add argument coercion similar to other cases.
389 return math_ops.range(start_or_stop, stop, step)
390 if stop is not UNSPECIFIED:
391 stop = math_ops.maximum(start_or_stop, stop)
392 return math_ops.range(start_or_stop, stop)
393 start_or_stop = math_ops.maximum(start_or_stop, 0)
394 return math_ops.range(start_or_stop)
397def _py_range(start_or_stop, stop, step):
398 if step is not UNSPECIFIED:
399 return range(start_or_stop, stop, step)
400 if stop is not UNSPECIFIED:
401 return range(start_or_stop, stop)
402 return range(start_or_stop)
405def enumerate_(s, start=0):
406 enumerate_override = registry_lookup(enumerate_registry, s)
407 if enumerate_override is not None:
408 return enumerate_override(s, start)
409 return _py_enumerate(s, start)
412def _py_enumerate(s, start=0):
413 return enumerate(s, start)
416def zip_(*iterables, strict=False):
417 zip_fn = _py_zip
418 # If the overridden function is not the same across all iterables, use _py_zip
419 for x in iterables:
420 zip_override = registry_lookup(zip_registry, x)
421 if zip_override is None or (zip_fn != _py_zip and zip_override != zip_fn): # pylint: disable=comparison-with-callable
422 zip_fn = _py_zip
423 break
424 zip_fn = zip_override
425 return zip_fn(*iterables, strict=strict)
428def _py_zip(*iterables, strict=False):
429 if strict:
430 return zip(*iterables, strict=True)
431 else:
432 # Python < 3.10 doesn't have `strict` kwarg.
433 return zip(*iterables)
436def map_(fn, *iterables):
437 map_fn = _py_map
438 # If the overridden function is not the same across all iterables, use _py_map
439 for x in iterables:
440 map_override = registry_lookup(map_registry, x)
441 if map_override is None or (map_fn != _py_map and map_override != map_fn): # pylint: disable=comparison-with-callable
442 map_fn = _py_map
443 break
444 map_fn = map_override
445 return map_fn(fn, *iterables)
448def _py_map(fn, *iterables):
449 return map(fn, *iterables)
452def next_(iterator, default=UNSPECIFIED):
453 next_override = registry_lookup(next_registry, iterator)
454 if next_override is not None:
455 return next_override(iterator, default)
456 return next_py(iterator, default)
459def next_py(iterator, default=UNSPECIFIED):
460 if default is UNSPECIFIED:
461 return next(iterator)
462 return next(iterator, default)
465def filter_(function, iterable):
466 filter_override = registry_lookup(filter_registry, iterable)
467 if filter_override is not None:
468 return filter_override(function, iterable)
469 return _py_filter(function, iterable)
472def _py_filter(function, iterable):
473 return filter(function, iterable)
476def any_(iterable):
477 any_override = registry_lookup(any_registry, iterable)
478 if any_override is not None:
479 return any_override(iterable)
480 return _py_any(iterable)
483def _py_any(iterable):
484 return any(iterable)
487def all_(iterable):
488 all_override = registry_lookup(all_registry, iterable)
489 if all_override is not None:
490 return all_override(iterable)
491 return _py_all(iterable)
494def _py_all(iterable):
495 return all(iterable)
498def sorted_(iterable, key=UNSPECIFIED, reverse=UNSPECIFIED):
499 sorted_override = registry_lookup(sorted_registry, iterable)
500 if sorted_override is not None:
501 return sorted_override(iterable, key, reverse)
502 return _py_sorted(iterable, key, reverse)
505def _py_sorted(iterable, key, reverse):
506 if key is not UNSPECIFIED and reverse is UNSPECIFIED:
507 return sorted(iterable, key=key)
508 if key is UNSPECIFIED and reverse is not UNSPECIFIED:
509 return sorted(iterable, reverse=reverse)
510 if key is not UNSPECIFIED and reverse is not UNSPECIFIED:
511 return sorted(iterable, key=key, reverse=reverse)
512 return sorted(iterable)
515SUPPORTED_BUILTINS = (abs, float, int, len, print, range, enumerate, zip, map,
516 filter, any, all, sorted)
518BUILTIN_FUNCTIONS_MAP = {
519 'abs': abs_,
520 'any': any_,
521 'all': all_,
522 'enumerate': enumerate_,
523 'filter': filter_,
524 'float': float_,
525 'int': int_,
526 'len': len_,
527 'map': map_,
528 'next': next_,
529 'print': print_,
530 'range': range_,
531 'sorted': sorted_,
532 'zip': zip_,
533}