Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/backend.py: 29%

2449 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16 

17"""Keras backend API.""" 

18 

19import collections 

20import itertools 

21import json 

22import os 

23import random 

24import sys 

25import threading 

26import warnings 

27import weakref 

28 

29import numpy as np 

30import tensorflow.compat.v2 as tf 

31 

32from keras.src import backend_config 

33from keras.src.distribute import distribute_coordinator_utils as dc 

34from keras.src.dtensor import dtensor_api as dtensor 

35from keras.src.engine import keras_tensor 

36from keras.src.utils import control_flow_util 

37from keras.src.utils import object_identity 

38from keras.src.utils import tf_contextlib 

39from keras.src.utils import tf_inspect 

40from keras.src.utils import tf_utils 

41 

42# isort: off 

43from tensorflow.core.protobuf import config_pb2 

44from tensorflow.python.eager import context 

45from tensorflow.python.eager.context import get_config 

46from tensorflow.python.platform import tf_logging as logging 

47from tensorflow.python.util.tf_export import keras_export 

48from tensorflow.tools.docs import doc_controls 

49 

50py_all = all 

51py_sum = sum 

52py_any = any 

53 

54# INTERNAL UTILS 

55 

56# The internal graph maintained by Keras and used by the symbolic Keras APIs 

57# while executing eagerly (such as the functional API for model-building). 

58# This is thread-local to allow building separate models in different threads 

59# concurrently, but comes at the cost of not being able to build one model 

60# across threads. 

61_GRAPH = threading.local() 

62 

63# A graph which is used for constructing functions in eager mode. 

64_CURRENT_SCRATCH_GRAPH = threading.local() 

65 

66 

67# This is a thread local object that will hold the default internal TF session 

68# used by Keras. It can be set manually via `set_session(sess)`. 

69class SessionLocal(threading.local): 

70 def __init__(self): 

71 super().__init__() 

72 self.session = None 

73 

74 

75_SESSION = SessionLocal() 

76 

77 

78# A global dictionary mapping graph objects to an index of counters used 

79# for various layer/optimizer names in each graph. 

80# Allows to give unique autogenerated names to layers, in a graph-specific way. 

81PER_GRAPH_OBJECT_NAME_UIDS = weakref.WeakKeyDictionary() 

82 

83 

84# A global set tracking what object names have been seen so far. 

85# Optionally used as an avoid-list when generating names 

86OBSERVED_NAMES = set() 

87 

88 

89# _DUMMY_EAGER_GRAPH.key is used as a key in _GRAPH_LEARNING_PHASES. 

90# We keep a separate reference to it to make sure it does not get removed from 

91# _GRAPH_LEARNING_PHASES. 

92# _DummyEagerGraph inherits from threading.local to make its `key` attribute 

93# thread local. This is needed to make set_learning_phase affect only the 

94# current thread during eager execution (see b/123096885 for more details). 

95class _DummyEagerGraph(threading.local): 

96 """_DummyEagerGraph provides a thread local `key` attribute. 

97 

98 We can't use threading.local directly, i.e. without subclassing, because 

99 gevent monkey patches threading.local and its version does not support 

100 weak references. 

101 """ 

102 

103 class _WeakReferencableClass: 

104 """This dummy class is needed for two reasons. 

105 

106 - We need something that supports weak references. Basic types like 

107 string and ints don't. 

108 - We need something whose hash and equality are based on object identity 

109 to make sure they are treated as different keys to 

110 _GRAPH_LEARNING_PHASES. 

111 

112 An empty Python class satisfies both of these requirements. 

113 """ 

114 

115 pass 

116 

117 def __init__(self): 

118 # Constructors for classes subclassing threading.local run once 

119 # per thread accessing something in the class. Thus, each thread will 

120 # get a different key. 

121 super().__init__() 

122 self.key = _DummyEagerGraph._WeakReferencableClass() 

123 self.learning_phase_is_set = False 

124 

125 

126_DUMMY_EAGER_GRAPH = _DummyEagerGraph() 

127 

128# This boolean flag can be set to True to leave variable initialization 

129# up to the user. 

130# Change its value via `manual_variable_initialization(value)`. 

131_MANUAL_VAR_INIT = False 

132 

133# This list holds the available devices. 

134# It is populated when `_get_available_gpus()` is called for the first time. 

135# We assume our devices don't change henceforth. 

136_LOCAL_DEVICES = None 

137 

138# The below functions are kept accessible from backend for compatibility. 

139epsilon = backend_config.epsilon 

140floatx = backend_config.floatx 

141image_data_format = backend_config.image_data_format 

142set_epsilon = backend_config.set_epsilon 

143set_floatx = backend_config.set_floatx 

144set_image_data_format = backend_config.set_image_data_format 

145 

146 

147@keras_export("keras.backend.backend") 

148@doc_controls.do_not_generate_docs 

149def backend(): 

150 """Publicly accessible method for determining the current backend. 

151 

152 Only exists for API compatibility with multi-backend Keras. 

153 

154 Returns: 

155 The string "tensorflow". 

156 """ 

157 return "tensorflow" 

158 

159 

160@keras_export("keras.backend.cast_to_floatx") 

161@tf.__internal__.dispatch.add_dispatch_support 

162@doc_controls.do_not_generate_docs 

163def cast_to_floatx(x): 

164 """Cast a Numpy array to the default Keras float type. 

165 

166 Args: 

167 x: Numpy array or TensorFlow tensor. 

168 

169 Returns: 

170 The same array (Numpy array if `x` was a Numpy array, or TensorFlow 

171 tensor if `x` was a tensor), cast to its new type. 

172 

173 Example: 

174 

175 >>> tf.keras.backend.floatx() 

176 'float32' 

177 >>> arr = np.array([1.0, 2.0], dtype='float64') 

178 >>> arr.dtype 

179 dtype('float64') 

180 >>> new_arr = cast_to_floatx(arr) 

181 >>> new_arr 

182 array([1., 2.], dtype=float32) 

183 >>> new_arr.dtype 

184 dtype('float32') 

185 

186 """ 

187 if isinstance(x, (tf.Tensor, tf.Variable, tf.SparseTensor)): 

188 return tf.cast(x, dtype=floatx()) 

189 return np.asarray(x, dtype=floatx()) 

190 

191 

192@keras_export("keras.backend.get_uid") 

193def get_uid(prefix=""): 

194 """Associates a string prefix with an integer counter in a TensorFlow graph. 

195 

196 Args: 

197 prefix: String prefix to index. 

198 

199 Returns: 

200 Unique integer ID. 

201 

202 Example: 

203 

204 >>> get_uid('dense') 

205 1 

206 >>> get_uid('dense') 

207 2 

208 

209 """ 

210 graph = get_graph() 

211 if graph not in PER_GRAPH_OBJECT_NAME_UIDS: 

212 PER_GRAPH_OBJECT_NAME_UIDS[graph] = collections.defaultdict(int) 

213 layer_name_uids = PER_GRAPH_OBJECT_NAME_UIDS[graph] 

214 layer_name_uids[prefix] += 1 

215 return layer_name_uids[prefix] 

216 

217 

218@keras_export("keras.backend.reset_uids") 

219def reset_uids(): 

220 """Resets graph identifiers.""" 

221 

222 PER_GRAPH_OBJECT_NAME_UIDS.clear() 

223 OBSERVED_NAMES.clear() 

224 

225 

226@keras_export("keras.backend.clear_session") 

227def clear_session(): 

228 """Resets all state generated by Keras. 

229 

230 Keras manages a global state, which it uses to implement the Functional 

231 model-building API and to uniquify autogenerated layer names. 

232 

233 If you are creating many models in a loop, this global state will consume 

234 an increasing amount of memory over time, and you may want to clear it. 

235 Calling `clear_session()` releases the global state: this helps avoid 

236 clutter from old models and layers, especially when memory is limited. 

237 

238 Example 1: calling `clear_session()` when creating models in a loop 

239 

240 ```python 

241 for _ in range(100): 

242 # Without `clear_session()`, each iteration of this loop will 

243 # slightly increase the size of the global state managed by Keras 

244 model = tf.keras.Sequential([ 

245 tf.keras.layers.Dense(10) for _ in range(10)]) 

246 

247 for _ in range(100): 

248 # With `clear_session()` called at the beginning, 

249 # Keras starts with a blank state at each iteration 

250 # and memory consumption is constant over time. 

251 tf.keras.backend.clear_session() 

252 model = tf.keras.Sequential([ 

253 tf.keras.layers.Dense(10) for _ in range(10)]) 

254 ``` 

255 

256 Example 2: resetting the layer name generation counter 

257 

258 >>> import tensorflow as tf 

259 >>> layers = [tf.keras.layers.Dense(10) for _ in range(10)] 

260 >>> new_layer = tf.keras.layers.Dense(10) 

261 >>> print(new_layer.name) 

262 dense_10 

263 >>> tf.keras.backend.set_learning_phase(1) 

264 >>> print(tf.keras.backend.learning_phase()) 

265 1 

266 >>> tf.keras.backend.clear_session() 

267 >>> new_layer = tf.keras.layers.Dense(10) 

268 >>> print(new_layer.name) 

269 dense 

270 """ 

271 global _SESSION 

272 global _GRAPH_LEARNING_PHASES 

273 global _GRAPH_VARIABLES 

274 global _GRAPH_TF_OPTIMIZERS 

275 global _GRAPH 

276 _GRAPH.graph = None 

277 tf.compat.v1.reset_default_graph() 

278 reset_uids() 

279 if _SESSION.session is not None: 

280 _SESSION.session.close() 

281 _SESSION.session = None 

282 graph = get_graph() 

283 with graph.as_default(): 

284 _DUMMY_EAGER_GRAPH.learning_phase_is_set = False 

285 

286 _GRAPH_LEARNING_PHASES = {} 

287 # Create the learning phase placeholder in graph using the default 

288 # factory 

289 phase = _default_learning_phase() 

290 _internal_set_learning_phase(graph, phase) 

291 

292 _GRAPH_VARIABLES.pop(graph, None) 

293 _GRAPH_TF_OPTIMIZERS.pop(graph, None) 

294 if tf.executing_eagerly(): 

295 # Clear pending nodes in eager executors, kernel caches and 

296 # step_containers. 

297 context.context().clear_kernel_cache() 

298 

299 

300# Inject the clear_session function to keras_deps to remove the dependency 

301# from TFLite to Keras. 

302tf.__internal__.register_clear_session_function(clear_session) 

303 

304 

305@keras_export("keras.backend.manual_variable_initialization") 

306@doc_controls.do_not_generate_docs 

307def manual_variable_initialization(value): 

308 """Sets the manual variable initialization flag. 

309 

310 This boolean flag determines whether 

311 variables should be initialized 

312 as they are instantiated (default), or if 

313 the user should handle the initialization 

314 (e.g. via `tf.compat.v1.initialize_all_variables()`). 

315 

316 Args: 

317 value: Python boolean. 

318 """ 

319 global _MANUAL_VAR_INIT 

320 _MANUAL_VAR_INIT = value 

321 

322 

323@keras_export("keras.backend.learning_phase") 

324@doc_controls.do_not_generate_docs 

325def learning_phase(): 

326 """Returns the learning phase flag. 

327 

328 The learning phase flag is a bool tensor (0 = test, 1 = train) 

329 to be passed as input to any Keras function 

330 that uses a different behavior at train time and test time. 

331 

332 Returns: 

333 Learning phase (scalar integer tensor or Python integer). 

334 """ 

335 graph = tf.compat.v1.get_default_graph() 

336 if graph is getattr(_GRAPH, "graph", None): 

337 # Don't enter an init_scope for the learning phase if eager execution 

338 # is enabled but we're inside the Keras workspace graph. 

339 learning_phase = symbolic_learning_phase() 

340 else: 

341 with tf.init_scope(): 

342 # We always check & set the learning phase inside the init_scope, 

343 # otherwise the wrong default_graph will be used to look up the 

344 # learning phase inside of functions & defuns. 

345 # 

346 # This is because functions & defuns (both in graph & in eager mode) 

347 # will always execute non-eagerly using a function-specific default 

348 # subgraph. 

349 if context.executing_eagerly(): 

350 if _DUMMY_EAGER_GRAPH.key not in _GRAPH_LEARNING_PHASES: 

351 return _default_learning_phase() 

352 else: 

353 return _internal_get_learning_phase(_DUMMY_EAGER_GRAPH.key) 

354 else: 

355 learning_phase = symbolic_learning_phase() 

356 _mark_func_graph_as_unsaveable(graph, learning_phase) 

357 return learning_phase 

358 

359 

360def global_learning_phase_is_set(): 

361 return _DUMMY_EAGER_GRAPH.learning_phase_is_set 

362 

363 

364def _mark_func_graph_as_unsaveable(graph, learning_phase): 

365 """Mark graph as unsaveable due to use of symbolic keras learning phase. 

366 

367 Functions that capture the symbolic learning phase cannot be exported to 

368 SavedModel. Mark the funcgraph as unsaveable, so that an error will be 

369 raised if it is exported. 

370 

371 Args: 

372 graph: Graph or FuncGraph object. 

373 learning_phase: Learning phase placeholder or int defined in the graph. 

374 """ 

375 if graph.building_function and is_placeholder(learning_phase): 

376 graph.mark_as_unsaveable( 

377 "The keras learning phase placeholder was used inside a function. " 

378 "Exporting placeholders is not supported when saving out a " 

379 "SavedModel. Please call `tf.keras.backend.set_learning_phase(0)` " 

380 "in the function to set the learning phase to a constant value." 

381 ) 

382 

383 

384def symbolic_learning_phase(): 

385 graph = get_graph() 

386 with graph.as_default(): 

387 if graph not in _GRAPH_LEARNING_PHASES: 

388 phase = _default_learning_phase() 

389 _internal_set_learning_phase(graph, phase) 

390 

391 return _internal_get_learning_phase(graph) 

392 

393 

394def _internal_set_learning_phase(graph, value): 

395 global _GRAPH_LEARNING_PHASES 

396 

397 if isinstance(value, tf.Tensor): 

398 # The 'value' here is a tf.Tensor with attribute 'graph'. 

399 # There is a circular reference between key 'graph' and attribute 

400 # 'graph'. So we need use a weakref.ref to refer to the 'value' tensor 

401 # here. Otherwise, it would lead to memory leak. 

402 value_ref = weakref.ref(value) 

403 _GRAPH_LEARNING_PHASES[graph] = value_ref 

404 else: 

405 _GRAPH_LEARNING_PHASES[graph] = value 

406 

407 

408def _internal_get_learning_phase(graph): 

409 phase = _GRAPH_LEARNING_PHASES.get(graph, None) 

410 if isinstance(phase, weakref.ref): 

411 return phase() 

412 else: 

413 return phase 

414 

415 

416def _default_learning_phase(): 

417 if context.executing_eagerly(): 

418 return 0 

419 else: 

420 with name_scope(""): 

421 return tf.compat.v1.placeholder_with_default( 

422 False, shape=(), name="keras_learning_phase" 

423 ) 

424 

425 

426@keras_export("keras.backend.set_learning_phase") 

427@doc_controls.do_not_generate_docs 

428def set_learning_phase(value): 

429 """Sets the learning phase to a fixed value. 

430 

431 The backend learning phase affects any code that calls 

432 `backend.learning_phase()` 

433 In particular, all Keras built-in layers use the learning phase as the 

434 default for the `training` arg to `Layer.__call__`. 

435 

436 User-written layers and models can achieve the same behavior with code that 

437 looks like: 

438 

439 ```python 

440 def call(self, inputs, training=None): 

441 if training is None: 

442 training = backend.learning_phase() 

443 ``` 

444 

445 Args: 

446 value: Learning phase value, either 0 or 1 (integers). 

447 0 = test, 1 = train 

448 

449 Raises: 

450 ValueError: if `value` is neither `0` nor `1`. 

451 """ 

452 warnings.warn( 

453 "`tf.keras.backend.set_learning_phase` is deprecated and " 

454 "will be removed after 2020-10-11. To update it, simply " 

455 "pass a True/False value to the `training` argument of the " 

456 "`__call__` method of your layer or model." 

457 ) 

458 deprecated_internal_set_learning_phase(value) 

459 

460 

461def deprecated_internal_set_learning_phase(value): 

462 """A deprecated internal implementation of set_learning_phase. 

463 

464 This method is an internal-only version of `set_learning_phase` that 

465 does not raise a deprecation error. It is required because 

466 saved_model needs to keep working with user code that uses the deprecated 

467 learning phase methods until those APIs are fully removed from the public 

468 API. 

469 

470 Specifically SavedModel saving needs to make sure the learning phase is 0 

471 during tracing even if users overwrote it to a different value. 

472 

473 But, we don't want to raise deprecation warnings for users when savedmodel 

474 sets learning phase just for compatibility with code that relied on 

475 explicitly setting the learning phase for other values. 

476 

477 Args: 

478 value: Learning phase value, either 0 or 1 (integers). 

479 0 = test, 1 = train 

480 

481 Raises: 

482 ValueError: if `value` is neither `0` nor `1`. 

483 """ 

484 if value not in {0, 1}: 

485 raise ValueError("Expected learning phase to be 0 or 1.") 

486 with tf.init_scope(): 

487 if tf.executing_eagerly(): 

488 # In an eager context, the learning phase values applies to both the 

489 # eager context and the internal Keras graph. 

490 _DUMMY_EAGER_GRAPH.learning_phase_is_set = True 

491 _internal_set_learning_phase(_DUMMY_EAGER_GRAPH.key, value) 

492 

493 _internal_set_learning_phase(get_graph(), value) 

494 

495 

496@keras_export("keras.backend.learning_phase_scope") 

497@tf_contextlib.contextmanager 

498@doc_controls.do_not_generate_docs 

499def learning_phase_scope(value): 

500 """Provides a scope within which the learning phase is equal to `value`. 

501 

502 The learning phase gets restored to its original value upon exiting the 

503 scope. 

504 

505 Args: 

506 value: Learning phase value, either 0 or 1 (integers). 

507 0 = test, 1 = train 

508 

509 Yields: 

510 None. 

511 

512 Raises: 

513 ValueError: if `value` is neither `0` nor `1`. 

514 """ 

515 warnings.warn( 

516 "`tf.keras.backend.learning_phase_scope` is deprecated and " 

517 "will be removed after 2020-10-11. To update it, simply " 

518 "pass a True/False value to the `training` argument of the " 

519 "`__call__` method of your layer or model.", 

520 stacklevel=2, 

521 ) 

522 with deprecated_internal_learning_phase_scope(value): 

523 try: 

524 yield 

525 finally: 

526 pass 

527 

528 

529@tf_contextlib.contextmanager 

530def deprecated_internal_learning_phase_scope(value): 

531 """An internal-only version of `learning_phase_scope`. 

532 

533 Unlike the public method, this method does not raise a deprecation warning. 

534 This is needed because saved model saving needs to set learning phase 

535 to maintain compatibility 

536 with code that sets/gets the learning phase, but saved model 

537 saving itself shouldn't raise a deprecation warning. 

538 

539 We can get rid of this method and its usages when the public API is 

540 removed. 

541 

542 Args: 

543 value: Learning phase value, either 0 or 1 (integers). 

544 0 = test, 1 = train 

545 

546 Yields: 

547 None. 

548 

549 Raises: 

550 ValueError: if `value` is neither `0` nor `1`. 

551 """ 

552 global _GRAPH_LEARNING_PHASES 

553 if value not in {0, 1}: 

554 raise ValueError("Expected learning phase to be 0 or 1.") 

555 

556 with tf.init_scope(): 

557 if tf.executing_eagerly(): 

558 previous_eager_value = _internal_get_learning_phase( 

559 _DUMMY_EAGER_GRAPH.key 

560 ) 

561 previous_graph_value = _internal_get_learning_phase(get_graph()) 

562 

563 learning_phase_previously_set = _DUMMY_EAGER_GRAPH.learning_phase_is_set 

564 try: 

565 deprecated_internal_set_learning_phase(value) 

566 yield 

567 finally: 

568 # Restore learning phase to initial value. 

569 if not learning_phase_previously_set: 

570 _DUMMY_EAGER_GRAPH.learning_phase_is_set = False 

571 with tf.init_scope(): 

572 if tf.executing_eagerly(): 

573 if previous_eager_value is not None: 

574 _internal_set_learning_phase( 

575 _DUMMY_EAGER_GRAPH.key, previous_eager_value 

576 ) 

577 elif _DUMMY_EAGER_GRAPH.key in _GRAPH_LEARNING_PHASES: 

578 del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] 

579 

580 graph = get_graph() 

581 if previous_graph_value is not None: 

582 _internal_set_learning_phase(graph, previous_graph_value) 

583 elif graph in _GRAPH_LEARNING_PHASES: 

584 del _GRAPH_LEARNING_PHASES[graph] 

585 

586 

587@tf_contextlib.contextmanager 

588def eager_learning_phase_scope(value): 

589 """Internal scope that sets the learning phase in eager / tf.function only. 

590 

591 Args: 

592 value: Learning phase value, either 0 or 1 (integers). 

593 0 = test, 1 = train 

594 

595 Yields: 

596 None. 

597 

598 Raises: 

599 ValueError: if `value` is neither `0` nor `1`. 

600 """ 

601 global _GRAPH_LEARNING_PHASES 

602 assert value in {0, 1} 

603 assert tf.compat.v1.executing_eagerly_outside_functions() 

604 global_learning_phase_was_set = global_learning_phase_is_set() 

605 if global_learning_phase_was_set: 

606 previous_value = learning_phase() 

607 try: 

608 _internal_set_learning_phase(_DUMMY_EAGER_GRAPH.key, value) 

609 yield 

610 finally: 

611 # Restore learning phase to initial value or unset. 

612 if global_learning_phase_was_set: 

613 _internal_set_learning_phase(_DUMMY_EAGER_GRAPH.key, previous_value) 

614 else: 

615 del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] 

616 

617 

618def _as_graph_element(obj): 

619 """Convert `obj` to a graph element if possible, otherwise return `None`. 

620 

621 Args: 

622 obj: Object to convert. 

623 

624 Returns: 

625 The result of `obj._as_graph_element()` if that method is available; 

626 otherwise `None`. 

627 """ 

628 conv_fn = getattr(obj, "_as_graph_element", None) 

629 if conv_fn and callable(conv_fn): 

630 return conv_fn() 

631 return None 

632 

633 

634def _assert_same_graph(original_item, item): 

635 """Fail if the 2 items are from different graphs. 

636 

637 Args: 

638 original_item: Original item to check against. 

639 item: Item to check. 

640 

641 Raises: 

642 ValueError: if graphs do not match. 

643 """ 

644 original_graph = getattr(original_item, "graph", None) 

645 graph = getattr(item, "graph", None) 

646 if original_graph and graph and original_graph is not graph: 

647 raise ValueError( 

648 "%s must be from the same graph as %s (graphs are %s and %s)." 

649 % (item, original_item, graph, original_graph) 

650 ) 

651 

652 

653def _current_graph(op_input_list, graph=None): 

654 """Returns the appropriate graph to use for the given inputs. 

655 

656 This library method provides a consistent algorithm for choosing the graph 

657 in which an Operation should be constructed: 

658 

659 1. If the default graph is being used to construct a function, we 

660 use the default graph. 

661 2. If the "graph" is specified explicitly, we validate that all of the 

662 inputs in "op_input_list" are compatible with that graph. 

663 3. Otherwise, we attempt to select a graph from the first Operation- 

664 or Tensor-valued input in "op_input_list", and validate that all other 

665 such inputs are in the same graph. 

666 4. If the graph was not specified and it could not be inferred from 

667 "op_input_list", we attempt to use the default graph. 

668 

669 Args: 

670 op_input_list: A list of inputs to an operation, which may include 

671 `Tensor`, `Operation`, and other objects that may be converted to a 

672 graph element. 

673 graph: (Optional) The explicit graph to use. 

674 

675 Raises: 

676 TypeError: If op_input_list is not a list or tuple, or if graph is not a 

677 Graph. 

678 ValueError: If a graph is explicitly passed and not all inputs are from 

679 it, or if the inputs are from multiple graphs, or we could not find a 

680 graph and there was no default graph. 

681 

682 Returns: 

683 The appropriate graph to use for the given inputs. 

684 

685 """ 

686 current_default_graph = tf.compat.v1.get_default_graph() 

687 if current_default_graph.building_function: 

688 return current_default_graph 

689 

690 op_input_list = tuple(op_input_list) # Handle generators correctly 

691 if graph and not isinstance(graph, tf.Graph): 

692 raise TypeError(f"Input graph needs to be a Graph: {graph}") 

693 

694 def _is_symbolic_tensor(tensor): 

695 if hasattr(tf, "is_symbolic_tensor"): 

696 return tf.is_symbolic_tensor(tensor) 

697 return type(tensor) == tf.Tensor 

698 

699 # 1. We validate that all of the inputs are from the same graph. This is 

700 # either the supplied graph parameter, or the first one selected from one 

701 # the graph-element-valued inputs. In the latter case, we hold onto 

702 # that input in original_graph_element so we can provide a more 

703 # informative error if a mismatch is found. 

704 original_graph_element = None 

705 for op_input in op_input_list: 

706 if isinstance( 

707 op_input, (tf.Operation, tf.__internal__.CompositeTensor) 

708 ) or _is_symbolic_tensor(op_input): 

709 graph_element = op_input 

710 else: 

711 graph_element = _as_graph_element(op_input) 

712 

713 if graph_element is not None: 

714 if not graph: 

715 original_graph_element = graph_element 

716 graph = getattr(graph_element, "graph", None) 

717 elif original_graph_element is not None: 

718 _assert_same_graph(original_graph_element, graph_element) 

719 elif graph_element.graph is not graph: 

720 raise ValueError( 

721 f"{graph_element} is not from the passed-in graph." 

722 ) 

723 

724 # 2. If all else fails, we use the default graph, which is always there. 

725 return graph or current_default_graph 

726 

727 

728def _get_session(op_input_list=()): 

729 """Returns the session object for the current thread.""" 

730 global _SESSION 

731 default_session = tf.compat.v1.get_default_session() 

732 if default_session is not None: 

733 session = default_session 

734 else: 

735 if tf.inside_function(): 

736 raise RuntimeError( 

737 "Cannot get session inside Tensorflow graph function." 

738 ) 

739 # If we don't have a session, or that session does not match the current 

740 # graph, create and cache a new session. 

741 if getattr( 

742 _SESSION, "session", None 

743 ) is None or _SESSION.session.graph is not _current_graph( 

744 op_input_list 

745 ): 

746 # If we are creating the Session inside a tf.distribute.Strategy 

747 # scope, we ask the strategy for the right session options to use. 

748 if tf.distribute.has_strategy(): 

749 configure_and_create_distributed_session( 

750 tf.distribute.get_strategy() 

751 ) 

752 else: 

753 _SESSION.session = tf.compat.v1.Session( 

754 config=get_default_session_config() 

755 ) 

756 session = _SESSION.session 

757 return session 

758 

759 

760@keras_export(v1=["keras.backend.get_session"]) 

761def get_session(op_input_list=()): 

762 """Returns the TF session to be used by the backend. 

763 

764 If a default TensorFlow session is available, we will return it. 

765 

766 Else, we will return the global Keras session assuming it matches 

767 the current graph. 

768 

769 If no global Keras session exists at this point: 

770 we will create a new global session. 

771 

772 Note that you can manually set the global session 

773 via `K.set_session(sess)`. 

774 

775 Args: 

776 op_input_list: An option sequence of tensors or ops, which will be used 

777 to determine the current graph. Otherwise the default graph will be 

778 used. 

779 

780 Returns: 

781 A TensorFlow session. 

782 """ 

783 session = _get_session(op_input_list) 

784 if not _MANUAL_VAR_INIT: 

785 with session.graph.as_default(): 

786 _initialize_variables(session) 

787 return session 

788 

789 

790# Inject the get_session function to keras_deps to remove the dependency 

791# from TFLite to Keras. 

792tf.__internal__.register_get_session_function(get_session) 

793 

794# Inject the get_session function to tracking_util to avoid the backward 

795# dependency from TF to Keras. 

796tf.__internal__.tracking.register_session_provider(get_session) 

797 

798 

799def get_graph(): 

800 if tf.executing_eagerly(): 

801 global _GRAPH 

802 if not getattr(_GRAPH, "graph", None): 

803 _GRAPH.graph = tf.__internal__.FuncGraph("keras_graph") 

804 return _GRAPH.graph 

805 else: 

806 return tf.compat.v1.get_default_graph() 

807 

808 

809@tf_contextlib.contextmanager 

810def _scratch_graph(graph=None): 

811 """Retrieve a shared and temporary func graph. 

812 

813 The eager execution path lifts a subgraph from the keras global graph into 

814 a scratch graph in order to create a function. DistributionStrategies, in 

815 turn, constructs multiple functions as well as a final combined function. In 

816 order for that logic to work correctly, all of the functions need to be 

817 created on the same scratch FuncGraph. 

818 

819 Args: 

820 graph: A graph to be used as the current scratch graph. If not set then 

821 a scratch graph will either be retrieved or created: 

822 

823 Yields: 

824 The current scratch graph. 

825 """ 

826 global _CURRENT_SCRATCH_GRAPH 

827 scratch_graph = getattr(_CURRENT_SCRATCH_GRAPH, "graph", None) 

828 # If scratch graph and `graph` are both configured, they must match. 

829 if ( 

830 scratch_graph is not None 

831 and graph is not None 

832 and scratch_graph is not graph 

833 ): 

834 raise ValueError("Multiple scratch graphs specified.") 

835 

836 if scratch_graph: 

837 yield scratch_graph 

838 return 

839 

840 graph = graph or tf.__internal__.FuncGraph("keras_scratch_graph") 

841 try: 

842 _CURRENT_SCRATCH_GRAPH.graph = graph 

843 yield graph 

844 finally: 

845 _CURRENT_SCRATCH_GRAPH.graph = None 

846 

847 

848@keras_export(v1=["keras.backend.set_session"]) 

849def set_session(session): 

850 """Sets the global TensorFlow session. 

851 

852 Args: 

853 session: A TF Session. 

854 """ 

855 global _SESSION 

856 _SESSION.session = session 

857 

858 

859def get_default_session_config(): 

860 if os.environ.get("OMP_NUM_THREADS"): 

861 logging.warning( 

862 "OMP_NUM_THREADS is no longer used by the default Keras config. " 

863 "To configure the number of threads, use tf.config.threading APIs." 

864 ) 

865 

866 config = get_config() 

867 config.allow_soft_placement = True 

868 

869 return config 

870 

871 

872def get_default_graph_uid_map(): 

873 graph = tf.compat.v1.get_default_graph() 

874 name_uid_map = PER_GRAPH_OBJECT_NAME_UIDS.get(graph, None) 

875 if name_uid_map is None: 

876 name_uid_map = collections.defaultdict(int) 

877 PER_GRAPH_OBJECT_NAME_UIDS[graph] = name_uid_map 

878 return name_uid_map 

879 

880 

881# DEVICE MANIPULATION 

882 

883 

884class _TfDeviceCaptureOp: 

885 """Class for capturing the TF device scope.""" 

886 

887 def __init__(self): 

888 self.device = None 

889 

890 def _set_device(self, device): 

891 """This method captures TF's explicit device scope setting.""" 

892 if isinstance(device, tf.DeviceSpec): 

893 device = device.to_string() 

894 self.device = device 

895 

896 def _set_device_from_string(self, device_str): 

897 self.device = device_str 

898 

899 

900def _get_current_tf_device(): 

901 """Return explicit device of current context, otherwise returns `None`. 

902 

903 Returns: 

904 If the current device scope is explicitly set, it returns a string with 

905 the device (`CPU` or `GPU`). If the scope is not explicitly set, it will 

906 return `None`. 

907 """ 

908 graph = get_graph() 

909 op = _TfDeviceCaptureOp() 

910 graph._apply_device_functions(op) 

911 if tf.__internal__.tf2.enabled(): 

912 return tf.DeviceSpec.from_string(op.device) 

913 else: 

914 return tf.compat.v1.DeviceSpec.from_string(op.device) 

915 

916 

917def _is_current_explicit_device(device_type): 

918 """Check if the current device is explicitly set to `device_type`. 

919 

920 Args: 

921 device_type: A string containing `GPU` or `CPU` (case-insensitive). 

922 

923 Returns: 

924 A boolean indicating if the current device scope is explicitly set on 

925 the device type. 

926 

927 Raises: 

928 ValueError: If the `device_type` string indicates an unsupported device. 

929 """ 

930 device_type = device_type.upper() 

931 if device_type not in ["CPU", "GPU"]: 

932 raise ValueError('`device_type` should be either "CPU" or "GPU".') 

933 device = _get_current_tf_device() 

934 return device is not None and device.device_type == device_type.upper() 

935 

936 

937def _get_available_gpus(): 

938 """Get a list of available GPU devices (formatted as strings). 

939 

940 Returns: 

941 A list of available GPU devices. 

942 """ 

943 if tf.compat.v1.executing_eagerly_outside_functions(): 

944 # Returns names of devices directly. 

945 return [d.name for d in tf.config.list_logical_devices("GPU")] 

946 

947 global _LOCAL_DEVICES 

948 if _LOCAL_DEVICES is None: 

949 _LOCAL_DEVICES = get_session().list_devices() 

950 return [x.name for x in _LOCAL_DEVICES if x.device_type == "GPU"] 

951 

952 

953def _has_nchw_support(): 

954 """Check whether the current scope supports NCHW ops. 

955 

956 TensorFlow does not support NCHW on CPU. Therefore we check if we are not 

957 explicitly put on 

958 CPU, and have GPUs available. In this case there will be soft-placing on the 

959 GPU device. 

960 

961 Returns: 

962 bool: if the current scope device placement would support nchw 

963 """ 

964 explicitly_on_cpu = _is_current_explicit_device("CPU") 

965 gpus_available = bool(_get_available_gpus()) 

966 return not explicitly_on_cpu and gpus_available 

967 

968 

969# VARIABLE MANIPULATION 

970 

971 

972def _constant_to_tensor(x, dtype): 

973 """Convert the input `x` to a tensor of type `dtype`. 

974 

975 This is slightly faster than the _to_tensor function, at the cost of 

976 handling fewer cases. 

977 

978 Args: 

979 x: An object to be converted (numpy arrays, floats, ints and lists of 

980 them). 

981 dtype: The destination type. 

982 

983 Returns: 

984 A tensor. 

985 """ 

986 return tf.constant(x, dtype=dtype) 

987 

988 

989def _to_tensor(x, dtype): 

990 """Convert the input `x` to a tensor of type `dtype`. 

991 

992 Args: 

993 x: An object to be converted (numpy array, list, tensors). 

994 dtype: The destination type. 

995 

996 Returns: 

997 A tensor. 

998 """ 

999 return tf.convert_to_tensor(x, dtype=dtype) 

1000 

1001 

1002@keras_export("keras.backend.is_sparse") 

1003@doc_controls.do_not_generate_docs 

1004def is_sparse(tensor): 

1005 """Returns whether a tensor is a sparse tensor. 

1006 

1007 Args: 

1008 tensor: A tensor instance. 

1009 

1010 Returns: 

1011 A boolean. 

1012 

1013 Example: 

1014 

1015 

1016 >>> a = tf.keras.backend.placeholder((2, 2), sparse=False) 

1017 >>> print(tf.keras.backend.is_sparse(a)) 

1018 False 

1019 >>> b = tf.keras.backend.placeholder((2, 2), sparse=True) 

1020 >>> print(tf.keras.backend.is_sparse(b)) 

1021 True 

1022 

1023 """ 

1024 spec = getattr(tensor, "_type_spec", None) 

1025 if spec is not None: 

1026 return isinstance(spec, tf.SparseTensorSpec) 

1027 return isinstance(tensor, tf.SparseTensor) 

1028 

