Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/utils/tf_utils.py: 25%

219 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""TensorFlow-related utilities.""" 

16 

17import collections 

18import copy 

19import numpy as np 

20 

21from tensorflow.python.data.experimental.ops import cardinality 

22from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib 

23from tensorflow.python.eager import context 

24from tensorflow.python.framework import composite_tensor 

25from tensorflow.python.framework import ops 

26from tensorflow.python.framework import sparse_tensor 

27from tensorflow.python.framework import tensor_shape 

28from tensorflow.python.framework import tensor_spec 

29from tensorflow.python.framework import tensor_util 

30from tensorflow.python.framework import type_spec 

31from tensorflow.python.keras import backend as K 

32from tensorflow.python.keras.engine import keras_tensor 

33from tensorflow.python.keras.utils import object_identity 

34from tensorflow.python.keras.utils import tf_contextlib 

35from tensorflow.python.ops import math_ops 

36from tensorflow.python.ops import variables 

37from tensorflow.python.ops.ragged import ragged_tensor 

38from tensorflow.python.ops.ragged import ragged_tensor_value 

39from tensorflow.python.util import nest 

40from tensorflow.python.util.tf_export import keras_export 

41 

42 

43def is_tensor_or_tensor_list(v): 

44 v = nest.flatten(v) 

45 if v and isinstance(v[0], ops.Tensor): 

46 return True 

47 else: 

48 return False 

49 

50 

51def get_reachable_from_inputs(inputs, targets=None): 

52 """Returns the set of tensors/ops reachable from `inputs`. 

53 

54 Stops if all targets have been found (target is optional). 

55 

56 Only valid in Symbolic mode, not Eager mode. 

57 

58 Args: 

59 inputs: List of tensors. 

60 targets: List of tensors. 

61 

62 Returns: 

63 A set of tensors reachable from the inputs (includes the inputs themselves). 

64 """ 

65 inputs = nest.flatten(inputs, expand_composites=True) 

66 reachable = object_identity.ObjectIdentitySet(inputs) 

67 if targets: 

68 remaining_targets = object_identity.ObjectIdentitySet(nest.flatten(targets)) 

69 queue = collections.deque(inputs) 

70 

71 while queue: 

72 x = queue.pop() 

73 if isinstance(x, tuple(_user_convertible_tensor_types)): 

74 # Can't find consumers of user-specific types. 

75 continue 

76 

77 if isinstance(x, ops.Operation): 

78 outputs = x.outputs[:] or [] 

79 outputs += x._control_outputs # pylint: disable=protected-access 

80 elif isinstance(x, variables.Variable): 

81 try: 

82 outputs = [x.op] 

83 except AttributeError: 

84 # Variables can be created in an Eager context. 

85 outputs = [] 

86 elif tensor_util.is_tf_type(x): 

87 outputs = x.consumers() 

88 else: 

89 raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x)) 

90 

91 for y in outputs: 

92 if y not in reachable: 

93 reachable.add(y) 

94 if targets: 

95 remaining_targets.discard(y) 

96 queue.appendleft(y) 

97 

98 if targets and not remaining_targets: 

99 return reachable 

100 

101 return reachable 

102 

103 

104# This function needs access to private functions of `nest`. 

105# pylint: disable=protected-access 

106def map_structure_with_atomic(is_atomic_fn, map_fn, nested): 

107 """Maps the atomic elements of a nested structure. 

108 

109 Args: 

110 is_atomic_fn: A function that determines if an element of `nested` is 

111 atomic. 

112 map_fn: The function to apply to atomic elements of `nested`. 

113 nested: A nested structure. 

114 

115 Returns: 

116 The nested structure, with atomic elements mapped according to `map_fn`. 

117 

118 Raises: 

119 ValueError: If an element that is neither atomic nor a sequence is 

120 encountered. 

