Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/tf_utils.py: 22%

282 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""TensorFlow-related utilities.""" 

16 

17import collections 

18import contextlib 

19import copy 

20import platform 

21import random 

22import threading 

23 

24import numpy as np 

25import tensorflow.compat.v2 as tf 

26from absl import logging 

27 

28from keras.src import backend 

29from keras.src.engine import keras_tensor 

30from keras.src.utils import object_identity 

31from keras.src.utils import tf_contextlib 

32 

33# isort: off 

34from tensorflow.python.framework import ops 

35from tensorflow.python.util.tf_export import keras_export 

36from tensorflow.python import pywrap_tfe 

37 

38 

39@keras_export("keras.utils.set_random_seed", v1=[]) 

40def set_random_seed(seed): 

41 """Sets all random seeds for the program (Python, NumPy, and TensorFlow). 

42 

43 You can use this utility to make almost any Keras program fully 

44 deterministic. Some limitations apply in cases where network communications 

45 are involved (e.g. parameter server distribution), which creates additional 

46 sources of randomness, or when certain non-deterministic cuDNN ops are 

47 involved. 

48 

49 Calling this utility is equivalent to the following: 

50 

51 ```python 

52 import random 

53 import numpy as np 

54 import tensorflow as tf 

55 random.seed(seed) 

56 np.random.seed(seed) 

57 tf.random.set_seed(seed) 

58 ``` 

59 

60 Arguments: 

61 seed: Integer, the random seed to use. 

62 """ 

63 if not isinstance(seed, int): 

64 raise ValueError( 

65 "Expected `seed` argument to be an integer. " 

66 f"Received: seed={seed} (of type {type(seed)})" 

67 ) 

68 random.seed(seed) 

69 np.random.seed(seed) 

70 tf.random.set_seed(seed) 

71 backend._SEED_GENERATOR.generator = random.Random(seed) 

72 

73 

74def get_random_seed(): 

75 """Retrieve a seed value to seed a random generator. 

76 

77 Returns: 

78 the random seed as an integer. 

79 """ 

80 if getattr(backend._SEED_GENERATOR, "generator", None): 

81 return backend._SEED_GENERATOR.generator.randint(1, 1e9) 

82 else: 

83 return random.randint(1, 1e9) 

84 

85 

86def is_tensor_or_tensor_list(v): 

87 v = tf.nest.flatten(v) 

88 if v and isinstance(v[0], tf.Tensor): 

89 return True 

90 else: 

91 return False 

92 

93 

94def get_reachable_from_inputs(inputs, targets=None): 

95 """Returns the set of tensors/ops reachable from `inputs`. 

96 

97 Stops if all targets have been found (target is optional). 

98 

99 Only valid in Symbolic mode, not Eager mode. 

100 

101 Args: 

102 inputs: List of tensors. 

103 targets: List of tensors. 

104 

105 Returns: 

106 A set of tensors reachable from the inputs (includes the inputs 

107 themselves). 

108 """ 

109 inputs = tf.nest.flatten(inputs, expand_composites=True) 

110 reachable = object_identity.ObjectIdentitySet(inputs) 

111 if targets: 

112 remaining_targets = object_identity.ObjectIdentitySet( 

113 tf.nest.flatten(targets) 

114 ) 

115 queue = collections.deque(inputs) 

116 

117 while queue: 

118 x = queue.pop() 

119 if isinstance(x, tuple(_user_convertible_tensor_types)): 

120 # Can't find consumers of user-specific types. 

121 continue 

122 

123 if isinstance(x, tf.Operation): 

124 outputs = x.outputs[:] or [] 

125 outputs += x._control_outputs 

126 elif isinstance(x, tf.Variable): 

127 try: 

128 outputs = [x.op] 

129 except AttributeError: 

130 # Variables can be created in an Eager context. 

131 outputs = [] 

132 elif tf.is_tensor(x): 

133 outputs = x.consumers() 

134 else: 

135 raise TypeError( 

136 "Expected tf.Operation, tf.Variable, or tf.Tensor. " 

137 f"Received: {x}" 

138 ) 

139 

140 for y in outputs: 

141 if y not in reachable: 

142 reachable.add(y) 

143 if targets: 

144 remaining_targets.discard(y) 

145 queue.appendleft(y) 

146 

147 if targets and not remaining_targets: 

148 return reachable 

149 

150 return reachable 

151 

152 

153# This function needs access to private functions of `nest`. 

154 

155 

156def map_structure_with_atomic(is_atomic_fn, map_fn, nested): 

157 """Maps the atomic elements of a nested structure. 