1029 

1030@keras_export("keras.backend.to_dense") 

1031@tf.__internal__.dispatch.add_dispatch_support 

1032@doc_controls.do_not_generate_docs 

1033def to_dense(tensor): 

1034 """Converts a sparse tensor into a dense tensor and returns it. 

1035 

1036 Args: 

1037 tensor: A tensor instance (potentially sparse). 

1038 

1039 Returns: 

1040 A dense tensor. 

1041 

1042 Examples: 

1043 

1044 

1045 >>> b = tf.keras.backend.placeholder((2, 2), sparse=True) 

1046 >>> print(tf.keras.backend.is_sparse(b)) 

1047 True 

1048 >>> c = tf.keras.backend.to_dense(b) 

1049 >>> print(tf.keras.backend.is_sparse(c)) 

1050 False 

1051 

1052 """ 

1053 if is_sparse(tensor): 

1054 return tf.sparse.to_dense(tensor) 

1055 else: 

1056 return tensor 

1057 

1058 

1059@keras_export("keras.backend.name_scope", v1=[]) 

1060@doc_controls.do_not_generate_docs 

1061def name_scope(name): 

1062 """A context manager for use when defining a Python op. 

1063 

1064 This context manager pushes a name scope, which will make the name of all 

1065 operations added within it have a prefix. 

1066 

1067 For example, to define a new Python op called `my_op`: 

1068 

1069 

1070 def my_op(a): 

1071 with tf.name_scope("MyOp") as scope: 

1072 a = tf.convert_to_tensor(a, name="a") 

1073 # Define some computation that uses `a`. 

1074 return foo_op(..., name=scope) 

1075 

1076 

1077 When executed, the Tensor `a` will have the name `MyOp/a`. 

1078 

1079 Args: 

1080 name: The prefix to use on all names created within the name scope. 

1081 

1082 Returns: 

1083 Name scope context manager. 

1084 """ 

1085 return tf.name_scope(name) 

1086 

1087 

1088# Export V1 version. 

1089_v1_name_scope = tf.compat.v1.name_scope 

1090keras_export(v1=["keras.backend.name_scope"], allow_multiple_exports=True)( 

1091 _v1_name_scope 

1092) 

1093 

1094 

1095@keras_export("keras.backend.variable") 

1096@doc_controls.do_not_generate_docs 

1097def variable(value, dtype=None, name=None, constraint=None): 

1098 """Instantiates a variable and returns it. 

1099 

1100 Args: 

1101 value: Numpy array, initial value of the tensor. 

1102 dtype: Tensor type. 

1103 name: Optional name string for the tensor. 

1104 constraint: Optional projection function to be 

1105 applied to the variable after an optimizer update. 

1106 

1107 Returns: 

1108 A variable instance (with Keras metadata included). 

1109 

1110 Examples: 

1111 

1112 >>> val = np.array([[1, 2], [3, 4]]) 

1113 >>> kvar = tf.keras.backend.variable(value=val, dtype='float64', 

1114 ... name='example_var') 

1115 >>> tf.keras.backend.dtype(kvar) 

1116 'float64' 

1117 >>> print(kvar) 

1118 <tf.Variable 'example_var:...' shape=(2, 2) dtype=float64, numpy= 

1119 array([[1., 2.], 

1120 [3., 4.]])> 

1121 

1122 """ 

1123 if dtype is None: 

1124 dtype = floatx() 

1125 if hasattr(value, "tocoo"): 

1126 sparse_coo = value.tocoo() 

1127 indices = np.concatenate( 

1128 ( 

1129 np.expand_dims(sparse_coo.row, 1), 

1130 np.expand_dims(sparse_coo.col, 1), 

1131 ), 

1132 1, 

1133 ) 

1134 v = tf.SparseTensor( 

1135 indices=indices, 

1136 values=sparse_coo.data, 

1137 dense_shape=sparse_coo.shape, 

1138 ) 

1139 v._keras_shape = sparse_coo.shape 

1140 return v 

1141 v = tf.Variable( 

1142 value, dtype=tf.as_dtype(dtype), name=name, constraint=constraint 

1143 ) 

1144 if isinstance(value, np.ndarray): 

1145 v._keras_shape = value.shape 

1146 elif hasattr(value, "shape"): 

1147 v._keras_shape = int_shape(value) 

1148 track_variable(v) 

1149 return v 

1150 

1151 

1152def track_tf_optimizer(tf_optimizer): 

1153 """Tracks the given TF optimizer for initialization of its variables.""" 

1154 if tf.executing_eagerly(): 

1155 return 

1156 optimizers = _GRAPH_TF_OPTIMIZERS[None] 

1157 optimizers.add(tf_optimizer) 

1158 

1159 

1160@keras_export("keras.__internal__.backend.track_variable", v1=[]) 

1161def track_variable(v): 

1162 """Tracks the given variable for initialization.""" 

1163 if tf.executing_eagerly(): 

1164 return 

1165 graph = v.graph if hasattr(v, "graph") else get_graph() 

1166 _GRAPH_VARIABLES[graph].add(v) 

1167 

1168 

1169def observe_object_name(name): 

1170 """Observe a name and make sure it won't be used by `unique_object_name`.""" 

1171 OBSERVED_NAMES.add(name) 

1172 

1173 

1174def unique_object_name( 

1175 name, 

1176 name_uid_map=None, 

1177 avoid_names=None, 

1178 namespace="", 

1179 zero_based=False, 

1180 avoid_observed_names=False, 

1181): 

1182 """Makes a object name (or any string) unique within a Keras session. 

1183 

1184 Args: 

1185 name: String name to make unique. 

1186 name_uid_map: An optional defaultdict(int) to use when creating unique 

1187 names. If None (default), uses a per-Graph dictionary. 

1188 avoid_names: An optional set or dict with names which should not be used. 

1189 If None (default), don't avoid any names unless `avoid_observed_names` 

1190 is True. 

1191 namespace: Gets a name which is unique within the (graph, namespace). 

1192 Layers which are not Networks use a blank namespace and so get 

1193 graph-global names. 

1194 zero_based: If True, name sequences start with no suffix (e.g. "dense", 

1195 "dense_1"). If False, naming is one-based ("dense_1", "dense_2"). 

1196 avoid_observed_names: If True, avoid any names that have been observed by 

1197 `backend.observe_object_name`. 

1198 

1199 Returns: 

1200 Unique string name. 

1201 

1202 Example: 

1203 

1204 

1205 unique_object_name('dense') # dense_1 

1206 unique_object_name('dense') # dense_2 

1207 

1208 """ 

1209 if name_uid_map is None: 

1210 name_uid_map = get_default_graph_uid_map() 

1211 if avoid_names is None: 

1212 if avoid_observed_names: 

1213 avoid_names = OBSERVED_NAMES 

1214 else: 

1215 avoid_names = set() 

1216 proposed_name = None 

1217 while proposed_name is None or proposed_name in avoid_names: 

1218 name_key = (namespace, name) 

1219 if zero_based: 

1220 number = name_uid_map[name_key] 

1221 if number: 

1222 proposed_name = name + "_" + str(number) 

1223 else: 

1224 proposed_name = name 

1225 name_uid_map[name_key] += 1 

1226 else: 

1227 name_uid_map[name_key] += 1 

1228 proposed_name = name + "_" + str(name_uid_map[name_key]) 

1229 return proposed_name 

1230 

1231 

1232def _get_variables(graph=None): 

1233 """Returns variables corresponding to the given graph for initialization.""" 

1234 assert not tf.executing_eagerly() 

1235 variables = _GRAPH_VARIABLES[graph] 

1236 for opt in _GRAPH_TF_OPTIMIZERS[graph]: 

1237 variables.update(opt.optimizer.variables()) 

1238 return variables 

1239 

1240 

1241@keras_export("keras.__internal__.backend.initialize_variables", v1=[]) 

1242def _initialize_variables(session): 

1243 """Utility to initialize uninitialized variables on the fly.""" 

1244 variables = _get_variables(get_graph()) 

1245 candidate_vars = [] 

1246 for v in variables: 

1247 if not getattr(v, "_keras_initialized", False): 

1248 candidate_vars.append(v) 

1249 if candidate_vars: 

1250 # This step is expensive, so we only run it on variables not already 

1251 # marked as initialized. 

1252 is_initialized = session.run( 

1253 [tf.compat.v1.is_variable_initialized(v) for v in candidate_vars] 

1254 ) 

1255 # TODO(kathywu): Some metric variables loaded from SavedModel are never 

1256 # actually used, and do not have an initializer. 

1257 should_be_initialized = [ 

1258 (not is_initialized[n]) and v.initializer is not None 

1259 for n, v in enumerate(candidate_vars) 

1260 ] 

1261 uninitialized_vars = [] 

1262 for flag, v in zip(should_be_initialized, candidate_vars): 

1263 if flag: 

1264 uninitialized_vars.append(v) 

1265 v._keras_initialized = True 

1266 if uninitialized_vars: 

1267 session.run(tf.compat.v1.variables_initializer(uninitialized_vars)) 

1268 

1269 

1270@keras_export("keras.backend.constant") 

1271@tf.__internal__.dispatch.add_dispatch_support 

1272@doc_controls.do_not_generate_docs 

1273def constant(value, dtype=None, shape=None, name=None): 

1274 """Creates a constant tensor. 

1275 

1276 Args: 

1277 value: A constant value (or list) 

1278 dtype: The type of the elements of the resulting tensor. 

1279 shape: Optional dimensions of resulting tensor. 

1280 name: Optional name for the tensor. 

1281 

1282 Returns: 

1283 A Constant Tensor. 

1284 """ 

1285 if dtype is None: 

1286 dtype = floatx() 

1287 

1288 return tf.constant(value, dtype=dtype, shape=shape, name=name) 

1289 

1290 

1291@keras_export("keras.backend.is_keras_tensor") 

1292def is_keras_tensor(x): 

1293 """Returns whether `x` is a Keras tensor. 

1294 

1295 A "Keras tensor" is a tensor that was returned by a Keras layer, 

1296 (`Layer` class) or by `Input`. 

1297 

1298 Args: 

1299 x: A candidate tensor. 

1300 

1301 Returns: 

1302 A boolean: Whether the argument is a Keras tensor. 

1303 

1304 Raises: 

1305 ValueError: In case `x` is not a symbolic tensor. 

1306 

1307 Examples: 

1308 

1309 >>> np_var = np.array([1, 2]) 

1310 >>> # A numpy array is not a symbolic tensor. 

1311 >>> tf.keras.backend.is_keras_tensor(np_var) 

1312 Traceback (most recent call last): 

1313 ... 

1314 ValueError: Unexpectedly found an instance of type 

1315 `<class 'numpy.ndarray'>`. 

1316 Expected a symbolic tensor instance. 

1317 >>> keras_var = tf.keras.backend.variable(np_var) 

1318 >>> # A variable created with the keras backend is not a Keras tensor. 

1319 >>> tf.keras.backend.is_keras_tensor(keras_var) 

1320 False 

1321 >>> keras_placeholder = tf.keras.backend.placeholder(shape=(2, 4, 5)) 

1322 >>> # A placeholder is a Keras tensor. 

1323 >>> tf.keras.backend.is_keras_tensor(keras_placeholder) 

1324 True 

1325 >>> keras_input = tf.keras.layers.Input([10]) 

1326 >>> # An Input is a Keras tensor. 

1327 >>> tf.keras.backend.is_keras_tensor(keras_input) 

1328 True 

1329 >>> keras_layer_output = tf.keras.layers.Dense(10)(keras_input) 

1330 >>> # Any Keras layer output is a Keras tensor. 

1331 >>> tf.keras.backend.is_keras_tensor(keras_layer_output) 

1332 True 

1333 

1334 """ 

1335 if not isinstance( 

1336 x, 

1337 ( 

1338 tf.Tensor, 

1339 tf.Variable, 

1340 tf.SparseTensor, 

1341 tf.RaggedTensor, 

1342 keras_tensor.KerasTensor, 

1343 ), 

1344 ): 

1345 raise ValueError( 

1346 "Unexpectedly found an instance of type `" 

1347 + str(type(x)) 

1348 + "`. Expected a symbolic tensor instance." 

1349 ) 

1350 if tf.compat.v1.executing_eagerly_outside_functions(): 

1351 return isinstance(x, keras_tensor.KerasTensor) 

1352 return hasattr(x, "_keras_history") 

1353 

1354 

1355@keras_export("keras.backend.placeholder") 

1356@doc_controls.do_not_generate_docs 

1357def placeholder( 

1358 shape=None, ndim=None, dtype=None, sparse=False, name=None, ragged=False 

1359): 

1360 """Instantiates a placeholder tensor and returns it. 

1361 

1362 Args: 

1363 shape: Shape of the placeholder 

1364 (integer tuple, may include `None` entries). 

1365 ndim: Number of axes of the tensor. 

1366 At least one of {`shape`, `ndim`} must be specified. 

1367 If both are specified, `shape` is used. 

1368 dtype: Placeholder type. 

1369 sparse: Boolean, whether the placeholder should have a sparse type. 

1370 name: Optional name string for the placeholder. 

1371 ragged: Boolean, whether the placeholder should have a ragged type. 

1372 In this case, values of 'None' in the 'shape' argument represent 

1373 ragged dimensions. For more information about RaggedTensors, see 

1374 this [guide](https://www.tensorflow.org/guide/ragged_tensor). 

1375 

1376 Raises: 

1377 ValueError: If called with sparse = True and ragged = True. 

1378 

1379 Returns: 

1380 Tensor instance (with Keras metadata included). 

1381 

1382 Examples: 

1383 

1384 

1385 >>> input_ph = tf.keras.backend.placeholder(shape=(2, 4, 5)) 

1386 >>> input_ph 

1387 <KerasTensor: shape=(2, 4, 5) dtype=float32 (created by layer ...)> 

1388 

1389 """ 

1390 if sparse and ragged: 

1391 raise ValueError( 

1392 "Cannot set both sparse and ragged to " 

1393 "True when creating a placeholder." 

1394 ) 

1395 if dtype is None: 

1396 dtype = floatx() 

1397 if not shape: 

1398 if ndim: 

1399 shape = (None,) * ndim 

1400 if tf.compat.v1.executing_eagerly_outside_functions(): 

1401 if sparse: 

1402 spec = tf.SparseTensorSpec(shape=shape, dtype=dtype) 

1403 elif ragged: 

1404 ragged_rank = 0 

1405 for i in range(1, len(shape)): 

1406 # Hacky because could be tensorshape or tuple maybe? 

1407 # Or just tensorshape? 

1408 if shape[i] is None or ( 

1409 hasattr(shape[i], "value") and shape[i].value is None 

1410 ): 

1411 ragged_rank = i 

1412 spec = tf.RaggedTensorSpec( 

1413 shape=shape, dtype=dtype, ragged_rank=ragged_rank 

1414 ) 

1415 else: 

1416 spec = tf.TensorSpec(shape=shape, dtype=dtype, name=name) 

1417 x = keras_tensor.keras_tensor_from_type_spec(spec, name=name) 

1418 else: 

1419 with get_graph().as_default(): 

1420 if sparse: 

1421 x = tf.compat.v1.sparse_placeholder( 

1422 dtype, shape=shape, name=name 

1423 ) 

1424 elif ragged: 

1425 ragged_rank = 0 

1426 for i in range(1, len(shape)): 

1427 if shape[i] is None: 

1428 ragged_rank = i 

1429 type_spec = tf.RaggedTensorSpec( 

1430 shape=shape, dtype=dtype, ragged_rank=ragged_rank 

1431 ) 

1432 

1433 def tensor_spec_to_placeholder(tensorspec): 

1434 return tf.compat.v1.placeholder( 

1435 tensorspec.dtype, tensorspec.shape 

1436 ) 

1437 

1438 x = tf.nest.map_structure( 

1439 tensor_spec_to_placeholder, 

1440 type_spec, 

1441 expand_composites=True, 

1442 ) 

1443 else: 

1444 x = tf.compat.v1.placeholder(dtype, shape=shape, name=name) 

1445 

1446 if tf.executing_eagerly(): 

1447 # Add keras_history connectivity information to the placeholder 

1448 # when the placeholder is built in a top-level eager context 

1449 # (intended to be used with keras.backend.function) 

1450 from keras.src.engine import ( 

1451 input_layer, 

1452 ) 

1453 

1454 x = input_layer.Input(tensor=x) 

1455 x._is_backend_placeholder = True 

1456 

1457 return x 

1458 

1459 

1460def is_placeholder(x): 

1461 """Returns whether `x` is a placeholder. 

1462 

1463 Args: 

1464 x: A candidate placeholder. 

1465 

1466 Returns: 

1467 Boolean. 

1468 """ 

1469 try: 

1470 if tf.compat.v1.executing_eagerly_outside_functions(): 

1471 return hasattr(x, "_is_backend_placeholder") 

1472 

1473 # TODO(b/246438937): Remove the special case for tf.Variable once 

1474 # tf.Variable becomes CompositeTensor and will be expanded into 

1475 # dt_resource tensors. 

1476 if tf_utils.is_extension_type(x) and not isinstance(x, tf.Variable): 

1477 flat_components = tf.nest.flatten(x, expand_composites=True) 

1478 return py_any(is_placeholder(c) for c in flat_components) 

1479 else: 

1480 return x.op.type == "Placeholder" 

1481 except AttributeError: 

1482 return False 

1483 

1484 

1485@keras_export("keras.backend.shape") 

1486@tf.__internal__.dispatch.add_dispatch_support 

1487@doc_controls.do_not_generate_docs 

1488def shape(x): 

1489 """Returns the symbolic shape of a tensor or variable. 

1490 

1491 Args: 

1492 x: A tensor or variable. 

1493 

1494 Returns: 

1495 A symbolic shape (which is itself a tensor). 

1496 

1497 Examples: 

1498 

1499 >>> val = np.array([[1, 2], [3, 4]]) 

1500 >>> kvar = tf.keras.backend.variable(value=val) 

1501 >>> tf.keras.backend.shape(kvar) 

1502 <tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 2], dtype=int32)> 

1503 >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5)) 

1504 >>> tf.keras.backend.shape(input) 

1505 <KerasTensor: shape=(3,) dtype=int32 inferred_value=[2, 4, 5] ...> 

1506 

1507 """ 

1508 return tf.shape(x) 

1509 

1510 

1511@keras_export("keras.backend.int_shape") 

1512@doc_controls.do_not_generate_docs 

1513def int_shape(x): 

1514 """Returns shape of tensor/variable as a tuple of int/None entries. 

1515 

1516 Args: 

1517 x: Tensor or variable. 

1518 

1519 Returns: 

1520 A tuple of integers (or None entries). 

1521 

1522 Examples: 

1523 

1524 >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5)) 

1525 >>> tf.keras.backend.int_shape(input) 

1526 (2, 4, 5) 

1527 >>> val = np.array([[1, 2], [3, 4]]) 

1528 >>> kvar = tf.keras.backend.variable(value=val) 

1529 >>> tf.keras.backend.int_shape(kvar) 

1530 (2, 2) 

1531 

1532 """ 

1533 try: 

1534 shape = x.shape 

1535 if not isinstance(shape, tuple): 

1536 shape = tuple(shape.as_list()) 

1537 return shape 

1538 except ValueError: 

1539 return None 

1540 

1541 

1542@keras_export("keras.backend.ndim") 

1543@doc_controls.do_not_generate_docs 

1544def ndim(x): 

1545 """Returns the number of axes in a tensor, as an integer. 

1546 

1547 Args: 

1548 x: Tensor or variable. 

1549 

1550 Returns: 

1551 Integer (scalar), number of axes. 

1552 

1553 Examples: 

1554 

1555 

1556 >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5)) 

1557 >>> val = np.array([[1, 2], [3, 4]]) 

1558 >>> kvar = tf.keras.backend.variable(value=val) 

1559 >>> tf.keras.backend.ndim(input) 

1560 3 

1561 >>> tf.keras.backend.ndim(kvar) 

1562 2 

1563 

1564 """ 

1565 return x.shape.rank 

1566 

1567 

1568@keras_export("keras.backend.dtype") 

1569@tf.__internal__.dispatch.add_dispatch_support 

1570@doc_controls.do_not_generate_docs 

1571def dtype(x): 

1572 """Returns the dtype of a Keras tensor or variable, as a string. 

1573 

1574 Args: 

1575 x: Tensor or variable. 

1576 

1577 Returns: 

1578 String, dtype of `x`. 

1579 

1580 Examples: 

1581 

1582 >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5))) 

1583 'float32' 

1584 >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5), 

1585 ... dtype='float32')) 

1586 'float32' 

1587 >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5), 

1588 ... dtype='float64')) 

1589 'float64' 

1590 >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]])) 

1591 >>> tf.keras.backend.dtype(kvar) 

1592 'float32' 

1593 >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]), 

1594 ... dtype='float32') 

1595 >>> tf.keras.backend.dtype(kvar) 

1596 'float32' 

1597 

1598 """ 

1599 return x.dtype.base_dtype.name 

1600 

1601 

1602@doc_controls.do_not_generate_docs 

1603def dtype_numpy(x): 

1604 """Returns the numpy dtype of a Keras tensor or variable. 

1605 

1606 Args: 

1607 x: Tensor or variable. 

1608 

1609 Returns: 

1610 numpy.dtype, dtype of `x`. 

1611 """ 

1612 return tf.as_dtype(x.dtype).as_numpy_dtype 

1613 

1614 

1615@keras_export("keras.backend.eval") 

1616@doc_controls.do_not_generate_docs 

1617def eval(x): 

1618 """Evaluates the value of a variable. 

1619 

1620 Args: 

1621 x: A variable. 

1622 

1623 Returns: 

1624 A Numpy array. 

1625 

1626 Examples: 

1627 

1628 >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]), 

1629 ... dtype='float32') 

1630 >>> tf.keras.backend.eval(kvar) 

1631 array([[1., 2.], 

1632 [3., 4.]], dtype=float32) 

1633 

1634 """ 

1635 return get_value(to_dense(x)) 

1636 

1637 

1638@keras_export("keras.backend.zeros") 

1639@doc_controls.do_not_generate_docs 

1640def zeros(shape, dtype=None, name=None): 

1641 """Instantiates an all-zeros variable and returns it. 

1642 

1643 Args: 

1644 shape: Tuple or list of integers, shape of returned Keras variable 

1645 dtype: data type of returned Keras variable 

1646 name: name of returned Keras variable 

1647 

1648 Returns: 

1649 A variable (including Keras metadata), filled with `0.0`. 

1650 Note that if `shape` was symbolic, we cannot return a variable, 

1651 and will return a dynamically-shaped tensor instead. 

1652 

1653 Example: 

1654 

1655 >>> kvar = tf.keras.backend.zeros((3,4)) 

1656 >>> tf.keras.backend.eval(kvar) 

1657 array([[0., 0., 0., 0.], 

1658 [0., 0., 0., 0.], 

1659 [0., 0., 0., 0.]], dtype=float32) 

1660 >>> A = tf.constant([1,2,3]) 

1661 >>> kvar2 = tf.keras.backend.zeros(A.shape) # [0., 0., 0.] 

1662 >>> tf.keras.backend.eval(kvar2) 

1663 array([0., 0., 0.], dtype=float32) 

1664 >>> kvar3 = tf.keras.backend.zeros(A.shape,dtype=tf.int32) 

1665 >>> tf.keras.backend.eval(kvar3) 

1666 array([0, 0, 0], dtype=int32) 

1667 >>> kvar4 = tf.keras.backend.zeros([2,3]) 

1668 >>> tf.keras.backend.eval(kvar4) 

1669 array([[0., 0., 0.], 

1670 [0., 0., 0.]], dtype=float32) 

1671 

1672 """ 

1673 with tf.init_scope(): 

1674 if dtype is None: 

1675 dtype = floatx() 

1676 tf_dtype = tf.as_dtype(dtype) 

1677 v = tf.zeros(shape=shape, dtype=tf_dtype, name=name) 

1678 if py_all(v.shape.as_list()): 

1679 return variable(v, dtype=dtype, name=name) 

1680 return v 

1681 

1682 

1683@keras_export("keras.backend.ones") 

1684@tf.__internal__.dispatch.add_dispatch_support 

1685@doc_controls.do_not_generate_docs 

1686def ones(shape, dtype=None, name=None): 

1687 """Instantiates an all-ones variable and returns it. 

1688 

1689 Args: 

1690 shape: Tuple of integers, shape of returned Keras variable. 

1691 dtype: String, data type of returned Keras variable. 

1692 name: String, name of returned Keras variable. 

1693 

1694 Returns: 

1695 A Keras variable, filled with `1.0`. 

1696 Note that if `shape` was symbolic, we cannot return a variable, 

1697 and will return a dynamically-shaped tensor instead. 

1698 

1699 Example: 

1700 

1701 

1702 >>> kvar = tf.keras.backend.ones((3,4)) 

1703 >>> tf.keras.backend.eval(kvar) 

1704 array([[1., 1., 1., 1.], 

1705 [1., 1., 1., 1.], 

1706 [1., 1., 1., 1.]], dtype=float32) 

1707 

1708 """ 

1709 with tf.init_scope(): 

1710 if dtype is None: 

1711 dtype = floatx() 

1712 tf_dtype = tf.as_dtype(dtype) 

1713 v = tf.ones(shape=shape, dtype=tf_dtype, name=name) 

1714 if py_all(v.shape.as_list()): 

1715 return variable(v, dtype=dtype, name=name) 

1716 return v 

1717 

1718 

1719@keras_export("keras.backend.eye") 

1720@tf.__internal__.dispatch.add_dispatch_support 

1721@doc_controls.do_not_generate_docs 

1722def eye(size, dtype=None, name=None): 

1723 """Instantiate an identity matrix and returns it. 

1724 

1725 Args: 

1726 size: Integer, number of rows/columns. 

1727 dtype: String, data type of returned Keras variable. 

1728 name: String, name of returned Keras variable. 

1729 

1730 Returns: 

1731 A Keras variable, an identity matrix. 

1732 

1733 Example: 

1734 

1735 

1736 >>> kvar = tf.keras.backend.eye(3) 

1737 >>> tf.keras.backend.eval(kvar) 

1738 array([[1., 0., 0.], 

1739 [0., 1., 0.], 

1740 [0., 0., 1.]], dtype=float32) 

1741 

1742 

1743 """ 

1744 if dtype is None: 

1745 dtype = floatx() 

1746 tf_dtype = tf.as_dtype(dtype) 

1747 return variable(tf.eye(size, dtype=tf_dtype), dtype, name) 

1748 

1749 

1750@keras_export("keras.backend.zeros_like") 

1751@doc_controls.do_not_generate_docs 

1752def zeros_like(x, dtype=None, name=None): 

1753 """Instantiates an all-zeros variable of the same shape as another tensor. 

1754 

1755 Args: 

1756 x: Keras variable or Keras tensor. 

1757 dtype: dtype of returned Keras variable. 

1758 `None` uses the dtype of `x`. 

1759 name: name for the variable to create. 

1760 

1761 Returns: 

1762 A Keras variable with the shape of `x` filled with zeros. 

1763 

1764 Example: 

1765 

1766 ```python 

1767 kvar = tf.keras.backend.variable(np.random.random((2,3))) 

1768 kvar_zeros = tf.keras.backend.zeros_like(kvar) 

1769 K.eval(kvar_zeros) 

1770 # array([[ 0., 0., 0.], [ 0., 0., 0.]], dtype=float32) 

1771 ``` 

1772 """ 

1773 return tf.zeros_like(x, dtype=dtype, name=name) 

1774 

1775 

1776@keras_export("keras.backend.ones_like") 

1777@tf.__internal__.dispatch.add_dispatch_support 

1778@doc_controls.do_not_generate_docs 

1779def ones_like(x, dtype=None, name=None): 

1780 """Instantiates an all-ones variable of the same shape as another tensor. 

1781 

1782 Args: 

1783 x: Keras variable or tensor. 

1784 dtype: String, dtype of returned Keras variable. 

1785 None uses the dtype of x. 

1786 name: String, name for the variable to create. 

1787 

1788 Returns: 

1789 A Keras variable with the shape of x filled with ones. 

1790 

1791 Example: 

1792 

1793 >>> kvar = tf.keras.backend.variable(np.random.random((2,3))) 

1794 >>> kvar_ones = tf.keras.backend.ones_like(kvar) 

1795 >>> tf.keras.backend.eval(kvar_ones) 

1796 array([[1., 1., 1.], 

1797 [1., 1., 1.]], dtype=float32) 

1798 

1799 """ 

1800 return tf.ones_like(x, dtype=dtype, name=name) 

1801 

1802 

1803def identity(x, name=None): 

1804 """Returns a tensor with the same content as the input tensor. 

1805 

1806 Args: 

1807 x: The input tensor. 

1808 name: String, name for the variable to create. 

1809 

1810 Returns: 

1811 A tensor of the same shape, type and content. 

1812 """ 

1813 return tf.identity(x, name=name) 

1814 

1815 

1816# Global flag to enforce tf.random.Generator for RandomGenerator. 

1817# When this is enabled, for any caller to RandomGenerator, it will use 

1818# tf.random.Generator to generate random numbers. 

1819# The legacy behavior is to use TF's legacy stateful RNG ops like 

1820# tf.random.uniform. 

1821_USE_GENERATOR_FOR_RNG = False 

1822 

1823# The global generator to create the seed when initializing the 

1824# tf.random.Genrator used by RandomGenerator. When tf.random.Generator becomes 

1825# the default solution, we would like the it to be initialized in a controlable 

1826# way, so that each client of the program could start with same seed. This is 

1827# very important for certain use case that requires all the client to have their 

1828# state in sync. This instance will be set when user call 

1829# `tf.keras.utils.set_random_seed()` 

1830_SEED_GENERATOR = threading.local() 

1831 

1832 

1833@keras_export( 

1834 "keras.backend.experimental.is_tf_random_generator_enabled", v1=[] 

1835) 

1836def is_tf_random_generator_enabled(): 

1837 """Check whether `tf.random.Generator` is used for RNG in Keras. 

1838 

1839 Compared to existing TF stateful random ops, `tf.random.Generator` uses 

1840 `tf.Variable` and stateless random ops to generate random numbers, 

1841 which leads to better reproducibility in distributed training. 

1842 Note enabling it might introduce some breakage to existing code, 

1843 by producing differently-seeded random number sequences 

1844 and breaking tests that rely on specific random numbers being generated. 

1845 To disable the 

1846 usage of `tf.random.Generator`, please use 

1847 `tf.keras.backend.experimental.disable_random_generator`. 

1848 

1849 We expect the `tf.random.Generator` code path to become the default, and 

1850 will remove the legacy stateful random ops such as `tf.random.uniform` in 

1851 the future (see the [TF RNG guide]( 

1852 https://www.tensorflow.org/guide/random_numbers)). 

1853 

1854 This API will also be removed in a future release as well, together with 

1855 `tf.keras.backend.experimental.enable_tf_random_generator()` and 

1856 `tf.keras.backend.experimental.disable_tf_random_generator()` 

1857 

1858 Returns: 

1859 boolean: whether `tf.random.Generator` is used for random number 

1860 generation in Keras. 

1861 """ 

1862 return _USE_GENERATOR_FOR_RNG 

1863 

1864 

1865@keras_export("keras.backend.experimental.enable_tf_random_generator", v1=[]) 

1866def enable_tf_random_generator(): 

1867 """Enable the `tf.random.Generator` as the RNG for Keras. 

1868 

1869 See `tf.keras.backend.experimental.is_tf_random_generator_enabled` for more 

1870 details. 

1871 """ 

1872 

1873 global _USE_GENERATOR_FOR_RNG 

1874 _USE_GENERATOR_FOR_RNG = True 

1875 

1876 

1877@keras_export("keras.backend.experimental.disable_tf_random_generator", v1=[]) 

1878def disable_tf_random_generator(): 

1879 """Disable the `tf.random.Generator` as the RNG for Keras. 

1880 

1881 See `tf.keras.backend.experimental.is_tf_random_generator_enabled` for more 

1882 details. 

1883 """ 

1884 global _USE_GENERATOR_FOR_RNG 

1885 _USE_GENERATOR_FOR_RNG = False 

1886 

1887 

1888class RandomGenerator(tf.__internal__.tracking.AutoTrackable): 

1889 """Random generator that selects appropriate random ops. 

1890 

1891 This class contains the logic for legacy stateful random ops, as well as the 

1892 new stateless random ops with seeds and tf.random.Generator. Any class that 

1893 relies on RNG (eg initializer, shuffle, dropout) should use this class to 

1894 handle the transition from legacy RNGs to new RNGs. 

1895 

1896 Args: 

1897 seed: Optional int seed. When `rng_type` is "stateful", the seed is used 

1898 to create `tf.random.Generator` to produce deterministic sequences. 

1899 When `rng_type` is "stateless", new seed will be created if it is not 

1900 provided by user, and it will be passed down to stateless random ops. 

1901 When `rng_type` is "legacy_stateful", the seed will be passed down to 

1902 stateful random ops. 

1903 rng_type: Type of RNG to use, one of "stateful", "stateless", 

1904 "legacy_stateful". When `None` it uses "stateful" if 

1905 `enable_tf_random_generator` has been activated, or 

1906 "legacy_stateful" otherwise. 

1907 - When using "stateless", the random ops outputs are constant (the same 

1908 inputs result in the same outputs). 

1909 - When using "stateful" or "legacy_stateful", the random ops outputs are 

1910 non-constant, but deterministic: calling the same random op multiple 

1911 times with the same inputs results in a deterministic sequence of 

1912 different outputs. 

1913 - "legacy_stateful" is backed by TF1 stateful RNG ops 

1914 (e.g. `tf.random.uniform`), while "stateful" 

1915 is backed by TF2 APIs (e.g. `tf.random.Generator.uniform`). 

1916 Defaults to `None`. 

1917 """ 

1918 

1919 RNG_STATELESS = "stateless" 

1920 RNG_STATEFUL = "stateful" 

1921 RNG_LEGACY_STATEFUL = "legacy_stateful" 

1922 

1923 def __init__(self, seed=None, rng_type=None, **kwargs): 

1924 self._seed = seed 

1925 self._set_rng_type(rng_type, **kwargs) 

1926 self._built = False 

1927 

1928 def _set_rng_type(self, rng_type, **kwargs): 

1929 # Only supported kwargs is "force_generator", which we will remove once 

1930 # we clean up all the caller. 

1931 # TODO(scottzhu): Remove the kwargs for force_generator. 

1932 if kwargs.get("force_generator", False): 

1933 rng_type = self.RNG_STATEFUL 

1934 if rng_type is None: 

1935 if is_tf_random_generator_enabled(): 

1936 self._rng_type = self.RNG_STATEFUL 

1937 else: 

1938 self._rng_type = self.RNG_LEGACY_STATEFUL 

1939 else: 

1940 if rng_type not in [ 

1941 self.RNG_STATEFUL, 

1942 self.RNG_LEGACY_STATEFUL, 

1943 self.RNG_STATELESS, 

1944 ]: 

1945 raise ValueError( 

1946 "Invalid `rng_type` received. " 

1947 'Valid `rng_type` are ["stateless", ' 

1948 '"stateful", "legacy_stateful"].' 

1949 f" Got: {rng_type}" 

1950 ) 

1951 self._rng_type = rng_type 

1952 

1953 def _maybe_init(self): 

1954 """Lazily init the RandomGenerator. 

1955 

1956 The TF API executing_eagerly_outside_functions() has some side effect, 

1957 and couldn't be used before API like tf.enable_eager_execution(). Some 

1958 of the client side code was creating the initializer at the code load 

1959 time, which triggers the creation of RandomGenerator. Lazy init this 

1960 class to walkaround this issue until it is resolved on TF side. 

1961 """ 

1962 # TODO(b/167482354): Change this back to normal init when the bug is 

1963 # fixed. 