121 """ 

122 if is_atomic_fn(nested): 

123 return map_fn(nested) 

124 

125 # Recursively convert. 

126 if not nest.is_nested(nested): 

127 raise ValueError( 

128 'Received non-atomic and non-sequence element: {}'.format(nested)) 

129 if nest.is_mapping(nested): 

130 values = [nested[k] for k in sorted(nested.keys())] 

131 elif nest.is_attrs(nested): 

132 values = _astuple(nested) 

133 else: 

134 values = nested 

135 mapped_values = [ 

136 map_structure_with_atomic(is_atomic_fn, map_fn, ele) for ele in values 

137 ] 

138 return nest._sequence_like(nested, mapped_values) 

139 

140 

141def get_shapes(tensors): 

142 """Gets shapes from tensors.""" 

143 return nest.map_structure(lambda x: x.shape, tensors) 

144 

145 

146# pylint: enable=protected-access 

147 

148 

149def convert_shapes(input_shape, to_tuples=True): 

150 """Converts nested shape representations to desired format. 

151 

152 Performs: 

153 

154 TensorShapes -> tuples if `to_tuples=True`. 

155 tuples of int or None -> TensorShapes if `to_tuples=False`. 

156 

157 Valid objects to be converted are: 

158 - TensorShapes 

159 - tuples with elements of type int or None. 

160 - ints 

161 - None 

162 

163 Args: 

164 input_shape: A nested structure of objects to be converted to TensorShapes. 

165 to_tuples: If `True`, converts all TensorShape to tuples. Otherwise converts 

166 all tuples representing shapes to TensorShapes. 

167 

168 Returns: 

169 Nested structure of shapes in desired format. 

170 

171 Raises: 

172 ValueError: when the input tensor shape can't be converted to tuples, eg 

173 unknown tensor shape. 

174 """ 

175 

176 def _is_shape_component(value): 

177 return value is None or isinstance(value, (int, tensor_shape.Dimension)) 

178 

179 def _is_atomic_shape(input_shape): 

180 # Ex: TensorShape or (None, 10, 32) or 5 or `None` 

181 if _is_shape_component(input_shape): 

182 return True 

183 if isinstance(input_shape, tensor_shape.TensorShape): 

184 return True 

185 if (isinstance(input_shape, (tuple, list)) and 

186 all(_is_shape_component(ele) for ele in input_shape)): 

187 return True 

188 return False 

189 

190 def _convert_shape(input_shape): 

191 input_shape = tensor_shape.TensorShape(input_shape) 

192 if to_tuples: 

193 input_shape = tuple(input_shape.as_list()) 

194 return input_shape 

195 

196 return map_structure_with_atomic(_is_atomic_shape, _convert_shape, 

197 input_shape) 

198 

199 

200class ListWrapper(object): 

201 """A wrapper for lists to be treated as elements for `nest`.""" 

202 

203 def __init__(self, list_to_wrap): 

204 self._list = list_to_wrap 

205 

206 def as_list(self): 

207 return self._list 

208 

209 

210def convert_inner_node_data(nested, wrap=False): 

211 """Either wraps or unwraps innermost node data lists in `ListWrapper` objects. 

212 

213 Args: 

214 nested: A nested data structure. 

215 wrap: If `True`, wrap innermost lists in `ListWrapper` objects. If `False`, 

216 unwraps `ListWrapper` objects into lists. 

217 

218 Returns: 

219 Structure of same type as nested, with lists wrapped/unwrapped. 

