Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/operators/py_builtins.py: 27%

284 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Operators corresponding to Python builtin functions. 

16 

17List of built-in functions: https://docs.python.org/3/library/functions.html 

18""" 

19 

20import inspect 

21 

22from tensorflow.python.autograph.utils import tensors 

23from tensorflow.python.autograph.utils import type_registry 

24from tensorflow.python.framework import constant_op 

25from tensorflow.python.framework import dtypes 

26from tensorflow.python.framework import ops 

27from tensorflow.python.framework import tensor_util 

28from tensorflow.python.ops import array_ops 

29from tensorflow.python.ops import cond 

30from tensorflow.python.ops import control_flow_assert 

31from tensorflow.python.ops import gen_parsing_ops 

32from tensorflow.python.ops import gen_string_ops 

33from tensorflow.python.ops import list_ops 

34from tensorflow.python.ops import math_ops 

35 

36 

37UNSPECIFIED = object() 

38 

39abs_registry = type_registry.TypeRegistry() 

40len_registry = type_registry.TypeRegistry() 

41print_registry = type_registry.TypeRegistry() 

42enumerate_registry = type_registry.TypeRegistry() 

43zip_registry = type_registry.TypeRegistry() 

44map_registry = type_registry.TypeRegistry() 

45filter_registry = type_registry.TypeRegistry() 

46any_registry = type_registry.TypeRegistry() 

47all_registry = type_registry.TypeRegistry() 

48sorted_registry = type_registry.TypeRegistry() 

49next_registry = type_registry.TypeRegistry() 

50 

51 

52def registry_lookup(reg, obj): 

53 try: 

54 return reg.lookup(obj) 

55 except LookupError: 

56 pass 

57 return None 

58 

59 

60def overload_of(f): 

61 if f in SUPPORTED_BUILTINS: 

62 return BUILTIN_FUNCTIONS_MAP[f.__name__] 

63 return f 

64 

65 

66def _find_originating_frame(caller_fn_scope, innermost=True): 

67 """Locates the frame in which `caller_fn_scope` was defined.""" 

68 ctx_frame = inspect.currentframe() 

69 result = None 

70 while ctx_frame is not None: 

71 # Note it should not be normally possible to get false positives this way 

72 # because the function scope object is not accessible to user code (barring 

73 # call stack introspection). 

74 if ctx_frame.f_locals.get(caller_fn_scope.name, None) is caller_fn_scope: 

75 result = ctx_frame 

76 if innermost: 

77 break 

78 ctx_frame = ctx_frame.f_back 

79 

80 assert result is not None, ( 

81 'the conversion process should ensure the caller_fn_scope is always' 

82 ' found somewhere on the call stack') 

83 

84 return result 

85 

86 

87def locals_in_original_context(caller_fn_scope): 

88 """Executes the locals function in the context of a specified function.""" 

89 return _find_originating_frame(caller_fn_scope, innermost=True).f_locals 

90 

91 

92def globals_in_original_context(caller_fn_scope): 

93 """Executes the locals function in the context of a specified function.""" 

94 return _find_originating_frame(caller_fn_scope, innermost=True).f_globals 

95 

96 

97def eval_in_original_context(f, args, caller_fn_scope): 

98 """Executes the eval function in the context of a specified function.""" 

99 # When control flow is rewritten using functions, eval should use the 

100 # variables found in the same block where it was called. That is equivalent 

101 # to the innermost function call. 

102 ctx_frame = _find_originating_frame(caller_fn_scope, innermost=True) 

103 

104 args = ( 

105 args[0], 

106 ctx_frame.f_globals if len(args) < 2 else args[1], 

107 ctx_frame.f_locals if len(args) < 3 else args[2], 

108 ) 

109 return f(*args) 

110 

111 

112def super_in_original_context(f, args, caller_fn_scope): 

113 """Executes the super function in the context of a specified function. 

114 

115 See https://docs.python.org/3/library/functions.html#super for the exact 

116 details 

117 

118 Args: 