158 

159 Args: 

160 is_atomic_fn: A function that determines if an element of `nested` is 

161 atomic. 

162 map_fn: The function to apply to atomic elements of `nested`. 

163 nested: A nested structure. 

164 

165 Returns: 

166 The nested structure, with atomic elements mapped according to `map_fn`. 

167 

168 Raises: 

169 ValueError: If an element that is neither atomic nor a sequence is 

170 encountered. 

171 """ 

172 if is_atomic_fn(nested): 

173 return map_fn(nested) 

174 

175 # Recursively convert. 

176 if not tf.nest.is_nested(nested): 

177 raise ValueError( 

178 f"Received non-atomic and non-sequence element: {nested} " 

179 f"of type {type(nested)}" 

180 ) 

181 if tf.__internal__.nest.is_mapping(nested): 

182 values = [nested[k] for k in sorted(nested.keys())] 

183 elif tf.__internal__.nest.is_attrs(nested): 

184 values = _astuple(nested) 

185 else: 

186 values = nested 

187 mapped_values = [ 

188 map_structure_with_atomic(is_atomic_fn, map_fn, ele) for ele in values 

189 ] 

190 return tf.__internal__.nest.sequence_like(nested, mapped_values) 

191 

192 

193def get_shapes(tensors): 

194 """Gets shapes from tensors.""" 

195 return tf.nest.map_structure( 

196 lambda x: x.shape if hasattr(x, "shape") else None, tensors 

197 ) 

198 

199 

200def convert_shapes(input_shape, to_tuples=True): 

201 """Converts nested shape representations to desired format. 

202 

203 Performs: 

204 

205 TensorShapes -> tuples if `to_tuples=True`. 

206 tuples of int or None -> TensorShapes if `to_tuples=False`. 

207 

208 Valid objects to be converted are: 

209 - TensorShapes 

210 - tuples with elements of type int or None. 

211 - ints 

212 - None 

213 

214 Args: 

215 input_shape: A nested structure of objects to be converted to 

216 TensorShapes. 

217 to_tuples: If `True`, converts all TensorShape to tuples. Otherwise 

218 converts all tuples representing shapes to TensorShapes. 

219 

220 Returns: 

221 Nested structure of shapes in desired format. 

222 

223 Raises: 

224 ValueError: when the input tensor shape can't be converted to tuples, eg 

225 unknown tensor shape. 

226 """ 

227 

228 def _is_shape_component(value): 

229 return value is None or isinstance(value, (int, tf.compat.v1.Dimension)) 

230 

231 def _is_atomic_shape(input_shape): 

232 # Ex: TensorShape or (None, 10, 32) or 5 or `None` 

233 if _is_shape_component(input_shape): 

234 return True 

235 if isinstance(input_shape, tf.TensorShape): 

236 return True 

237 if isinstance(input_shape, (tuple, list)) and all( 

238 _is_shape_component(ele) for ele in input_shape 

239 ): 

240 return True 

241 return False 

242 

243 def _convert_shape(input_shape): 

244 input_shape = tf.TensorShape(input_shape) 

245 if to_tuples: 

246 input_shape = tuple(input_shape.as_list()) 

247 return input_shape 

248 

249 return map_structure_with_atomic( 

250 _is_atomic_shape, _convert_shape, input_shape 

251 ) 

252 

253 

254def validate_axis(axis, input_shape): 

255 """Validate an axis value and returns its standardized form. 

256 

257 Args: 

258 axis: Value to validate. Can be an integer or a list/tuple of integers. 

259 Integers may be negative. 

260 input_shape: Reference input shape that the axis/axes refer to. 

261 

262 Returns: 

263 Normalized form of `axis`, i.e. a list with all-positive values. 