220 """ 

221 

222 def _is_serialized_node_data(nested): 

223 # Node data can be of form `[layer_name, node_id, tensor_id]` or 

224 # `[layer_name, node_id, tensor_id, kwargs]`. 

225 if (isinstance(nested, list) and (len(nested) in [3, 4]) and 

226 isinstance(nested[0], str)): 

227 return True 

228 return False 

229 

230 def _is_atomic_nested(nested): 

231 """Returns `True` if `nested` is a list representing node data.""" 

232 if isinstance(nested, ListWrapper): 

233 return True 

234 if _is_serialized_node_data(nested): 

235 return True 

236 return not nest.is_nested(nested) 

237 

238 def _convert_object_or_list(nested): 

239 """Convert b/t `ListWrapper` object and list representations.""" 

240 if wrap: 

241 if isinstance(nested, ListWrapper): 

242 return nested 

243 if _is_serialized_node_data(nested): 

244 return ListWrapper(nested) 

245 return nested 

246 else: 

247 if isinstance(nested, ListWrapper): 

248 return nested.as_list() 

249 return nested 

250 

251 return map_structure_with_atomic(_is_atomic_nested, _convert_object_or_list, 

252 nested) 

253 

254 

255def shape_type_conversion(fn): 

256 """Decorator that handles tuple/TensorShape conversion. 

257 

258 Used in `compute_output_shape` and `build`. 

259 

260 Args: 

261 fn: function to wrap. 

262 

263 Returns: 

264 Wrapped function. 

265 """ 

266 

267 def wrapper(instance, input_shape): 

268 # Pass shapes as tuples to `fn` 

269 # This preserves compatibility with external Keras. 

270 if input_shape is not None: 

271 input_shape = convert_shapes(input_shape, to_tuples=True) 

272 output_shape = fn(instance, input_shape) 

273 # Return shapes from `fn` as TensorShapes. 

274 if output_shape is not None: 

275 output_shape = convert_shapes(output_shape, to_tuples=False) 

276 return output_shape 

277 

278 return wrapper 

279 

280 

281def are_all_symbolic_tensors(tensors): 

282 return all(map(is_symbolic_tensor, tensors)) 

283 

284 

285_user_convertible_tensor_types = set() 

286 

287 

288def is_extension_type(tensor): 

289 """Returns whether a tensor is of an ExtensionType. 

290 

291 github.com/tensorflow/community/pull/269 

292 Currently it works by checking if `tensor` is a `CompositeTensor` instance, 

293 but this will be changed to use an appropriate extensiontype protocol 

294 check once ExtensionType is made public. 

295 

296 Args: 

297 tensor: An object to test 

298 

299 Returns: 

300 True if the tensor is an extension type object, false if not. 

301 """ 

302 return isinstance(tensor, composite_tensor.CompositeTensor) 

303 

304 

305def is_symbolic_tensor(tensor): 

306 """Returns whether a tensor is symbolic (from a TF graph) or an eager tensor. 

307 

308 A Variable can be seen as either: it is considered symbolic 

309 when we are in a graph scope, and eager when we are in an eager scope. 

310 

311 Args: 

312 tensor: A tensor instance to test. 

313 

314 Returns: 

315 True for symbolic tensors, False for eager tensors. 

316 """ 

317 if isinstance(tensor, ops.Tensor): 

318 return hasattr(tensor, 'graph') 

319 elif is_extension_type(tensor): 

320 component_tensors = nest.flatten(tensor, expand_composites=True) 

321 return any(hasattr(t, 'graph') for t in component_tensors) 

322 elif isinstance(tensor, variables.Variable): 

323 # Variables that are output of a Keras Layer in Functional API mode 

324 # should be considered symbolic. 

325 # TODO(omalleyt): We need a better way to check this in order to 

326 # enable `run_eagerly=True` for Models containing Layers that 

327 # return Variables as outputs. 

328 return (getattr(tensor, '_keras_history', False) or 

329 not context.executing_eagerly()) 

330 elif isinstance(tensor, tuple(_user_convertible_tensor_types)): 

331 tensor = ops.convert_to_tensor_or_composite(tensor) 

332 return is_symbolic_tensor(tensor) 

333 else: 

334 return False 

335 

336 

337@keras_export('keras.__internal__.utils.register_symbolic_tensor_type', v1=[]) 

338def register_symbolic_tensor_type(cls): 

339 """Allows users to specify types regarded as symbolic `Tensor`s. 

340 

341 Used in conjunction with `tf.register_tensor_conversion_function`, calling 

342 `tf.keras.__internal__.utils.register_symbolic_tensor_type(cls)` 