1964 if self._built: 

1965 return 

1966 

1967 if ( 

1968 self._rng_type == self.RNG_STATEFUL 

1969 and not tf.compat.v1.executing_eagerly_outside_functions() 

1970 ): 

1971 # Fall back to legacy stateful since the generator need to work in 

1972 # tf2. 

1973 self._rng_type = self.RNG_LEGACY_STATEFUL 

1974 

1975 if self._rng_type == self.RNG_STATELESS: 

1976 self._seed = self._create_seed(self._seed) 

1977 self._generator = None 

1978 elif self._rng_type == self.RNG_STATEFUL: 

1979 with tf_utils.maybe_init_scope(self): 

1980 seed = self._create_seed(self._seed) 

1981 self._generator = tf.random.Generator.from_seed( 

1982 seed, alg=tf.random.Algorithm.AUTO_SELECT 

1983 ) 

1984 else: 

1985 # In legacy stateful, we use stateful op, regardless whether user 

1986 # provide seed or not. Seeded stateful op will ensure generating 

1987 # same sequences. 

1988 self._generator = None 

1989 self._built = True 

1990 

1991 def make_seed_for_stateless_op(self): 

1992 """Generate a new seed based on the init config. 

1993 

1994 Note that this will not return python ints which will be frozen in the 

1995 graph and cause stateless op to return the same value. It will only 

1996 return value when generator is used, otherwise it will return None. 

1997 

1998 Returns: 

1999 A tensor with shape [2,]. 

2000 """ 

2001 self._maybe_init() 

2002 if self._rng_type == self.RNG_STATELESS: 

2003 return [self._seed, 0] 

2004 elif self._rng_type == self.RNG_STATEFUL: 

2005 return self._generator.make_seeds()[:, 0] 

2006 return None 

2007 

2008 def make_legacy_seed(self): 

2009 """Create a new seed for the legacy stateful ops to use. 

2010 

2011 When user didn't provide any original seed, this method will return 

2012 None. Otherwise it will increment the counter and return as the new 

2013 seed. 

2014 

2015 Note that it is important to generate different seed for stateful ops in 

2016 the `tf.function`. The random ops will return same value when same seed 

2017 is provided in the `tf.function`. 

2018 

2019 Returns: 

2020 int as new seed, or None. 

2021 """ 

2022 if self._seed is not None: 

2023 result = self._seed 

2024 self._seed += 1 

2025 return result 

2026 return None 

2027 

2028 def _create_seed(self, user_specified_seed): 

2029 if user_specified_seed is not None: 

2030 return user_specified_seed 

2031 elif getattr(_SEED_GENERATOR, "generator", None): 

2032 return _SEED_GENERATOR.generator.randint(1, 1e9) 

2033 else: 

2034 return random.randint(1, int(1e9)) 

2035 

2036 def random_normal( 

2037 self, shape, mean=0.0, stddev=1.0, dtype=None, nonce=None 

2038 ): 

2039 """Produce random number based on the normal distribution. 

2040 

2041 Args: 

2042 shape: The shape of the random values to generate. 

2043 mean: Floats, default to 0. Mean of the random values to generate. 

2044 stddev: Floats, default to 1. Standard deviation of the random values 

2045 to generate. 

2046 dtype: Optional dtype of the tensor. Only floating point types are 

2047 supported. If not specified, `tf.keras.backend.floatx()` is used, 

2048 which default to `float32` unless you configured it otherwise (via 

2049 `tf.keras.backend.set_floatx(float_dtype)`) 

2050 nonce: Optional integer scalar, that will be folded into the seed in 

2051 the stateless mode. 

2052 """ 

2053 self._maybe_init() 

2054 dtype = dtype or floatx() 

2055 if self._rng_type == self.RNG_STATEFUL: 

2056 return self._generator.normal( 

2057 shape=shape, mean=mean, stddev=stddev, dtype=dtype 

2058 ) 

2059 elif self._rng_type == self.RNG_STATELESS: 

2060 seed = self.make_seed_for_stateless_op() 

2061 if nonce: 

2062 seed = tf.random.experimental.stateless_fold_in(seed, nonce) 

2063 return tf.random.stateless_normal( 

2064 shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed 

2065 ) 

2066 return tf.random.normal( 

2067 shape=shape, 

2068 mean=mean, 

2069 stddev=stddev, 

2070 dtype=dtype, 

2071 seed=self.make_legacy_seed(), 

2072 ) 

2073 

2074 def random_uniform( 

2075 self, shape, minval=0.0, maxval=None, dtype=None, nonce=None 

2076 ): 

2077 """Produce random number based on the uniform distribution. 

2078 

2079 Args: 

2080 shape: The shape of the random values to generate. 

2081 minval: Floats, default to 0. Lower bound of the range of 

2082 random values to generate (inclusive). 

2083 minval: Floats, default to None. Upper bound of the range of 

2084 random values to generate (exclusive). 

2085 dtype: Optional dtype of the tensor. Only floating point types are 

2086 supported. If not specified, `tf.keras.backend.floatx()` is used, 

2087 which default to `float32` unless you configured it otherwise (via 

2088 `tf.keras.backend.set_floatx(float_dtype)`) 

2089 nonce: Optional integer scalar, that will be folded into the seed in 

2090 the stateless mode. 

2091 """ 

2092 self._maybe_init() 

2093 dtype = dtype or floatx() 

2094 if self._rng_type == self.RNG_STATEFUL: 

2095 return self._generator.uniform( 

2096 shape=shape, minval=minval, maxval=maxval, dtype=dtype 

2097 ) 

2098 elif self._rng_type == self.RNG_STATELESS: 

2099 seed = self.make_seed_for_stateless_op() 

2100 if nonce: 

2101 seed = tf.random.experimental.stateless_fold_in(seed, nonce) 

2102 return tf.random.stateless_uniform( 

2103 shape=shape, 

2104 minval=minval, 

2105 maxval=maxval, 

2106 dtype=dtype, 

2107 seed=seed, 

2108 ) 

2109 return tf.random.uniform( 

2110 shape=shape, 

2111 minval=minval, 

2112 maxval=maxval, 

2113 dtype=dtype, 

2114 seed=self.make_legacy_seed(), 

2115 ) 

2116 

2117 def truncated_normal( 

2118 self, shape, mean=0.0, stddev=1.0, dtype=None, nonce=None 

2119 ): 

2120 """Produce random number based on the truncated normal distribution. 

2121 

2122 Args: 

2123 shape: The shape of the random values to generate. 

2124 mean: Floats, default to 0. Mean of the random values to generate. 

2125 stddev: Floats, default to 1. Standard deviation of the random values 

2126 to generate. 

2127 dtype: Optional dtype of the tensor. Only floating point types are 

2128 supported. If not specified, `tf.keras.backend.floatx()` is used, 

2129 which default to `float32` unless you configured it otherwise (via 

2130 `tf.keras.backend.set_floatx(float_dtype)`) 

2131 nonce: Optional integer scalar, that will be folded into the seed in 

2132 the stateless mode. 

2133 """ 

2134 self._maybe_init() 

2135 dtype = dtype or floatx() 

2136 if self._rng_type == self.RNG_STATEFUL: 

2137 return self._generator.truncated_normal( 

2138 shape=shape, mean=mean, stddev=stddev, dtype=dtype 

2139 ) 

2140 elif self._rng_type == self.RNG_STATELESS: 

2141 seed = self.make_seed_for_stateless_op() 

2142 if nonce: 

2143 seed = tf.random.experimental.stateless_fold_in(seed, nonce) 

2144 return tf.random.stateless_truncated_normal( 

2145 shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed 

2146 ) 

2147 return tf.random.truncated_normal( 

2148 shape=shape, 

2149 mean=mean, 

2150 stddev=stddev, 

2151 dtype=dtype, 

2152 seed=self.make_legacy_seed(), 

2153 ) 

2154 

2155 def dropout(self, inputs, rate, noise_shape=None): 

2156 self._maybe_init() 

2157 if self._rng_type == self.RNG_STATEFUL: 

2158 return tf.nn.experimental.general_dropout( 

2159 inputs, 

2160 rate=rate, 

2161 noise_shape=noise_shape, 

2162 uniform_sampler=self._generator.uniform, 

2163 ) 

2164 elif self._rng_type == self.RNG_STATELESS: 

2165 return tf.nn.experimental.stateless_dropout( 

2166 inputs, 

2167 rate=rate, 

2168 noise_shape=noise_shape, 

2169 seed=self.make_seed_for_stateless_op(), 

2170 ) 

2171 else: 

2172 return tf.nn.dropout( 

2173 inputs, 

2174 rate=rate, 

2175 noise_shape=noise_shape, 

2176 seed=self.make_legacy_seed(), 

2177 ) 

2178 

2179 

2180@keras_export("keras.backend.random_uniform_variable") 

2181@doc_controls.do_not_generate_docs 

2182def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None): 

2183 """Instantiates a variable with values drawn from a uniform distribution. 

2184 

2185 Args: 

2186 shape: Tuple of integers, shape of returned Keras variable. 

2187 low: Float, lower boundary of the output interval. 

2188 high: Float, upper boundary of the output interval. 

2189 dtype: String, dtype of returned Keras variable. 

2190 name: String, name of returned Keras variable. 

2191 seed: Integer, random seed. 

2192 

2193 Returns: 

2194 A Keras variable, filled with drawn samples. 

2195 

2196 Example: 

2197 

2198 >>> kvar = tf.keras.backend.random_uniform_variable(shape=(2,3), 

2199 ... low=0.0, high=1.0) 

2200 >>> kvar 

2201 <tf.Variable 'Variable:0' shape=(2, 3) dtype=float32, numpy=..., 

2202 dtype=float32)> 

2203 """ 

2204 if dtype is None: 

2205 dtype = floatx() 

2206 tf_dtype = tf.as_dtype(dtype) 

2207 if seed is None: 

2208 # ensure that randomness is conditioned by the Numpy RNG 

2209 seed = np.random.randint(10e8) 

2210 value = tf.compat.v1.random_uniform_initializer( 

2211 low, high, dtype=tf_dtype, seed=seed 

2212 )(shape) 

2213 return variable(value, dtype=dtype, name=name) 

2214 

2215 

2216@keras_export("keras.backend.random_normal_variable") 

2217@doc_controls.do_not_generate_docs 

2218def random_normal_variable( 

2219 shape, mean, scale, dtype=None, name=None, seed=None 

2220): 

2221 """Instantiates a variable with values drawn from a normal distribution. 

2222 

2223 Args: 

2224 shape: Tuple of integers, shape of returned Keras variable. 

2225 mean: Float, mean of the normal distribution. 

2226 scale: Float, standard deviation of the normal distribution. 

2227 dtype: String, dtype of returned Keras variable. 

2228 name: String, name of returned Keras variable. 

2229 seed: Integer, random seed. 

2230 

2231 Returns: 

2232 A Keras variable, filled with drawn samples. 

2233 

2234 Example: 

2235 

2236 >>> kvar = tf.keras.backend.random_normal_variable(shape=(2,3), 

2237 ... mean=0.0, scale=1.0) 

2238 >>> kvar 

2239 <tf.Variable 'Variable:0' shape=(2, 3) dtype=float32, numpy=..., 

2240 dtype=float32)> 

2241 """ 

2242 if dtype is None: 

2243 dtype = floatx() 

2244 tf_dtype = tf.as_dtype(dtype) 

2245 if seed is None: 

2246 # ensure that randomness is conditioned by the Numpy RNG 

2247 seed = np.random.randint(10e8) 

2248 value = tf.compat.v1.random_normal_initializer( 

2249 mean, scale, dtype=tf_dtype, seed=seed 

2250 )(shape) 

2251 return variable(value, dtype=dtype, name=name) 

2252 

2253 

2254@keras_export("keras.backend.count_params") 

2255@doc_controls.do_not_generate_docs 

2256def count_params(x): 

2257 """Returns the static number of elements in a variable or tensor. 

2258 

2259 Args: 

2260 x: Variable or tensor. 

2261 

2262 Returns: 

2263 Integer, the number of scalars in `x`. 

2264 

2265 Example: 

2266 

2267 >>> kvar = tf.keras.backend.zeros((2,3)) 

2268 >>> tf.keras.backend.count_params(kvar) 

2269 6 

2270 >>> tf.keras.backend.eval(kvar) 

2271 array([[0., 0., 0.], 

2272 [0., 0., 0.]], dtype=float32) 

2273 

2274 """ 

2275 return np.prod(x.shape.as_list()) 

2276 

2277 

2278@keras_export("keras.backend.cast") 

2279@tf.__internal__.dispatch.add_dispatch_support 

2280@doc_controls.do_not_generate_docs 

2281def cast(x, dtype): 

2282 """Casts a tensor to a different dtype and returns it. 

2283 

2284 You can cast a Keras variable but it still returns a Keras tensor. 

2285 

2286 Args: 

2287 x: Keras tensor (or variable). 

2288 dtype: String, either (`'float16'`, `'float32'`, or `'float64'`). 

2289 

2290 Returns: 

2291 Keras tensor with dtype `dtype`. 

2292 

2293 Examples: 

2294 Cast a float32 variable to a float64 tensor 

2295 

2296 >>> input = tf.keras.backend.ones(shape=(1,3)) 

2297 >>> print(input) 

2298 <tf.Variable 'Variable:0' shape=(1, 3) dtype=float32, 

2299 numpy=array([[1., 1., 1.]], dtype=float32)> 

2300 >>> cast_input = tf.keras.backend.cast(input, dtype='float64') 

2301 >>> print(cast_input) 

2302 tf.Tensor([[1. 1. 1.]], shape=(1, 3), dtype=float64) 

2303 

2304 """ 

2305 return tf.cast(x, dtype) 

2306 

2307 

2308# UPDATES OPS 

2309 

2310 

2311@keras_export("keras.backend.update") 

2312@doc_controls.do_not_generate_docs 

2313def update(x, new_x): 

2314 return tf.compat.v1.assign(x, new_x) 

2315 

2316 

2317@keras_export("keras.backend.update_add") 

2318@doc_controls.do_not_generate_docs 

2319def update_add(x, increment): 

2320 """Update the value of `x` by adding `increment`. 

2321 

2322 Args: 

2323 x: A Variable. 

2324 increment: A tensor of same shape as `x`. 

2325 

2326 Returns: 

2327 The variable `x` updated. 

2328 """ 

2329 return tf.compat.v1.assign_add(x, increment) 

2330 

2331 

2332@keras_export("keras.backend.update_sub") 

2333@doc_controls.do_not_generate_docs 

2334def update_sub(x, decrement): 

2335 """Update the value of `x` by subtracting `decrement`. 

2336 

2337 Args: 

2338 x: A Variable. 

2339 decrement: A tensor of same shape as `x`. 

2340 

2341 Returns: 

2342 The variable `x` updated. 

2343 """ 

2344 return tf.compat.v1.assign_sub(x, decrement) 

2345 

2346 

2347@keras_export("keras.backend.moving_average_update") 

2348@doc_controls.do_not_generate_docs 

2349def moving_average_update(x, value, momentum): 

2350 """Compute the exponential moving average of a value. 

2351 

2352 The moving average 'x' is updated with 'value' following: 

2353 

2354 ``` 

2355 x = x * momentum + value * (1 - momentum) 

2356 ``` 

2357 

2358 For example: 

2359 

2360 >>> x = tf.Variable(0.0) 

2361 >>> momentum=0.9 

2362 >>> moving_average_update(x, value = 2.0, momentum=momentum).numpy() 

2363 >>> x.numpy() 

2364 0.2 

2365 

2366 The result will be biased towards the initial value of the variable. 

2367 

2368 If the variable was initialized to zero, you can divide by 

2369 `1 - momentum ** num_updates` to debias it (Section 3 of 

2370 [Kingma et al., 2015](https://arxiv.org/abs/1412.6980)): 

2371 

2372 >>> num_updates = 1.0 

2373 >>> x_zdb = x/(1 - momentum**num_updates) 

2374 >>> x_zdb.numpy() 

2375 2.0 

2376 

2377 Args: 

2378 x: A Variable, the moving average. 

2379 value: A tensor with the same shape as `x`, the new value to be 

2380 averaged in. 

2381 momentum: The moving average momentum. 

2382 

2383 Returns: 

2384 The updated variable. 

2385 """ 

2386 if tf.__internal__.tf2.enabled(): 

2387 momentum = tf.cast(momentum, x.dtype) 

2388 value = tf.cast(value, x.dtype) 

2389 return x.assign_sub((x - value) * (1 - momentum)) 

2390 else: 

2391 return tf.__internal__.train.assign_moving_average( 

2392 x, value, momentum, zero_debias=True 

2393 ) 

2394 

2395 

2396# LINEAR ALGEBRA 

2397 

2398 

2399@keras_export("keras.backend.dot") 

2400@tf.__internal__.dispatch.add_dispatch_support 

2401@doc_controls.do_not_generate_docs 

2402def dot(x, y): 

2403 """Multiplies 2 tensors (and/or variables) and returns a tensor. 

2404 

2405 This operation corresponds to `numpy.dot(a, b, out=None)`. 

2406 

2407 Args: 

2408 x: Tensor or variable. 

2409 y: Tensor or variable. 

2410 

2411 Returns: 

2412 A tensor, dot product of `x` and `y`. 

2413 

2414 Examples: 

2415 

2416 If inputs `x` and `y` are 2-D arrays, then it is equivalent to `tf.matmul`. 

2417 >>> x = tf.keras.backend.placeholder(shape=(2, 3)) 

2418 >>> y = tf.keras.backend.placeholder(shape=(3, 4)) 

2419 >>> xy = tf.keras.backend.dot(x, y) 

2420 >>> xy 

2421 <KerasTensor: shape=(2, 4) dtype=float32 ...> 

2422 

2423 >>> x = tf.keras.backend.placeholder(shape=(32, 28, 3)) 

2424 >>> y = tf.keras.backend.placeholder(shape=(3, 4)) 

2425 >>> xy = tf.keras.backend.dot(x, y) 

2426 >>> xy 

2427 <KerasTensor: shape=(32, 28, 4) dtype=float32 ...> 

2428 

2429 If `x` is an N-D array and `y` is an M-D array (where M>=2), it is a sum 

2430 product over the last axis of `x` and the second-to-last axis of `y`. 

2431 >>> x = tf.keras.backend.random_uniform_variable( 

2432 ... shape=(2, 3), low=0., high=1.) 

2433 >>> y = tf.keras.backend.ones((4, 3, 5)) 

2434 >>> xy = tf.keras.backend.dot(x, y) 

2435 >>> tf.keras.backend.int_shape(xy) 

2436 (2, 4, 5) 

2437 """ 

2438 if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2): 

2439 x_shape = [] 

2440 for i, s in zip(int_shape(x), tf.unstack(tf.shape(x))): 

2441 if i is not None: 

2442 x_shape.append(i) 

2443 else: 

2444 x_shape.append(s) 

2445 x_shape = tuple(x_shape) 

2446 y_shape = [] 

2447 for i, s in zip(int_shape(y), tf.unstack(tf.shape(y))): 

2448 if i is not None: 

2449 y_shape.append(i) 

2450 else: 

2451 y_shape.append(s) 

2452 y_shape = tuple(y_shape) 

2453 y_permute_dim = list(range(ndim(y))) 

2454 y_permute_dim = [y_permute_dim.pop(-2)] + y_permute_dim 

2455 xt = tf.reshape(x, [-1, x_shape[-1]]) 

2456 yt = tf.reshape( 

2457 tf.compat.v1.transpose(y, perm=y_permute_dim), [y_shape[-2], -1] 

2458 ) 

2459 return tf.reshape( 

2460 tf.matmul(xt, yt), x_shape[:-1] + y_shape[:-2] + y_shape[-1:] 

2461 ) 

2462 if is_sparse(x): 

2463 out = tf.sparse.sparse_dense_matmul(x, y) 

2464 else: 

2465 out = tf.matmul(x, y) 

2466 return out 

2467 

2468 

2469@keras_export("keras.backend.batch_dot") 

2470@tf.__internal__.dispatch.add_dispatch_support 

2471@doc_controls.do_not_generate_docs 

2472def batch_dot(x, y, axes=None): 

2473 """Batchwise dot product. 

2474 

2475 `batch_dot` is used to compute dot product of `x` and `y` when 

2476 `x` and `y` are data in batch, i.e. in a shape of 

2477 `(batch_size, :)`. 

2478 `batch_dot` results in a tensor or variable with less dimensions 

2479 than the input. If the number of dimensions is reduced to 1, 

2480 we use `expand_dims` to make sure that ndim is at least 2. 

2481 

2482 Args: 

2483 x: Keras tensor or variable with `ndim >= 2`. 

2484 y: Keras tensor or variable with `ndim >= 2`. 

2485 axes: Tuple or list of integers with target dimensions, or single integer. 

2486 The sizes of `x.shape[axes[0]]` and `y.shape[axes[1]]` should be equal. 

2487 

2488 Returns: 

2489 A tensor with shape equal to the concatenation of `x`'s shape 

2490 (less the dimension that was summed over) and `y`'s shape 

2491 (less the batch dimension and the dimension that was summed over). 

2492 If the final rank is 1, we reshape it to `(batch_size, 1)`. 

2493 

2494 Examples: 

2495 

2496 >>> x_batch = tf.keras.backend.ones(shape=(32, 20, 1)) 

2497 >>> y_batch = tf.keras.backend.ones(shape=(32, 30, 20)) 

2498 >>> xy_batch_dot = tf.keras.backend.batch_dot(x_batch, y_batch, axes=(1, 2)) 

2499 >>> tf.keras.backend.int_shape(xy_batch_dot) 

2500 (32, 1, 30) 

2501 

2502 Shape inference: 

2503 Let `x`'s shape be `(100, 20)` and `y`'s shape be `(100, 30, 20)`. 

2504 If `axes` is (1, 2), to find the output shape of resultant tensor, 

2505 loop through each dimension in `x`'s shape and `y`'s shape: 

2506 * `x.shape[0]` : 100 : append to output shape 

2507 * `x.shape[1]` : 20 : do not append to output shape, 

2508 dimension 1 of `x` has been summed over. (`dot_axes[0]` = 1) 

2509 * `y.shape[0]` : 100 : do not append to output shape, 

2510 always ignore first dimension of `y` 

2511 * `y.shape[1]` : 30 : append to output shape 

2512 * `y.shape[2]` : 20 : do not append to output shape, 

2513 dimension 2 of `y` has been summed over. (`dot_axes[1]` = 2) 

2514 `output_shape` = `(100, 30)` 

2515 """ 

2516 x_shape = int_shape(x) 

2517 y_shape = int_shape(y) 

2518 

2519 x_ndim = len(x_shape) 

2520 y_ndim = len(y_shape) 

2521 

2522 if x_ndim < 2 or y_ndim < 2: 

2523 raise ValueError( 

2524 "Cannot do batch_dot on inputs " 

2525 "with rank < 2. " 

2526 "Received inputs with shapes " 

2527 + str(x_shape) 

2528 + " and " 

2529 + str(y_shape) 

2530 + "." 

2531 ) 

2532 

2533 x_batch_size = x_shape[0] 

2534 y_batch_size = y_shape[0] 

2535 

2536 if x_batch_size is not None and y_batch_size is not None: 

2537 if x_batch_size != y_batch_size: 

2538 raise ValueError( 

2539 "Cannot do batch_dot on inputs " 

2540 "with different batch sizes. " 

2541 "Received inputs with shapes " 

2542 + str(x_shape) 

2543 + " and " 

2544 + str(y_shape) 

2545 + "." 

2546 ) 

2547 if isinstance(axes, int): 

2548 axes = [axes, axes] 

2549 

2550 if axes is None: 

2551 if y_ndim == 2: 

2552 axes = [x_ndim - 1, y_ndim - 1] 

2553 else: 

2554 axes = [x_ndim - 1, y_ndim - 2] 

2555 

2556 if py_any(isinstance(a, (list, tuple)) for a in axes): 

2557 raise ValueError( 

2558 "Multiple target dimensions are not supported. " 

2559 + "Expected: None, int, (int, int), " 

2560 + "Provided: " 

2561 + str(axes) 

2562 ) 

2563 

2564 # if tuple, convert to list. 

2565 axes = list(axes) 

2566 

2567 # convert negative indices. 

2568 if axes[0] < 0: 

2569 axes[0] += x_ndim 

2570 if axes[1] < 0: 

2571 axes[1] += y_ndim 

2572 

2573 # sanity checks 

2574 if 0 in axes: 

2575 raise ValueError( 

2576 "Cannot perform batch_dot over axis 0. " 

2577 "If your inputs are not batched, " 

2578 "add a dummy batch dimension to your " 

2579 "inputs using K.expand_dims(x, 0)" 

2580 ) 

2581 a0, a1 = axes 

2582 d1 = x_shape[a0] 

2583 d2 = y_shape[a1] 

2584 

2585 if d1 is not None and d2 is not None and d1 != d2: 

2586 raise ValueError( 

2587 "Cannot do batch_dot on inputs with shapes " 

2588 + str(x_shape) 

2589 + " and " 

2590 + str(y_shape) 

2591 + " with axes=" 

2592 + str(axes) 

2593 + ". x.shape[%d] != y.shape[%d] (%d != %d)." 

2594 % (axes[0], axes[1], d1, d2) 

2595 ) 

2596 

2597 # backup ndims. Need them later. 

2598 orig_x_ndim = x_ndim 

2599 orig_y_ndim = y_ndim 

2600 

2601 # if rank is 2, expand to 3. 

2602 if x_ndim == 2: 

2603 x = tf.expand_dims(x, 1) 

2604 a0 += 1 

2605 x_ndim += 1 

2606 if y_ndim == 2: 

2607 y = tf.expand_dims(y, 2) 

2608 y_ndim += 1 

2609 

2610 # bring x's dimension to be reduced to last axis. 

2611 if a0 != x_ndim - 1: 

2612 pattern = list(range(x_ndim)) 

2613 for i in range(a0, x_ndim - 1): 

2614 pattern[i] = pattern[i + 1] 

2615 pattern[-1] = a0 

2616 x = tf.compat.v1.transpose(x, pattern) 

2617 

2618 # bring y's dimension to be reduced to axis 1. 

2619 if a1 != 1: 

2620 pattern = list(range(y_ndim)) 

2621 for i in range(a1, 1, -1): 

2622 pattern[i] = pattern[i - 1] 

2623 pattern[1] = a1 

2624 y = tf.compat.v1.transpose(y, pattern) 

2625 

2626 # normalize both inputs to rank 3. 

2627 if x_ndim > 3: 

2628 # squash middle dimensions of x. 

2629 x_shape = shape(x) 

2630 x_mid_dims = x_shape[1:-1] 

2631 x_squashed_shape = tf.stack([x_shape[0], -1, x_shape[-1]]) 

2632 x = tf.reshape(x, x_squashed_shape) 

2633 x_squashed = True 

2634 else: 

2635 x_squashed = False 

2636 

2637 if y_ndim > 3: 

2638 # squash trailing dimensions of y. 

2639 y_shape = shape(y) 

2640 y_trail_dims = y_shape[2:] 

2641 y_squashed_shape = tf.stack([y_shape[0], y_shape[1], -1]) 

2642 y = tf.reshape(y, y_squashed_shape) 

2643 y_squashed = True 

2644 else: 

2645 y_squashed = False 

2646 

2647 result = tf.matmul(x, y) 

2648 

2649 # if inputs were squashed, we have to reshape the matmul output. 

2650 output_shape = tf.shape(result) 

2651 do_reshape = False 

2652 

2653 if x_squashed: 

2654 output_shape = tf.concat( 

2655 [output_shape[:1], x_mid_dims, output_shape[-1:]], 0 

2656 ) 

2657 do_reshape = True 

2658 

2659 if y_squashed: 

2660 output_shape = tf.concat([output_shape[:-1], y_trail_dims], 0) 

2661 do_reshape = True 

2662 

2663 if do_reshape: 

2664 result = tf.reshape(result, output_shape) 

2665 

2666 # if the inputs were originally rank 2, we remove the added 1 dim. 

2667 if orig_x_ndim == 2: 

2668 result = tf.squeeze(result, 1) 

2669 elif orig_y_ndim == 2: 

2670 result = tf.squeeze(result, -1) 

2671 

2672 return result 

2673 

2674 

2675@keras_export("keras.backend.transpose") 

2676@tf.__internal__.dispatch.add_dispatch_support 

2677@doc_controls.do_not_generate_docs 

2678def transpose(x): 

2679 """Transposes a tensor and returns it. 

2680 

2681 Args: 

2682 x: Tensor or variable. 

2683 

2684 Returns: 

2685 A tensor. 

2686 

2687 Examples: 

2688 

2689 >>> var = tf.keras.backend.variable([[1, 2, 3], [4, 5, 6]]) 

2690 >>> tf.keras.backend.eval(var) 

2691 array([[1., 2., 3.], 

2692 [4., 5., 6.]], dtype=float32) 

2693 >>> var_transposed = tf.keras.backend.transpose(var) 

2694 >>> tf.keras.backend.eval(var_transposed) 

2695 array([[1., 4.], 

2696 [2., 5.], 

2697 [3., 6.]], dtype=float32) 

2698 >>> input = tf.keras.backend.placeholder((2, 3)) 

2699 >>> input 

2700 <KerasTensor: shape=(2, 3) dtype=float32 ...> 

2701 >>> input_transposed = tf.keras.backend.transpose(input) 

2702 >>> input_transposed 

2703 <KerasTensor: shape=(3, 2) dtype=float32 ...> 

2704 """ 

2705 return tf.compat.v1.transpose(x) 

2706 

2707 

2708@keras_export("keras.backend.gather") 

2709@tf.__internal__.dispatch.add_dispatch_support 

2710@doc_controls.do_not_generate_docs 

2711def gather(reference, indices): 

2712 """Retrieves the elements of indices `indices` in the tensor `reference`. 

2713 

2714 Args: 

2715 reference: A tensor. 

2716 indices: An integer tensor of indices. 

2717 

2718 Returns: 

2719 A tensor of same type as `reference`. 

2720 

2721 Examples: 

2722 

2723 >>> var = tf.keras.backend.variable([[1, 2, 3], [4, 5, 6]]) 

2724 >>> tf.keras.backend.eval(var) 

2725 array([[1., 2., 3.], 

2726 [4., 5., 6.]], dtype=float32) 

2727 >>> var_gathered = tf.keras.backend.gather(var, [0]) 

2728 >>> tf.keras.backend.eval(var_gathered) 

2729 array([[1., 2., 3.]], dtype=float32) 

2730 >>> var_gathered = tf.keras.backend.gather(var, [1]) 

2731 >>> tf.keras.backend.eval(var_gathered) 

2732 array([[4., 5., 6.]], dtype=float32) 

2733 >>> var_gathered = tf.keras.backend.gather(var, [0,1,0]) 

2734 >>> tf.keras.backend.eval(var_gathered) 

2735 array([[1., 2., 3.], 

2736 [4., 5., 6.], 

2737 [1., 2., 3.]], dtype=float32) 

2738 """ 

2739 return tf.compat.v1.gather(reference, indices) 

2740 

2741 

2742# ELEMENT-WISE OPERATIONS 

2743 

2744 

2745@keras_export("keras.backend.max") 

2746@tf.__internal__.dispatch.add_dispatch_support 

2747@doc_controls.do_not_generate_docs 

2748def max(x, axis=None, keepdims=False): 

2749 """Maximum value in a tensor. 

2750 

2751 Args: 

2752 x: A tensor or variable. 

2753 axis: An integer, the axis to find maximum values. 

2754 keepdims: A boolean, whether to keep the dimensions or not. 

2755 If `keepdims` is `False`, the rank of the tensor is reduced 

2756 by 1. If `keepdims` is `True`, 

2757 the reduced dimension is retained with length 1. 

2758 

2759 Returns: 

2760 A tensor with maximum values of `x`. 

2761 """ 

2762 return tf.reduce_max(x, axis, keepdims) 

2763 

2764 

2765@keras_export("keras.backend.min") 

2766@tf.__internal__.dispatch.add_dispatch_support 

2767@doc_controls.do_not_generate_docs 

2768def min(x, axis=None, keepdims=False): 

2769 """Minimum value in a tensor. 

2770 

2771 Args: 

2772 x: A tensor or variable. 

2773 axis: An integer, the axis to find minimum values. 

2774 keepdims: A boolean, whether to keep the dimensions or not. 

2775 If `keepdims` is `False`, the rank of the tensor is reduced 

2776 by 1. If `keepdims` is `True`, 

2777 the reduced dimension is retained with length 1. 

2778 

2779 Returns: 

2780 A tensor with minimum values of `x`. 

2781 """ 

2782 return tf.reduce_min(x, axis, keepdims) 

2783 

2784 

2785@keras_export("keras.backend.sum") 

2786@tf.__internal__.dispatch.add_dispatch_support 

2787@doc_controls.do_not_generate_docs 

2788def sum(x, axis=None, keepdims=False): 

2789 """Sum of the values in a tensor, alongside the specified axis. 

2790 

2791 Args: 

2792 x: A tensor or variable. 

2793 axis: An integer, the axis to sum over. 

2794 keepdims: A boolean, whether to keep the dimensions or not. 

2795 If `keepdims` is `False`, the rank of the tensor is reduced 

2796 by 1. If `keepdims` is `True`, 

2797 the reduced dimension is retained with length 1. 

2798 

2799 Returns: 

2800 A tensor with sum of `x`. 

2801 """ 

2802 return tf.reduce_sum(x, axis, keepdims) 

2803 

2804 

2805@keras_export("keras.backend.prod") 

2806@tf.__internal__.dispatch.add_dispatch_support 

2807@doc_controls.do_not_generate_docs 

2808def prod(x, axis=None, keepdims=False): 

2809 """Multiplies the values in a tensor, alongside the specified axis. 

2810 

2811 Args: 

2812 x: A tensor or variable. 

2813 axis: An integer, the axis to compute the product. 

2814 keepdims: A boolean, whether to keep the dimensions or not. 

2815 If `keepdims` is `False`, the rank of the tensor is reduced 

2816 by 1. If `keepdims` is `True`, 

2817 the reduced dimension is retained with length 1. 

2818 

2819 Returns: 

2820 A tensor with the product of elements of `x`. 

2821 """ 

2822 return tf.reduce_prod(x, axis, keepdims) 

2823 

2824 

2825@keras_export("keras.backend.cumsum") 

2826@tf.__internal__.dispatch.add_dispatch_support 

2827@doc_controls.do_not_generate_docs 

2828def cumsum(x, axis=0): 

2829 """Cumulative sum of the values in a tensor, alongside the specified axis. 

2830 

2831 Args: 

2832 x: A tensor or variable. 

2833 axis: An integer, the axis to compute the sum. 

2834 

2835 Returns: 

2836 A tensor of the cumulative sum of values of `x` along `axis`. 

2837 """ 

2838 return tf.cumsum(x, axis=axis) 

2839 

2840 

2841@keras_export("keras.backend.cumprod") 

2842@tf.__internal__.dispatch.add_dispatch_support 

2843@doc_controls.do_not_generate_docs 

2844def cumprod(x, axis=0): 

2845 """Cumulative product of the values in a tensor alongside `axis`. 

2846 

2847 Args: 

2848 x: A tensor or variable. 

2849 axis: An integer, the axis to compute the product. 

2850 

2851 Returns: 

2852 A tensor of the cumulative product of values of `x` along `axis`. 

2853 """ 

2854 return tf.math.cumprod(x, axis=axis) 

2855 

2856 

2857@keras_export("keras.backend.var") 

2858@doc_controls.do_not_generate_docs 

2859def var(x, axis=None, keepdims=False): 