264 """ 

265 input_shape = tf.TensorShape(input_shape) 

266 rank = input_shape.rank 

267 if not rank: 

268 raise ValueError( 

269 f"Input has undefined rank. Received: input_shape={input_shape}" 

270 ) 

271 

272 # Convert axis to list and resolve negatives 

273 if isinstance(axis, int): 

274 axis = [axis] 

275 else: 

276 axis = list(axis) 

277 for idx, x in enumerate(axis): 

278 if x < 0: 

279 axis[idx] = rank + x 

280 

281 # Validate axes 

282 for x in axis: 

283 if x < 0 or x >= rank: 

284 raise ValueError( 

285 "Invalid value for `axis` argument. " 

286 "Expected 0 <= axis < inputs.rank (with " 

287 f"inputs.rank={rank}). Received: axis={tuple(axis)}" 

288 ) 

289 if len(axis) != len(set(axis)): 

290 raise ValueError(f"Duplicate axis: {tuple(axis)}") 

291 return axis 

292 

293 

294class ListWrapper: 

295 """A wrapper for lists to be treated as elements for `nest`.""" 

296 

297 def __init__(self, list_to_wrap): 

298 self._list = list_to_wrap 

299 

300 def as_list(self): 

301 return self._list 

302 

303 

304def convert_inner_node_data(nested, wrap=False): 

305 """Either wraps or unwraps innermost node data lists in `ListWrapper` 

306 objects. 

307 

308 Args: 

309 nested: A nested data structure. 

310 wrap: If `True`, wrap innermost lists in `ListWrapper` objects. If 

311 `False`, unwraps `ListWrapper` objects into lists. 

312 

313 Returns: 

314 Structure of same type as nested, with lists wrapped/unwrapped. 

315 """ 

316 

317 def _is_serialized_node_data(nested): 

318 # Node data can be of form `[layer_name, node_id, tensor_id]` or 

319 # `[layer_name, node_id, tensor_id, kwargs]`. 

320 if ( 

321 isinstance(nested, list) 

322 and (len(nested) in [3, 4]) 

323 and isinstance(nested[0], str) 

324 ): 

325 return True 

326 return False 

327 

328 def _is_atomic_nested(nested): 

329 """Returns `True` if `nested` is a list representing node data.""" 

330 if isinstance(nested, ListWrapper): 

331 return True 

332 if _is_serialized_node_data(nested): 

333 return True 

334 return not tf.nest.is_nested(nested) 

335 

336 def _convert_object_or_list(nested): 

337 """Convert b/t `ListWrapper` object and list representations.""" 

338 if wrap: 

339 if isinstance(nested, ListWrapper): 

340 return nested 

341 if _is_serialized_node_data(nested): 

342 return ListWrapper(nested) 

343 return nested 

344 else: 

345 if isinstance(nested, ListWrapper): 

346 return nested.as_list() 

347 return nested 

348 

349 return map_structure_with_atomic( 

350 _is_atomic_nested, _convert_object_or_list, nested 

351 ) 

352 

353 

354def shape_type_conversion(fn): 

355 """Decorator that handles tuple/TensorShape conversion. 

356 

357 Used in `compute_output_shape` and `build`. 

358 

359 Args: 

360 fn: function to wrap. 

361 

362 Returns: 

363 Wrapped function. 

364 """ 

365 

366 def wrapper(instance, input_shape): 

367 # Pass shapes as tuples to `fn` 

368 # This preserves compatibility with external Keras. 

369 if input_shape is not None: 

370 input_shape = convert_shapes(input_shape, to_tuples=True) 

371 output_shape = fn(instance, input_shape) 

372 # Return shapes from `fn` as TensorShapes. 

373 if output_shape is not None: 

374 output_shape = convert_shapes(output_shape, to_tuples=False) 

375 return output_shape 

376 

377 return wrapper 

378 

379 

380def are_all_symbolic_tensors(tensors): 

381 return all(map(is_symbolic_tensor, tensors)) 

382 

383 

384_user_convertible_tensor_types = set() 

385 

386 

387def is_extension_type(tensor): 

388 """Returns whether a tensor is of an ExtensionType. 

389 

390 github.com/tensorflow/community/pull/269 

391 Currently it works by checking if `tensor` is a `CompositeTensor` instance, 

392 but this will be changed to use an appropriate extensiontype protocol 

393 check once ExtensionType is made public. 

394 

395 Args: 

396 tensor: An object to test 

397 

398 Returns: 

399 True if the tensor is an extension type object, false if not. 

400 """ 

401 return isinstance(tensor, tf.__internal__.CompositeTensor) 

402 

403 

404def is_symbolic_tensor(tensor): 

405 """Returns whether a tensor is symbolic (from a TF graph) or an eager 

406 tensor. 

407 

408 A Variable can be seen as either: it is considered symbolic 

409 when we are in a graph scope, and eager when we are in an eager scope. 

410 

411 Args: 

412 tensor: A tensor instance to test. 

413 

414 Returns: 

415 True for symbolic tensors, False for eager tensors. 