343 allows non-`Tensor` objects to be plumbed through Keras layers. 

344 

345 Example: 

346 

347 ```python 

348 # One-time setup. 

349 class Foo(object): 

350 def __init__(self, input_): 

351 self._input = input_ 

352 def value(self): 

353 return tf.constant(42.) 

354 

355 tf.register_tensor_conversion_function( 

356 Foo, lambda x, *args, **kwargs: x.value()) 

357 

358 tf.keras.__internal__.utils.register_symbolic_tensor_type(Foo) 

359 

360 # User-land. 

361 layer = tf.keras.layers.Lambda(lambda input_: Foo(input_)) 

362 ``` 

363 

364 Args: 

365 cls: A `class` type which shall be regarded as a symbolic `Tensor`. 

366 """ 

367 global _user_convertible_tensor_types 

368 if cls not in _user_convertible_tensor_types: 

369 keras_tensor.register_keras_tensor_specialization( 

370 cls, keras_tensor.UserRegisteredTypeKerasTensor) 

371 _user_convertible_tensor_types.add(cls) 

372 

373 

374def type_spec_from_value(value): 

375 """Grab type_spec without converting array-likes to tensors.""" 

376 if is_extension_type(value): 

377 return value._type_spec # pylint: disable=protected-access 

378 # Get a TensorSpec for array-like data without 

379 # converting the data to a Tensor 

380 if hasattr(value, 'shape') and hasattr(value, 'dtype'): 

381 return tensor_spec.TensorSpec(value.shape, value.dtype) 

382 else: 

383 return type_spec.type_spec_from_value(value) 

384 

385 

386def is_ragged(tensor): 

387 """Returns true if `tensor` is a ragged tensor or ragged tensor value.""" 

388 return isinstance( 

389 tensor, 

390 (ragged_tensor.RaggedTensor, ragged_tensor_value.RaggedTensorValue)) 

391 

392 

393def is_sparse(tensor): 

394 """Returns true if `tensor` is a sparse tensor or sparse tensor value.""" 

395 return isinstance( 

396 tensor, 

397 (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)) 

398 

399 

400def is_tensor_or_variable(x): 

401 return tensor_util.is_tf_type(x) or isinstance(x, variables.Variable) 

402 

403 

404def assert_no_legacy_layers(layers): 

405 """Prevent tf.layers.Layers from being used with Keras. 

406 

407 Certain legacy layers inherit from their keras analogs; however they are 

408 not supported with keras and can lead to subtle and hard to diagnose bugs. 

409 

410 Args: 

411 layers: A list of layers to check 

412 

413 Raises: 

414 TypeError: If any elements of layers are tf.layers.Layers 

415 """ 

416 

417 # isinstance check for tf.layers.Layer introduces a circular dependency. 

418 legacy_layers = [l for l in layers if getattr(l, '_is_legacy_layer', None)] 

419 if legacy_layers: 

420 layer_str = '\n'.join(' ' + str(l) for l in legacy_layers) 

421 raise TypeError( 

422 'The following are legacy tf.layers.Layers:\n{}\nTo use keras as a ' 

423 'framework (for instance using the Network, Model, or Sequential ' 

424 'classes), please use the tf.keras.layers implementation instead. ' 

425 '(Or, if writing custom layers, subclass from tf.keras.layers rather ' 

426 'than tf.layers)'.format(layer_str)) 

427 

428 

429@tf_contextlib.contextmanager 

430def maybe_init_scope(layer): 

431 """Open an `init_scope` if in V2 mode and using the keras graph. 

432 

433 Args: 

434 layer: The Layer/Model that is currently active. 

435 

436 Yields: 

437 None 