2860 """Variance of a tensor, alongside the specified axis. 

2861 

2862 Args: 

2863 x: A tensor or variable. 

2864 axis: An integer, the axis to compute the variance. 

2865 keepdims: A boolean, whether to keep the dimensions or not. 

2866 If `keepdims` is `False`, the rank of the tensor is reduced 

2867 by 1. If `keepdims` is `True`, 

2868 the reduced dimension is retained with length 1. 

2869 

2870 Returns: 

2871 A tensor with the variance of elements of `x`. 

2872 """ 

2873 if x.dtype.base_dtype == tf.bool: 

2874 x = tf.cast(x, floatx()) 

2875 return tf.math.reduce_variance(x, axis=axis, keepdims=keepdims) 

2876 

2877 

2878@keras_export("keras.backend.std") 

2879@tf.__internal__.dispatch.add_dispatch_support 

2880@doc_controls.do_not_generate_docs 

2881def std(x, axis=None, keepdims=False): 

2882 """Standard deviation of a tensor, alongside the specified axis. 

2883 

2884 It is an alias to `tf.math.reduce_std`. 

2885 

2886 Args: 

2887 x: A tensor or variable. It should have numerical dtypes. Boolean type 

2888 inputs will be converted to float. 

2889 axis: An integer, the axis to compute the standard deviation. If `None` 

2890 (the default), reduces all dimensions. Must be in the range 

2891 `[-rank(x), rank(x))`. 

2892 keepdims: A boolean, whether to keep the dimensions or not. 

2893 If `keepdims` is `False`, the rank of the tensor is reduced 

2894 by 1. If `keepdims` is `True`, the reduced dimension is retained 

2895 with length 1. 

2896 

2897 Returns: 

2898 A tensor with the standard deviation of elements of `x` with same dtype. 

2899 Boolean type input will be converted to float. 

2900 """ 

2901 if x.dtype.base_dtype == tf.bool: 

2902 x = tf.cast(x, floatx()) 

2903 return tf.math.reduce_std(x, axis=axis, keepdims=keepdims) 

2904 

2905 

2906@keras_export("keras.backend.mean") 

2907@tf.__internal__.dispatch.add_dispatch_support 

2908@doc_controls.do_not_generate_docs 

2909def mean(x, axis=None, keepdims=False): 

2910 """Mean of a tensor, alongside the specified axis. 

2911 

2912 Args: 

2913 x: A tensor or variable. 

2914 axis: A list of integer. Axes to compute the mean. 

2915 keepdims: A boolean, whether to keep the dimensions or not. 

2916 If `keepdims` is `False`, the rank of the tensor is reduced 

2917 by 1 for each entry in `axis`. If `keepdims` is `True`, 

2918 the reduced dimensions are retained with length 1. 

2919 

2920 Returns: 

2921 A tensor with the mean of elements of `x`. 

2922 """ 

2923 if x.dtype.base_dtype == tf.bool: 

2924 x = tf.cast(x, floatx()) 

2925 return tf.reduce_mean(x, axis, keepdims) 

2926 

2927 

2928@keras_export("keras.backend.any") 

2929@tf.__internal__.dispatch.add_dispatch_support 

2930@doc_controls.do_not_generate_docs 

2931def any(x, axis=None, keepdims=False): 

2932 """Bitwise reduction (logical OR). 

2933 

2934 Args: 

2935 x: Tensor or variable. 

2936 axis: axis along which to perform the reduction. 

2937 keepdims: whether the drop or broadcast the reduction axes. 

2938 

2939 Returns: 

2940 A uint8 tensor (0s and 1s). 

2941 """ 

2942 x = tf.cast(x, tf.bool) 

2943 return tf.reduce_any(x, axis, keepdims) 

2944 

2945 

2946@keras_export("keras.backend.all") 

2947@tf.__internal__.dispatch.add_dispatch_support 

2948@doc_controls.do_not_generate_docs 

2949def all(x, axis=None, keepdims=False): 

2950 """Bitwise reduction (logical AND). 

2951 

2952 Args: 

2953 x: Tensor or variable. 

2954 axis: axis along which to perform the reduction. 

2955 keepdims: whether the drop or broadcast the reduction axes. 

2956 

2957 Returns: 

2958 A uint8 tensor (0s and 1s). 

2959 """ 

2960 x = tf.cast(x, tf.bool) 

2961 return tf.reduce_all(x, axis, keepdims) 

2962 

2963 

2964@keras_export("keras.backend.argmax") 

2965@tf.__internal__.dispatch.add_dispatch_support 

2966@doc_controls.do_not_generate_docs 

2967def argmax(x, axis=-1): 

2968 """Returns the index of the maximum value along an axis. 

2969 

2970 Args: 

2971 x: Tensor or variable. 

2972 axis: axis along which to perform the reduction. 

2973 

2974 Returns: 

2975 A tensor. 

2976 """ 

2977 return tf.argmax(x, axis) 

2978 

2979 

2980@keras_export("keras.backend.argmin") 

2981@tf.__internal__.dispatch.add_dispatch_support 

2982@doc_controls.do_not_generate_docs 

2983def argmin(x, axis=-1): 

2984 """Returns the index of the minimum value along an axis. 

2985 

2986 Args: 

2987 x: Tensor or variable. 

2988 axis: axis along which to perform the reduction. 

2989 

2990 Returns: 

2991 A tensor. 

2992 """ 

2993 return tf.argmin(x, axis) 

2994 

2995 

2996@keras_export("keras.backend.square") 

2997@tf.__internal__.dispatch.add_dispatch_support 

2998@doc_controls.do_not_generate_docs 

2999def square(x): 

3000 """Element-wise square. 

3001 

3002 Args: 

3003 x: Tensor or variable. 

3004 

3005 Returns: 

3006 A tensor. 

3007 """ 

3008 return tf.square(x) 

3009 

3010 

3011@keras_export("keras.backend.abs") 

3012@tf.__internal__.dispatch.add_dispatch_support 

3013@doc_controls.do_not_generate_docs 

3014def abs(x): 

3015 """Element-wise absolute value. 

3016 

3017 Args: 

3018 x: Tensor or variable. 

3019 

3020 Returns: 

3021 A tensor. 

3022 """ 

3023 return tf.abs(x) 

3024 

3025 

3026@keras_export("keras.backend.sqrt") 

3027@tf.__internal__.dispatch.add_dispatch_support 

3028@doc_controls.do_not_generate_docs 

3029def sqrt(x): 

3030 """Element-wise square root. 

3031 

3032 This function clips negative tensor values to 0 before computing the 

3033 square root. 

3034 

3035 Args: 

3036 x: Tensor or variable. 

3037 

3038 Returns: 

3039 A tensor. 

3040 """ 

3041 zero = _constant_to_tensor(0.0, x.dtype.base_dtype) 

3042 x = tf.maximum(x, zero) 

3043 return tf.sqrt(x) 

3044 

3045 

3046@keras_export("keras.backend.exp") 

3047@tf.__internal__.dispatch.add_dispatch_support 

3048@doc_controls.do_not_generate_docs 

3049def exp(x): 

3050 """Element-wise exponential. 

3051 

3052 Args: 

3053 x: Tensor or variable. 

3054 

3055 Returns: 

3056 A tensor. 

3057 """ 

3058 return tf.exp(x) 

3059 

3060 

3061@keras_export("keras.backend.log") 

3062@tf.__internal__.dispatch.add_dispatch_support 

3063@doc_controls.do_not_generate_docs 

3064def log(x): 

3065 """Element-wise log. 

3066 

3067 Args: 

3068 x: Tensor or variable. 

3069 

3070 Returns: 

3071 A tensor. 

3072 """ 

3073 return tf.math.log(x) 

3074 

3075 

3076def logsumexp(x, axis=None, keepdims=False): 

3077 """Computes log(sum(exp(elements across dimensions of a tensor))). 

3078 

3079 This function is more numerically stable than log(sum(exp(x))). 

3080 It avoids overflows caused by taking the exp of large inputs and 

3081 underflows caused by taking the log of small inputs. 

3082 

3083 Args: 

3084 x: A tensor or variable. 

3085 axis: An integer, the axis to reduce over. 

3086 keepdims: A boolean, whether to keep the dimensions or not. 

3087 If `keepdims` is `False`, the rank of the tensor is reduced 

3088 by 1. If `keepdims` is `True`, the reduced dimension is 

3089 retained with length 1. 

3090 

3091 Returns: 

3092 The reduced tensor. 

3093 """ 

3094 return tf.reduce_logsumexp(x, axis, keepdims) 

3095 

3096 

3097@keras_export("keras.backend.round") 

3098@tf.__internal__.dispatch.add_dispatch_support 

3099@doc_controls.do_not_generate_docs 

3100def round(x): 

3101 """Element-wise rounding to the closest integer. 

3102 

3103 In case of tie, the rounding mode used is "half to even". 

3104 

3105 Args: 

3106 x: Tensor or variable. 

3107 

3108 Returns: 

3109 A tensor. 

3110 """ 

3111 return tf.round(x) 

3112 

3113 

3114@keras_export("keras.backend.sign") 

3115@tf.__internal__.dispatch.add_dispatch_support 

3116@doc_controls.do_not_generate_docs 

3117def sign(x): 

3118 """Element-wise sign. 

3119 

3120 Args: 

3121 x: Tensor or variable. 

3122 

3123 Returns: 

3124 A tensor. 

3125 """ 

3126 return tf.sign(x) 

3127 

3128 

3129@keras_export("keras.backend.pow") 

3130@tf.__internal__.dispatch.add_dispatch_support 

3131@doc_controls.do_not_generate_docs 

3132def pow(x, a): 

3133 """Element-wise exponentiation. 

3134 

3135 Args: 

3136 x: Tensor or variable. 

3137 a: Python integer. 

3138 

3139 Returns: 

3140 A tensor. 

3141 """ 

3142 return tf.pow(x, a) 

3143 

3144 

3145@keras_export("keras.backend.clip") 

3146@tf.__internal__.dispatch.add_dispatch_support 

3147@doc_controls.do_not_generate_docs 

3148def clip(x, min_value, max_value): 

3149 """Element-wise value clipping. 

3150 

3151 Args: 

3152 x: Tensor or variable. 

3153 min_value: Python float, integer, or tensor. 

3154 max_value: Python float, integer, or tensor. 

3155 

3156 Returns: 

3157 A tensor. 

3158 """ 

3159 if isinstance(min_value, (int, float)) and isinstance( 

3160 max_value, (int, float) 

3161 ): 

3162 if max_value < min_value: 

3163 max_value = min_value 

3164 if min_value is None: 

3165 min_value = -np.inf 

3166 if max_value is None: 

3167 max_value = np.inf 

3168 return tf.clip_by_value(x, min_value, max_value) 

3169 

3170 

3171@keras_export("keras.backend.equal") 

3172@tf.__internal__.dispatch.add_dispatch_support 

3173@doc_controls.do_not_generate_docs 

3174def equal(x, y): 

3175 """Element-wise equality between two tensors. 

3176 

3177 Args: 

3178 x: Tensor or variable. 

3179 y: Tensor or variable. 

3180 

3181 Returns: 

3182 A bool tensor. 

3183 """ 

3184 return tf.equal(x, y) 

3185 

3186 

3187@keras_export("keras.backend.not_equal") 

3188@tf.__internal__.dispatch.add_dispatch_support 

3189@doc_controls.do_not_generate_docs 

3190def not_equal(x, y): 

3191 """Element-wise inequality between two tensors. 

3192 

3193 Args: 

3194 x: Tensor or variable. 

3195 y: Tensor or variable. 

3196 

3197 Returns: 

3198 A bool tensor. 

3199 """ 

3200 return tf.not_equal(x, y) 

3201 

3202 

3203@keras_export("keras.backend.greater") 

3204@tf.__internal__.dispatch.add_dispatch_support 

3205@doc_controls.do_not_generate_docs 

3206def greater(x, y): 

3207 """Element-wise truth value of (x > y). 

3208 

3209 Args: 

3210 x: Tensor or variable. 

3211 y: Tensor or variable. 

3212 

3213 Returns: 

3214 A bool tensor. 

3215 """ 

3216 return tf.greater(x, y) 

3217 

3218 

3219@keras_export("keras.backend.greater_equal") 

3220@tf.__internal__.dispatch.add_dispatch_support 

3221@doc_controls.do_not_generate_docs 

3222def greater_equal(x, y): 

3223 """Element-wise truth value of (x >= y). 

3224 

3225 Args: 

3226 x: Tensor or variable. 

3227 y: Tensor or variable. 

3228 

3229 Returns: 

3230 A bool tensor. 

3231 """ 

3232 return tf.greater_equal(x, y) 

3233 

3234 

3235@keras_export("keras.backend.less") 

3236@tf.__internal__.dispatch.add_dispatch_support 

3237@doc_controls.do_not_generate_docs 

3238def less(x, y): 

3239 """Element-wise truth value of (x < y). 

3240 

3241 Args: 

3242 x: Tensor or variable. 

3243 y: Tensor or variable. 

3244 

3245 Returns: 

3246 A bool tensor. 

3247 """ 

3248 return tf.less(x, y) 

3249 

3250 

3251@keras_export("keras.backend.less_equal") 

3252@tf.__internal__.dispatch.add_dispatch_support 

3253@doc_controls.do_not_generate_docs 

3254def less_equal(x, y): 

3255 """Element-wise truth value of (x <= y). 

3256 

3257 Args: 

3258 x: Tensor or variable. 

3259 y: Tensor or variable. 

3260 

3261 Returns: 

3262 A bool tensor. 

3263 """ 

3264 return tf.less_equal(x, y) 

3265 

3266 

3267@keras_export("keras.backend.maximum") 

3268@tf.__internal__.dispatch.add_dispatch_support 

3269@doc_controls.do_not_generate_docs 

3270def maximum(x, y): 

3271 """Element-wise maximum of two tensors. 

3272 

3273 Args: 

3274 x: Tensor or variable. 

3275 y: Tensor or variable. 

3276 

3277 Returns: 

3278 A tensor with the element wise maximum value(s) of `x` and `y`. 

3279 

3280 Examples: 

3281 

3282 >>> x = tf.Variable([[1, 2], [3, 4]]) 

3283 >>> y = tf.Variable([[2, 1], [0, -1]]) 

3284 >>> m = tf.keras.backend.maximum(x, y) 

3285 >>> m 

3286 <tf.Tensor: shape=(2, 2), dtype=int32, numpy= 

3287 array([[2, 2], 

3288 [3, 4]], dtype=int32)> 

3289 """ 

3290 return tf.maximum(x, y) 

3291 

3292 

3293@keras_export("keras.backend.minimum") 

3294@tf.__internal__.dispatch.add_dispatch_support 

3295@doc_controls.do_not_generate_docs 

3296def minimum(x, y): 

3297 """Element-wise minimum of two tensors. 

3298 

3299 Args: 

3300 x: Tensor or variable. 

3301 y: Tensor or variable. 

3302 

3303 Returns: 

3304 A tensor. 

3305 """ 

3306 return tf.minimum(x, y) 

3307 

3308 

3309@keras_export("keras.backend.sin") 

3310@tf.__internal__.dispatch.add_dispatch_support 

3311@doc_controls.do_not_generate_docs 

3312def sin(x): 

3313 """Computes sin of x element-wise. 

3314 

3315 Args: 

3316 x: Tensor or variable. 

3317 

3318 Returns: 

3319 A tensor. 

3320 """ 

3321 return tf.sin(x) 

3322 

3323 

3324@keras_export("keras.backend.cos") 

3325@tf.__internal__.dispatch.add_dispatch_support 

3326@doc_controls.do_not_generate_docs 

3327def cos(x): 

3328 """Computes cos of x element-wise. 

3329 

3330 Args: 

3331 x: Tensor or variable. 

3332 

3333 Returns: 

3334 A tensor. 

3335 """ 

3336 return tf.cos(x) 

3337 

3338 

3339def _regular_normalize_batch_in_training( 

3340 x, gamma, beta, reduction_axes, epsilon=1e-3 

3341): 

3342 """Non-fused version of `normalize_batch_in_training`. 

3343 

3344 Args: 

3345 x: Input tensor or variable. 

3346 gamma: Tensor by which to scale the input. 

3347 beta: Tensor with which to center the input. 

3348 reduction_axes: iterable of integers, 

3349 axes over which to normalize. 

3350 epsilon: Fuzz factor. 

3351 

3352 Returns: 

3353 A tuple length of 3, `(normalized_tensor, mean, variance)`. 

3354 """ 

3355 mean, var = tf.compat.v1.nn.moments(x, reduction_axes, None, None, False) 

3356 normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon) 

3357 return normed, mean, var 

3358 

3359 

3360def _broadcast_normalize_batch_in_training( 

3361 x, gamma, beta, reduction_axes, epsilon=1e-3 

3362): 

3363 """Non-fused, broadcast version of `normalize_batch_in_training`. 

3364 

3365 Args: 

3366 x: Input tensor or variable. 

3367 gamma: Tensor by which to scale the input. 

3368 beta: Tensor with which to center the input. 

3369 reduction_axes: iterable of integers, 

3370 axes over which to normalize. 

3371 epsilon: Fuzz factor. 

3372 

3373 Returns: 

3374 A tuple length of 3, `(normalized_tensor, mean, variance)`. 

3375 """ 

3376 mean, var = tf.compat.v1.nn.moments(x, reduction_axes, None, None, False) 

3377 target_shape = [] 

3378 for axis in range(ndim(x)): 

3379 if axis in reduction_axes: 

3380 target_shape.append(1) 

3381 else: 

3382 target_shape.append(tf.shape(x)[axis]) 

3383 target_shape = tf.stack(target_shape) 

3384 

3385 broadcast_mean = tf.reshape(mean, target_shape) 

3386 broadcast_var = tf.reshape(var, target_shape) 

3387 if gamma is None: 

3388 broadcast_gamma = None 

3389 else: 

3390 broadcast_gamma = tf.reshape(gamma, target_shape) 

3391 if beta is None: 

3392 broadcast_beta = None 

3393 else: 

3394 broadcast_beta = tf.reshape(beta, target_shape) 

3395 

3396 normed = tf.nn.batch_normalization( 

3397 x, 

3398 broadcast_mean, 

3399 broadcast_var, 

3400 broadcast_beta, 

3401 broadcast_gamma, 

3402 epsilon, 

3403 ) 

3404 return normed, mean, var 

3405 

3406 

3407def _fused_normalize_batch_in_training( 

3408 x, gamma, beta, reduction_axes, epsilon=1e-3 

3409): 

3410 """Fused version of `normalize_batch_in_training`. 

3411 

3412 Args: 

3413 x: Input tensor or variable. 

3414 gamma: Tensor by which to scale the input. 

3415 beta: Tensor with which to center the input. 

3416 reduction_axes: iterable of integers, 

3417 axes over which to normalize. 

3418 epsilon: Fuzz factor. 

3419 

3420 Returns: 

3421 A tuple length of 3, `(normalized_tensor, mean, variance)`. 

3422 """ 

3423 if list(reduction_axes) == [0, 1, 2]: 

3424 normalization_axis = 3 

3425 tf_data_format = "NHWC" 

3426 else: 

3427 normalization_axis = 1 

3428 tf_data_format = "NCHW" 

3429 

3430 if gamma is None: 

3431 gamma = tf.constant( 

3432 1.0, dtype=x.dtype, shape=[x.shape[normalization_axis]] 

3433 ) 

3434 if beta is None: 

3435 beta = tf.constant( 

3436 0.0, dtype=x.dtype, shape=[x.shape[normalization_axis]] 

3437 ) 

3438 

3439 return tf.compat.v1.nn.fused_batch_norm( 

3440 x, gamma, beta, epsilon=epsilon, data_format=tf_data_format 

3441 ) 

3442 

3443 

3444@keras_export("keras.backend.normalize_batch_in_training") 

3445@doc_controls.do_not_generate_docs 

3446def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3): 

3447 """Computes mean and std for batch then apply batch_normalization on batch. 

3448 

3449 Args: 

3450 x: Input tensor or variable. 

3451 gamma: Tensor by which to scale the input. 

3452 beta: Tensor with which to center the input. 

3453 reduction_axes: iterable of integers, 

3454 axes over which to normalize. 

3455 epsilon: Fuzz factor. 

3456 

3457 Returns: 

3458 A tuple length of 3, `(normalized_tensor, mean, variance)`. 

3459 """ 

3460 if ndim(x) == 4 and list(reduction_axes) in [[0, 1, 2], [0, 2, 3]]: 

3461 if not _has_nchw_support() and list(reduction_axes) == [0, 2, 3]: 

3462 return _broadcast_normalize_batch_in_training( 

3463 x, gamma, beta, reduction_axes, epsilon=epsilon 

3464 ) 

3465 return _fused_normalize_batch_in_training( 

3466 x, gamma, beta, reduction_axes, epsilon=epsilon 

3467 ) 

3468 else: 

3469 if sorted(reduction_axes) == list(range(ndim(x)))[:-1]: 

3470 return _regular_normalize_batch_in_training( 

3471 x, gamma, beta, reduction_axes, epsilon=epsilon 

3472 ) 

3473 else: 

3474 return _broadcast_normalize_batch_in_training( 

3475 x, gamma, beta, reduction_axes, epsilon=epsilon 

3476 ) 

3477 

3478 

3479@keras_export("keras.backend.batch_normalization") 

3480@tf.__internal__.dispatch.add_dispatch_support 

3481@doc_controls.do_not_generate_docs 

3482def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3): 

3483 """Applies batch normalization on x given mean, var, beta and gamma. 

3484 

3485 I.e. returns: 

3486 `output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta` 

3487 

3488 Args: 

3489 x: Input tensor or variable. 

3490 mean: Mean of batch. 

3491 var: Variance of batch. 

3492 beta: Tensor with which to center the input. 

3493 gamma: Tensor by which to scale the input. 

3494 axis: Integer, the axis that should be normalized. 

3495 (typically the features axis). 

3496 epsilon: Fuzz factor. 

3497 

3498 Returns: 

3499 A tensor. 

3500 """ 

3501 if ndim(x) == 4: 

3502 # The CPU implementation of `fused_batch_norm` only supports NHWC 

3503 if axis == 1 or axis == -3: 

3504 tf_data_format = "NCHW" 

3505 elif axis == 3 or axis == -1: 

3506 tf_data_format = "NHWC" 

3507 else: 

3508 tf_data_format = None 

3509 

3510 if ( 

3511 tf_data_format == "NHWC" 

3512 or tf_data_format == "NCHW" 

3513 and _has_nchw_support() 

3514 ): 

3515 # The mean / var / beta / gamma tensors may be broadcasted 

3516 # so they may have extra axes of size 1, which should be squeezed. 

3517 if ndim(mean) > 1: 

3518 mean = tf.reshape(mean, [-1]) 

3519 if ndim(var) > 1: 

3520 var = tf.reshape(var, [-1]) 

3521 if beta is None: 

3522 beta = zeros_like(mean) 

3523 elif ndim(beta) > 1: 

3524 beta = tf.reshape(beta, [-1]) 

3525 if gamma is None: 

3526 gamma = ones_like(mean) 

3527 elif ndim(gamma) > 1: 

3528 gamma = tf.reshape(gamma, [-1]) 

3529 y, _, _ = tf.compat.v1.nn.fused_batch_norm( 

3530 x, 

3531 gamma, 

3532 beta, 

3533 epsilon=epsilon, 

3534 mean=mean, 

3535 variance=var, 

3536 data_format=tf_data_format, 

3537 is_training=False, 

3538 ) 

3539 return y 

3540 return tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon) 

3541 

3542 

3543# SHAPE OPERATIONS 

3544 

3545 

3546@keras_export("keras.backend.concatenate") 

3547@tf.__internal__.dispatch.add_dispatch_support 

3548@doc_controls.do_not_generate_docs 

3549def concatenate(tensors, axis=-1): 

3550 """Concatenates a list of tensors alongside the specified axis. 

3551 

3552 Args: 

3553 tensors: list of tensors to concatenate. 

3554 axis: concatenation axis. 

3555 

3556 Returns: 

3557 A tensor. 

3558 

3559 Example: 

3560 

3561 >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 

3562 >>> b = tf.constant([[10, 20, 30], [40, 50, 60], [70, 80, 90]]) 

3563 >>> tf.keras.backend.concatenate((a, b), axis=-1) 

3564 <tf.Tensor: shape=(3, 6), dtype=int32, numpy= 

3565 array([[ 1, 2, 3, 10, 20, 30], 

3566 [ 4, 5, 6, 40, 50, 60], 

3567 [ 7, 8, 9, 70, 80, 90]], dtype=int32)> 

3568 

3569 """ 

3570 if axis < 0: 

3571 rank = ndim(tensors[0]) 

3572 if rank: 

3573 axis %= rank 

3574 else: 

3575 axis = 0 

3576 

3577 if py_all(is_sparse(x) for x in tensors): 

3578 return tf.compat.v1.sparse_concat(axis, tensors) 

3579 elif py_all(isinstance(x, tf.RaggedTensor) for x in tensors): 

3580 return tf.concat(tensors, axis) 

3581 else: 

3582 return tf.concat([to_dense(x) for x in tensors], axis) 

3583 

3584 

3585@keras_export("keras.backend.reshape") 

3586@tf.__internal__.dispatch.add_dispatch_support 

3587@doc_controls.do_not_generate_docs 

3588def reshape(x, shape): 

3589 """Reshapes a tensor to the specified shape. 

3590 

3591 Args: 

3592 x: Tensor or variable. 

3593 shape: Target shape tuple. 

3594 

3595 Returns: 

3596 A tensor. 

3597 

3598 Example: 

3599 

3600 >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) 

3601 >>> a 

3602 <tf.Tensor: shape=(4, 3), dtype=int32, numpy= 

3603 array([[ 1, 2, 3], 

3604 [ 4, 5, 6], 

3605 [ 7, 8, 9], 

3606 [10, 11, 12]], dtype=int32)> 

3607 >>> tf.keras.backend.reshape(a, shape=(2, 6)) 

3608 <tf.Tensor: shape=(2, 6), dtype=int32, numpy= 

3609 array([[ 1, 2, 3, 4, 5, 6], 

3610 [ 7, 8, 9, 10, 11, 12]], dtype=int32)> 

3611 

3612 """ 

3613 return tf.reshape(x, shape) 

3614 

3615 

3616@keras_export("keras.backend.permute_dimensions") 

3617@tf.__internal__.dispatch.add_dispatch_support 

3618@doc_controls.do_not_generate_docs 

3619def permute_dimensions(x, pattern): 

3620 """Permutes axes in a tensor. 

3621 

3622 Args: 

3623 x: Tensor or variable. 

3624 pattern: A tuple of 

3625 dimension indices, e.g. `(0, 2, 1)`. 

3626 

3627 Returns: 

3628 A tensor. 

3629 

3630 Example: 

3631 

3632 >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) 

3633 >>> a 

3634 <tf.Tensor: shape=(4, 3), dtype=int32, numpy= 

3635 array([[ 1, 2, 3], 

3636 [ 4, 5, 6], 

3637 [ 7, 8, 9], 

3638 [10, 11, 12]], dtype=int32)> 

3639 >>> tf.keras.backend.permute_dimensions(a, pattern=(1, 0)) 

3640 <tf.Tensor: shape=(3, 4), dtype=int32, numpy= 

3641 array([[ 1, 4, 7, 10], 

3642 [ 2, 5, 8, 11], 

3643 [ 3, 6, 9, 12]], dtype=int32)> 

3644 

3645 """ 

3646 return tf.compat.v1.transpose(x, perm=pattern) 

3647 

3648 

3649@keras_export("keras.backend.resize_images") 

3650@tf.__internal__.dispatch.add_dispatch_support 

3651@doc_controls.do_not_generate_docs 

3652def resize_images( 

3653 x, height_factor, width_factor, data_format, interpolation="nearest" 

3654): 

3655 """Resizes the images contained in a 4D tensor. 

3656 

3657 Args: 

3658 x: Tensor or variable to resize. 

3659 height_factor: Positive integer. 

3660 width_factor: Positive integer. 

3661 data_format: One of `"channels_first"`, `"channels_last"`. 

3662 interpolation: A string, one of `"area"`, `"bicubic"`, `"bilinear"`, 

3663 `"gaussian"`, `"lanczos3"`, `"lanczos5"`, `"mitchellcubic"`, 

3664 `"nearest"`. 

3665 

3666 Returns: 

3667 A tensor. 

3668 

3669 Raises: 

3670 ValueError: in case of incorrect value for 

3671 `data_format` or `interpolation`. 

3672 """ 

3673 if data_format == "channels_first": 

3674 rows, cols = 2, 3 

3675 elif data_format == "channels_last": 

3676 rows, cols = 1, 2 

3677 else: 

3678 raise ValueError(f"Invalid `data_format` argument: {data_format}") 

3679 

3680 new_shape = x.shape[rows : cols + 1] 

3681 if new_shape.is_fully_defined(): 

3682 new_shape = tf.constant(new_shape.as_list(), dtype="int32") 

3683 else: 

3684 new_shape = tf.shape(x)[rows : cols + 1] 

3685 new_shape *= tf.constant( 

3686 np.array([height_factor, width_factor], dtype="int32") 

3687 ) 

3688 

3689 if data_format == "channels_first": 

3690 x = permute_dimensions(x, [0, 2, 3, 1]) 

3691 interpolations = { 

3692 "area": tf.image.ResizeMethod.AREA, 

3693 "bicubic": tf.image.ResizeMethod.BICUBIC, 

3694 "bilinear": tf.image.ResizeMethod.BILINEAR, 

3695 "gaussian": tf.image.ResizeMethod.GAUSSIAN, 

3696 "lanczos3": tf.image.ResizeMethod.LANCZOS3, 

3697 "lanczos5": tf.image.ResizeMethod.LANCZOS5, 

3698 "mitchellcubic": tf.image.ResizeMethod.MITCHELLCUBIC, 

3699 "nearest": tf.image.ResizeMethod.NEAREST_NEIGHBOR, 

3700 } 

3701 interploations_list = '"' + '", "'.join(interpolations.keys()) + '"' 

3702 if interpolation in interpolations: 

3703 x = tf.image.resize(x, new_shape, method=interpolations[interpolation]) 

3704 else: 

3705 raise ValueError( 

3706 "`interpolation` argument should be one of: " 

3707 f'{interploations_list}. Received: "{interpolation}".' 

3708 ) 

3709 if data_format == "channels_first": 

3710 x = permute_dimensions(x, [0, 3, 1, 2]) 

3711 

3712 return x 

3713 

3714 

3715@keras_export("keras.backend.resize_volumes") 

3716@tf.__internal__.dispatch.add_dispatch_support 

3717@doc_controls.do_not_generate_docs 

3718def resize_volumes(x, depth_factor, height_factor, width_factor, data_format): 

3719 """Resizes the volume contained in a 5D tensor. 

3720 

3721 Args: 

3722 x: Tensor or variable to resize. 

3723 depth_factor: Positive integer. 

3724 height_factor: Positive integer. 

3725 width_factor: Positive integer. 

3726 data_format: One of `"channels_first"`, `"channels_last"`. 

3727 

3728 Returns: 

3729 A tensor. 

3730 

3731 Raises: 

3732 ValueError: if `data_format` is neither 

3733 `channels_last` or `channels_first`. 

3734 """ 

3735 if data_format == "channels_first": 

3736 output = repeat_elements(x, depth_factor, axis=2) 

3737 output = repeat_elements(output, height_factor, axis=3) 

3738 output = repeat_elements(output, width_factor, axis=4) 

3739 return output 

3740 elif data_format == "channels_last": 

3741 output = repeat_elements(x, depth_factor, axis=1) 

3742 output = repeat_elements(output, height_factor, axis=2) 

3743 output = repeat_elements(output, width_factor, axis=3) 

3744 return output 

3745 else: 

3746 raise ValueError("Invalid data_format: " + str(data_format)) 

3747 

3748 

3749@keras_export("keras.backend.repeat_elements") 

3750@tf.__internal__.dispatch.add_dispatch_support 

3751@doc_controls.do_not_generate_docs 

3752def repeat_elements(x, rep, axis): 

3753 """Repeats the elements of a tensor along an axis, like `np.repeat`. 

3754 

3755 If `x` has shape `(s1, s2, s3)` and `axis` is `1`, the output 

3756 will have shape `(s1, s2 * rep, s3)`. 

3757 

3758 Args: 

3759 x: Tensor or variable. 

3760 rep: Python integer, number of times to repeat. 

3761 axis: Axis along which to repeat. 

3762 

3763 Returns: 

3764 A tensor. 

3765 

3766 Example: 

3767 

3768 >>> b = tf.constant([1, 2, 3]) 

3769 >>> tf.keras.backend.repeat_elements(b, rep=2, axis=0) 

3770 <tf.Tensor: shape=(6,), dtype=int32, 

3771 numpy=array([1, 1, 2, 2, 3, 3], dtype=int32)> 

3772 

3773 """ 

3774 x_shape = x.shape.as_list() 

3775 # For static axis 

3776 if x_shape[axis] is not None: 

3777 # slices along the repeat axis 

3778 splits = tf.split(value=x, num_or_size_splits=x_shape[axis], axis=axis) 

3779 # repeat each slice the given number of reps 

3780 x_rep = [s for s in splits for _ in range(rep)] 

3781 return concatenate(x_rep, axis) 

3782 

3783 # Here we use tf.tile to mimic behavior of np.repeat so that 

3784 # we can handle dynamic shapes (that include None). 

3785 # To do that, we need an auxiliary axis to repeat elements along 

3786 # it and then merge them along the desired axis. 

3787 

3788 # Repeating 

3789 auxiliary_axis = axis + 1 

3790 x_shape = tf.shape(x) 

3791 x_rep = tf.expand_dims(x, axis=auxiliary_axis) 

3792 reps = np.ones(len(x.shape) + 1) 

3793 reps[auxiliary_axis] = rep 

3794 x_rep = tf.tile(x_rep, reps) 

3795 

3796 # Merging 

3797 reps = np.delete(reps, auxiliary_axis) 

3798 reps[axis] = rep 

3799 reps = tf.constant(reps, dtype="int32") 

3800 x_shape *= reps 

3801 x_rep = tf.reshape(x_rep, x_shape) 

3802 

3803 # Fix shape representation 

3804 x_shape = x.shape.as_list() 

3805 x_rep.set_shape(x_shape) 

3806 x_rep._keras_shape = tuple(x_shape) 

3807 return x_rep 

3808 

3809 

3810@keras_export("keras.backend.repeat") 

3811@tf.__internal__.dispatch.add_dispatch_support 

3812@doc_controls.do_not_generate_docs 

3813def repeat(x, n): 

3814 """Repeats a 2D tensor. 

3815 

3816 if `x` has shape (samples, dim) and `n` is `2`, 

3817 the output will have shape `(samples, 2, dim)`. 

3818 

3819 Args: 

3820 x: Tensor or variable. 

3821 n: Python integer, number of times to repeat. 

3822 

3823 Returns: 

3824 A tensor. 

3825 

3826 Example: 

3827 

3828 >>> b = tf.constant([[1, 2], [3, 4]]) 

3829 >>> b 

3830 <tf.Tensor: shape=(2, 2), dtype=int32, numpy= 

3831 array([[1, 2], 

3832 [3, 4]], dtype=int32)> 

3833 >>> tf.keras.backend.repeat(b, n=2) 

3834 <tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy= 

3835 array([[[1, 2], 

3836 [1, 2]], 

3837 [[3, 4], 

3838 [3, 4]]], dtype=int32)> 

3839 

3840 """ 

3841 assert ndim(x) == 2 

3842 x = tf.expand_dims(x, 1) 

3843 pattern = tf.stack([1, n, 1]) 

3844 return tf.tile(x, pattern) 

3845 

3846 

3847@keras_export("keras.backend.arange") 

3848@tf.__internal__.dispatch.add_dispatch_support 

3849@doc_controls.do_not_generate_docs 

3850def arange(start, stop=None, step=1, dtype="int32"): 

