Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/backend.py: 29%
2449 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
17"""Keras backend API."""
19import collections
20import itertools
21import json
22import os
23import random
24import sys
25import threading
26import warnings
27import weakref
29import numpy as np
30import tensorflow.compat.v2 as tf
32from keras.src import backend_config
33from keras.src.distribute import distribute_coordinator_utils as dc
34from keras.src.dtensor import dtensor_api as dtensor
35from keras.src.engine import keras_tensor
36from keras.src.utils import control_flow_util
37from keras.src.utils import object_identity
38from keras.src.utils import tf_contextlib
39from keras.src.utils import tf_inspect
40from keras.src.utils import tf_utils
42# isort: off
43from tensorflow.core.protobuf import config_pb2
44from tensorflow.python.eager import context
45from tensorflow.python.eager.context import get_config
46from tensorflow.python.platform import tf_logging as logging
47from tensorflow.python.util.tf_export import keras_export
48from tensorflow.tools.docs import doc_controls
50py_all = all
51py_sum = sum
52py_any = any
54# INTERNAL UTILS
56# The internal graph maintained by Keras and used by the symbolic Keras APIs
57# while executing eagerly (such as the functional API for model-building).
58# This is thread-local to allow building separate models in different threads
59# concurrently, but comes at the cost of not being able to build one model
60# across threads.
61_GRAPH = threading.local()
63# A graph which is used for constructing functions in eager mode.
64_CURRENT_SCRATCH_GRAPH = threading.local()
67# This is a thread local object that will hold the default internal TF session
68# used by Keras. It can be set manually via `set_session(sess)`.
69class SessionLocal(threading.local):
70 def __init__(self):
71 super().__init__()
72 self.session = None
75_SESSION = SessionLocal()
78# A global dictionary mapping graph objects to an index of counters used
79# for various layer/optimizer names in each graph.
80# Allows to give unique autogenerated names to layers, in a graph-specific way.
81PER_GRAPH_OBJECT_NAME_UIDS = weakref.WeakKeyDictionary()
84# A global set tracking what object names have been seen so far.
85# Optionally used as an avoid-list when generating names
86OBSERVED_NAMES = set()
89# _DUMMY_EAGER_GRAPH.key is used as a key in _GRAPH_LEARNING_PHASES.
90# We keep a separate reference to it to make sure it does not get removed from
91# _GRAPH_LEARNING_PHASES.
92# _DummyEagerGraph inherits from threading.local to make its `key` attribute
93# thread local. This is needed to make set_learning_phase affect only the
94# current thread during eager execution (see b/123096885 for more details).
95class _DummyEagerGraph(threading.local):
96 """_DummyEagerGraph provides a thread local `key` attribute.
98 We can't use threading.local directly, i.e. without subclassing, because
99 gevent monkey patches threading.local and its version does not support
100 weak references.
101 """
103 class _WeakReferencableClass:
104 """This dummy class is needed for two reasons.
106 - We need something that supports weak references. Basic types like
107 string and ints don't.
108 - We need something whose hash and equality are based on object identity
109 to make sure they are treated as different keys to
110 _GRAPH_LEARNING_PHASES.
112 An empty Python class satisfies both of these requirements.
113 """
115 pass
117 def __init__(self):
118 # Constructors for classes subclassing threading.local run once
119 # per thread accessing something in the class. Thus, each thread will
120 # get a different key.
121 super().__init__()
122 self.key = _DummyEagerGraph._WeakReferencableClass()
123 self.learning_phase_is_set = False
126_DUMMY_EAGER_GRAPH = _DummyEagerGraph()
128# This boolean flag can be set to True to leave variable initialization
129# up to the user.
130# Change its value via `manual_variable_initialization(value)`.
131_MANUAL_VAR_INIT = False
133# This list holds the available devices.
134# It is populated when `_get_available_gpus()` is called for the first time.
135# We assume our devices don't change henceforth.
136_LOCAL_DEVICES = None
138# The below functions are kept accessible from backend for compatibility.
139epsilon = backend_config.epsilon
140floatx = backend_config.floatx
141image_data_format = backend_config.image_data_format
142set_epsilon = backend_config.set_epsilon
143set_floatx = backend_config.set_floatx
144set_image_data_format = backend_config.set_image_data_format
147@keras_export("keras.backend.backend")
148@doc_controls.do_not_generate_docs
149def backend():
150 """Publicly accessible method for determining the current backend.
152 Only exists for API compatibility with multi-backend Keras.
154 Returns:
155 The string "tensorflow".
156 """
157 return "tensorflow"
160@keras_export("keras.backend.cast_to_floatx")
161@tf.__internal__.dispatch.add_dispatch_support
162@doc_controls.do_not_generate_docs
163def cast_to_floatx(x):
164 """Cast a Numpy array to the default Keras float type.
166 Args:
167 x: Numpy array or TensorFlow tensor.
169 Returns:
170 The same array (Numpy array if `x` was a Numpy array, or TensorFlow
171 tensor if `x` was a tensor), cast to its new type.
173 Example:
175 >>> tf.keras.backend.floatx()
176 'float32'
177 >>> arr = np.array([1.0, 2.0], dtype='float64')
178 >>> arr.dtype
179 dtype('float64')
180 >>> new_arr = cast_to_floatx(arr)
181 >>> new_arr
182 array([1., 2.], dtype=float32)
183 >>> new_arr.dtype
184 dtype('float32')
186 """
187 if isinstance(x, (tf.Tensor, tf.Variable, tf.SparseTensor)):
188 return tf.cast(x, dtype=floatx())
189 return np.asarray(x, dtype=floatx())
192@keras_export("keras.backend.get_uid")
193def get_uid(prefix=""):
194 """Associates a string prefix with an integer counter in a TensorFlow graph.
196 Args:
197 prefix: String prefix to index.
199 Returns:
200 Unique integer ID.
202 Example:
204 >>> get_uid('dense')
205 1
206 >>> get_uid('dense')
207 2
209 """
210 graph = get_graph()
211 if graph not in PER_GRAPH_OBJECT_NAME_UIDS:
212 PER_GRAPH_OBJECT_NAME_UIDS[graph] = collections.defaultdict(int)
213 layer_name_uids = PER_GRAPH_OBJECT_NAME_UIDS[graph]
214 layer_name_uids[prefix] += 1
215 return layer_name_uids[prefix]
218@keras_export("keras.backend.reset_uids")
219def reset_uids():
220 """Resets graph identifiers."""
222 PER_GRAPH_OBJECT_NAME_UIDS.clear()
223 OBSERVED_NAMES.clear()
226@keras_export("keras.backend.clear_session")
227def clear_session():
228 """Resets all state generated by Keras.
230 Keras manages a global state, which it uses to implement the Functional
231 model-building API and to uniquify autogenerated layer names.
233 If you are creating many models in a loop, this global state will consume
234 an increasing amount of memory over time, and you may want to clear it.
235 Calling `clear_session()` releases the global state: this helps avoid
236 clutter from old models and layers, especially when memory is limited.
238 Example 1: calling `clear_session()` when creating models in a loop
240 ```python
241 for _ in range(100):
242 # Without `clear_session()`, each iteration of this loop will
243 # slightly increase the size of the global state managed by Keras
244 model = tf.keras.Sequential([
245 tf.keras.layers.Dense(10) for _ in range(10)])
247 for _ in range(100):
248 # With `clear_session()` called at the beginning,
249 # Keras starts with a blank state at each iteration
250 # and memory consumption is constant over time.
251 tf.keras.backend.clear_session()
252 model = tf.keras.Sequential([
253 tf.keras.layers.Dense(10) for _ in range(10)])
254 ```
256 Example 2: resetting the layer name generation counter
258 >>> import tensorflow as tf
259 >>> layers = [tf.keras.layers.Dense(10) for _ in range(10)]
260 >>> new_layer = tf.keras.layers.Dense(10)
261 >>> print(new_layer.name)
262 dense_10
263 >>> tf.keras.backend.set_learning_phase(1)
264 >>> print(tf.keras.backend.learning_phase())
265 1
266 >>> tf.keras.backend.clear_session()
267 >>> new_layer = tf.keras.layers.Dense(10)
268 >>> print(new_layer.name)
269 dense
270 """
271 global _SESSION
272 global _GRAPH_LEARNING_PHASES
273 global _GRAPH_VARIABLES
274 global _GRAPH_TF_OPTIMIZERS
275 global _GRAPH
276 _GRAPH.graph = None
277 tf.compat.v1.reset_default_graph()
278 reset_uids()
279 if _SESSION.session is not None:
280 _SESSION.session.close()
281 _SESSION.session = None
282 graph = get_graph()
283 with graph.as_default():
284 _DUMMY_EAGER_GRAPH.learning_phase_is_set = False
286 _GRAPH_LEARNING_PHASES = {}
287 # Create the learning phase placeholder in graph using the default
288 # factory
289 phase = _default_learning_phase()
290 _internal_set_learning_phase(graph, phase)
292 _GRAPH_VARIABLES.pop(graph, None)
293 _GRAPH_TF_OPTIMIZERS.pop(graph, None)
294 if tf.executing_eagerly():
295 # Clear pending nodes in eager executors, kernel caches and
296 # step_containers.
297 context.context().clear_kernel_cache()
300# Inject the clear_session function to keras_deps to remove the dependency
301# from TFLite to Keras.
302tf.__internal__.register_clear_session_function(clear_session)
305@keras_export("keras.backend.manual_variable_initialization")
306@doc_controls.do_not_generate_docs
307def manual_variable_initialization(value):
308 """Sets the manual variable initialization flag.
310 This boolean flag determines whether
311 variables should be initialized
312 as they are instantiated (default), or if
313 the user should handle the initialization
314 (e.g. via `tf.compat.v1.initialize_all_variables()`).
316 Args:
317 value: Python boolean.
318 """
319 global _MANUAL_VAR_INIT
320 _MANUAL_VAR_INIT = value
323@keras_export("keras.backend.learning_phase")
324@doc_controls.do_not_generate_docs
325def learning_phase():
326 """Returns the learning phase flag.
328 The learning phase flag is a bool tensor (0 = test, 1 = train)
329 to be passed as input to any Keras function
330 that uses a different behavior at train time and test time.
332 Returns:
333 Learning phase (scalar integer tensor or Python integer).
334 """
335 graph = tf.compat.v1.get_default_graph()
336 if graph is getattr(_GRAPH, "graph", None):
337 # Don't enter an init_scope for the learning phase if eager execution
338 # is enabled but we're inside the Keras workspace graph.
339 learning_phase = symbolic_learning_phase()
340 else:
341 with tf.init_scope():
342 # We always check & set the learning phase inside the init_scope,
343 # otherwise the wrong default_graph will be used to look up the
344 # learning phase inside of functions & defuns.
345 #
346 # This is because functions & defuns (both in graph & in eager mode)
347 # will always execute non-eagerly using a function-specific default
348 # subgraph.
349 if context.executing_eagerly():
350 if _DUMMY_EAGER_GRAPH.key not in _GRAPH_LEARNING_PHASES:
351 return _default_learning_phase()
352 else:
353 return _internal_get_learning_phase(_DUMMY_EAGER_GRAPH.key)
354 else:
355 learning_phase = symbolic_learning_phase()
356 _mark_func_graph_as_unsaveable(graph, learning_phase)
357 return learning_phase
360def global_learning_phase_is_set():
361 return _DUMMY_EAGER_GRAPH.learning_phase_is_set
364def _mark_func_graph_as_unsaveable(graph, learning_phase):
365 """Mark graph as unsaveable due to use of symbolic keras learning phase.
367 Functions that capture the symbolic learning phase cannot be exported to
368 SavedModel. Mark the funcgraph as unsaveable, so that an error will be
369 raised if it is exported.
371 Args:
372 graph: Graph or FuncGraph object.
373 learning_phase: Learning phase placeholder or int defined in the graph.
374 """
375 if graph.building_function and is_placeholder(learning_phase):
376 graph.mark_as_unsaveable(
377 "The keras learning phase placeholder was used inside a function. "
378 "Exporting placeholders is not supported when saving out a "
379 "SavedModel. Please call `tf.keras.backend.set_learning_phase(0)` "
380 "in the function to set the learning phase to a constant value."
381 )
384def symbolic_learning_phase():
385 graph = get_graph()
386 with graph.as_default():
387 if graph not in _GRAPH_LEARNING_PHASES:
388 phase = _default_learning_phase()
389 _internal_set_learning_phase(graph, phase)
391 return _internal_get_learning_phase(graph)
394def _internal_set_learning_phase(graph, value):
395 global _GRAPH_LEARNING_PHASES
397 if isinstance(value, tf.Tensor):
398 # The 'value' here is a tf.Tensor with attribute 'graph'.
399 # There is a circular reference between key 'graph' and attribute
400 # 'graph'. So we need use a weakref.ref to refer to the 'value' tensor
401 # here. Otherwise, it would lead to memory leak.
402 value_ref = weakref.ref(value)
403 _GRAPH_LEARNING_PHASES[graph] = value_ref
404 else:
405 _GRAPH_LEARNING_PHASES[graph] = value
408def _internal_get_learning_phase(graph):
409 phase = _GRAPH_LEARNING_PHASES.get(graph, None)
410 if isinstance(phase, weakref.ref):
411 return phase()
412 else:
413 return phase
416def _default_learning_phase():
417 if context.executing_eagerly():
418 return 0
419 else:
420 with name_scope(""):
421 return tf.compat.v1.placeholder_with_default(
422 False, shape=(), name="keras_learning_phase"
423 )
426@keras_export("keras.backend.set_learning_phase")
427@doc_controls.do_not_generate_docs
428def set_learning_phase(value):
429 """Sets the learning phase to a fixed value.
431 The backend learning phase affects any code that calls
432 `backend.learning_phase()`
433 In particular, all Keras built-in layers use the learning phase as the
434 default for the `training` arg to `Layer.__call__`.
436 User-written layers and models can achieve the same behavior with code that
437 looks like:
439 ```python
440 def call(self, inputs, training=None):
441 if training is None:
442 training = backend.learning_phase()
443 ```
445 Args:
446 value: Learning phase value, either 0 or 1 (integers).
447 0 = test, 1 = train
449 Raises:
450 ValueError: if `value` is neither `0` nor `1`.
451 """
452 warnings.warn(
453 "`tf.keras.backend.set_learning_phase` is deprecated and "
454 "will be removed after 2020-10-11. To update it, simply "
455 "pass a True/False value to the `training` argument of the "
456 "`__call__` method of your layer or model."
457 )
458 deprecated_internal_set_learning_phase(value)
461def deprecated_internal_set_learning_phase(value):
462 """A deprecated internal implementation of set_learning_phase.
464 This method is an internal-only version of `set_learning_phase` that
465 does not raise a deprecation error. It is required because
466 saved_model needs to keep working with user code that uses the deprecated
467 learning phase methods until those APIs are fully removed from the public
468 API.
470 Specifically SavedModel saving needs to make sure the learning phase is 0
471 during tracing even if users overwrote it to a different value.
473 But, we don't want to raise deprecation warnings for users when savedmodel
474 sets learning phase just for compatibility with code that relied on
475 explicitly setting the learning phase for other values.
477 Args:
478 value: Learning phase value, either 0 or 1 (integers).
479 0 = test, 1 = train
481 Raises:
482 ValueError: if `value` is neither `0` nor `1`.
483 """
484 if value not in {0, 1}:
485 raise ValueError("Expected learning phase to be 0 or 1.")
486 with tf.init_scope():
487 if tf.executing_eagerly():
488 # In an eager context, the learning phase values applies to both the
489 # eager context and the internal Keras graph.
490 _DUMMY_EAGER_GRAPH.learning_phase_is_set = True
491 _internal_set_learning_phase(_DUMMY_EAGER_GRAPH.key, value)
493 _internal_set_learning_phase(get_graph(), value)
496@keras_export("keras.backend.learning_phase_scope")
497@tf_contextlib.contextmanager
498@doc_controls.do_not_generate_docs
499def learning_phase_scope(value):
500 """Provides a scope within which the learning phase is equal to `value`.
502 The learning phase gets restored to its original value upon exiting the
503 scope.
505 Args:
506 value: Learning phase value, either 0 or 1 (integers).
507 0 = test, 1 = train
509 Yields:
510 None.
512 Raises:
513 ValueError: if `value` is neither `0` nor `1`.
514 """
515 warnings.warn(
516 "`tf.keras.backend.learning_phase_scope` is deprecated and "
517 "will be removed after 2020-10-11. To update it, simply "
518 "pass a True/False value to the `training` argument of the "
519 "`__call__` method of your layer or model.",
520 stacklevel=2,
521 )
522 with deprecated_internal_learning_phase_scope(value):
523 try:
524 yield
525 finally:
526 pass
529@tf_contextlib.contextmanager
530def deprecated_internal_learning_phase_scope(value):
531 """An internal-only version of `learning_phase_scope`.
533 Unlike the public method, this method does not raise a deprecation warning.
534 This is needed because saved model saving needs to set learning phase
535 to maintain compatibility
536 with code that sets/gets the learning phase, but saved model
537 saving itself shouldn't raise a deprecation warning.
539 We can get rid of this method and its usages when the public API is
540 removed.
542 Args:
543 value: Learning phase value, either 0 or 1 (integers).
544 0 = test, 1 = train
546 Yields:
547 None.
549 Raises:
550 ValueError: if `value` is neither `0` nor `1`.
551 """
552 global _GRAPH_LEARNING_PHASES
553 if value not in {0, 1}:
554 raise ValueError("Expected learning phase to be 0 or 1.")
556 with tf.init_scope():
557 if tf.executing_eagerly():
558 previous_eager_value = _internal_get_learning_phase(
559 _DUMMY_EAGER_GRAPH.key
560 )
561 previous_graph_value = _internal_get_learning_phase(get_graph())
563 learning_phase_previously_set = _DUMMY_EAGER_GRAPH.learning_phase_is_set
564 try:
565 deprecated_internal_set_learning_phase(value)
566 yield
567 finally:
568 # Restore learning phase to initial value.
569 if not learning_phase_previously_set:
570 _DUMMY_EAGER_GRAPH.learning_phase_is_set = False
571 with tf.init_scope():
572 if tf.executing_eagerly():
573 if previous_eager_value is not None:
574 _internal_set_learning_phase(
575 _DUMMY_EAGER_GRAPH.key, previous_eager_value
576 )
577 elif _DUMMY_EAGER_GRAPH.key in _GRAPH_LEARNING_PHASES:
578 del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key]
580 graph = get_graph()
581 if previous_graph_value is not None:
582 _internal_set_learning_phase(graph, previous_graph_value)
583 elif graph in _GRAPH_LEARNING_PHASES:
584 del _GRAPH_LEARNING_PHASES[graph]
587@tf_contextlib.contextmanager
588def eager_learning_phase_scope(value):
589 """Internal scope that sets the learning phase in eager / tf.function only.
591 Args:
592 value: Learning phase value, either 0 or 1 (integers).
593 0 = test, 1 = train
595 Yields:
596 None.
598 Raises:
599 ValueError: if `value` is neither `0` nor `1`.
600 """
601 global _GRAPH_LEARNING_PHASES
602 assert value in {0, 1}
603 assert tf.compat.v1.executing_eagerly_outside_functions()
604 global_learning_phase_was_set = global_learning_phase_is_set()
605 if global_learning_phase_was_set:
606 previous_value = learning_phase()
607 try:
608 _internal_set_learning_phase(_DUMMY_EAGER_GRAPH.key, value)
609 yield
610 finally:
611 # Restore learning phase to initial value or unset.
612 if global_learning_phase_was_set:
613 _internal_set_learning_phase(_DUMMY_EAGER_GRAPH.key, previous_value)
614 else:
615 del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key]
618def _as_graph_element(obj):
619 """Convert `obj` to a graph element if possible, otherwise return `None`.
621 Args:
622 obj: Object to convert.
624 Returns:
625 The result of `obj._as_graph_element()` if that method is available;
626 otherwise `None`.
627 """
628 conv_fn = getattr(obj, "_as_graph_element", None)
629 if conv_fn and callable(conv_fn):
630 return conv_fn()
631 return None
634def _assert_same_graph(original_item, item):
635 """Fail if the 2 items are from different graphs.
637 Args:
638 original_item: Original item to check against.
639 item: Item to check.
641 Raises:
642 ValueError: if graphs do not match.
643 """
644 original_graph = getattr(original_item, "graph", None)
645 graph = getattr(item, "graph", None)
646 if original_graph and graph and original_graph is not graph:
647 raise ValueError(
648 "%s must be from the same graph as %s (graphs are %s and %s)."
649 % (item, original_item, graph, original_graph)
650 )
653def _current_graph(op_input_list, graph=None):
654 """Returns the appropriate graph to use for the given inputs.
656 This library method provides a consistent algorithm for choosing the graph
657 in which an Operation should be constructed:
659 1. If the default graph is being used to construct a function, we
660 use the default graph.
661 2. If the "graph" is specified explicitly, we validate that all of the
662 inputs in "op_input_list" are compatible with that graph.
663 3. Otherwise, we attempt to select a graph from the first Operation-
664 or Tensor-valued input in "op_input_list", and validate that all other
665 such inputs are in the same graph.
666 4. If the graph was not specified and it could not be inferred from
667 "op_input_list", we attempt to use the default graph.
669 Args:
670 op_input_list: A list of inputs to an operation, which may include
671 `Tensor`, `Operation`, and other objects that may be converted to a
672 graph element.
673 graph: (Optional) The explicit graph to use.
675 Raises:
676 TypeError: If op_input_list is not a list or tuple, or if graph is not a
677 Graph.
678 ValueError: If a graph is explicitly passed and not all inputs are from
679 it, or if the inputs are from multiple graphs, or we could not find a
680 graph and there was no default graph.
682 Returns:
683 The appropriate graph to use for the given inputs.
685 """
686 current_default_graph = tf.compat.v1.get_default_graph()
687 if current_default_graph.building_function:
688 return current_default_graph
690 op_input_list = tuple(op_input_list) # Handle generators correctly
691 if graph and not isinstance(graph, tf.Graph):
692 raise TypeError(f"Input graph needs to be a Graph: {graph}")
694 def _is_symbolic_tensor(tensor):
695 if hasattr(tf, "is_symbolic_tensor"):
696 return tf.is_symbolic_tensor(tensor)
697 return type(tensor) == tf.Tensor
699 # 1. We validate that all of the inputs are from the same graph. This is
700 # either the supplied graph parameter, or the first one selected from one
701 # the graph-element-valued inputs. In the latter case, we hold onto
702 # that input in original_graph_element so we can provide a more
703 # informative error if a mismatch is found.
704 original_graph_element = None
705 for op_input in op_input_list:
706 if isinstance(
707 op_input, (tf.Operation, tf.__internal__.CompositeTensor)
708 ) or _is_symbolic_tensor(op_input):
709 graph_element = op_input
710 else:
711 graph_element = _as_graph_element(op_input)
713 if graph_element is not None:
714 if not graph:
715 original_graph_element = graph_element
716 graph = getattr(graph_element, "graph", None)
717 elif original_graph_element is not None:
718 _assert_same_graph(original_graph_element, graph_element)
719 elif graph_element.graph is not graph:
720 raise ValueError(
721 f"{graph_element} is not from the passed-in graph."
722 )
724 # 2. If all else fails, we use the default graph, which is always there.
725 return graph or current_default_graph
728def _get_session(op_input_list=()):
729 """Returns the session object for the current thread."""
730 global _SESSION
731 default_session = tf.compat.v1.get_default_session()
732 if default_session is not None:
733 session = default_session
734 else:
735 if tf.inside_function():
736 raise RuntimeError(
737 "Cannot get session inside Tensorflow graph function."
738 )
739 # If we don't have a session, or that session does not match the current
740 # graph, create and cache a new session.
741 if getattr(
742 _SESSION, "session", None
743 ) is None or _SESSION.session.graph is not _current_graph(
744 op_input_list
745 ):
746 # If we are creating the Session inside a tf.distribute.Strategy
747 # scope, we ask the strategy for the right session options to use.
748 if tf.distribute.has_strategy():
749 configure_and_create_distributed_session(
750 tf.distribute.get_strategy()
751 )
752 else:
753 _SESSION.session = tf.compat.v1.Session(
754 config=get_default_session_config()
755 )
756 session = _SESSION.session
757 return session
760@keras_export(v1=["keras.backend.get_session"])
761def get_session(op_input_list=()):
762 """Returns the TF session to be used by the backend.
764 If a default TensorFlow session is available, we will return it.
766 Else, we will return the global Keras session assuming it matches
767 the current graph.
769 If no global Keras session exists at this point:
770 we will create a new global session.
772 Note that you can manually set the global session
773 via `K.set_session(sess)`.
775 Args:
776 op_input_list: An option sequence of tensors or ops, which will be used
777 to determine the current graph. Otherwise the default graph will be
778 used.
780 Returns:
781 A TensorFlow session.
782 """
783 session = _get_session(op_input_list)
784 if not _MANUAL_VAR_INIT:
785 with session.graph.as_default():
786 _initialize_variables(session)
787 return session
790# Inject the get_session function to keras_deps to remove the dependency
791# from TFLite to Keras.
792tf.__internal__.register_get_session_function(get_session)
794# Inject the get_session function to tracking_util to avoid the backward
795# dependency from TF to Keras.
796tf.__internal__.tracking.register_session_provider(get_session)
799def get_graph():
800 if tf.executing_eagerly():
801 global _GRAPH
802 if not getattr(_GRAPH, "graph", None):
803 _GRAPH.graph = tf.__internal__.FuncGraph("keras_graph")
804 return _GRAPH.graph
805 else:
806 return tf.compat.v1.get_default_graph()
809@tf_contextlib.contextmanager
810def _scratch_graph(graph=None):
811 """Retrieve a shared and temporary func graph.
813 The eager execution path lifts a subgraph from the keras global graph into
814 a scratch graph in order to create a function. DistributionStrategies, in
815 turn, constructs multiple functions as well as a final combined function. In
816 order for that logic to work correctly, all of the functions need to be
817 created on the same scratch FuncGraph.
819 Args:
820 graph: A graph to be used as the current scratch graph. If not set then
821 a scratch graph will either be retrieved or created:
823 Yields:
824 The current scratch graph.
825 """
826 global _CURRENT_SCRATCH_GRAPH
827 scratch_graph = getattr(_CURRENT_SCRATCH_GRAPH, "graph", None)
828 # If scratch graph and `graph` are both configured, they must match.
829 if (
830 scratch_graph is not None
831 and graph is not None
832 and scratch_graph is not graph
833 ):
834 raise ValueError("Multiple scratch graphs specified.")
836 if scratch_graph:
837 yield scratch_graph
838 return
840 graph = graph or tf.__internal__.FuncGraph("keras_scratch_graph")
841 try:
842 _CURRENT_SCRATCH_GRAPH.graph = graph
843 yield graph
844 finally:
845 _CURRENT_SCRATCH_GRAPH.graph = None
848@keras_export(v1=["keras.backend.set_session"])
849def set_session(session):
850 """Sets the global TensorFlow session.
852 Args:
853 session: A TF Session.
854 """
855 global _SESSION
856 _SESSION.session = session
859def get_default_session_config():
860 if os.environ.get("OMP_NUM_THREADS"):
861 logging.warning(
862 "OMP_NUM_THREADS is no longer used by the default Keras config. "
863 "To configure the number of threads, use tf.config.threading APIs."
864 )
866 config = get_config()
867 config.allow_soft_placement = True
869 return config
872def get_default_graph_uid_map():
873 graph = tf.compat.v1.get_default_graph()
874 name_uid_map = PER_GRAPH_OBJECT_NAME_UIDS.get(graph, None)
875 if name_uid_map is None:
876 name_uid_map = collections.defaultdict(int)
877 PER_GRAPH_OBJECT_NAME_UIDS[graph] = name_uid_map
878 return name_uid_map
881# DEVICE MANIPULATION
884class _TfDeviceCaptureOp:
885 """Class for capturing the TF device scope."""
887 def __init__(self):
888 self.device = None
890 def _set_device(self, device):
891 """This method captures TF's explicit device scope setting."""
892 if isinstance(device, tf.DeviceSpec):
893 device = device.to_string()
894 self.device = device
896 def _set_device_from_string(self, device_str):
897 self.device = device_str
900def _get_current_tf_device():
901 """Return explicit device of current context, otherwise returns `None`.
903 Returns:
904 If the current device scope is explicitly set, it returns a string with
905 the device (`CPU` or `GPU`). If the scope is not explicitly set, it will
906 return `None`.
907 """
908 graph = get_graph()
909 op = _TfDeviceCaptureOp()
910 graph._apply_device_functions(op)
911 if tf.__internal__.tf2.enabled():
912 return tf.DeviceSpec.from_string(op.device)
913 else:
914 return tf.compat.v1.DeviceSpec.from_string(op.device)
917def _is_current_explicit_device(device_type):
918 """Check if the current device is explicitly set to `device_type`.
920 Args:
921 device_type: A string containing `GPU` or `CPU` (case-insensitive).
923 Returns:
924 A boolean indicating if the current device scope is explicitly set on
925 the device type.
927 Raises:
928 ValueError: If the `device_type` string indicates an unsupported device.
929 """
930 device_type = device_type.upper()
931 if device_type not in ["CPU", "GPU"]:
932 raise ValueError('`device_type` should be either "CPU" or "GPU".')
933 device = _get_current_tf_device()
934 return device is not None and device.device_type == device_type.upper()
937def _get_available_gpus():
938 """Get a list of available GPU devices (formatted as strings).
940 Returns:
941 A list of available GPU devices.
942 """
943 if tf.compat.v1.executing_eagerly_outside_functions():
944 # Returns names of devices directly.
945 return [d.name for d in tf.config.list_logical_devices("GPU")]
947 global _LOCAL_DEVICES
948 if _LOCAL_DEVICES is None:
949 _LOCAL_DEVICES = get_session().list_devices()
950 return [x.name for x in _LOCAL_DEVICES if x.device_type == "GPU"]
953def _has_nchw_support():
954 """Check whether the current scope supports NCHW ops.
956 TensorFlow does not support NCHW on CPU. Therefore we check if we are not
957 explicitly put on
958 CPU, and have GPUs available. In this case there will be soft-placing on the
959 GPU device.
961 Returns:
962 bool: if the current scope device placement would support nchw
963 """
964 explicitly_on_cpu = _is_current_explicit_device("CPU")
965 gpus_available = bool(_get_available_gpus())
966 return not explicitly_on_cpu and gpus_available
969# VARIABLE MANIPULATION
972def _constant_to_tensor(x, dtype):
973 """Convert the input `x` to a tensor of type `dtype`.
975 This is slightly faster than the _to_tensor function, at the cost of
976 handling fewer cases.
978 Args:
979 x: An object to be converted (numpy arrays, floats, ints and lists of
980 them).
981 dtype: The destination type.
983 Returns:
984 A tensor.
985 """
986 return tf.constant(x, dtype=dtype)
989def _to_tensor(x, dtype):
990 """Convert the input `x` to a tensor of type `dtype`.
992 Args:
993 x: An object to be converted (numpy array, list, tensors).
994 dtype: The destination type.
996 Returns:
997 A tensor.
998 """
999 return tf.convert_to_tensor(x, dtype=dtype)
1002@keras_export("keras.backend.is_sparse")
1003@doc_controls.do_not_generate_docs
1004def is_sparse(tensor):
1005 """Returns whether a tensor is a sparse tensor.
1007 Args:
1008 tensor: A tensor instance.
1010 Returns:
1011 A boolean.
1013 Example:
1016 >>> a = tf.keras.backend.placeholder((2, 2), sparse=False)
1017 >>> print(tf.keras.backend.is_sparse(a))
1018 False
1019 >>> b = tf.keras.backend.placeholder((2, 2), sparse=True)
1020 >>> print(tf.keras.backend.is_sparse(b))
1021 True
1023 """
1024 spec = getattr(tensor, "_type_spec", None)
1025 if spec is not None:
1026 return isinstance(spec, tf.SparseTensorSpec)
1027 return isinstance(tensor, tf.SparseTensor)
1030@keras_export("keras.backend.to_dense")
1031@tf.__internal__.dispatch.add_dispatch_support
1032@doc_controls.do_not_generate_docs
1033def to_dense(tensor):
1034 """Converts a sparse tensor into a dense tensor and returns it.
1036 Args:
1037 tensor: A tensor instance (potentially sparse).
1039 Returns:
1040 A dense tensor.
1042 Examples:
1045 >>> b = tf.keras.backend.placeholder((2, 2), sparse=True)
1046 >>> print(tf.keras.backend.is_sparse(b))
1047 True
1048 >>> c = tf.keras.backend.to_dense(b)
1049 >>> print(tf.keras.backend.is_sparse(c))
1050 False
1052 """
1053 if is_sparse(tensor):
1054 return tf.sparse.to_dense(tensor)
1055 else:
1056 return tensor
1059@keras_export("keras.backend.name_scope", v1=[])
1060@doc_controls.do_not_generate_docs
1061def name_scope(name):
1062 """A context manager for use when defining a Python op.
1064 This context manager pushes a name scope, which will make the name of all
1065 operations added within it have a prefix.
1067 For example, to define a new Python op called `my_op`:
1070 def my_op(a):
1071 with tf.name_scope("MyOp") as scope:
1072 a = tf.convert_to_tensor(a, name="a")
1073 # Define some computation that uses `a`.
1074 return foo_op(..., name=scope)
1077 When executed, the Tensor `a` will have the name `MyOp/a`.
1079 Args:
1080 name: The prefix to use on all names created within the name scope.
1082 Returns:
1083 Name scope context manager.
1084 """
1085 return tf.name_scope(name)
1088# Export V1 version.
1089_v1_name_scope = tf.compat.v1.name_scope
1090keras_export(v1=["keras.backend.name_scope"], allow_multiple_exports=True)(
1091 _v1_name_scope
1092)
1095@keras_export("keras.backend.variable")
1096@doc_controls.do_not_generate_docs
1097def variable(value, dtype=None, name=None, constraint=None):
1098 """Instantiates a variable and returns it.
1100 Args:
1101 value: Numpy array, initial value of the tensor.
1102 dtype: Tensor type.
1103 name: Optional name string for the tensor.
1104 constraint: Optional projection function to be
1105 applied to the variable after an optimizer update.
1107 Returns:
1108 A variable instance (with Keras metadata included).
1110 Examples:
1112 >>> val = np.array([[1, 2], [3, 4]])
1113 >>> kvar = tf.keras.backend.variable(value=val, dtype='float64',
1114 ... name='example_var')
1115 >>> tf.keras.backend.dtype(kvar)
1116 'float64'
1117 >>> print(kvar)
1118 <tf.Variable 'example_var:...' shape=(2, 2) dtype=float64, numpy=
1119 array([[1., 2.],
1120 [3., 4.]])>
1122 """
1123 if dtype is None:
1124 dtype = floatx()
1125 if hasattr(value, "tocoo"):
1126 sparse_coo = value.tocoo()
1127 indices = np.concatenate(
1128 (
1129 np.expand_dims(sparse_coo.row, 1),
1130 np.expand_dims(sparse_coo.col, 1),
1131 ),
1132 1,
1133 )
1134 v = tf.SparseTensor(
1135 indices=indices,
1136 values=sparse_coo.data,
1137 dense_shape=sparse_coo.shape,
1138 )
1139 v._keras_shape = sparse_coo.shape
1140 return v
1141 v = tf.Variable(
1142 value, dtype=tf.as_dtype(dtype), name=name, constraint=constraint
1143 )
1144 if isinstance(value, np.ndarray):
1145 v._keras_shape = value.shape
1146 elif hasattr(value, "shape"):
1147 v._keras_shape = int_shape(value)
1148 track_variable(v)
1149 return v
1152def track_tf_optimizer(tf_optimizer):
1153 """Tracks the given TF optimizer for initialization of its variables."""
1154 if tf.executing_eagerly():
1155 return
1156 optimizers = _GRAPH_TF_OPTIMIZERS[None]
1157 optimizers.add(tf_optimizer)
1160@keras_export("keras.__internal__.backend.track_variable", v1=[])
1161def track_variable(v):
1162 """Tracks the given variable for initialization."""
1163 if tf.executing_eagerly():
1164 return
1165 graph = v.graph if hasattr(v, "graph") else get_graph()
1166 _GRAPH_VARIABLES[graph].add(v)
1169def observe_object_name(name):
1170 """Observe a name and make sure it won't be used by `unique_object_name`."""
1171 OBSERVED_NAMES.add(name)
1174def unique_object_name(
1175 name,
1176 name_uid_map=None,
1177 avoid_names=None,
1178 namespace="",
1179 zero_based=False,
1180 avoid_observed_names=False,
1181):
1182 """Makes a object name (or any string) unique within a Keras session.
1184 Args:
1185 name: String name to make unique.
1186 name_uid_map: An optional defaultdict(int) to use when creating unique
1187 names. If None (default), uses a per-Graph dictionary.
1188 avoid_names: An optional set or dict with names which should not be used.
1189 If None (default), don't avoid any names unless `avoid_observed_names`
1190 is True.
1191 namespace: Gets a name which is unique within the (graph, namespace).
1192 Layers which are not Networks use a blank namespace and so get
1193 graph-global names.
1194 zero_based: If True, name sequences start with no suffix (e.g. "dense",
1195 "dense_1"). If False, naming is one-based ("dense_1", "dense_2").
1196 avoid_observed_names: If True, avoid any names that have been observed by
1197 `backend.observe_object_name`.
1199 Returns:
1200 Unique string name.
1202 Example:
1205 unique_object_name('dense') # dense_1
1206 unique_object_name('dense') # dense_2
1208 """
1209 if name_uid_map is None:
1210 name_uid_map = get_default_graph_uid_map()
1211 if avoid_names is None:
1212 if avoid_observed_names:
1213 avoid_names = OBSERVED_NAMES
1214 else:
1215 avoid_names = set()
1216 proposed_name = None
1217 while proposed_name is None or proposed_name in avoid_names:
1218 name_key = (namespace, name)
1219 if zero_based:
1220 number = name_uid_map[name_key]
1221 if number:
1222 proposed_name = name + "_" + str(number)
1223 else:
1224 proposed_name = name
1225 name_uid_map[name_key] += 1
1226 else:
1227 name_uid_map[name_key] += 1
1228 proposed_name = name + "_" + str(name_uid_map[name_key])
1229 return proposed_name
1232def _get_variables(graph=None):
1233 """Returns variables corresponding to the given graph for initialization."""
1234 assert not tf.executing_eagerly()
1235 variables = _GRAPH_VARIABLES[graph]
1236 for opt in _GRAPH_TF_OPTIMIZERS[graph]:
1237 variables.update(opt.optimizer.variables())
1238 return variables
1241@keras_export("keras.__internal__.backend.initialize_variables", v1=[])
1242def _initialize_variables(session):
1243 """Utility to initialize uninitialized variables on the fly."""
1244 variables = _get_variables(get_graph())
1245 candidate_vars = []
1246 for v in variables:
1247 if not getattr(v, "_keras_initialized", False):
1248 candidate_vars.append(v)
1249 if candidate_vars:
1250 # This step is expensive, so we only run it on variables not already
1251 # marked as initialized.
1252 is_initialized = session.run(
1253 [tf.compat.v1.is_variable_initialized(v) for v in candidate_vars]
1254 )
1255 # TODO(kathywu): Some metric variables loaded from SavedModel are never
1256 # actually used, and do not have an initializer.
1257 should_be_initialized = [
1258 (not is_initialized[n]) and v.initializer is not None
1259 for n, v in enumerate(candidate_vars)
1260 ]
1261 uninitialized_vars = []
1262 for flag, v in zip(should_be_initialized, candidate_vars):
1263 if flag:
1264 uninitialized_vars.append(v)
1265 v._keras_initialized = True
1266 if uninitialized_vars:
1267 session.run(tf.compat.v1.variables_initializer(uninitialized_vars))
1270@keras_export("keras.backend.constant")
1271@tf.__internal__.dispatch.add_dispatch_support
1272@doc_controls.do_not_generate_docs
1273def constant(value, dtype=None, shape=None, name=None):
1274 """Creates a constant tensor.
1276 Args:
1277 value: A constant value (or list)
1278 dtype: The type of the elements of the resulting tensor.
1279 shape: Optional dimensions of resulting tensor.
1280 name: Optional name for the tensor.
1282 Returns:
1283 A Constant Tensor.
1284 """
1285 if dtype is None:
1286 dtype = floatx()
1288 return tf.constant(value, dtype=dtype, shape=shape, name=name)
1291@keras_export("keras.backend.is_keras_tensor")
1292def is_keras_tensor(x):
1293 """Returns whether `x` is a Keras tensor.
1295 A "Keras tensor" is a tensor that was returned by a Keras layer,
1296 (`Layer` class) or by `Input`.
1298 Args:
1299 x: A candidate tensor.
1301 Returns:
1302 A boolean: Whether the argument is a Keras tensor.
1304 Raises:
1305 ValueError: In case `x` is not a symbolic tensor.
1307 Examples:
1309 >>> np_var = np.array([1, 2])
1310 >>> # A numpy array is not a symbolic tensor.
1311 >>> tf.keras.backend.is_keras_tensor(np_var)
1312 Traceback (most recent call last):
1313 ...
1314 ValueError: Unexpectedly found an instance of type
1315 `<class 'numpy.ndarray'>`.
1316 Expected a symbolic tensor instance.
1317 >>> keras_var = tf.keras.backend.variable(np_var)
1318 >>> # A variable created with the keras backend is not a Keras tensor.
1319 >>> tf.keras.backend.is_keras_tensor(keras_var)
1320 False
1321 >>> keras_placeholder = tf.keras.backend.placeholder(shape=(2, 4, 5))
1322 >>> # A placeholder is a Keras tensor.
1323 >>> tf.keras.backend.is_keras_tensor(keras_placeholder)
1324 True
1325 >>> keras_input = tf.keras.layers.Input([10])
1326 >>> # An Input is a Keras tensor.
1327 >>> tf.keras.backend.is_keras_tensor(keras_input)
1328 True
1329 >>> keras_layer_output = tf.keras.layers.Dense(10)(keras_input)
1330 >>> # Any Keras layer output is a Keras tensor.
1331 >>> tf.keras.backend.is_keras_tensor(keras_layer_output)
1332 True
1334 """
1335 if not isinstance(
1336 x,
1337 (
1338 tf.Tensor,
1339 tf.Variable,
1340 tf.SparseTensor,
1341 tf.RaggedTensor,
1342 keras_tensor.KerasTensor,
1343 ),
1344 ):
1345 raise ValueError(
1346 "Unexpectedly found an instance of type `"
1347 + str(type(x))
1348 + "`. Expected a symbolic tensor instance."
1349 )
1350 if tf.compat.v1.executing_eagerly_outside_functions():
1351 return isinstance(x, keras_tensor.KerasTensor)
1352 return hasattr(x, "_keras_history")
1355@keras_export("keras.backend.placeholder")
1356@doc_controls.do_not_generate_docs
1357def placeholder(
1358 shape=None, ndim=None, dtype=None, sparse=False, name=None, ragged=False
1359):
1360 """Instantiates a placeholder tensor and returns it.
1362 Args:
1363 shape: Shape of the placeholder
1364 (integer tuple, may include `None` entries).
1365 ndim: Number of axes of the tensor.
1366 At least one of {`shape`, `ndim`} must be specified.
1367 If both are specified, `shape` is used.
1368 dtype: Placeholder type.
1369 sparse: Boolean, whether the placeholder should have a sparse type.
1370 name: Optional name string for the placeholder.
1371 ragged: Boolean, whether the placeholder should have a ragged type.
1372 In this case, values of 'None' in the 'shape' argument represent
1373 ragged dimensions. For more information about RaggedTensors, see
1374 this [guide](https://www.tensorflow.org/guide/ragged_tensor).
1376 Raises:
1377 ValueError: If called with sparse = True and ragged = True.
1379 Returns:
1380 Tensor instance (with Keras metadata included).
1382 Examples:
1385 >>> input_ph = tf.keras.backend.placeholder(shape=(2, 4, 5))
1386 >>> input_ph
1387 <KerasTensor: shape=(2, 4, 5) dtype=float32 (created by layer ...)>
1389 """
1390 if sparse and ragged:
1391 raise ValueError(
1392 "Cannot set both sparse and ragged to "
1393 "True when creating a placeholder."
1394 )
1395 if dtype is None:
1396 dtype = floatx()
1397 if not shape:
1398 if ndim:
1399 shape = (None,) * ndim
1400 if tf.compat.v1.executing_eagerly_outside_functions():
1401 if sparse:
1402 spec = tf.SparseTensorSpec(shape=shape, dtype=dtype)
1403 elif ragged:
1404 ragged_rank = 0
1405 for i in range(1, len(shape)):
1406 # Hacky because could be tensorshape or tuple maybe?
1407 # Or just tensorshape?
1408 if shape[i] is None or (
1409 hasattr(shape[i], "value") and shape[i].value is None
1410 ):
1411 ragged_rank = i
1412 spec = tf.RaggedTensorSpec(
1413 shape=shape, dtype=dtype, ragged_rank=ragged_rank
1414 )
1415 else:
1416 spec = tf.TensorSpec(shape=shape, dtype=dtype, name=name)
1417 x = keras_tensor.keras_tensor_from_type_spec(spec, name=name)
1418 else:
1419 with get_graph().as_default():
1420 if sparse:
1421 x = tf.compat.v1.sparse_placeholder(
1422 dtype, shape=shape, name=name
1423 )
1424 elif ragged:
1425 ragged_rank = 0
1426 for i in range(1, len(shape)):
1427 if shape[i] is None:
1428 ragged_rank = i
1429 type_spec = tf.RaggedTensorSpec(
1430 shape=shape, dtype=dtype, ragged_rank=ragged_rank
1431 )
1433 def tensor_spec_to_placeholder(tensorspec):
1434 return tf.compat.v1.placeholder(
1435 tensorspec.dtype, tensorspec.shape
1436 )
1438 x = tf.nest.map_structure(
1439 tensor_spec_to_placeholder,
1440 type_spec,
1441 expand_composites=True,
1442 )
1443 else:
1444 x = tf.compat.v1.placeholder(dtype, shape=shape, name=name)
1446 if tf.executing_eagerly():
1447 # Add keras_history connectivity information to the placeholder
1448 # when the placeholder is built in a top-level eager context
1449 # (intended to be used with keras.backend.function)
1450 from keras.src.engine import (
1451 input_layer,
1452 )
1454 x = input_layer.Input(tensor=x)
1455 x._is_backend_placeholder = True
1457 return x
1460def is_placeholder(x):
1461 """Returns whether `x` is a placeholder.
1463 Args:
1464 x: A candidate placeholder.
1466 Returns:
1467 Boolean.
1468 """
1469 try:
1470 if tf.compat.v1.executing_eagerly_outside_functions():
1471 return hasattr(x, "_is_backend_placeholder")
1473 # TODO(b/246438937): Remove the special case for tf.Variable once
1474 # tf.Variable becomes CompositeTensor and will be expanded into
1475 # dt_resource tensors.
1476 if tf_utils.is_extension_type(x) and not isinstance(x, tf.Variable):
1477 flat_components = tf.nest.flatten(x, expand_composites=True)
1478 return py_any(is_placeholder(c) for c in flat_components)
1479 else:
1480 return x.op.type == "Placeholder"
1481 except AttributeError:
1482 return False
1485@keras_export("keras.backend.shape")
1486@tf.__internal__.dispatch.add_dispatch_support
1487@doc_controls.do_not_generate_docs
1488def shape(x):
1489 """Returns the symbolic shape of a tensor or variable.
1491 Args:
1492 x: A tensor or variable.
1494 Returns:
1495 A symbolic shape (which is itself a tensor).
1497 Examples:
1499 >>> val = np.array([[1, 2], [3, 4]])
1500 >>> kvar = tf.keras.backend.variable(value=val)
1501 >>> tf.keras.backend.shape(kvar)
1502 <tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 2], dtype=int32)>
1503 >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5))
1504 >>> tf.keras.backend.shape(input)
1505 <KerasTensor: shape=(3,) dtype=int32 inferred_value=[2, 4, 5] ...>
1507 """
1508 return tf.shape(x)
1511@keras_export("keras.backend.int_shape")
1512@doc_controls.do_not_generate_docs
1513def int_shape(x):
1514 """Returns shape of tensor/variable as a tuple of int/None entries.
1516 Args:
1517 x: Tensor or variable.
1519 Returns:
1520 A tuple of integers (or None entries).
1522 Examples:
1524 >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5))
1525 >>> tf.keras.backend.int_shape(input)
1526 (2, 4, 5)
1527 >>> val = np.array([[1, 2], [3, 4]])
1528 >>> kvar = tf.keras.backend.variable(value=val)
1529 >>> tf.keras.backend.int_shape(kvar)
1530 (2, 2)
1532 """
1533 try:
1534 shape = x.shape
1535 if not isinstance(shape, tuple):
1536 shape = tuple(shape.as_list())
1537 return shape
1538 except ValueError:
1539 return None
1542@keras_export("keras.backend.ndim")
1543@doc_controls.do_not_generate_docs
1544def ndim(x):
1545 """Returns the number of axes in a tensor, as an integer.
1547 Args:
1548 x: Tensor or variable.
1550 Returns:
1551 Integer (scalar), number of axes.
1553 Examples:
1556 >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5))
1557 >>> val = np.array([[1, 2], [3, 4]])
1558 >>> kvar = tf.keras.backend.variable(value=val)
1559 >>> tf.keras.backend.ndim(input)
1560 3
1561 >>> tf.keras.backend.ndim(kvar)
1562 2
1564 """
1565 return x.shape.rank
1568@keras_export("keras.backend.dtype")
1569@tf.__internal__.dispatch.add_dispatch_support
1570@doc_controls.do_not_generate_docs
1571def dtype(x):
1572 """Returns the dtype of a Keras tensor or variable, as a string.
1574 Args:
1575 x: Tensor or variable.
1577 Returns:
1578 String, dtype of `x`.
1580 Examples:
1582 >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5)))
1583 'float32'
1584 >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5),
1585 ... dtype='float32'))
1586 'float32'
1587 >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5),
1588 ... dtype='float64'))
1589 'float64'
1590 >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]))
1591 >>> tf.keras.backend.dtype(kvar)
1592 'float32'
1593 >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]),
1594 ... dtype='float32')
1595 >>> tf.keras.backend.dtype(kvar)
1596 'float32'
1598 """
1599 return x.dtype.base_dtype.name
1602@doc_controls.do_not_generate_docs
1603def dtype_numpy(x):
1604 """Returns the numpy dtype of a Keras tensor or variable.
1606 Args:
1607 x: Tensor or variable.
1609 Returns:
1610 numpy.dtype, dtype of `x`.
1611 """
1612 return tf.as_dtype(x.dtype).as_numpy_dtype
1615@keras_export("keras.backend.eval")
1616@doc_controls.do_not_generate_docs
1617def eval(x):
1618 """Evaluates the value of a variable.
1620 Args:
1621 x: A variable.
1623 Returns:
1624 A Numpy array.
1626 Examples:
1628 >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]),
1629 ... dtype='float32')
1630 >>> tf.keras.backend.eval(kvar)
1631 array([[1., 2.],
1632 [3., 4.]], dtype=float32)
1634 """
1635 return get_value(to_dense(x))
1638@keras_export("keras.backend.zeros")
1639@doc_controls.do_not_generate_docs
1640def zeros(shape, dtype=None, name=None):
1641 """Instantiates an all-zeros variable and returns it.
1643 Args:
1644 shape: Tuple or list of integers, shape of returned Keras variable
1645 dtype: data type of returned Keras variable
1646 name: name of returned Keras variable
1648 Returns:
1649 A variable (including Keras metadata), filled with `0.0`.
1650 Note that if `shape` was symbolic, we cannot return a variable,
1651 and will return a dynamically-shaped tensor instead.
1653 Example:
1655 >>> kvar = tf.keras.backend.zeros((3,4))
1656 >>> tf.keras.backend.eval(kvar)
1657 array([[0., 0., 0., 0.],
1658 [0., 0., 0., 0.],
1659 [0., 0., 0., 0.]], dtype=float32)
1660 >>> A = tf.constant([1,2,3])
1661 >>> kvar2 = tf.keras.backend.zeros(A.shape) # [0., 0., 0.]
1662 >>> tf.keras.backend.eval(kvar2)
1663 array([0., 0., 0.], dtype=float32)
1664 >>> kvar3 = tf.keras.backend.zeros(A.shape,dtype=tf.int32)
1665 >>> tf.keras.backend.eval(kvar3)
1666 array([0, 0, 0], dtype=int32)
1667 >>> kvar4 = tf.keras.backend.zeros([2,3])
1668 >>> tf.keras.backend.eval(kvar4)
1669 array([[0., 0., 0.],
1670 [0., 0., 0.]], dtype=float32)
1672 """
1673 with tf.init_scope():
1674 if dtype is None:
1675 dtype = floatx()
1676 tf_dtype = tf.as_dtype(dtype)
1677 v = tf.zeros(shape=shape, dtype=tf_dtype, name=name)
1678 if py_all(v.shape.as_list()):
1679 return variable(v, dtype=dtype, name=name)
1680 return v
1683@keras_export("keras.backend.ones")
1684@tf.__internal__.dispatch.add_dispatch_support
1685@doc_controls.do_not_generate_docs
1686def ones(shape, dtype=None, name=None):
1687 """Instantiates an all-ones variable and returns it.
1689 Args:
1690 shape: Tuple of integers, shape of returned Keras variable.
1691 dtype: String, data type of returned Keras variable.
1692 name: String, name of returned Keras variable.
1694 Returns:
1695 A Keras variable, filled with `1.0`.
1696 Note that if `shape` was symbolic, we cannot return a variable,
1697 and will return a dynamically-shaped tensor instead.
1699 Example:
1702 >>> kvar = tf.keras.backend.ones((3,4))
1703 >>> tf.keras.backend.eval(kvar)
1704 array([[1., 1., 1., 1.],
1705 [1., 1., 1., 1.],
1706 [1., 1., 1., 1.]], dtype=float32)
1708 """
1709 with tf.init_scope():
1710 if dtype is None:
1711 dtype = floatx()
1712 tf_dtype = tf.as_dtype(dtype)
1713 v = tf.ones(shape=shape, dtype=tf_dtype, name=name)
1714 if py_all(v.shape.as_list()):
1715 return variable(v, dtype=dtype, name=name)
1716 return v
1719@keras_export("keras.backend.eye")
1720@tf.__internal__.dispatch.add_dispatch_support
1721@doc_controls.do_not_generate_docs
1722def eye(size, dtype=None, name=None):
1723 """Instantiate an identity matrix and returns it.
1725 Args:
1726 size: Integer, number of rows/columns.
1727 dtype: String, data type of returned Keras variable.
1728 name: String, name of returned Keras variable.
1730 Returns:
1731 A Keras variable, an identity matrix.
1733 Example:
1736 >>> kvar = tf.keras.backend.eye(3)
1737 >>> tf.keras.backend.eval(kvar)
1738 array([[1., 0., 0.],
1739 [0., 1., 0.],
1740 [0., 0., 1.]], dtype=float32)
1743 """
1744 if dtype is None:
1745 dtype = floatx()
1746 tf_dtype = tf.as_dtype(dtype)
1747 return variable(tf.eye(size, dtype=tf_dtype), dtype, name)
1750@keras_export("keras.backend.zeros_like")
1751@doc_controls.do_not_generate_docs
1752def zeros_like(x, dtype=None, name=None):
1753 """Instantiates an all-zeros variable of the same shape as another tensor.
1755 Args:
1756 x: Keras variable or Keras tensor.
1757 dtype: dtype of returned Keras variable.
1758 `None` uses the dtype of `x`.
1759 name: name for the variable to create.
1761 Returns:
1762 A Keras variable with the shape of `x` filled with zeros.
1764 Example:
1766 ```python
1767 kvar = tf.keras.backend.variable(np.random.random((2,3)))
1768 kvar_zeros = tf.keras.backend.zeros_like(kvar)
1769 K.eval(kvar_zeros)
1770 # array([[ 0., 0., 0.], [ 0., 0., 0.]], dtype=float32)
1771 ```
1772 """
1773 return tf.zeros_like(x, dtype=dtype, name=name)
1776@keras_export("keras.backend.ones_like")
1777@tf.__internal__.dispatch.add_dispatch_support
1778@doc_controls.do_not_generate_docs
1779def ones_like(x, dtype=None, name=None):
1780 """Instantiates an all-ones variable of the same shape as another tensor.
1782 Args:
1783 x: Keras variable or tensor.
1784 dtype: String, dtype of returned Keras variable.
1785 None uses the dtype of x.
1786 name: String, name for the variable to create.
1788 Returns:
1789 A Keras variable with the shape of x filled with ones.
1791 Example:
1793 >>> kvar = tf.keras.backend.variable(np.random.random((2,3)))
1794 >>> kvar_ones = tf.keras.backend.ones_like(kvar)
1795 >>> tf.keras.backend.eval(kvar_ones)
1796 array([[1., 1., 1.],
1797 [1., 1., 1.]], dtype=float32)
1799 """
1800 return tf.ones_like(x, dtype=dtype, name=name)
1803def identity(x, name=None):
1804 """Returns a tensor with the same content as the input tensor.
1806 Args:
1807 x: The input tensor.
1808 name: String, name for the variable to create.
1810 Returns:
1811 A tensor of the same shape, type and content.
1812 """
1813 return tf.identity(x, name=name)
1816# Global flag to enforce tf.random.Generator for RandomGenerator.
1817# When this is enabled, for any caller to RandomGenerator, it will use
1818# tf.random.Generator to generate random numbers.
1819# The legacy behavior is to use TF's legacy stateful RNG ops like
1820# tf.random.uniform.
1821_USE_GENERATOR_FOR_RNG = False
1823# The global generator to create the seed when initializing the
1824# tf.random.Genrator used by RandomGenerator. When tf.random.Generator becomes
1825# the default solution, we would like the it to be initialized in a controlable
1826# way, so that each client of the program could start with same seed. This is
1827# very important for certain use case that requires all the client to have their
1828# state in sync. This instance will be set when user call
1829# `tf.keras.utils.set_random_seed()`
1830_SEED_GENERATOR = threading.local()
1833@keras_export(
1834 "keras.backend.experimental.is_tf_random_generator_enabled", v1=[]
1835)
1836def is_tf_random_generator_enabled():
1837 """Check whether `tf.random.Generator` is used for RNG in Keras.
1839 Compared to existing TF stateful random ops, `tf.random.Generator` uses
1840 `tf.Variable` and stateless random ops to generate random numbers,
1841 which leads to better reproducibility in distributed training.
1842 Note enabling it might introduce some breakage to existing code,
1843 by producing differently-seeded random number sequences
1844 and breaking tests that rely on specific random numbers being generated.
1845 To disable the
1846 usage of `tf.random.Generator`, please use
1847 `tf.keras.backend.experimental.disable_random_generator`.
1849 We expect the `tf.random.Generator` code path to become the default, and
1850 will remove the legacy stateful random ops such as `tf.random.uniform` in
1851 the future (see the [TF RNG guide](
1852 https://www.tensorflow.org/guide/random_numbers)).
1854 This API will also be removed in a future release as well, together with
1855 `tf.keras.backend.experimental.enable_tf_random_generator()` and
1856 `tf.keras.backend.experimental.disable_tf_random_generator()`
1858 Returns:
1859 boolean: whether `tf.random.Generator` is used for random number
1860 generation in Keras.
1861 """
1862 return _USE_GENERATOR_FOR_RNG
1865@keras_export("keras.backend.experimental.enable_tf_random_generator", v1=[])
1866def enable_tf_random_generator():
1867 """Enable the `tf.random.Generator` as the RNG for Keras.
1869 See `tf.keras.backend.experimental.is_tf_random_generator_enabled` for more
1870 details.
1871 """
1873 global _USE_GENERATOR_FOR_RNG
1874 _USE_GENERATOR_FOR_RNG = True
1877@keras_export("keras.backend.experimental.disable_tf_random_generator", v1=[])
1878def disable_tf_random_generator():
1879 """Disable the `tf.random.Generator` as the RNG for Keras.
1881 See `tf.keras.backend.experimental.is_tf_random_generator_enabled` for more
1882 details.
1883 """
1884 global _USE_GENERATOR_FOR_RNG
1885 _USE_GENERATOR_FOR_RNG = False
1888class RandomGenerator(tf.__internal__.tracking.AutoTrackable):
1889 """Random generator that selects appropriate random ops.
1891 This class contains the logic for legacy stateful random ops, as well as the
1892 new stateless random ops with seeds and tf.random.Generator. Any class that
1893 relies on RNG (eg initializer, shuffle, dropout) should use this class to
1894 handle the transition from legacy RNGs to new RNGs.
1896 Args:
1897 seed: Optional int seed. When `rng_type` is "stateful", the seed is used
1898 to create `tf.random.Generator` to produce deterministic sequences.
1899 When `rng_type` is "stateless", new seed will be created if it is not
1900 provided by user, and it will be passed down to stateless random ops.
1901 When `rng_type` is "legacy_stateful", the seed will be passed down to
1902 stateful random ops.
1903 rng_type: Type of RNG to use, one of "stateful", "stateless",
1904 "legacy_stateful". When `None` it uses "stateful" if
1905 `enable_tf_random_generator` has been activated, or
1906 "legacy_stateful" otherwise.
1907 - When using "stateless", the random ops outputs are constant (the same
1908 inputs result in the same outputs).
1909 - When using "stateful" or "legacy_stateful", the random ops outputs are
1910 non-constant, but deterministic: calling the same random op multiple
1911 times with the same inputs results in a deterministic sequence of
1912 different outputs.
1913 - "legacy_stateful" is backed by TF1 stateful RNG ops
1914 (e.g. `tf.random.uniform`), while "stateful"
1915 is backed by TF2 APIs (e.g. `tf.random.Generator.uniform`).
1916 Defaults to `None`.
1917 """
1919 RNG_STATELESS = "stateless"
1920 RNG_STATEFUL = "stateful"
1921 RNG_LEGACY_STATEFUL = "legacy_stateful"
1923 def __init__(self, seed=None, rng_type=None, **kwargs):
1924 self._seed = seed
1925 self._set_rng_type(rng_type, **kwargs)
1926 self._built = False
1928 def _set_rng_type(self, rng_type, **kwargs):
1929 # Only supported kwargs is "force_generator", which we will remove once
1930 # we clean up all the caller.
1931 # TODO(scottzhu): Remove the kwargs for force_generator.
1932 if kwargs.get("force_generator", False):
1933 rng_type = self.RNG_STATEFUL
1934 if rng_type is None:
1935 if is_tf_random_generator_enabled():
1936 self._rng_type = self.RNG_STATEFUL
1937 else:
1938 self._rng_type = self.RNG_LEGACY_STATEFUL
1939 else:
1940 if rng_type not in [
1941 self.RNG_STATEFUL,
1942 self.RNG_LEGACY_STATEFUL,
1943 self.RNG_STATELESS,
1944 ]:
1945 raise ValueError(
1946 "Invalid `rng_type` received. "
1947 'Valid `rng_type` are ["stateless", '
1948 '"stateful", "legacy_stateful"].'
1949 f" Got: {rng_type}"
1950 )
1951 self._rng_type = rng_type
1953 def _maybe_init(self):
1954 """Lazily init the RandomGenerator.
1956 The TF API executing_eagerly_outside_functions() has some side effect,
1957 and couldn't be used before API like tf.enable_eager_execution(). Some
1958 of the client side code was creating the initializer at the code load
1959 time, which triggers the creation of RandomGenerator. Lazy init this
1960 class to walkaround this issue until it is resolved on TF side.
1961 """
1962 # TODO(b/167482354): Change this back to normal init when the bug is
1963 # fixed.
1964 if self._built:
1965 return
1967 if (
1968 self._rng_type == self.RNG_STATEFUL
1969 and not tf.compat.v1.executing_eagerly_outside_functions()
1970 ):
1971 # Fall back to legacy stateful since the generator need to work in
1972 # tf2.
1973 self._rng_type = self.RNG_LEGACY_STATEFUL
1975 if self._rng_type == self.RNG_STATELESS:
1976 self._seed = self._create_seed(self._seed)
1977 self._generator = None
1978 elif self._rng_type == self.RNG_STATEFUL:
1979 with tf_utils.maybe_init_scope(self):
1980 seed = self._create_seed(self._seed)
1981 self._generator = tf.random.Generator.from_seed(
1982 seed, alg=tf.random.Algorithm.AUTO_SELECT
1983 )
1984 else:
1985 # In legacy stateful, we use stateful op, regardless whether user
1986 # provide seed or not. Seeded stateful op will ensure generating
1987 # same sequences.
1988 self._generator = None
1989 self._built = True
1991 def make_seed_for_stateless_op(self):
1992 """Generate a new seed based on the init config.
1994 Note that this will not return python ints which will be frozen in the
1995 graph and cause stateless op to return the same value. It will only
1996 return value when generator is used, otherwise it will return None.
1998 Returns:
1999 A tensor with shape [2,].
2000 """
2001 self._maybe_init()
2002 if self._rng_type == self.RNG_STATELESS:
2003 return [self._seed, 0]
2004 elif self._rng_type == self.RNG_STATEFUL:
2005 return self._generator.make_seeds()[:, 0]
2006 return None
2008 def make_legacy_seed(self):
2009 """Create a new seed for the legacy stateful ops to use.
2011 When user didn't provide any original seed, this method will return
2012 None. Otherwise it will increment the counter and return as the new
2013 seed.
2015 Note that it is important to generate different seed for stateful ops in
2016 the `tf.function`. The random ops will return same value when same seed
2017 is provided in the `tf.function`.
2019 Returns:
2020 int as new seed, or None.
2021 """
2022 if self._seed is not None:
2023 result = self._seed
2024 self._seed += 1
2025 return result
2026 return None
2028 def _create_seed(self, user_specified_seed):
2029 if user_specified_seed is not None:
2030 return user_specified_seed
2031 elif getattr(_SEED_GENERATOR, "generator", None):
2032 return _SEED_GENERATOR.generator.randint(1, 1e9)
2033 else:
2034 return random.randint(1, int(1e9))
2036 def random_normal(
2037 self, shape, mean=0.0, stddev=1.0, dtype=None, nonce=None
2038 ):
2039 """Produce random number based on the normal distribution.
2041 Args:
2042 shape: The shape of the random values to generate.
2043 mean: Floats, default to 0. Mean of the random values to generate.
2044 stddev: Floats, default to 1. Standard deviation of the random values
2045 to generate.
2046 dtype: Optional dtype of the tensor. Only floating point types are
2047 supported. If not specified, `tf.keras.backend.floatx()` is used,
2048 which default to `float32` unless you configured it otherwise (via
2049 `tf.keras.backend.set_floatx(float_dtype)`)
2050 nonce: Optional integer scalar, that will be folded into the seed in
2051 the stateless mode.
2052 """
2053 self._maybe_init()
2054 dtype = dtype or floatx()
2055 if self._rng_type == self.RNG_STATEFUL:
2056 return self._generator.normal(
2057 shape=shape, mean=mean, stddev=stddev, dtype=dtype
2058 )
2059 elif self._rng_type == self.RNG_STATELESS:
2060 seed = self.make_seed_for_stateless_op()
2061 if nonce:
2062 seed = tf.random.experimental.stateless_fold_in(seed, nonce)
2063 return tf.random.stateless_normal(
2064 shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed
2065 )
2066 return tf.random.normal(
2067 shape=shape,
2068 mean=mean,
2069 stddev=stddev,
2070 dtype=dtype,
2071 seed=self.make_legacy_seed(),
2072 )
2074 def random_uniform(
2075 self, shape, minval=0.0, maxval=None, dtype=None, nonce=None
2076 ):
2077 """Produce random number based on the uniform distribution.
2079 Args:
2080 shape: The shape of the random values to generate.
2081 minval: Floats, default to 0. Lower bound of the range of
2082 random values to generate (inclusive).
2083 minval: Floats, default to None. Upper bound of the range of
2084 random values to generate (exclusive).
2085 dtype: Optional dtype of the tensor. Only floating point types are
2086 supported. If not specified, `tf.keras.backend.floatx()` is used,
2087 which default to `float32` unless you configured it otherwise (via
2088 `tf.keras.backend.set_floatx(float_dtype)`)
2089 nonce: Optional integer scalar, that will be folded into the seed in
2090 the stateless mode.
2091 """
2092 self._maybe_init()
2093 dtype = dtype or floatx()
2094 if self._rng_type == self.RNG_STATEFUL:
2095 return self._generator.uniform(
2096 shape=shape, minval=minval, maxval=maxval, dtype=dtype
2097 )
2098 elif self._rng_type == self.RNG_STATELESS:
2099 seed = self.make_seed_for_stateless_op()
2100 if nonce:
2101 seed = tf.random.experimental.stateless_fold_in(seed, nonce)
2102 return tf.random.stateless_uniform(
2103 shape=shape,
2104 minval=minval,
2105 maxval=maxval,
2106 dtype=dtype,
2107 seed=seed,
2108 )
2109 return tf.random.uniform(
2110 shape=shape,
2111 minval=minval,
2112 maxval=maxval,
2113 dtype=dtype,
2114 seed=self.make_legacy_seed(),
2115 )
2117 def truncated_normal(
2118 self, shape, mean=0.0, stddev=1.0, dtype=None, nonce=None
2119 ):
2120 """Produce random number based on the truncated normal distribution.
2122 Args:
2123 shape: The shape of the random values to generate.
2124 mean: Floats, default to 0. Mean of the random values to generate.
2125 stddev: Floats, default to 1. Standard deviation of the random values
2126 to generate.
2127 dtype: Optional dtype of the tensor. Only floating point types are
2128 supported. If not specified, `tf.keras.backend.floatx()` is used,
2129 which default to `float32` unless you configured it otherwise (via
2130 `tf.keras.backend.set_floatx(float_dtype)`)
2131 nonce: Optional integer scalar, that will be folded into the seed in
2132 the stateless mode.
2133 """
2134 self._maybe_init()
2135 dtype = dtype or floatx()
2136 if self._rng_type == self.RNG_STATEFUL:
2137 return self._generator.truncated_normal(
2138 shape=shape, mean=mean, stddev=stddev, dtype=dtype
2139 )
2140 elif self._rng_type == self.RNG_STATELESS:
2141 seed = self.make_seed_for_stateless_op()
2142 if nonce:
2143 seed = tf.random.experimental.stateless_fold_in(seed, nonce)
2144 return tf.random.stateless_truncated_normal(
2145 shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed
2146 )
2147 return tf.random.truncated_normal(
2148 shape=shape,
2149 mean=mean,
2150 stddev=stddev,
2151 dtype=dtype,
2152 seed=self.make_legacy_seed(),
2153 )
2155 def dropout(self, inputs, rate, noise_shape=None):
2156 self._maybe_init()
2157 if self._rng_type == self.RNG_STATEFUL:
2158 return tf.nn.experimental.general_dropout(
2159 inputs,
2160 rate=rate,
2161 noise_shape=noise_shape,
2162 uniform_sampler=self._generator.uniform,
2163 )
2164 elif self._rng_type == self.RNG_STATELESS:
2165 return tf.nn.experimental.stateless_dropout(
2166 inputs,
2167 rate=rate,
2168 noise_shape=noise_shape,
2169 seed=self.make_seed_for_stateless_op(),
2170 )
2171 else:
2172 return tf.nn.dropout(
2173 inputs,
2174 rate=rate,
2175 noise_shape=noise_shape,
2176 seed=self.make_legacy_seed(),
2177 )
2180@keras_export("keras.backend.random_uniform_variable")
2181@doc_controls.do_not_generate_docs
2182def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None):
2183 """Instantiates a variable with values drawn from a uniform distribution.
2185 Args:
2186 shape: Tuple of integers, shape of returned Keras variable.
2187 low: Float, lower boundary of the output interval.
2188 high: Float, upper boundary of the output interval.
2189 dtype: String, dtype of returned Keras variable.
2190 name: String, name of returned Keras variable.
2191 seed: Integer, random seed.
2193 Returns:
2194 A Keras variable, filled with drawn samples.
2196 Example:
2198 >>> kvar = tf.keras.backend.random_uniform_variable(shape=(2,3),
2199 ... low=0.0, high=1.0)
2200 >>> kvar
2201 <tf.Variable 'Variable:0' shape=(2, 3) dtype=float32, numpy=...,
2202 dtype=float32)>
2203 """
2204 if dtype is None:
2205 dtype = floatx()
2206 tf_dtype = tf.as_dtype(dtype)
2207 if seed is None:
2208 # ensure that randomness is conditioned by the Numpy RNG
2209 seed = np.random.randint(10e8)
2210 value = tf.compat.v1.random_uniform_initializer(
2211 low, high, dtype=tf_dtype, seed=seed
2212 )(shape)
2213 return variable(value, dtype=dtype, name=name)
2216@keras_export("keras.backend.random_normal_variable")
2217@doc_controls.do_not_generate_docs
2218def random_normal_variable(
2219 shape, mean, scale, dtype=None, name=None, seed=None
2220):
2221 """Instantiates a variable with values drawn from a normal distribution.
2223 Args:
2224 shape: Tuple of integers, shape of returned Keras variable.
2225 mean: Float, mean of the normal distribution.
2226 scale: Float, standard deviation of the normal distribution.
2227 dtype: String, dtype of returned Keras variable.
2228 name: String, name of returned Keras variable.
2229 seed: Integer, random seed.
2231 Returns:
2232 A Keras variable, filled with drawn samples.
2234 Example:
2236 >>> kvar = tf.keras.backend.random_normal_variable(shape=(2,3),
2237 ... mean=0.0, scale=1.0)
2238 >>> kvar
2239 <tf.Variable 'Variable:0' shape=(2, 3) dtype=float32, numpy=...,
2240 dtype=float32)>
2241 """
2242 if dtype is None:
2243 dtype = floatx()
2244 tf_dtype = tf.as_dtype(dtype)
2245 if seed is None:
2246 # ensure that randomness is conditioned by the Numpy RNG
2247 seed = np.random.randint(10e8)
2248 value = tf.compat.v1.random_normal_initializer(
2249 mean, scale, dtype=tf_dtype, seed=seed
2250 )(shape)
2251 return variable(value, dtype=dtype, name=name)
2254@keras_export("keras.backend.count_params")
2255@doc_controls.do_not_generate_docs
2256def count_params(x):
2257 """Returns the static number of elements in a variable or tensor.
2259 Args:
2260 x: Variable or tensor.
2262 Returns:
2263 Integer, the number of scalars in `x`.
2265 Example:
2267 >>> kvar = tf.keras.backend.zeros((2,3))
2268 >>> tf.keras.backend.count_params(kvar)
2269 6
2270 >>> tf.keras.backend.eval(kvar)
2271 array([[0., 0., 0.],
2272 [0., 0., 0.]], dtype=float32)
2274 """
2275 return np.prod(x.shape.as_list())
2278@keras_export("keras.backend.cast")
2279@tf.__internal__.dispatch.add_dispatch_support
2280@doc_controls.do_not_generate_docs
2281def cast(x, dtype):
2282 """Casts a tensor to a different dtype and returns it.
2284 You can cast a Keras variable but it still returns a Keras tensor.
2286 Args:
2287 x: Keras tensor (or variable).
2288 dtype: String, either (`'float16'`, `'float32'`, or `'float64'`).
2290 Returns:
2291 Keras tensor with dtype `dtype`.
2293 Examples:
2294 Cast a float32 variable to a float64 tensor
2296 >>> input = tf.keras.backend.ones(shape=(1,3))
2297 >>> print(input)
2298 <tf.Variable 'Variable:0' shape=(1, 3) dtype=float32,
2299 numpy=array([[1., 1., 1.]], dtype=float32)>
2300 >>> cast_input = tf.keras.backend.cast(input, dtype='float64')
2301 >>> print(cast_input)
2302 tf.Tensor([[1. 1. 1.]], shape=(1, 3), dtype=float64)
2304 """
2305 return tf.cast(x, dtype)
2308# UPDATES OPS
2311@keras_export("keras.backend.update")
2312@doc_controls.do_not_generate_docs
2313def update(x, new_x):
2314 return tf.compat.v1.assign(x, new_x)
2317@keras_export("keras.backend.update_add")
2318@doc_controls.do_not_generate_docs
2319def update_add(x, increment):
2320 """Update the value of `x` by adding `increment`.
2322 Args:
2323 x: A Variable.
2324 increment: A tensor of same shape as `x`.
2326 Returns:
2327 The variable `x` updated.
2328 """
2329 return tf.compat.v1.assign_add(x, increment)
2332@keras_export("keras.backend.update_sub")
2333@doc_controls.do_not_generate_docs
2334def update_sub(x, decrement):
2335 """Update the value of `x` by subtracting `decrement`.
2337 Args:
2338 x: A Variable.
2339 decrement: A tensor of same shape as `x`.
2341 Returns:
2342 The variable `x` updated.
2343 """
2344 return tf.compat.v1.assign_sub(x, decrement)
2347@keras_export("keras.backend.moving_average_update")
2348@doc_controls.do_not_generate_docs
2349def moving_average_update(x, value, momentum):
2350 """Compute the exponential moving average of a value.
2352 The moving average 'x' is updated with 'value' following:
2354 ```
2355 x = x * momentum + value * (1 - momentum)
2356 ```
2358 For example:
2360 >>> x = tf.Variable(0.0)
2361 >>> momentum=0.9
2362 >>> moving_average_update(x, value = 2.0, momentum=momentum).numpy()
2363 >>> x.numpy()
2364 0.2
2366 The result will be biased towards the initial value of the variable.
2368 If the variable was initialized to zero, you can divide by
2369 `1 - momentum ** num_updates` to debias it (Section 3 of
2370 [Kingma et al., 2015](https://arxiv.org/abs/1412.6980)):
2372 >>> num_updates = 1.0
2373 >>> x_zdb = x/(1 - momentum**num_updates)
2374 >>> x_zdb.numpy()
2375 2.0
2377 Args:
2378 x: A Variable, the moving average.
2379 value: A tensor with the same shape as `x`, the new value to be
2380 averaged in.
2381 momentum: The moving average momentum.
2383 Returns:
2384 The updated variable.
2385 """
2386 if tf.__internal__.tf2.enabled():
2387 momentum = tf.cast(momentum, x.dtype)
2388 value = tf.cast(value, x.dtype)
2389 return x.assign_sub((x - value) * (1 - momentum))
2390 else:
2391 return tf.__internal__.train.assign_moving_average(
2392 x, value, momentum, zero_debias=True
2393 )
2396# LINEAR ALGEBRA
2399@keras_export("keras.backend.dot")
2400@tf.__internal__.dispatch.add_dispatch_support
2401@doc_controls.do_not_generate_docs
2402def dot(x, y):
2403 """Multiplies 2 tensors (and/or variables) and returns a tensor.
2405 This operation corresponds to `numpy.dot(a, b, out=None)`.
2407 Args:
2408 x: Tensor or variable.
2409 y: Tensor or variable.
2411 Returns:
2412 A tensor, dot product of `x` and `y`.
2414 Examples:
2416 If inputs `x` and `y` are 2-D arrays, then it is equivalent to `tf.matmul`.
2417 >>> x = tf.keras.backend.placeholder(shape=(2, 3))
2418 >>> y = tf.keras.backend.placeholder(shape=(3, 4))
2419 >>> xy = tf.keras.backend.dot(x, y)
2420 >>> xy
2421 <KerasTensor: shape=(2, 4) dtype=float32 ...>
2423 >>> x = tf.keras.backend.placeholder(shape=(32, 28, 3))
2424 >>> y = tf.keras.backend.placeholder(shape=(3, 4))
2425 >>> xy = tf.keras.backend.dot(x, y)
2426 >>> xy
2427 <KerasTensor: shape=(32, 28, 4) dtype=float32 ...>
2429 If `x` is an N-D array and `y` is an M-D array (where M>=2), it is a sum
2430 product over the last axis of `x` and the second-to-last axis of `y`.
2431 >>> x = tf.keras.backend.random_uniform_variable(
2432 ... shape=(2, 3), low=0., high=1.)
2433 >>> y = tf.keras.backend.ones((4, 3, 5))
2434 >>> xy = tf.keras.backend.dot(x, y)
2435 >>> tf.keras.backend.int_shape(xy)
2436 (2, 4, 5)
2437 """
2438 if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2):
2439 x_shape = []
2440 for i, s in zip(int_shape(x), tf.unstack(tf.shape(x))):
2441 if i is not None:
2442 x_shape.append(i)
2443 else:
2444 x_shape.append(s)
2445 x_shape = tuple(x_shape)
2446 y_shape = []
2447 for i, s in zip(int_shape(y), tf.unstack(tf.shape(y))):
2448 if i is not None:
2449 y_shape.append(i)
2450 else:
2451 y_shape.append(s)
2452 y_shape = tuple(y_shape)
2453 y_permute_dim = list(range(ndim(y)))
2454 y_permute_dim = [y_permute_dim.pop(-2)] + y_permute_dim
2455 xt = tf.reshape(x, [-1, x_shape[-1]])
2456 yt = tf.reshape(
2457 tf.compat.v1.transpose(y, perm=y_permute_dim), [y_shape[-2], -1]
2458 )
2459 return tf.reshape(
2460 tf.matmul(xt, yt), x_shape[:-1] + y_shape[:-2] + y_shape[-1:]
2461 )
2462 if is_sparse(x):
2463 out = tf.sparse.sparse_dense_matmul(x, y)
2464 else:
2465 out = tf.matmul(x, y)
2466 return out
2469@keras_export("keras.backend.batch_dot")
2470@tf.__internal__.dispatch.add_dispatch_support
2471@doc_controls.do_not_generate_docs
2472def batch_dot(x, y, axes=None):
2473 """Batchwise dot product.
2475 `batch_dot` is used to compute dot product of `x` and `y` when
2476 `x` and `y` are data in batch, i.e. in a shape of
2477 `(batch_size, :)`.
2478 `batch_dot` results in a tensor or variable with less dimensions
2479 than the input. If the number of dimensions is reduced to 1,
2480 we use `expand_dims` to make sure that ndim is at least 2.
2482 Args:
2483 x: Keras tensor or variable with `ndim >= 2`.
2484 y: Keras tensor or variable with `ndim >= 2`.
2485 axes: Tuple or list of integers with target dimensions, or single integer.
2486 The sizes of `x.shape[axes[0]]` and `y.shape[axes[1]]` should be equal.
2488 Returns:
2489 A tensor with shape equal to the concatenation of `x`'s shape
2490 (less the dimension that was summed over) and `y`'s shape
2491 (less the batch dimension and the dimension that was summed over).
2492 If the final rank is 1, we reshape it to `(batch_size, 1)`.
2494 Examples:
2496 >>> x_batch = tf.keras.backend.ones(shape=(32, 20, 1))
2497 >>> y_batch = tf.keras.backend.ones(shape=(32, 30, 20))
2498 >>> xy_batch_dot = tf.keras.backend.batch_dot(x_batch, y_batch, axes=(1, 2))
2499 >>> tf.keras.backend.int_shape(xy_batch_dot)
2500 (32, 1, 30)
2502 Shape inference:
2503 Let `x`'s shape be `(100, 20)` and `y`'s shape be `(100, 30, 20)`.
2504 If `axes` is (1, 2), to find the output shape of resultant tensor,
2505 loop through each dimension in `x`'s shape and `y`'s shape:
2506 * `x.shape[0]` : 100 : append to output shape
2507 * `x.shape[1]` : 20 : do not append to output shape,
2508 dimension 1 of `x` has been summed over. (`dot_axes[0]` = 1)
2509 * `y.shape[0]` : 100 : do not append to output shape,
2510 always ignore first dimension of `y`
2511 * `y.shape[1]` : 30 : append to output shape
2512 * `y.shape[2]` : 20 : do not append to output shape,
2513 dimension 2 of `y` has been summed over. (`dot_axes[1]` = 2)
2514 `output_shape` = `(100, 30)`
2515 """
2516 x_shape = int_shape(x)
2517 y_shape = int_shape(y)
2519 x_ndim = len(x_shape)
2520 y_ndim = len(y_shape)
2522 if x_ndim < 2 or y_ndim < 2:
2523 raise ValueError(
2524 "Cannot do batch_dot on inputs "
2525 "with rank < 2. "
2526 "Received inputs with shapes "
2527 + str(x_shape)
2528 + " and "
2529 + str(y_shape)
2530 + "."
2531 )
2533 x_batch_size = x_shape[0]
2534 y_batch_size = y_shape[0]
2536 if x_batch_size is not None and y_batch_size is not None:
2537 if x_batch_size != y_batch_size:
2538 raise ValueError(
2539 "Cannot do batch_dot on inputs "
2540 "with different batch sizes. "
2541 "Received inputs with shapes "
2542 + str(x_shape)
2543 + " and "
2544 + str(y_shape)
2545 + "."
2546 )
2547 if isinstance(axes, int):
2548 axes = [axes, axes]
2550 if axes is None:
2551 if y_ndim == 2:
2552 axes = [x_ndim - 1, y_ndim - 1]
2553 else:
2554 axes = [x_ndim - 1, y_ndim - 2]
2556 if py_any(isinstance(a, (list, tuple)) for a in axes):
2557 raise ValueError(
2558 "Multiple target dimensions are not supported. "
2559 + "Expected: None, int, (int, int), "
2560 + "Provided: "
2561 + str(axes)
2562 )
2564 # if tuple, convert to list.
2565 axes = list(axes)
2567 # convert negative indices.
2568 if axes[0] < 0:
2569 axes[0] += x_ndim
2570 if axes[1] < 0:
2571 axes[1] += y_ndim
2573 # sanity checks
2574 if 0 in axes:
2575 raise ValueError(
2576 "Cannot perform batch_dot over axis 0. "
2577 "If your inputs are not batched, "
2578 "add a dummy batch dimension to your "
2579 "inputs using K.expand_dims(x, 0)"
2580 )
2581 a0, a1 = axes
2582 d1 = x_shape[a0]
2583 d2 = y_shape[a1]
2585 if d1 is not None and d2 is not None and d1 != d2:
2586 raise ValueError(
2587 "Cannot do batch_dot on inputs with shapes "
2588 + str(x_shape)
2589 + " and "
2590 + str(y_shape)
2591 + " with axes="
2592 + str(axes)
2593 + ". x.shape[%d] != y.shape[%d] (%d != %d)."
2594 % (axes[0], axes[1], d1, d2)
2595 )
2597 # backup ndims. Need them later.
2598 orig_x_ndim = x_ndim
2599 orig_y_ndim = y_ndim
2601 # if rank is 2, expand to 3.
2602 if x_ndim == 2:
2603 x = tf.expand_dims(x, 1)
2604 a0 += 1
2605 x_ndim += 1
2606 if y_ndim == 2:
2607 y = tf.expand_dims(y, 2)
2608 y_ndim += 1
2610 # bring x's dimension to be reduced to last axis.
2611 if a0 != x_ndim - 1:
2612 pattern = list(range(x_ndim))
2613 for i in range(a0, x_ndim - 1):
2614 pattern[i] = pattern[i + 1]
2615 pattern[-1] = a0
2616 x = tf.compat.v1.transpose(x, pattern)
2618 # bring y's dimension to be reduced to axis 1.
2619 if a1 != 1:
2620 pattern = list(range(y_ndim))
2621 for i in range(a1, 1, -1):
2622 pattern[i] = pattern[i - 1]
2623 pattern[1] = a1
2624 y = tf.compat.v1.transpose(y, pattern)
2626 # normalize both inputs to rank 3.
2627 if x_ndim > 3:
2628 # squash middle dimensions of x.
2629 x_shape = shape(x)
2630 x_mid_dims = x_shape[1:-1]
2631 x_squashed_shape = tf.stack([x_shape[0], -1, x_shape[-1]])
2632 x = tf.reshape(x, x_squashed_shape)
2633 x_squashed = True
2634 else:
2635 x_squashed = False
2637 if y_ndim > 3:
2638 # squash trailing dimensions of y.
2639 y_shape = shape(y)
2640 y_trail_dims = y_shape[2:]
2641 y_squashed_shape = tf.stack([y_shape[0], y_shape[1], -1])
2642 y = tf.reshape(y, y_squashed_shape)
2643 y_squashed = True
2644 else:
2645 y_squashed = False
2647 result = tf.matmul(x, y)
2649 # if inputs were squashed, we have to reshape the matmul output.
2650 output_shape = tf.shape(result)
2651 do_reshape = False
2653 if x_squashed:
2654 output_shape = tf.concat(
2655 [output_shape[:1], x_mid_dims, output_shape[-1:]], 0
2656 )
2657 do_reshape = True
2659 if y_squashed:
2660 output_shape = tf.concat([output_shape[:-1], y_trail_dims], 0)
2661 do_reshape = True
2663 if do_reshape:
2664 result = tf.reshape(result, output_shape)
2666 # if the inputs were originally rank 2, we remove the added 1 dim.
2667 if orig_x_ndim == 2:
2668 result = tf.squeeze(result, 1)
2669 elif orig_y_ndim == 2:
2670 result = tf.squeeze(result, -1)
2672 return result
2675@keras_export("keras.backend.transpose")
2676@tf.__internal__.dispatch.add_dispatch_support
2677@doc_controls.do_not_generate_docs
2678def transpose(x):
2679 """Transposes a tensor and returns it.
2681 Args:
2682 x: Tensor or variable.
2684 Returns:
2685 A tensor.
2687 Examples:
2689 >>> var = tf.keras.backend.variable([[1, 2, 3], [4, 5, 6]])
2690 >>> tf.keras.backend.eval(var)
2691 array([[1., 2., 3.],
2692 [4., 5., 6.]], dtype=float32)
2693 >>> var_transposed = tf.keras.backend.transpose(var)
2694 >>> tf.keras.backend.eval(var_transposed)
2695 array([[1., 4.],
2696 [2., 5.],
2697 [3., 6.]], dtype=float32)
2698 >>> input = tf.keras.backend.placeholder((2, 3))
2699 >>> input
2700 <KerasTensor: shape=(2, 3) dtype=float32 ...>
2701 >>> input_transposed = tf.keras.backend.transpose(input)
2702 >>> input_transposed
2703 <KerasTensor: shape=(3, 2) dtype=float32 ...>
2704 """
2705 return tf.compat.v1.transpose(x)
2708@keras_export("keras.backend.gather")
2709@tf.__internal__.dispatch.add_dispatch_support
2710@doc_controls.do_not_generate_docs
2711def gather(reference, indices):
2712 """Retrieves the elements of indices `indices` in the tensor `reference`.
2714 Args:
2715 reference: A tensor.
2716 indices: An integer tensor of indices.
2718 Returns:
2719 A tensor of same type as `reference`.
2721 Examples:
2723 >>> var = tf.keras.backend.variable([[1, 2, 3], [4, 5, 6]])
2724 >>> tf.keras.backend.eval(var)
2725 array([[1., 2., 3.],
2726 [4., 5., 6.]], dtype=float32)
2727 >>> var_gathered = tf.keras.backend.gather(var, [0])
2728 >>> tf.keras.backend.eval(var_gathered)
2729 array([[1., 2., 3.]], dtype=float32)
2730 >>> var_gathered = tf.keras.backend.gather(var, [1])
2731 >>> tf.keras.backend.eval(var_gathered)
2732 array([[4., 5., 6.]], dtype=float32)
2733 >>> var_gathered = tf.keras.backend.gather(var, [0,1,0])
2734 >>> tf.keras.backend.eval(var_gathered)
2735 array([[1., 2., 3.],
2736 [4., 5., 6.],
2737 [1., 2., 3.]], dtype=float32)
2738 """
2739 return tf.compat.v1.gather(reference, indices)
2742# ELEMENT-WISE OPERATIONS
2745@keras_export("keras.backend.max")
2746@tf.__internal__.dispatch.add_dispatch_support
2747@doc_controls.do_not_generate_docs
2748def max(x, axis=None, keepdims=False):
2749 """Maximum value in a tensor.
2751 Args:
2752 x: A tensor or variable.
2753 axis: An integer, the axis to find maximum values.
2754 keepdims: A boolean, whether to keep the dimensions or not.
2755 If `keepdims` is `False`, the rank of the tensor is reduced
2756 by 1. If `keepdims` is `True`,
2757 the reduced dimension is retained with length 1.
2759 Returns:
2760 A tensor with maximum values of `x`.
2761 """
2762 return tf.reduce_max(x, axis, keepdims)
2765@keras_export("keras.backend.min")
2766@tf.__internal__.dispatch.add_dispatch_support
2767@doc_controls.do_not_generate_docs
2768def min(x, axis=None, keepdims=False):
2769 """Minimum value in a tensor.
2771 Args:
2772 x: A tensor or variable.
2773 axis: An integer, the axis to find minimum values.
2774 keepdims: A boolean, whether to keep the dimensions or not.
2775 If `keepdims` is `False`, the rank of the tensor is reduced
2776 by 1. If `keepdims` is `True`,
2777 the reduced dimension is retained with length 1.
2779 Returns:
2780 A tensor with minimum values of `x`.
2781 """
2782 return tf.reduce_min(x, axis, keepdims)
2785@keras_export("keras.backend.sum")
2786@tf.__internal__.dispatch.add_dispatch_support
2787@doc_controls.do_not_generate_docs
2788def sum(x, axis=None, keepdims=False):
2789 """Sum of the values in a tensor, alongside the specified axis.
2791 Args:
2792 x: A tensor or variable.
2793 axis: An integer, the axis to sum over.
2794 keepdims: A boolean, whether to keep the dimensions or not.
2795 If `keepdims` is `False`, the rank of the tensor is reduced
2796 by 1. If `keepdims` is `True`,
2797 the reduced dimension is retained with length 1.
2799 Returns:
2800 A tensor with sum of `x`.
2801 """
2802 return tf.reduce_sum(x, axis, keepdims)
2805@keras_export("keras.backend.prod")
2806@tf.__internal__.dispatch.add_dispatch_support
2807@doc_controls.do_not_generate_docs
2808def prod(x, axis=None, keepdims=False):
2809 """Multiplies the values in a tensor, alongside the specified axis.
2811 Args:
2812 x: A tensor or variable.
2813 axis: An integer, the axis to compute the product.
2814 keepdims: A boolean, whether to keep the dimensions or not.
2815 If `keepdims` is `False`, the rank of the tensor is reduced
2816 by 1. If `keepdims` is `True`,
2817 the reduced dimension is retained with length 1.
2819 Returns:
2820 A tensor with the product of elements of `x`.
2821 """
2822 return tf.reduce_prod(x, axis, keepdims)
2825@keras_export("keras.backend.cumsum")
2826@tf.__internal__.dispatch.add_dispatch_support
2827@doc_controls.do_not_generate_docs
2828def cumsum(x, axis=0):
2829 """Cumulative sum of the values in a tensor, alongside the specified axis.
2831 Args:
2832 x: A tensor or variable.
2833 axis: An integer, the axis to compute the sum.
2835 Returns:
2836 A tensor of the cumulative sum of values of `x` along `axis`.
2837 """
2838 return tf.cumsum(x, axis=axis)
2841@keras_export("keras.backend.cumprod")
2842@tf.__internal__.dispatch.add_dispatch_support
2843@doc_controls.do_not_generate_docs
2844def cumprod(x, axis=0):
2845 """Cumulative product of the values in a tensor alongside `axis`.
2847 Args:
2848 x: A tensor or variable.
2849 axis: An integer, the axis to compute the product.
2851 Returns:
2852 A tensor of the cumulative product of values of `x` along `axis`.
2853 """
2854 return tf.math.cumprod(x, axis=axis)
2857@keras_export("keras.backend.var")
2858@doc_controls.do_not_generate_docs
2859def var(x, axis=None, keepdims=False):
2860 """Variance of a tensor, alongside the specified axis.
2862 Args:
2863 x: A tensor or variable.
2864 axis: An integer, the axis to compute the variance.
2865 keepdims: A boolean, whether to keep the dimensions or not.
2866 If `keepdims` is `False`, the rank of the tensor is reduced
2867 by 1. If `keepdims` is `True`,
2868 the reduced dimension is retained with length 1.
2870 Returns:
2871 A tensor with the variance of elements of `x`.
2872 """
2873 if x.dtype.base_dtype == tf.bool:
2874 x = tf.cast(x, floatx())
2875 return tf.math.reduce_variance(x, axis=axis, keepdims=keepdims)
2878@keras_export("keras.backend.std")
2879@tf.__internal__.dispatch.add_dispatch_support
2880@doc_controls.do_not_generate_docs
2881def std(x, axis=None, keepdims=False):
2882 """Standard deviation of a tensor, alongside the specified axis.
2884 It is an alias to `tf.math.reduce_std`.
2886 Args:
2887 x: A tensor or variable. It should have numerical dtypes. Boolean type
2888 inputs will be converted to float.
2889 axis: An integer, the axis to compute the standard deviation. If `None`
2890 (the default), reduces all dimensions. Must be in the range
2891 `[-rank(x), rank(x))`.
2892 keepdims: A boolean, whether to keep the dimensions or not.
2893 If `keepdims` is `False`, the rank of the tensor is reduced
2894 by 1. If `keepdims` is `True`, the reduced dimension is retained
2895 with length 1.
2897 Returns:
2898 A tensor with the standard deviation of elements of `x` with same dtype.
2899 Boolean type input will be converted to float.
2900 """
2901 if x.dtype.base_dtype == tf.bool:
2902 x = tf.cast(x, floatx())
2903 return tf.math.reduce_std(x, axis=axis, keepdims=keepdims)
2906@keras_export("keras.backend.mean")
2907@tf.__internal__.dispatch.add_dispatch_support
2908@doc_controls.do_not_generate_docs
2909def mean(x, axis=None, keepdims=False):
2910 """Mean of a tensor, alongside the specified axis.
2912 Args:
2913 x: A tensor or variable.
2914 axis: A list of integer. Axes to compute the mean.
2915 keepdims: A boolean, whether to keep the dimensions or not.
2916 If `keepdims` is `False`, the rank of the tensor is reduced
2917 by 1 for each entry in `axis`. If `keepdims` is `True`,
2918 the reduced dimensions are retained with length 1.
2920 Returns:
2921 A tensor with the mean of elements of `x`.
2922 """
2923 if x.dtype.base_dtype == tf.bool:
2924 x = tf.cast(x, floatx())
2925 return tf.reduce_mean(x, axis, keepdims)
2928@keras_export("keras.backend.any")
2929@tf.__internal__.dispatch.add_dispatch_support
2930@doc_controls.do_not_generate_docs
2931def any(x, axis=None, keepdims=False):
2932 """Bitwise reduction (logical OR).
2934 Args:
2935 x: Tensor or variable.
2936 axis: axis along which to perform the reduction.
2937 keepdims: whether the drop or broadcast the reduction axes.
2939 Returns:
2940 A uint8 tensor (0s and 1s).
2941 """
2942 x = tf.cast(x, tf.bool)
2943 return tf.reduce_any(x, axis, keepdims)
2946@keras_export("keras.backend.all")
2947@tf.__internal__.dispatch.add_dispatch_support
2948@doc_controls.do_not_generate_docs
2949def all(x, axis=None, keepdims=False):
2950 """Bitwise reduction (logical AND).
2952 Args:
2953 x: Tensor or variable.
2954 axis: axis along which to perform the reduction.
2955 keepdims: whether the drop or broadcast the reduction axes.
2957 Returns:
2958 A uint8 tensor (0s and 1s).
2959 """
2960 x = tf.cast(x, tf.bool)
2961 return tf.reduce_all(x, axis, keepdims)
2964@keras_export("keras.backend.argmax")
2965@tf.__internal__.dispatch.add_dispatch_support
2966@doc_controls.do_not_generate_docs
2967def argmax(x, axis=-1):
2968 """Returns the index of the maximum value along an axis.
2970 Args:
2971 x: Tensor or variable.
2972 axis: axis along which to perform the reduction.
2974 Returns:
2975 A tensor.
2976 """
2977 return tf.argmax(x, axis)
2980@keras_export("keras.backend.argmin")
2981@tf.__internal__.dispatch.add_dispatch_support
2982@doc_controls.do_not_generate_docs
2983def argmin(x, axis=-1):
2984 """Returns the index of the minimum value along an axis.
2986 Args:
2987 x: Tensor or variable.
2988 axis: axis along which to perform the reduction.
2990 Returns:
2991 A tensor.
2992 """
2993 return tf.argmin(x, axis)
2996@keras_export("keras.backend.square")
2997@tf.__internal__.dispatch.add_dispatch_support
2998@doc_controls.do_not_generate_docs
2999def square(x):
3000 """Element-wise square.
3002 Args:
3003 x: Tensor or variable.
3005 Returns:
3006 A tensor.
3007 """
3008 return tf.square(x)
3011@keras_export("keras.backend.abs")
3012@tf.__internal__.dispatch.add_dispatch_support
3013@doc_controls.do_not_generate_docs
3014def abs(x):
3015 """Element-wise absolute value.
3017 Args:
3018 x: Tensor or variable.
3020 Returns:
3021 A tensor.
3022 """
3023 return tf.abs(x)
3026@keras_export("keras.backend.sqrt")
3027@tf.__internal__.dispatch.add_dispatch_support
3028@doc_controls.do_not_generate_docs
3029def sqrt(x):
3030 """Element-wise square root.
3032 This function clips negative tensor values to 0 before computing the
3033 square root.
3035 Args:
3036 x: Tensor or variable.
3038 Returns:
3039 A tensor.
3040 """
3041 zero = _constant_to_tensor(0.0, x.dtype.base_dtype)
3042 x = tf.maximum(x, zero)
3043 return tf.sqrt(x)
3046@keras_export("keras.backend.exp")
3047@tf.__internal__.dispatch.add_dispatch_support
3048@doc_controls.do_not_generate_docs
3049def exp(x):
3050 """Element-wise exponential.
3052 Args:
3053 x: Tensor or variable.
3055 Returns:
3056 A tensor.
3057 """
3058 return tf.exp(x)
3061@keras_export("keras.backend.log")
3062@tf.__internal__.dispatch.add_dispatch_support
3063@doc_controls.do_not_generate_docs
3064def log(x):
3065 """Element-wise log.
3067 Args:
3068 x: Tensor or variable.
3070 Returns:
3071 A tensor.
3072 """
3073 return tf.math.log(x)
3076def logsumexp(x, axis=None, keepdims=False):
3077 """Computes log(sum(exp(elements across dimensions of a tensor))).
3079 This function is more numerically stable than log(sum(exp(x))).
3080 It avoids overflows caused by taking the exp of large inputs and
3081 underflows caused by taking the log of small inputs.
3083 Args:
3084 x: A tensor or variable.
3085 axis: An integer, the axis to reduce over.
3086 keepdims: A boolean, whether to keep the dimensions or not.
3087 If `keepdims` is `False`, the rank of the tensor is reduced
3088 by 1. If `keepdims` is `True`, the reduced dimension is
3089 retained with length 1.
3091 Returns:
3092 The reduced tensor.
3093 """
3094 return tf.reduce_logsumexp(x, axis, keepdims)
3097@keras_export("keras.backend.round")
3098@tf.__internal__.dispatch.add_dispatch_support
3099@doc_controls.do_not_generate_docs
3100def round(x):
3101 """Element-wise rounding to the closest integer.
3103 In case of tie, the rounding mode used is "half to even".
3105 Args:
3106 x: Tensor or variable.
3108 Returns:
3109 A tensor.
3110 """
3111 return tf.round(x)
3114@keras_export("keras.backend.sign")
3115@tf.__internal__.dispatch.add_dispatch_support
3116@doc_controls.do_not_generate_docs
3117def sign(x):
3118 """Element-wise sign.
3120 Args:
3121 x: Tensor or variable.
3123 Returns:
3124 A tensor.
3125 """
3126 return tf.sign(x)
3129@keras_export("keras.backend.pow")
3130@tf.__internal__.dispatch.add_dispatch_support
3131@doc_controls.do_not_generate_docs
3132def pow(x, a):
3133 """Element-wise exponentiation.
3135 Args:
3136 x: Tensor or variable.
3137 a: Python integer.
3139 Returns:
3140 A tensor.
3141 """
3142 return tf.pow(x, a)
3145@keras_export("keras.backend.clip")
3146@tf.__internal__.dispatch.add_dispatch_support
3147@doc_controls.do_not_generate_docs
3148def clip(x, min_value, max_value):
3149 """Element-wise value clipping.
3151 Args:
3152 x: Tensor or variable.
3153 min_value: Python float, integer, or tensor.
3154 max_value: Python float, integer, or tensor.
3156 Returns:
3157 A tensor.
3158 """
3159 if isinstance(min_value, (int, float)) and isinstance(
3160 max_value, (int, float)
3161 ):
3162 if max_value < min_value:
3163 max_value = min_value
3164 if min_value is None:
3165 min_value = -np.inf
3166 if max_value is None:
3167 max_value = np.inf
3168 return tf.clip_by_value(x, min_value, max_value)
3171@keras_export("keras.backend.equal")
3172@tf.__internal__.dispatch.add_dispatch_support
3173@doc_controls.do_not_generate_docs
3174def equal(x, y):
3175 """Element-wise equality between two tensors.
3177 Args:
3178 x: Tensor or variable.
3179 y: Tensor or variable.
3181 Returns:
3182 A bool tensor.
3183 """
3184 return tf.equal(x, y)
3187@keras_export("keras.backend.not_equal")
3188@tf.__internal__.dispatch.add_dispatch_support
3189@doc_controls.do_not_generate_docs
3190def not_equal(x, y):
3191 """Element-wise inequality between two tensors.
3193 Args:
3194 x: Tensor or variable.
3195 y: Tensor or variable.
3197 Returns:
3198 A bool tensor.
3199 """
3200 return tf.not_equal(x, y)
3203@keras_export("keras.backend.greater")
3204@tf.__internal__.dispatch.add_dispatch_support
3205@doc_controls.do_not_generate_docs
3206def greater(x, y):
3207 """Element-wise truth value of (x > y).
3209 Args:
3210 x: Tensor or variable.
3211 y: Tensor or variable.
3213 Returns:
3214 A bool tensor.
3215 """
3216 return tf.greater(x, y)
3219@keras_export("keras.backend.greater_equal")
3220@tf.__internal__.dispatch.add_dispatch_support
3221@doc_controls.do_not_generate_docs
3222def greater_equal(x, y):
3223 """Element-wise truth value of (x >= y).
3225 Args:
3226 x: Tensor or variable.
3227 y: Tensor or variable.
3229 Returns:
3230 A bool tensor.
3231 """
3232 return tf.greater_equal(x, y)
3235@keras_export("keras.backend.less")
3236@tf.__internal__.dispatch.add_dispatch_support
3237@doc_controls.do_not_generate_docs
3238def less(x, y):
3239 """Element-wise truth value of (x < y).
3241 Args:
3242 x: Tensor or variable.
3243 y: Tensor or variable.
3245 Returns:
3246 A bool tensor.
3247 """
3248 return tf.less(x, y)
3251@keras_export("keras.backend.less_equal")
3252@tf.__internal__.dispatch.add_dispatch_support
3253@doc_controls.do_not_generate_docs
3254def less_equal(x, y):
3255 """Element-wise truth value of (x <= y).
3257 Args:
3258 x: Tensor or variable.
3259 y: Tensor or variable.
3261 Returns:
3262 A bool tensor.
3263 """
3264 return tf.less_equal(x, y)
3267@keras_export("keras.backend.maximum")
3268@tf.__internal__.dispatch.add_dispatch_support
3269@doc_controls.do_not_generate_docs
3270def maximum(x, y):
3271 """Element-wise maximum of two tensors.
3273 Args:
3274 x: Tensor or variable.
3275 y: Tensor or variable.
3277 Returns:
3278 A tensor with the element wise maximum value(s) of `x` and `y`.
3280 Examples:
3282 >>> x = tf.Variable([[1, 2], [3, 4]])
3283 >>> y = tf.Variable([[2, 1], [0, -1]])
3284 >>> m = tf.keras.backend.maximum(x, y)
3285 >>> m
3286 <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
3287 array([[2, 2],
3288 [3, 4]], dtype=int32)>
3289 """
3290 return tf.maximum(x, y)
3293@keras_export("keras.backend.minimum")
3294@tf.__internal__.dispatch.add_dispatch_support
3295@doc_controls.do_not_generate_docs
3296def minimum(x, y):
3297 """Element-wise minimum of two tensors.
3299 Args:
3300 x: Tensor or variable.
3301 y: Tensor or variable.
3303 Returns:
3304 A tensor.
3305 """
3306 return tf.minimum(x, y)
3309@keras_export("keras.backend.sin")
3310@tf.__internal__.dispatch.add_dispatch_support
3311@doc_controls.do_not_generate_docs
3312def sin(x):
3313 """Computes sin of x element-wise.
3315 Args:
3316 x: Tensor or variable.
3318 Returns:
3319 A tensor.
3320 """
3321 return tf.sin(x)
3324@keras_export("keras.backend.cos")
3325@tf.__internal__.dispatch.add_dispatch_support
3326@doc_controls.do_not_generate_docs
3327def cos(x):
3328 """Computes cos of x element-wise.
3330 Args:
3331 x: Tensor or variable.
3333 Returns:
3334 A tensor.
3335 """
3336 return tf.cos(x)
3339def _regular_normalize_batch_in_training(
3340 x, gamma, beta, reduction_axes, epsilon=1e-3
3341):
3342 """Non-fused version of `normalize_batch_in_training`.
3344 Args:
3345 x: Input tensor or variable.
3346 gamma: Tensor by which to scale the input.
3347 beta: Tensor with which to center the input.
3348 reduction_axes: iterable of integers,
3349 axes over which to normalize.
3350 epsilon: Fuzz factor.
3352 Returns:
3353 A tuple length of 3, `(normalized_tensor, mean, variance)`.
3354 """
3355 mean, var = tf.compat.v1.nn.moments(x, reduction_axes, None, None, False)
3356 normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
3357 return normed, mean, var
3360def _broadcast_normalize_batch_in_training(
3361 x, gamma, beta, reduction_axes, epsilon=1e-3
3362):
3363 """Non-fused, broadcast version of `normalize_batch_in_training`.
3365 Args:
3366 x: Input tensor or variable.
3367 gamma: Tensor by which to scale the input.
3368 beta: Tensor with which to center the input.
3369 reduction_axes: iterable of integers,
3370 axes over which to normalize.
3371 epsilon: Fuzz factor.
3373 Returns:
3374 A tuple length of 3, `(normalized_tensor, mean, variance)`.
3375 """
3376 mean, var = tf.compat.v1.nn.moments(x, reduction_axes, None, None, False)
3377 target_shape = []
3378 for axis in range(ndim(x)):
3379 if axis in reduction_axes:
3380 target_shape.append(1)
3381 else:
3382 target_shape.append(tf.shape(x)[axis])
3383 target_shape = tf.stack(target_shape)
3385 broadcast_mean = tf.reshape(mean, target_shape)
3386 broadcast_var = tf.reshape(var, target_shape)
3387 if gamma is None:
3388 broadcast_gamma = None
3389 else:
3390 broadcast_gamma = tf.reshape(gamma, target_shape)
3391 if beta is None:
3392 broadcast_beta = None
3393 else:
3394 broadcast_beta = tf.reshape(beta, target_shape)
3396 normed = tf.nn.batch_normalization(
3397 x,
3398 broadcast_mean,
3399 broadcast_var,
3400 broadcast_beta,
3401 broadcast_gamma,
3402 epsilon,
3403 )
3404 return normed, mean, var
3407def _fused_normalize_batch_in_training(
3408 x, gamma, beta, reduction_axes, epsilon=1e-3
3409):
3410 """Fused version of `normalize_batch_in_training`.
3412 Args:
3413 x: Input tensor or variable.
3414 gamma: Tensor by which to scale the input.
3415 beta: Tensor with which to center the input.
3416 reduction_axes: iterable of integers,
3417 axes over which to normalize.
3418 epsilon: Fuzz factor.
3420 Returns:
3421 A tuple length of 3, `(normalized_tensor, mean, variance)`.
3422 """
3423 if list(reduction_axes) == [0, 1, 2]:
3424 normalization_axis = 3
3425 tf_data_format = "NHWC"
3426 else:
3427 normalization_axis = 1
3428 tf_data_format = "NCHW"
3430 if gamma is None:
3431 gamma = tf.constant(
3432 1.0, dtype=x.dtype, shape=[x.shape[normalization_axis]]
3433 )
3434 if beta is None:
3435 beta = tf.constant(
3436 0.0, dtype=x.dtype, shape=[x.shape[normalization_axis]]
3437 )
3439 return tf.compat.v1.nn.fused_batch_norm(
3440 x, gamma, beta, epsilon=epsilon, data_format=tf_data_format
3441 )
3444@keras_export("keras.backend.normalize_batch_in_training")
3445@doc_controls.do_not_generate_docs
3446def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
3447 """Computes mean and std for batch then apply batch_normalization on batch.
3449 Args:
3450 x: Input tensor or variable.
3451 gamma: Tensor by which to scale the input.
3452 beta: Tensor with which to center the input.
3453 reduction_axes: iterable of integers,
3454 axes over which to normalize.
3455 epsilon: Fuzz factor.
3457 Returns:
3458 A tuple length of 3, `(normalized_tensor, mean, variance)`.
3459 """
3460 if ndim(x) == 4 and list(reduction_axes) in [[0, 1, 2], [0, 2, 3]]:
3461 if not _has_nchw_support() and list(reduction_axes) == [0, 2, 3]:
3462 return _broadcast_normalize_batch_in_training(
3463 x, gamma, beta, reduction_axes, epsilon=epsilon
3464 )
3465 return _fused_normalize_batch_in_training(
3466 x, gamma, beta, reduction_axes, epsilon=epsilon
3467 )
3468 else:
3469 if sorted(reduction_axes) == list(range(ndim(x)))[:-1]:
3470 return _regular_normalize_batch_in_training(
3471 x, gamma, beta, reduction_axes, epsilon=epsilon
3472 )
3473 else:
3474 return _broadcast_normalize_batch_in_training(
3475 x, gamma, beta, reduction_axes, epsilon=epsilon
3476 )
3479@keras_export("keras.backend.batch_normalization")
3480@tf.__internal__.dispatch.add_dispatch_support
3481@doc_controls.do_not_generate_docs
3482def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
3483 """Applies batch normalization on x given mean, var, beta and gamma.
3485 I.e. returns:
3486 `output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta`
3488 Args:
3489 x: Input tensor or variable.
3490 mean: Mean of batch.
3491 var: Variance of batch.
3492 beta: Tensor with which to center the input.
3493 gamma: Tensor by which to scale the input.
3494 axis: Integer, the axis that should be normalized.
3495 (typically the features axis).
3496 epsilon: Fuzz factor.
3498 Returns:
3499 A tensor.
3500 """
3501 if ndim(x) == 4:
3502 # The CPU implementation of `fused_batch_norm` only supports NHWC
3503 if axis == 1 or axis == -3:
3504 tf_data_format = "NCHW"
3505 elif axis == 3 or axis == -1:
3506 tf_data_format = "NHWC"
3507 else:
3508 tf_data_format = None
3510 if (
3511 tf_data_format == "NHWC"
3512 or tf_data_format == "NCHW"
3513 and _has_nchw_support()
3514 ):
3515 # The mean / var / beta / gamma tensors may be broadcasted
3516 # so they may have extra axes of size 1, which should be squeezed.
3517 if ndim(mean) > 1:
3518 mean = tf.reshape(mean, [-1])
3519 if ndim(var) > 1:
3520 var = tf.reshape(var, [-1])
3521 if beta is None:
3522 beta = zeros_like(mean)
3523 elif ndim(beta) > 1:
3524 beta = tf.reshape(beta, [-1])
3525 if gamma is None:
3526 gamma = ones_like(mean)
3527 elif ndim(gamma) > 1:
3528 gamma = tf.reshape(gamma, [-1])
3529 y, _, _ = tf.compat.v1.nn.fused_batch_norm(
3530 x,
3531 gamma,
3532 beta,
3533 epsilon=epsilon,
3534 mean=mean,
3535 variance=var,
3536 data_format=tf_data_format,
3537 is_training=False,
3538 )
3539 return y
3540 return tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
3543# SHAPE OPERATIONS
3546@keras_export("keras.backend.concatenate")
3547@tf.__internal__.dispatch.add_dispatch_support
3548@doc_controls.do_not_generate_docs
3549def concatenate(tensors, axis=-1):
3550 """Concatenates a list of tensors alongside the specified axis.
3552 Args:
3553 tensors: list of tensors to concatenate.
3554 axis: concatenation axis.
3556 Returns:
3557 A tensor.
3559 Example:
3561 >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
3562 >>> b = tf.constant([[10, 20, 30], [40, 50, 60], [70, 80, 90]])
3563 >>> tf.keras.backend.concatenate((a, b), axis=-1)
3564 <tf.Tensor: shape=(3, 6), dtype=int32, numpy=
3565 array([[ 1, 2, 3, 10, 20, 30],
3566 [ 4, 5, 6, 40, 50, 60],
3567 [ 7, 8, 9, 70, 80, 90]], dtype=int32)>
3569 """
3570 if axis < 0:
3571 rank = ndim(tensors[0])
3572 if rank:
3573 axis %= rank
3574 else:
3575 axis = 0
3577 if py_all(is_sparse(x) for x in tensors):
3578 return tf.compat.v1.sparse_concat(axis, tensors)
3579 elif py_all(isinstance(x, tf.RaggedTensor) for x in tensors):
3580 return tf.concat(tensors, axis)
3581 else:
3582 return tf.concat([to_dense(x) for x in tensors], axis)
3585@keras_export("keras.backend.reshape")
3586@tf.__internal__.dispatch.add_dispatch_support
3587@doc_controls.do_not_generate_docs
3588def reshape(x, shape):
3589 """Reshapes a tensor to the specified shape.
3591 Args:
3592 x: Tensor or variable.
3593 shape: Target shape tuple.
3595 Returns:
3596 A tensor.
3598 Example:
3600 >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
3601 >>> a
3602 <tf.Tensor: shape=(4, 3), dtype=int32, numpy=
3603 array([[ 1, 2, 3],
3604 [ 4, 5, 6],
3605 [ 7, 8, 9],
3606 [10, 11, 12]], dtype=int32)>
3607 >>> tf.keras.backend.reshape(a, shape=(2, 6))
3608 <tf.Tensor: shape=(2, 6), dtype=int32, numpy=
3609 array([[ 1, 2, 3, 4, 5, 6],
3610 [ 7, 8, 9, 10, 11, 12]], dtype=int32)>
3612 """
3613 return tf.reshape(x, shape)
3616@keras_export("keras.backend.permute_dimensions")
3617@tf.__internal__.dispatch.add_dispatch_support
3618@doc_controls.do_not_generate_docs
3619def permute_dimensions(x, pattern):
3620 """Permutes axes in a tensor.
3622 Args:
3623 x: Tensor or variable.
3624 pattern: A tuple of
3625 dimension indices, e.g. `(0, 2, 1)`.
3627 Returns:
3628 A tensor.
3630 Example:
3632 >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
3633 >>> a
3634 <tf.Tensor: shape=(4, 3), dtype=int32, numpy=
3635 array([[ 1, 2, 3],
3636 [ 4, 5, 6],
3637 [ 7, 8, 9],
3638 [10, 11, 12]], dtype=int32)>
3639 >>> tf.keras.backend.permute_dimensions(a, pattern=(1, 0))
3640 <tf.Tensor: shape=(3, 4), dtype=int32, numpy=
3641 array([[ 1, 4, 7, 10],
3642 [ 2, 5, 8, 11],
3643 [ 3, 6, 9, 12]], dtype=int32)>
3645 """
3646 return tf.compat.v1.transpose(x, perm=pattern)
3649@keras_export("keras.backend.resize_images")
3650@tf.__internal__.dispatch.add_dispatch_support
3651@doc_controls.do_not_generate_docs
3652def resize_images(
3653 x, height_factor, width_factor, data_format, interpolation="nearest"
3654):
3655 """Resizes the images contained in a 4D tensor.
3657 Args:
3658 x: Tensor or variable to resize.
3659 height_factor: Positive integer.
3660 width_factor: Positive integer.
3661 data_format: One of `"channels_first"`, `"channels_last"`.
3662 interpolation: A string, one of `"area"`, `"bicubic"`, `"bilinear"`,
3663 `"gaussian"`, `"lanczos3"`, `"lanczos5"`, `"mitchellcubic"`,
3664 `"nearest"`.
3666 Returns:
3667 A tensor.
3669 Raises:
3670 ValueError: in case of incorrect value for
3671 `data_format` or `interpolation`.
3672 """
3673 if data_format == "channels_first":
3674 rows, cols = 2, 3
3675 elif data_format == "channels_last":
3676 rows, cols = 1, 2
3677 else:
3678 raise ValueError(f"Invalid `data_format` argument: {data_format}")
3680 new_shape = x.shape[rows : cols + 1]
3681 if new_shape.is_fully_defined():
3682 new_shape = tf.constant(new_shape.as_list(), dtype="int32")
3683 else:
3684 new_shape = tf.shape(x)[rows : cols + 1]
3685 new_shape *= tf.constant(
3686 np.array([height_factor, width_factor], dtype="int32")
3687 )
3689 if data_format == "channels_first":
3690 x = permute_dimensions(x, [0, 2, 3, 1])
3691 interpolations = {
3692 "area": tf.image.ResizeMethod.AREA,
3693 "bicubic": tf.image.ResizeMethod.BICUBIC,
3694 "bilinear": tf.image.ResizeMethod.BILINEAR,
3695 "gaussian": tf.image.ResizeMethod.GAUSSIAN,
3696 "lanczos3": tf.image.ResizeMethod.LANCZOS3,
3697 "lanczos5": tf.image.ResizeMethod.LANCZOS5,
3698 "mitchellcubic": tf.image.ResizeMethod.MITCHELLCUBIC,
3699 "nearest": tf.image.ResizeMethod.NEAREST_NEIGHBOR,
3700 }
3701 interploations_list = '"' + '", "'.join(interpolations.keys()) + '"'
3702 if interpolation in interpolations:
3703 x = tf.image.resize(x, new_shape, method=interpolations[interpolation])
3704 else:
3705 raise ValueError(
3706 "`interpolation` argument should be one of: "
3707 f'{interploations_list}. Received: "{interpolation}".'
3708 )
3709 if data_format == "channels_first":
3710 x = permute_dimensions(x, [0, 3, 1, 2])
3712 return x
3715@keras_export("keras.backend.resize_volumes")
3716@tf.__internal__.dispatch.add_dispatch_support
3717@doc_controls.do_not_generate_docs
3718def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
3719 """Resizes the volume contained in a 5D tensor.
3721 Args:
3722 x: Tensor or variable to resize.
3723 depth_factor: Positive integer.
3724 height_factor: Positive integer.
3725 width_factor: Positive integer.
3726 data_format: One of `"channels_first"`, `"channels_last"`.
3728 Returns:
3729 A tensor.
3731 Raises:
3732 ValueError: if `data_format` is neither
3733 `channels_last` or `channels_first`.
3734 """
3735 if data_format == "channels_first":
3736 output = repeat_elements(x, depth_factor, axis=2)
3737 output = repeat_elements(output, height_factor, axis=3)
3738 output = repeat_elements(output, width_factor, axis=4)
3739 return output
3740 elif data_format == "channels_last":
3741 output = repeat_elements(x, depth_factor, axis=1)
3742 output = repeat_elements(output, height_factor, axis=2)
3743 output = repeat_elements(output, width_factor, axis=3)
3744 return output
3745 else:
3746 raise ValueError("Invalid data_format: " + str(data_format))
3749@keras_export("keras.backend.repeat_elements")
3750@tf.__internal__.dispatch.add_dispatch_support
3751@doc_controls.do_not_generate_docs
3752def repeat_elements(x, rep, axis):
3753 """Repeats the elements of a tensor along an axis, like `np.repeat`.
3755 If `x` has shape `(s1, s2, s3)` and `axis` is `1`, the output
3756 will have shape `(s1, s2 * rep, s3)`.
3758 Args:
3759 x: Tensor or variable.
3760 rep: Python integer, number of times to repeat.
3761 axis: Axis along which to repeat.
3763 Returns:
3764 A tensor.
3766 Example:
3768 >>> b = tf.constant([1, 2, 3])
3769 >>> tf.keras.backend.repeat_elements(b, rep=2, axis=0)
3770 <tf.Tensor: shape=(6,), dtype=int32,
3771 numpy=array([1, 1, 2, 2, 3, 3], dtype=int32)>
3773 """
3774 x_shape = x.shape.as_list()
3775 # For static axis
3776 if x_shape[axis] is not None:
3777 # slices along the repeat axis
3778 splits = tf.split(value=x, num_or_size_splits=x_shape[axis], axis=axis)
3779 # repeat each slice the given number of reps
3780 x_rep = [s for s in splits for _ in range(rep)]
3781 return concatenate(x_rep, axis)
3783 # Here we use tf.tile to mimic behavior of np.repeat so that
3784 # we can handle dynamic shapes (that include None).
3785 # To do that, we need an auxiliary axis to repeat elements along
3786 # it and then merge them along the desired axis.
3788 # Repeating
3789 auxiliary_axis = axis + 1
3790 x_shape = tf.shape(x)
3791 x_rep = tf.expand_dims(x, axis=auxiliary_axis)
3792 reps = np.ones(len(x.shape) + 1)
3793 reps[auxiliary_axis] = rep
3794 x_rep = tf.tile(x_rep, reps)
3796 # Merging
3797 reps = np.delete(reps, auxiliary_axis)
3798 reps[axis] = rep
3799 reps = tf.constant(reps, dtype="int32")
3800 x_shape *= reps
3801 x_rep = tf.reshape(x_rep, x_shape)
3803 # Fix shape representation
3804 x_shape = x.shape.as_list()
3805 x_rep.set_shape(x_shape)
3806 x_rep._keras_shape = tuple(x_shape)
3807 return x_rep
3810@keras_export("keras.backend.repeat")
3811@tf.__internal__.dispatch.add_dispatch_support
3812@doc_controls.do_not_generate_docs
3813def repeat(x, n):
3814 """Repeats a 2D tensor.
3816 if `x` has shape (samples, dim) and `n` is `2`,
3817 the output will have shape `(samples, 2, dim)`.
3819 Args:
3820 x: Tensor or variable.
3821 n: Python integer, number of times to repeat.
3823 Returns:
3824 A tensor.
3826 Example:
3828 >>> b = tf.constant([[1, 2], [3, 4]])
3829 >>> b
3830 <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
3831 array([[1, 2],
3832 [3, 4]], dtype=int32)>
3833 >>> tf.keras.backend.repeat(b, n=2)
3834 <tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy=
3835 array([[[1, 2],
3836 [1, 2]],
3837 [[3, 4],
3838 [3, 4]]], dtype=int32)>
3840 """
3841 assert ndim(x) == 2
3842 x = tf.expand_dims(x, 1)
3843 pattern = tf.stack([1, n, 1])
3844 return tf.tile(x, pattern)
3847@keras_export("keras.backend.arange")
3848@tf.__internal__.dispatch.add_dispatch_support
3849@doc_controls.do_not_generate_docs
3850def arange(start, stop=None, step=1, dtype="int32"):
3851 """Creates a 1D tensor containing a sequence of integers.
3853 The function arguments use the same convention as
3854 Theano's arange: if only one argument is provided,
3855 it is in fact the "stop" argument and "start" is 0.
3857 The default type of the returned tensor is `'int32'` to
3858 match TensorFlow's default.
3860 Args:
3861 start: Start value.
3862 stop: Stop value.
3863 step: Difference between two successive values.
3864 dtype: Integer dtype to use.
3866 Returns:
3867 An integer tensor.
3869 Example:
3871 >>> tf.keras.backend.arange(start=0, stop=10, step=1.5)
3872 <tf.Tensor: shape=(7,), dtype=float32,
3873 numpy=array([0. , 1.5, 3. , 4.5, 6. , 7.5, 9. ], dtype=float32)>
3877 """
3878 # Match the behavior of numpy and Theano by returning an empty sequence.
3879 if stop is None and start < 0:
3880 start = 0
3881 result = tf.range(start, limit=stop, delta=step, name="arange")
3882 if dtype != "int32":
3883 result = cast(result, dtype)
3884 return result
3887@keras_export("keras.backend.tile")
3888@tf.__internal__.dispatch.add_dispatch_support
3889@doc_controls.do_not_generate_docs
3890def tile(x, n):
3891 """Creates a tensor by tiling `x` by `n`.
3893 Args:
3894 x: A tensor or variable
3895 n: A list of integer. The length must be the same as the number of
3896 dimensions in `x`.
3898 Returns:
3899 A tiled tensor.
3900 """
3901 if isinstance(n, int):
3902 n = [n]
3903 return tf.tile(x, n)
3906@keras_export("keras.backend.flatten")
3907@tf.__internal__.dispatch.add_dispatch_support
3908@doc_controls.do_not_generate_docs
3909def flatten(x):
3910 """Flatten a tensor.
3912 Args:
3913 x: A tensor or variable.
3915 Returns:
3916 A tensor, reshaped into 1-D
3918 Example:
3920 >>> b = tf.constant([[1, 2], [3, 4]])
3921 >>> b
3922 <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
3923 array([[1, 2],
3924 [3, 4]], dtype=int32)>
3925 >>> tf.keras.backend.flatten(b)
3926 <tf.Tensor: shape=(4,), dtype=int32,
3927 numpy=array([1, 2, 3, 4], dtype=int32)>
3929 """
3930 return tf.reshape(x, [-1])
3933@keras_export("keras.backend.batch_flatten")
3934@tf.__internal__.dispatch.add_dispatch_support
3935@doc_controls.do_not_generate_docs
3936def batch_flatten(x):
3937 """Turn a nD tensor into a 2D tensor with same 0th dimension.
3939 In other words, it flattens each data samples of a batch.
3941 Args:
3942 x: A tensor or variable.
3944 Returns:
3945 A tensor.
3947 Examples:
3948 Flattening a 3D tensor to 2D by collapsing the last dimension.
3950 >>> x_batch = tf.keras.backend.ones(shape=(2, 3, 4, 5))
3951 >>> x_batch_flatten = batch_flatten(x_batch)
3952 >>> tf.keras.backend.int_shape(x_batch_flatten)
3953 (2, 60)
3955 """
3956 x = tf.reshape(x, tf.stack([-1, prod(shape(x)[1:])]))
3957 return x
3960@keras_export("keras.backend.expand_dims")
3961@tf.__internal__.dispatch.add_dispatch_support
3962@doc_controls.do_not_generate_docs
3963def expand_dims(x, axis=-1):
3964 """Adds a 1-sized dimension at index "axis".
3966 Args:
3967 x: A tensor or variable.
3968 axis: Position where to add a new axis.
3970 Returns:
3971 A tensor with expanded dimensions.
3972 """
3973 return tf.expand_dims(x, axis)
3976@keras_export("keras.backend.squeeze")
3977@tf.__internal__.dispatch.add_dispatch_support
3978@doc_controls.do_not_generate_docs
3979def squeeze(x, axis):
3980 """Removes a 1-dimension from the tensor at index "axis".
3982 Args:
3983 x: A tensor or variable.
3984 axis: Axis to drop.
3986 Returns:
3987 A tensor with the same data as `x` but reduced dimensions.
3988 """
3989 return tf.squeeze(x, [axis])
3992@keras_export("keras.backend.temporal_padding")
3993@tf.__internal__.dispatch.add_dispatch_support
3994@doc_controls.do_not_generate_docs
3995def temporal_padding(x, padding=(1, 1)):
3996 """Pads the middle dimension of a 3D tensor.
3998 Args:
3999 x: Tensor or variable.
4000 padding: Tuple of 2 integers, how many zeros to
4001 add at the start and end of dim 1.
4003 Returns:
4004 A padded 3D tensor.
4005 """
4006 assert len(padding) == 2
4007 pattern = [[0, 0], [padding[0], padding[1]], [0, 0]]
4008 return tf.compat.v1.pad(x, pattern)
4011@keras_export("keras.backend.spatial_2d_padding")
4012@tf.__internal__.dispatch.add_dispatch_support
4013@doc_controls.do_not_generate_docs
4014def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
4015 """Pads the 2nd and 3rd dimensions of a 4D tensor.
4017 Args:
4018 x: Tensor or variable.
4019 padding: Tuple of 2 tuples, padding pattern.
4020 data_format: One of `channels_last` or `channels_first`.
4022 Returns:
4023 A padded 4D tensor.
4025 Raises:
4026 ValueError: if `data_format` is neither
4027 `channels_last` or `channels_first`.
4028 """
4029 assert len(padding) == 2
4030 assert len(padding[0]) == 2
4031 assert len(padding[1]) == 2
4032 if data_format is None:
4033 data_format = image_data_format()
4034 if data_format not in {"channels_first", "channels_last"}:
4035 raise ValueError("Unknown data_format: " + str(data_format))
4037 if data_format == "channels_first":
4038 pattern = [[0, 0], [0, 0], list(padding[0]), list(padding[1])]
4039 else:
4040 pattern = [[0, 0], list(padding[0]), list(padding[1]), [0, 0]]
4041 return tf.compat.v1.pad(x, pattern)
4044@keras_export("keras.backend.spatial_3d_padding")
4045@tf.__internal__.dispatch.add_dispatch_support
4046@doc_controls.do_not_generate_docs
4047def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
4048 """Pads 5D tensor with zeros along the depth, height, width dimensions.
4050 Pads these dimensions with respectively
4051 "padding[0]", "padding[1]" and "padding[2]" zeros left and right.
4053 For 'channels_last' data_format,
4054 the 2nd, 3rd and 4th dimension will be padded.
4055 For 'channels_first' data_format,
4056 the 3rd, 4th and 5th dimension will be padded.
4058 Args:
4059 x: Tensor or variable.
4060 padding: Tuple of 3 tuples, padding pattern.
4061 data_format: One of `channels_last` or `channels_first`.
4063 Returns:
4064 A padded 5D tensor.
4066 Raises:
4067 ValueError: if `data_format` is neither
4068 `channels_last` or `channels_first`.
4070 """
4071 assert len(padding) == 3
4072 assert len(padding[0]) == 2
4073 assert len(padding[1]) == 2
4074 assert len(padding[2]) == 2
4075 if data_format is None:
4076 data_format = image_data_format()
4077 if data_format not in {"channels_first", "channels_last"}:
4078 raise ValueError("Unknown data_format: " + str(data_format))
4080 if data_format == "channels_first":
4081 pattern = [
4082 [0, 0],
4083 [0, 0],
4084 [padding[0][0], padding[0][1]],
4085 [padding[1][0], padding[1][1]],
4086 [padding[2][0], padding[2][1]],
4087 ]
4088 else:
4089 pattern = [
4090 [0, 0],
4091 [padding[0][0], padding[0][1]],
4092 [padding[1][0], padding[1][1]],
4093 [padding[2][0], padding[2][1]],
4094 [0, 0],
4095 ]
4096 return tf.compat.v1.pad(x, pattern)
4099@keras_export("keras.backend.stack")
4100@tf.__internal__.dispatch.add_dispatch_support
4101@doc_controls.do_not_generate_docs
4102def stack(x, axis=0):
4103 """Stacks a list of rank `R` tensors into a rank `R+1` tensor.
4105 Args:
4106 x: List of tensors.
4107 axis: Axis along which to perform stacking.
4109 Returns:
4110 A tensor.
4112 Example:
4114 >>> a = tf.constant([[1, 2],[3, 4]])
4115 >>> b = tf.constant([[10, 20],[30, 40]])
4116 >>> tf.keras.backend.stack((a, b))
4117 <tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy=
4118 array([[[ 1, 2],
4119 [ 3, 4]],
4120 [[10, 20],
4121 [30, 40]]], dtype=int32)>
4123 """
4124 return tf.stack(x, axis=axis)
4127@keras_export("keras.backend.one_hot")
4128@tf.__internal__.dispatch.add_dispatch_support
4129@doc_controls.do_not_generate_docs
4130def one_hot(indices, num_classes):
4131 """Computes the one-hot representation of an integer tensor.
4133 Args:
4134 indices: nD integer tensor of shape
4135 `(batch_size, dim1, dim2, ... dim(n-1))`
4136 num_classes: Integer, number of classes to consider.
4138 Returns:
4139 (n + 1)D one hot representation of the input
4140 with shape `(batch_size, dim1, dim2, ... dim(n-1), num_classes)`
4142 Returns:
4143 The one-hot tensor.
4144 """
4145 return tf.one_hot(indices, depth=num_classes, axis=-1)
4148@keras_export("keras.backend.reverse")
4149@tf.__internal__.dispatch.add_dispatch_support
4150@doc_controls.do_not_generate_docs
4151def reverse(x, axes):
4152 """Reverse a tensor along the specified axes.
4154 Args:
4155 x: Tensor to reverse.
4156 axes: Integer or iterable of integers.
4157 Axes to reverse.
4159 Returns:
4160 A tensor.
4161 """
4162 if isinstance(axes, int):
4163 axes = [axes]
4164 return tf.reverse(x, axes)
4167# VALUE MANIPULATION
4168_VALUE_SET_CODE_STRING = """
4169 >>> K = tf.keras.backend # Common keras convention
4170 >>> v = K.variable(1.)
4172 >>> # reassign
4173 >>> K.set_value(v, 2.)
4174 >>> print(K.get_value(v))
4175 2.0
4177 >>> # increment
4178 >>> K.set_value(v, K.get_value(v) + 1)
4179 >>> print(K.get_value(v))
4180 3.0
4182 Variable semantics in TensorFlow 2 are eager execution friendly. The above
4183 code is roughly equivalent to:
4185 >>> v = tf.Variable(1.)
4187 >>> v.assign(2.)
4188 >>> print(v.numpy())
4189 2.0
4191 >>> v.assign_add(1.)
4192 >>> print(v.numpy())
4193 3.0"""[
4194 3:
4195] # Prune first newline and indent to match the docstring template.
4198@keras_export("keras.backend.get_value")
4199@doc_controls.do_not_generate_docs
4200def get_value(x):
4201 """Returns the value of a variable.
4203 `backend.get_value` is the complement of `backend.set_value`, and provides
4204 a generic interface for reading from variables while abstracting away the
4205 differences between TensorFlow 1.x and 2.x semantics.
4207 {snippet}
4209 Args:
4210 x: input variable.
4212 Returns:
4213 A Numpy array.
4214 """
4215 if not tf.is_tensor(x):
4216 return x
4217 if tf.executing_eagerly() or isinstance(x, tf.__internal__.EagerTensor):
4218 return x.numpy()
4219 if not getattr(x, "_in_graph_mode", True):
4220 # This is a variable which was created in an eager context, but is being
4221 # evaluated from a Graph.
4222 with tf.__internal__.eager_context.eager_mode():
4223 return x.numpy()
4225 if tf.compat.v1.executing_eagerly_outside_functions():
4226 # This method of evaluating works inside the Keras FuncGraph.
4227 with tf.init_scope():
4228 return x.numpy()
4230 with x.graph.as_default():
4231 return x.eval(session=get_session((x,)))
4234@keras_export("keras.backend.batch_get_value")
4235@tf.__internal__.dispatch.add_dispatch_support
4236@doc_controls.do_not_generate_docs
4237def batch_get_value(tensors):
4238 """Returns the value of more than one tensor variable.
4240 Args:
4241 tensors: list of ops to run.
4243 Returns:
4244 A list of Numpy arrays.
4246 Raises:
4247 RuntimeError: If this method is called inside defun.
4248 """
4249 if tf.executing_eagerly():
4250 return [x.numpy() for x in tensors]
4251 elif tf.inside_function():
4252 raise RuntimeError("Cannot get value inside Tensorflow graph function.")
4253 if tensors:
4254 return get_session(tensors).run(tensors)
4255 else:
4256 return []
4259@keras_export("keras.backend.set_value")
4260@doc_controls.do_not_generate_docs
4261def set_value(x, value):
4262 """Sets the value of a variable, from a Numpy array.
4264 `backend.set_value` is the complement of `backend.get_value`, and provides
4265 a generic interface for assigning to variables while abstracting away the
4266 differences between TensorFlow 1.x and 2.x semantics.
4268 {snippet}
4270 Args:
4271 x: Variable to set to a new value.
4272 value: Value to set the tensor to, as a Numpy array
4273 (of the same shape).
4274 """
4275 value = np.asarray(value, dtype=dtype_numpy(x))
4276 if tf.compat.v1.executing_eagerly_outside_functions():
4277 _assign_value_to_variable(x, value)
4278 else:
4279 with get_graph().as_default():
4280 tf_dtype = tf.as_dtype(x.dtype.name.split("_")[0])
4281 if hasattr(x, "_assign_placeholder"):
4282 assign_placeholder = x._assign_placeholder
4283 assign_op = x._assign_op
4284 else:
4285 # In order to support assigning weights to resizable variables
4286 # in Keras, we make a placeholder with the correct number of
4287 # dimensions but with None in each dimension. This way, we can
4288 # assign weights of any size (as long as they have the correct
4289 # dimensionality).
4290 placeholder_shape = tf.TensorShape([None] * value.ndim)
4291 assign_placeholder = tf.compat.v1.placeholder(
4292 tf_dtype, shape=placeholder_shape
4293 )
4294 assign_op = x.assign(assign_placeholder)
4295 x._assign_placeholder = assign_placeholder
4296 x._assign_op = assign_op
4297 get_session().run(assign_op, feed_dict={assign_placeholder: value})
4300@keras_export("keras.backend.batch_set_value")
4301@tf.__internal__.dispatch.add_dispatch_support
4302@doc_controls.do_not_generate_docs
4303def batch_set_value(tuples):
4304 """Sets the values of many tensor variables at once.
4306 Args:
4307 tuples: a list of tuples `(tensor, value)`.
4308 `value` should be a Numpy array.
4309 """
4310 if tf.executing_eagerly() or tf.inside_function():
4311 for x, value in tuples:
4312 value = np.asarray(value, dtype=dtype_numpy(x))
4313 _assign_value_to_variable(x, value)
4314 else:
4315 with get_graph().as_default():
4316 if tuples:
4317 assign_ops = []
4318 feed_dict = {}
4319 for x, value in tuples:
4320 value = np.asarray(value, dtype=dtype_numpy(x))
4321 tf_dtype = tf.as_dtype(x.dtype.name.split("_")[0])
4322 if hasattr(x, "_assign_placeholder"):
4323 assign_placeholder = x._assign_placeholder
4324 assign_op = x._assign_op
4325 else:
4326 # In order to support assigning weights to resizable
4327 # variables in Keras, we make a placeholder with the
4328 # correct number of dimensions but with None in each
4329 # dimension. This way, we can assign weights of any size
4330 # (as long as they have the correct dimensionality).
4331 placeholder_shape = tf.TensorShape([None] * value.ndim)
4332 assign_placeholder = tf.compat.v1.placeholder(
4333 tf_dtype, shape=placeholder_shape
4334 )
4335 assign_op = x.assign(assign_placeholder)
4336 x._assign_placeholder = assign_placeholder
4337 x._assign_op = assign_op
4338 assign_ops.append(assign_op)
4339 feed_dict[assign_placeholder] = value
4340 get_session().run(assign_ops, feed_dict=feed_dict)
4343get_value.__doc__ = get_value.__doc__.format(snippet=_VALUE_SET_CODE_STRING)
4344set_value.__doc__ = set_value.__doc__.format(snippet=_VALUE_SET_CODE_STRING)
4347def _assign_value_to_variable(variable, value):
4348 # Helper function to assign value to variable. It handles normal tf.Variable
4349 # as well as DTensor variable.
4350 if isinstance(variable, dtensor.DVariable):
4351 mesh = variable.layout.mesh
4352 replicate_layout = dtensor.Layout.replicated(
4353 rank=variable.shape.rank, mesh=mesh
4354 )
4355 # TODO(b/262894693): Avoid the broadcast of tensor to all devices.
4356 d_value = dtensor.copy_to_mesh(value, replicate_layout)
4357 d_value = dtensor.relayout(d_value, variable.layout)
4358 variable.assign(d_value)
4359 else:
4360 # For the normal tf.Variable assign
4361 variable.assign(value)
4364@keras_export("keras.backend.print_tensor")
4365@tf.__internal__.dispatch.add_dispatch_support
4366@doc_controls.do_not_generate_docs
4367def print_tensor(x, message="", summarize=3):
4368 """Prints `message` and the tensor value when evaluated.
4370 Note that `print_tensor` returns a new tensor identical to `x`
4371 which should be used in the following code. Otherwise the
4372 print operation is not taken into account during evaluation.
4374 Example:
4376 >>> x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
4377 >>> tf.keras.backend.print_tensor(x)
4378 <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
4379 array([[1., 2.],
4380 [3., 4.]], dtype=float32)>
4382 Args:
4383 x: Tensor to print.
4384 message: Message to print jointly with the tensor.
4385 summarize: The first and last `summarize` elements within each dimension
4386 are recursively printed per Tensor. If None, then the first 3 and
4387 last 3 elements of each dimension are printed for each tensor. If
4388 set to -1, it will print all elements of every tensor.
4390 Returns:
4391 The same tensor `x`, unchanged.
4392 """
4393 if isinstance(x, tf.Tensor) and hasattr(x, "graph"):
4394 with get_graph().as_default():
4395 op = tf.print(
4396 message, x, output_stream=sys.stdout, summarize=summarize
4397 )
4398 with tf.control_dependencies([op]):
4399 return tf.identity(x)
4400 else:
4401 tf.print(message, x, output_stream=sys.stdout, summarize=summarize)
4402 return x
4405# GRAPH MANIPULATION
4408class GraphExecutionFunction:
4409 """Runs a computation graph.
4411 It's possible to pass arguments to `tf.Session.run()` via `session_kwargs`.
4412 In particular additional operations via `fetches` argument and additional
4413 tensor substitutions via `feed_dict` arguments. Note that given
4414 substitutions are merged with substitutions from `inputs`. Even though
4415 `feed_dict` is passed once in the constructor (called in `model.compile()`)
4416 we can modify the values in the dictionary. Through this feed_dict we can
4417 provide additional substitutions besides Keras inputs.
4419 Args:
4420 inputs: Feed placeholders to the computation graph.
4421 outputs: Output tensors to fetch.
4422 updates: Additional update ops to be run at function call.
4423 name: A name to help users identify what this function does.
4424 session_kwargs: Arguments to `tf.Session.run()`:
4425 `fetches`, `feed_dict`, `options`, `run_metadata`.
4426 """
4428 def __init__(
4429 self, inputs, outputs, updates=None, name=None, **session_kwargs
4430 ):
4431 updates = updates or []
4432 if not isinstance(updates, (list, tuple)):
4433 raise TypeError(
4434 "`updates` in a Keras backend function "
4435 "should be a list or tuple."
4436 )
4438 self.inputs = tf.nest.flatten(
4439 tf_utils.convert_variables_to_tensors(inputs),
4440 expand_composites=True,
4441 )
4442 self._outputs_structure = tf_utils.convert_variables_to_tensors(outputs)
4443 self.outputs = tf.nest.flatten(
4444 self._outputs_structure, expand_composites=True
4445 )
4446 # TODO(b/127668432): Consider using autograph to generate these
4447 # dependencies in call.
4448 # Index 0 = total loss or model output for `predict`.
4449 with tf.control_dependencies([self.outputs[0]]):
4450 updates_ops = []
4451 for update in updates:
4452 if isinstance(update, tuple):
4453 p, new_p = update
4454 updates_ops.append(tf.compat.v1.assign(p, new_p))
4455 else:
4456 # assumed already an op
4457 updates_ops.append(update)
4458 self.updates_op = tf.group(*updates_ops)
4459 self.name = name
4460 # additional tensor substitutions
4461 self.feed_dict = session_kwargs.pop("feed_dict", None)
4462 # additional operations
4463 self.fetches = session_kwargs.pop("fetches", [])
4464 if not isinstance(self.fetches, list):
4465 self.fetches = [self.fetches]
4466 self.run_options = session_kwargs.pop("options", None)
4467 self.run_metadata = session_kwargs.pop("run_metadata", None)
4468 # The main use case of `fetches` being passed to a model is the ability
4469 # to run custom updates
4470 # This requires us to wrap fetches in `identity` ops.
4471 self.fetches = [tf.identity(x) for x in self.fetches]
4472 self.session_kwargs = session_kwargs
4473 # This mapping keeps track of the function that should receive the
4474 # output from a fetch in `fetches`: { fetch: function(fetch_output) }
4475 # A Callback can use this to register a function with access to the
4476 # output values for a fetch it added.
4477 self.fetch_callbacks = {}
4479 if session_kwargs:
4480 raise ValueError(
4481 "Some keys in session_kwargs are not supported at this time: %s"
4482 % (session_kwargs.keys(),)
4483 )
4485 self._callable_fn = None
4486 self._feed_arrays = None
4487 self._feed_symbols = None
4488 self._symbol_vals = None
4489 self._fetches = None
4490 self._session = None
4492 def _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session):
4493 """Generates a callable that runs the graph.
4495 Args:
4496 feed_arrays: List of input tensors to be fed Numpy arrays at runtime.
4497 feed_symbols: List of input tensors to be fed symbolic tensors at
4498 runtime.
4499 symbol_vals: List of symbolic tensors to be fed to `feed_symbols`.
4500 session: Session to use to generate the callable.
4502 Returns:
4503 Function that runs the graph according to the above options.
4504 """
4505 # Prepare callable options.
4506 callable_opts = config_pb2.CallableOptions()
4507 # Handle external-data feed.
4508 for x in feed_arrays:
4509 callable_opts.feed.append(x.name)
4510 if self.feed_dict:
4511 for key in sorted(self.feed_dict.keys()):
4512 callable_opts.feed.append(key.name)
4513 # Handle symbolic feed.
4514 for x, y in zip(feed_symbols, symbol_vals):
4515 connection = callable_opts.tensor_connection.add()
4516 if x.dtype != y.dtype:
4517 y = tf.cast(y, dtype=x.dtype)
4518 from_tensor = _as_graph_element(y)
4519 if from_tensor is None:
4520 from_tensor = y
4521 connection.from_tensor = from_tensor.name # Data tensor
4522 connection.to_tensor = x.name # Placeholder
4523 # Handle fetches.
4524 for x in self.outputs + self.fetches:
4525 callable_opts.fetch.append(x.name)
4526 # Handle updates.
4527 callable_opts.target.append(self.updates_op.name)
4528 # Handle run_options.
4529 if self.run_options:
4530 callable_opts.run_options.CopyFrom(self.run_options)
4531 # Create callable.
4532 callable_fn = session._make_callable_from_options(callable_opts)
4533 # Cache parameters corresponding to the generated callable, so that
4534 # we can detect future mismatches and refresh the callable.
4535 self._callable_fn = callable_fn
4536 self._feed_arrays = feed_arrays
4537 self._feed_symbols = feed_symbols
4538 self._symbol_vals = symbol_vals
4539 self._fetches = list(self.fetches)
4540 self._session = session
4542 def _call_fetch_callbacks(self, fetches_output):
4543 for fetch, output in zip(self._fetches, fetches_output):
4544 if fetch in self.fetch_callbacks:
4545 self.fetch_callbacks[fetch](output)
4547 def _eval_if_composite(self, tensor):
4548 """Helper method which evaluates any CompositeTensors passed to it."""
4549 # We need to evaluate any composite tensor objects that have been
4550 # reconstructed in 'pack_sequence_as', since otherwise they'll be output
4551 # as actual CompositeTensor objects instead of the value(s) contained in
4552 # the CompositeTensors. E.g., if output_structure contains a
4553 # SparseTensor, then this ensures that we return its value as a
4554 # SparseTensorValue rather than a SparseTensor.
4556 if tf_utils.is_extension_type(tensor):
4557 return self._session.run(tensor)
4558 else:
4559 return tensor
4561 def __call__(self, inputs):
4562 inputs = tf.nest.flatten(
4563 tf_utils.convert_variables_to_tensors(inputs),
4564 expand_composites=True,
4565 )
4567 session = get_session(inputs)
4568 feed_arrays = []
4569 array_vals = []
4570 feed_symbols = []
4571 symbol_vals = []
4572 for tensor, value in zip(self.inputs, inputs):
4573 if value is None:
4574 continue
4576 if tf.is_tensor(value):
4577 # Case: feeding symbolic tensor.
4578 feed_symbols.append(tensor)
4579 symbol_vals.append(value)
4580 else:
4581 # Case: feeding Numpy array.
4582 feed_arrays.append(tensor)
4583 # We need to do array conversion and type casting at this level,
4584 # since `callable_fn` only supports exact matches.
4585 tensor_type = tf.as_dtype(tensor.dtype)
4586 array_vals.append(
4587 np.asarray(value, dtype=tensor_type.as_numpy_dtype)
4588 )
4590 if self.feed_dict:
4591 for key in sorted(self.feed_dict.keys()):
4592 array_vals.append(
4593 np.asarray(
4594 self.feed_dict[key], dtype=key.dtype.as_numpy_dtype
4595 )
4596 )
4598 # Refresh callable if anything has changed.
4599 if (
4600 self._callable_fn is None
4601 or feed_arrays != self._feed_arrays
4602 or symbol_vals != self._symbol_vals
4603 or feed_symbols != self._feed_symbols
4604 or self.fetches != self._fetches
4605 or session != self._session
4606 ):
4607 self._make_callable(feed_arrays, feed_symbols, symbol_vals, session)
4609 fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
4610 self._call_fetch_callbacks(fetched[-len(self._fetches) :])
4611 output_structure = tf.nest.pack_sequence_as(
4612 self._outputs_structure,
4613 fetched[: len(self.outputs)],
4614 expand_composites=True,
4615 )
4616 # We need to evaluate any composite tensor objects that have been
4617 # reconstructed in 'pack_sequence_as', since otherwise they'll be output
4618 # as actual CompositeTensor objects instead of the value(s) contained in
4619 # the CompositeTensors. E.g., if output_structure contains a
4620 # SparseTensor, then this ensures that we return its value as a
4621 # SparseTensorValue rather than a SparseTensor.
4622 return tf.nest.map_structure(self._eval_if_composite, output_structure)
4625@keras_export("keras.backend.function")
4626@doc_controls.do_not_generate_docs
4627def function(inputs, outputs, updates=None, name=None, **kwargs):
4628 """Instantiates a Keras function.
4630 Args:
4631 inputs: List of placeholder tensors.
4632 outputs: List of output tensors.
4633 updates: List of update ops.
4634 name: String, name of function.
4635 **kwargs: Passed to `tf.Session.run`.
4637 Returns:
4638 Output values as Numpy arrays.
4640 Raises:
4641 ValueError: if invalid kwargs are passed in or if in eager execution.
4642 """
4643 if tf.compat.v1.executing_eagerly_outside_functions():
4644 if kwargs:
4645 raise ValueError(
4646 "Session keyword arguments are not supported during "
4647 "eager execution. You passed: %s" % (kwargs,)
4648 )
4649 if updates:
4650 raise ValueError(
4651 "`updates` argument is not supported during "
4652 "eager execution. You passed: %s" % (updates,)
4653 )
4654 from keras.src import models
4656 model = models.Model(inputs=inputs, outputs=outputs)
4658 wrap_outputs = isinstance(outputs, list) and len(outputs) == 1
4660 def func(model_inputs):
4661 outs = model(model_inputs)
4662 if wrap_outputs:
4663 outs = [outs]
4664 return tf_utils.sync_to_numpy_or_python_type(outs)
4666 return func
4668 if kwargs:
4669 for key in kwargs:
4670 if key not in tf_inspect.getfullargspec(tf.compat.v1.Session.run)[
4671 0
4672 ] and key not in ["inputs", "outputs", "updates", "name"]:
4673 msg = (
4674 'Invalid argument "%s" passed to K.function with '
4675 "TensorFlow backend" % key
4676 )
4677 raise ValueError(msg)
4678 return GraphExecutionFunction(
4679 inputs, outputs, updates=updates, name=name, **kwargs
4680 )
4683@keras_export("keras.backend.gradients")
4684@doc_controls.do_not_generate_docs
4685def gradients(loss, variables):
4686 """Returns the gradients of `loss` w.r.t. `variables`.
4688 Args:
4689 loss: Scalar tensor to minimize.
4690 variables: List of variables.
4692 Returns:
4693 A gradients tensor.
4694 """
4695 return tf.compat.v1.gradients(
4696 loss, variables, colocate_gradients_with_ops=True
4697 )
4700@keras_export("keras.backend.stop_gradient")
4701@tf.__internal__.dispatch.add_dispatch_support
4702@doc_controls.do_not_generate_docs
4703def stop_gradient(variables):
4704 """Returns `variables` but with zero gradient w.r.t. every other variable.
4706 Args:
4707 variables: Tensor or list of tensors to consider constant with respect
4708 to any other variable.
4711 Returns:
4712 A single tensor or a list of tensors (depending on the passed argument)
4713 that has no gradient with respect to any other variable.
4714 """
4715 if isinstance(variables, (list, tuple)):
4716 return map(tf.stop_gradient, variables)
4717 return tf.stop_gradient(variables)
4720# CONTROL FLOW
4723@keras_export("keras.backend.rnn")
4724@tf.__internal__.dispatch.add_dispatch_support
4725def rnn(
4726 step_function,
4727 inputs,
4728 initial_states,
4729 go_backwards=False,
4730 mask=None,
4731 constants=None,
4732 unroll=False,
4733 input_length=None,
4734 time_major=False,
4735 zero_output_for_mask=False,
4736 return_all_outputs=True,
4737):
4738 """Iterates over the time dimension of a tensor.
4740 Args:
4741 step_function: RNN step function.
4742 Args;
4743 input; Tensor with shape `(samples, ...)` (no time dimension),
4744 representing input for the batch of samples at a certain
4745 time step.
4746 states; List of tensors.
4747 Returns;
4748 output; Tensor with shape `(samples, output_dim)`
4749 (no time dimension).
4750 new_states; List of tensors, same length and shapes
4751 as 'states'. The first state in the list must be the
4752 output tensor at the previous timestep.
4753 inputs: Tensor of temporal data of shape `(samples, time, ...)`
4754 (at least 3D), or nested tensors, and each of which has shape
4755 `(samples, time, ...)`.
4756 initial_states: Tensor with shape `(samples, state_size)`
4757 (no time dimension), containing the initial values for the states
4758 used in the step function. In the case that state_size is in a
4759 nested shape, the shape of initial_states will also follow the
4760 nested structure.
4761 go_backwards: Boolean. If True, do the iteration over the time
4762 dimension in reverse order and return the reversed sequence.
4763 mask: Binary tensor with shape `(samples, time, 1)`,
4764 with a zero for every element that is masked.
4765 constants: List of constant values passed at each step.
4766 unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
4767 input_length: An integer or a 1-D Tensor, depending on whether
4768 the time dimension is fixed-length or not. In case of variable
4769 length input, it is used for masking in case there's no mask
4770 specified.
4771 time_major: Boolean. If true, the inputs and outputs will be in shape
4772 `(timesteps, batch, ...)`, whereas in the False case, it will be
4773 `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
4774 efficient because it avoids transposes at the beginning and end of
4775 the RNN calculation. However, most TensorFlow data is batch-major,
4776 so by default this function accepts input and emits output in
4777 batch-major form.
4778 zero_output_for_mask: Boolean. If True, the output for masked timestep
4779 will be zeros, whereas in the False case, output from previous
4780 timestep is returned.
4781 return_all_outputs: Boolean. If True, return the recurrent outputs for
4782 all timesteps in the sequence. If False, only return the output for
4783 the last timestep (which consumes less memory).
4785 Returns:
4786 A tuple, `(last_output, outputs, new_states)`.
4787 last_output: the latest output of the rnn, of shape `(samples, ...)`
4788 outputs:
4789 - If `return_all_outputs=True`: a tensor with shape
4790 `(samples, time, ...)` where each entry `outputs[s, t]` is the
4791 output of the step function at time `t` for sample `s`
4792 - Else, a tensor equal to `last_output` with shape
4793 `(samples, 1, ...)`
4794 new_states: list of tensors, latest states returned by
4795 the step function, of shape `(samples, ...)`.
4797 Raises:
4798 ValueError: if input dimension is less than 3.
4799 ValueError: if `unroll` is `True` but input timestep is not a fixed
4800 number.
4801 ValueError: if `mask` is provided (not `None`) but states is not
4802 provided (`len(states)` == 0).
4803 """
4804 if not tf.__internal__.tf2.enabled():
4805 return_all_outputs = True # Not supported in TF1.
4807 def swap_batch_timestep(input_t):
4808 # Swap the batch and timestep dim for the incoming tensor.
4809 axes = list(range(len(input_t.shape)))
4810 axes[0], axes[1] = 1, 0
4811 return tf.compat.v1.transpose(input_t, axes)
4813 if not time_major:
4814 inputs = tf.nest.map_structure(swap_batch_timestep, inputs)
4816 flatted_inputs = tf.nest.flatten(inputs)
4817 time_steps = flatted_inputs[0].shape[0]
4818 batch = flatted_inputs[0].shape[1]
4819 time_steps_t = tf.shape(flatted_inputs[0])[0]
4821 for input_ in flatted_inputs:
4822 input_.shape.with_rank_at_least(3)
4824 if mask is not None:
4825 if mask.dtype != tf.bool:
4826 mask = tf.cast(mask, tf.bool)
4827 if len(mask.shape) == 2:
4828 mask = expand_dims(mask)
4829 if not time_major:
4830 mask = swap_batch_timestep(mask)
4832 if constants is None:
4833 constants = []
4835 # tf.where needs its condition tensor to be the same shape as its two
4836 # result tensors, but in our case the condition (mask) tensor is
4837 # (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
4838 # So we need to broadcast the mask to match the shape of inputs.
4839 # That's what the tile call does, it just repeats the mask along its
4840 # second dimension n times.
4841 def _expand_mask(mask_t, input_t, fixed_dim=1):
4842 if tf.nest.is_nested(mask_t):
4843 raise ValueError(
4844 f"mask_t is expected to be tensor, but got {mask_t}"
4845 )
4846 if tf.nest.is_nested(input_t):
4847 raise ValueError(
4848 f"input_t is expected to be tensor, but got {input_t}"
4849 )
4850 rank_diff = len(input_t.shape) - len(mask_t.shape)
4851 for _ in range(rank_diff):
4852 mask_t = tf.expand_dims(mask_t, -1)
4853 multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
4854 return tf.tile(mask_t, multiples)
4856 if unroll:
4857 if not time_steps:
4858 raise ValueError("Unrolling requires a fixed number of timesteps.")
4859 states = tuple(initial_states)
4860 successive_states = []
4861 successive_outputs = []
4863 # Process the input tensors. The input tensor need to be split on the
4864 # time_step dim, and reverse if go_backwards is True. In the case of
4865 # nested input, the input is flattened and then transformed
4866 # individually. The result of this will be a tuple of lists, each of
4867 # the item in tuple is list of the tensor with shape (batch, feature)
4868 def _process_single_input_t(input_t):
4869 input_t = tf.unstack(input_t) # unstack for time_step dim
4870 if go_backwards:
4871 input_t.reverse()
4872 return input_t
4874 if tf.nest.is_nested(inputs):
4875 processed_input = tf.nest.map_structure(
4876 _process_single_input_t, inputs
4877 )
4878 else:
4879 processed_input = (_process_single_input_t(inputs),)
4881 def _get_input_tensor(time):
4882 inp = [t_[time] for t_ in processed_input]
4883 return tf.nest.pack_sequence_as(inputs, inp)
4885 if mask is not None:
4886 mask_list = tf.unstack(mask)
4887 if go_backwards:
4888 mask_list.reverse()
4890 for i in range(time_steps):
4891 inp = _get_input_tensor(i)
4892 mask_t = mask_list[i]
4893 output, new_states = step_function(
4894 inp, tuple(states) + tuple(constants)
4895 )
4896 tiled_mask_t = _expand_mask(mask_t, output)
4898 if not successive_outputs:
4899 prev_output = zeros_like(output)
4900 else:
4901 prev_output = successive_outputs[-1]
4903 output = tf.where(tiled_mask_t, output, prev_output)
4905 flat_states = tf.nest.flatten(states)
4906 flat_new_states = tf.nest.flatten(new_states)
4907 tiled_mask_t = tuple(
4908 _expand_mask(mask_t, s) for s in flat_states
4909 )
4910 flat_final_states = tuple(
4911 tf.where(m, s, ps)
4912 for m, s, ps in zip(
4913 tiled_mask_t, flat_new_states, flat_states
4914 )
4915 )
4916 states = tf.nest.pack_sequence_as(states, flat_final_states)
4918 if return_all_outputs:
4919 successive_outputs.append(output)
4920 successive_states.append(states)
4921 else:
4922 successive_outputs = [output]
4923 successive_states = [states]
4924 last_output = successive_outputs[-1]
4925 new_states = successive_states[-1]
4926 outputs = tf.stack(successive_outputs)
4928 if zero_output_for_mask:
4929 last_output = tf.where(
4930 _expand_mask(mask_list[-1], last_output),
4931 last_output,
4932 zeros_like(last_output),
4933 )
4934 outputs = tf.where(
4935 _expand_mask(mask, outputs, fixed_dim=2),
4936 outputs,
4937 zeros_like(outputs),
4938 )
4940 else: # mask is None
4941 for i in range(time_steps):
4942 inp = _get_input_tensor(i)
4943 output, states = step_function(
4944 inp, tuple(states) + tuple(constants)
4945 )
4946 if return_all_outputs:
4947 successive_outputs.append(output)
4948 successive_states.append(states)
4949 else:
4950 successive_outputs = [output]
4951 successive_states = [states]
4952 last_output = successive_outputs[-1]
4953 new_states = successive_states[-1]
4954 outputs = tf.stack(successive_outputs)
4956 else: # Unroll == False
4957 states = tuple(initial_states)
4959 # Create input tensor array, if the inputs is nested tensors, then it
4960 # will be flattened first, and tensor array will be created one per
4961 # flattened tensor.
4962 input_ta = tuple(
4963 tf.TensorArray(
4964 dtype=inp.dtype,
4965 size=time_steps_t,
4966 tensor_array_name=f"input_ta_{i}",
4967 )
4968 for i, inp in enumerate(flatted_inputs)
4969 )
4970 input_ta = tuple(
4971 ta.unstack(input_)
4972 if not go_backwards
4973 else ta.unstack(reverse(input_, 0))
4974 for ta, input_ in zip(input_ta, flatted_inputs)
4975 )
4977 # Get the time(0) input and compute the output for that, the output will
4978 # be used to determine the dtype of output tensor array. Don't read from
4979 # input_ta due to TensorArray clear_after_read default to True.
4980 input_time_zero = tf.nest.pack_sequence_as(
4981 inputs, [inp[0] for inp in flatted_inputs]
4982 )
4983 # output_time_zero is used to determine the cell output shape and its
4984 # dtype. the value is discarded.
4985 output_time_zero, _ = step_function(
4986 input_time_zero, tuple(initial_states) + tuple(constants)
4987 )
4989 output_ta_size = time_steps_t if return_all_outputs else 1
4990 output_ta = tuple(
4991 tf.TensorArray(
4992 dtype=out.dtype,
4993 size=output_ta_size,
4994 element_shape=out.shape,
4995 tensor_array_name=f"output_ta_{i}",
4996 )
4997 for i, out in enumerate(tf.nest.flatten(output_time_zero))
4998 )
5000 time = tf.constant(0, dtype="int32", name="time")
5002 # We only specify the 'maximum_iterations' when building for XLA since
5003 # that causes slowdowns on GPU in TF.
5004 if (
5005 not tf.executing_eagerly()
5006 and control_flow_util.GraphOrParentsInXlaContext(
5007 tf.compat.v1.get_default_graph()
5008 )
5009 ):
5010 if input_length is None:
5011 max_iterations = time_steps_t
5012 else:
5013 max_iterations = tf.reduce_max(input_length)
5014 else:
5015 max_iterations = None
5017 while_loop_kwargs = {
5018 "cond": lambda time, *_: time < time_steps_t,
5019 "maximum_iterations": max_iterations,
5020 "parallel_iterations": 32,
5021 "swap_memory": True,
5022 }
5023 if mask is not None:
5024 if go_backwards:
5025 mask = reverse(mask, 0)
5027 mask_ta = tf.TensorArray(
5028 dtype=tf.bool, size=time_steps_t, tensor_array_name="mask_ta"
5029 )
5030 mask_ta = mask_ta.unstack(mask)
5032 def masking_fn(time):
5033 return mask_ta.read(time)
5035 def compute_masked_output(mask_t, flat_out, flat_mask):
5036 tiled_mask_t = tuple(
5037 _expand_mask(mask_t, o, fixed_dim=len(mask_t.shape))
5038 for o in flat_out
5039 )
5040 return tuple(
5041 tf.where(m, o, fm)
5042 for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask)
5043 )
5045 elif isinstance(input_length, tf.Tensor):
5046 if go_backwards:
5047 max_len = tf.reduce_max(input_length, axis=0)
5048 rev_input_length = tf.subtract(max_len - 1, input_length)
5050 def masking_fn(time):
5051 return tf.less(rev_input_length, time)
5053 else:
5055 def masking_fn(time):
5056 return tf.greater(input_length, time)
5058 def compute_masked_output(mask_t, flat_out, flat_mask):
5059 return tuple(
5060 tf.compat.v1.where(mask_t, o, zo)
5061 for (o, zo) in zip(flat_out, flat_mask)
5062 )
5064 else:
5065 masking_fn = None
5067 if masking_fn is not None:
5068 # Mask for the T output will be base on the output of T - 1. In the
5069 # case T = 0, a zero filled tensor will be used.
5070 flat_zero_output = tuple(
5071 tf.zeros_like(o) for o in tf.nest.flatten(output_time_zero)
5072 )
5074 def _step(time, output_ta_t, prev_output, *states):
5075 """RNN step function.
5077 Args:
5078 time: Current timestep value.
5079 output_ta_t: TensorArray.
5080 prev_output: tuple of outputs from time - 1.
5081 *states: List of states.
5083 Returns:
5084 Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`
5085 """
5086 current_input = tuple(ta.read(time) for ta in input_ta)
5087 # maybe set shape.
5088 current_input = tf.nest.pack_sequence_as(inputs, current_input)
5089 mask_t = masking_fn(time)
5090 output, new_states = step_function(
5091 current_input, tuple(states) + tuple(constants)
5092 )
5093 # mask output
5094 flat_output = tf.nest.flatten(output)
5095 flat_mask_output = (
5096 flat_zero_output
5097 if zero_output_for_mask
5098 else tf.nest.flatten(prev_output)
5099 )
5100 flat_new_output = compute_masked_output(
5101 mask_t, flat_output, flat_mask_output
5102 )
5104 # mask states
5105 flat_state = tf.nest.flatten(states)
5106 flat_new_state = tf.nest.flatten(new_states)
5107 for state, new_state in zip(flat_state, flat_new_state):
5108 if isinstance(new_state, tf.Tensor):
5109 new_state.set_shape(state.shape)
5110 flat_final_state = compute_masked_output(
5111 mask_t, flat_new_state, flat_state
5112 )
5113 new_states = tf.nest.pack_sequence_as(
5114 new_states, flat_final_state
5115 )
5117 ta_index_to_write = time if return_all_outputs else 0
5118 output_ta_t = tuple(
5119 ta.write(ta_index_to_write, out)
5120 for ta, out in zip(output_ta_t, flat_new_output)
5121 )
5123 return (time + 1, output_ta_t, tuple(flat_new_output)) + tuple(
5124 new_states
5125 )
5127 final_outputs = tf.compat.v1.while_loop(
5128 body=_step,
5129 loop_vars=(time, output_ta, flat_zero_output) + states,
5130 **while_loop_kwargs,
5131 )
5132 # Skip final_outputs[2] which is the output for final timestep.
5133 new_states = final_outputs[3:]
5134 else:
5136 def _step(time, output_ta_t, *states):
5137 """RNN step function.
5139 Args:
5140 time: Current timestep value.
5141 output_ta_t: TensorArray.
5142 *states: List of states.
5144 Returns:
5145 Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
5146 """
5147 current_input = tuple(ta.read(time) for ta in input_ta)
5148 current_input = tf.nest.pack_sequence_as(inputs, current_input)
5149 output, new_states = step_function(
5150 current_input, tuple(states) + tuple(constants)
5151 )
5152 flat_state = tf.nest.flatten(states)
5153 flat_new_state = tf.nest.flatten(new_states)
5154 for state, new_state in zip(flat_state, flat_new_state):
5155 if isinstance(new_state, tf.Tensor):
5156 new_state.set_shape(state.shape)
5158 flat_output = tf.nest.flatten(output)
5159 ta_index_to_write = time if return_all_outputs else 0
5160 output_ta_t = tuple(
5161 ta.write(ta_index_to_write, out)
5162 for ta, out in zip(output_ta_t, flat_output)
5163 )
5165 new_states = tf.nest.pack_sequence_as(
5166 initial_states, flat_new_state
5167 )
5168 return (time + 1, output_ta_t) + tuple(new_states)
5170 final_outputs = tf.compat.v1.while_loop(
5171 body=_step,
5172 loop_vars=(time, output_ta) + states,
5173 **while_loop_kwargs,
5174 )
5175 new_states = final_outputs[2:]
5177 output_ta = final_outputs[1]
5179 outputs = tuple(o.stack() for o in output_ta)
5180 last_output = tuple(o[-1] for o in outputs)
5182 outputs = tf.nest.pack_sequence_as(output_time_zero, outputs)
5183 last_output = tf.nest.pack_sequence_as(output_time_zero, last_output)
5185 # static shape inference
5186 def set_shape(output_):
5187 if isinstance(output_, tf.Tensor):
5188 shape = output_.shape.as_list()
5189 if return_all_outputs:
5190 shape[0] = time_steps
5191 else:
5192 shape[0] = 1
5193 shape[1] = batch
5194 output_.set_shape(shape)
5195 return output_
5197 outputs = tf.nest.map_structure(set_shape, outputs)
5199 if not time_major:
5200 outputs = tf.nest.map_structure(swap_batch_timestep, outputs)
5202 return last_output, outputs, new_states
5205@keras_export("keras.backend.switch")
5206@tf.__internal__.dispatch.add_dispatch_support
5207@doc_controls.do_not_generate_docs
5208def switch(condition, then_expression, else_expression):
5209 """Switches between two operations depending on a scalar value.
5211 Note that both `then_expression` and `else_expression`
5212 should be symbolic tensors of the *same shape*.
5214 Args:
5215 condition: tensor (`int` or `bool`).
5216 then_expression: either a tensor, or a callable that returns a tensor.
5217 else_expression: either a tensor, or a callable that returns a tensor.
5219 Returns:
5220 The selected tensor.
5222 Raises:
5223 ValueError: If rank of `condition` is greater than rank of expressions.
5224 """
5225 if condition.dtype != tf.bool:
5226 condition = tf.cast(condition, "bool")
5227 cond_ndim = ndim(condition)
5228 if not cond_ndim:
5229 if not callable(then_expression):
5231 def then_expression_fn():
5232 return then_expression
5234 else:
5235 then_expression_fn = then_expression
5236 if not callable(else_expression):
5238 def else_expression_fn():
5239 return else_expression
5241 else:
5242 else_expression_fn = else_expression
5243 x = tf.compat.v1.cond(condition, then_expression_fn, else_expression_fn)
5244 else:
5245 # tf.where needs its condition tensor
5246 # to be the same shape as its two
5247 # result tensors
5248 if callable(then_expression):
5249 then_expression = then_expression()
5250 if callable(else_expression):
5251 else_expression = else_expression()
5252 expr_ndim = ndim(then_expression)
5253 if cond_ndim > expr_ndim:
5254 raise ValueError(
5255 "Rank of `condition` should be less than or"
5256 " equal to rank of `then_expression` and "
5257 "`else_expression`. ndim(condition)="
5258 + str(cond_ndim)
5259 + ", ndim(then_expression)="
5260 + str(expr_ndim)
5261 )
5262 if cond_ndim > 1:
5263 ndim_diff = expr_ndim - cond_ndim
5264 cond_shape = tf.concat(
5265 [tf.shape(condition), [1] * ndim_diff], axis=0
5266 )
5267 condition = tf.reshape(condition, cond_shape)
5268 expr_shape = tf.shape(then_expression)
5269 shape_diff = expr_shape - cond_shape
5270 tile_shape = tf.where(
5271 shape_diff > 0, expr_shape, tf.ones_like(expr_shape)
5272 )
5273 condition = tf.tile(condition, tile_shape)
5274 x = tf.where(condition, then_expression, else_expression)
5275 return x
5278@keras_export("keras.backend.in_train_phase")
5279@doc_controls.do_not_generate_docs
5280def in_train_phase(x, alt, training=None):
5281 """Selects `x` in train phase, and `alt` otherwise.
5283 Note that `alt` should have the *same shape* as `x`.
5285 Args:
5286 x: What to return in train phase
5287 (tensor or callable that returns a tensor).
5288 alt: What to return otherwise
5289 (tensor or callable that returns a tensor).
5290 training: Optional scalar tensor
5291 (or Python boolean, or Python integer)
5292 specifying the learning phase.
5294 Returns:
5295 Either `x` or `alt` based on the `training` flag.
5296 the `training` flag defaults to `K.learning_phase()`.
5297 """
5298 from keras.src.engine import (
5299 base_layer_utils,
5300 )
5302 if training is None:
5303 training = base_layer_utils.call_context().training
5305 if training is None:
5306 training = learning_phase()
5308 # TODO(b/138862903): Handle the case when training is tensor.
5309 if not tf.is_tensor(training):
5310 if training == 1 or training is True:
5311 if callable(x):
5312 return x()
5313 else:
5314 return x
5316 elif training == 0 or training is False:
5317 if callable(alt):
5318 return alt()
5319 else:
5320 return alt
5322 # else: assume learning phase is a placeholder tensor.
5323 x = switch(training, x, alt)
5324 return x
5327@keras_export("keras.backend.in_test_phase")
5328@doc_controls.do_not_generate_docs
5329def in_test_phase(x, alt, training=None):
5330 """Selects `x` in test phase, and `alt` otherwise.
5332 Note that `alt` should have the *same shape* as `x`.
5334 Args:
5335 x: What to return in test phase
5336 (tensor or callable that returns a tensor).
5337 alt: What to return otherwise
5338 (tensor or callable that returns a tensor).
5339 training: Optional scalar tensor
5340 (or Python boolean, or Python integer)
5341 specifying the learning phase.
5343 Returns:
5344 Either `x` or `alt` based on `K.learning_phase`.
5345 """
5346 return in_train_phase(alt, x, training=training)
5349# NN OPERATIONS
5352@keras_export("keras.backend.relu")
5353@tf.__internal__.dispatch.add_dispatch_support
5354@doc_controls.do_not_generate_docs
5355def relu(x, alpha=0.0, max_value=None, threshold=0.0):
5356 """Rectified linear unit.
5358 With default values, it returns element-wise `max(x, 0)`.
5360 Otherwise, it follows:
5361 `f(x) = max_value` for `x >= max_value`,
5362 `f(x) = x` for `threshold <= x < max_value`,
5363 `f(x) = alpha * (x - threshold)` otherwise.
5365 Args:
5366 x: A tensor or variable.
5367 alpha: A scalar, slope of negative section (default=`0.`).
5368 max_value: float. Saturation threshold.
5369 threshold: float. Threshold value for thresholded activation.
5371 Returns:
5372 A tensor.
5373 """
5374 # While x can be a tensor or variable, we also see cases where
5375 # numpy arrays, lists, tuples are passed as well.
5376 # lists, tuples do not have 'dtype' attribute.
5377 dtype = getattr(x, "dtype", floatx())
5378 if alpha != 0.0:
5379 if max_value is None and threshold == 0:
5380 return tf.nn.leaky_relu(x, alpha=alpha)
5382 if threshold != 0:
5383 negative_part = tf.nn.relu(-x + threshold)
5384 else:
5385 negative_part = tf.nn.relu(-x)
5387 clip_max = max_value is not None
5389 if threshold != 0:
5390 # computes x for x > threshold else 0
5391 x = x * tf.cast(tf.greater(x, threshold), dtype=dtype)
5392 elif max_value == 6:
5393 # if no threshold, then can use nn.relu6 native TF op for performance
5394 x = tf.nn.relu6(x)
5395 clip_max = False
5396 else:
5397 x = tf.nn.relu(x)
5399 if clip_max:
5400 max_value = _constant_to_tensor(max_value, x.dtype.base_dtype)
5401 zero = _constant_to_tensor(0, x.dtype.base_dtype)
5402 x = tf.clip_by_value(x, zero, max_value)
5404 if alpha != 0.0:
5405 alpha = _to_tensor(alpha, x.dtype.base_dtype)
5406 x -= alpha * negative_part
5407 return x
5410@keras_export("keras.backend.elu")
5411@tf.__internal__.dispatch.add_dispatch_support
5412@doc_controls.do_not_generate_docs
5413def elu(x, alpha=1.0):
5414 """Exponential linear unit.
5416 Args:
5417 x: A tensor or variable to compute the activation function for.
5418 alpha: A scalar, slope of negative section.
5420 Returns:
5421 A tensor.
5422 """
5423 res = tf.nn.elu(x)
5424 if alpha == 1:
5425 return res
5426 else:
5427 return tf.where(x > 0, res, alpha * res)
5430@keras_export("keras.backend.softmax")
5431@tf.__internal__.dispatch.add_dispatch_support
5432@doc_controls.do_not_generate_docs
5433def softmax(x, axis=-1):
5434 """Softmax of a tensor.
5436 Args:
5437 x: A tensor or variable.
5438 axis: The dimension softmax would be performed on.
5439 The default is -1 which indicates the last dimension.
5441 Returns:
5442 A tensor.
5443 """
5444 return tf.nn.softmax(x, axis=axis)
5447@keras_export("keras.backend.softplus")
5448@tf.__internal__.dispatch.add_dispatch_support
5449@doc_controls.do_not_generate_docs
5450def softplus(x):
5451 """Softplus of a tensor.
5453 Args:
5454 x: A tensor or variable.
5456 Returns:
5457 A tensor.
5458 """
5459 return tf.math.softplus(x)
5462@keras_export("keras.backend.softsign")
5463@tf.__internal__.dispatch.add_dispatch_support
5464@doc_controls.do_not_generate_docs
5465def softsign(x):
5466 """Softsign of a tensor.
5468 Args:
5469 x: A tensor or variable.
5471 Returns:
5472 A tensor.
5473 """
5474 return tf.math.softsign(x)
5477def _get_logits(output, from_logits, op_type, fn_name):
5478 output_ = output
5479 from_logits_ = from_logits
5481 has_keras_logits = hasattr(output, "_keras_logits")
5482 if has_keras_logits:
5483 output_ = output._keras_logits
5484 from_logits_ = True
5486 from_expected_op_type = (
5487 not isinstance(output, (tf.__internal__.EagerTensor, tf.Variable))
5488 and output.op.type == op_type
5489 ) and not has_keras_logits
5491 if from_expected_op_type:
5492 # When softmax activation function is used for output operation, we
5493 # use logits from the softmax function directly to compute loss in order
5494 # to prevent collapsing zero when training.
5495 # See b/117284466
5496 assert len(output.op.inputs) == 1
5497 output_ = output.op.inputs[0]
5498 from_logits_ = True
5500 if from_logits and (has_keras_logits or from_expected_op_type):
5501 warnings.warn(
5502 f'"`{fn_name}` received `from_logits=True`, but '
5503 f"the `output` argument was produced by a {op_type} "
5504 "activation and thus does not represent logits. "
5505 "Was this intended?",
5506 stacklevel=2,
5507 )
5509 return output_, from_logits_
5512@keras_export("keras.backend.categorical_crossentropy")
5513@tf.__internal__.dispatch.add_dispatch_support
5514@doc_controls.do_not_generate_docs
5515def categorical_crossentropy(target, output, from_logits=False, axis=-1):
5516 """Categorical crossentropy between an output tensor and a target tensor.
5518 Args:
5519 target: A tensor of the same shape as `output`.
5520 output: A tensor resulting from a softmax
5521 (unless `from_logits` is True, in which
5522 case `output` is expected to be the logits).
5523 from_logits: Boolean, whether `output` is the
5524 result of a softmax, or is a tensor of logits.
5525 axis: Int specifying the channels axis. `axis=-1` corresponds to data
5526 format `channels_last`, and `axis=1` corresponds to data format
5527 `channels_first`.
5529 Returns:
5530 Output tensor.
5532 Raises:
5533 ValueError: if `axis` is neither -1 nor one of the axes of `output`.
5535 Example:
5537 >>> a = tf.constant([1., 0., 0., 0., 1., 0., 0., 0., 1.], shape=[3,3])
5538 >>> print(a)
5539 tf.Tensor(
5540 [[1. 0. 0.]
5541 [0. 1. 0.]
5542 [0. 0. 1.]], shape=(3, 3), dtype=float32)
5543 >>> b = tf.constant([.9, .05, .05, .05, .89, .06, .05, .01, .94],
5544 ... shape=[3, 3])
5545 >>> print(b)
5546 tf.Tensor(
5547 [[0.9 0.05 0.05]
5548 [0.05 0.89 0.06]
5549 [0.05 0.01 0.94]], shape=(3, 3), dtype=float32)
5550 >>> loss = tf.keras.backend.categorical_crossentropy(a, b)
5551 >>> print(np.around(loss, 5))
5552 [0.10536 0.11653 0.06188]
5553 >>> loss = tf.keras.backend.categorical_crossentropy(a, a)
5554 >>> print(np.around(loss, 5))
5555 [0. 0. 0.]
5557 """
5558 target = tf.convert_to_tensor(target)
5559 output = tf.convert_to_tensor(output)
5560 target.shape.assert_is_compatible_with(output.shape)
5562 output, from_logits = _get_logits(
5563 output, from_logits, "Softmax", "categorical_crossentropy"
5564 )
5565 if from_logits:
5566 return tf.nn.softmax_cross_entropy_with_logits(
5567 labels=target, logits=output, axis=axis
5568 )
5570 # Adjust the predictions so that the probability of
5571 # each class for every sample adds up to 1
5572 # This is needed to ensure that the cross entropy is
5573 # computed correctly.
5574 output = output / tf.reduce_sum(output, axis, True)
5576 # Compute cross entropy from probabilities.
5577 epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
5578 output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)
5579 return -tf.reduce_sum(target * tf.math.log(output), axis)
5582@keras_export("keras.backend.categorical_focal_crossentropy")
5583@tf.__internal__.dispatch.add_dispatch_support
5584@doc_controls.do_not_generate_docs
5585def categorical_focal_crossentropy(
5586 target,
5587 output,
5588 alpha=0.25,
5589 gamma=2.0,
5590 from_logits=False,
5591 axis=-1,
5592):
5593 """Computes the alpha balanced focal crossentropy loss.
5595 According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it
5596 helps to apply a focal factor to down-weight easy examples and focus more on
5597 hard examples. The general formula for the focal loss (FL)
5598 is as follows:
5600 `FL(p_t) = (1 − p_t)^gamma * log(p_t)`
5602 where `p_t` is defined as follows:
5603 `p_t = output if y_true == 1, else 1 - output`
5605 `(1 − p_t)^gamma` is the `modulating_factor`, where `gamma` is a focusing
5606 parameter. When `gamma` = 0, there is no focal effect on the cross entropy.
5607 `gamma` reduces the importance given to simple examples in a smooth manner.
5609 The authors use alpha-balanced variant of focal loss (FL) in the paper:
5610 `FL(p_t) = −alpha * (1 − p_t)^gamma * log(p_t)`
5612 where `alpha` is the weight factor for the classes. If `alpha` = 1, the
5613 loss won't be able to handle class imbalance properly as all
5614 classes will have the same weight. This can be a constant or a list of
5615 constants. If alpha is a list, it must have the same length as the number
5616 of classes.
5618 The formula above can be generalized to:
5619 `FL(p_t) = alpha * (1 − p_t)^gamma * CrossEntropy(target, output)`
5621 where minus comes from `CrossEntropy(target, output)` (CE).
5623 Extending this to multi-class case is straightforward:
5624 `FL(p_t) = alpha * (1 − p_t)^gamma * CategoricalCE(target, output)`
5626 Args:
5627 target: Ground truth values from the dataset.
5628 output: Predictions of the model.
5629 alpha: A weight balancing factor for all classes, default is `0.25` as
5630 mentioned in the reference. It can be a list of floats or a scalar.
5631 In the multi-class case, alpha may be set by inverse class
5632 frequency by using `compute_class_weight` from `sklearn.utils`.
5633 gamma: A focusing parameter, default is `2.0` as mentioned in the
5634 reference. It helps to gradually reduce the importance given to
5635 simple examples in a smooth manner.
5636 from_logits: Whether `output` is expected to be a logits tensor. By
5637 default, we consider that `output` encodes a probability
5638 distribution.
5639 axis: Int specifying the channels axis. `axis=-1` corresponds to data
5640 format `channels_last`, and `axis=1` corresponds to data format
5641 `channels_first`.
5643 Returns:
5644 A tensor.
5645 """
5646 target = tf.convert_to_tensor(target)
5647 output = tf.convert_to_tensor(output)
5648 target.shape.assert_is_compatible_with(output.shape)
5650 output, from_logits = _get_logits(
5651 output, from_logits, "Softmax", "categorical_focal_crossentropy"
5652 )
5654 if from_logits:
5655 output = tf.nn.softmax(output, axis=axis)
5657 # Adjust the predictions so that the probability of
5658 # each class for every sample adds up to 1
5659 # This is needed to ensure that the cross entropy is
5660 # computed correctly.
5661 output = output / tf.reduce_sum(output, axis=axis, keepdims=True)
5663 epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
5664 output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)
5666 # Calculate cross entropy
5667 cce = -target * tf.math.log(output)
5669 # Calculate factors
5670 modulating_factor = tf.pow(1.0 - output, gamma)
5671 weighting_factor = tf.multiply(modulating_factor, alpha)
5673 # Apply weighting factor
5674 focal_cce = tf.multiply(weighting_factor, cce)
5675 focal_cce = tf.reduce_sum(focal_cce, axis=axis)
5676 return focal_cce
5679@keras_export("keras.backend.sparse_categorical_crossentropy")
5680@tf.__internal__.dispatch.add_dispatch_support
5681@doc_controls.do_not_generate_docs
5682def sparse_categorical_crossentropy(
5683 target, output, from_logits=False, axis=-1, ignore_class=None
5684):
5685 """Categorical crossentropy with integer targets.
5687 Args:
5688 target: An integer tensor.
5689 output: A tensor resulting from a softmax
5690 (unless `from_logits` is True, in which
5691 case `output` is expected to be the logits).
5692 from_logits: Boolean, whether `output` is the
5693 result of a softmax, or is a tensor of logits.
5694 axis: Int specifying the channels axis. `axis=-1` corresponds to data
5695 format `channels_last`, and `axis=1` corresponds to data format
5696 `channels_first`.
5697 ignore_class: Optional integer. The ID of a class to be ignored
5698 during loss computation. This is useful, for example, in
5699 segmentation problems featuring a "void" class (commonly -1
5700 or 255) in segmentation maps.
5701 By default (`ignore_class=None`), all classes are considered.
5703 Returns:
5704 Output tensor.
5706 Raises:
5707 ValueError: if `axis` is neither -1 nor one of the axes of `output`.
5708 """
5709 target = tf.convert_to_tensor(target)
5710 output = tf.convert_to_tensor(output)
5712 target = cast(target, "int64")
5714 output, from_logits = _get_logits(
5715 output, from_logits, "Softmax", "sparse_categorical_crossentropy"
5716 )
5717 if not from_logits:
5718 epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
5719 output = tf.clip_by_value(output, epsilon_, 1 - epsilon_)
5720 output = tf.math.log(output)
5722 # Permute output so that the last axis contains the logits/probabilities.
5723 if isinstance(output.shape, (tuple, list)):
5724 output_rank = len(output.shape)
5725 else:
5726 output_rank = output.shape.ndims
5727 if output_rank is not None:
5728 axis %= output_rank
5729 if axis != output_rank - 1:
5730 permutation = list(
5731 itertools.chain(
5732 range(axis), range(axis + 1, output_rank), [axis]
5733 )
5734 )
5735 output = tf.compat.v1.transpose(output, perm=permutation)
5736 elif axis != -1:
5737 raise ValueError(
5738 "Cannot compute sparse categorical crossentropy with `axis={}` "
5739 "on an output tensor with unknown rank".format(axis)
5740 )
5742 # Try to adjust the shape so that rank of labels = rank of logits - 1.
5743 output_shape = tf.shape(output)
5744 target_rank = target.shape.ndims
5746 update_shape = (
5747 target_rank is not None
5748 and output_rank is not None
5749 and target_rank != output_rank - 1
5750 )
5751 if update_shape:
5752 target = flatten(target)
5753 output = tf.reshape(output, [-1, output_shape[-1]])
5755 if ignore_class is not None:
5756 valid_mask = tf.not_equal(target, cast(ignore_class, target.dtype))
5757 target = target[valid_mask]
5758 output = output[valid_mask]
5760 if py_any(_is_symbolic_tensor(v) for v in [target, output]):
5761 with get_graph().as_default():
5762 res = tf.nn.sparse_softmax_cross_entropy_with_logits(
5763 labels=target, logits=output
5764 )
5765 else:
5766 res = tf.nn.sparse_softmax_cross_entropy_with_logits(
5767 labels=target, logits=output
5768 )
5770 if ignore_class is not None:
5771 res_shape = cast(output_shape[:-1], "int64")
5772 valid_mask = tf.reshape(valid_mask, res_shape)
5773 res = tf.scatter_nd(tf.where(valid_mask), res, res_shape)
5774 res._keras_mask = valid_mask
5776 return res
5778 if update_shape and output_rank >= 3:
5779 # If our output includes timesteps or
5780 # spatial dimensions we need to reshape
5781 res = tf.reshape(res, output_shape[:-1])
5783 return res
5786@keras_export("keras.backend.binary_crossentropy")
5787@tf.__internal__.dispatch.add_dispatch_support
5788@doc_controls.do_not_generate_docs
5789def binary_crossentropy(target, output, from_logits=False):
5790 """Binary crossentropy between an output tensor and a target tensor.
5792 Args:
5793 target: A tensor with the same shape as `output`.
5794 output: A tensor.
5795 from_logits: Whether `output` is expected to be a logits tensor.
5796 By default, we consider that `output`
5797 encodes a probability distribution.
5799 Returns:
5800 A tensor.
5801 """
5802 target = tf.convert_to_tensor(target)
5803 output = tf.convert_to_tensor(output)
5805 output, from_logits = _get_logits(
5806 output, from_logits, "Sigmoid", "binary_crossentropy"
5807 )
5808 if from_logits:
5809 return tf.nn.sigmoid_cross_entropy_with_logits(
5810 labels=target, logits=output
5811 )
5813 epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
5814 output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)
5816 # Compute cross entropy from probabilities.
5817 bce = target * tf.math.log(output + epsilon())
5818 bce += (1 - target) * tf.math.log(1 - output + epsilon())
5819 return -bce
5822@keras_export("keras.backend.binary_focal_crossentropy")
5823@tf.__internal__.dispatch.add_dispatch_support
5824@doc_controls.do_not_generate_docs
5825def binary_focal_crossentropy(
5826 target,
5827 output,
5828 apply_class_balancing=False,
5829 alpha=0.25,
5830 gamma=2.0,
5831 from_logits=False,
5832):
5833 """Binary focal crossentropy between an output tensor and a target tensor.
5835 According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it
5836 helps to apply a focal factor to down-weight easy examples and focus more on
5837 hard examples. By default, the focal tensor is computed as follows:
5839 `focal_factor = (1 - output) ** gamma` for class 1
5840 `focal_factor = output ** gamma` for class 0
5841 where `gamma` is a focusing parameter. When `gamma` = 0, there is no focal
5842 effect on the binary crossentropy.
5844 If `apply_class_balancing == True`, this function also takes into account a
5845 weight balancing factor for the binary classes 0 and 1 as follows:
5847 `weight = alpha` for class 1 (`target == 1`)
5848 `weight = 1 - alpha` for class 0
5849 where `alpha` is a float in the range of `[0, 1]`.
5851 Args:
5852 target: A tensor with the same shape as `output`.
5853 output: A tensor.
5854 apply_class_balancing: A bool, whether to apply weight balancing on the
5855 binary classes 0 and 1.
5856 alpha: A weight balancing factor for class 1, default is `0.25` as
5857 mentioned in the reference. The weight for class 0 is `1.0 - alpha`.
5858 gamma: A focusing parameter, default is `2.0` as mentioned in the
5859 reference.
5860 from_logits: Whether `output` is expected to be a logits tensor. By
5861 default, we consider that `output` encodes a probability
5862 distribution.
5864 Returns:
5865 A tensor.
5866 """
5868 sigmoidal = sigmoid(output) if from_logits else output
5870 p_t = target * sigmoidal + (1 - target) * (1 - sigmoidal)
5872 # Calculate focal factor
5873 focal_factor = tf.pow(1.0 - p_t, gamma)
5875 # Binary crossentropy
5876 bce = binary_crossentropy(
5877 target=target,
5878 output=output,
5879 from_logits=from_logits,
5880 )
5881 focal_bce = focal_factor * bce
5883 if apply_class_balancing:
5884 weight = target * alpha + (1 - target) * (1 - alpha)
5885 focal_bce = weight * focal_bce
5887 return focal_bce
5890@keras_export("keras.backend.sigmoid")
5891@tf.__internal__.dispatch.add_dispatch_support
5892@doc_controls.do_not_generate_docs
5893def sigmoid(x):
5894 """Element-wise sigmoid.
5896 Args:
5897 x: A tensor or variable.
5899 Returns:
5900 A tensor.
5901 """
5902 return tf.math.sigmoid(x)
5905@keras_export("keras.backend.hard_sigmoid")
5906@tf.__internal__.dispatch.add_dispatch_support
5907@doc_controls.do_not_generate_docs
5908def hard_sigmoid(x):
5909 """Segment-wise linear approximation of sigmoid.
5911 Faster than sigmoid.
5912 Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`.
5913 In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`.
5915 Args:
5916 x: A tensor or variable.
5918 Returns:
5919 A tensor.
5920 """
5921 point_two = _constant_to_tensor(0.2, x.dtype.base_dtype)
5922 point_five = _constant_to_tensor(0.5, x.dtype.base_dtype)
5923 x = tf.multiply(x, point_two)
5924 x = tf.add(x, point_five)
5925 x = tf.clip_by_value(x, 0.0, 1.0)
5926 return x
5929@keras_export("keras.backend.tanh")
5930@tf.__internal__.dispatch.add_dispatch_support
5931@doc_controls.do_not_generate_docs
5932def tanh(x):
5933 """Element-wise tanh.
5935 Args:
5936 x: A tensor or variable.
5938 Returns:
5939 A tensor.
5940 """
5941 return tf.tanh(x)
5944@keras_export("keras.backend.dropout")
5945@tf.__internal__.dispatch.add_dispatch_support
5946@doc_controls.do_not_generate_docs
5947def dropout(x, level, noise_shape=None, seed=None):
5948 """Sets entries in `x` to zero at random, while scaling the entire tensor.
5950 Args:
5951 x: tensor
5952 level: fraction of the entries in the tensor
5953 that will be set to 0.
5954 noise_shape: shape for randomly generated keep/drop flags,
5955 must be broadcastable to the shape of `x`
5956 seed: random seed to ensure determinism.
5958 Returns:
5959 A tensor.
5960 """
5961 if seed is None:
5962 seed = np.random.randint(10e6)
5963 return tf.nn.dropout(x, rate=level, noise_shape=noise_shape, seed=seed)
5966@keras_export("keras.backend.l2_normalize")
5967@tf.__internal__.dispatch.add_dispatch_support
5968@doc_controls.do_not_generate_docs
5969def l2_normalize(x, axis=None):
5970 """Normalizes a tensor wrt the L2 norm alongside the specified axis.
5972 Args:
5973 x: Tensor or variable.
5974 axis: axis along which to perform normalization.
5976 Returns:
5977 A tensor.
5978 """
5979 return tf.linalg.l2_normalize(x, axis=axis)
5982@keras_export("keras.backend.in_top_k")
5983@tf.__internal__.dispatch.add_dispatch_support
5984@doc_controls.do_not_generate_docs
5985def in_top_k(predictions, targets, k):
5986 """Returns whether the `targets` are in the top `k` `predictions`.
5988 Args:
5989 predictions: A tensor of shape `(batch_size, classes)` and type
5990 `float32`.
5991 targets: A 1D tensor of length `batch_size` and type `int32` or `int64`.
5992 k: An `int`, number of top elements to consider.
5994 Returns:
5995 A 1D tensor of length `batch_size` and type `bool`.
5996 `output[i]` is `True` if `predictions[i, targets[i]]` is within top-`k`
5997 values of `predictions[i]`.
5998 """
5999 return tf.compat.v1.math.in_top_k(predictions, targets, k)
6002# CONVOLUTIONS
6005def _preprocess_conv1d_input(x, data_format):
6006 """Transpose and cast the input before the conv1d.
6008 Args:
6009 x: input tensor.
6010 data_format: string, `"channels_last"` or `"channels_first"`.
6012 Returns:
6013 A tensor.
6014 """
6015 tf_data_format = "NWC" # to pass TF Conv2dNative operations
6016 if data_format == "channels_first":
6017 if not _has_nchw_support():
6018 x = tf.compat.v1.transpose(x, (0, 2, 1)) # NCW -> NWC
6019 else:
6020 tf_data_format = "NCW"
6021 return x, tf_data_format
6024def _preprocess_conv2d_input(x, data_format, force_transpose=False):
6025 """Transpose and cast the input before the conv2d.
6027 Args:
6028 x: input tensor.
6029 data_format: string, `"channels_last"` or `"channels_first"`.
6030 force_transpose: Boolean. If True, the input will always be transposed
6031 from NCHW to NHWC if `data_format` is `"channels_first"`.
6032 If False, the transposition only occurs on CPU (GPU ops are
6033 assumed to support NCHW).
6035 Returns:
6036 A tensor.
6037 """
6038 tf_data_format = "NHWC"
6039 if data_format == "channels_first":
6040 if not _has_nchw_support() or force_transpose:
6041 x = tf.compat.v1.transpose(x, (0, 2, 3, 1)) # NCHW -> NHWC
6042 else:
6043 tf_data_format = "NCHW"
6044 return x, tf_data_format
6047def _preprocess_conv3d_input(x, data_format):
6048 """Transpose and cast the input before the conv3d.
6050 Args:
6051 x: input tensor.
6052 data_format: string, `"channels_last"` or `"channels_first"`.
6054 Returns:
6055 A tensor.
6056 """
6057 tf_data_format = "NDHWC"
6058 if data_format == "channels_first":
6059 if not _has_nchw_support():
6060 x = tf.compat.v1.transpose(x, (0, 2, 3, 4, 1))
6061 else:
6062 tf_data_format = "NCDHW"
6063 return x, tf_data_format
6066def _preprocess_padding(padding):
6067 """Convert keras' padding to TensorFlow's padding.
6069 Args:
6070 padding: string, one of 'same' , 'valid'
6072 Returns:
6073 a string, one of 'SAME', 'VALID'.
6075 Raises:
6076 ValueError: if invalid `padding'`
6077 """
6078 if padding == "same":
6079 padding = "SAME"
6080 elif padding == "valid":
6081 padding = "VALID"
6082 else:
6083 raise ValueError("Invalid padding: " + str(padding))
6084 return padding
6087@keras_export("keras.backend.conv1d")
6088@tf.__internal__.dispatch.add_dispatch_support
6089@doc_controls.do_not_generate_docs
6090def conv1d(
6091 x, kernel, strides=1, padding="valid", data_format=None, dilation_rate=1
6092):
6093 """1D convolution.
6095 Args:
6096 x: Tensor or variable.
6097 kernel: kernel tensor.
6098 strides: stride integer.
6099 padding: string, `"same"`, `"causal"` or `"valid"`.
6100 data_format: string, one of "channels_last", "channels_first".
6101 dilation_rate: integer dilate rate.
6103 Returns:
6104 A tensor, result of 1D convolution.
6106 Raises:
6107 ValueError: if `data_format` is neither `channels_last` or
6108 `channels_first`.
6109 """
6110 if data_format is None:
6111 data_format = image_data_format()
6112 if data_format not in {"channels_first", "channels_last"}:
6113 raise ValueError("Unknown data_format: " + str(data_format))
6115 kernel_shape = kernel.shape.as_list()
6116 if padding == "causal":
6117 # causal (dilated) convolution:
6118 left_pad = dilation_rate * (kernel_shape[0] - 1)
6119 x = temporal_padding(x, (left_pad, 0))
6120 padding = "valid"
6121 padding = _preprocess_padding(padding)
6123 x, tf_data_format = _preprocess_conv1d_input(x, data_format)
6124 x = tf.compat.v1.nn.convolution(
6125 input=x,
6126 filter=kernel,
6127 dilation_rate=dilation_rate,
6128 strides=strides,
6129 padding=padding,
6130 data_format=tf_data_format,
6131 )
6132 if data_format == "channels_first" and tf_data_format == "NWC":
6133 x = tf.compat.v1.transpose(x, (0, 2, 1)) # NWC -> NCW
6134 return x
6137@keras_export("keras.backend.conv2d")
6138@tf.__internal__.dispatch.add_dispatch_support
6139@doc_controls.do_not_generate_docs
6140def conv2d(
6141 x,
6142 kernel,
6143 strides=(1, 1),
6144 padding="valid",
6145 data_format=None,
6146 dilation_rate=(1, 1),
6147):
6148 """2D convolution.
6150 Args:
6151 x: Tensor or variable.
6152 kernel: kernel tensor.
6153 strides: strides tuple.
6154 padding: string, `"same"` or `"valid"`.
6155 data_format: `"channels_last"` or `"channels_first"`.
6156 dilation_rate: tuple of 2 integers.
6158 Returns:
6159 A tensor, result of 2D convolution.
6161 Raises:
6162 ValueError: if `data_format` is neither `channels_last` or
6163 `channels_first`.
6164 """
6165 if data_format is None:
6166 data_format = image_data_format()
6167 if data_format not in {"channels_first", "channels_last"}:
6168 raise ValueError("Unknown data_format: " + str(data_format))
6170 x, tf_data_format = _preprocess_conv2d_input(x, data_format)
6171 padding = _preprocess_padding(padding)
6172 x = tf.compat.v1.nn.convolution(
6173 input=x,
6174 filter=kernel,
6175 dilation_rate=dilation_rate,
6176 strides=strides,
6177 padding=padding,
6178 data_format=tf_data_format,
6179 )
6180 if data_format == "channels_first" and tf_data_format == "NHWC":
6181 x = tf.compat.v1.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
6182 return x
6185@keras_export("keras.backend.conv2d_transpose")
6186@tf.__internal__.dispatch.add_dispatch_support
6187@doc_controls.do_not_generate_docs
6188def conv2d_transpose(
6189 x,
6190 kernel,
6191 output_shape,
6192 strides=(1, 1),
6193 padding="valid",
6194 data_format=None,
6195 dilation_rate=(1, 1),
6196):
6197 """2D deconvolution (i.e.
6199 transposed convolution).
6201 Args:
6202 x: Tensor or variable.
6203 kernel: kernel tensor.
6204 output_shape: 1D int tensor for the output shape.
6205 strides: strides tuple.
6206 padding: string, `"same"` or `"valid"`.
6207 data_format: string, `"channels_last"` or `"channels_first"`.
6208 dilation_rate: Tuple of 2 integers.
6210 Returns:
6211 A tensor, result of transposed 2D convolution.
6213 Raises:
6214 ValueError: if `data_format` is neither `channels_last` or
6215 `channels_first`.
6216 """
6217 if data_format is None:
6218 data_format = image_data_format()
6219 if data_format not in {"channels_first", "channels_last"}:
6220 raise ValueError("Unknown data_format: " + str(data_format))
6222 # `atrous_conv2d_transpose` only supports NHWC format, even on GPU.
6223 if data_format == "channels_first" and dilation_rate != (1, 1):
6224 force_transpose = True
6225 else:
6226 force_transpose = False
6228 x, tf_data_format = _preprocess_conv2d_input(
6229 x, data_format, force_transpose
6230 )
6232 if data_format == "channels_first" and tf_data_format == "NHWC":
6233 output_shape = (
6234 output_shape[0],
6235 output_shape[2],
6236 output_shape[3],
6237 output_shape[1],
6238 )
6239 if output_shape[0] is None:
6240 output_shape = (shape(x)[0],) + tuple(output_shape[1:])
6242 if isinstance(output_shape, (tuple, list)):
6243 output_shape = tf.stack(list(output_shape))
6245 padding = _preprocess_padding(padding)
6246 if tf_data_format == "NHWC":
6247 strides = (1,) + strides + (1,)
6248 else:
6249 strides = (1, 1) + strides
6251 if dilation_rate == (1, 1):
6252 x = tf.compat.v1.nn.conv2d_transpose(
6253 x,
6254 kernel,
6255 output_shape,
6256 strides,
6257 padding=padding,
6258 data_format=tf_data_format,
6259 )
6260 else:
6261 if dilation_rate[0] != dilation_rate[1]:
6262 raise ValueError(
6263 "Expected the 2 dimensions of the `dilation_rate` argument "
6264 "to be equal to each other. "
6265 f"Received: dilation_rate={dilation_rate}"
6266 )
6267 x = tf.nn.atrous_conv2d_transpose(
6268 x, kernel, output_shape, rate=dilation_rate[0], padding=padding
6269 )
6270 if data_format == "channels_first" and tf_data_format == "NHWC":
6271 x = tf.compat.v1.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
6272 return x
6275def separable_conv1d(
6276 x,
6277 depthwise_kernel,
6278 pointwise_kernel,
6279 strides=1,
6280 padding="valid",
6281 data_format=None,
6282 dilation_rate=1,
6283):
6284 """1D convolution with separable filters.
6286 Args:
6287 x: input tensor
6288 depthwise_kernel: convolution kernel for the depthwise convolution.
6289 pointwise_kernel: kernel for the 1x1 convolution.
6290 strides: stride integer.
6291 padding: string, `"same"` or `"valid"`.
6292 data_format: string, `"channels_last"` or `"channels_first"`.
6293 dilation_rate: integer dilation rate.
6295 Returns:
6296 Output tensor.
6298 Raises:
6299 ValueError: if `data_format` is neither `channels_last` or
6300 `channels_first`.
6301 """
6302 if data_format is None:
6303 data_format = image_data_format()
6304 if data_format not in {"channels_first", "channels_last"}:
6305 raise ValueError("Unknown data_format: " + str(data_format))
6307 if isinstance(strides, int):
6308 strides = (strides,)
6309 if isinstance(dilation_rate, int):
6310 dilation_rate = (dilation_rate,)
6312 x, tf_data_format = _preprocess_conv1d_input(x, data_format)
6313 padding = _preprocess_padding(padding)
6314 if not isinstance(strides, tuple):
6315 strides = tuple(strides)
6316 if tf_data_format == "NWC":
6317 spatial_start_dim = 1
6318 strides = (1,) + strides * 2 + (1,)
6319 else:
6320 spatial_start_dim = 2
6321 strides = (1, 1) + strides * 2
6322 x = tf.expand_dims(x, spatial_start_dim)
6323 depthwise_kernel = tf.expand_dims(depthwise_kernel, 0)
6324 pointwise_kernel = tf.expand_dims(pointwise_kernel, 0)
6325 dilation_rate = (1,) + dilation_rate
6327 x = tf.nn.separable_conv2d(
6328 x,
6329 depthwise_kernel,
6330 pointwise_kernel,
6331 strides=strides,
6332 padding=padding,
6333 dilations=dilation_rate,
6334 data_format=tf_data_format,
6335 )
6337 x = tf.squeeze(x, [spatial_start_dim])
6339 if data_format == "channels_first" and tf_data_format == "NWC":
6340 x = tf.compat.v1.transpose(x, (0, 2, 1)) # NWC -> NCW
6342 return x
6345@keras_export("keras.backend.separable_conv2d")
6346@tf.__internal__.dispatch.add_dispatch_support
6347@doc_controls.do_not_generate_docs
6348def separable_conv2d(
6349 x,
6350 depthwise_kernel,
6351 pointwise_kernel,
6352 strides=(1, 1),
6353 padding="valid",
6354 data_format=None,
6355 dilation_rate=(1, 1),
6356):
6357 """2D convolution with separable filters.
6359 Args:
6360 x: input tensor
6361 depthwise_kernel: convolution kernel for the depthwise convolution.
6362 pointwise_kernel: kernel for the 1x1 convolution.
6363 strides: strides tuple (length 2).
6364 padding: string, `"same"` or `"valid"`.
6365 data_format: string, `"channels_last"` or `"channels_first"`.
6366 dilation_rate: tuple of integers,
6367 dilation rates for the separable convolution.
6369 Returns:
6370 Output tensor.
6372 Raises:
6373 ValueError: if `data_format` is neither `channels_last` or
6374 `channels_first`.
6375 ValueError: if `strides` is not a tuple of 2 integers.
6376 """
6377 if data_format is None:
6378 data_format = image_data_format()
6379 if data_format not in {"channels_first", "channels_last"}:
6380 raise ValueError("Unknown data_format: " + str(data_format))
6381 if len(strides) != 2:
6382 raise ValueError("`strides` must be a tuple of 2 integers.")
6384 x, tf_data_format = _preprocess_conv2d_input(x, data_format)
6385 padding = _preprocess_padding(padding)
6386 if not isinstance(strides, tuple):
6387 strides = tuple(strides)
6388 if tf_data_format == "NHWC":
6389 strides = (1,) + strides + (1,)
6390 else:
6391 strides = (1, 1) + strides
6393 x = tf.nn.separable_conv2d(
6394 x,
6395 depthwise_kernel,
6396 pointwise_kernel,
6397 strides=strides,
6398 padding=padding,
6399 dilations=dilation_rate,
6400 data_format=tf_data_format,
6401 )
6402 if data_format == "channels_first" and tf_data_format == "NHWC":
6403 x = tf.compat.v1.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
6404 return x
6407@keras_export("keras.backend.depthwise_conv2d")
6408@tf.__internal__.dispatch.add_dispatch_support
6409@doc_controls.do_not_generate_docs
6410def depthwise_conv2d(
6411 x,
6412 depthwise_kernel,
6413 strides=(1, 1),
6414 padding="valid",
6415 data_format=None,
6416 dilation_rate=(1, 1),
6417):
6418 """2D convolution with separable filters.
6420 Args:
6421 x: input tensor
6422 depthwise_kernel: convolution kernel for the depthwise convolution.
6423 strides: strides tuple (length 2).
6424 padding: string, `"same"` or `"valid"`.
6425 data_format: string, `"channels_last"` or `"channels_first"`.
6426 dilation_rate: tuple of integers,
6427 dilation rates for the separable convolution.
6429 Returns:
6430 Output tensor.
6432 Raises:
6433 ValueError: if `data_format` is neither `channels_last` or
6434 `channels_first`.
6435 """
6436 if data_format is None:
6437 data_format = image_data_format()
6438 if data_format not in {"channels_first", "channels_last"}:
6439 raise ValueError("Unknown data_format: " + str(data_format))
6441 x, tf_data_format = _preprocess_conv2d_input(x, data_format)
6442 padding = _preprocess_padding(padding)
6443 if tf_data_format == "NHWC":
6444 strides = (1,) + strides + (1,)
6445 else:
6446 strides = (1, 1) + strides
6448 x = tf.nn.depthwise_conv2d(
6449 x,
6450 depthwise_kernel,
6451 strides=strides,
6452 padding=padding,
6453 dilations=dilation_rate,
6454 data_format=tf_data_format,
6455 )
6456 if data_format == "channels_first" and tf_data_format == "NHWC":
6457 x = tf.compat.v1.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
6458 return x
6461@keras_export("keras.backend.conv3d")
6462@tf.__internal__.dispatch.add_dispatch_support
6463@doc_controls.do_not_generate_docs
6464def conv3d(
6465 x,
6466 kernel,
6467 strides=(1, 1, 1),
6468 padding="valid",
6469 data_format=None,
6470 dilation_rate=(1, 1, 1),
6471):
6472 """3D convolution.
6474 Args:
6475 x: Tensor or variable.
6476 kernel: kernel tensor.
6477 strides: strides tuple.
6478 padding: string, `"same"` or `"valid"`.
6479 data_format: string, `"channels_last"` or `"channels_first"`.
6480 dilation_rate: tuple of 3 integers.
6482 Returns:
6483 A tensor, result of 3D convolution.
6485 Raises:
6486 ValueError: if `data_format` is neither `channels_last` or
6487 `channels_first`.
6488 """
6489 if data_format is None:
6490 data_format = image_data_format()
6491 if data_format not in {"channels_first", "channels_last"}:
6492 raise ValueError("Unknown data_format: " + str(data_format))
6494 x, tf_data_format = _preprocess_conv3d_input(x, data_format)
6495 padding = _preprocess_padding(padding)
6496 x = tf.compat.v1.nn.convolution(
6497 input=x,
6498 filter=kernel,
6499 dilation_rate=dilation_rate,
6500 strides=strides,
6501 padding=padding,
6502 data_format=tf_data_format,
6503 )
6504 if data_format == "channels_first" and tf_data_format == "NDHWC":
6505 x = tf.compat.v1.transpose(x, (0, 4, 1, 2, 3))
6506 return x
6509def conv3d_transpose(
6510 x,
6511 kernel,
6512 output_shape,
6513 strides=(1, 1, 1),
6514 padding="valid",
6515 data_format=None,
6516):
6517 """3D deconvolution (i.e.
6519 transposed convolution).
6521 Args:
6522 x: input tensor.
6523 kernel: kernel tensor.
6524 output_shape: 1D int tensor for the output shape.
6525 strides: strides tuple.
6526 padding: string, "same" or "valid".
6527 data_format: string, `"channels_last"` or `"channels_first"`.
6529 Returns:
6530 A tensor, result of transposed 3D convolution.
6532 Raises:
6533 ValueError: if `data_format` is neither `channels_last` or
6534 `channels_first`.
6535 """
6536 if data_format is None:
6537 data_format = image_data_format()
6538 if data_format not in {"channels_first", "channels_last"}:
6539 raise ValueError("Unknown data_format: " + str(data_format))
6540 if isinstance(output_shape, (tuple, list)):
6541 output_shape = tf.stack(output_shape)
6543 x, tf_data_format = _preprocess_conv3d_input(x, data_format)
6545 if data_format == "channels_first" and tf_data_format == "NDHWC":
6546 output_shape = (
6547 output_shape[0],
6548 output_shape[2],
6549 output_shape[3],
6550 output_shape[4],
6551 output_shape[1],
6552 )
6553 if output_shape[0] is None:
6554 output_shape = (tf.shape(x)[0],) + tuple(output_shape[1:])
6555 output_shape = tf.stack(list(output_shape))
6557 padding = _preprocess_padding(padding)
6558 if tf_data_format == "NDHWC":
6559 strides = (1,) + strides + (1,)
6560 else:
6561 strides = (1, 1) + strides
6563 x = tf.compat.v1.nn.conv3d_transpose(
6564 x,
6565 kernel,
6566 output_shape,
6567 strides,
6568 padding=padding,
6569 data_format=tf_data_format,
6570 )
6571 if data_format == "channels_first" and tf_data_format == "NDHWC":
6572 x = tf.compat.v1.transpose(x, (0, 4, 1, 2, 3))
6573 return x
6576@keras_export("keras.backend.pool2d")
6577@tf.__internal__.dispatch.add_dispatch_support
6578@doc_controls.do_not_generate_docs
6579def pool2d(
6580 x,
6581 pool_size,
6582 strides=(1, 1),
6583 padding="valid",
6584 data_format=None,
6585 pool_mode="max",
6586):
6587 """2D Pooling.
6589 Args:
6590 x: Tensor or variable.
6591 pool_size: tuple of 2 integers.
6592 strides: tuple of 2 integers.
6593 padding: string, `"same"` or `"valid"`.
6594 data_format: string, `"channels_last"` or `"channels_first"`.
6595 pool_mode: string, `"max"` or `"avg"`.
6597 Returns:
6598 A tensor, result of 2D pooling.
6600 Raises:
6601 ValueError: if `data_format` is neither `"channels_last"` or
6602 `"channels_first"`.
6603 ValueError: if `pool_size` is not a tuple of 2 integers.
6604 ValueError: if `strides` is not a tuple of 2 integers.
6605 ValueError: if `pool_mode` is neither `"max"` or `"avg"`.
6606 """
6607 if data_format is None:
6608 data_format = image_data_format()
6609 if data_format not in {"channels_first", "channels_last"}:
6610 raise ValueError("Unknown data_format: " + str(data_format))
6611 if len(pool_size) != 2:
6612 raise ValueError("`pool_size` must be a tuple of 2 integers.")
6613 if len(strides) != 2:
6614 raise ValueError("`strides` must be a tuple of 2 integers.")
6616 x, tf_data_format = _preprocess_conv2d_input(x, data_format)
6617 padding = _preprocess_padding(padding)
6618 if tf_data_format == "NHWC":
6619 strides = (1,) + strides + (1,)
6620 pool_size = (1,) + pool_size + (1,)
6621 else:
6622 strides = (1, 1) + strides
6623 pool_size = (1, 1) + pool_size
6625 if pool_mode == "max":
6626 x = tf.compat.v1.nn.max_pool(
6627 x, pool_size, strides, padding=padding, data_format=tf_data_format
6628 )
6629 elif pool_mode == "avg":
6630 x = tf.compat.v1.nn.avg_pool(
6631 x, pool_size, strides, padding=padding, data_format=tf_data_format
6632 )
6633 else:
6634 raise ValueError("Invalid pooling mode: " + str(pool_mode))
6636 if data_format == "channels_first" and tf_data_format == "NHWC":
6637 x = tf.compat.v1.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
6638 return x
6641@keras_export("keras.backend.pool3d")
6642@tf.__internal__.dispatch.add_dispatch_support
6643@doc_controls.do_not_generate_docs
6644def pool3d(
6645 x,
6646 pool_size,
6647 strides=(1, 1, 1),
6648 padding="valid",
6649 data_format=None,
6650 pool_mode="max",
6651):
6652 """3D Pooling.
6654 Args:
6655 x: Tensor or variable.
6656 pool_size: tuple of 3 integers.
6657 strides: tuple of 3 integers.
6658 padding: string, `"same"` or `"valid"`.
6659 data_format: string, `"channels_last"` or `"channels_first"`.
6660 pool_mode: string, `"max"` or `"avg"`.
6662 Returns:
6663 A tensor, result of 3D pooling.
6665 Raises:
6666 ValueError: if `data_format` is neither `"channels_last"` or
6667 `"channels_first"`.
6668 ValueError: if `pool_mode` is neither `"max"` or `"avg"`.
6669 """
6670 if data_format is None:
6671 data_format = image_data_format()
6672 if data_format not in {"channels_first", "channels_last"}:
6673 raise ValueError("Unknown data_format: " + str(data_format))
6675 x, tf_data_format = _preprocess_conv3d_input(x, data_format)
6676 padding = _preprocess_padding(padding)
6677 if tf_data_format == "NDHWC":
6678 strides = (1,) + strides + (1,)
6679 pool_size = (1,) + pool_size + (1,)
6680 else:
6681 strides = (1, 1) + strides
6682 pool_size = (1, 1) + pool_size
6684 if pool_mode == "max":
6685 x = tf.nn.max_pool3d(
6686 x, pool_size, strides, padding=padding, data_format=tf_data_format
6687 )
6688 elif pool_mode == "avg":
6689 x = tf.nn.avg_pool3d(
6690 x, pool_size, strides, padding=padding, data_format=tf_data_format
6691 )
6692 else:
6693 raise ValueError("Invalid pooling mode: " + str(pool_mode))
6695 if data_format == "channels_first" and tf_data_format == "NDHWC":
6696 x = tf.compat.v1.transpose(x, (0, 4, 1, 2, 3))
6697 return x
6700def local_conv(
6701 inputs, kernel, kernel_size, strides, output_shape, data_format=None
6702):
6703 """Apply N-D convolution with un-shared weights.
6705 Args:
6706 inputs: (N+2)-D tensor with shape
6707 (batch_size, channels_in, d_in1, ..., d_inN)
6708 if data_format='channels_first', or
6709 (batch_size, d_in1, ..., d_inN, channels_in)
6710 if data_format='channels_last'.
6711 kernel: the unshared weight for N-D convolution,
6712 with shape (output_items, feature_dim, channels_out), where
6713 feature_dim = np.prod(kernel_size) * channels_in,
6714 output_items = np.prod(output_shape).
6715 kernel_size: a tuple of N integers, specifying the
6716 spatial dimensions of the N-D convolution window.
6717 strides: a tuple of N integers, specifying the strides
6718 of the convolution along the spatial dimensions.
6719 output_shape: a tuple of (d_out1, ..., d_outN) specifying the spatial
6720 dimensionality of the output.
6721 data_format: string, "channels_first" or "channels_last".
6723 Returns:
6724 An (N+2)-D tensor with shape:
6725 (batch_size, channels_out) + output_shape
6726 if data_format='channels_first', or:
6727 (batch_size,) + output_shape + (channels_out,)
6728 if data_format='channels_last'.
6730 Raises:
6731 ValueError: if `data_format` is neither
6732 `channels_last` nor `channels_first`.
6733 """
6734 if data_format is None:
6735 data_format = image_data_format()
6736 if data_format not in {"channels_first", "channels_last"}:
6737 raise ValueError("Unknown data_format: " + str(data_format))
6739 kernel_shape = int_shape(kernel)
6740 feature_dim = kernel_shape[1]
6741 channels_out = kernel_shape[-1]
6742 ndims = len(output_shape)
6743 spatial_dimensions = list(range(ndims))
6745 xs = []
6746 output_axes_ticks = [range(axis_max) for axis_max in output_shape]
6747 for position in itertools.product(*output_axes_ticks):
6748 slices = [slice(None)]
6750 if data_format == "channels_first":
6751 slices.append(slice(None))
6753 slices.extend(
6754 slice(
6755 position[d] * strides[d],
6756 position[d] * strides[d] + kernel_size[d],
6757 )
6758 for d in spatial_dimensions
6759 )
6761 if data_format == "channels_last":
6762 slices.append(slice(None))
6764 xs.append(reshape(inputs[slices], (1, -1, feature_dim)))
6766 x_aggregate = concatenate(xs, axis=0)
6767 output = batch_dot(x_aggregate, kernel)
6768 output = reshape(output, output_shape + (-1, channels_out))
6770 if data_format == "channels_first":
6771 permutation = [ndims, ndims + 1] + spatial_dimensions
6772 else:
6773 permutation = [ndims] + spatial_dimensions + [ndims + 1]
6775 return permute_dimensions(output, permutation)
6778@keras_export("keras.backend.local_conv1d")
6779@tf.__internal__.dispatch.add_dispatch_support
6780@doc_controls.do_not_generate_docs
6781def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
6782 """Apply 1D conv with un-shared weights.
6784 Args:
6785 inputs: 3D tensor with shape:
6786 (batch_size, steps, input_dim)
6787 if data_format is "channels_last" or
6788 (batch_size, input_dim, steps)
6789 if data_format is "channels_first".
6790 kernel: the unshared weight for convolution,
6791 with shape (output_length, feature_dim, filters).
6792 kernel_size: a tuple of a single integer,
6793 specifying the length of the 1D convolution window.
6794 strides: a tuple of a single integer,
6795 specifying the stride length of the convolution.
6796 data_format: the data format, channels_first or channels_last.
6798 Returns:
6799 A 3d tensor with shape:
6800 (batch_size, output_length, filters)
6801 if data_format='channels_first'
6802 or 3D tensor with shape:
6803 (batch_size, filters, output_length)
6804 if data_format='channels_last'.
6805 """
6806 output_shape = (kernel.shape[0],)
6807 return local_conv(
6808 inputs, kernel, kernel_size, strides, output_shape, data_format
6809 )
6812@keras_export("keras.backend.local_conv2d")
6813@tf.__internal__.dispatch.add_dispatch_support
6814@doc_controls.do_not_generate_docs
6815def local_conv2d(
6816 inputs, kernel, kernel_size, strides, output_shape, data_format=None
6817):
6818 """Apply 2D conv with un-shared weights.
6820 Args:
6821 inputs: 4D tensor with shape:
6822 (batch_size, filters, new_rows, new_cols)
6823 if data_format='channels_first'
6824 or 4D tensor with shape:
6825 (batch_size, new_rows, new_cols, filters)
6826 if data_format='channels_last'.
6827 kernel: the unshared weight for convolution,
6828 with shape (output_items, feature_dim, filters).
6829 kernel_size: a tuple of 2 integers, specifying the
6830 width and height of the 2D convolution window.
6831 strides: a tuple of 2 integers, specifying the strides
6832 of the convolution along the width and height.
6833 output_shape: a tuple with (output_row, output_col).
6834 data_format: the data format, channels_first or channels_last.
6836 Returns:
6837 A 4D tensor with shape:
6838 (batch_size, filters, new_rows, new_cols)
6839 if data_format='channels_first'
6840 or 4D tensor with shape:
6841 (batch_size, new_rows, new_cols, filters)
6842 if data_format='channels_last'.
6843 """
6844 return local_conv(
6845 inputs, kernel, kernel_size, strides, output_shape, data_format
6846 )
6849@keras_export("keras.backend.bias_add")
6850@tf.__internal__.dispatch.add_dispatch_support
6851@doc_controls.do_not_generate_docs
6852def bias_add(x, bias, data_format=None):
6853 """Adds a bias vector to a tensor.
6855 Args:
6856 x: Tensor or variable.
6857 bias: Bias tensor to add.
6858 data_format: string, `"channels_last"` or `"channels_first"`.
6860 Returns:
6861 Output tensor.
6863 Raises:
6864 ValueError: In one of the two cases below:
6865 1. invalid `data_format` argument.
6866 2. invalid bias shape.
6867 the bias should be either a vector or
6868 a tensor with ndim(x) - 1 dimension
6869 """
6870 if data_format is None:
6871 data_format = image_data_format()
6872 if data_format not in {"channels_first", "channels_last"}:
6873 raise ValueError("Unknown data_format: " + str(data_format))
6874 bias_shape = int_shape(bias)
6875 if len(bias_shape) != 1 and len(bias_shape) != ndim(x) - 1:
6876 raise ValueError(
6877 "Unexpected bias dimensions %d, expect to be 1 or %d dimensions"
6878 % (len(bias_shape), ndim(x) - 1)
6879 )
6881 if len(bias_shape) == 1:
6882 if data_format == "channels_first":
6883 return tf.nn.bias_add(x, bias, data_format="NCHW")
6884 return tf.nn.bias_add(x, bias, data_format="NHWC")
6885 if ndim(x) in (3, 4, 5):
6886 if data_format == "channels_first":
6887 bias_reshape_axis = (1, bias_shape[-1]) + bias_shape[:-1]
6888 return x + reshape(bias, bias_reshape_axis)
6889 return x + reshape(bias, (1,) + bias_shape)
6890 return tf.nn.bias_add(x, bias)
6893# RANDOMNESS
6896@keras_export("keras.backend.random_normal")
6897@tf.__internal__.dispatch.add_dispatch_support
6898@doc_controls.do_not_generate_docs
6899def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
6900 """Returns a tensor with normal distribution of values.
6902 It is an alias to `tf.random.normal`.
6904 Args:
6905 shape: A tuple of integers, the shape of tensor to create.
6906 mean: A float, the mean value of the normal distribution to draw
6907 samples. Defaults to `0.0`.
6908 stddev: A float, the standard deviation of the normal distribution
6909 to draw samples. Defaults to `1.0`.
6910 dtype: `tf.dtypes.DType`, dtype of returned tensor. None uses Keras
6911 backend dtype which is float32. Defaults to `None`.
6912 seed: Integer, random seed. Will use a random numpy integer when not
6913 specified.
6915 Returns:
6916 A tensor with normal distribution of values.
6918 Example:
6920 >>> random_normal_tensor = tf.keras.backend.random_normal(shape=(2,3),
6921 ... mean=0.0, stddev=1.0)
6922 >>> random_normal_tensor
6923 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...,
6924 dtype=float32)>
6925 """
6926 if dtype is None:
6927 dtype = floatx()
6928 if seed is None:
6929 seed = np.random.randint(10e6)
6930 return tf.random.normal(
6931 shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed
6932 )
6935@keras_export("keras.backend.random_uniform")
6936@tf.__internal__.dispatch.add_dispatch_support
6937@doc_controls.do_not_generate_docs
6938def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
6939 """Returns a tensor with uniform distribution of values.
6941 Args:
6942 shape: A tuple of integers, the shape of tensor to create.
6943 minval: A float, lower boundary of the uniform distribution
6944 to draw samples.
6945 maxval: A float, upper boundary of the uniform distribution
6946 to draw samples.
6947 dtype: String, dtype of returned tensor.
6948 seed: Integer, random seed.
6950 Returns:
6951 A tensor.
6953 Example:
6955 >>> random_uniform_tensor = tf.keras.backend.random_uniform(shape=(2,3),
6956 ... minval=0.0, maxval=1.0)
6957 >>> random_uniform_tensor
6958 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...,
6959 dtype=float32)>
6960 """
6961 if dtype is None:
6962 dtype = floatx()
6963 if seed is None:
6964 seed = np.random.randint(10e6)
6965 return tf.random.uniform(
6966 shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed
6967 )
6970@keras_export("keras.backend.random_binomial")
6971@tf.__internal__.dispatch.add_dispatch_support
6972@doc_controls.do_not_generate_docs
6973def random_binomial(shape, p=0.0, dtype=None, seed=None):
6974 """Returns a tensor with random binomial distribution of values.
6976 DEPRECATED, use `tf.keras.backend.random_bernoulli` instead.
6978 The binomial distribution with parameters `n` and `p` is the probability
6979 distribution of the number of successful Bernoulli process. Only supports
6980 `n` = 1 for now.
6982 Args:
6983 shape: A tuple of integers, the shape of tensor to create.
6984 p: A float, `0. <= p <= 1`, probability of binomial distribution.
6985 dtype: String, dtype of returned tensor.
6986 seed: Integer, random seed.
6988 Returns:
6989 A tensor.
6991 Example:
6993 >>> random_binomial_tensor = tf.keras.backend.random_binomial(shape=(2,3),
6994 ... p=0.5)
6995 >>> random_binomial_tensor
6996 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...,
6997 dtype=float32)>
6998 """
6999 warnings.warn(
7000 "`tf.keras.backend.random_binomial` is deprecated, "
7001 "and will be removed in a future version."
7002 "Please use `tf.keras.backend.random_bernoulli` instead.",
7003 stacklevel=2,
7004 )
7005 return random_bernoulli(shape, p, dtype, seed)
7008@keras_export("keras.backend.random_bernoulli")
7009@tf.__internal__.dispatch.add_dispatch_support
7010@doc_controls.do_not_generate_docs
7011def random_bernoulli(shape, p=0.0, dtype=None, seed=None):
7012 """Returns a tensor with random bernoulli distribution of values.
7014 Args:
7015 shape: A tuple of integers, the shape of tensor to create.
7016 p: A float, `0. <= p <= 1`, probability of bernoulli distribution.
7017 dtype: String, dtype of returned tensor.
7018 seed: Integer, random seed.
7020 Returns:
7021 A tensor.
7022 """
7023 if dtype is None:
7024 dtype = floatx()
7025 if seed is None:
7026 seed = np.random.randint(10e6)
7027 return tf.where(
7028 tf.random.uniform(shape, dtype=dtype, seed=seed) <= p,
7029 tf.ones(shape, dtype=dtype),
7030 tf.zeros(shape, dtype=dtype),
7031 )
7034@keras_export("keras.backend.truncated_normal")
7035@tf.__internal__.dispatch.add_dispatch_support
7036@doc_controls.do_not_generate_docs
7037def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
7038 """Returns a tensor with truncated random normal distribution of values.
7040 The generated values follow a normal distribution
7041 with specified mean and standard deviation,
7042 except that values whose magnitude is more than
7043 two standard deviations from the mean are dropped and re-picked.
7045 Args:
7046 shape: A tuple of integers, the shape of tensor to create.
7047 mean: Mean of the values.
7048 stddev: Standard deviation of the values.
7049 dtype: String, dtype of returned tensor.
7050 seed: Integer, random seed.
7052 Returns:
7053 A tensor.
7054 """
7055 if dtype is None:
7056 dtype = floatx()
7057 if seed is None:
7058 seed = np.random.randint(10e6)
7059 return tf.random.truncated_normal(
7060 shape, mean, stddev, dtype=dtype, seed=seed
7061 )
7064# CTC
7065# TensorFlow has a native implementation, but it uses sparse tensors
7066# and therefore requires a wrapper for Keras. The functions below convert
7067# dense to sparse tensors and also wraps up the beam search code that is
7068# in TensorFlow's CTC implementation
7071@keras_export("keras.backend.ctc_label_dense_to_sparse")
7072@tf.__internal__.dispatch.add_dispatch_support
7073@doc_controls.do_not_generate_docs
7074def ctc_label_dense_to_sparse(labels, label_lengths):
7075 """Converts CTC labels from dense to sparse.
7077 Args:
7078 labels: dense CTC labels.
7079 label_lengths: length of the labels.
7081 Returns:
7082 A sparse tensor representation of the labels.
7083 """
7084 label_shape = tf.shape(labels)
7085 num_batches_tns = tf.stack([label_shape[0]])
7086 max_num_labels_tns = tf.stack([label_shape[1]])
7088 def range_less_than(old_input, current_input):
7089 return tf.expand_dims(tf.range(tf.shape(old_input)[1]), 0) < tf.fill(
7090 max_num_labels_tns, current_input
7091 )
7093 init = tf.cast(tf.fill([1, label_shape[1]], 0), tf.bool)
7094 dense_mask = tf.compat.v1.scan(
7095 range_less_than, label_lengths, initializer=init, parallel_iterations=1
7096 )
7097 dense_mask = dense_mask[:, 0, :]
7099 label_array = tf.reshape(
7100 tf.tile(tf.range(0, label_shape[1]), num_batches_tns), label_shape
7101 )
7102 label_ind = tf.compat.v1.boolean_mask(label_array, dense_mask)
7104 batch_array = tf.compat.v1.transpose(
7105 tf.reshape(
7106 tf.tile(tf.range(0, label_shape[0]), max_num_labels_tns),
7107 reverse(label_shape, 0),
7108 )
7109 )
7110 batch_ind = tf.compat.v1.boolean_mask(batch_array, dense_mask)
7111 indices = tf.compat.v1.transpose(
7112 tf.reshape(concatenate([batch_ind, label_ind], axis=0), [2, -1])
7113 )
7115 vals_sparse = tf.compat.v1.gather_nd(labels, indices)
7117 return tf.SparseTensor(
7118 tf.cast(indices, tf.int64), vals_sparse, tf.cast(label_shape, tf.int64)
7119 )
7122@keras_export("keras.backend.ctc_batch_cost")
7123@tf.__internal__.dispatch.add_dispatch_support
7124@doc_controls.do_not_generate_docs
7125def ctc_batch_cost(y_true, y_pred, input_length, label_length):
7126 """Runs CTC loss algorithm on each batch element.
7128 Args:
7129 y_true: tensor `(samples, max_string_length)`
7130 containing the truth labels.
7131 y_pred: tensor `(samples, time_steps, num_categories)`
7132 containing the prediction, or output of the softmax.
7133 input_length: tensor `(samples, 1)` containing the sequence length for
7134 each batch item in `y_pred`.
7135 label_length: tensor `(samples, 1)` containing the sequence length for
7136 each batch item in `y_true`.
7138 Returns:
7139 Tensor with shape (samples,1) containing the
7140 CTC loss of each element.
7141 """
7142 label_length = tf.cast(tf.squeeze(label_length, axis=-1), tf.int32)
7143 input_length = tf.cast(tf.squeeze(input_length, axis=-1), tf.int32)
7144 sparse_labels = tf.cast(
7145 ctc_label_dense_to_sparse(y_true, label_length), tf.int32
7146 )
7148 y_pred = tf.math.log(
7149 tf.compat.v1.transpose(y_pred, perm=[1, 0, 2]) + epsilon()
7150 )
7152 return tf.expand_dims(
7153 tf.compat.v1.nn.ctc_loss(
7154 inputs=y_pred, labels=sparse_labels, sequence_length=input_length
7155 ),
7156 1,
7157 )
7160@keras_export("keras.backend.ctc_decode")
7161@tf.__internal__.dispatch.add_dispatch_support
7162@doc_controls.do_not_generate_docs
7163def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
7164 """Decodes the output of a softmax.
7166 Can use either greedy search (also known as best path)
7167 or a constrained dictionary search.
7169 Args:
7170 y_pred: tensor `(samples, time_steps, num_categories)`
7171 containing the prediction, or output of the softmax.
7172 input_length: tensor `(samples, )` containing the sequence length for
7173 each batch item in `y_pred`.
7174 greedy: perform much faster best-path search if `true`.
7175 This does not use a dictionary.
7176 beam_width: if `greedy` is `false`: a beam search decoder will be used
7177 with a beam of this width.
7178 top_paths: if `greedy` is `false`,
7179 how many of the most probable paths will be returned.
7181 Returns:
7182 Tuple:
7183 List: if `greedy` is `true`, returns a list of one element that
7184 contains the decoded sequence.
7185 If `false`, returns the `top_paths` most probable
7186 decoded sequences.
7187 Each decoded sequence has shape (samples, time_steps).
7188 Important: blank labels are returned as `-1`.
7189 Tensor `(top_paths, )` that contains
7190 the log probability of each decoded sequence.
7191 """
7192 input_shape = shape(y_pred)
7193 num_samples, num_steps = input_shape[0], input_shape[1]
7194 y_pred = tf.math.log(
7195 tf.compat.v1.transpose(y_pred, perm=[1, 0, 2]) + epsilon()
7196 )
7197 input_length = tf.cast(input_length, tf.int32)
7199 if greedy:
7200 (decoded, log_prob) = tf.nn.ctc_greedy_decoder(
7201 inputs=y_pred, sequence_length=input_length
7202 )
7203 else:
7204 (decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder(
7205 inputs=y_pred,
7206 sequence_length=input_length,
7207 beam_width=beam_width,
7208 top_paths=top_paths,
7209 )
7210 decoded_dense = []
7211 for st in decoded:
7212 st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps))
7213 decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1))
7214 return (decoded_dense, log_prob)
7217# HIGH ORDER FUNCTIONS
7220@keras_export("keras.backend.map_fn")
7221@doc_controls.do_not_generate_docs
7222def map_fn(fn, elems, name=None, dtype=None):
7223 """Map the function fn over the elements elems and return the outputs.
7225 Args:
7226 fn: Callable that will be called upon each element in elems
7227 elems: tensor
7228 name: A string name for the map node in the graph
7229 dtype: Output data type.
7231 Returns:
7232 Tensor with dtype `dtype`.
7233 """
7234 return tf.compat.v1.map_fn(fn, elems, name=name, dtype=dtype)
7237@keras_export("keras.backend.foldl")
7238@doc_controls.do_not_generate_docs
7239def foldl(fn, elems, initializer=None, name=None):
7240 """Reduce elems using fn to combine them from left to right.
7242 Args:
7243 fn: Callable that will be called upon each element in elems and an
7244 accumulator, for instance `lambda acc, x: acc + x`
7245 elems: tensor
7246 initializer: The first value used (`elems[0]` in case of None)
7247 name: A string name for the foldl node in the graph
7249 Returns:
7250 Tensor with same type and shape as `initializer`.
7251 """
7252 return tf.compat.v1.foldl(fn, elems, initializer=initializer, name=name)
7255@keras_export("keras.backend.foldr")
7256@doc_controls.do_not_generate_docs
7257def foldr(fn, elems, initializer=None, name=None):
7258 """Reduce elems using fn to combine them from right to left.
7260 Args:
7261 fn: Callable that will be called upon each element in elems and an
7262 accumulator, for instance `lambda acc, x: acc + x`
7263 elems: tensor
7264 initializer: The first value used (`elems[-1]` in case of None)
7265 name: A string name for the foldr node in the graph
7267 Returns:
7268 Same type and shape as initializer
7269 """
7270 return tf.compat.v1.foldr(fn, elems, initializer=initializer, name=name)
7273# Load Keras default configuration from config file if present.
7274# Set Keras base dir path given KERAS_HOME env variable, if applicable.
7275# Otherwise either ~/.keras or /tmp.
7276if "KERAS_HOME" in os.environ:
7277 _keras_dir = os.environ.get("KERAS_HOME")
7278else:
7279 _keras_base_dir = os.path.expanduser("~")
7280 _keras_dir = os.path.join(_keras_base_dir, ".keras")
7281_config_path = os.path.expanduser(os.path.join(_keras_dir, "keras.json"))
7282if os.path.exists(_config_path):
7283 try:
7284 with open(_config_path) as fh:
7285 _config = json.load(fh)
7286 except ValueError:
7287 _config = {}
7288 _floatx = _config.get("floatx", floatx())
7289 assert _floatx in {"float16", "float32", "float64"}
7290 _epsilon = _config.get("epsilon", epsilon())
7291 assert isinstance(_epsilon, float)
7292 _image_data_format = _config.get("image_data_format", image_data_format())
7293 assert _image_data_format in {"channels_last", "channels_first"}
7294 set_floatx(_floatx)
7295 set_epsilon(_epsilon)
7296 set_image_data_format(_image_data_format)
7298# Save config file.
7299if not os.path.exists(_keras_dir):
7300 try:
7301 os.makedirs(_keras_dir)
7302 except OSError:
7303 # Except permission denied and potential race conditions
7304 # in multi-threaded environments.
7305 pass
7307if not os.path.exists(_config_path):
7308 _config = {
7309 "floatx": floatx(),
7310 "epsilon": epsilon(),
7311 "backend": "tensorflow",
7312 "image_data_format": image_data_format(),
7313 }
7314 try:
7315 with open(_config_path, "w") as f:
7316 f.write(json.dumps(_config, indent=4))
7317 except IOError:
7318 # Except permission denied.
7319 pass
7322def configure_and_create_distributed_session(distribution_strategy):
7323 """Configure session config and create a session with it."""
7325 def _create_session(distribution_strategy):
7326 """Create the Distributed Strategy session."""
7327 session_config = get_default_session_config()
7329 # If a session already exists, merge in its config; in the case there is
7330 # a conflict, take values of the existing config.
7331 global _SESSION
7332 if getattr(_SESSION, "session", None) and _SESSION.session._config:
7333 session_config.MergeFrom(_SESSION.session._config)
7335 if is_tpu_strategy(distribution_strategy):
7336 # TODO(priyag, yuefengz): Remove this workaround when Distribute
7337 # Coordinator is integrated with keras and we can create a session
7338 # from there.
7339 distribution_strategy.configure(session_config)
7340 master = (
7341 distribution_strategy.extended._tpu_cluster_resolver.master()
7342 )
7343 session = tf.compat.v1.Session(config=session_config, target=master)
7344 else:
7345 worker_context = dc.get_current_worker_context()
7346 if worker_context:
7347 dc_session_config = worker_context.session_config
7348 # Merge the default session config to the one from distribute
7349 # coordinator, which is fine for now since they don't have
7350 # conflicting configurations.
7351 dc_session_config.MergeFrom(session_config)
7352 session = tf.compat.v1.Session(
7353 config=dc_session_config,
7354 target=worker_context.master_target,
7355 )
7356 else:
7357 distribution_strategy.configure(session_config)
7358 session = tf.compat.v1.Session(config=session_config)
7360 set_session(session)
7362 if distribution_strategy.extended._in_multi_worker_mode():
7363 dc.run_distribute_coordinator(_create_session, distribution_strategy)
7364 else:
7365 _create_session(distribution_strategy)
7368def _is_tpu_strategy_class(clz):
7369 is_tpu_strat = lambda k: k.__name__.startswith("TPUStrategy")
7370 if is_tpu_strat(clz):
7371 return True
7372 return py_any(map(_is_tpu_strategy_class, clz.__bases__))
7375def is_tpu_strategy(strategy):
7376 """Returns whether input is a TPUStrategy instance or subclass instance."""
7377 return _is_tpu_strategy_class(strategy.__class__)
7380def _is_symbolic_tensor(x):
7381 return tf.is_tensor(x) and not isinstance(x, tf.__internal__.EagerTensor)
7384def convert_inputs_if_ragged(inputs):
7385 """Converts any ragged tensors to dense."""
7387 def _convert_ragged_input(inputs):
7388 if isinstance(inputs, tf.RaggedTensor):
7389 return inputs.to_tensor()
7390 return inputs
7392 flat_inputs = tf.nest.flatten(inputs)
7393 contains_ragged = py_any(
7394 isinstance(i, tf.RaggedTensor) for i in flat_inputs
7395 )
7397 if not contains_ragged:
7398 return inputs, None
7400 inputs = tf.nest.map_structure(_convert_ragged_input, inputs)
7401 # Multiple mask are not yet supported, so one mask is used on all inputs.
7402 # We approach this similarly when using row lengths to ignore steps.
7403 nested_row_lengths = tf.cast(
7404 flat_inputs[0].nested_row_lengths()[0], "int32"
7405 )
7406 return inputs, nested_row_lengths
7409def maybe_convert_to_ragged(
7410 is_ragged_input, output, nested_row_lengths, go_backwards=False
7411):
7412 """Converts any ragged input back to its initial structure."""
7413 if not is_ragged_input:
7414 return output
7416 if go_backwards:
7417 # Reverse based on the timestep dim, so that nested_row_lengths will
7418 # mask from the correct direction. Return the reverse ragged tensor.
7419 output = reverse(output, [1])
7420 ragged = tf.RaggedTensor.from_tensor(output, nested_row_lengths)
7421 return reverse(ragged, [1])
7422 else:
7423 return tf.RaggedTensor.from_tensor(output, nested_row_lengths)
7426class ContextValueCache(weakref.WeakKeyDictionary):
7427 """Container that caches (possibly tensor) values based on the context.
7429 This class is similar to defaultdict, where values may be produced by the
7430 default factory specified during initialization. This class also has a
7431 default value for the key (when key is `None`) -- the key is set to the
7432 current graph or eager context. The default factories for key and value are
7433 only used in `__getitem__` and `setdefault`. The `.get()` behavior remains
7434 the same.
7436 This object will return the value of the current graph or closest parent
7437 graph if the current graph is a function. This is to reflect the fact that
7438 if a tensor is created in eager/graph, child functions may capture that
7439 tensor.
7441 The default factory method may accept keyword arguments (unlike defaultdict,
7442 which only accepts callables with 0 arguments). To pass keyword arguments to
7443 `default_factory`, use the `setdefault` method instead of `__getitem__`.
7445 An example of how this class can be used in different contexts:
7447 ```
7448 cache = ContextValueCache(int)
7450 # Eager mode
7451 cache[None] += 2
7452 cache[None] += 4
7453 assert cache[None] == 6
7455 # Graph mode
7456 with tf.Graph().as_default() as g:
7457 cache[None] += 5
7458 cache[g] += 3
7459 assert cache[g] == 8
7460 ```
7462 Example of a default factory with arguments:
7464 ```
7465 cache = ContextValueCache(lambda x: x + 1)
7466 g = tf.get_default_graph()
7468 # Example with keyword argument.
7469 value = cache.setdefault(key=g, kwargs={'x': 3})
7470 assert cache[g] == 4
7471 ```
7472 """
7474 def __init__(self, default_factory):
7475 self.default_factory = default_factory
7476 weakref.WeakKeyDictionary.__init__(self)
7478 def _key(self):
7479 if tf.executing_eagerly():
7480 return _DUMMY_EAGER_GRAPH.key
7481 else:
7482 return tf.compat.v1.get_default_graph()
7484 def _get_parent_graph(self, graph):
7485 """Returns the parent graph or dummy eager object."""
7486 # TODO(b/149317164): Currently FuncGraphs use ops.get_default_graph() as
7487 # the outer graph. This results in outer_graph always being a Graph,
7488 # even in eager mode (get_default_graph will create a new Graph if there
7489 # isn't a default graph). Because of this bug, we have to specially set
7490 # the key when eager execution is enabled.
7491 parent_graph = graph.outer_graph
7492 if (
7493 not isinstance(parent_graph, tf.__internal__.FuncGraph)
7494 and tf.compat.v1.executing_eagerly_outside_functions()
7495 ):
7496 return _DUMMY_EAGER_GRAPH.key
7497 return parent_graph
7499 def _get_recursive(self, key):
7500 """Gets the value at key or the closest parent graph."""
7501 value = self.get(key)
7502 if value is not None:
7503 return value
7505 # Since FuncGraphs are able to capture tensors and variables from their
7506 # parent graphs, recursively search to see if there is a value stored
7507 # for one of the parent graphs.
7508 if isinstance(key, tf.__internal__.FuncGraph):
7509 return self._get_recursive(self._get_parent_graph(key))
7510 return None
7512 def __getitem__(self, key):
7513 """Gets the value at key (or current context), or sets default value.
7515 Args:
7516 key: May be `None` or `Graph`object. When `None`, the key is set to
7517 the current context.
7519 Returns:
7520 Either the cached or default value.
7521 """
7522 if key is None:
7523 key = self._key()
7525 value = self._get_recursive(key)
7526 if value is None:
7527 value = self[key] = self.default_factory()
7528 return value
7530 def setdefault(self, key=None, default=None, kwargs=None):
7531 """Sets the default value if key is not in dict, and returns the
7532 value."""
7533 if key is None:
7534 key = self._key()
7535 kwargs = kwargs or {}
7537 if default is None and key not in self:
7538 default = self.default_factory(**kwargs)
7539 return weakref.WeakKeyDictionary.setdefault(self, key, default)
7542# This dictionary holds a mapping {graph: learning_phase}. In eager mode, a
7543# dummy object is used.
7544# A learning phase is a bool tensor used to run Keras models in
7545# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
7546_GRAPH_LEARNING_PHASES = ContextValueCache(
7547 object_identity.ObjectIdentityWeakSet
7548)
7550# This dictionary holds a mapping between a graph and variables to initialize
7551# in the graph.
7552_GRAPH_VARIABLES = ContextValueCache(object_identity.ObjectIdentityWeakSet)
7554# This dictionary holds a mapping between a graph and TF optimizers created in
7555# the graph.
7556_GRAPH_TF_OPTIMIZERS = ContextValueCache(object_identity.ObjectIdentityWeakSet)