416 """ 

417 if isinstance(tensor, tf.Tensor): 

418 return hasattr(tensor, "graph") 

419 elif is_extension_type(tensor): 

420 component_tensors = tf.nest.flatten(tensor, expand_composites=True) 

421 return any(hasattr(t, "graph") for t in component_tensors) 

422 elif isinstance(tensor, tf.Variable): 

423 # Variables that are output of a Keras Layer in Functional API mode 

424 # should be considered symbolic. 

425 # TODO(omalleyt): We need a better way to check this in order to 

426 # enable `run_eagerly=True` for Models containing Layers that 

427 # return Variables as outputs. 

428 return ( 

429 getattr(tensor, "_keras_history", False) 

430 or not tf.executing_eagerly() 

431 ) 

432 elif isinstance(tensor, tuple(_user_convertible_tensor_types)): 

433 tensor = ops.convert_to_tensor_or_composite(tensor) 

434 return is_symbolic_tensor(tensor) 

435 else: 

436 return False 

437 

438 

439@keras_export("keras.__internal__.utils.register_symbolic_tensor_type", v1=[]) 

440def register_symbolic_tensor_type(cls): 

441 """Allows users to specify types regarded as symbolic `Tensor`s. 

442 

443 Used in conjunction with `tf.register_tensor_conversion_function`, calling 

444 `tf.keras.__internal__.utils.register_symbolic_tensor_type(cls)` 

445 allows non-`Tensor` objects to be plumbed through Keras layers. 

446 

447 Example: 

448 

449 ```python 

450 # One-time setup. 

451 class Foo: 

452 def __init__(self, input_): 

453 self._input = input_ 

454 def value(self): 

455 return tf.constant(42.) 

456 

457 tf.register_tensor_conversion_function( 

458 Foo, lambda x, *args, **kwargs: x.value()) 

459 

460 tf.keras.__internal__.utils.register_symbolic_tensor_type(Foo) 

461 

462 # User-land. 

463 layer = tf.keras.layers.Lambda(lambda input_: Foo(input_)) 

464 ``` 

465 

466 Args: 

467 cls: A `class` type which shall be regarded as a symbolic `Tensor`. 

468 """ 

469 global _user_convertible_tensor_types 

470 if cls not in _user_convertible_tensor_types: 

471 keras_tensor.register_keras_tensor_specialization( 

472 cls, keras_tensor.UserRegisteredTypeKerasTensor 

473 ) 

474 _user_convertible_tensor_types.add(cls) 

475 

476 

477def type_spec_from_value(value): 

478 """Grab type_spec without converting array-likes to tensors.""" 

479 if is_extension_type(value): 

480 return value._type_spec 

481 # Get a TensorSpec for array-like data without 

482 # converting the data to a Tensor 

483 if hasattr(value, "shape") and hasattr(value, "dtype"): 

484 return tf.TensorSpec(value.shape, value.dtype) 

485 else: 

486 return tf.type_spec_from_value(value) 

487 

488 

489def is_ragged(tensor): 

490 """Returns true if `tensor` is a ragged tensor or ragged tensor value.""" 

491 return isinstance( 

492 tensor, (tf.RaggedTensor, tf.compat.v1.ragged.RaggedTensorValue) 

493 ) 

494 

495 

496def is_sparse(tensor): 

497 """Returns true if `tensor` is a sparse tensor or sparse tensor value.""" 

498 return isinstance(tensor, (tf.SparseTensor, tf.compat.v1.SparseTensorValue)) 

499 

500 

501def is_tensor_or_variable(x): 

502 return tf.is_tensor(x) or isinstance(x, tf.Variable) 

503 

504 

505def is_tensor_or_extension_type(x): 

506 """Returns true if 'x' is a TF-native type or an ExtensionType.""" 

507 return tf.is_tensor(x) or is_extension_type(x) 

508 

509 

510def convert_variables_to_tensors(values): 

511 """Converts `Variable`s in `values` to `Tensor`s. 

512 

513 This is a Keras version of `convert_variables_to_tensors` in TensorFlow 

514 variable_utils.py. 

515 

516 If an object in `values` is an `ExtensionType` and it overrides its 

517 `_convert_variables_to_tensors` method, its `ResourceVariable` components 

518 will also be converted to `Tensor`s. Objects other than `ResourceVariable`s 

519 in `values` will be returned unchanged. 

520 

521 Args: 