3851 """Creates a 1D tensor containing a sequence of integers. 

3852 

3853 The function arguments use the same convention as 

3854 Theano's arange: if only one argument is provided, 

3855 it is in fact the "stop" argument and "start" is 0. 

3856 

3857 The default type of the returned tensor is `'int32'` to 

3858 match TensorFlow's default. 

3859 

3860 Args: 

3861 start: Start value. 

3862 stop: Stop value. 

3863 step: Difference between two successive values. 

3864 dtype: Integer dtype to use. 

3865 

3866 Returns: 

3867 An integer tensor. 

3868 

3869 Example: 

3870 

3871 >>> tf.keras.backend.arange(start=0, stop=10, step=1.5) 

3872 <tf.Tensor: shape=(7,), dtype=float32, 

3873 numpy=array([0. , 1.5, 3. , 4.5, 6. , 7.5, 9. ], dtype=float32)> 

3874 

3875 

3876 

3877 """ 

3878 # Match the behavior of numpy and Theano by returning an empty sequence. 

3879 if stop is None and start < 0: 

3880 start = 0 

3881 result = tf.range(start, limit=stop, delta=step, name="arange") 

3882 if dtype != "int32": 

3883 result = cast(result, dtype) 

3884 return result 

3885 

3886 

3887@keras_export("keras.backend.tile") 

3888@tf.__internal__.dispatch.add_dispatch_support 

3889@doc_controls.do_not_generate_docs 

3890def tile(x, n): 

3891 """Creates a tensor by tiling `x` by `n`. 

3892 

3893 Args: 

3894 x: A tensor or variable 

3895 n: A list of integer. The length must be the same as the number of 

3896 dimensions in `x`. 

3897 

3898 Returns: 

3899 A tiled tensor. 

3900 """ 

3901 if isinstance(n, int): 

3902 n = [n] 

3903 return tf.tile(x, n) 

3904 

3905 

3906@keras_export("keras.backend.flatten") 

3907@tf.__internal__.dispatch.add_dispatch_support 

3908@doc_controls.do_not_generate_docs 

3909def flatten(x): 

3910 """Flatten a tensor. 

3911 

3912 Args: 

3913 x: A tensor or variable. 

3914 

3915 Returns: 

3916 A tensor, reshaped into 1-D 

3917 

3918 Example: 

3919 

3920 >>> b = tf.constant([[1, 2], [3, 4]]) 

3921 >>> b 

3922 <tf.Tensor: shape=(2, 2), dtype=int32, numpy= 

3923 array([[1, 2], 

3924 [3, 4]], dtype=int32)> 

3925 >>> tf.keras.backend.flatten(b) 

3926 <tf.Tensor: shape=(4,), dtype=int32, 

3927 numpy=array([1, 2, 3, 4], dtype=int32)> 

3928 

3929 """ 

3930 return tf.reshape(x, [-1]) 

3931 

3932 

3933@keras_export("keras.backend.batch_flatten") 

3934@tf.__internal__.dispatch.add_dispatch_support 

3935@doc_controls.do_not_generate_docs 

3936def batch_flatten(x): 

3937 """Turn a nD tensor into a 2D tensor with same 0th dimension. 

3938 

3939 In other words, it flattens each data samples of a batch. 

3940 

3941 Args: 

3942 x: A tensor or variable. 

3943 

3944 Returns: 

3945 A tensor. 

3946 

3947 Examples: 

3948 Flattening a 3D tensor to 2D by collapsing the last dimension. 

3949 

3950 >>> x_batch = tf.keras.backend.ones(shape=(2, 3, 4, 5)) 

3951 >>> x_batch_flatten = batch_flatten(x_batch) 

3952 >>> tf.keras.backend.int_shape(x_batch_flatten) 

3953 (2, 60) 

3954 

3955 """ 

3956 x = tf.reshape(x, tf.stack([-1, prod(shape(x)[1:])])) 

3957 return x 

3958 

3959 

3960@keras_export("keras.backend.expand_dims") 

3961@tf.__internal__.dispatch.add_dispatch_support 

3962@doc_controls.do_not_generate_docs 

3963def expand_dims(x, axis=-1): 

3964 """Adds a 1-sized dimension at index "axis". 

3965 

3966 Args: 

3967 x: A tensor or variable. 

3968 axis: Position where to add a new axis. 

3969 

3970 Returns: 

3971 A tensor with expanded dimensions. 

3972 """ 

3973 return tf.expand_dims(x, axis) 

3974 

3975 

3976@keras_export("keras.backend.squeeze") 

3977@tf.__internal__.dispatch.add_dispatch_support 

3978@doc_controls.do_not_generate_docs 

3979def squeeze(x, axis): 

3980 """Removes a 1-dimension from the tensor at index "axis". 

3981 

3982 Args: 

3983 x: A tensor or variable. 

3984 axis: Axis to drop. 

3985 

3986 Returns: 

3987 A tensor with the same data as `x` but reduced dimensions. 

3988 """ 

3989 return tf.squeeze(x, [axis]) 

3990 

3991 

3992@keras_export("keras.backend.temporal_padding") 

3993@tf.__internal__.dispatch.add_dispatch_support 

3994@doc_controls.do_not_generate_docs 

3995def temporal_padding(x, padding=(1, 1)): 

3996 """Pads the middle dimension of a 3D tensor. 

3997 

3998 Args: 

3999 x: Tensor or variable. 

4000 padding: Tuple of 2 integers, how many zeros to 

4001 add at the start and end of dim 1. 

4002 

4003 Returns: 

4004 A padded 3D tensor. 

4005 """ 

4006 assert len(padding) == 2 

4007 pattern = [[0, 0], [padding[0], padding[1]], [0, 0]] 

4008 return tf.compat.v1.pad(x, pattern) 

4009 

4010 

4011@keras_export("keras.backend.spatial_2d_padding") 

4012@tf.__internal__.dispatch.add_dispatch_support 

4013@doc_controls.do_not_generate_docs 

4014def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None): 

4015 """Pads the 2nd and 3rd dimensions of a 4D tensor. 

4016 

4017 Args: 

4018 x: Tensor or variable. 

4019 padding: Tuple of 2 tuples, padding pattern. 

4020 data_format: One of `channels_last` or `channels_first`. 

4021 

4022 Returns: 

4023 A padded 4D tensor. 

4024 

4025 Raises: 

4026 ValueError: if `data_format` is neither 

4027 `channels_last` or `channels_first`. 

4028 """ 

4029 assert len(padding) == 2 

4030 assert len(padding[0]) == 2 

4031 assert len(padding[1]) == 2 

4032 if data_format is None: 

4033 data_format = image_data_format() 

4034 if data_format not in {"channels_first", "channels_last"}: 

4035 raise ValueError("Unknown data_format: " + str(data_format)) 

4036 

4037 if data_format == "channels_first": 

4038 pattern = [[0, 0], [0, 0], list(padding[0]), list(padding[1])] 

4039 else: 

4040 pattern = [[0, 0], list(padding[0]), list(padding[1]), [0, 0]] 

4041 return tf.compat.v1.pad(x, pattern) 

4042 

4043 

4044@keras_export("keras.backend.spatial_3d_padding") 

4045@tf.__internal__.dispatch.add_dispatch_support 

4046@doc_controls.do_not_generate_docs 

4047def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None): 

4048 """Pads 5D tensor with zeros along the depth, height, width dimensions. 

4049 

4050 Pads these dimensions with respectively 

4051 "padding[0]", "padding[1]" and "padding[2]" zeros left and right. 

4052 

4053 For 'channels_last' data_format, 

4054 the 2nd, 3rd and 4th dimension will be padded. 

4055 For 'channels_first' data_format, 

4056 the 3rd, 4th and 5th dimension will be padded. 

4057 

4058 Args: 

4059 x: Tensor or variable. 

4060 padding: Tuple of 3 tuples, padding pattern. 

4061 data_format: One of `channels_last` or `channels_first`. 

4062 

4063 Returns: 

4064 A padded 5D tensor. 

4065 

4066 Raises: 

4067 ValueError: if `data_format` is neither 

4068 `channels_last` or `channels_first`. 

4069 

4070 """ 

4071 assert len(padding) == 3 

4072 assert len(padding[0]) == 2 

4073 assert len(padding[1]) == 2 

4074 assert len(padding[2]) == 2 

4075 if data_format is None: 

4076 data_format = image_data_format() 

4077 if data_format not in {"channels_first", "channels_last"}: 

4078 raise ValueError("Unknown data_format: " + str(data_format)) 

4079 

4080 if data_format == "channels_first": 

4081 pattern = [ 

4082 [0, 0], 

4083 [0, 0], 

4084 [padding[0][0], padding[0][1]], 

4085 [padding[1][0], padding[1][1]], 

4086 [padding[2][0], padding[2][1]], 

4087 ] 

4088 else: 

4089 pattern = [ 

4090 [0, 0], 

4091 [padding[0][0], padding[0][1]], 

4092 [padding[1][0], padding[1][1]], 

4093 [padding[2][0], padding[2][1]], 

4094 [0, 0], 

4095 ] 

4096 return tf.compat.v1.pad(x, pattern) 

4097 

4098 

4099@keras_export("keras.backend.stack") 

4100@tf.__internal__.dispatch.add_dispatch_support 

4101@doc_controls.do_not_generate_docs 

4102def stack(x, axis=0): 

4103 """Stacks a list of rank `R` tensors into a rank `R+1` tensor. 

4104 

4105 Args: 

4106 x: List of tensors. 

4107 axis: Axis along which to perform stacking. 

4108 

4109 Returns: 

4110 A tensor. 

4111 

4112 Example: 

4113 

4114 >>> a = tf.constant([[1, 2],[3, 4]]) 

4115 >>> b = tf.constant([[10, 20],[30, 40]]) 

4116 >>> tf.keras.backend.stack((a, b)) 

4117 <tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy= 

4118 array([[[ 1, 2], 

4119 [ 3, 4]], 

4120 [[10, 20], 

4121 [30, 40]]], dtype=int32)> 

4122 

4123 """ 

4124 return tf.stack(x, axis=axis) 

4125 

4126 

4127@keras_export("keras.backend.one_hot") 

4128@tf.__internal__.dispatch.add_dispatch_support 

4129@doc_controls.do_not_generate_docs 

4130def one_hot(indices, num_classes): 

4131 """Computes the one-hot representation of an integer tensor. 

4132 

4133 Args: 

4134 indices: nD integer tensor of shape 

4135 `(batch_size, dim1, dim2, ... dim(n-1))` 

4136 num_classes: Integer, number of classes to consider. 

4137 

4138 Returns: 

4139 (n + 1)D one hot representation of the input 

4140 with shape `(batch_size, dim1, dim2, ... dim(n-1), num_classes)` 

4141 

4142 Returns: 

4143 The one-hot tensor. 

4144 """ 

4145 return tf.one_hot(indices, depth=num_classes, axis=-1) 

4146 

4147 

4148@keras_export("keras.backend.reverse") 

4149@tf.__internal__.dispatch.add_dispatch_support 

4150@doc_controls.do_not_generate_docs 

4151def reverse(x, axes): 

4152 """Reverse a tensor along the specified axes. 

4153 

4154 Args: 

4155 x: Tensor to reverse. 

4156 axes: Integer or iterable of integers. 

4157 Axes to reverse. 

4158 

4159 Returns: 

4160 A tensor. 

4161 """ 

4162 if isinstance(axes, int): 

4163 axes = [axes] 

4164 return tf.reverse(x, axes) 

4165 

4166 

4167# VALUE MANIPULATION 

4168_VALUE_SET_CODE_STRING = """ 

4169 >>> K = tf.keras.backend # Common keras convention 

4170 >>> v = K.variable(1.) 

4171 

4172 >>> # reassign 

4173 >>> K.set_value(v, 2.) 

4174 >>> print(K.get_value(v)) 

4175 2.0 

4176 

4177 >>> # increment 

4178 >>> K.set_value(v, K.get_value(v) + 1) 

4179 >>> print(K.get_value(v)) 

4180 3.0 

4181 

4182 Variable semantics in TensorFlow 2 are eager execution friendly. The above 

4183 code is roughly equivalent to: 

4184 

4185 >>> v = tf.Variable(1.) 

4186 

4187 >>> v.assign(2.) 

4188 >>> print(v.numpy()) 

4189 2.0 

4190 

4191 >>> v.assign_add(1.) 

4192 >>> print(v.numpy()) 

4193 3.0"""[ 

4194 3: 

4195] # Prune first newline and indent to match the docstring template. 

4196 

4197 

4198@keras_export("keras.backend.get_value") 

4199@doc_controls.do_not_generate_docs 

4200def get_value(x): 

4201 """Returns the value of a variable. 

4202 

4203 `backend.get_value` is the complement of `backend.set_value`, and provides 

4204 a generic interface for reading from variables while abstracting away the 

4205 differences between TensorFlow 1.x and 2.x semantics. 

4206 

4207 {snippet} 

4208 

4209 Args: 

4210 x: input variable. 

4211 

4212 Returns: 

4213 A Numpy array. 

4214 """ 

4215 if not tf.is_tensor(x): 

4216 return x 

4217 if tf.executing_eagerly() or isinstance(x, tf.__internal__.EagerTensor): 

4218 return x.numpy() 

4219 if not getattr(x, "_in_graph_mode", True): 

4220 # This is a variable which was created in an eager context, but is being 

4221 # evaluated from a Graph. 

4222 with tf.__internal__.eager_context.eager_mode(): 

4223 return x.numpy() 

4224 

4225 if tf.compat.v1.executing_eagerly_outside_functions(): 

4226 # This method of evaluating works inside the Keras FuncGraph. 

4227 with tf.init_scope(): 

4228 return x.numpy() 

4229 

4230 with x.graph.as_default(): 

4231 return x.eval(session=get_session((x,))) 

4232 

4233 

4234@keras_export("keras.backend.batch_get_value") 

4235@tf.__internal__.dispatch.add_dispatch_support 

4236@doc_controls.do_not_generate_docs 

4237def batch_get_value(tensors): 

4238 """Returns the value of more than one tensor variable. 

4239 

4240 Args: 

4241 tensors: list of ops to run. 

4242 

4243 Returns: 

4244 A list of Numpy arrays. 

4245 

4246 Raises: 

4247 RuntimeError: If this method is called inside defun. 

4248 """ 

4249 if tf.executing_eagerly(): 

4250 return [x.numpy() for x in tensors] 

4251 elif tf.inside_function(): 

4252 raise RuntimeError("Cannot get value inside Tensorflow graph function.") 

4253 if tensors: 

4254 return get_session(tensors).run(tensors) 

4255 else: 

4256 return [] 

4257 

4258 

4259@keras_export("keras.backend.set_value") 

4260@doc_controls.do_not_generate_docs 

4261def set_value(x, value): 

4262 """Sets the value of a variable, from a Numpy array. 

4263 

4264 `backend.set_value` is the complement of `backend.get_value`, and provides 

4265 a generic interface for assigning to variables while abstracting away the 

4266 differences between TensorFlow 1.x and 2.x semantics. 

4267 

4268 {snippet} 

4269 

4270 Args: 

4271 x: Variable to set to a new value. 

4272 value: Value to set the tensor to, as a Numpy array 

4273 (of the same shape). 

4274 """ 

4275 value = np.asarray(value, dtype=dtype_numpy(x)) 

4276 if tf.compat.v1.executing_eagerly_outside_functions(): 

4277 _assign_value_to_variable(x, value) 

4278 else: 

4279 with get_graph().as_default(): 

4280 tf_dtype = tf.as_dtype(x.dtype.name.split("_")[0]) 

4281 if hasattr(x, "_assign_placeholder"): 

4282 assign_placeholder = x._assign_placeholder 

4283 assign_op = x._assign_op 

4284 else: 

4285 # In order to support assigning weights to resizable variables 

4286 # in Keras, we make a placeholder with the correct number of 

4287 # dimensions but with None in each dimension. This way, we can 

4288 # assign weights of any size (as long as they have the correct 

4289 # dimensionality). 

4290 placeholder_shape = tf.TensorShape([None] * value.ndim) 

4291 assign_placeholder = tf.compat.v1.placeholder( 

4292 tf_dtype, shape=placeholder_shape 

4293 ) 

4294 assign_op = x.assign(assign_placeholder) 

4295 x._assign_placeholder = assign_placeholder 

4296 x._assign_op = assign_op 

4297 get_session().run(assign_op, feed_dict={assign_placeholder: value}) 

4298 

4299 

4300@keras_export("keras.backend.batch_set_value") 

4301@tf.__internal__.dispatch.add_dispatch_support 

4302@doc_controls.do_not_generate_docs 

4303def batch_set_value(tuples): 

4304 """Sets the values of many tensor variables at once. 

4305 

4306 Args: 

4307 tuples: a list of tuples `(tensor, value)`. 

4308 `value` should be a Numpy array. 

4309 """ 

4310 if tf.executing_eagerly() or tf.inside_function(): 

4311 for x, value in tuples: 

4312 value = np.asarray(value, dtype=dtype_numpy(x)) 

4313 _assign_value_to_variable(x, value) 

4314 else: 

4315 with get_graph().as_default(): 

4316 if tuples: 

4317 assign_ops = [] 

4318 feed_dict = {} 

4319 for x, value in tuples: 

4320 value = np.asarray(value, dtype=dtype_numpy(x)) 

4321 tf_dtype = tf.as_dtype(x.dtype.name.split("_")[0]) 

4322 if hasattr(x, "_assign_placeholder"): 

4323 assign_placeholder = x._assign_placeholder 

4324 assign_op = x._assign_op 

4325 else: 

4326 # In order to support assigning weights to resizable 

4327 # variables in Keras, we make a placeholder with the 

4328 # correct number of dimensions but with None in each 

4329 # dimension. This way, we can assign weights of any size 

4330 # (as long as they have the correct dimensionality). 

4331 placeholder_shape = tf.TensorShape([None] * value.ndim) 

4332 assign_placeholder = tf.compat.v1.placeholder( 

4333 tf_dtype, shape=placeholder_shape 

4334 ) 

4335 assign_op = x.assign(assign_placeholder) 

4336 x._assign_placeholder = assign_placeholder 

4337 x._assign_op = assign_op 

4338 assign_ops.append(assign_op) 

4339 feed_dict[assign_placeholder] = value 

4340 get_session().run(assign_ops, feed_dict=feed_dict) 

4341 

4342 

4343get_value.__doc__ = get_value.__doc__.format(snippet=_VALUE_SET_CODE_STRING) 

4344set_value.__doc__ = set_value.__doc__.format(snippet=_VALUE_SET_CODE_STRING) 

4345 

4346 

4347def _assign_value_to_variable(variable, value): 

4348 # Helper function to assign value to variable. It handles normal tf.Variable 

4349 # as well as DTensor variable. 

4350 if isinstance(variable, dtensor.DVariable): 

4351 mesh = variable.layout.mesh 

4352 replicate_layout = dtensor.Layout.replicated( 

4353 rank=variable.shape.rank, mesh=mesh 

4354 ) 

4355 # TODO(b/262894693): Avoid the broadcast of tensor to all devices. 

4356 d_value = dtensor.copy_to_mesh(value, replicate_layout) 

4357 d_value = dtensor.relayout(d_value, variable.layout) 

4358 variable.assign(d_value) 

4359 else: 

4360 # For the normal tf.Variable assign 

4361 variable.assign(value) 

4362 

4363 

4364@keras_export("keras.backend.print_tensor") 

4365@tf.__internal__.dispatch.add_dispatch_support 

4366@doc_controls.do_not_generate_docs 

4367def print_tensor(x, message="", summarize=3): 

4368 """Prints `message` and the tensor value when evaluated. 

4369 

4370 Note that `print_tensor` returns a new tensor identical to `x` 

4371 which should be used in the following code. Otherwise the 

4372 print operation is not taken into account during evaluation. 

4373 

4374 Example: 

4375 

4376 >>> x = tf.constant([[1.0, 2.0], [3.0, 4.0]]) 

4377 >>> tf.keras.backend.print_tensor(x) 

4378 <tf.Tensor: shape=(2, 2), dtype=float32, numpy= 

4379 array([[1., 2.], 

4380 [3., 4.]], dtype=float32)> 

4381 

4382 Args: 

4383 x: Tensor to print. 

4384 message: Message to print jointly with the tensor. 

4385 summarize: The first and last `summarize` elements within each dimension 

4386 are recursively printed per Tensor. If None, then the first 3 and 

4387 last 3 elements of each dimension are printed for each tensor. If 

4388 set to -1, it will print all elements of every tensor. 

4389 

4390 Returns: 

4391 The same tensor `x`, unchanged. 

4392 """ 

4393 if isinstance(x, tf.Tensor) and hasattr(x, "graph"): 

4394 with get_graph().as_default(): 

4395 op = tf.print( 

4396 message, x, output_stream=sys.stdout, summarize=summarize 

4397 ) 

4398 with tf.control_dependencies([op]): 

4399 return tf.identity(x) 

4400 else: 

4401 tf.print(message, x, output_stream=sys.stdout, summarize=summarize) 

4402 return x 

4403 

4404 

4405# GRAPH MANIPULATION 

4406 

4407 

4408class GraphExecutionFunction: 

4409 """Runs a computation graph. 

4410 

4411 It's possible to pass arguments to `tf.Session.run()` via `session_kwargs`. 

4412 In particular additional operations via `fetches` argument and additional 

4413 tensor substitutions via `feed_dict` arguments. Note that given 

4414 substitutions are merged with substitutions from `inputs`. Even though 

4415 `feed_dict` is passed once in the constructor (called in `model.compile()`) 

4416 we can modify the values in the dictionary. Through this feed_dict we can 

4417 provide additional substitutions besides Keras inputs. 

4418 

4419 Args: 

4420 inputs: Feed placeholders to the computation graph. 

4421 outputs: Output tensors to fetch. 

4422 updates: Additional update ops to be run at function call. 

4423 name: A name to help users identify what this function does. 

4424 session_kwargs: Arguments to `tf.Session.run()`: 

4425 `fetches`, `feed_dict`, `options`, `run_metadata`. 

4426 """ 

4427 

4428 def __init__( 

4429 self, inputs, outputs, updates=None, name=None, **session_kwargs 

4430 ): 

4431 updates = updates or [] 

4432 if not isinstance(updates, (list, tuple)): 

4433 raise TypeError( 

4434 "`updates` in a Keras backend function " 

4435 "should be a list or tuple." 

4436 ) 

4437 

4438 self.inputs = tf.nest.flatten( 

4439 tf_utils.convert_variables_to_tensors(inputs), 

4440 expand_composites=True, 

4441 ) 

4442 self._outputs_structure = tf_utils.convert_variables_to_tensors(outputs) 

4443 self.outputs = tf.nest.flatten( 

4444 self._outputs_structure, expand_composites=True 

4445 ) 

4446 # TODO(b/127668432): Consider using autograph to generate these 

4447 # dependencies in call. 

4448 # Index 0 = total loss or model output for `predict`. 

4449 with tf.control_dependencies([self.outputs[0]]): 

4450 updates_ops = [] 

4451 for update in updates: 

4452 if isinstance(update, tuple): 

4453 p, new_p = update 

4454 updates_ops.append(tf.compat.v1.assign(p, new_p)) 

4455 else: 

4456 # assumed already an op 

4457 updates_ops.append(update) 

4458 self.updates_op = tf.group(*updates_ops) 

4459 self.name = name 

4460 # additional tensor substitutions 

4461 self.feed_dict = session_kwargs.pop("feed_dict", None) 

4462 # additional operations 

4463 self.fetches = session_kwargs.pop("fetches", []) 

4464 if not isinstance(self.fetches, list): 

4465 self.fetches = [self.fetches] 

4466 self.run_options = session_kwargs.pop("options", None) 

4467 self.run_metadata = session_kwargs.pop("run_metadata", None) 

4468 # The main use case of `fetches` being passed to a model is the ability 

4469 # to run custom updates 

4470 # This requires us to wrap fetches in `identity` ops. 

4471 self.fetches = [tf.identity(x) for x in self.fetches] 

4472 self.session_kwargs = session_kwargs 

4473 # This mapping keeps track of the function that should receive the 

4474 # output from a fetch in `fetches`: { fetch: function(fetch_output) } 

4475 # A Callback can use this to register a function with access to the 

4476 # output values for a fetch it added. 

4477 self.fetch_callbacks = {} 

4478 

4479 if session_kwargs: 

4480 raise ValueError( 

4481 "Some keys in session_kwargs are not supported at this time: %s" 

4482 % (session_kwargs.keys(),) 

4483 ) 

4484 

4485 self._callable_fn = None 

4486 self._feed_arrays = None 

4487 self._feed_symbols = None 

4488 self._symbol_vals = None 

4489 self._fetches = None 

4490 self._session = None 

4491 

4492 def _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session): 

4493 """Generates a callable that runs the graph. 

4494 

4495 Args: 

4496 feed_arrays: List of input tensors to be fed Numpy arrays at runtime. 

4497 feed_symbols: List of input tensors to be fed symbolic tensors at 

4498 runtime. 

4499 symbol_vals: List of symbolic tensors to be fed to `feed_symbols`. 

4500 session: Session to use to generate the callable. 

4501 

4502 Returns: 

4503 Function that runs the graph according to the above options. 

4504 """ 

4505 # Prepare callable options. 

4506 callable_opts = config_pb2.CallableOptions() 

4507 # Handle external-data feed. 

4508 for x in feed_arrays: 

4509 callable_opts.feed.append(x.name) 

4510 if self.feed_dict: 

4511 for key in sorted(self.feed_dict.keys()): 

4512 callable_opts.feed.append(key.name) 

4513 # Handle symbolic feed. 

4514 for x, y in zip(feed_symbols, symbol_vals): 

4515 connection = callable_opts.tensor_connection.add() 

4516 if x.dtype != y.dtype: 

4517 y = tf.cast(y, dtype=x.dtype) 

4518 from_tensor = _as_graph_element(y) 

4519 if from_tensor is None: 

4520 from_tensor = y 

4521 connection.from_tensor = from_tensor.name # Data tensor 

4522 connection.to_tensor = x.name # Placeholder 

4523 # Handle fetches. 

4524 for x in self.outputs + self.fetches: 

4525 callable_opts.fetch.append(x.name) 

4526 # Handle updates. 

4527 callable_opts.target.append(self.updates_op.name) 

4528 # Handle run_options. 

4529 if self.run_options: 

4530 callable_opts.run_options.CopyFrom(self.run_options) 

4531 # Create callable. 

4532 callable_fn = session._make_callable_from_options(callable_opts) 

4533 # Cache parameters corresponding to the generated callable, so that 

4534 # we can detect future mismatches and refresh the callable. 

4535 self._callable_fn = callable_fn 

4536 self._feed_arrays = feed_arrays 

4537 self._feed_symbols = feed_symbols 

4538 self._symbol_vals = symbol_vals 

4539 self._fetches = list(self.fetches) 

4540 self._session = session 

4541 

4542 def _call_fetch_callbacks(self, fetches_output): 

4543 for fetch, output in zip(self._fetches, fetches_output): 

4544 if fetch in self.fetch_callbacks: 

4545 self.fetch_callbacks[fetch](output) 

4546 

4547 def _eval_if_composite(self, tensor): 

4548 """Helper method which evaluates any CompositeTensors passed to it.""" 

4549 # We need to evaluate any composite tensor objects that have been 

4550 # reconstructed in 'pack_sequence_as', since otherwise they'll be output 

4551 # as actual CompositeTensor objects instead of the value(s) contained in 

4552 # the CompositeTensors. E.g., if output_structure contains a 

4553 # SparseTensor, then this ensures that we return its value as a 

4554 # SparseTensorValue rather than a SparseTensor. 

4555 

4556 if tf_utils.is_extension_type(tensor): 

4557 return self._session.run(tensor) 

4558 else: 

4559 return tensor 

4560 

4561 def __call__(self, inputs): 

4562 inputs = tf.nest.flatten( 

4563 tf_utils.convert_variables_to_tensors(inputs), 

4564 expand_composites=True, 

4565 ) 

4566 

4567 session = get_session(inputs) 

4568 feed_arrays = [] 

4569 array_vals = [] 

4570 feed_symbols = [] 

4571 symbol_vals = [] 

4572 for tensor, value in zip(self.inputs, inputs): 

4573 if value is None: 

4574 continue 

4575 

4576 if tf.is_tensor(value): 

4577 # Case: feeding symbolic tensor. 

4578 feed_symbols.append(tensor) 

4579 symbol_vals.append(value) 

4580 else: 

4581 # Case: feeding Numpy array. 

4582 feed_arrays.append(tensor) 

4583 # We need to do array conversion and type casting at this level, 

4584 # since `callable_fn` only supports exact matches. 

4585 tensor_type = tf.as_dtype(tensor.dtype) 

4586 array_vals.append( 

4587 np.asarray(value, dtype=tensor_type.as_numpy_dtype) 

4588 ) 

4589 

4590 if self.feed_dict: 

4591 for key in sorted(self.feed_dict.keys()): 

4592 array_vals.append( 

4593 np.asarray( 

4594 self.feed_dict[key], dtype=key.dtype.as_numpy_dtype 

4595 ) 

4596 ) 

4597 

4598 # Refresh callable if anything has changed. 

4599 if ( 

4600 self._callable_fn is None 

4601 or feed_arrays != self._feed_arrays 

4602 or symbol_vals != self._symbol_vals 

4603 or feed_symbols != self._feed_symbols 

4604 or self.fetches != self._fetches 

4605 or session != self._session 

4606 ): 

4607 self._make_callable(feed_arrays, feed_symbols, symbol_vals, session) 

4608 

4609 fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata) 

4610 self._call_fetch_callbacks(fetched[-len(self._fetches) :]) 

4611 output_structure = tf.nest.pack_sequence_as( 

4612 self._outputs_structure, 

4613 fetched[: len(self.outputs)], 

4614 expand_composites=True, 

4615 ) 

4616 # We need to evaluate any composite tensor objects that have been 

4617 # reconstructed in 'pack_sequence_as', since otherwise they'll be output 

4618 # as actual CompositeTensor objects instead of the value(s) contained in 

4619 # the CompositeTensors. E.g., if output_structure contains a 

4620 # SparseTensor, then this ensures that we return its value as a 

4621 # SparseTensorValue rather than a SparseTensor. 

4622 return tf.nest.map_structure(self._eval_if_composite, output_structure) 

4623 

4624 

4625@keras_export("keras.backend.function") 

4626@doc_controls.do_not_generate_docs 

4627def function(inputs, outputs, updates=None, name=None, **kwargs): 

4628 """Instantiates a Keras function. 

4629 

4630 Args: 

4631 inputs: List of placeholder tensors. 

4632 outputs: List of output tensors. 

4633 updates: List of update ops. 

4634 name: String, name of function. 

4635 **kwargs: Passed to `tf.Session.run`. 

4636 

4637 Returns: 

4638 Output values as Numpy arrays. 

4639 

4640 Raises: 

4641 ValueError: if invalid kwargs are passed in or if in eager execution. 

4642 """ 

4643 if tf.compat.v1.executing_eagerly_outside_functions(): 

4644 if kwargs: 

4645 raise ValueError( 

4646 "Session keyword arguments are not supported during " 

4647 "eager execution. You passed: %s" % (kwargs,) 

4648 ) 

4649 if updates: 

4650 raise ValueError( 

4651 "`updates` argument is not supported during " 

4652 "eager execution. You passed: %s" % (updates,) 

4653 ) 

4654 from keras.src import models 

4655 

4656 model = models.Model(inputs=inputs, outputs=outputs) 

4657 

4658 wrap_outputs = isinstance(outputs, list) and len(outputs) == 1 

4659 

4660 def func(model_inputs): 

4661 outs = model(model_inputs) 

4662 if wrap_outputs: 

4663 outs = [outs] 

4664 return tf_utils.sync_to_numpy_or_python_type(outs) 

4665 

4666 return func 

4667 

4668 if kwargs: 

4669 for key in kwargs: 

4670 if key not in tf_inspect.getfullargspec(tf.compat.v1.Session.run)[ 

4671 0 

4672 ] and key not in ["inputs", "outputs", "updates", "name"]: 

4673 msg = ( 

4674 'Invalid argument "%s" passed to K.function with ' 

4675 "TensorFlow backend" % key 

4676 ) 

4677 raise ValueError(msg) 

4678 return GraphExecutionFunction( 

4679 inputs, outputs, updates=updates, name=name, **kwargs 

4680 ) 

4681 

4682 

4683@keras_export("keras.backend.gradients") 

4684@doc_controls.do_not_generate_docs 

4685def gradients(loss, variables): 

4686 """Returns the gradients of `loss` w.r.t. `variables`. 

4687 

4688 Args: 

4689 loss: Scalar tensor to minimize. 

4690 variables: List of variables. 

4691 

4692 Returns: 

4693 A gradients tensor. 

4694 """ 

4695 return tf.compat.v1.gradients( 

4696 loss, variables, colocate_gradients_with_ops=True 

4697 ) 

4698 

4699 

4700@keras_export("keras.backend.stop_gradient") 

4701@tf.__internal__.dispatch.add_dispatch_support 

4702@doc_controls.do_not_generate_docs 

4703def stop_gradient(variables): 

4704 """Returns `variables` but with zero gradient w.r.t. every other variable. 

4705 

4706 Args: 

4707 variables: Tensor or list of tensors to consider constant with respect 

4708 to any other variable. 

4709 

4710 

4711 Returns: 

4712 A single tensor or a list of tensors (depending on the passed argument) 

4713 that has no gradient with respect to any other variable. 

4714 """ 

4715 if isinstance(variables, (list, tuple)): 

4716 return map(tf.stop_gradient, variables) 

4717 return tf.stop_gradient(variables) 

4718 

4719 

4720# CONTROL FLOW 

4721 

4722 

4723@keras_export("keras.backend.rnn") 

4724@tf.__internal__.dispatch.add_dispatch_support 

4725def rnn( 

4726 step_function, 

4727 inputs, 

4728 initial_states, 

4729 go_backwards=False, 

4730 mask=None, 

4731 constants=None, 

4732 unroll=False, 

4733 input_length=None, 

4734 time_major=False, 

4735 zero_output_for_mask=False, 

4736 return_all_outputs=True, 

4737): 