438 """ 

439 # Don't open an init_scope in V1 mode or when using legacy tf.layers. 

440 if (ops.executing_eagerly_outside_functions() and 

441 getattr(layer, '_keras_style', True)): 

442 with ops.init_scope(): 

443 yield 

444 else: 

445 yield 

446 

447 

448@tf_contextlib.contextmanager 

449def graph_context_for_symbolic_tensors(*args, **kwargs): 

450 """Returns graph context manager if any of the inputs is a symbolic tensor.""" 

451 if any(is_symbolic_tensor(v) for v in list(args) + list(kwargs.values())): 

452 with K.get_graph().as_default(): 

453 yield 

454 else: 

455 yield 

456 

457 

458def dataset_is_infinite(dataset): 

459 """True if the passed dataset is infinite.""" 

460 if ops.executing_eagerly_outside_functions(): 

461 return math_ops.equal( 

462 cardinality.cardinality(dataset), cardinality.INFINITE) 

463 else: 

464 dataset_size = K.get_session().run(cardinality.cardinality(dataset)) 

465 return dataset_size == cardinality.INFINITE 

466 

467 

468def get_tensor_spec(t, dynamic_batch=False, name=None): 

469 """Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`.""" 

470 # pylint: disable=protected-access 

471 if isinstance(t, type_spec.TypeSpec): 

472 spec = t 

473 elif is_extension_type(t): 

474 # TODO(b/148821952): Should these specs have a name attr? 

475 spec = t._type_spec 

476 elif (hasattr(t, '_keras_history') and 

477 hasattr(t._keras_history[0], '_type_spec')): 

478 return t._keras_history[0]._type_spec 

479 elif hasattr(t, 'shape') and hasattr(t, 'dtype'): 

480 spec = tensor_spec.TensorSpec(shape=t.shape, dtype=t.dtype, name=name) 

481 else: 

482 return None # Allow non-Tensors to pass through. 

483 

484 if not dynamic_batch: 

485 return spec 

486 

487 dynamic_batch_spec = copy.deepcopy(spec) 

488 # RaggedTensorSpec only has a private _shape. 

489 shape = dynamic_batch_spec._shape 

490 if shape.rank is not None and shape.rank > 0: 

491 shape_list = shape.as_list() 

492 shape_list[0] = None 

493 dynamic_batch_spec._shape = tensor_shape.TensorShape(shape_list) 

494 return dynamic_batch_spec 

495 # pylint: enable=protected-access 

496 

497 

498def sync_to_numpy_or_python_type(tensors): 

499 """Syncs and converts a structure of `Tensor`s to `NumPy` arrays or Python scalar types. 

500 

501 For each tensor, it calls `tensor.numpy()`. If the result is a scalar value, 

502 it converts it to a Python type, such as a float or int, by calling 

503 `result.item()`. 

504 

505 Numpy scalars are converted, as Python types are often more convenient to deal 

506 with. This is especially useful for bfloat16 Numpy scalars, which don't 

507 support as many operations as other Numpy values. 

508 

509 Async strategies (such as `TPUStrategy` and `ParameterServerStrategy`) are 

510 forced to 

511 sync during this process. 

512 

513 Args: 

514 tensors: A structure of tensors. 

515 

516 Returns: 

517 `tensors`, but scalar tensors are converted to Python types and non-scalar 

518 tensors are converted to Numpy arrays. 

519 """ 

520 if isinstance(tensors, coordinator_lib.RemoteValue): 

521 return tensors.fetch() 

522 

523 def _to_single_numpy_or_python_type(t): 

524 if isinstance(t, ops.Tensor): 

525 x = t.numpy() 

526 return x.item() if np.ndim(x) == 0 else x 

527 return t # Don't turn ragged or sparse tensors to NumPy. 

528 

529 return nest.map_structure(_to_single_numpy_or_python_type, tensors) 

530 

531 

532def _astuple(attrs): 

533 """Converts the given attrs to tuple non-recursively.""" 

534 cls = type(attrs) 

535 fields = getattr(cls, '__attrs_attrs__', None) 

536 if fields is None: 

537 raise ValueError('%r is not an attrs-decorated class.' % cls) 

538 values = [] 

539 for field in fields: 

540 values.append(getattr(attrs, field.name)) 

541 return tuple(values)