522 values: A nested structure of `ResourceVariable`s, or any other objects. 

523 

524 Returns: 

525 A new structure with `ResourceVariable`s in `values` converted to 

526 `Tensor`s. 

527 """ 

528 

529 def _convert_resource_variable_to_tensor(x): 

530 if isinstance(x, tf.Variable): 

531 return tf.convert_to_tensor(x) 

532 elif is_extension_type(x): 

533 return x._convert_variables_to_tensors() 

534 else: 

535 return x 

536 

537 return tf.nest.map_structure(_convert_resource_variable_to_tensor, values) 

538 

539 

540def assert_no_legacy_layers(layers): 

541 """Prevent tf.layers.Layers from being used with Keras. 

542 

543 Certain legacy layers inherit from their keras analogs; however they are 

544 not supported with keras and can lead to subtle and hard to diagnose bugs. 

545 

546 Args: 

547 layers: A list of layers to check 

548 

549 Raises: 

550 TypeError: If any elements of layers are tf.layers.Layers 

551 """ 

552 

553 # isinstance check for tf.layers.Layer introduces a circular dependency. 

554 legacy_layers = [l for l in layers if getattr(l, "_is_legacy_layer", None)] 

555 if legacy_layers: 

556 layer_str = "\n".join(" " + str(l) for l in legacy_layers) 

557 raise TypeError( 

558 f"The following are legacy tf.layers.Layers:\n{layer_str}\n" 

559 "To use keras as a " 

560 "framework (for instance using the Network, Model, or Sequential " 

561 "classes), please use the tf.keras.layers implementation instead. " 

562 "(Or, if writing custom layers, subclass from tf.keras.layers " 

563 "rather than tf.layers)" 

564 ) 

565 

566 

567@tf_contextlib.contextmanager 

568def maybe_init_scope(layer): 

569 """Open an `init_scope` if in V2 mode and using the keras graph. 

570 

571 Args: 

572 layer: The Layer/Model that is currently active. 

573 

574 Yields: 

575 None 

576 """ 

577 # Don't open an init_scope in V1 mode, when using legacy tf.layers, or in a 

578 # local-variable scope. 

579 # The local-variable scope should ensure that created variables are local to 

580 # the function being executed, rather than lifted out of the graph by 

581 # `init_scope`. This way the variables are freely usable and mutable within 

582 # the function, which enables a visitation guarantee for model evaluation, 

583 # when the scope is applied to metric variable creation. 

584 if ( 

585 tf.compat.v1.executing_eagerly_outside_functions() 

586 and getattr(layer, "_keras_style", True) 

587 and not in_local_vars_context() 

588 ): 

589 with tf.init_scope(): 

590 yield 

591 else: 

592 yield 

593 

594 

595@tf_contextlib.contextmanager 

596def graph_context_for_symbolic_tensors(*args, **kwargs): 

597 """Returns graph context manager if any of the inputs is a symbolic 