4738 """Iterates over the time dimension of a tensor. 

4739 

4740 Args: 

4741 step_function: RNN step function. 

4742 Args; 

4743 input; Tensor with shape `(samples, ...)` (no time dimension), 

4744 representing input for the batch of samples at a certain 

4745 time step. 

4746 states; List of tensors. 

4747 Returns; 

4748 output; Tensor with shape `(samples, output_dim)` 

4749 (no time dimension). 

4750 new_states; List of tensors, same length and shapes 

4751 as 'states'. The first state in the list must be the 

4752 output tensor at the previous timestep. 

4753 inputs: Tensor of temporal data of shape `(samples, time, ...)` 

4754 (at least 3D), or nested tensors, and each of which has shape 

4755 `(samples, time, ...)`. 

4756 initial_states: Tensor with shape `(samples, state_size)` 

4757 (no time dimension), containing the initial values for the states 

4758 used in the step function. In the case that state_size is in a 

4759 nested shape, the shape of initial_states will also follow the 

4760 nested structure. 

4761 go_backwards: Boolean. If True, do the iteration over the time 

4762 dimension in reverse order and return the reversed sequence. 

4763 mask: Binary tensor with shape `(samples, time, 1)`, 

4764 with a zero for every element that is masked. 

4765 constants: List of constant values passed at each step. 

4766 unroll: Whether to unroll the RNN or to use a symbolic `while_loop`. 

4767 input_length: An integer or a 1-D Tensor, depending on whether 

4768 the time dimension is fixed-length or not. In case of variable 

4769 length input, it is used for masking in case there's no mask 

4770 specified. 

4771 time_major: Boolean. If true, the inputs and outputs will be in shape 

4772 `(timesteps, batch, ...)`, whereas in the False case, it will be 

4773 `(batch, timesteps, ...)`. Using `time_major = True` is a bit more 

4774 efficient because it avoids transposes at the beginning and end of 

4775 the RNN calculation. However, most TensorFlow data is batch-major, 

4776 so by default this function accepts input and emits output in 

4777 batch-major form. 

4778 zero_output_for_mask: Boolean. If True, the output for masked timestep 

4779 will be zeros, whereas in the False case, output from previous 

4780 timestep is returned. 

4781 return_all_outputs: Boolean. If True, return the recurrent outputs for 

4782 all timesteps in the sequence. If False, only return the output for 

4783 the last timestep (which consumes less memory). 

4784 

4785 Returns: 

4786 A tuple, `(last_output, outputs, new_states)`. 

4787 last_output: the latest output of the rnn, of shape `(samples, ...)` 

4788 outputs: 

4789 - If `return_all_outputs=True`: a tensor with shape 

4790 `(samples, time, ...)` where each entry `outputs[s, t]` is the 

4791 output of the step function at time `t` for sample `s` 

4792 - Else, a tensor equal to `last_output` with shape 

4793 `(samples, 1, ...)` 

4794 new_states: list of tensors, latest states returned by 

4795 the step function, of shape `(samples, ...)`. 

4796 

4797 Raises: 

4798 ValueError: if input dimension is less than 3. 

4799 ValueError: if `unroll` is `True` but input timestep is not a fixed 

4800 number. 

4801 ValueError: if `mask` is provided (not `None`) but states is not 

4802 provided (`len(states)` == 0). 

4803 """ 

4804 if not tf.__internal__.tf2.enabled(): 

4805 return_all_outputs = True # Not supported in TF1. 

4806 

4807 def swap_batch_timestep(input_t): 

4808 # Swap the batch and timestep dim for the incoming tensor. 

4809 axes = list(range(len(input_t.shape))) 

4810 axes[0], axes[1] = 1, 0 

4811 return tf.compat.v1.transpose(input_t, axes) 

4812 

4813 if not time_major: 

4814 inputs = tf.nest.map_structure(swap_batch_timestep, inputs) 

4815 

4816 flatted_inputs = tf.nest.flatten(inputs) 

4817 time_steps = flatted_inputs[0].shape[0] 

4818 batch = flatted_inputs[0].shape[1] 

4819 time_steps_t = tf.shape(flatted_inputs[0])[0] 

4820 

4821 for input_ in flatted_inputs: 

4822 input_.shape.with_rank_at_least(3) 

4823 

4824 if mask is not None: 

4825 if mask.dtype != tf.bool: 

4826 mask = tf.cast(mask, tf.bool) 

4827 if len(mask.shape) == 2: 

4828 mask = expand_dims(mask) 

4829 if not time_major: 

4830 mask = swap_batch_timestep(mask) 

4831 

4832 if constants is None: 

4833 constants = [] 

4834 

4835 # tf.where needs its condition tensor to be the same shape as its two 

4836 # result tensors, but in our case the condition (mask) tensor is 

4837 # (nsamples, 1), and inputs are (nsamples, ndimensions) or even more. 

4838 # So we need to broadcast the mask to match the shape of inputs. 

4839 # That's what the tile call does, it just repeats the mask along its 

4840 # second dimension n times. 

4841 def _expand_mask(mask_t, input_t, fixed_dim=1): 

4842 if tf.nest.is_nested(mask_t): 

4843 raise ValueError( 

4844 f"mask_t is expected to be tensor, but got {mask_t}" 

4845 ) 

4846 if tf.nest.is_nested(input_t): 

4847 raise ValueError( 

4848 f"input_t is expected to be tensor, but got {input_t}" 

4849 ) 

4850 rank_diff = len(input_t.shape) - len(mask_t.shape) 

4851 for _ in range(rank_diff): 

4852 mask_t = tf.expand_dims(mask_t, -1) 

4853 multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:] 

4854 return tf.tile(mask_t, multiples) 

4855 

4856 if unroll: 

4857 if not time_steps: 

4858 raise ValueError("Unrolling requires a fixed number of timesteps.") 

4859 states = tuple(initial_states) 

4860 successive_states = [] 

4861 successive_outputs = [] 

4862 

4863 # Process the input tensors. The input tensor need to be split on the 

4864 # time_step dim, and reverse if go_backwards is True. In the case of 

4865 # nested input, the input is flattened and then transformed 

4866 # individually. The result of this will be a tuple of lists, each of 

4867 # the item in tuple is list of the tensor with shape (batch, feature) 

4868 def _process_single_input_t(input_t): 

4869 input_t = tf.unstack(input_t) # unstack for time_step dim 

4870 if go_backwards: 

4871 input_t.reverse() 

4872 return input_t 

4873 

4874 if tf.nest.is_nested(inputs): 

4875 processed_input = tf.nest.map_structure( 

4876 _process_single_input_t, inputs 

4877 ) 

4878 else: 

4879 processed_input = (_process_single_input_t(inputs),) 

4880 

4881 def _get_input_tensor(time): 

4882 inp = [t_[time] for t_ in processed_input] 

4883 return tf.nest.pack_sequence_as(inputs, inp) 

4884 

4885 if mask is not None: 

4886 mask_list = tf.unstack(mask) 

4887 if go_backwards: 

4888 mask_list.reverse() 

4889 

4890 for i in range(time_steps): 

4891 inp = _get_input_tensor(i) 

4892 mask_t = mask_list[i] 

4893 output, new_states = step_function( 

4894 inp, tuple(states) + tuple(constants) 

4895 ) 

4896 tiled_mask_t = _expand_mask(mask_t, output) 

4897 

4898 if not successive_outputs: 

4899 prev_output = zeros_like(output) 

4900 else: 

4901 prev_output = successive_outputs[-1] 

4902 

4903 output = tf.where(tiled_mask_t, output, prev_output) 

4904 

4905 flat_states = tf.nest.flatten(states) 

4906 flat_new_states = tf.nest.flatten(new_states) 

4907 tiled_mask_t = tuple( 

4908 _expand_mask(mask_t, s) for s in flat_states 

4909 ) 

4910 flat_final_states = tuple( 

4911 tf.where(m, s, ps) 

4912 for m, s, ps in zip( 

4913 tiled_mask_t, flat_new_states, flat_states 

4914 ) 

4915 ) 

4916 states = tf.nest.pack_sequence_as(states, flat_final_states) 

4917 

4918 if return_all_outputs: 

4919 successive_outputs.append(output) 

4920 successive_states.append(states) 

4921 else: 

4922 successive_outputs = [output] 

4923 successive_states = [states] 

4924 last_output = successive_outputs[-1] 

4925 new_states = successive_states[-1] 

4926 outputs = tf.stack(successive_outputs) 

4927 

4928 if zero_output_for_mask: 

4929 last_output = tf.where( 

4930 _expand_mask(mask_list[-1], last_output), 

4931 last_output, 

4932 zeros_like(last_output), 

4933 ) 

4934 outputs = tf.where( 

4935 _expand_mask(mask, outputs, fixed_dim=2), 

4936 outputs, 

4937 zeros_like(outputs), 

4938 ) 

4939 

4940 else: # mask is None 

4941 for i in range(time_steps): 

4942 inp = _get_input_tensor(i) 

4943 output, states = step_function( 

4944 inp, tuple(states) + tuple(constants) 

4945 ) 

4946 if return_all_outputs: 

4947 successive_outputs.append(output) 

4948 successive_states.append(states) 

4949 else: 

4950 successive_outputs = [output] 

4951 successive_states = [states] 

4952 last_output = successive_outputs[-1] 

4953 new_states = successive_states[-1] 

4954 outputs = tf.stack(successive_outputs) 

4955 

4956 else: # Unroll == False 

4957 states = tuple(initial_states) 

4958 

4959 # Create input tensor array, if the inputs is nested tensors, then it 

4960 # will be flattened first, and tensor array will be created one per 

4961 # flattened tensor. 

4962 input_ta = tuple( 

4963 tf.TensorArray( 

4964 dtype=inp.dtype, 

4965 size=time_steps_t, 

4966 tensor_array_name=f"input_ta_{i}", 

4967 ) 

4968 for i, inp in enumerate(flatted_inputs) 

4969 ) 

4970 input_ta = tuple( 

4971 ta.unstack(input_) 

4972 if not go_backwards 

4973 else ta.unstack(reverse(input_, 0)) 

4974 for ta, input_ in zip(input_ta, flatted_inputs) 

4975 ) 

4976 

4977 # Get the time(0) input and compute the output for that, the output will 

4978 # be used to determine the dtype of output tensor array. Don't read from 

4979 # input_ta due to TensorArray clear_after_read default to True. 

4980 input_time_zero = tf.nest.pack_sequence_as( 

4981 inputs, [inp[0] for inp in flatted_inputs] 

4982 ) 

4983 # output_time_zero is used to determine the cell output shape and its 

4984 # dtype. the value is discarded. 

4985 output_time_zero, _ = step_function( 

4986 input_time_zero, tuple(initial_states) + tuple(constants) 

4987 ) 

4988 

4989 output_ta_size = time_steps_t if return_all_outputs else 1 

4990 output_ta = tuple( 

4991 tf.TensorArray( 

4992 dtype=out.dtype, 

4993 size=output_ta_size, 

4994 element_shape=out.shape, 

4995 tensor_array_name=f"output_ta_{i}", 

4996 ) 

4997 for i, out in enumerate(tf.nest.flatten(output_time_zero)) 

4998 ) 

4999 

5000 time = tf.constant(0, dtype="int32", name="time") 

5001 

5002 # We only specify the 'maximum_iterations' when building for XLA since 

5003 # that causes slowdowns on GPU in TF. 

5004 if ( 

5005 not tf.executing_eagerly() 

5006 and control_flow_util.GraphOrParentsInXlaContext( 

5007 tf.compat.v1.get_default_graph() 

5008 ) 

5009 ): 

5010 if input_length is None: 

5011 max_iterations = time_steps_t 

5012 else: 

5013 max_iterations = tf.reduce_max(input_length) 

5014 else: 

5015 max_iterations = None 

5016 

5017 while_loop_kwargs = { 

5018 "cond": lambda time, *_: time < time_steps_t, 

5019 "maximum_iterations": max_iterations, 

5020 "parallel_iterations": 32, 

5021 "swap_memory": True, 

5022 } 

5023 if mask is not None: 

5024 if go_backwards: 

5025 mask = reverse(mask, 0) 

5026 

5027 mask_ta = tf.TensorArray( 

5028 dtype=tf.bool, size=time_steps_t, tensor_array_name="mask_ta" 

5029 ) 

5030 mask_ta = mask_ta.unstack(mask) 

5031 

5032 def masking_fn(time): 

5033 return mask_ta.read(time) 

5034 

5035 def compute_masked_output(mask_t, flat_out, flat_mask): 

5036 tiled_mask_t = tuple( 

5037 _expand_mask(mask_t, o, fixed_dim=len(mask_t.shape)) 

5038 for o in flat_out 

5039 ) 

5040 return tuple( 

5041 tf.where(m, o, fm) 

5042 for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask) 

5043 ) 

5044 

5045 elif isinstance(input_length, tf.Tensor): 

5046 if go_backwards: 

5047 max_len = tf.reduce_max(input_length, axis=0) 

5048 rev_input_length = tf.subtract(max_len - 1, input_length) 

5049 

5050 def masking_fn(time): 

5051 return tf.less(rev_input_length, time) 

5052 

5053 else: 

5054 

5055 def masking_fn(time): 

5056 return tf.greater(input_length, time) 

5057 

5058 def compute_masked_output(mask_t, flat_out, flat_mask): 

5059 return tuple( 

5060 tf.compat.v1.where(mask_t, o, zo) 

5061 for (o, zo) in zip(flat_out, flat_mask) 

5062 ) 

5063 

5064 else: 

5065 masking_fn = None 

5066 

5067 if masking_fn is not None: 

5068 # Mask for the T output will be base on the output of T - 1. In the 

5069 # case T = 0, a zero filled tensor will be used. 

5070 flat_zero_output = tuple( 

5071 tf.zeros_like(o) for o in tf.nest.flatten(output_time_zero) 

5072 ) 

5073 

5074 def _step(time, output_ta_t, prev_output, *states): 

5075 """RNN step function. 

5076 

5077 Args: 

5078 time: Current timestep value. 

5079 output_ta_t: TensorArray. 

5080 prev_output: tuple of outputs from time - 1. 

5081 *states: List of states. 

5082 

5083 Returns: 

5084 Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)` 

5085 """ 

5086 current_input = tuple(ta.read(time) for ta in input_ta) 

5087 # maybe set shape. 

5088 current_input = tf.nest.pack_sequence_as(inputs, current_input) 

5089 mask_t = masking_fn(time) 

5090 output, new_states = step_function( 

5091 current_input, tuple(states) + tuple(constants) 

5092 ) 

5093 # mask output 

5094 flat_output = tf.nest.flatten(output) 

5095 flat_mask_output = ( 

5096 flat_zero_output 

5097 if zero_output_for_mask 

5098 else tf.nest.flatten(prev_output) 

5099 ) 

5100 flat_new_output = compute_masked_output( 

5101 mask_t, flat_output, flat_mask_output 

5102 ) 

5103 

5104 # mask states 

5105 flat_state = tf.nest.flatten(states) 

5106 flat_new_state = tf.nest.flatten(new_states) 

5107 for state, new_state in zip(flat_state, flat_new_state): 

5108 if isinstance(new_state, tf.Tensor): 

5109 new_state.set_shape(state.shape) 

5110 flat_final_state = compute_masked_output( 

5111 mask_t, flat_new_state, flat_state 

5112 ) 

5113 new_states = tf.nest.pack_sequence_as( 

5114 new_states, flat_final_state 

5115 ) 

5116 

5117 ta_index_to_write = time if return_all_outputs else 0 

5118 output_ta_t = tuple( 

5119 ta.write(ta_index_to_write, out) 

5120 for ta, out in zip(output_ta_t, flat_new_output) 

5121 ) 

5122 

5123 return (time + 1, output_ta_t, tuple(flat_new_output)) + tuple( 

5124 new_states 

5125 ) 

5126 

5127 final_outputs = tf.compat.v1.while_loop( 

5128 body=_step, 

5129 loop_vars=(time, output_ta, flat_zero_output) + states, 

5130 **while_loop_kwargs, 

5131 ) 

5132 # Skip final_outputs[2] which is the output for final timestep. 

5133 new_states = final_outputs[3:] 

5134 else: 

5135 

5136 def _step(time, output_ta_t, *states): 

5137 """RNN step function. 

5138 

5139 Args: 

5140 time: Current timestep value. 

5141 output_ta_t: TensorArray. 

5142 *states: List of states. 

5143 

5144 Returns: 

5145 Tuple: `(time + 1,output_ta_t) + tuple(new_states)` 

5146 """ 

5147 current_input = tuple(ta.read(time) for ta in input_ta) 

5148 current_input = tf.nest.pack_sequence_as(inputs, current_input) 

5149 output, new_states = step_function( 

5150 current_input, tuple(states) + tuple(constants) 

5151 ) 

5152 flat_state = tf.nest.flatten(states) 

5153 flat_new_state = tf.nest.flatten(new_states) 

5154 for state, new_state in zip(flat_state, flat_new_state): 

5155 if isinstance(new_state, tf.Tensor): 

5156 new_state.set_shape(state.shape) 

5157 

5158 flat_output = tf.nest.flatten(output) 

5159 ta_index_to_write = time if return_all_outputs else 0 

5160 output_ta_t = tuple( 

5161 ta.write(ta_index_to_write, out) 

5162 for ta, out in zip(output_ta_t, flat_output) 

5163 ) 

5164 

5165 new_states = tf.nest.pack_sequence_as( 

5166 initial_states, flat_new_state 

5167 ) 

5168 return (time + 1, output_ta_t) + tuple(new_states) 

5169 

5170 final_outputs = tf.compat.v1.while_loop( 

5171 body=_step, 

5172 loop_vars=(time, output_ta) + states, 

5173 **while_loop_kwargs, 

5174 ) 

5175 new_states = final_outputs[2:] 

5176 

5177 output_ta = final_outputs[1] 

5178 

5179 outputs = tuple(o.stack() for o in output_ta) 

5180 last_output = tuple(o[-1] for o in outputs) 

5181 

5182 outputs = tf.nest.pack_sequence_as(output_time_zero, outputs) 

5183 last_output = tf.nest.pack_sequence_as(output_time_zero, last_output) 

5184 

5185 # static shape inference 

5186 def set_shape(output_): 

5187 if isinstance(output_, tf.Tensor): 

5188 shape = output_.shape.as_list() 

5189 if return_all_outputs: 

5190 shape[0] = time_steps 

5191 else: 

5192 shape[0] = 1 

5193 shape[1] = batch 

5194 output_.set_shape(shape) 

5195 return output_ 

5196 

5197 outputs = tf.nest.map_structure(set_shape, outputs) 

5198 

5199 if not time_major: 

5200 outputs = tf.nest.map_structure(swap_batch_timestep, outputs) 

5201 

5202 return last_output, outputs, new_states 

5203 

5204 

5205@keras_export("keras.backend.switch") 

5206@tf.__internal__.dispatch.add_dispatch_support 

5207@doc_controls.do_not_generate_docs 

5208def switch(condition, then_expression, else_expression): 

5209 """Switches between two operations depending on a scalar value. 

5210 

5211 Note that both `then_expression` and `else_expression` 

5212 should be symbolic tensors of the *same shape*. 

5213 

5214 Args: 

5215 condition: tensor (`int` or `bool`). 

5216 then_expression: either a tensor, or a callable that returns a tensor. 

5217 else_expression: either a tensor, or a callable that returns a tensor. 

5218 

5219 Returns: 

5220 The selected tensor. 

5221 

5222 Raises: 

5223 ValueError: If rank of `condition` is greater than rank of expressions. 

5224 """ 

5225 if condition.dtype != tf.bool: 

5226 condition = tf.cast(condition, "bool") 

5227 cond_ndim = ndim(condition) 

5228 if not cond_ndim: 

5229 if not callable(then_expression): 

5230 

5231 def then_expression_fn(): 

5232 return then_expression 

5233 

5234 else: 

5235 then_expression_fn = then_expression 

5236 if not callable(else_expression): 

5237 

5238 def else_expression_fn(): 

5239 return else_expression 

5240 

5241 else: 

5242 else_expression_fn = else_expression 

5243 x = tf.compat.v1.cond(condition, then_expression_fn, else_expression_fn) 

5244 else: 

5245 # tf.where needs its condition tensor 

5246 # to be the same shape as its two 

5247 # result tensors 

5248 if callable(then_expression): 

5249 then_expression = then_expression() 

5250 if callable(else_expression): 

5251 else_expression = else_expression() 

5252 expr_ndim = ndim(then_expression) 

5253 if cond_ndim > expr_ndim: 

5254 raise ValueError( 

5255 "Rank of `condition` should be less than or" 

5256 " equal to rank of `then_expression` and " 

5257 "`else_expression`. ndim(condition)=" 

5258 + str(cond_ndim) 

5259 + ", ndim(then_expression)=" 

5260 + str(expr_ndim) 

5261 ) 

5262 if cond_ndim > 1: 

5263 ndim_diff = expr_ndim - cond_ndim 

5264 cond_shape = tf.concat( 

5265 [tf.shape(condition), [1] * ndim_diff], axis=0 

5266 ) 

5267 condition = tf.reshape(condition, cond_shape) 

5268 expr_shape = tf.shape(then_expression) 

5269 shape_diff = expr_shape - cond_shape 

5270 tile_shape = tf.where( 

5271 shape_diff > 0, expr_shape, tf.ones_like(expr_shape) 

5272 ) 

5273 condition = tf.tile(condition, tile_shape) 

5274 x = tf.where(condition, then_expression, else_expression) 

5275 return x 

5276 

5277 

5278@keras_export("keras.backend.in_train_phase") 

5279@doc_controls.do_not_generate_docs 

5280def in_train_phase(x, alt, training=None): 

5281 """Selects `x` in train phase, and `alt` otherwise. 

5282 

5283 Note that `alt` should have the *same shape* as `x`. 

5284 

5285 Args: 

5286 x: What to return in train phase 

5287 (tensor or callable that returns a tensor). 

5288 alt: What to return otherwise 

5289 (tensor or callable that returns a tensor). 

5290 training: Optional scalar tensor 

5291 (or Python boolean, or Python integer) 

5292 specifying the learning phase. 

5293 

5294 Returns: 

5295 Either `x` or `alt` based on the `training` flag. 

5296 the `training` flag defaults to `K.learning_phase()`. 

5297 """ 

5298 from keras.src.engine import ( 

5299 base_layer_utils, 

5300 ) 

5301 

5302 if training is None: 

5303 training = base_layer_utils.call_context().training 

5304 

5305 if training is None: 

5306 training = learning_phase() 

5307 

5308 # TODO(b/138862903): Handle the case when training is tensor. 

5309 if not tf.is_tensor(training): 

5310 if training == 1 or training is True: 

5311 if callable(x): 

5312 return x() 

5313 else: 

5314 return x 

5315 

5316 elif training == 0 or training is False: 

5317 if callable(alt): 

5318 return alt() 

5319 else: 

5320 return alt 

5321 

5322 # else: assume learning phase is a placeholder tensor. 

5323 x = switch(training, x, alt) 

5324 return x 

5325 

5326 

5327@keras_export("keras.backend.in_test_phase") 

5328@doc_controls.do_not_generate_docs 

5329def in_test_phase(x, alt, training=None): 

5330 """Selects `x` in test phase, and `alt` otherwise. 

5331 

5332 Note that `alt` should have the *same shape* as `x`. 

5333 

5334 Args: 

5335 x: What to return in test phase 

5336 (tensor or callable that returns a tensor). 

5337 alt: What to return otherwise 

5338 (tensor or callable that returns a tensor). 

5339 training: Optional scalar tensor 

5340 (or Python boolean, or Python integer) 

5341 specifying the learning phase. 

5342 

5343 Returns: 

5344 Either `x` or `alt` based on `K.learning_phase`. 

5345 """ 

5346 return in_train_phase(alt, x, training=training) 

5347 

5348 

5349# NN OPERATIONS 

5350 

5351 

5352@keras_export("keras.backend.relu") 

5353@tf.__internal__.dispatch.add_dispatch_support 

5354@doc_controls.do_not_generate_docs 

5355def relu(x, alpha=0.0, max_value=None, threshold=0.0): 

5356 """Rectified linear unit. 

5357 

5358 With default values, it returns element-wise `max(x, 0)`. 

5359 

5360 Otherwise, it follows: 

5361 `f(x) = max_value` for `x >= max_value`, 

5362 `f(x) = x` for `threshold <= x < max_value`, 

5363 `f(x) = alpha * (x - threshold)` otherwise. 

5364 

5365 Args: 

5366 x: A tensor or variable. 

5367 alpha: A scalar, slope of negative section (default=`0.`). 

5368 max_value: float. Saturation threshold. 

5369 threshold: float. Threshold value for thresholded activation. 

5370 

5371 Returns: 

5372 A tensor. 

5373 """ 

5374 # While x can be a tensor or variable, we also see cases where 

5375 # numpy arrays, lists, tuples are passed as well. 

5376 # lists, tuples do not have 'dtype' attribute. 

5377 dtype = getattr(x, "dtype", floatx()) 

5378 if alpha != 0.0: 

5379 if max_value is None and threshold == 0: 

5380 return tf.nn.leaky_relu(x, alpha=alpha) 

5381 

5382 if threshold != 0: 

5383 negative_part = tf.nn.relu(-x + threshold) 

5384 else: 

5385 negative_part = tf.nn.relu(-x) 

5386 

5387 clip_max = max_value is not None 

5388 

5389 if threshold != 0: 

5390 # computes x for x > threshold else 0 

5391 x = x * tf.cast(tf.greater(x, threshold), dtype=dtype) 

5392 elif max_value == 6: 

5393 # if no threshold, then can use nn.relu6 native TF op for performance 

5394 x = tf.nn.relu6(x) 

5395 clip_max = False 

5396 else: 

5397 x = tf.nn.relu(x) 

5398 

5399 if clip_max: 

5400 max_value = _constant_to_tensor(max_value, x.dtype.base_dtype) 

5401 zero = _constant_to_tensor(0, x.dtype.base_dtype) 

5402 x = tf.clip_by_value(x, zero, max_value) 

5403 

5404 if alpha != 0.0: 

5405 alpha = _to_tensor(alpha, x.dtype.base_dtype) 

5406 x -= alpha * negative_part 

5407 return x 

5408 

5409 

5410@keras_export("keras.backend.elu") 

5411@tf.__internal__.dispatch.add_dispatch_support 

5412@doc_controls.do_not_generate_docs 

5413def elu(x, alpha=1.0): 

5414 """Exponential linear unit. 

5415 

5416 Args: 

5417 x: A tensor or variable to compute the activation function for. 

5418 alpha: A scalar, slope of negative section. 

5419 

5420 Returns: 

5421 A tensor. 

5422 """ 

5423 res = tf.nn.elu(x) 

5424 if alpha == 1: 

5425 return res 

5426 else: 

5427 return tf.where(x > 0, res, alpha * res) 

5428 

5429 

5430@keras_export("keras.backend.softmax") 

5431@tf.__internal__.dispatch.add_dispatch_support 

5432@doc_controls.do_not_generate_docs 

5433def softmax(x, axis=-1): 

5434 """Softmax of a tensor. 

5435 

5436 Args: 

5437 x: A tensor or variable. 

5438 axis: The dimension softmax would be performed on. 

5439 The default is -1 which indicates the last dimension. 

5440 

5441 Returns: 

5442 A tensor. 

5443 """ 

5444 return tf.nn.softmax(x, axis=axis) 

5445 

5446 

5447@keras_export("keras.backend.softplus") 

5448@tf.__internal__.dispatch.add_dispatch_support 

5449@doc_controls.do_not_generate_docs 

5450def softplus(x): 

5451 """Softplus of a tensor. 

5452 

5453 Args: 

5454 x: A tensor or variable. 

5455 

5456 Returns: 

5457 A tensor. 

5458 """ 

5459 return tf.math.softplus(x) 

5460 

5461 

5462@keras_export("keras.backend.softsign") 

5463@tf.__internal__.dispatch.add_dispatch_support 

5464@doc_controls.do_not_generate_docs 

5465def softsign(x): 

5466 """Softsign of a tensor. 

5467 

5468 Args: 

5469 x: A tensor or variable. 

5470 

5471 Returns: 

5472 A tensor. 

5473 """ 

5474 return tf.math.softsign(x) 

5475 

5476 

5477def _get_logits(output, from_logits, op_type, fn_name): 

5478 output_ = output 

5479 from_logits_ = from_logits 

5480 

5481 has_keras_logits = hasattr(output, "_keras_logits") 

5482 if has_keras_logits: 

5483 output_ = output._keras_logits 

5484 from_logits_ = True 

5485 

5486 from_expected_op_type = ( 

5487 not isinstance(output, (tf.__internal__.EagerTensor, tf.Variable)) 

5488 and output.op.type == op_type 

5489 ) and not has_keras_logits 

5490 

5491 if from_expected_op_type: 

5492 # When softmax activation function is used for output operation, we 

5493 # use logits from the softmax function directly to compute loss in order 

5494 # to prevent collapsing zero when training. 

5495 # See b/117284466 

5496 assert len(output.op.inputs) == 1 

5497 output_ = output.op.inputs[0] 

5498 from_logits_ = True 

5499 

5500 if from_logits and (has_keras_logits or from_expected_op_type): 

5501 warnings.warn( 

5502 f'"`{fn_name}` received `from_logits=True`, but ' 

5503 f"the `output` argument was produced by a {op_type} " 

5504 "activation and thus does not represent logits. " 

5505 "Was this intended?", 

5506 stacklevel=2, 

5507 ) 

5508 

5509 return output_, from_logits_ 

5510 

5511 

5512@keras_export("keras.backend.categorical_crossentropy") 

5513@tf.__internal__.dispatch.add_dispatch_support 

5514@doc_controls.do_not_generate_docs 

5515def categorical_crossentropy(target, output, from_logits=False, axis=-1): 

5516 """Categorical crossentropy between an output tensor and a target tensor. 

5517 

5518 Args: 

5519 target: A tensor of the same shape as `output`. 

5520 output: A tensor resulting from a softmax 

5521 (unless `from_logits` is True, in which 

5522 case `output` is expected to be the logits). 

5523 from_logits: Boolean, whether `output` is the 

5524 result of a softmax, or is a tensor of logits. 

5525 axis: Int specifying the channels axis. `axis=-1` corresponds to data 

5526 format `channels_last`, and `axis=1` corresponds to data format 

5527 `channels_first`. 

5528 

5529 Returns: 

5530 Output tensor. 

5531 

5532 Raises: 

5533 ValueError: if `axis` is neither -1 nor one of the axes of `output`. 

5534 

5535 Example: 

5536 

5537 >>> a = tf.constant([1., 0., 0., 0., 1., 0., 0., 0., 1.], shape=[3,3]) 

5538 >>> print(a) 

5539 tf.Tensor( 

5540 [[1. 0. 0.] 

5541 [0. 1. 0.] 

5542 [0. 0. 1.]], shape=(3, 3), dtype=float32) 

5543 >>> b = tf.constant([.9, .05, .05, .05, .89, .06, .05, .01, .94], 

5544 ... shape=[3, 3]) 

5545 >>> print(b) 

5546 tf.Tensor( 

5547 [[0.9 0.05 0.05] 

5548 [0.05 0.89 0.06] 

5549 [0.05 0.01 0.94]], shape=(3, 3), dtype=float32) 

5550 >>> loss = tf.keras.backend.categorical_crossentropy(a, b) 

5551 >>> print(np.around(loss, 5)) 

5552 [0.10536 0.11653 0.06188] 

5553 >>> loss = tf.keras.backend.categorical_crossentropy(a, a) 

5554 >>> print(np.around(loss, 5)) 

5555 [0. 0. 0.] 

5556 

5557 """ 

5558 target = tf.convert_to_tensor(target) 

5559 output = tf.convert_to_tensor(output) 

5560 target.shape.assert_is_compatible_with(output.shape) 

5561 

5562 output, from_logits = _get_logits( 

5563 output, from_logits, "Softmax", "categorical_crossentropy" 

5564 ) 

5565 if from_logits: 

5566 return tf.nn.softmax_cross_entropy_with_logits( 

5567 labels=target, logits=output, axis=axis 

5568 ) 

5569 

5570 # Adjust the predictions so that the probability of 

5571 # each class for every sample adds up to 1 

5572 # This is needed to ensure that the cross entropy is 

5573 # computed correctly. 

5574 output = output / tf.reduce_sum(output, axis, True) 

5575 

5576 # Compute cross entropy from probabilities. 

5577 epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype) 

5578 output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_) 

5579 return -tf.reduce_sum(target * tf.math.log(output), axis) 

5580 

5581 

5582@keras_export("keras.backend.categorical_focal_crossentropy") 

5583@tf.__internal__.dispatch.add_dispatch_support 

5584@doc_controls.do_not_generate_docs 

5585def categorical_focal_crossentropy( 

5586 target, 

5587 output, 

5588 alpha=0.25, 

5589 gamma=2.0, 

5590 from_logits=False, 

5591 axis=-1, 

5592): 

5593 """Computes the alpha balanced focal crossentropy loss. 

5594 

5595 According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it 

5596 helps to apply a focal factor to down-weight easy examples and focus more on 

5597 hard examples. The general formula for the focal loss (FL) 

5598 is as follows: 

5599 

5600 `FL(p_t) = (1 − p_t)^gamma * log(p_t)` 

5601 

5602 where `p_t` is defined as follows: 

5603 `p_t = output if y_true == 1, else 1 - output` 

5604 

5605 `(1 − p_t)^gamma` is the `modulating_factor`, where `gamma` is a focusing 

5606 parameter. When `gamma` = 0, there is no focal effect on the cross entropy. 

5607 `gamma` reduces the importance given to simple examples in a smooth manner. 

5608 

5609 The authors use alpha-balanced variant of focal loss (FL) in the paper: 

5610 `FL(p_t) = −alpha * (1 − p_t)^gamma * log(p_t)` 

5611 

5612 where `alpha` is the weight factor for the classes. If `alpha` = 1, the 

5613 loss won't be able to handle class imbalance properly as all 

5614 classes will have the same weight. This can be a constant or a list of 

5615 constants. If alpha is a list, it must have the same length as the number 

5616 of classes. 

5617 

5618 The formula above can be generalized to: 

5619 `FL(p_t) = alpha * (1 − p_t)^gamma * CrossEntropy(target, output)` 

5620 

5621 where minus comes from `CrossEntropy(target, output)` (CE). 

5622 

5623 Extending this to multi-class case is straightforward: 

5624 `FL(p_t) = alpha * (1 − p_t)^gamma * CategoricalCE(target, output)` 

5625 

5626 Args: 

5627 target: Ground truth values from the dataset. 

5628 output: Predictions of the model. 

5629 alpha: A weight balancing factor for all classes, default is `0.25` as 

5630 mentioned in the reference. It can be a list of floats or a scalar. 

5631 In the multi-class case, alpha may be set by inverse class 

5632 frequency by using `compute_class_weight` from `sklearn.utils`. 

5633 gamma: A focusing parameter, default is `2.0` as mentioned in the 

5634 reference. It helps to gradually reduce the importance given to 

5635 simple examples in a smooth manner. 

5636 from_logits: Whether `output` is expected to be a logits tensor. By 

5637 default, we consider that `output` encodes a probability 

5638 distribution. 

5639 axis: Int specifying the channels axis. `axis=-1` corresponds to data 

5640 format `channels_last`, and `axis=1` corresponds to data format 

5641 `channels_first`. 

5642 

5643 Returns: 

5644 A tensor. 

5645 """ 

5646 target = tf.convert_to_tensor(target) 

5647 output = tf.convert_to_tensor(output) 

5648 target.shape.assert_is_compatible_with(output.shape) 

5649 

5650 output, from_logits = _get_logits( 

5651 output, from_logits, "Softmax", "categorical_focal_crossentropy" 

5652 ) 

5653 

5654 if from_logits: 

5655 output = tf.nn.softmax(output, axis=axis) 

5656 

5657 # Adjust the predictions so that the probability of 

5658 # each class for every sample adds up to 1 

5659 # This is needed to ensure that the cross entropy is 

5660 # computed correctly. 

5661 output = output / tf.reduce_sum(output, axis=axis, keepdims=True) 

5662 

5663 epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype) 

5664 output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_) 

5665 

5666 # Calculate cross entropy 

5667 cce = -target * tf.math.log(output) 

5668 

5669 # Calculate factors 

5670 modulating_factor = tf.pow(1.0 - output, gamma) 

5671 weighting_factor = tf.multiply(modulating_factor, alpha) 

5672 

5673 # Apply weighting factor 

5674 focal_cce = tf.multiply(weighting_factor, cce) 

5675 focal_cce = tf.reduce_sum(focal_cce, axis=axis) 

5676 return focal_cce 

5677 

5678 

5679@keras_export("keras.backend.sparse_categorical_crossentropy") 

5680@tf.__internal__.dispatch.add_dispatch_support 

5681@doc_controls.do_not_generate_docs 

5682def sparse_categorical_crossentropy( 

5683 target, output, from_logits=False, axis=-1, ignore_class=None 

5684): 