119 f: Callable, typically the super builtin 

120 args: List[Any], the original call arguments 

121 caller_fn_scope: Optional[function_wrappers.FunctionScope], the function 

122 scope of the converted function in which this call was originally made 

123 

124 Returns: 

125 The result of calling `f` as if it was called in the frame indicated by 

126 `caller_fn_scope`. 

127 """ 

128 

129 # Only the no-arg call is desugared. 

130 if args: 

131 return f(*args) 

132 

133 # Inner functions seem to include their closure in f_locals, so we need 

134 # to find the outermost frame. 

135 ctx_frame = _find_originating_frame(caller_fn_scope, innermost=False) 

136 

137 # When super(..) is called without arguments, it looks for __class__ cell 

138 # variable and the first argument passed in the enclosing function according 

139 # to the spec https://www.python.org/dev/peps/pep-3135/ . 

140 # 

141 # We couldn't verify if `inspect.currentframe().f_code.co_varnames[0]` is 

142 # guaranteed to be the first argument from an official doc or PEP, however, 

143 # it's fairly stable and well established: 

144 # - An unofficial community doc mentions it. 

145 # https://python-reference.readthedocs.io/en/latest/docs/code/varnames.html 

146 # - CPython has tests checking that order, which was merged in 2008, and 

147 # unchanged since then. 

148 # https://github.com/python/cpython/blame/2f224a077a83ac9de8a12bb7dcc516642b8176d8/Lib/lib2to3/tests/data/py2_test_grammar.py#L157 

149 # https://github.com/python/cpython/blame/2f224a077a83ac9de8a12bb7dcc516642b8176d8/Lib/lib2to3/tests/data/py3_test_grammar.py#L192 

150 # 

151 # Note: the name can be more reliably obtained by inspecting the calling 

152 # function's argspec. 

153 # 

154 # Even though methods can be declared using *args (def method(*args)), 

155 # that pattern is disallowed by super() -- it raises super() no arguments. 

156 # Method definitions using **kwargs are not allowed at all. 

157 # In other words, we can always assume that self is on the first positional 

158 # argument (for correct code). 

159 # 

160 # TODO(mdan): Consider additional checks in case the input code is incorrect. 

161 # For example, the error might be cryptic compared to what super() regularly 

162 # raises. 

163 

164 type_arg = ctx_frame.f_locals['__class__'] 

165 self_arg_name = ctx_frame.f_code.co_varnames[0] 

166 self_arg = ctx_frame.f_locals[self_arg_name] 

167 return f(type_arg, self_arg) 

168 

169 

170def abs_(x): 

171 abs_override = registry_lookup(abs_registry, x) 

172 if abs_override is not None: 

173 return abs_override(x) 

174 if tensor_util.is_tf_type(x): 

175 return _tf_abs(x) 

176 return _py_abs(x) 

177 

178 

179def _tf_abs(x): 

180 return math_ops.abs(x) 

181 

182 

183def _py_abs(x): 

184 return abs(x) 

185 

186 

187def float_(x=0): 

188 if tensor_util.is_tf_type(x): 

189 return _tf_float(x) 

190 return _py_float(x) 

191 

192 

193def _tf_float(x): 

194 # TODO(mdan): We shouldn't assume float32. 

195 if x.dtype == dtypes.string: 

196 return gen_parsing_ops.string_to_number(x, out_type=dtypes.float32) 

197 return math_ops.cast(x, dtype=dtypes.float32) 

198 

199 

200def _py_float(x): 

201 return float(x) 

202 

203 

204def int_(x=0, base=UNSPECIFIED): 

205 if tensor_util.is_tf_type(x): 

206 return _tf_int(x, base) 

207 return _py_int(x, base) 

208 

209 

210def _tf_int(x, base): 

211 if base not in (10, UNSPECIFIED): 

212 raise NotImplementedError('base {} not supported for int'.format(base)) 

213 

214 # TODO(mdan): We shouldn't assume int32. 

215 if x.dtype == dtypes.string: 

216 return gen_parsing_ops.string_to_number(x, out_type=dtypes.int32) 

217 return math_ops.cast(x, dtype=dtypes.int32) 

218 

219 

220def _py_int(x, base): 

221 if base is UNSPECIFIED: 

222 return int(x) 

223 return int(x, base) 

224 

225 

226def len_(s): 

227 len_override = registry_lookup(len_registry, s) 

228 if len_override is not None: 

229 return len_override(s) 

230 if tensors.is_tensor_array(s): 

231 return _tf_tensor_array_len(s) 

232 elif tensors.is_tensor_list(s): 

233 return _tf_tensor_list_len(s) 

234 elif tensor_util.is_tf_type(s): 

235 return _tf_tensor_len(s) 

236 return _py_len(s) 

237 

238 

239def _tf_tensor_array_len(s): 

240 return s.size() 

241 

242 

243def _tf_tensor_list_len(s): 

244 return list_ops.tensor_list_length(s) 

245 

246 

247def _tf_tensor_len(s): 

248 """Overload of len_ for Tensor arguments.""" 

249 # Statically shaped tensors: length is known ahead of time. 

250 if s.shape.ndims and s.shape.dims[0].value is not None: 

251 return s.shape.dims[0].value 

252 

253 # Static shape of unknown dimensions: use dynamic shape but statically 

254 # check that it's a scalar. 

255 shape = array_ops.shape(s) 

256 

257 assert shape.shape, 'shape tensor of zero size? {}'.format(shape) 

258 

259 if shape.shape[0] == 0: 

260 raise ValueError( 

261 'len requires a non-scalar tensor, got one of shape {}'.format(shape)) 

262 

263 if shape.shape.dims[0].value is not None: 

264 return array_ops.shape(s)[0] 

265 

266 # Fully dynamic shape: use ops. 

267 rank = array_ops.rank(s) 

268 

269 def raise_zero_rank_error(): 

270 msg = gen_string_ops.string_join( 

271 ['len requires non-zero rank, got ', 

272 gen_string_ops.as_string(rank)]) 

273 with ops.control_dependencies([control_flow_assert.Assert(False, [msg])]): 

274 return constant_op.constant(0, dtype=dtypes.int32) 

275 

276 return cond.cond(rank > 0, lambda: array_ops.shape(s)[0], 

277 raise_zero_rank_error) 

278 

279 

280def _py_len(s): 

281 return len(s) 

282 

283 

284def print_(*objects, **kwargs): 

285 """Overload of the print builtin.""" 

286 # Note: Python 2.6 doesn't support explicit keywords after starargs. 

287 unknown_kwargs = tuple( 

288 set(kwargs.keys()) - set(('sep', 'end', 'file', 'flush'))) 

289 if unknown_kwargs: 

290 raise ValueError('invalid keyword arguments: {}'.format(unknown_kwargs)) 

291 

292 print_fn = _py_print 

293 for x in objects: 

294 print_override = registry_lookup(print_registry, x) 

295 if print_override is not None: # pylint: disable=comparison-with-callable 

296 print_fn = print_override 

297 break 

298 

299 if print_fn is _py_print: 

300 # If this fails, ops/autograph_ops.py hasn't been imported. 

301 assert not any(tensor_util.is_tf_type(s) for s in objects) 

302 

303 return print_fn(*objects, **kwargs) 

304 

305 

306def _py_print(*objects, **kwargs): 

307 print(*objects, **kwargs) 

308 

309 

310def min_(*args, **kwargs): 

311 if any(tensor_util.is_tf_type(s) for s in args): 

312 return _tf_min(*args, **kwargs) 

313 return _py_min(*args, **kwargs) 

314 

315 

316def _tf_min(*args, **kwargs): 

317 if len(kwargs): 

318 kwargs_tuple = tuple(set(kwargs.keys())) 

319 raise ValueError('These keyword arguments are ' 

320 'currently not supported: {}'.format(kwargs_tuple)) 

321 if len(args) == 1: 

322 rank = args[0].shape.rank 

323 if rank == 0: 

324 return args[0] 

325 if rank == 1: 

326 return math_ops.reduce_min(*args, axis=0) 

327 raise ValueError('min(arg) currently support only tensor with rank 1, ' 

328 'but got a tensor with rank {}'.format(rank)) 

329 for arg in args: 

330 rank = arg.shape.rank 

331 if rank != 0: 

332 raise ValueError('min(arg1, arg2, *args) currently support ' 

333 'only scalar tensor, but got a tensor ' 

334 'with shape {}'.format(rank)) 

335 return math_ops.reduce_min(args, axis=0) 

336 

337 

338def _py_min(*args, **kwargs): 

339 return min(*args, **kwargs) 

340 

341 

342def max_(*args, **kwargs): 

343 if any(tensor_util.is_tf_type(s) for s in args): 

344 return _tf_max(*args, **kwargs) 

345 return _py_max(*args, **kwargs) 

346 

347 

348def _tf_max(*args, **kwargs): 

349 if len(kwargs): 

350 kwargs_tuple = tuple(set(kwargs.keys())) 

351 raise ValueError('These keyword arguments are ' 

352 'currently not supported: {}'.format(kwargs_tuple)) 

353 if len(args) == 1: 

354 rank = args[0].shape.rank 

355 if rank == 0: 

356 return args[0] 

357 if rank == 1: 

358 return math_ops.reduce_max(*args, axis=0) 

359 raise ValueError('max(arg) currently support only tensor with rank 1, ' 

360 'but got a tensor with rank {}'.format(rank)) 

361 for arg in args: 

362 rank = arg.shape.rank 

363 if rank != 0: 

364 raise ValueError('max(arg1, arg2, *args) currently support ' 

365 'only scalar tensor, but got a tensor ' 

366 'with shape {}'.format(rank)) 

367 return math_ops.reduce_max(args, axis=0) 

368 

369 

370def _py_max(*args, **kwargs): 

371 return max(*args, **kwargs) 

372 

373 

374def range_(start_or_stop, stop=UNSPECIFIED, step=UNSPECIFIED): 

375 if any(tensor_util.is_tf_type(s) for s in (start_or_stop, stop, step)): 

376 return _tf_range(start_or_stop, stop, step) 

377 return _py_range(start_or_stop, stop, step) 

378 

379 

380def _tf_range(start_or_stop, stop, step): 

381 """Overload of range_ that generates a TF range tensor.""" 

382 # Note: for static inputs (e.g. constants), tf.range errors out at graph 

383 # construction time, instead of returning an empty tensor. Preventing the 

384 # graph construction error aligns the semantics with Python. 

385 

386 # TODO(mdan): We should optimize this when a full tensor is not required. 

387 if step is not UNSPECIFIED: 

388 # TODO(mdan): Add argument coercion similar to other cases. 

389 return math_ops.range(start_or_stop, stop, step) 

390 if stop is not UNSPECIFIED: 

391 stop = math_ops.maximum(start_or_stop, stop) 

392 return math_ops.range(start_or_stop, stop) 

393 start_or_stop = math_ops.maximum(start_or_stop, 0) 

394 return math_ops.range(start_or_stop) 

395 

396 

397def _py_range(start_or_stop, stop, step): 

398 if step is not UNSPECIFIED: 

399 return range(start_or_stop, stop, step) 

400 if stop is not UNSPECIFIED: 

401 return range(start_or_stop, stop) 

402 return range(start_or_stop) 

403 

404 

405def enumerate_(s, start=0): 

406 enumerate_override = registry_lookup(enumerate_registry, s) 

407 if enumerate_override is not None: 

408 return enumerate_override(s, start) 

409 return _py_enumerate(s, start) 

410 

411 

412def _py_enumerate(s, start=0): 

413 return enumerate(s, start) 

414 

415 

416def zip_(*iterables, strict=False): 

417 zip_fn = _py_zip 

418 # If the overridden function is not the same across all iterables, use _py_zip 

419 for x in iterables: 

420 zip_override = registry_lookup(zip_registry, x) 

421 if zip_override is None or (zip_fn != _py_zip and zip_override != zip_fn): # pylint: disable=comparison-with-callable 

422 zip_fn = _py_zip 

423 break 

424 zip_fn = zip_override 

425 return zip_fn(*iterables, strict=strict) 

426 

427 

428def _py_zip(*iterables, strict=False): 

429 if strict: 

430 return zip(*iterables, strict=True) 

431 else: 

432 # Python < 3.10 doesn't have `strict` kwarg. 

433 return zip(*iterables) 

434 

435 

436def map_(fn, *iterables): 

437 map_fn = _py_map 

438 # If the overridden function is not the same across all iterables, use _py_map 

439 for x in iterables: 

440 map_override = registry_lookup(map_registry, x) 

441 if map_override is None or (map_fn != _py_map and map_override != map_fn): # pylint: disable=comparison-with-callable 

442 map_fn = _py_map 

443 break 

444 map_fn = map_override 

445 return map_fn(fn, *iterables) 

446 

447 

448def _py_map(fn, *iterables): 

449 return map(fn, *iterables) 

450 

451 

452def next_(iterator, default=UNSPECIFIED): 

453 next_override = registry_lookup(next_registry, iterator) 

454 if next_override is not None: 

455 return next_override(iterator, default) 

456 return next_py(iterator, default) 

457 

458 

459def next_py(iterator, default=UNSPECIFIED): 

460 if default is UNSPECIFIED: 

461 return next(iterator) 

462 return next(iterator, default) 

463 

464 

465def filter_(function, iterable): 

466 filter_override = registry_lookup(filter_registry, iterable) 

467 if filter_override is not None: 

468 return filter_override(function, iterable) 

469 return _py_filter(function, iterable) 

470 

471 

472def _py_filter(function, iterable): 

473 return filter(function, iterable) 

474 

475 

476def any_(iterable): 

477 any_override = registry_lookup(any_registry, iterable) 

478 if any_override is not None: 

479 return any_override(iterable) 

480 return _py_any(iterable) 

481 

482 

483def _py_any(iterable): 

484 return any(iterable) 

485 

486 

487def all_(iterable): 

488 all_override = registry_lookup(all_registry, iterable) 

489 if all_override is not None: 

490 return all_override(iterable) 

491 return _py_all(iterable) 

492 

493 

494def _py_all(iterable): 

495 return all(iterable) 

496 

497 

498def sorted_(iterable, key=UNSPECIFIED, reverse=UNSPECIFIED): 

499 sorted_override = registry_lookup(sorted_registry, iterable) 

500 if sorted_override is not None: 

501 return sorted_override(iterable, key, reverse) 

502 return _py_sorted(iterable, key, reverse) 

503 

504 

505def _py_sorted(iterable, key, reverse): 

506 if key is not UNSPECIFIED and reverse is UNSPECIFIED: 

507 return sorted(iterable, key=key) 

508 if key is UNSPECIFIED and reverse is not UNSPECIFIED: 

509 return sorted(iterable, reverse=reverse) 

510 if key is not UNSPECIFIED and reverse is not UNSPECIFIED: 

511 return sorted(iterable, key=key, reverse=reverse) 

512 return sorted(iterable) 

513 

514 

515SUPPORTED_BUILTINS = (abs, float, int, len, print, range, enumerate, zip, map, 

516 filter, any, all, sorted) 

517 

518BUILTIN_FUNCTIONS_MAP = { 

519 'abs': abs_, 

520 'any': any_, 

521 'all': all_, 

522 'enumerate': enumerate_, 

523 'filter': filter_, 

524 'float': float_, 

525 'int': int_, 

526 'len': len_, 

527 'map': map_, 

528 'next': next_, 

529 'print': print_, 

530 'range': range_, 

531 'sorted': sorted_, 

532 'zip': zip_, 

533}