598 tensor.""" 

599 if any(is_symbolic_tensor(v) for v in list(args) + list(kwargs.values())): 

600 with backend.get_graph().as_default(): 

601 yield 

602 else: 

603 yield 

604 

605 

606def dataset_is_infinite(dataset): 

607 """True if the passed dataset is infinite.""" 

608 if tf.compat.v1.executing_eagerly_outside_functions(): 

609 return tf.equal( 

610 tf.data.experimental.cardinality(dataset), 

611 tf.data.experimental.INFINITE_CARDINALITY, 

612 ) 

613 else: 

614 dataset_size = backend.get_session().run( 

615 tf.data.experimental.cardinality(dataset) 

616 ) 

617 return dataset_size == tf.data.experimental.INFINITE_CARDINALITY 

618 

619 

620def get_tensor_spec(t, dynamic_batch=False, name=None): 

621 """Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`.""" 

622 

623 if isinstance(t, tf.TypeSpec): 

624 spec = t 

625 elif is_extension_type(t): 

626 # TODO(b/148821952): Should these specs have a name attr? 

627 spec = t._type_spec 

628 elif hasattr(t, "_keras_history") and hasattr( 

629 t._keras_history[0], "_type_spec" 

630 ): 

631 return t._keras_history[0]._type_spec 

632 elif isinstance(t, keras_tensor.KerasTensor): 

633 spec = t.type_spec 

634 elif hasattr(t, "shape") and hasattr(t, "dtype"): 

635 spec = tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=name) 

636 else: 

637 return None # Allow non-Tensors to pass through. 

638 

639 if not dynamic_batch: 

640 return spec 

641 

642 shape = spec.shape 

643 if shape.rank is None or shape.rank == 0: 

644 return spec 

645 

646 shape_list = shape.as_list() 

647 shape_list[0] = None 

648 # TODO(b/203201161) Remove this deepcopy one type_spec_with_shape has been 

649 # updated to not mutate spec. 

650 spec = copy.deepcopy(spec) 

651 return keras_tensor.type_spec_with_shape(spec, tf.TensorShape(shape_list)) 

652 

653 

654def sync_to_numpy_or_python_type(tensors): 

655 """Syncs and converts a structure of `Tensor`s to `NumPy` arrays or Python 

656 scalar types. 

657 

658 For each tensor, it calls `tensor.numpy()`. If the result is a scalar value, 

659 it converts it to a Python type, such as a float or int, by calling 

660 `result.item()`. 

661 

662 Numpy scalars are converted, as Python types are often more convenient to 

663 deal with. This is especially useful for bfloat16 Numpy scalars, which don't 

664 support as many operations as other Numpy values. 

665 

666 Async strategies (such as `TPUStrategy` and `ParameterServerStrategy`) are 

667 forced to 

668 sync during this process. 

669 

670 Args: 

671 tensors: A structure of tensors. 

672 

673 Returns: 

674 `tensors`, but scalar tensors are converted to Python types and non-scalar 

675 tensors are converted to Numpy arrays. 

676 """ 

677 if isinstance(tensors, tf.distribute.experimental.coordinator.RemoteValue): 

678 tensors = tensors.fetch() 

679 if isinstance(tensors, list) and isinstance( 

680 tensors[0], tf.distribute.experimental.coordinator.RemoteValue 

681 ): 

682 tensors = tf.nest.map_structure(lambda t: t.fetch(), tensors) 

683 

684 def _to_single_numpy_or_python_type(t): 

685 # Don't turn ragged or sparse tensors to NumPy. 

686 if isinstance(t, tf.Tensor): 

687 t = t.numpy() 

688 # Strings, ragged and sparse tensors don't have .item(). Return them 

689 # as-is. 

690 if not isinstance(t, (np.ndarray, np.generic)): 

691 return t 

692 return t.item() if np.ndim(t) == 0 else t 

693 

694 return tf.nest.map_structure(_to_single_numpy_or_python_type, tensors) 

695 

696 

697def _astuple(attrs): 

698 """Converts the given attrs to tuple non-recursively.""" 

699 cls = type(attrs) 

700 fields = getattr(cls, "__attrs_attrs__", None) 

701 if fields is None: 

702 raise ValueError(f"{cls} is not an attrs-decorated class.") 

703 values = [] 

704 for field in fields: 

705 values.append(getattr(attrs, field.name)) 

706 return tuple(values) 

707 

708 

709def can_jit_compile(warn=False): 

710 """Returns True if TensorFlow XLA is available for the platform.""" 

711 if platform.system() == "Darwin" and "arm" in platform.processor().lower(): 

712 if warn: 

713 logging.warning( 

714 "XLA (`jit_compile`) is not yet supported on Apple M1/M2 ARM " 

715 "processors. Falling back to `jit_compile=False`." 

716 ) 

717 return False 

718 if pywrap_tfe.TF_ListPluggablePhysicalDevices(): 

719 if warn: 

720 logging.warning( 

721 "XLA (`jit_compile`) is not supported on your system. " 

722 "Falling back to `jit_compile=False`." 

723 ) 

724 return False 

725 return True 

726 

727 

728_metric_local_vars_scope = threading.local() 

729 

730 

731def get_metric_local_vars_scope(): 

732 try: 

733 return _metric_local_vars_scope.current 

734 except AttributeError: 

735 return None 

736 

737 

738def in_local_vars_context(): 

739 ctx = get_metric_local_vars_scope() 

740 return ctx is not None 

741 

742 

743@contextlib.contextmanager 

744def with_metric_local_vars_scope(): 

745 previous_scope = getattr(_metric_local_vars_scope, "current", None) 

746 _metric_local_vars_scope.current = MetricLocalVarsScope() 

747 yield 

748 _metric_local_vars_scope.current = previous_scope 

749 

750 

751class MetricLocalVarsScope: 

752 """Turn on local variable creation for Metrics. 

753 

754 No functionality is needed here, it just exists to modulate Metric's 

755 variable creation.""" 

756