5685 """Categorical crossentropy with integer targets. 

5686 

5687 Args: 

5688 target: An integer tensor. 

5689 output: A tensor resulting from a softmax 

5690 (unless `from_logits` is True, in which 

5691 case `output` is expected to be the logits). 

5692 from_logits: Boolean, whether `output` is the 

5693 result of a softmax, or is a tensor of logits. 

5694 axis: Int specifying the channels axis. `axis=-1` corresponds to data 

5695 format `channels_last`, and `axis=1` corresponds to data format 

5696 `channels_first`. 

5697 ignore_class: Optional integer. The ID of a class to be ignored 

5698 during loss computation. This is useful, for example, in 

5699 segmentation problems featuring a "void" class (commonly -1 

5700 or 255) in segmentation maps. 

5701 By default (`ignore_class=None`), all classes are considered. 

5702 

5703 Returns: 

5704 Output tensor. 

5705 

5706 Raises: 

5707 ValueError: if `axis` is neither -1 nor one of the axes of `output`. 

5708 """ 

5709 target = tf.convert_to_tensor(target) 

5710 output = tf.convert_to_tensor(output) 

5711 

5712 target = cast(target, "int64") 

5713 

5714 output, from_logits = _get_logits( 

5715 output, from_logits, "Softmax", "sparse_categorical_crossentropy" 

5716 ) 

5717 if not from_logits: 

5718 epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype) 

5719 output = tf.clip_by_value(output, epsilon_, 1 - epsilon_) 

5720 output = tf.math.log(output) 

5721 

5722 # Permute output so that the last axis contains the logits/probabilities. 

5723 if isinstance(output.shape, (tuple, list)): 

5724 output_rank = len(output.shape) 

5725 else: 

5726 output_rank = output.shape.ndims 

5727 if output_rank is not None: 

5728 axis %= output_rank 

5729 if axis != output_rank - 1: 

5730 permutation = list( 

5731 itertools.chain( 

5732 range(axis), range(axis + 1, output_rank), [axis] 

5733 ) 

5734 ) 

5735 output = tf.compat.v1.transpose(output, perm=permutation) 

5736 elif axis != -1: 

5737 raise ValueError( 

5738 "Cannot compute sparse categorical crossentropy with `axis={}` " 

5739 "on an output tensor with unknown rank".format(axis) 

5740 ) 

5741 

5742 # Try to adjust the shape so that rank of labels = rank of logits - 1. 

5743 output_shape = tf.shape(output) 

5744 target_rank = target.shape.ndims 

5745 

5746 update_shape = ( 

5747 target_rank is not None 

5748 and output_rank is not None 

5749 and target_rank != output_rank - 1 

5750 ) 

5751 if update_shape: 

5752 target = flatten(target) 

5753 output = tf.reshape(output, [-1, output_shape[-1]]) 

5754 

5755 if ignore_class is not None: 

5756 valid_mask = tf.not_equal(target, cast(ignore_class, target.dtype)) 

5757 target = target[valid_mask] 

5758 output = output[valid_mask] 

5759 

5760 if py_any(_is_symbolic_tensor(v) for v in [target, output]): 

5761 with get_graph().as_default(): 

5762 res = tf.nn.sparse_softmax_cross_entropy_with_logits( 

5763 labels=target, logits=output 

5764 ) 

5765 else: 

5766 res = tf.nn.sparse_softmax_cross_entropy_with_logits( 

5767 labels=target, logits=output 

5768 ) 

5769 

5770 if ignore_class is not None: 

5771 res_shape = cast(output_shape[:-1], "int64") 

5772 valid_mask = tf.reshape(valid_mask, res_shape) 

5773 res = tf.scatter_nd(tf.where(valid_mask), res, res_shape) 

5774 res._keras_mask = valid_mask 

5775 

5776 return res 

5777 

5778 if update_shape and output_rank >= 3: 

5779 # If our output includes timesteps or 

5780 # spatial dimensions we need to reshape 

5781 res = tf.reshape(res, output_shape[:-1]) 

5782 

5783 return res 

5784 

5785 

5786@keras_export("keras.backend.binary_crossentropy") 

5787@tf.__internal__.dispatch.add_dispatch_support 

5788@doc_controls.do_not_generate_docs 

5789def binary_crossentropy(target, output, from_logits=False): 

5790 """Binary crossentropy between an output tensor and a target tensor. 

5791 

5792 Args: 

5793 target: A tensor with the same shape as `output`. 

5794 output: A tensor. 

5795 from_logits: Whether `output` is expected to be a logits tensor. 

5796 By default, we consider that `output` 

5797 encodes a probability distribution. 

5798 

5799 Returns: 

5800 A tensor. 

5801 """ 

5802 target = tf.convert_to_tensor(target) 

5803 output = tf.convert_to_tensor(output) 

5804 

5805 output, from_logits = _get_logits( 

5806 output, from_logits, "Sigmoid", "binary_crossentropy" 

5807 ) 

5808 if from_logits: 

5809 return tf.nn.sigmoid_cross_entropy_with_logits( 

5810 labels=target, logits=output 

5811 ) 

5812 

5813 epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype) 

5814 output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_) 

5815 

5816 # Compute cross entropy from probabilities. 

5817 bce = target * tf.math.log(output + epsilon()) 

5818 bce += (1 - target) * tf.math.log(1 - output + epsilon()) 

5819 return -bce 

5820 

5821 

5822@keras_export("keras.backend.binary_focal_crossentropy") 

5823@tf.__internal__.dispatch.add_dispatch_support 

5824@doc_controls.do_not_generate_docs 

5825def binary_focal_crossentropy( 

5826 target, 

5827 output, 

5828 apply_class_balancing=False, 

5829 alpha=0.25, 

5830 gamma=2.0, 

5831 from_logits=False, 

5832): 

5833 """Binary focal crossentropy between an output tensor and a target tensor. 

5834 

5835 According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it 

5836 helps to apply a focal factor to down-weight easy examples and focus more on 

5837 hard examples. By default, the focal tensor is computed as follows: 

5838 

5839 `focal_factor = (1 - output) ** gamma` for class 1 

5840 `focal_factor = output ** gamma` for class 0 

5841 where `gamma` is a focusing parameter. When `gamma` = 0, there is no focal 

5842 effect on the binary crossentropy. 

5843 

5844 If `apply_class_balancing == True`, this function also takes into account a 

5845 weight balancing factor for the binary classes 0 and 1 as follows: 

5846 

5847 `weight = alpha` for class 1 (`target == 1`) 

5848 `weight = 1 - alpha` for class 0 

5849 where `alpha` is a float in the range of `[0, 1]`. 

5850 

5851 Args: 

5852 target: A tensor with the same shape as `output`. 

5853 output: A tensor. 

5854 apply_class_balancing: A bool, whether to apply weight balancing on the 

5855 binary classes 0 and 1. 

5856 alpha: A weight balancing factor for class 1, default is `0.25` as 

5857 mentioned in the reference. The weight for class 0 is `1.0 - alpha`. 

5858 gamma: A focusing parameter, default is `2.0` as mentioned in the 

5859 reference. 

5860 from_logits: Whether `output` is expected to be a logits tensor. By 

5861 default, we consider that `output` encodes a probability 

5862 distribution. 

5863 

5864 Returns: 

5865 A tensor. 

5866 """ 

5867 

5868 sigmoidal = sigmoid(output) if from_logits else output 

5869 

5870 p_t = target * sigmoidal + (1 - target) * (1 - sigmoidal) 

5871 

5872 # Calculate focal factor 

5873 focal_factor = tf.pow(1.0 - p_t, gamma) 

5874 

5875 # Binary crossentropy 

5876 bce = binary_crossentropy( 

5877 target=target, 

5878 output=output, 

5879 from_logits=from_logits, 

5880 ) 

5881 focal_bce = focal_factor * bce 

5882 

5883 if apply_class_balancing: 

5884 weight = target * alpha + (1 - target) * (1 - alpha) 

5885 focal_bce = weight * focal_bce 

5886 

5887 return focal_bce 

5888 

5889 

5890@keras_export("keras.backend.sigmoid") 

5891@tf.__internal__.dispatch.add_dispatch_support 

5892@doc_controls.do_not_generate_docs 

5893def sigmoid(x): 

5894 """Element-wise sigmoid. 

5895 

5896 Args: 

5897 x: A tensor or variable. 

5898 

5899 Returns: 

5900 A tensor. 

5901 """ 

5902 return tf.math.sigmoid(x) 

5903 

5904 

5905@keras_export("keras.backend.hard_sigmoid") 

5906@tf.__internal__.dispatch.add_dispatch_support 

5907@doc_controls.do_not_generate_docs 

5908def hard_sigmoid(x): 

5909 """Segment-wise linear approximation of sigmoid. 

5910 

5911 Faster than sigmoid. 

5912 Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`. 

5913 In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`. 

5914 

5915 Args: 

5916 x: A tensor or variable. 

5917 

5918 Returns: 

5919 A tensor. 

5920 """ 

5921 point_two = _constant_to_tensor(0.2, x.dtype.base_dtype) 

5922 point_five = _constant_to_tensor(0.5, x.dtype.base_dtype) 

5923 x = tf.multiply(x, point_two) 

5924 x = tf.add(x, point_five) 

5925 x = tf.clip_by_value(x, 0.0, 1.0) 

5926 return x 

5927 

5928 

5929@keras_export("keras.backend.tanh") 

5930@tf.__internal__.dispatch.add_dispatch_support 

5931@doc_controls.do_not_generate_docs 

5932def tanh(x): 

5933 """Element-wise tanh. 

5934 

5935 Args: 

5936 x: A tensor or variable. 

5937 

5938 Returns: 

5939 A tensor. 

5940 """ 

5941 return tf.tanh(x) 

5942 

5943 

5944@keras_export("keras.backend.dropout") 

5945@tf.__internal__.dispatch.add_dispatch_support 

5946@doc_controls.do_not_generate_docs 

5947def dropout(x, level, noise_shape=None, seed=None): 

5948 """Sets entries in `x` to zero at random, while scaling the entire tensor. 

5949 

5950 Args: 

5951 x: tensor 

5952 level: fraction of the entries in the tensor 

5953 that will be set to 0. 

5954 noise_shape: shape for randomly generated keep/drop flags, 

5955 must be broadcastable to the shape of `x` 

5956 seed: random seed to ensure determinism. 

5957 

5958 Returns: 

5959 A tensor. 

5960 """ 

5961 if seed is None: 

5962 seed = np.random.randint(10e6) 

5963 return tf.nn.dropout(x, rate=level, noise_shape=noise_shape, seed=seed) 

5964 

5965 

5966@keras_export("keras.backend.l2_normalize") 

5967@tf.__internal__.dispatch.add_dispatch_support 

5968@doc_controls.do_not_generate_docs 

5969def l2_normalize(x, axis=None): 

5970 """Normalizes a tensor wrt the L2 norm alongside the specified axis. 

5971 

5972 Args: 

5973 x: Tensor or variable. 

5974 axis: axis along which to perform normalization. 

5975 

5976 Returns: 

5977 A tensor. 

5978 """ 

5979 return tf.linalg.l2_normalize(x, axis=axis) 

5980 

5981 

5982@keras_export("keras.backend.in_top_k") 

5983@tf.__internal__.dispatch.add_dispatch_support 

5984@doc_controls.do_not_generate_docs 

5985def in_top_k(predictions, targets, k): 

5986 """Returns whether the `targets` are in the top `k` `predictions`. 

5987 

5988 Args: 

5989 predictions: A tensor of shape `(batch_size, classes)` and type 

5990 `float32`. 

5991 targets: A 1D tensor of length `batch_size` and type `int32` or `int64`. 

5992 k: An `int`, number of top elements to consider. 

5993 

5994 Returns: 

5995 A 1D tensor of length `batch_size` and type `bool`. 

5996 `output[i]` is `True` if `predictions[i, targets[i]]` is within top-`k` 

5997 values of `predictions[i]`. 

5998 """ 

5999 return tf.compat.v1.math.in_top_k(predictions, targets, k) 

6000 

6001 

6002# CONVOLUTIONS 

6003 

6004 

6005def _preprocess_conv1d_input(x, data_format): 

6006 """Transpose and cast the input before the conv1d. 

6007 

6008 Args: 

6009 x: input tensor. 

6010 data_format: string, `"channels_last"` or `"channels_first"`. 

6011 

6012 Returns: 

6013 A tensor. 

6014 """ 

6015 tf_data_format = "NWC" # to pass TF Conv2dNative operations 

6016 if data_format == "channels_first": 

6017 if not _has_nchw_support(): 

6018 x = tf.compat.v1.transpose(x, (0, 2, 1)) # NCW -> NWC 

6019 else: 

6020 tf_data_format = "NCW" 

6021 return x, tf_data_format 

6022 

6023 

6024def _preprocess_conv2d_input(x, data_format, force_transpose=False): 

6025 """Transpose and cast the input before the conv2d. 

6026 

6027 Args: 

6028 x: input tensor. 

6029 data_format: string, `"channels_last"` or `"channels_first"`. 

6030 force_transpose: Boolean. If True, the input will always be transposed 

6031 from NCHW to NHWC if `data_format` is `"channels_first"`. 

6032 If False, the transposition only occurs on CPU (GPU ops are 

6033 assumed to support NCHW). 

6034 

6035 Returns: 

6036 A tensor. 

6037 """ 

6038 tf_data_format = "NHWC" 

6039 if data_format == "channels_first": 

6040 if not _has_nchw_support() or force_transpose: 

6041 x = tf.compat.v1.transpose(x, (0, 2, 3, 1)) # NCHW -> NHWC 

6042 else: 

6043 tf_data_format = "NCHW" 

6044 return x, tf_data_format 

6045 

6046 

6047def _preprocess_conv3d_input(x, data_format): 

6048 """Transpose and cast the input before the conv3d. 

6049 

6050 Args: 

6051 x: input tensor. 

6052 data_format: string, `"channels_last"` or `"channels_first"`. 

6053 

6054 Returns: 

6055 A tensor. 

6056 """ 

6057 tf_data_format = "NDHWC" 

6058 if data_format == "channels_first": 

6059 if not _has_nchw_support(): 

6060 x = tf.compat.v1.transpose(x, (0, 2, 3, 4, 1)) 

6061 else: 

6062 tf_data_format = "NCDHW" 

6063 return x, tf_data_format 

6064 

6065 

6066def _preprocess_padding(padding): 

6067 """Convert keras' padding to TensorFlow's padding. 

6068 

6069 Args: 

6070 padding: string, one of 'same' , 'valid' 

6071 

6072 Returns: 

6073 a string, one of 'SAME', 'VALID'. 

6074 

6075 Raises: 

6076 ValueError: if invalid `padding'` 

6077 """ 

6078 if padding == "same": 

6079 padding = "SAME" 

6080 elif padding == "valid": 

6081 padding = "VALID" 

6082 else: 

6083 raise ValueError("Invalid padding: " + str(padding)) 

6084 return padding 

6085 

6086 

6087@keras_export("keras.backend.conv1d") 

6088@tf.__internal__.dispatch.add_dispatch_support 

6089@doc_controls.do_not_generate_docs 

6090def conv1d( 

6091 x, kernel, strides=1, padding="valid", data_format=None, dilation_rate=1 

6092): 

6093 """1D convolution. 

6094 

6095 Args: 

6096 x: Tensor or variable. 

6097 kernel: kernel tensor. 

6098 strides: stride integer. 

6099 padding: string, `"same"`, `"causal"` or `"valid"`. 

6100 data_format: string, one of "channels_last", "channels_first". 

6101 dilation_rate: integer dilate rate. 

6102 

6103 Returns: 

6104 A tensor, result of 1D convolution. 

6105 

6106 Raises: 

6107 ValueError: if `data_format` is neither `channels_last` or 

6108 `channels_first`. 

6109 """ 

6110 if data_format is None: 

6111 data_format = image_data_format() 

6112 if data_format not in {"channels_first", "channels_last"}: 

6113 raise ValueError("Unknown data_format: " + str(data_format)) 

6114 

6115 kernel_shape = kernel.shape.as_list() 

6116 if padding == "causal": 

6117 # causal (dilated) convolution: 

6118 left_pad = dilation_rate * (kernel_shape[0] - 1) 

6119 x = temporal_padding(x, (left_pad, 0)) 

6120 padding = "valid" 

6121 padding = _preprocess_padding(padding) 

6122 

6123 x, tf_data_format = _preprocess_conv1d_input(x, data_format) 

6124 x = tf.compat.v1.nn.convolution( 

6125 input=x, 

6126 filter=kernel, 

6127 dilation_rate=dilation_rate, 

6128 strides=strides, 

6129 padding=padding, 

6130 data_format=tf_data_format, 

6131 ) 

6132 if data_format == "channels_first" and tf_data_format == "NWC": 

6133 x = tf.compat.v1.transpose(x, (0, 2, 1)) # NWC -> NCW 

6134 return x 

6135 

6136 

6137@keras_export("keras.backend.conv2d") 

6138@tf.__internal__.dispatch.add_dispatch_support 

6139@doc_controls.do_not_generate_docs 

6140def conv2d( 

6141 x, 

6142 kernel, 

6143 strides=(1, 1), 

6144 padding="valid", 

6145 data_format=None, 

6146 dilation_rate=(1, 1), 

6147): 

6148 """2D convolution. 

6149 

6150 Args: 

6151 x: Tensor or variable. 

6152 kernel: kernel tensor. 

6153 strides: strides tuple. 

6154 padding: string, `"same"` or `"valid"`. 

6155 data_format: `"channels_last"` or `"channels_first"`. 

6156 dilation_rate: tuple of 2 integers. 

6157 

6158 Returns: 

6159 A tensor, result of 2D convolution. 

6160 

6161 Raises: 

6162 ValueError: if `data_format` is neither `channels_last` or 

6163 `channels_first`. 

6164 """ 

6165 if data_format is None: 

6166 data_format = image_data_format() 

6167 if data_format not in {"channels_first", "channels_last"}: 

6168 raise ValueError("Unknown data_format: " + str(data_format)) 

6169 

6170 x, tf_data_format = _preprocess_conv2d_input(x, data_format) 

6171 padding = _preprocess_padding(padding) 

6172 x = tf.compat.v1.nn.convolution( 

6173 input=x, 

6174 filter=kernel, 

6175 dilation_rate=dilation_rate, 

6176 strides=strides, 

6177 padding=padding, 

6178 data_format=tf_data_format, 

6179 ) 

6180 if data_format == "channels_first" and tf_data_format == "NHWC": 

6181 x = tf.compat.v1.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW 

6182 return x 

6183 

6184 

6185@keras_export("keras.backend.conv2d_transpose") 

6186@tf.__internal__.dispatch.add_dispatch_support 

6187@doc_controls.do_not_generate_docs 

6188def conv2d_transpose( 

6189 x, 

6190 kernel, 

6191 output_shape, 

6192 strides=(1, 1), 

6193 padding="valid", 

6194 data_format=None, 

6195 dilation_rate=(1, 1), 

6196): 

6197 """2D deconvolution (i.e. 

6198 

6199 transposed convolution). 

6200 

6201 Args: 

6202 x: Tensor or variable. 

6203 kernel: kernel tensor. 

6204 output_shape: 1D int tensor for the output shape. 

6205 strides: strides tuple. 

6206 padding: string, `"same"` or `"valid"`. 

6207 data_format: string, `"channels_last"` or `"channels_first"`. 

6208 dilation_rate: Tuple of 2 integers. 

6209 

6210 Returns: 

6211 A tensor, result of transposed 2D convolution. 

6212 

6213 Raises: 

6214 ValueError: if `data_format` is neither `channels_last` or 

6215 `channels_first`. 

6216 """ 

6217 if data_format is None: 

6218 data_format = image_data_format() 

6219 if data_format not in {"channels_first", "channels_last"}: 

6220 raise ValueError("Unknown data_format: " + str(data_format)) 

6221 

6222 # `atrous_conv2d_transpose` only supports NHWC format, even on GPU. 

6223 if data_format == "channels_first" and dilation_rate != (1, 1): 

6224 force_transpose = True 

6225 else: 

6226 force_transpose = False 

6227 

6228 x, tf_data_format = _preprocess_conv2d_input( 

6229 x, data_format, force_transpose 

6230 ) 

6231 

6232 if data_format == "channels_first" and tf_data_format == "NHWC": 

6233 output_shape = ( 

6234 output_shape[0], 

6235 output_shape[2], 

6236 output_shape[3], 

6237 output_shape[1], 

6238 ) 

6239 if output_shape[0] is None: 

6240 output_shape = (shape(x)[0],) + tuple(output_shape[1:]) 

6241 

6242 if isinstance(output_shape, (tuple, list)): 

6243 output_shape = tf.stack(list(output_shape)) 

6244 

6245 padding = _preprocess_padding(padding) 

6246 if tf_data_format == "NHWC": 

6247 strides = (1,) + strides + (1,) 

6248 else: 

6249 strides = (1, 1) + strides 

6250 

6251 if dilation_rate == (1, 1): 

6252 x = tf.compat.v1.nn.conv2d_transpose( 

6253 x, 

6254 kernel, 

6255 output_shape, 

6256 strides, 

6257 padding=padding, 

6258 data_format=tf_data_format, 

6259 ) 

6260 else: 

6261 if dilation_rate[0] != dilation_rate[1]: 

6262 raise ValueError( 

6263 "Expected the 2 dimensions of the `dilation_rate` argument " 

6264 "to be equal to each other. " 

6265 f"Received: dilation_rate={dilation_rate}" 

6266 ) 

6267 x = tf.nn.atrous_conv2d_transpose( 

6268 x, kernel, output_shape, rate=dilation_rate[0], padding=padding 

6269 ) 

6270 if data_format == "channels_first" and tf_data_format == "NHWC": 

6271 x = tf.compat.v1.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW 

6272 return x 

6273 

6274 

6275def separable_conv1d( 

6276 x, 

6277 depthwise_kernel, 

6278 pointwise_kernel, 

6279 strides=1, 

6280 padding="valid", 

6281 data_format=None, 

6282 dilation_rate=1, 

6283): 

6284 """1D convolution with separable filters. 

6285 

6286 Args: 

6287 x: input tensor 

6288 depthwise_kernel: convolution kernel for the depthwise convolution. 

6289 pointwise_kernel: kernel for the 1x1 convolution. 

6290 strides: stride integer. 

6291 padding: string, `"same"` or `"valid"`. 

6292 data_format: string, `"channels_last"` or `"channels_first"`. 

6293 dilation_rate: integer dilation rate. 

6294 

6295 Returns: 

6296 Output tensor. 

6297 

6298 Raises: 

6299 ValueError: if `data_format` is neither `channels_last` or 

6300 `channels_first`. 

6301 """ 

6302 if data_format is None: 

6303 data_format = image_data_format() 

6304 if data_format not in {"channels_first", "channels_last"}: 

6305 raise ValueError("Unknown data_format: " + str(data_format)) 

6306 

6307 if isinstance(strides, int): 

6308 strides = (strides,) 

6309 if isinstance(dilation_rate, int): 

6310 dilation_rate = (dilation_rate,) 

6311 

6312 x, tf_data_format = _preprocess_conv1d_input(x, data_format) 

6313 padding = _preprocess_padding(padding) 

6314 if not isinstance(strides, tuple): 

6315 strides = tuple(strides) 

6316 if tf_data_format == "NWC": 

6317 spatial_start_dim = 1 

6318 strides = (1,) + strides * 2 + (1,) 

6319 else: 

6320 spatial_start_dim = 2 

6321 strides = (1, 1) + strides * 2 

6322 x = tf.expand_dims(x, spatial_start_dim) 

6323 depthwise_kernel = tf.expand_dims(depthwise_kernel, 0) 

6324 pointwise_kernel = tf.expand_dims(pointwise_kernel, 0) 

6325 dilation_rate = (1,) + dilation_rate 

6326 

6327 x = tf.nn.separable_conv2d( 

6328 x, 

6329 depthwise_kernel, 

6330 pointwise_kernel, 

6331 strides=strides, 

6332 padding=padding, 

6333 dilations=dilation_rate, 

6334 data_format=tf_data_format, 

6335 ) 

6336 

6337 x = tf.squeeze(x, [spatial_start_dim]) 

6338 

6339 if data_format == "channels_first" and tf_data_format == "NWC": 

6340 x = tf.compat.v1.transpose(x, (0, 2, 1)) # NWC -> NCW 

6341 

6342 return x 

6343 

6344 

6345@keras_export("keras.backend.separable_conv2d") 

6346@tf.__internal__.dispatch.add_dispatch_support 

6347@doc_controls.do_not_generate_docs 

6348def separable_conv2d( 

6349 x, 

6350 depthwise_kernel, 

6351 pointwise_kernel, 

6352 strides=(1, 1), 

6353 padding="valid", 

6354 data_format=None, 

6355 dilation_rate=(1, 1), 

6356): 

6357 """2D convolution with separable filters. 

6358 

6359 Args: 

6360 x: input tensor 

6361 depthwise_kernel: convolution kernel for the depthwise convolution. 

6362 pointwise_kernel: kernel for the 1x1 convolution. 

6363 strides: strides tuple (length 2). 

6364 padding: string, `"same"` or `"valid"`. 

6365 data_format: string, `"channels_last"` or `"channels_first"`. 

6366 dilation_rate: tuple of integers, 

6367 dilation rates for the separable convolution. 

6368 

6369 Returns: 

6370 Output tensor. 

6371 

6372 Raises: 

6373 ValueError: if `data_format` is neither `channels_last` or 

6374 `channels_first`. 

6375 ValueError: if `strides` is not a tuple of 2 integers. 

6376 """ 

6377 if data_format is None: 

6378 data_format = image_data_format() 

6379 if data_format not in {"channels_first", "channels_last"}: 

6380 raise ValueError("Unknown data_format: " + str(data_format)) 

6381 if len(strides) != 2: 

6382 raise ValueError("`strides` must be a tuple of 2 integers.") 

6383 

6384 x, tf_data_format = _preprocess_conv2d_input(x, data_format) 

6385 padding = _preprocess_padding(padding) 

6386 if not isinstance(strides, tuple): 

6387 strides = tuple(strides) 

6388 if tf_data_format == "NHWC": 

6389 strides = (1,) + strides + (1,) 

6390 else: 

6391 strides = (1, 1) + strides 

6392 

6393 x = tf.nn.separable_conv2d( 

6394 x, 

6395 depthwise_kernel, 

6396 pointwise_kernel, 

6397 strides=strides, 

6398 padding=padding, 

6399 dilations=dilation_rate, 

6400 data_format=tf_data_format, 

6401 ) 

6402 if data_format == "channels_first" and tf_data_format == "NHWC": 

6403 x = tf.compat.v1.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW 

6404 return x 

6405 

6406 

6407@keras_export("keras.backend.depthwise_conv2d") 

6408@tf.__internal__.dispatch.add_dispatch_support 

6409@doc_controls.do_not_generate_docs 

6410def depthwise_conv2d( 

6411 x, 

6412 depthwise_kernel, 

6413 strides=(1, 1), 

6414 padding="valid", 

6415 data_format=None, 

6416 dilation_rate=(1, 1), 

6417): 

6418 """2D convolution with separable filters. 

6419 

6420 Args: 

6421 x: input tensor 

6422 depthwise_kernel: convolution kernel for the depthwise convolution. 

6423 strides: strides tuple (length 2). 

6424 padding: string, `"same"` or `"valid"`. 

6425 data_format: string, `"channels_last"` or `"channels_first"`. 

6426 dilation_rate: tuple of integers, 

6427 dilation rates for the separable convolution. 

6428 

6429 Returns: 

6430 Output tensor. 

6431 

6432 Raises: 

6433 ValueError: if `data_format` is neither `channels_last` or 

6434 `channels_first`. 

6435 """ 

6436 if data_format is None: 

6437 data_format = image_data_format() 

6438 if data_format not in {"channels_first", "channels_last"}: 

6439 raise ValueError("Unknown data_format: " + str(data_format)) 

6440 

6441 x, tf_data_format = _preprocess_conv2d_input(x, data_format) 

6442 padding = _preprocess_padding(padding) 

6443 if tf_data_format == "NHWC": 

6444 strides = (1,) + strides + (1,) 

6445 else: 

6446 strides = (1, 1) + strides 

6447 

6448 x = tf.nn.depthwise_conv2d( 

6449 x, 

6450 depthwise_kernel, 

6451 strides=strides, 

6452 padding=padding, 

6453 dilations=dilation_rate, 

6454 data_format=tf_data_format, 

6455 ) 

6456 if data_format == "channels_first" and tf_data_format == "NHWC": 

6457 x = tf.compat.v1.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW 

6458 return x 

6459 

6460 

6461@keras_export("keras.backend.conv3d") 

6462@tf.__internal__.dispatch.add_dispatch_support 

6463@doc_controls.do_not_generate_docs 

6464def conv3d( 

6465 x, 

6466 kernel, 

6467 strides=(1, 1, 1), 

6468 padding="valid", 

6469 data_format=None, 

6470 dilation_rate=(1, 1, 1), 

6471): 

6472 """3D convolution. 

6473 

6474 Args: 

6475 x: Tensor or variable. 

6476 kernel: kernel tensor. 

6477 strides: strides tuple. 

6478 padding: string, `"same"` or `"valid"`. 

6479 data_format: string, `"channels_last"` or `"channels_first"`. 

6480 dilation_rate: tuple of 3 integers. 

6481 

6482 Returns: 

6483 A tensor, result of 3D convolution. 

6484 

6485 Raises: 

6486 ValueError: if `data_format` is neither `channels_last` or 

6487 `channels_first`. 

6488 """ 

6489 if data_format is None: 

6490 data_format = image_data_format() 

6491 if data_format not in {"channels_first", "channels_last"}: 

6492 raise ValueError("Unknown data_format: " + str(data_format)) 

6493 

6494 x, tf_data_format = _preprocess_conv3d_input(x, data_format) 

6495 padding = _preprocess_padding(padding) 

6496 x = tf.compat.v1.nn.convolution( 

6497 input=x, 

6498 filter=kernel, 

6499 dilation_rate=dilation_rate, 

6500 strides=strides, 

6501 padding=padding, 

6502 data_format=tf_data_format, 

6503 ) 

6504 if data_format == "channels_first" and tf_data_format == "NDHWC": 

6505 x = tf.compat.v1.transpose(x, (0, 4, 1, 2, 3)) 

6506 return x 

6507 

6508 

6509def conv3d_transpose( 

6510 x, 

6511 kernel, 

6512 output_shape, 

6513 strides=(1, 1, 1), 

6514 padding="valid", 

6515 data_format=None, 

6516): 

6517 """3D deconvolution (i.e. 

6518 

6519 transposed convolution). 

6520 

6521 Args: 

6522 x: input tensor. 

6523 kernel: kernel tensor. 

6524 output_shape: 1D int tensor for the output shape. 

6525 strides: strides tuple. 

6526 padding: string, "same" or "valid". 

6527 data_format: string, `"channels_last"` or `"channels_first"`. 

6528 

6529 Returns: 

6530 A tensor, result of transposed 3D convolution. 

6531 

6532 Raises: 

6533 ValueError: if `data_format` is neither `channels_last` or 

6534 `channels_first`. 

6535 """ 

6536 if data_format is None: 

6537 data_format = image_data_format() 

6538 if data_format not in {"channels_first", "channels_last"}: 

6539 raise ValueError("Unknown data_format: " + str(data_format)) 

6540 if isinstance(output_shape, (tuple, list)): 

6541 output_shape = tf.stack(output_shape) 

6542 

6543 x, tf_data_format = _preprocess_conv3d_input(x, data_format) 

6544 

6545 if data_format == "channels_first" and tf_data_format == "NDHWC": 

6546 output_shape = ( 

6547 output_shape[0], 

6548 output_shape[2], 

6549 output_shape[3], 

6550 output_shape[4], 

6551 output_shape[1], 

6552 ) 

6553 if output_shape[0] is None: 

6554 output_shape = (tf.shape(x)[0],) + tuple(output_shape[1:]) 

6555 output_shape = tf.stack(list(output_shape)) 

6556 

6557 padding = _preprocess_padding(padding) 

6558 if tf_data_format == "NDHWC": 

6559 strides = (1,) + strides + (1,) 

6560 else: 

6561 strides = (1, 1) + strides 

6562 

6563 x = tf.compat.v1.nn.conv3d_transpose( 

6564 x, 

6565 kernel, 

6566 output_shape, 

6567 strides, 

6568 padding=padding, 

6569 data_format=tf_data_format, 

6570 ) 

6571 if data_format == "channels_first" and tf_data_format == "NDHWC": 

6572 x = tf.compat.v1.transpose(x, (0, 4, 1, 2, 3)) 

6573 return x 

6574 

6575 

6576@keras_export("keras.backend.pool2d") 

6577@tf.__internal__.dispatch.add_dispatch_support 

6578@doc_controls.do_not_generate_docs 

6579def pool2d( 

6580 x, 

6581 pool_size, 

6582 strides=(1, 1), 

6583 padding="valid", 

6584 data_format=None, 

6585 pool_mode="max", 

6586): 

6587 """2D Pooling. 

6588 

6589 Args: 

6590 x: Tensor or variable. 

6591 pool_size: tuple of 2 integers. 

6592 strides: tuple of 2 integers. 

6593 padding: string, `"same"` or `"valid"`. 

6594 data_format: string, `"channels_last"` or `"channels_first"`. 

6595 pool_mode: string, `"max"` or `"avg"`. 

6596 

6597 Returns: 

6598 A tensor, result of 2D pooling. 

6599 

6600 Raises: 

6601 ValueError: if `data_format` is neither `"channels_last"` or 

6602 `"channels_first"`. 

6603 ValueError: if `pool_size` is not a tuple of 2 integers. 

6604 ValueError: if `strides` is not a tuple of 2 integers. 

6605 ValueError: if `pool_mode` is neither `"max"` or `"avg"`. 

6606 """ 

6607 if data_format is None: 

6608 data_format = image_data_format() 

6609 if data_format not in {"channels_first", "channels_last"}: 

6610 raise ValueError("Unknown data_format: " + str(data_format)) 

6611 if len(pool_size) != 2: 

6612 raise ValueError("`pool_size` must be a tuple of 2 integers.") 

6613 if len(strides) != 2: 

6614 raise ValueError("`strides` must be a tuple of 2 integers.") 

6615 

6616 x, tf_data_format = _preprocess_conv2d_input(x, data_format) 

6617 padding = _preprocess_padding(padding) 

6618 if tf_data_format == "NHWC": 

6619 strides = (1,) + strides + (1,) 

6620 pool_size = (1,) + pool_size + (1,) 

6621 else: 

6622 strides = (1, 1) + strides 

6623 pool_size = (1, 1) + pool_size 

6624 

6625 if pool_mode == "max": 

6626 x = tf.compat.v1.nn.max_pool( 

6627 x, pool_size, strides, padding=padding, data_format=tf_data_format 

6628 ) 

6629 elif pool_mode == "avg": 

6630 x = tf.compat.v1.nn.avg_pool( 

6631 x, pool_size, strides, padding=padding, data_format=tf_data_format 

6632 ) 

6633 else: 

6634 raise ValueError("Invalid pooling mode: " + str(pool_mode)) 

6635 

6636 if data_format == "channels_first" and tf_data_format == "NHWC": 

6637 x = tf.compat.v1.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW 

6638 return x 

6639 

6640 

6641@keras_export("keras.backend.pool3d") 

6642@tf.__internal__.dispatch.add_dispatch_support 

6643@doc_controls.do_not_generate_docs 

6644def pool3d( 

6645 x, 

6646 pool_size, 

6647 strides=(1, 1, 1), 

6648 padding="valid", 

6649 data_format=None, 

6650 pool_mode="max", 

6651): 

6652 """3D Pooling. 

6653 

6654 Args: 

6655 x: Tensor or variable. 

6656 pool_size: tuple of 3 integers. 

6657 strides: tuple of 3 integers. 

6658 padding: string, `"same"` or `"valid"`. 

6659 data_format: string, `"channels_last"` or `"channels_first"`. 

6660 pool_mode: string, `"max"` or `"avg"`. 

6661 

6662 Returns: 

6663 A tensor, result of 3D pooling. 

6664 

6665 Raises: 

6666 ValueError: if `data_format` is neither `"channels_last"` or 

6667 `"channels_first"`. 

6668 ValueError: if `pool_mode` is neither `"max"` or `"avg"`. 

6669 """ 

6670 if data_format is None: 

6671 data_format = image_data_format() 

6672 if data_format not in {"channels_first", "channels_last"}: 

6673 raise ValueError("Unknown data_format: " + str(data_format)) 

6674 

6675 x, tf_data_format = _preprocess_conv3d_input(x, data_format) 

6676 padding = _preprocess_padding(padding) 

6677 if tf_data_format == "NDHWC": 

6678 strides = (1,) + strides + (1,) 

6679 pool_size = (1,) + pool_size + (1,) 

6680 else: 

6681 strides = (1, 1) + strides 

6682 pool_size = (1, 1) + pool_size 

6683 

6684 if pool_mode == "max": 

6685 x = tf.nn.max_pool3d( 

6686 x, pool_size, strides, padding=padding, data_format=tf_data_format 

6687 ) 

6688 elif pool_mode == "avg": 

6689 x = tf.nn.avg_pool3d( 

6690 x, pool_size, strides, padding=padding, data_format=tf_data_format 

6691 ) 

6692 else: 

6693 raise ValueError("Invalid pooling mode: " + str(pool_mode)) 

6694 

6695 if data_format == "channels_first" and tf_data_format == "NDHWC": 

6696 x = tf.compat.v1.transpose(x, (0, 4, 1, 2, 3)) 

6697 return x 

6698 

6699 

6700def local_conv( 

6701 inputs, kernel, kernel_size, strides, output_shape, data_format=None 

6702): 

6703 """Apply N-D convolution with un-shared weights. 

6704 

6705 Args: 

6706 inputs: (N+2)-D tensor with shape 

6707 (batch_size, channels_in, d_in1, ..., d_inN) 

6708 if data_format='channels_first', or 

6709 (batch_size, d_in1, ..., d_inN, channels_in) 

6710 if data_format='channels_last'. 

6711 kernel: the unshared weight for N-D convolution, 

6712 with shape (output_items, feature_dim, channels_out), where 

6713 feature_dim = np.prod(kernel_size) * channels_in, 

6714 output_items = np.prod(output_shape). 

6715 kernel_size: a tuple of N integers, specifying the 

6716 spatial dimensions of the N-D convolution window. 

6717 strides: a tuple of N integers, specifying the strides 

6718 of the convolution along the spatial dimensions. 

6719 output_shape: a tuple of (d_out1, ..., d_outN) specifying the spatial 

6720 dimensionality of the output. 

6721 data_format: string, "channels_first" or "channels_last". 

6722 

6723 Returns: 

6724 An (N+2)-D tensor with shape: 

6725 (batch_size, channels_out) + output_shape 

6726 if data_format='channels_first', or: 

6727 (batch_size,) + output_shape + (channels_out,) 

6728 if data_format='channels_last'. 

6729 

6730 Raises: 

6731 ValueError: if `data_format` is neither 

6732 `channels_last` nor `channels_first`. 

6733 """ 

6734 if data_format is None: 

6735 data_format = image_data_format() 

6736 if data_format not in {"channels_first", "channels_last"}: 

6737 raise ValueError("Unknown data_format: " + str(data_format)) 

6738 

6739 kernel_shape = int_shape(kernel) 

6740 feature_dim = kernel_shape[1] 

6741 channels_out = kernel_shape[-1] 

6742 ndims = len(output_shape) 

6743 spatial_dimensions = list(range(ndims)) 

6744 

6745 xs = [] 

6746 output_axes_ticks = [range(axis_max) for axis_max in output_shape] 

6747 for position in itertools.product(*output_axes_ticks): 

6748 slices = [slice(None)] 

6749 

6750 if data_format == "channels_first": 

6751 slices.append(slice(None)) 

6752 

6753 slices.extend( 

6754 slice( 

6755 position[d] * strides[d], 

6756 position[d] * strides[d] + kernel_size[d], 

6757 ) 

6758 for d in spatial_dimensions 

6759 ) 

6760 

6761 if data_format == "channels_last": 

6762 slices.append(slice(None)) 

6763 

6764 xs.append(reshape(inputs[slices], (1, -1, feature_dim))) 

6765 

6766 x_aggregate = concatenate(xs, axis=0) 

6767 output = batch_dot(x_aggregate, kernel) 

6768 output = reshape(output, output_shape + (-1, channels_out)) 

6769 

6770 if data_format == "channels_first": 

6771 permutation = [ndims, ndims + 1] + spatial_dimensions 

6772 else: 

6773 permutation = [ndims] + spatial_dimensions + [ndims + 1] 

6774 

6775 return permute_dimensions(output, permutation) 

6776 

6777 

6778@keras_export("keras.backend.local_conv1d") 

6779@tf.__internal__.dispatch.add_dispatch_support 

6780@doc_controls.do_not_generate_docs 

6781def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None): 

6782 """Apply 1D conv with un-shared weights. 

6783 

6784 Args: 

6785 inputs: 3D tensor with shape: 

6786 (batch_size, steps, input_dim) 

6787 if data_format is "channels_last" or 

6788 (batch_size, input_dim, steps) 

6789 if data_format is "channels_first". 

6790 kernel: the unshared weight for convolution, 

6791 with shape (output_length, feature_dim, filters). 

6792 kernel_size: a tuple of a single integer, 

6793 specifying the length of the 1D convolution window. 

6794 strides: a tuple of a single integer, 

6795 specifying the stride length of the convolution. 

6796 data_format: the data format, channels_first or channels_last. 

6797 

6798 Returns: 

6799 A 3d tensor with shape: 

6800 (batch_size, output_length, filters) 

6801 if data_format='channels_first' 

6802 or 3D tensor with shape: 

6803 (batch_size, filters, output_length) 

6804 if data_format='channels_last'. 

6805 """ 

6806 output_shape = (kernel.shape[0],) 

6807 return local_conv( 

6808 inputs, kernel, kernel_size, strides, output_shape, data_format 

6809 ) 

6810 

6811 

6812@keras_export("keras.backend.local_conv2d") 

6813@tf.__internal__.dispatch.add_dispatch_support 

6814@doc_controls.do_not_generate_docs 

6815def local_conv2d( 

6816 inputs, kernel, kernel_size, strides, output_shape, data_format=None 

6817): 

6818 """Apply 2D conv with un-shared weights. 

6819 

6820 Args: 

6821 inputs: 4D tensor with shape: 

6822 (batch_size, filters, new_rows, new_cols) 

6823 if data_format='channels_first' 

6824 or 4D tensor with shape: 

6825 (batch_size, new_rows, new_cols, filters) 

6826 if data_format='channels_last'. 

6827 kernel: the unshared weight for convolution, 

6828 with shape (output_items, feature_dim, filters). 

6829 kernel_size: a tuple of 2 integers, specifying the 

6830 width and height of the 2D convolution window. 

6831 strides: a tuple of 2 integers, specifying the strides 

6832 of the convolution along the width and height. 

6833 output_shape: a tuple with (output_row, output_col). 

6834 data_format: the data format, channels_first or channels_last. 

6835 

6836 Returns: 

6837 A 4D tensor with shape: 

6838 (batch_size, filters, new_rows, new_cols) 

6839 if data_format='channels_first' 

6840 or 4D tensor with shape: 

6841 (batch_size, new_rows, new_cols, filters) 

6842 if data_format='channels_last'. 

6843 """ 

6844 return local_conv( 

6845 inputs, kernel, kernel_size, strides, output_shape, data_format 

6846 ) 

6847 

6848 

6849@keras_export("keras.backend.bias_add") 

6850@tf.__internal__.dispatch.add_dispatch_support 

6851@doc_controls.do_not_generate_docs 

6852def bias_add(x, bias, data_format=None): 

6853 """Adds a bias vector to a tensor. 

6854 

6855 Args: 

6856 x: Tensor or variable. 

6857 bias: Bias tensor to add. 

6858 data_format: string, `"channels_last"` or `"channels_first"`. 

6859 

6860 Returns: 

6861 Output tensor. 

6862 

6863 Raises: 

6864 ValueError: In one of the two cases below: 

6865 1. invalid `data_format` argument. 

6866 2. invalid bias shape. 

6867 the bias should be either a vector or 

6868 a tensor with ndim(x) - 1 dimension 

6869 """ 

6870 if data_format is None: 

6871 data_format = image_data_format() 

6872 if data_format not in {"channels_first", "channels_last"}: 

6873 raise ValueError("Unknown data_format: " + str(data_format)) 

6874 bias_shape = int_shape(bias) 

6875 if len(bias_shape) != 1 and len(bias_shape) != ndim(x) - 1: 

6876 raise ValueError( 

6877 "Unexpected bias dimensions %d, expect to be 1 or %d dimensions" 

6878 % (len(bias_shape), ndim(x) - 1) 

6879 ) 

6880 

6881 if len(bias_shape) == 1: 

6882 if data_format == "channels_first": 

6883 return tf.nn.bias_add(x, bias, data_format="NCHW") 

6884 return tf.nn.bias_add(x, bias, data_format="NHWC") 

6885 if ndim(x) in (3, 4, 5): 

6886 if data_format == "channels_first": 

6887 bias_reshape_axis = (1, bias_shape[-1]) + bias_shape[:-1] 

6888 return x + reshape(bias, bias_reshape_axis) 

6889 return x + reshape(bias, (1,) + bias_shape) 

6890 return tf.nn.bias_add(x, bias) 

6891 

6892 

6893# RANDOMNESS 

6894 

6895 

6896@keras_export("keras.backend.random_normal") 

6897@tf.__internal__.dispatch.add_dispatch_support 

6898@doc_controls.do_not_generate_docs 

6899def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): 

6900 """Returns a tensor with normal distribution of values. 

6901 

6902 It is an alias to `tf.random.normal`. 

6903 

6904 Args: 

6905 shape: A tuple of integers, the shape of tensor to create. 

6906 mean: A float, the mean value of the normal distribution to draw 

6907 samples. Defaults to `0.0`. 

6908 stddev: A float, the standard deviation of the normal distribution 

6909 to draw samples. Defaults to `1.0`. 

6910 dtype: `tf.dtypes.DType`, dtype of returned tensor. None uses Keras 

6911 backend dtype which is float32. Defaults to `None`. 

6912 seed: Integer, random seed. Will use a random numpy integer when not 

6913 specified. 

6914 

6915 Returns: 

6916 A tensor with normal distribution of values. 

6917 

6918 Example: 

6919 

6920 >>> random_normal_tensor = tf.keras.backend.random_normal(shape=(2,3), 

6921 ... mean=0.0, stddev=1.0) 

6922 >>> random_normal_tensor 

6923 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=..., 

6924 dtype=float32)> 

6925 """ 

6926 if dtype is None: 

6927 dtype = floatx() 

6928 if seed is None: 

6929 seed = np.random.randint(10e6) 

6930 return tf.random.normal( 

6931 shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed 

6932 ) 

6933 

6934 

6935@keras_export("keras.backend.random_uniform") 

6936@tf.__internal__.dispatch.add_dispatch_support 

6937@doc_controls.do_not_generate_docs 

6938def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): 

6939 """Returns a tensor with uniform distribution of values. 

6940 

6941 Args: 

6942 shape: A tuple of integers, the shape of tensor to create. 

6943 minval: A float, lower boundary of the uniform distribution 

6944 to draw samples. 

6945 maxval: A float, upper boundary of the uniform distribution 

6946 to draw samples. 

6947 dtype: String, dtype of returned tensor. 

6948 seed: Integer, random seed. 

6949 

6950 Returns: 

6951 A tensor. 

6952 

6953 Example: 

6954 

6955 >>> random_uniform_tensor = tf.keras.backend.random_uniform(shape=(2,3), 

6956 ... minval=0.0, maxval=1.0) 

6957 >>> random_uniform_tensor 

6958 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=..., 

6959 dtype=float32)> 

6960 """ 

6961 if dtype is None: 

6962 dtype = floatx() 

6963 if seed is None: 

6964 seed = np.random.randint(10e6) 

6965 return tf.random.uniform( 

6966 shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed 

6967 ) 

6968 

6969 

6970@keras_export("keras.backend.random_binomial") 

6971@tf.__internal__.dispatch.add_dispatch_support 

6972@doc_controls.do_not_generate_docs 

6973def random_binomial(shape, p=0.0, dtype=None, seed=None): 

6974 """Returns a tensor with random binomial distribution of values. 

6975 

6976 DEPRECATED, use `tf.keras.backend.random_bernoulli` instead. 

6977 

6978 The binomial distribution with parameters `n` and `p` is the probability 

6979 distribution of the number of successful Bernoulli process. Only supports 

6980 `n` = 1 for now. 

6981 

6982 Args: 

6983 shape: A tuple of integers, the shape of tensor to create. 

6984 p: A float, `0. <= p <= 1`, probability of binomial distribution. 

6985 dtype: String, dtype of returned tensor. 

6986 seed: Integer, random seed. 

6987 

6988 Returns: 

6989 A tensor. 

6990 

6991 Example: 

6992 

6993 >>> random_binomial_tensor = tf.keras.backend.random_binomial(shape=(2,3), 

6994 ... p=0.5) 

6995 >>> random_binomial_tensor 

6996 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=..., 

6997 dtype=float32)> 

6998 """ 

6999 warnings.warn( 

7000 "`tf.keras.backend.random_binomial` is deprecated, " 

7001 "and will be removed in a future version." 

7002 "Please use `tf.keras.backend.random_bernoulli` instead.", 

7003 stacklevel=2, 

7004 ) 

7005 return random_bernoulli(shape, p, dtype, seed) 

7006 

7007 

7008@keras_export("keras.backend.random_bernoulli") 

7009@tf.__internal__.dispatch.add_dispatch_support 

7010@doc_controls.do_not_generate_docs 

7011def random_bernoulli(shape, p=0.0, dtype=None, seed=None): 

7012 """Returns a tensor with random bernoulli distribution of values. 

7013 

7014 Args: 

7015 shape: A tuple of integers, the shape of tensor to create. 

7016 p: A float, `0. <= p <= 1`, probability of bernoulli distribution. 

7017 dtype: String, dtype of returned tensor. 

7018 seed: Integer, random seed. 

7019 

7020 Returns: 

7021 A tensor. 

7022 """ 

7023 if dtype is None: 

7024 dtype = floatx() 

7025 if seed is None: 

7026 seed = np.random.randint(10e6) 

7027 return tf.where( 

7028 tf.random.uniform(shape, dtype=dtype, seed=seed) <= p, 

7029 tf.ones(shape, dtype=dtype), 

7030 tf.zeros(shape, dtype=dtype), 

7031 ) 

7032 

7033 

7034@keras_export("keras.backend.truncated_normal") 

7035@tf.__internal__.dispatch.add_dispatch_support 

7036@doc_controls.do_not_generate_docs 

7037def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): 

7038 """Returns a tensor with truncated random normal distribution of values. 

7039 

7040 The generated values follow a normal distribution 

7041 with specified mean and standard deviation, 

7042 except that values whose magnitude is more than 

7043 two standard deviations from the mean are dropped and re-picked. 

7044 

7045 Args: 

7046 shape: A tuple of integers, the shape of tensor to create. 

7047 mean: Mean of the values. 

7048 stddev: Standard deviation of the values. 

7049 dtype: String, dtype of returned tensor. 

7050 seed: Integer, random seed. 

7051 

7052 Returns: 

7053 A tensor. 

7054 """ 

7055 if dtype is None: 

7056 dtype = floatx() 

7057 if seed is None: 

7058 seed = np.random.randint(10e6) 

7059 return tf.random.truncated_normal( 

7060 shape, mean, stddev, dtype=dtype, seed=seed 

7061 ) 

7062 

7063 

7064# CTC 

7065# TensorFlow has a native implementation, but it uses sparse tensors 

7066# and therefore requires a wrapper for Keras. The functions below convert 

7067# dense to sparse tensors and also wraps up the beam search code that is 

7068# in TensorFlow's CTC implementation 

7069 

7070 

7071@keras_export("keras.backend.ctc_label_dense_to_sparse") 

7072@tf.__internal__.dispatch.add_dispatch_support 

7073@doc_controls.do_not_generate_docs 

7074def ctc_label_dense_to_sparse(labels, label_lengths): 

7075 """Converts CTC labels from dense to sparse. 

7076 

7077 Args: 

7078 labels: dense CTC labels. 

7079 label_lengths: length of the labels. 

7080 

7081 Returns: 

7082 A sparse tensor representation of the labels. 

7083 """ 

7084 label_shape = tf.shape(labels) 

7085 num_batches_tns = tf.stack([label_shape[0]]) 

7086 max_num_labels_tns = tf.stack([label_shape[1]]) 

7087 

7088 def range_less_than(old_input, current_input): 

7089 return tf.expand_dims(tf.range(tf.shape(old_input)[1]), 0) < tf.fill( 

7090 max_num_labels_tns, current_input 

7091 ) 

7092 

7093 init = tf.cast(tf.fill([1, label_shape[1]], 0), tf.bool) 

7094 dense_mask = tf.compat.v1.scan( 

7095 range_less_than, label_lengths, initializer=init, parallel_iterations=1 

7096 ) 

7097 dense_mask = dense_mask[:, 0, :] 

7098 

7099 label_array = tf.reshape( 

7100 tf.tile(tf.range(0, label_shape[1]), num_batches_tns), label_shape 

7101 ) 

7102 label_ind = tf.compat.v1.boolean_mask(label_array, dense_mask) 

7103 

7104 batch_array = tf.compat.v1.transpose( 

7105 tf.reshape( 

7106 tf.tile(tf.range(0, label_shape[0]), max_num_labels_tns), 

7107 reverse(label_shape, 0), 

7108 ) 

7109 ) 

7110 batch_ind = tf.compat.v1.boolean_mask(batch_array, dense_mask) 

7111 indices = tf.compat.v1.transpose( 

7112 tf.reshape(concatenate([batch_ind, label_ind], axis=0), [2, -1]) 

7113 ) 

7114 

7115 vals_sparse = tf.compat.v1.gather_nd(labels, indices) 

7116 

7117 return tf.SparseTensor( 

7118 tf.cast(indices, tf.int64), vals_sparse, tf.cast(label_shape, tf.int64) 

7119 ) 

7120 

7121 

7122@keras_export("keras.backend.ctc_batch_cost") 

7123@tf.__internal__.dispatch.add_dispatch_support 

7124@doc_controls.do_not_generate_docs 

7125def ctc_batch_cost(y_true, y_pred, input_length, label_length): 

7126 """Runs CTC loss algorithm on each batch element. 

7127 

7128 Args: 

7129 y_true: tensor `(samples, max_string_length)` 

7130 containing the truth labels. 

7131 y_pred: tensor `(samples, time_steps, num_categories)` 

7132 containing the prediction, or output of the softmax. 

7133 input_length: tensor `(samples, 1)` containing the sequence length for 

7134 each batch item in `y_pred`. 

7135 label_length: tensor `(samples, 1)` containing the sequence length for 

7136 each batch item in `y_true`. 

7137 

7138 Returns: 

7139 Tensor with shape (samples,1) containing the 

7140 CTC loss of each element. 

7141 """ 

7142 label_length = tf.cast(tf.squeeze(label_length, axis=-1), tf.int32) 

7143 input_length = tf.cast(tf.squeeze(input_length, axis=-1), tf.int32) 

7144 sparse_labels = tf.cast( 

7145 ctc_label_dense_to_sparse(y_true, label_length), tf.int32 

7146 ) 

7147 

7148 y_pred = tf.math.log( 

7149 tf.compat.v1.transpose(y_pred, perm=[1, 0, 2]) + epsilon() 

7150 ) 

7151 

7152 return tf.expand_dims( 

7153 tf.compat.v1.nn.ctc_loss( 

7154 inputs=y_pred, labels=sparse_labels, sequence_length=input_length 

7155 ), 

7156 1, 

7157 ) 

7158 

7159 

7160@keras_export("keras.backend.ctc_decode") 

7161@tf.__internal__.dispatch.add_dispatch_support 

7162@doc_controls.do_not_generate_docs 

7163def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1): 

7164 """Decodes the output of a softmax. 

7165 

7166 Can use either greedy search (also known as best path) 

7167 or a constrained dictionary search. 

7168 

7169 Args: 

7170 y_pred: tensor `(samples, time_steps, num_categories)` 

7171 containing the prediction, or output of the softmax. 

7172 input_length: tensor `(samples, )` containing the sequence length for 

7173 each batch item in `y_pred`. 

7174 greedy: perform much faster best-path search if `true`. 

7175 This does not use a dictionary. 

7176 beam_width: if `greedy` is `false`: a beam search decoder will be used 

7177 with a beam of this width. 

7178 top_paths: if `greedy` is `false`, 

7179 how many of the most probable paths will be returned. 

7180 

7181 Returns: 

7182 Tuple: 

7183 List: if `greedy` is `true`, returns a list of one element that 

7184 contains the decoded sequence. 

7185 If `false`, returns the `top_paths` most probable 

7186 decoded sequences. 

7187 Each decoded sequence has shape (samples, time_steps). 

7188 Important: blank labels are returned as `-1`. 

7189 Tensor `(top_paths, )` that contains 

7190 the log probability of each decoded sequence. 

7191 """ 

7192 input_shape = shape(y_pred) 

7193 num_samples, num_steps = input_shape[0], input_shape[1] 

7194 y_pred = tf.math.log( 

7195 tf.compat.v1.transpose(y_pred, perm=[1, 0, 2]) + epsilon() 

7196 ) 

7197 input_length = tf.cast(input_length, tf.int32) 

7198 

7199 if greedy: 

7200 (decoded, log_prob) = tf.nn.ctc_greedy_decoder( 

7201 inputs=y_pred, sequence_length=input_length 

7202 ) 

7203 else: 

7204 (decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder( 

7205 inputs=y_pred, 

7206 sequence_length=input_length, 

7207 beam_width=beam_width, 

7208 top_paths=top_paths, 

7209 ) 

7210 decoded_dense = [] 

7211 for st in decoded: 

7212 st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps)) 

7213 decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1)) 

7214 return (decoded_dense, log_prob) 

7215 

7216 

7217# HIGH ORDER FUNCTIONS 

7218 

7219 

7220@keras_export("keras.backend.map_fn") 

7221@doc_controls.do_not_generate_docs 

7222def map_fn(fn, elems, name=None, dtype=None): 

7223 """Map the function fn over the elements elems and return the outputs. 

7224 

7225 Args: 

7226 fn: Callable that will be called upon each element in elems 

7227 elems: tensor 

7228 name: A string name for the map node in the graph 

7229 dtype: Output data type. 

7230 

7231 Returns: 

7232 Tensor with dtype `dtype`. 

7233 """ 

7234 return tf.compat.v1.map_fn(fn, elems, name=name, dtype=dtype) 

7235 

7236 

7237@keras_export("keras.backend.foldl") 

7238@doc_controls.do_not_generate_docs 

7239def foldl(fn, elems, initializer=None, name=None): 

7240 """Reduce elems using fn to combine them from left to right. 

7241 

7242 Args: 

7243 fn: Callable that will be called upon each element in elems and an 

7244 accumulator, for instance `lambda acc, x: acc + x` 

7245 elems: tensor 

7246 initializer: The first value used (`elems[0]` in case of None) 

7247 name: A string name for the foldl node in the graph 

7248 

7249 Returns: 

7250 Tensor with same type and shape as `initializer`. 

7251 """ 

7252 return tf.compat.v1.foldl(fn, elems, initializer=initializer, name=name) 

7253 

7254 

7255@keras_export("keras.backend.foldr") 

7256@doc_controls.do_not_generate_docs 

7257def foldr(fn, elems, initializer=None, name=None): 

7258 """Reduce elems using fn to combine them from right to left. 

7259 

7260 Args: 

7261 fn: Callable that will be called upon each element in elems and an 

7262 accumulator, for instance `lambda acc, x: acc + x` 

7263 elems: tensor 

7264 initializer: The first value used (`elems[-1]` in case of None) 

7265 name: A string name for the foldr node in the graph 

7266 

7267 Returns: 

7268 Same type and shape as initializer 

7269 """ 

7270 return tf.compat.v1.foldr(fn, elems, initializer=initializer, name=name) 

7271 

7272 

7273# Load Keras default configuration from config file if present. 

7274# Set Keras base dir path given KERAS_HOME env variable, if applicable. 

7275# Otherwise either ~/.keras or /tmp. 

7276if "KERAS_HOME" in os.environ: 

7277 _keras_dir = os.environ.get("KERAS_HOME") 

7278else: 

7279 _keras_base_dir = os.path.expanduser("~") 

7280 _keras_dir = os.path.join(_keras_base_dir, ".keras") 

7281_config_path = os.path.expanduser(os.path.join(_keras_dir, "keras.json")) 

7282if os.path.exists(_config_path): 

7283 try: 

7284 with open(_config_path) as fh: 

7285 _config = json.load(fh) 

7286 except ValueError: 

7287 _config = {} 

7288 _floatx = _config.get("floatx", floatx()) 

7289 assert _floatx in {"float16", "float32", "float64"} 

7290 _epsilon = _config.get("epsilon", epsilon()) 

7291 assert isinstance(_epsilon, float) 

7292 _image_data_format = _config.get("image_data_format", image_data_format()) 

7293 assert _image_data_format in {"channels_last", "channels_first"} 

7294 set_floatx(_floatx) 

7295 set_epsilon(_epsilon) 

7296 set_image_data_format(_image_data_format) 

7297 

7298# Save config file. 

7299if not os.path.exists(_keras_dir): 

7300 try: 

7301 os.makedirs(_keras_dir) 

7302 except OSError: 

7303 # Except permission denied and potential race conditions 

7304 # in multi-threaded environments. 

7305 pass 

7306 

7307if not os.path.exists(_config_path): 

7308 _config = { 

7309 "floatx": floatx(), 

7310 "epsilon": epsilon(), 

7311 "backend": "tensorflow", 

7312 "image_data_format": image_data_format(), 

7313 } 

7314 try: 

7315 with open(_config_path, "w") as f: 

7316 f.write(json.dumps(_config, indent=4)) 

7317 except IOError: 

7318 # Except permission denied. 

7319 pass 

7320 

7321 

7322def configure_and_create_distributed_session(distribution_strategy): 

7323 """Configure session config and create a session with it.""" 

7324 

7325 def _create_session(distribution_strategy): 

7326 """Create the Distributed Strategy session.""" 

7327 session_config = get_default_session_config() 

7328 

7329 # If a session already exists, merge in its config; in the case there is 

7330 # a conflict, take values of the existing config. 

7331 global _SESSION 

7332 if getattr(_SESSION, "session", None) and _SESSION.session._config: 

7333 session_config.MergeFrom(_SESSION.session._config) 

7334 

7335 if is_tpu_strategy(distribution_strategy): 

7336 # TODO(priyag, yuefengz): Remove this workaround when Distribute 

7337 # Coordinator is integrated with keras and we can create a session 

7338 # from there. 

7339 distribution_strategy.configure(session_config) 

7340 master = ( 

7341 distribution_strategy.extended._tpu_cluster_resolver.master() 

7342 ) 

7343 session = tf.compat.v1.Session(config=session_config, target=master) 

7344 else: 

7345 worker_context = dc.get_current_worker_context() 

7346 if worker_context: 

7347 dc_session_config = worker_context.session_config 

7348 # Merge the default session config to the one from distribute 

7349 # coordinator, which is fine for now since they don't have 

7350 # conflicting configurations. 

7351 dc_session_config.MergeFrom(session_config) 

7352 session = tf.compat.v1.Session( 

7353 config=dc_session_config, 

7354 target=worker_context.master_target, 

7355 ) 

7356 else: 

7357 distribution_strategy.configure(session_config) 

7358 session = tf.compat.v1.Session(config=session_config) 

7359 

7360 set_session(session) 

7361 

7362 if distribution_strategy.extended._in_multi_worker_mode(): 

7363 dc.run_distribute_coordinator(_create_session, distribution_strategy) 

7364 else: 

7365 _create_session(distribution_strategy) 

7366 

7367 

7368def _is_tpu_strategy_class(clz): 

7369 is_tpu_strat = lambda k: k.__name__.startswith("TPUStrategy") 

7370 if is_tpu_strat(clz): 

7371 return True 

7372 return py_any(map(_is_tpu_strategy_class, clz.__bases__)) 

7373 

7374 

7375def is_tpu_strategy(strategy): 

7376 """Returns whether input is a TPUStrategy instance or subclass instance.""" 

7377 return _is_tpu_strategy_class(strategy.__class__) 

7378 

7379 

7380def _is_symbolic_tensor(x): 

7381 return tf.is_tensor(x) and not isinstance(x, tf.__internal__.EagerTensor) 

7382 

7383 

7384def convert_inputs_if_ragged(inputs): 

7385 """Converts any ragged tensors to dense.""" 

7386 

7387 def _convert_ragged_input(inputs): 

7388 if isinstance(inputs, tf.RaggedTensor): 

7389 return inputs.to_tensor() 

7390 return inputs 

7391 

7392 flat_inputs = tf.nest.flatten(inputs) 

7393 contains_ragged = py_any( 

7394 isinstance(i, tf.RaggedTensor) for i in flat_inputs 

7395 ) 

7396 

7397 if not contains_ragged: 

7398 return inputs, None 

7399 

7400 inputs = tf.nest.map_structure(_convert_ragged_input, inputs) 

7401 # Multiple mask are not yet supported, so one mask is used on all inputs. 

7402 # We approach this similarly when using row lengths to ignore steps. 

7403 nested_row_lengths = tf.cast( 

7404 flat_inputs[0].nested_row_lengths()[0], "int32" 

7405 ) 

7406 return inputs, nested_row_lengths 

7407 

7408 

7409def maybe_convert_to_ragged( 

7410 is_ragged_input, output, nested_row_lengths, go_backwards=False 

7411): 

7412 """Converts any ragged input back to its initial structure.""" 

7413 if not is_ragged_input: 

7414 return output 

7415 

7416 if go_backwards: 

7417 # Reverse based on the timestep dim, so that nested_row_lengths will 

7418 # mask from the correct direction. Return the reverse ragged tensor. 

7419 output = reverse(output, [1]) 

7420 ragged = tf.RaggedTensor.from_tensor(output, nested_row_lengths) 

7421 return reverse(ragged, [1]) 

7422 else: 

7423 return tf.RaggedTensor.from_tensor(output, nested_row_lengths) 

7424 

7425 

7426class ContextValueCache(weakref.WeakKeyDictionary): 

7427 """Container that caches (possibly tensor) values based on the context. 

7428 

7429 This class is similar to defaultdict, where values may be produced by the 

7430 default factory specified during initialization. This class also has a 

7431 default value for the key (when key is `None`) -- the key is set to the 

7432 current graph or eager context. The default factories for key and value are 

7433 only used in `__getitem__` and `setdefault`. The `.get()` behavior remains 

7434 the same. 

7435 

7436 This object will return the value of the current graph or closest parent 

7437 graph if the current graph is a function. This is to reflect the fact that 

7438 if a tensor is created in eager/graph, child functions may capture that 

7439 tensor. 

7440 

7441 The default factory method may accept keyword arguments (unlike defaultdict, 

7442 which only accepts callables with 0 arguments). To pass keyword arguments to 

7443 `default_factory`, use the `setdefault` method instead of `__getitem__`. 

7444 

7445 An example of how this class can be used in different contexts: 

7446 

7447 ``` 

7448 cache = ContextValueCache(int) 

7449 

7450 # Eager mode 

7451 cache[None] += 2 

7452 cache[None] += 4 

7453 assert cache[None] == 6 

7454 

7455 # Graph mode 

7456 with tf.Graph().as_default() as g: 

7457 cache[None] += 5 

7458 cache[g] += 3 

7459 assert cache[g] == 8 

7460 ``` 

7461 

7462 Example of a default factory with arguments: 

7463 

7464 ``` 

7465 cache = ContextValueCache(lambda x: x + 1) 

7466 g = tf.get_default_graph() 

7467 

7468 # Example with keyword argument. 

7469 value = cache.setdefault(key=g, kwargs={'x': 3}) 

7470 assert cache[g] == 4 

7471 ``` 

7472 """ 

7473 

7474 def __init__(self, default_factory): 

7475 self.default_factory = default_factory 

7476 weakref.WeakKeyDictionary.__init__(self) 

7477 

7478 def _key(self): 

7479 if tf.executing_eagerly(): 

7480 return _DUMMY_EAGER_GRAPH.key 

7481 else: 

7482 return tf.compat.v1.get_default_graph() 

7483 

7484 def _get_parent_graph(self, graph): 

7485 """Returns the parent graph or dummy eager object.""" 

7486 # TODO(b/149317164): Currently FuncGraphs use ops.get_default_graph() as 

7487 # the outer graph. This results in outer_graph always being a Graph, 

7488 # even in eager mode (get_default_graph will create a new Graph if there 

7489 # isn't a default graph). Because of this bug, we have to specially set 

7490 # the key when eager execution is enabled. 

7491 parent_graph = graph.outer_graph 

7492 if ( 

7493 not isinstance(parent_graph, tf.__internal__.FuncGraph) 

7494 and tf.compat.v1.executing_eagerly_outside_functions() 

7495 ): 

7496 return _DUMMY_EAGER_GRAPH.key 

7497 return parent_graph 

7498 

7499 def _get_recursive(self, key): 

7500 """Gets the value at key or the closest parent graph.""" 

7501 value = self.get(key) 

7502 if value is not None: 

7503 return value 

7504 

7505 # Since FuncGraphs are able to capture tensors and variables from their 

7506 # parent graphs, recursively search to see if there is a value stored 

7507 # for one of the parent graphs. 

7508 if isinstance(key, tf.__internal__.FuncGraph): 

7509 return self._get_recursive(self._get_parent_graph(key)) 

7510 return None 

7511 

7512 def __getitem__(self, key): 

7513 """Gets the value at key (or current context), or sets default value. 

7514 

7515 Args: 

7516 key: May be `None` or `Graph`object. When `None`, the key is set to 

7517 the current context. 

7518 

7519 Returns: 

7520 Either the cached or default value. 

7521 """ 

7522 if key is None: 

7523 key = self._key() 

7524 

7525 value = self._get_recursive(key) 

7526 if value is None: 

7527 value = self[key] = self.default_factory() 

7528 return value 

7529 

7530 def setdefault(self, key=None, default=None, kwargs=None): 

7531 """Sets the default value if key is not in dict, and returns the 

7532 value.""" 

7533 if key is None: 

7534 key = self._key() 

7535 kwargs = kwargs or {} 

7536 

7537 if default is None and key not in self: 

7538 default = self.default_factory(**kwargs) 

7539 return weakref.WeakKeyDictionary.setdefault(self, key, default) 

7540 

7541 

7542# This dictionary holds a mapping {graph: learning_phase}. In eager mode, a 

7543# dummy object is used. 

7544# A learning phase is a bool tensor used to run Keras models in 

7545# either train mode (learning_phase == 1) or test mode (learning_phase == 0). 

7546_GRAPH_LEARNING_PHASES = ContextValueCache( 

7547 object_identity.ObjectIdentityWeakSet 

7548) 

7549 

7550# This dictionary holds a mapping between a graph and variables to initialize 

7551# in the graph. 

7552_GRAPH_VARIABLES = ContextValueCache(object_identity.ObjectIdentityWeakSet) 

7553 

7554# This dictionary holds a mapping between a graph and TF optimizers created in 

7555# the graph. 

7556_GRAPH_TF_OPTIMIZERS = ContextValueCache(object_identity.ObjectIdentityWeakSet) 

7557