Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/backend.py: 30%
2306 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16# pylint: disable=redefined-outer-name
17# pylint: disable=redefined-builtin
18# pylint: disable=g-classes-have-attributes
19"""Keras backend API."""
21import collections
22import itertools
23import json
24import os
25import sys
26import threading
27import warnings
28import weakref
30import numpy as np
32from tensorflow.core.protobuf import config_pb2
33from tensorflow.python import tf2
34from tensorflow.python.client import session as session_module
35from tensorflow.python.distribute import distribute_lib
36from tensorflow.python.eager import context
37from tensorflow.python.eager.context import get_config
38from tensorflow.python.framework import composite_tensor
39from tensorflow.python.framework import config
40from tensorflow.python.framework import constant_op
41from tensorflow.python.framework import device_spec
42from tensorflow.python.framework import dtypes as dtypes_module
43from tensorflow.python.framework import func_graph
44from tensorflow.python.framework import ops
45from tensorflow.python.framework import sparse_tensor
46from tensorflow.python.framework import tensor_conversion
47from tensorflow.python.framework import tensor_shape
48from tensorflow.python.framework import tensor_spec
49from tensorflow.python.framework import tensor_util
50from tensorflow.python.keras import backend_config
51from tensorflow.python.keras.distribute import distribute_coordinator_utils as dc
52from tensorflow.python.keras.engine import keras_tensor
53from tensorflow.python.keras.utils import control_flow_util
54from tensorflow.python.keras.utils import object_identity
55from tensorflow.python.keras.utils import tf_contextlib
56from tensorflow.python.keras.utils import tf_inspect
57from tensorflow.python.ops import array_ops
58from tensorflow.python.ops import array_ops_stack
59from tensorflow.python.ops import clip_ops
60from tensorflow.python.ops import cond
61from tensorflow.python.ops import control_flow_ops
62from tensorflow.python.ops import ctc_ops as ctc
63from tensorflow.python.ops import functional_ops
64from tensorflow.python.ops import gradients as gradients_module
65from tensorflow.python.ops import image_ops
66from tensorflow.python.ops import init_ops
67from tensorflow.python.ops import linalg_ops
68from tensorflow.python.ops import logging_ops
69from tensorflow.python.ops import map_fn as map_fn_lib
70from tensorflow.python.ops import math_ops
71from tensorflow.python.ops import nn
72from tensorflow.python.ops import random_ops
73from tensorflow.python.ops import sparse_ops
74from tensorflow.python.ops import state_ops
75from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
76from tensorflow.python.ops import tensor_array_ops
77from tensorflow.python.ops import variable_v1
78from tensorflow.python.ops import variables as variables_module
79from tensorflow.python.ops import while_loop
80from tensorflow.python.ops.ragged import ragged_tensor
81from tensorflow.python.platform import tf_logging as logging
82from tensorflow.python.training import moving_averages
83from tensorflow.python.util import dispatch
84from tensorflow.python.util import nest
85from tensorflow.python.util.tf_export import keras_export
86from tensorflow.tools.docs import doc_controls
88py_all = all
89py_sum = sum
90py_any = any
92# INTERNAL UTILS
94# The internal graph maintained by Keras and used by the symbolic Keras APIs
95# while executing eagerly (such as the functional API for model-building).
96# This is thread-local to allow building separate models in different threads
97# concurrently, but comes at the cost of not being able to build one model
98# across threads.
99_GRAPH = threading.local()
101# A graph which is used for constructing functions in eager mode.
102_CURRENT_SCRATCH_GRAPH = threading.local()
104# This is a thread local object that will hold the default internal TF session
105# used by Keras. It can be set manually via `set_session(sess)`.
106_SESSION = threading.local()
109# A global dictionary mapping graph objects to an index of counters used
110# for various layer/optimizer names in each graph.
111# Allows to give unique autogenerated names to layers, in a graph-specific way.
112PER_GRAPH_OBJECT_NAME_UIDS = weakref.WeakKeyDictionary()
115# A global set tracking what object names have been seen so far.
116# Optionally used as an avoid-list when generating names
117OBSERVED_NAMES = set()
120# _DUMMY_EAGER_GRAPH.key is used as a key in _GRAPH_LEARNING_PHASES.
121# We keep a separate reference to it to make sure it does not get removed from
122# _GRAPH_LEARNING_PHASES.
123# _DummyEagerGraph inherits from threading.local to make its `key` attribute
124# thread local. This is needed to make set_learning_phase affect only the
125# current thread during eager execution (see b/123096885 for more details).
126class _DummyEagerGraph(threading.local):
127 """_DummyEagerGraph provides a thread local `key` attribute.
129 We can't use threading.local directly, i.e. without subclassing, because
130 gevent monkey patches threading.local and its version does not support
131 weak references.
132 """
134 class _WeakReferencableClass:
135 """This dummy class is needed for two reasons.
137 - We need something that supports weak references. Basic types like string
138 and ints don't.
139 - We need something whose hash and equality are based on object identity
140 to make sure they are treated as different keys to _GRAPH_LEARNING_PHASES.
142 An empty Python class satisfies both of these requirements.
143 """
144 pass
146 def __init__(self):
147 # Constructors for classes subclassing threading.local run once
148 # per thread accessing something in the class. Thus, each thread will
149 # get a different key.
150 super(_DummyEagerGraph, self).__init__()
151 self.key = _DummyEagerGraph._WeakReferencableClass()
152 self.learning_phase_is_set = False
155_DUMMY_EAGER_GRAPH = _DummyEagerGraph()
157# This boolean flag can be set to True to leave variable initialization
158# up to the user.
159# Change its value via `manual_variable_initialization(value)`.
160_MANUAL_VAR_INIT = False
162# This list holds the available devices.
163# It is populated when `_get_available_gpus()` is called for the first time.
164# We assume our devices don't change henceforth.
165_LOCAL_DEVICES = None
167# The below functions are kept accessible from backend for compatibility.
168epsilon = backend_config.epsilon
169floatx = backend_config.floatx
170image_data_format = backend_config.image_data_format
171set_epsilon = backend_config.set_epsilon
172set_floatx = backend_config.set_floatx
173set_image_data_format = backend_config.set_image_data_format
176@keras_export('keras.backend.backend')
177@doc_controls.do_not_generate_docs
178def backend():
179 """Publicly accessible method for determining the current backend.
181 Only exists for API compatibility with multi-backend Keras.
183 Returns:
184 The string "tensorflow".
185 """
186 return 'tensorflow'
189@keras_export('keras.backend.cast_to_floatx')
190@dispatch.add_dispatch_support
191@doc_controls.do_not_generate_docs
192def cast_to_floatx(x):
193 """Cast a Numpy array to the default Keras float type.
195 Args:
196 x: Numpy array or TensorFlow tensor.
198 Returns:
199 The same array (Numpy array if `x` was a Numpy array, or TensorFlow tensor
200 if `x` was a tensor), cast to its new type.
202 Example:
204 >>> tf.keras.backend.floatx()
205 'float32'
206 >>> arr = np.array([1.0, 2.0], dtype='float64')
207 >>> arr.dtype
208 dtype('float64')
209 >>> new_arr = cast_to_floatx(arr)
210 >>> new_arr
211 array([1., 2.], dtype=float32)
212 >>> new_arr.dtype
213 dtype('float32')
215 """
216 if isinstance(x, (ops.Tensor,
217 variables_module.Variable,
218 sparse_tensor.SparseTensor)):
219 return math_ops.cast(x, dtype=floatx())
220 return np.asarray(x, dtype=floatx())
223@keras_export('keras.backend.get_uid')
224def get_uid(prefix=''):
225 """Associates a string prefix with an integer counter in a TensorFlow graph.
227 Args:
228 prefix: String prefix to index.
230 Returns:
231 Unique integer ID.
233 Example:
235 >>> get_uid('dense')
236 1
237 >>> get_uid('dense')
238 2
240 """
241 graph = get_graph()
242 if graph not in PER_GRAPH_OBJECT_NAME_UIDS:
243 PER_GRAPH_OBJECT_NAME_UIDS[graph] = collections.defaultdict(int)
244 layer_name_uids = PER_GRAPH_OBJECT_NAME_UIDS[graph]
245 layer_name_uids[prefix] += 1
246 return layer_name_uids[prefix]
249@keras_export('keras.backend.reset_uids')
250def reset_uids():
251 """Resets graph identifiers.
252 """
254 PER_GRAPH_OBJECT_NAME_UIDS.clear()
255 OBSERVED_NAMES.clear()
258@keras_export('keras.backend.clear_session')
259def clear_session():
260 """Resets all state generated by Keras.
262 Keras manages a global state, which it uses to implement the Functional
263 model-building API and to uniquify autogenerated layer names.
265 If you are creating many models in a loop, this global state will consume
266 an increasing amount of memory over time, and you may want to clear it.
267 Calling `clear_session()` releases the global state: this helps avoid clutter
268 from old models and layers, especially when memory is limited.
270 Example 1: calling `clear_session()` when creating models in a loop
272 ```python
273 for _ in range(100):
274 # Without `clear_session()`, each iteration of this loop will
275 # slightly increase the size of the global state managed by Keras
276 model = tf.keras.Sequential([tf.keras.layers.Dense(10) for _ in range(10)])
278 for _ in range(100):
279 # With `clear_session()` called at the beginning,
280 # Keras starts with a blank state at each iteration
281 # and memory consumption is constant over time.
282 tf.keras.backend.clear_session()
283 model = tf.keras.Sequential([tf.keras.layers.Dense(10) for _ in range(10)])
284 ```
286 Example 2: resetting the layer name generation counter
288 >>> import tensorflow as tf
289 >>> layers = [tf.keras.layers.Dense(10) for _ in range(10)]
290 >>> new_layer = tf.keras.layers.Dense(10)
291 >>> print(new_layer.name)
292 dense_10
293 >>> tf.keras.backend.set_learning_phase(1)
294 >>> print(tf.keras.backend.learning_phase())
295 1
296 >>> tf.keras.backend.clear_session()
297 >>> new_layer = tf.keras.layers.Dense(10)
298 >>> print(new_layer.name)
299 dense
300 """
301 global _SESSION
302 global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned
303 global _GRAPH_VARIABLES # pylint: disable=global-variable-not-assigned
304 global _GRAPH_TF_OPTIMIZERS # pylint: disable=global-variable-not-assigned
305 global _GRAPH
306 _GRAPH.graph = None
307 ops.reset_default_graph()
308 reset_uids()
309 _SESSION.session = None
310 graph = get_graph()
311 with graph.as_default():
312 _DUMMY_EAGER_GRAPH.learning_phase_is_set = False
313 _GRAPH_LEARNING_PHASES.clear()
314 # Create the learning phase placeholder in graph using the default factory.
315 _GRAPH_LEARNING_PHASES.setdefault(graph)
316 _GRAPH_VARIABLES.pop(graph, None)
317 _GRAPH_TF_OPTIMIZERS.pop(graph, None)
318 if context.executing_eagerly():
319 # Clear pending nodes in eager executors, kernel caches and step_containers.
320 context.context().clear_kernel_cache()
323@keras_export('keras.backend.manual_variable_initialization')
324@doc_controls.do_not_generate_docs
325def manual_variable_initialization(value):
326 """Sets the manual variable initialization flag.
328 This boolean flag determines whether
329 variables should be initialized
330 as they are instantiated (default), or if
331 the user should handle the initialization
332 (e.g. via `tf.compat.v1.initialize_all_variables()`).
334 Args:
335 value: Python boolean.
336 """
337 global _MANUAL_VAR_INIT
338 _MANUAL_VAR_INIT = value
341@keras_export('keras.backend.learning_phase')
342@doc_controls.do_not_generate_docs
343def learning_phase():
344 """Returns the learning phase flag.
346 The learning phase flag is a bool tensor (0 = test, 1 = train)
347 to be passed as input to any Keras function
348 that uses a different behavior at train time and test time.
350 Returns:
351 Learning phase (scalar integer tensor or Python integer).
352 """
353 graph = ops.get_default_graph()
354 if graph is getattr(_GRAPH, 'graph', None):
355 # Don't enter an init_scope for the learning phase if eager execution
356 # is enabled but we're inside the Keras workspace graph.
357 learning_phase = symbolic_learning_phase()
358 else:
359 with ops.init_scope():
360 # We always check & set the learning phase inside the init_scope,
361 # otherwise the wrong default_graph will be used to look up the learning
362 # phase inside of functions & defuns.
363 #
364 # This is because functions & defuns (both in graph & in eager mode)
365 # will always execute non-eagerly using a function-specific default
366 # subgraph.
367 learning_phase = _GRAPH_LEARNING_PHASES[None]
368 _mark_func_graph_as_unsaveable(graph, learning_phase)
369 return learning_phase
372def global_learning_phase_is_set():
373 return _DUMMY_EAGER_GRAPH.learning_phase_is_set
376def _mark_func_graph_as_unsaveable(graph, learning_phase):
377 """Mark func graph as unsaveable due to use of symbolic keras learning phase.
379 Functions that capture the symbolic learning phase cannot be exported to
380 SavedModel. Mark the funcgraph as unsaveable, so that an error will be raised
381 if it is exported.
383 Args:
384 graph: Graph or FuncGraph object.
385 learning_phase: Learning phase placeholder or int defined in the graph.
386 """
387 if graph.building_function and is_placeholder(learning_phase):
388 graph.mark_as_unsaveable(
389 'The keras learning phase placeholder was used inside a function. '
390 'Exporting placeholders is not supported when saving out a SavedModel. '
391 'Please call `tf.keras.backend.set_learning_phase(0)` in the function '
392 'to set the learning phase to a constant value.')
395def symbolic_learning_phase():
396 graph = get_graph()
397 with graph.as_default():
398 return _GRAPH_LEARNING_PHASES[graph]
401def _default_learning_phase():
402 if context.executing_eagerly():
403 return 0
404 else:
405 with name_scope(''):
406 return array_ops.placeholder_with_default(
407 False, shape=(), name='keras_learning_phase')
410@keras_export('keras.backend.set_learning_phase')
411@doc_controls.do_not_generate_docs
412def set_learning_phase(value):
413 """Sets the learning phase to a fixed value.
415 The backend learning phase affects any code that calls
416 `backend.learning_phase()`
417 In particular, all Keras built-in layers use the learning phase as the default
418 for the `training` arg to `Layer.__call__`.
420 User-written layers and models can achieve the same behavior with code that
421 looks like:
423 ```python
424 def call(self, inputs, training=None):
425 if training is None:
426 training = backend.learning_phase()
427 ```
429 Args:
430 value: Learning phase value, either 0 or 1 (integers).
431 0 = test, 1 = train
433 Raises:
434 ValueError: if `value` is neither `0` nor `1`.
435 """
436 warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and '
437 'will be removed after 2020-10-11. To update it, simply '
438 'pass a True/False value to the `training` argument of the '
439 '`__call__` method of your layer or model.')
440 deprecated_internal_set_learning_phase(value)
443def deprecated_internal_set_learning_phase(value):
444 """A deprecated internal implementation of set_learning_phase.
446 This method is an internal-only version of `set_learning_phase` that
447 does not raise a deprecation error. It is required because
448 saved_model needs to keep working with user code that uses the deprecated
449 learning phase methods until those APIs are fully removed from the public API.
451 Specifically SavedModel saving needs to make sure the learning phase is 0
452 during tracing even if users overwrote it to a different value.
454 But, we don't want to raise deprecation warnings for users when savedmodel
455 sets learning phase just for compatibility with code that relied on
456 explicitly setting the learning phase for other values.
458 Args:
459 value: Learning phase value, either 0 or 1 (integers). 0 = test, 1 = train
461 Raises:
462 ValueError: if `value` is neither `0` nor `1`.
463 """
464 global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned
465 if value not in {0, 1}:
466 raise ValueError('Expected learning phase to be 0 or 1.')
467 with ops.init_scope():
468 if context.executing_eagerly():
469 # In an eager context, the learning phase values applies to both the eager
470 # context and the internal Keras graph.
471 _DUMMY_EAGER_GRAPH.learning_phase_is_set = True
472 _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = value
473 _GRAPH_LEARNING_PHASES[get_graph()] = value
476@keras_export('keras.backend.learning_phase_scope')
477@tf_contextlib.contextmanager
478@doc_controls.do_not_generate_docs
479def learning_phase_scope(value):
480 """Provides a scope within which the learning phase is equal to `value`.
482 The learning phase gets restored to its original value upon exiting the scope.
484 Args:
485 value: Learning phase value, either 0 or 1 (integers).
486 0 = test, 1 = train
488 Yields:
489 None.
491 Raises:
492 ValueError: if `value` is neither `0` nor `1`.
493 """
494 warnings.warn('`tf.keras.backend.learning_phase_scope` is deprecated and '
495 'will be removed after 2020-10-11. To update it, simply '
496 'pass a True/False value to the `training` argument of the '
497 '`__call__` method of your layer or model.')
498 with deprecated_internal_learning_phase_scope(value):
499 try:
500 yield
501 finally:
502 pass
505@tf_contextlib.contextmanager
506def deprecated_internal_learning_phase_scope(value):
507 """An internal-only version of `learning_phase_scope`.
509 Unlike the public method, this method does not raise a deprecation warning.
510 This is needed because saved model saving needs to set learning phase
511 to maintain compatibility
512 with code that sets/gets the learning phase, but saved model
513 saving itself shouldn't raise a deprecation warning.
515 We can get rid of this method and its usages when the public API is
516 removed.
518 Args:
519 value: Learning phase value, either 0 or 1 (integers). 0 = test, 1 = train
521 Yields:
522 None.
524 Raises:
525 ValueError: if `value` is neither `0` nor `1`.
526 """
527 global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned
528 if value not in {0, 1}:
529 raise ValueError('Expected learning phase to be 0 or 1.')
531 with ops.init_scope():
532 if context.executing_eagerly():
533 previous_eager_value = _GRAPH_LEARNING_PHASES.get(
534 _DUMMY_EAGER_GRAPH.key, None)
535 previous_graph_value = _GRAPH_LEARNING_PHASES.get(get_graph(), None)
537 learning_phase_previously_set = _DUMMY_EAGER_GRAPH.learning_phase_is_set
538 try:
539 deprecated_internal_set_learning_phase(value)
540 yield
541 finally:
542 # Restore learning phase to initial value.
543 if not learning_phase_previously_set:
544 _DUMMY_EAGER_GRAPH.learning_phase_is_set = False
545 with ops.init_scope():
546 if context.executing_eagerly():
547 if previous_eager_value is not None:
548 _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = previous_eager_value
549 elif _DUMMY_EAGER_GRAPH.key in _GRAPH_LEARNING_PHASES:
550 del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key]
552 graph = get_graph()
553 if previous_graph_value is not None:
554 _GRAPH_LEARNING_PHASES[graph] = previous_graph_value
555 elif graph in _GRAPH_LEARNING_PHASES:
556 del _GRAPH_LEARNING_PHASES[graph]
559@tf_contextlib.contextmanager
560def eager_learning_phase_scope(value):
561 """Internal scope that sets the learning phase in eager / tf.function only.
563 Args:
564 value: Learning phase value, either 0 or 1 (integers).
565 0 = test, 1 = train
567 Yields:
568 None.
570 Raises:
571 ValueError: if `value` is neither `0` nor `1`.
572 """
573 global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned
574 assert value in {0, 1}
575 assert ops.executing_eagerly_outside_functions()
576 global_learning_phase_was_set = global_learning_phase_is_set()
577 if global_learning_phase_was_set:
578 previous_value = learning_phase()
579 try:
580 _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = value
581 yield
582 finally:
583 # Restore learning phase to initial value or unset.
584 if global_learning_phase_was_set:
585 _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = previous_value
586 else:
587 del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key]
590def _as_graph_element(obj):
591 """Convert `obj` to a graph element if possible, otherwise return `None`.
593 Args:
594 obj: Object to convert.
596 Returns:
597 The result of `obj._as_graph_element()` if that method is available;
598 otherwise `None`.
599 """
600 conv_fn = getattr(obj, '_as_graph_element', None)
601 if conv_fn and callable(conv_fn):
602 return conv_fn()
603 return None
606def _assert_same_graph(original_item, item):
607 """Fail if the 2 items are from different graphs.
609 Args:
610 original_item: Original item to check against.
611 item: Item to check.
613 Raises:
614 ValueError: if graphs do not match.
615 """
616 original_graph = getattr(original_item, 'graph', None)
617 graph = getattr(item, 'graph', None)
618 if original_graph and graph and original_graph is not graph:
619 raise ValueError(
620 '%s must be from the same graph as %s (graphs are %s and %s).' %
621 (item, original_item, graph, original_graph))
624def _current_graph(op_input_list, graph=None):
625 """Returns the appropriate graph to use for the given inputs.
627 This library method provides a consistent algorithm for choosing the graph
628 in which an Operation should be constructed:
630 1. If the default graph is being used to construct a function, we
631 use the default graph.
632 2. If the "graph" is specified explicitly, we validate that all of the inputs
633 in "op_input_list" are compatible with that graph.
634 3. Otherwise, we attempt to select a graph from the first Operation-
635 or Tensor-valued input in "op_input_list", and validate that all other
636 such inputs are in the same graph.
637 4. If the graph was not specified and it could not be inferred from
638 "op_input_list", we attempt to use the default graph.
640 Args:
641 op_input_list: A list of inputs to an operation, which may include `Tensor`,
642 `Operation`, and other objects that may be converted to a graph element.
643 graph: (Optional) The explicit graph to use.
645 Raises:
646 TypeError: If op_input_list is not a list or tuple, or if graph is not a
647 Graph.
648 ValueError: If a graph is explicitly passed and not all inputs are from it,
649 or if the inputs are from multiple graphs, or we could not find a graph
650 and there was no default graph.
652 Returns:
653 The appropriate graph to use for the given inputs.
655 """
656 current_default_graph = ops.get_default_graph()
657 if current_default_graph.building_function:
658 return current_default_graph
660 op_input_list = tuple(op_input_list) # Handle generators correctly
661 if graph and not isinstance(graph, ops.Graph):
662 raise TypeError('Input graph needs to be a Graph: %s' % (graph,))
664 # 1. We validate that all of the inputs are from the same graph. This is
665 # either the supplied graph parameter, or the first one selected from one
666 # the graph-element-valued inputs. In the latter case, we hold onto
667 # that input in original_graph_element so we can provide a more
668 # informative error if a mismatch is found.
669 original_graph_element = None
670 for op_input in op_input_list:
671 # Determine if this is a valid graph_element.
672 # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this
673 # up.
674 if (isinstance(op_input, (
675 ops.Operation, ops.Tensor, composite_tensor.CompositeTensor)) and
676 ((not isinstance(op_input, ops.Tensor))
677 or type(op_input) == ops.Tensor)): # pylint: disable=unidiomatic-typecheck
678 graph_element = op_input
679 else:
680 graph_element = _as_graph_element(op_input)
682 if graph_element is not None:
683 if not graph:
684 original_graph_element = graph_element
685 graph = getattr(graph_element, 'graph', None)
686 elif original_graph_element is not None:
687 _assert_same_graph(original_graph_element, graph_element)
688 elif graph_element.graph is not graph:
689 raise ValueError('%s is not from the passed-in graph.' % graph_element)
691 # 2. If all else fails, we use the default graph, which is always there.
692 return graph or current_default_graph
695def _get_session(op_input_list=()):
696 """Returns the session object for the current thread."""
697 global _SESSION
698 default_session = ops.get_default_session()
699 if default_session is not None:
700 session = default_session
701 else:
702 if ops.inside_function():
703 raise RuntimeError('Cannot get session inside Tensorflow graph function.')
704 # If we don't have a session, or that session does not match the current
705 # graph, create and cache a new session.
706 if (getattr(_SESSION, 'session', None) is None or
707 _SESSION.session.graph is not _current_graph(op_input_list)):
708 # If we are creating the Session inside a tf.distribute.Strategy scope,
709 # we ask the strategy for the right session options to use.
710 if distribute_lib.has_strategy():
711 configure_and_create_distributed_session(
712 distribute_lib.get_strategy())
713 else:
714 _SESSION.session = session_module.Session(
715 config=get_default_session_config())
716 session = _SESSION.session
717 return session
720@keras_export(v1=['keras.backend.get_session'])
721def get_session(op_input_list=()):
722 """Returns the TF session to be used by the backend.
724 If a default TensorFlow session is available, we will return it.
726 Else, we will return the global Keras session assuming it matches
727 the current graph.
729 If no global Keras session exists at this point:
730 we will create a new global session.
732 Note that you can manually set the global session
733 via `K.set_session(sess)`.
735 Args:
736 op_input_list: An option sequence of tensors or ops, which will be used
737 to determine the current graph. Otherwise the default graph will be
738 used.
740 Returns:
741 A TensorFlow session.
742 """
743 session = _get_session(op_input_list)
744 if not _MANUAL_VAR_INIT:
745 with session.graph.as_default():
746 _initialize_variables(session)
747 return session
750def get_graph():
751 if context.executing_eagerly():
752 global _GRAPH
753 if not getattr(_GRAPH, 'graph', None):
754 _GRAPH.graph = func_graph.FuncGraph('keras_graph')
755 return _GRAPH.graph
756 else:
757 return ops.get_default_graph()
760@tf_contextlib.contextmanager
761def _scratch_graph(graph=None):
762 """Retrieve a shared and temporary func graph.
764 The eager execution path lifts a subgraph from the keras global graph into
765 a scratch graph in order to create a function. DistributionStrategies, in
766 turn, constructs multiple functions as well as a final combined function. In
767 order for that logic to work correctly, all of the functions need to be
768 created on the same scratch FuncGraph.
770 Args:
771 graph: A graph to be used as the current scratch graph. If not set then
772 a scratch graph will either be retrieved or created:
774 Yields:
775 The current scratch graph.
776 """
777 global _CURRENT_SCRATCH_GRAPH
778 scratch_graph = getattr(_CURRENT_SCRATCH_GRAPH, 'graph', None)
779 # If scratch graph and `graph` are both configured, they must match.
780 if (scratch_graph is not None and graph is not None and
781 scratch_graph is not graph):
782 raise ValueError('Multiple scratch graphs specified.')
784 if scratch_graph:
785 yield scratch_graph
786 return
788 graph = graph or func_graph.FuncGraph('keras_scratch_graph')
789 try:
790 _CURRENT_SCRATCH_GRAPH.graph = graph
791 yield graph
792 finally:
793 _CURRENT_SCRATCH_GRAPH.graph = None
796@keras_export(v1=['keras.backend.set_session'])
797def set_session(session):
798 """Sets the global TensorFlow session.
800 Args:
801 session: A TF Session.
802 """
803 global _SESSION
804 _SESSION.session = session
807def get_default_session_config():
808 if os.environ.get('OMP_NUM_THREADS'):
809 logging.warning(
810 'OMP_NUM_THREADS is no longer used by the default Keras config. '
811 'To configure the number of threads, use tf.config.threading APIs.')
813 config = get_config()
814 config.allow_soft_placement = True
816 return config
819def get_default_graph_uid_map():
820 graph = ops.get_default_graph()
821 name_uid_map = PER_GRAPH_OBJECT_NAME_UIDS.get(graph, None)
822 if name_uid_map is None:
823 name_uid_map = collections.defaultdict(int)
824 PER_GRAPH_OBJECT_NAME_UIDS[graph] = name_uid_map
825 return name_uid_map
828# DEVICE MANIPULATION
831class _TfDeviceCaptureOp:
832 """Class for capturing the TF device scope."""
834 def __init__(self):
835 self.device = None
837 def _set_device(self, device):
838 """This method captures TF's explicit device scope setting."""
839 if isinstance(device, device_spec.DeviceSpecV2):
840 device = device.to_string()
841 self.device = device
843 def _set_device_from_string(self, device_str):
844 self.device = device_str
847def _get_current_tf_device():
848 """Return explicit device of current context, otherwise returns `None`.
850 Returns:
851 If the current device scope is explicitly set, it returns a string with
852 the device (`CPU` or `GPU`). If the scope is not explicitly set, it will
853 return `None`.
854 """
855 graph = get_graph()
856 op = _TfDeviceCaptureOp()
857 graph._apply_device_functions(op)
858 if tf2.enabled():
859 return device_spec.DeviceSpecV2.from_string(op.device)
860 else:
861 return device_spec.DeviceSpecV1.from_string(op.device)
864def _is_current_explicit_device(device_type):
865 """Check if the current device is explicitly set on the device type specified.
867 Args:
868 device_type: A string containing `GPU` or `CPU` (case-insensitive).
870 Returns:
871 A boolean indicating if the current device scope is explicitly set on the
872 device type.
874 Raises:
875 ValueError: If the `device_type` string indicates an unsupported device.
876 """
877 device_type = device_type.upper()
878 if device_type not in ['CPU', 'GPU']:
879 raise ValueError('`device_type` should be either "CPU" or "GPU".')
880 device = _get_current_tf_device()
881 return device is not None and device.device_type == device_type.upper()
884def _get_available_gpus():
885 """Get a list of available GPU devices (formatted as strings).
887 Returns:
888 A list of available GPU devices.
889 """
890 if ops.executing_eagerly_outside_functions():
891 # Returns names of devices directly.
892 return [d.name for d in config.list_logical_devices('GPU')]
894 global _LOCAL_DEVICES
895 if _LOCAL_DEVICES is None:
896 _LOCAL_DEVICES = get_session().list_devices()
897 return [x.name for x in _LOCAL_DEVICES if x.device_type == 'GPU']
900def _has_nchw_support():
901 """Check whether the current scope supports NCHW ops.
903 TensorFlow does not support NCHW on CPU. Therefore we check if we are not
904 explicitly put on
905 CPU, and have GPUs available. In this case there will be soft-placing on the
906 GPU device.
908 Returns:
909 bool: if the current scope device placement would support nchw
910 """
911 explicitly_on_cpu = _is_current_explicit_device('CPU')
912 gpus_available = bool(_get_available_gpus())
913 return not explicitly_on_cpu and gpus_available
916# VARIABLE MANIPULATION
919def _constant_to_tensor(x, dtype):
920 """Convert the input `x` to a tensor of type `dtype`.
922 This is slightly faster than the _to_tensor function, at the cost of
923 handling fewer cases.
925 Args:
926 x: An object to be converted (numpy arrays, floats, ints and lists of
927 them).
928 dtype: The destination type.
930 Returns:
931 A tensor.
932 """
933 return constant_op.constant(x, dtype=dtype)
936def _to_tensor(x, dtype):
937 """Convert the input `x` to a tensor of type `dtype`.
939 Args:
940 x: An object to be converted (numpy array, list, tensors).
941 dtype: The destination type.
943 Returns:
944 A tensor.
945 """
946 return tensor_conversion.convert_to_tensor_v2_with_dispatch(x, dtype=dtype)
949@keras_export('keras.backend.is_sparse')
950@doc_controls.do_not_generate_docs
951def is_sparse(tensor):
952 """Returns whether a tensor is a sparse tensor.
954 Args:
955 tensor: A tensor instance.
957 Returns:
958 A boolean.
960 Example:
963 >>> a = tf.keras.backend.placeholder((2, 2), sparse=False)
964 >>> print(tf.keras.backend.is_sparse(a))
965 False
966 >>> b = tf.keras.backend.placeholder((2, 2), sparse=True)
967 >>> print(tf.keras.backend.is_sparse(b))
968 True
970 """
971 spec = getattr(tensor, '_type_spec', None)
972 if spec is not None:
973 return isinstance(spec, sparse_tensor.SparseTensorSpec)
974 return isinstance(tensor, sparse_tensor.SparseTensor)
977@keras_export('keras.backend.to_dense')
978@dispatch.add_dispatch_support
979@doc_controls.do_not_generate_docs
980def to_dense(tensor):
981 """Converts a sparse tensor into a dense tensor and returns it.
983 Args:
984 tensor: A tensor instance (potentially sparse).
986 Returns:
987 A dense tensor.
989 Examples:
992 >>> b = tf.keras.backend.placeholder((2, 2), sparse=True)
993 >>> print(tf.keras.backend.is_sparse(b))
994 True
995 >>> c = tf.keras.backend.to_dense(b)
996 >>> print(tf.keras.backend.is_sparse(c))
997 False
999 """
1000 if is_sparse(tensor):
1001 return sparse_ops.sparse_tensor_to_dense(tensor)
1002 else:
1003 return tensor
1006@keras_export('keras.backend.name_scope', v1=[])
1007@doc_controls.do_not_generate_docs
1008def name_scope(name):
1009 """A context manager for use when defining a Python op.
1011 This context manager pushes a name scope, which will make the name of all
1012 operations added within it have a prefix.
1014 For example, to define a new Python op called `my_op`:
1017 def my_op(a):
1018 with tf.name_scope("MyOp") as scope:
1019 a = tf.convert_to_tensor(a, name="a")
1020 # Define some computation that uses `a`.
1021 return foo_op(..., name=scope)
1024 When executed, the Tensor `a` will have the name `MyOp/a`.
1026 Args:
1027 name: The prefix to use on all names created within the name scope.
1029 Returns:
1030 Name scope context manager.
1031 """
1032 return ops.name_scope_v2(name)
1034# Export V1 version.
1035_v1_name_scope = ops.name_scope_v1
1036keras_export(v1=['keras.backend.name_scope'])(_v1_name_scope)
1039@keras_export('keras.backend.variable')
1040@doc_controls.do_not_generate_docs
1041def variable(value, dtype=None, name=None, constraint=None):
1042 """Instantiates a variable and returns it.
1044 Args:
1045 value: Numpy array, initial value of the tensor.
1046 dtype: Tensor type.
1047 name: Optional name string for the tensor.
1048 constraint: Optional projection function to be
1049 applied to the variable after an optimizer update.
1051 Returns:
1052 A variable instance (with Keras metadata included).
1054 Examples:
1056 >>> val = np.array([[1, 2], [3, 4]])
1057 >>> kvar = tf.keras.backend.variable(value=val, dtype='float64',
1058 ... name='example_var')
1059 >>> tf.keras.backend.dtype(kvar)
1060 'float64'
1061 >>> print(kvar)
1062 <tf.Variable 'example_var:...' shape=(2, 2) dtype=float64, numpy=
1063 array([[1., 2.],
1064 [3., 4.]])>
1066 """
1067 if dtype is None:
1068 dtype = floatx()
1069 if hasattr(value, 'tocoo'):
1070 sparse_coo = value.tocoo()
1071 indices = np.concatenate((np.expand_dims(sparse_coo.row, 1), np.expand_dims(
1072 sparse_coo.col, 1)), 1)
1073 v = sparse_tensor.SparseTensor(
1074 indices=indices, values=sparse_coo.data, dense_shape=sparse_coo.shape)
1075 v._keras_shape = sparse_coo.shape
1076 return v
1077 v = variables_module.Variable(
1078 value,
1079 dtype=dtypes_module.as_dtype(dtype),
1080 name=name,
1081 constraint=constraint)
1082 if isinstance(value, np.ndarray):
1083 v._keras_shape = value.shape
1084 elif hasattr(value, 'shape'):
1085 v._keras_shape = int_shape(value)
1086 track_variable(v)
1087 return v
1090def track_tf_optimizer(tf_optimizer):
1091 """Tracks the given TF optimizer for initialization of its variables."""
1092 if context.executing_eagerly():
1093 return
1094 optimizers = _GRAPH_TF_OPTIMIZERS[None]
1095 optimizers.add(tf_optimizer)
1098@keras_export('keras.__internal__.backend.track_variable', v1=[])
1099def track_variable(v):
1100 """Tracks the given variable for initialization."""
1101 if context.executing_eagerly():
1102 return
1103 graph = v.graph if hasattr(v, 'graph') else get_graph()
1104 _GRAPH_VARIABLES[graph].add(v)
1107def observe_object_name(name):
1108 """Observe a name and make sure it won't be used by `unique_object_name`."""
1109 OBSERVED_NAMES.add(name)
1112def unique_object_name(name,
1113 name_uid_map=None,
1114 avoid_names=None,
1115 namespace='',
1116 zero_based=False,
1117 avoid_observed_names=False):
1118 """Makes a object name (or arbitrary string) unique within a TensorFlow graph.
1120 Args:
1121 name: String name to make unique.
1122 name_uid_map: An optional defaultdict(int) to use when creating unique
1123 names. If None (default), uses a per-Graph dictionary.
1124 avoid_names: An optional set or dict with names which should not be used. If
1125 None (default), don't avoid any names unless `avoid_observed_names` is
1126 True.
1127 namespace: Gets a name which is unique within the (graph, namespace). Layers
1128 which are not Networks use a blank namespace and so get graph-global
1129 names.
1130 zero_based: If True, name sequences start with no suffix (e.g. "dense",
1131 "dense_1"). If False, naming is one-based ("dense_1", "dense_2").
1132 avoid_observed_names: If True, avoid any names that have been observed by
1133 `backend.observe_object_name`.
1135 Returns:
1136 Unique string name.
1138 Example:
1141 unique_object_name('dense') # dense_1
1142 unique_object_name('dense') # dense_2
1144 """
1145 if name_uid_map is None:
1146 name_uid_map = get_default_graph_uid_map()
1147 if avoid_names is None:
1148 if avoid_observed_names:
1149 avoid_names = OBSERVED_NAMES
1150 else:
1151 avoid_names = set()
1152 proposed_name = None
1153 while proposed_name is None or proposed_name in avoid_names:
1154 name_key = (namespace, name)
1155 if zero_based:
1156 number = name_uid_map[name_key]
1157 if number:
1158 proposed_name = name + '_' + str(number)
1159 else:
1160 proposed_name = name
1161 name_uid_map[name_key] += 1
1162 else:
1163 name_uid_map[name_key] += 1
1164 proposed_name = name + '_' + str(name_uid_map[name_key])
1165 return proposed_name
1168def _get_variables(graph=None):
1169 """Returns variables corresponding to the given graph for initialization."""
1170 assert not context.executing_eagerly()
1171 variables = _GRAPH_VARIABLES[graph]
1172 for opt in _GRAPH_TF_OPTIMIZERS[graph]:
1173 variables.update(opt.optimizer.variables())
1174 return variables
1177@keras_export('keras.__internal__.backend.initialize_variables', v1=[])
1178def _initialize_variables(session):
1179 """Utility to initialize uninitialized variables on the fly."""
1180 variables = _get_variables(get_graph())
1181 candidate_vars = []
1182 for v in variables:
1183 if not getattr(v, '_keras_initialized', False):
1184 candidate_vars.append(v)
1185 if candidate_vars:
1186 # This step is expensive, so we only run it on variables not already
1187 # marked as initialized.
1188 is_initialized = session.run(
1189 [variable_v1.is_variable_initialized(v) for v in candidate_vars])
1190 # TODO(kathywu): Some metric variables loaded from SavedModel are never
1191 # actually used, and do not have an initializer.
1192 should_be_initialized = [
1193 (not is_initialized[n]) and v.initializer is not None
1194 for n, v in enumerate(candidate_vars)]
1195 uninitialized_vars = []
1196 for flag, v in zip(should_be_initialized, candidate_vars):
1197 if flag:
1198 uninitialized_vars.append(v)
1199 v._keras_initialized = True
1200 if uninitialized_vars:
1201 session.run(variables_module.variables_initializer(uninitialized_vars))
1204@keras_export('keras.backend.constant')
1205@dispatch.add_dispatch_support
1206@doc_controls.do_not_generate_docs
1207def constant(value, dtype=None, shape=None, name=None):
1208 """Creates a constant tensor.
1210 Args:
1211 value: A constant value (or list)
1212 dtype: The type of the elements of the resulting tensor.
1213 shape: Optional dimensions of resulting tensor.
1214 name: Optional name for the tensor.
1216 Returns:
1217 A Constant Tensor.
1218 """
1219 if dtype is None:
1220 dtype = floatx()
1222 return constant_op.constant(value, dtype=dtype, shape=shape, name=name)
1225@keras_export('keras.backend.is_keras_tensor')
1226def is_keras_tensor(x):
1227 """Returns whether `x` is a Keras tensor.
1229 A "Keras tensor" is a tensor that was returned by a Keras layer,
1230 (`Layer` class) or by `Input`.
1232 Args:
1233 x: A candidate tensor.
1235 Returns:
1236 A boolean: Whether the argument is a Keras tensor.
1238 Raises:
1239 ValueError: In case `x` is not a symbolic tensor.
1241 Examples:
1243 >>> np_var = np.array([1, 2])
1244 >>> # A numpy array is not a symbolic tensor.
1245 >>> tf.keras.backend.is_keras_tensor(np_var)
1246 Traceback (most recent call last):
1247 ...
1248 ValueError: Unexpectedly found an instance of type `<class 'numpy.ndarray'>`.
1249 Expected a symbolic tensor instance.
1250 >>> keras_var = tf.keras.backend.variable(np_var)
1251 >>> # A variable created with the keras backend is not a Keras tensor.
1252 >>> tf.keras.backend.is_keras_tensor(keras_var)
1253 False
1254 >>> keras_placeholder = tf.keras.backend.placeholder(shape=(2, 4, 5))
1255 >>> # A placeholder is a Keras tensor.
1256 >>> tf.keras.backend.is_keras_tensor(keras_placeholder)
1257 True
1258 >>> keras_input = tf.keras.layers.Input([10])
1259 >>> # An Input is a Keras tensor.
1260 >>> tf.keras.backend.is_keras_tensor(keras_input)
1261 True
1262 >>> keras_layer_output = tf.keras.layers.Dense(10)(keras_input)
1263 >>> # Any Keras layer output is a Keras tensor.
1264 >>> tf.keras.backend.is_keras_tensor(keras_layer_output)
1265 True
1267 """
1268 if not isinstance(x,
1269 (ops.Tensor, variables_module.Variable,
1270 sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor,
1271 keras_tensor.KerasTensor)):
1272 raise ValueError('Unexpectedly found an instance of type `' + str(type(x)) +
1273 '`. Expected a symbolic tensor instance.')
1274 if ops.executing_eagerly_outside_functions():
1275 return isinstance(x, keras_tensor.KerasTensor)
1276 return hasattr(x, '_keras_history')
1279@keras_export('keras.backend.placeholder')
1280@doc_controls.do_not_generate_docs
1281def placeholder(shape=None,
1282 ndim=None,
1283 dtype=None,
1284 sparse=False,
1285 name=None,
1286 ragged=False):
1287 """Instantiates a placeholder tensor and returns it.
1289 Args:
1290 shape: Shape of the placeholder
1291 (integer tuple, may include `None` entries).
1292 ndim: Number of axes of the tensor.
1293 At least one of {`shape`, `ndim`} must be specified.
1294 If both are specified, `shape` is used.
1295 dtype: Placeholder type.
1296 sparse: Boolean, whether the placeholder should have a sparse type.
1297 name: Optional name string for the placeholder.
1298 ragged: Boolean, whether the placeholder should have a ragged type.
1299 In this case, values of 'None' in the 'shape' argument represent
1300 ragged dimensions. For more information about RaggedTensors, see this
1301 [guide](https://www.tensorflow.org/guide/ragged_tensors).
1303 Raises:
1304 ValueError: If called with sparse = True and ragged = True.
1306 Returns:
1307 Tensor instance (with Keras metadata included).
1309 Examples:
1312 >>> input_ph = tf.keras.backend.placeholder(shape=(2, 4, 5))
1313 >>> input_ph
1314 <KerasTensor: shape=(2, 4, 5) dtype=float32 (created by layer ...)>
1316 """
1317 if sparse and ragged:
1318 raise ValueError(
1319 'Cannot set both sparse and ragged to True when creating a placeholder.'
1320 )
1321 if dtype is None:
1322 dtype = floatx()
1323 if not shape:
1324 if ndim:
1325 shape = (None,) * ndim
1326 if ops.executing_eagerly_outside_functions():
1327 if sparse:
1328 spec = sparse_tensor.SparseTensorSpec(
1329 shape=shape, dtype=dtype)
1330 elif ragged:
1331 ragged_rank = 0
1332 for i in range(1, len(shape)):
1333 # Hacky because could be tensorshape or tuple maybe?
1334 # Or just tensorshape?
1335 if shape[i] is None or (
1336 hasattr(shape[i], 'value') and
1337 shape[i].value is None):
1338 ragged_rank = i
1339 spec = ragged_tensor.RaggedTensorSpec(
1340 shape=shape, dtype=dtype, ragged_rank=ragged_rank)
1341 else:
1342 spec = tensor_spec.TensorSpec(
1343 shape=shape, dtype=dtype, name=name)
1344 x = keras_tensor.keras_tensor_from_type_spec(spec, name=name)
1345 else:
1346 with get_graph().as_default():
1347 if sparse:
1348 x = array_ops.sparse_placeholder(dtype, shape=shape, name=name)
1349 elif ragged:
1350 ragged_rank = 0
1351 for i in range(1, len(shape)):
1352 if shape[i] is None:
1353 ragged_rank = i
1354 type_spec = ragged_tensor.RaggedTensorSpec(
1355 shape=shape, dtype=dtype, ragged_rank=ragged_rank)
1356 def tensor_spec_to_placeholder(tensorspec):
1357 return array_ops.placeholder(tensorspec.dtype, tensorspec.shape)
1358 x = nest.map_structure(tensor_spec_to_placeholder, type_spec,
1359 expand_composites=True)
1360 else:
1361 x = array_ops.placeholder(dtype, shape=shape, name=name)
1363 if context.executing_eagerly():
1364 # Add keras_history connectivity information to the placeholder
1365 # when the placeholder is built in a top-level eager context
1366 # (intended to be used with keras.backend.function)
1367 from tensorflow.python.keras.engine import input_layer # pylint: disable=g-import-not-at-top
1368 x = input_layer.Input(tensor=x)
1369 x._is_backend_placeholder = True
1371 return x
1374def is_placeholder(x):
1375 """Returns whether `x` is a placeholder.
1377 Args:
1378 x: A candidate placeholder.
1380 Returns:
1381 Boolean.
1382 """
1383 try:
1384 if ops.executing_eagerly_outside_functions():
1385 return hasattr(x, '_is_backend_placeholder')
1386 from tensorflow.python.keras.utils import tf_utils # pylint: disable=g-import-not-at-top
1387 if tf_utils.is_extension_type(x):
1388 flat_components = nest.flatten(x, expand_composites=True)
1389 return py_any(is_placeholder(c) for c in flat_components)
1390 else:
1391 return x.op.type == 'Placeholder'
1392 except AttributeError:
1393 return False
1396@keras_export('keras.backend.shape')
1397@dispatch.add_dispatch_support
1398@doc_controls.do_not_generate_docs
1399def shape(x):
1400 """Returns the symbolic shape of a tensor or variable.
1402 Args:
1403 x: A tensor or variable.
1405 Returns:
1406 A symbolic shape (which is itself a tensor).
1408 Examples:
1410 >>> val = np.array([[1, 2], [3, 4]])
1411 >>> kvar = tf.keras.backend.variable(value=val)
1412 >>> tf.keras.backend.shape(kvar)
1413 <tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 2], dtype=int32)>
1414 >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5))
1415 >>> tf.keras.backend.shape(input)
1416 <KerasTensor: shape=(3,) dtype=int32 inferred_value=[2, 4, 5] ...>
1418 """
1419 return array_ops.shape(x)
1422@keras_export('keras.backend.int_shape')
1423@doc_controls.do_not_generate_docs
1424def int_shape(x):
1425 """Returns the shape of tensor or variable as a tuple of int or None entries.
1427 Args:
1428 x: Tensor or variable.
1430 Returns:
1431 A tuple of integers (or None entries).
1433 Examples:
1435 >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5))
1436 >>> tf.keras.backend.int_shape(input)
1437 (2, 4, 5)
1438 >>> val = np.array([[1, 2], [3, 4]])
1439 >>> kvar = tf.keras.backend.variable(value=val)
1440 >>> tf.keras.backend.int_shape(kvar)
1441 (2, 2)
1443 """
1444 try:
1445 shape = x.shape
1446 if not isinstance(shape, tuple):
1447 shape = tuple(shape.as_list())
1448 return shape
1449 except ValueError:
1450 return None
1453@keras_export('keras.backend.ndim')
1454@doc_controls.do_not_generate_docs
1455def ndim(x):
1456 """Returns the number of axes in a tensor, as an integer.
1458 Args:
1459 x: Tensor or variable.
1461 Returns:
1462 Integer (scalar), number of axes.
1464 Examples:
1467 >>> input = tf.keras.backend.placeholder(shape=(2, 4, 5))
1468 >>> val = np.array([[1, 2], [3, 4]])
1469 >>> kvar = tf.keras.backend.variable(value=val)
1470 >>> tf.keras.backend.ndim(input)
1471 3
1472 >>> tf.keras.backend.ndim(kvar)
1473 2
1475 """
1476 return x.shape.rank
1479@keras_export('keras.backend.dtype')
1480@dispatch.add_dispatch_support
1481@doc_controls.do_not_generate_docs
1482def dtype(x):
1483 """Returns the dtype of a Keras tensor or variable, as a string.
1485 Args:
1486 x: Tensor or variable.
1488 Returns:
1489 String, dtype of `x`.
1491 Examples:
1493 >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5)))
1494 'float32'
1495 >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5),
1496 ... dtype='float32'))
1497 'float32'
1498 >>> tf.keras.backend.dtype(tf.keras.backend.placeholder(shape=(2,4,5),
1499 ... dtype='float64'))
1500 'float64'
1501 >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]))
1502 >>> tf.keras.backend.dtype(kvar)
1503 'float32'
1504 >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]),
1505 ... dtype='float32')
1506 >>> tf.keras.backend.dtype(kvar)
1507 'float32'
1509 """
1510 return x.dtype.base_dtype.name
1513@doc_controls.do_not_generate_docs
1514def dtype_numpy(x):
1515 """Returns the numpy dtype of a Keras tensor or variable.
1517 Args:
1518 x: Tensor or variable.
1520 Returns:
1521 numpy.dtype, dtype of `x`.
1522 """
1523 return dtypes_module.as_dtype(x.dtype).as_numpy_dtype
1526@keras_export('keras.backend.eval')
1527@doc_controls.do_not_generate_docs
1528def eval(x):
1529 """Evaluates the value of a variable.
1531 Args:
1532 x: A variable.
1534 Returns:
1535 A Numpy array.
1537 Examples:
1539 >>> kvar = tf.keras.backend.variable(np.array([[1, 2], [3, 4]]),
1540 ... dtype='float32')
1541 >>> tf.keras.backend.eval(kvar)
1542 array([[1., 2.],
1543 [3., 4.]], dtype=float32)
1545 """
1546 return get_value(to_dense(x))
1549@keras_export('keras.backend.zeros')
1550@doc_controls.do_not_generate_docs
1551def zeros(shape, dtype=None, name=None):
1552 """Instantiates an all-zeros variable and returns it.
1554 Args:
1555 shape: Tuple or list of integers, shape of returned Keras variable
1556 dtype: data type of returned Keras variable
1557 name: name of returned Keras variable
1559 Returns:
1560 A variable (including Keras metadata), filled with `0.0`.
1561 Note that if `shape` was symbolic, we cannot return a variable,
1562 and will return a dynamically-shaped tensor instead.
1564 Example:
1566 >>> kvar = tf.keras.backend.zeros((3,4))
1567 >>> tf.keras.backend.eval(kvar)
1568 array([[0., 0., 0., 0.],
1569 [0., 0., 0., 0.],
1570 [0., 0., 0., 0.]], dtype=float32)
1571 >>> A = tf.constant([1,2,3])
1572 >>> kvar2 = tf.keras.backend.zeros(A.shape) # [0., 0., 0.]
1573 >>> tf.keras.backend.eval(kvar2)
1574 array([0., 0., 0.], dtype=float32)
1575 >>> kvar3 = tf.keras.backend.zeros(A.shape,dtype=tf.int32)
1576 >>> tf.keras.backend.eval(kvar3)
1577 array([0, 0, 0], dtype=int32)
1578 >>> kvar4 = tf.keras.backend.zeros([2,3])
1579 >>> tf.keras.backend.eval(kvar4)
1580 array([[0., 0., 0.],
1581 [0., 0., 0.]], dtype=float32)
1583 """
1584 with ops.init_scope():
1585 if dtype is None:
1586 dtype = floatx()
1587 tf_dtype = dtypes_module.as_dtype(dtype)
1588 v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name)
1589 if py_all(v.shape.as_list()):
1590 return variable(v, dtype=dtype, name=name)
1591 return v
1594@keras_export('keras.backend.ones')
1595@dispatch.add_dispatch_support
1596@doc_controls.do_not_generate_docs
1597def ones(shape, dtype=None, name=None):
1598 """Instantiates an all-ones variable and returns it.
1600 Args:
1601 shape: Tuple of integers, shape of returned Keras variable.
1602 dtype: String, data type of returned Keras variable.
1603 name: String, name of returned Keras variable.
1605 Returns:
1606 A Keras variable, filled with `1.0`.
1607 Note that if `shape` was symbolic, we cannot return a variable,
1608 and will return a dynamically-shaped tensor instead.
1610 Example:
1613 >>> kvar = tf.keras.backend.ones((3,4))
1614 >>> tf.keras.backend.eval(kvar)
1615 array([[1., 1., 1., 1.],
1616 [1., 1., 1., 1.],
1617 [1., 1., 1., 1.]], dtype=float32)
1619 """
1620 with ops.init_scope():
1621 if dtype is None:
1622 dtype = floatx()
1623 tf_dtype = dtypes_module.as_dtype(dtype)
1624 v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name)
1625 if py_all(v.shape.as_list()):
1626 return variable(v, dtype=dtype, name=name)
1627 return v
1630@keras_export('keras.backend.eye')
1631@dispatch.add_dispatch_support
1632@doc_controls.do_not_generate_docs
1633def eye(size, dtype=None, name=None):
1634 """Instantiate an identity matrix and returns it.
1636 Args:
1637 size: Integer, number of rows/columns.
1638 dtype: String, data type of returned Keras variable.
1639 name: String, name of returned Keras variable.
1641 Returns:
1642 A Keras variable, an identity matrix.
1644 Example:
1647 >>> kvar = tf.keras.backend.eye(3)
1648 >>> tf.keras.backend.eval(kvar)
1649 array([[1., 0., 0.],
1650 [0., 1., 0.],
1651 [0., 0., 1.]], dtype=float32)
1654 """
1655 if dtype is None:
1656 dtype = floatx()
1657 tf_dtype = dtypes_module.as_dtype(dtype)
1658 return variable(linalg_ops.eye(size, dtype=tf_dtype), dtype, name)
1661@keras_export('keras.backend.zeros_like')
1662@doc_controls.do_not_generate_docs
1663def zeros_like(x, dtype=None, name=None):
1664 """Instantiates an all-zeros variable of the same shape as another tensor.
1666 Args:
1667 x: Keras variable or Keras tensor.
1668 dtype: dtype of returned Keras variable.
1669 `None` uses the dtype of `x`.
1670 name: name for the variable to create.
1672 Returns:
1673 A Keras variable with the shape of `x` filled with zeros.
1675 Example:
1677 ```python
1678 from tensorflow.keras import backend as K
1679 kvar = K.variable(np.random.random((2,3)))
1680 kvar_zeros = K.zeros_like(kvar)
1681 K.eval(kvar_zeros)
1682 # array([[ 0., 0., 0.], [ 0., 0., 0.]], dtype=float32)
1683 ```
1684 """
1685 return array_ops.zeros_like(x, dtype=dtype, name=name)
1688@keras_export('keras.backend.ones_like')
1689@dispatch.add_dispatch_support
1690@doc_controls.do_not_generate_docs
1691def ones_like(x, dtype=None, name=None):
1692 """Instantiates an all-ones variable of the same shape as another tensor.
1694 Args:
1695 x: Keras variable or tensor.
1696 dtype: String, dtype of returned Keras variable.
1697 None uses the dtype of x.
1698 name: String, name for the variable to create.
1700 Returns:
1701 A Keras variable with the shape of x filled with ones.
1703 Example:
1705 >>> kvar = tf.keras.backend.variable(np.random.random((2,3)))
1706 >>> kvar_ones = tf.keras.backend.ones_like(kvar)
1707 >>> tf.keras.backend.eval(kvar_ones)
1708 array([[1., 1., 1.],
1709 [1., 1., 1.]], dtype=float32)
1711 """
1712 return array_ops.ones_like(x, dtype=dtype, name=name)
1715def identity(x, name=None):
1716 """Returns a tensor with the same content as the input tensor.
1718 Args:
1719 x: The input tensor.
1720 name: String, name for the variable to create.
1722 Returns:
1723 A tensor of the same shape, type and content.
1724 """
1725 return array_ops.identity(x, name=name)
1728@keras_export('keras.backend.random_uniform_variable')
1729@doc_controls.do_not_generate_docs
1730def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None):
1731 """Instantiates a variable with values drawn from a uniform distribution.
1733 Args:
1734 shape: Tuple of integers, shape of returned Keras variable.
1735 low: Float, lower boundary of the output interval.
1736 high: Float, upper boundary of the output interval.
1737 dtype: String, dtype of returned Keras variable.
1738 name: String, name of returned Keras variable.
1739 seed: Integer, random seed.
1741 Returns:
1742 A Keras variable, filled with drawn samples.
1744 Example:
1746 >>> kvar = tf.keras.backend.random_uniform_variable(shape=(2,3),
1747 ... low=0.0, high=1.0)
1748 >>> kvar
1749 <tf.Variable 'Variable:0' shape=(2, 3) dtype=float32, numpy=...,
1750 dtype=float32)>
1751 """
1752 if dtype is None:
1753 dtype = floatx()
1754 tf_dtype = dtypes_module.as_dtype(dtype)
1755 if seed is None:
1756 # ensure that randomness is conditioned by the Numpy RNG
1757 seed = np.random.randint(10e8)
1758 value = init_ops.random_uniform_initializer(
1759 low, high, dtype=tf_dtype, seed=seed)(shape)
1760 return variable(value, dtype=dtype, name=name)
1763@keras_export('keras.backend.random_normal_variable')
1764@doc_controls.do_not_generate_docs
1765def random_normal_variable(shape, mean, scale, dtype=None, name=None,
1766 seed=None):
1767 """Instantiates a variable with values drawn from a normal distribution.
1769 Args:
1770 shape: Tuple of integers, shape of returned Keras variable.
1771 mean: Float, mean of the normal distribution.
1772 scale: Float, standard deviation of the normal distribution.
1773 dtype: String, dtype of returned Keras variable.
1774 name: String, name of returned Keras variable.
1775 seed: Integer, random seed.
1777 Returns:
1778 A Keras variable, filled with drawn samples.
1780 Example:
1782 >>> kvar = tf.keras.backend.random_normal_variable(shape=(2,3),
1783 ... mean=0.0, scale=1.0)
1784 >>> kvar
1785 <tf.Variable 'Variable:0' shape=(2, 3) dtype=float32, numpy=...,
1786 dtype=float32)>
1787 """
1788 if dtype is None:
1789 dtype = floatx()
1790 tf_dtype = dtypes_module.as_dtype(dtype)
1791 if seed is None:
1792 # ensure that randomness is conditioned by the Numpy RNG
1793 seed = np.random.randint(10e8)
1794 value = init_ops.random_normal_initializer(
1795 mean, scale, dtype=tf_dtype, seed=seed)(shape)
1796 return variable(value, dtype=dtype, name=name)
1799@keras_export('keras.backend.count_params')
1800@doc_controls.do_not_generate_docs
1801def count_params(x):
1802 """Returns the static number of elements in a variable or tensor.
1804 Args:
1805 x: Variable or tensor.
1807 Returns:
1808 Integer, the number of scalars in `x`.
1810 Example:
1812 >>> kvar = tf.keras.backend.zeros((2,3))
1813 >>> tf.keras.backend.count_params(kvar)
1814 6
1815 >>> tf.keras.backend.eval(kvar)
1816 array([[0., 0., 0.],
1817 [0., 0., 0.]], dtype=float32)
1819 """
1820 return np.prod(x.shape.as_list())
1823@keras_export('keras.backend.cast')
1824@dispatch.add_dispatch_support
1825@doc_controls.do_not_generate_docs
1826def cast(x, dtype):
1827 """Casts a tensor to a different dtype and returns it.
1829 You can cast a Keras variable but it still returns a Keras tensor.
1831 Args:
1832 x: Keras tensor (or variable).
1833 dtype: String, either (`'float16'`, `'float32'`, or `'float64'`).
1835 Returns:
1836 Keras tensor with dtype `dtype`.
1838 Examples:
1839 Cast a float32 variable to a float64 tensor
1841 >>> input = tf.keras.backend.ones(shape=(1,3))
1842 >>> print(input)
1843 <tf.Variable 'Variable:0' shape=(1, 3) dtype=float32,
1844 numpy=array([[1., 1., 1.]], dtype=float32)>
1845 >>> cast_input = tf.keras.backend.cast(input, dtype='float64')
1846 >>> print(cast_input)
1847 tf.Tensor([[1. 1. 1.]], shape=(1, 3), dtype=float64)
1849 """
1850 return math_ops.cast(x, dtype)
1853# UPDATES OPS
1856@keras_export('keras.backend.update')
1857@doc_controls.do_not_generate_docs
1858def update(x, new_x):
1859 return state_ops.assign(x, new_x)
1862@keras_export('keras.backend.update_add')
1863@doc_controls.do_not_generate_docs
1864def update_add(x, increment):
1865 """Update the value of `x` by adding `increment`.
1867 Args:
1868 x: A Variable.
1869 increment: A tensor of same shape as `x`.
1871 Returns:
1872 The variable `x` updated.
1873 """
1874 return state_ops.assign_add(x, increment)
1877@keras_export('keras.backend.update_sub')
1878@doc_controls.do_not_generate_docs
1879def update_sub(x, decrement):
1880 """Update the value of `x` by subtracting `decrement`.
1882 Args:
1883 x: A Variable.
1884 decrement: A tensor of same shape as `x`.
1886 Returns:
1887 The variable `x` updated.
1888 """
1889 return state_ops.assign_sub(x, decrement)
1892@keras_export('keras.backend.moving_average_update')
1893@doc_controls.do_not_generate_docs
1894def moving_average_update(x, value, momentum):
1895 """Compute the exponential moving average of a value.
1897 The moving average 'x' is updated with 'value' following:
1899 ```
1900 x = x * momentum + value * (1 - momentum)
1901 ```
1903 For example:
1905 >>> x = tf.Variable(0.0)
1906 >>> momentum=0.9
1907 >>> moving_average_update(x, value = 2.0, momentum=momentum).numpy()
1908 >>> x.numpy()
1909 0.2
1911 The result will be biased towards the initial value of the variable.
1913 If the variable was initialized to zero, you can divide by
1914 `1 - momentum ** num_updates` to debias it (Section 3 of
1915 [Kingma et al., 2015](https://arxiv.org/abs/1412.6980)):
1917 >>> num_updates = 1.0
1918 >>> x_zdb = x/(1 - momentum**num_updates)
1919 >>> x_zdb.numpy()
1920 2.0
1922 Args:
1923 x: A Variable, the moving average.
1924 value: A tensor with the same shape as `x`, the new value to be
1925 averaged in.
1926 momentum: The moving average momentum.
1928 Returns:
1929 The updated variable.
1930 """
1931 if tf2.enabled():
1932 momentum = math_ops.cast(momentum, x.dtype)
1933 value = math_ops.cast(value, x.dtype)
1934 return x.assign(x * momentum + value * (1 - momentum))
1935 else:
1936 return moving_averages.assign_moving_average(
1937 x, value, momentum, zero_debias=True)
1940# LINEAR ALGEBRA
1943@keras_export('keras.backend.dot')
1944@dispatch.add_dispatch_support
1945@doc_controls.do_not_generate_docs
1946def dot(x, y):
1947 """Multiplies 2 tensors (and/or variables) and returns a tensor.
1949 This operation corresponds to `numpy.dot(a, b, out=None)`.
1951 Args:
1952 x: Tensor or variable.
1953 y: Tensor or variable.
1955 Returns:
1956 A tensor, dot product of `x` and `y`.
1958 Examples:
1960 If inputs `x` and `y` are 2-D arrays, then it is equivalent to `tf.matmul`.
1961 >>> x = tf.keras.backend.placeholder(shape=(2, 3))
1962 >>> y = tf.keras.backend.placeholder(shape=(3, 4))
1963 >>> xy = tf.keras.backend.dot(x, y)
1964 >>> xy
1965 <KerasTensor: shape=(2, 4) dtype=float32 ...>
1967 >>> x = tf.keras.backend.placeholder(shape=(32, 28, 3))
1968 >>> y = tf.keras.backend.placeholder(shape=(3, 4))
1969 >>> xy = tf.keras.backend.dot(x, y)
1970 >>> xy
1971 <KerasTensor: shape=(32, 28, 4) dtype=float32 ...>
1973 If `x` is an N-D array and `y` is an M-D array (where M>=2), it is a sum
1974 product over the last axis of `x` and the second-to-last axis of `y`.
1975 >>> x = tf.keras.backend.random_uniform_variable(shape=(2, 3), low=0, high=1)
1976 >>> y = tf.keras.backend.ones((4, 3, 5))
1977 >>> xy = tf.keras.backend.dot(x, y)
1978 >>> tf.keras.backend.int_shape(xy)
1979 (2, 4, 5)
1980 """
1981 if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2):
1982 x_shape = []
1983 for i, s in zip(int_shape(x), array_ops_stack.unstack(array_ops.shape(x))):
1984 if i is not None:
1985 x_shape.append(i)
1986 else:
1987 x_shape.append(s)
1988 x_shape = tuple(x_shape)
1989 y_shape = []
1990 for i, s in zip(int_shape(y), array_ops_stack.unstack(array_ops.shape(y))):
1991 if i is not None:
1992 y_shape.append(i)
1993 else:
1994 y_shape.append(s)
1995 y_shape = tuple(y_shape)
1996 y_permute_dim = list(range(ndim(y)))
1997 y_permute_dim = [y_permute_dim.pop(-2)] + y_permute_dim
1998 xt = array_ops.reshape(x, [-1, x_shape[-1]])
1999 yt = array_ops.reshape(
2000 array_ops.transpose(y, perm=y_permute_dim), [y_shape[-2], -1])
2001 return array_ops.reshape(
2002 math_ops.matmul(xt, yt), x_shape[:-1] + y_shape[:-2] + y_shape[-1:])
2003 if is_sparse(x):
2004 out = sparse_ops.sparse_tensor_dense_matmul(x, y)
2005 else:
2006 out = math_ops.matmul(x, y)
2007 return out
2010@keras_export('keras.backend.batch_dot')
2011@dispatch.add_dispatch_support
2012@doc_controls.do_not_generate_docs
2013def batch_dot(x, y, axes=None):
2014 """Batchwise dot product.
2016 `batch_dot` is used to compute dot product of `x` and `y` when
2017 `x` and `y` are data in batch, i.e. in a shape of
2018 `(batch_size, :)`.
2019 `batch_dot` results in a tensor or variable with less dimensions
2020 than the input. If the number of dimensions is reduced to 1,
2021 we use `expand_dims` to make sure that ndim is at least 2.
2023 Args:
2024 x: Keras tensor or variable with `ndim >= 2`.
2025 y: Keras tensor or variable with `ndim >= 2`.
2026 axes: Tuple or list of integers with target dimensions, or single integer.
2027 The sizes of `x.shape[axes[0]]` and `y.shape[axes[1]]` should be equal.
2029 Returns:
2030 A tensor with shape equal to the concatenation of `x`'s shape
2031 (less the dimension that was summed over) and `y`'s shape
2032 (less the batch dimension and the dimension that was summed over).
2033 If the final rank is 1, we reshape it to `(batch_size, 1)`.
2035 Examples:
2037 >>> x_batch = tf.keras.backend.ones(shape=(32, 20, 1))
2038 >>> y_batch = tf.keras.backend.ones(shape=(32, 30, 20))
2039 >>> xy_batch_dot = tf.keras.backend.batch_dot(x_batch, y_batch, axes=(1, 2))
2040 >>> tf.keras.backend.int_shape(xy_batch_dot)
2041 (32, 1, 30)
2043 Shape inference:
2044 Let `x`'s shape be `(100, 20)` and `y`'s shape be `(100, 30, 20)`.
2045 If `axes` is (1, 2), to find the output shape of resultant tensor,
2046 loop through each dimension in `x`'s shape and `y`'s shape:
2047 * `x.shape[0]` : 100 : append to output shape
2048 * `x.shape[1]` : 20 : do not append to output shape,
2049 dimension 1 of `x` has been summed over. (`dot_axes[0]` = 1)
2050 * `y.shape[0]` : 100 : do not append to output shape,
2051 always ignore first dimension of `y`
2052 * `y.shape[1]` : 30 : append to output shape
2053 * `y.shape[2]` : 20 : do not append to output shape,
2054 dimension 2 of `y` has been summed over. (`dot_axes[1]` = 2)
2055 `output_shape` = `(100, 30)`
2056 """
2057 x_shape = int_shape(x)
2058 y_shape = int_shape(y)
2060 x_ndim = len(x_shape)
2061 y_ndim = len(y_shape)
2063 if x_ndim < 2 or y_ndim < 2:
2064 raise ValueError('Cannot do batch_dot on inputs '
2065 'with rank < 2. '
2066 'Received inputs with shapes ' +
2067 str(x_shape) + ' and ' +
2068 str(y_shape) + '.')
2070 x_batch_size = x_shape[0]
2071 y_batch_size = y_shape[0]
2073 if x_batch_size is not None and y_batch_size is not None:
2074 if x_batch_size != y_batch_size:
2075 raise ValueError('Cannot do batch_dot on inputs '
2076 'with different batch sizes. '
2077 'Received inputs with shapes ' +
2078 str(x_shape) + ' and ' +
2079 str(y_shape) + '.')
2080 if isinstance(axes, int):
2081 axes = [axes, axes]
2083 if axes is None:
2084 if y_ndim == 2:
2085 axes = [x_ndim - 1, y_ndim - 1]
2086 else:
2087 axes = [x_ndim - 1, y_ndim - 2]
2089 if py_any(isinstance(a, (list, tuple)) for a in axes):
2090 raise ValueError('Multiple target dimensions are not supported. ' +
2091 'Expected: None, int, (int, int), ' +
2092 'Provided: ' + str(axes))
2094 # if tuple, convert to list.
2095 axes = list(axes)
2097 # convert negative indices.
2098 if axes[0] < 0:
2099 axes[0] += x_ndim
2100 if axes[1] < 0:
2101 axes[1] += y_ndim
2103 # sanity checks
2104 if 0 in axes:
2105 raise ValueError('Cannot perform batch_dot over axis 0. '
2106 'If your inputs are not batched, '
2107 'add a dummy batch dimension to your '
2108 'inputs using K.expand_dims(x, 0)')
2109 a0, a1 = axes
2110 d1 = x_shape[a0]
2111 d2 = y_shape[a1]
2113 if d1 is not None and d2 is not None and d1 != d2:
2114 raise ValueError('Cannot do batch_dot on inputs with shapes ' +
2115 str(x_shape) + ' and ' + str(y_shape) +
2116 ' with axes=' + str(axes) + '. x.shape[%d] != '
2117 'y.shape[%d] (%d != %d).' % (axes[0], axes[1], d1, d2))
2119 # backup ndims. Need them later.
2120 orig_x_ndim = x_ndim
2121 orig_y_ndim = y_ndim
2123 # if rank is 2, expand to 3.
2124 if x_ndim == 2:
2125 x = array_ops.expand_dims(x, 1)
2126 a0 += 1
2127 x_ndim += 1
2128 if y_ndim == 2:
2129 y = array_ops.expand_dims(y, 2)
2130 y_ndim += 1
2132 # bring x's dimension to be reduced to last axis.
2133 if a0 != x_ndim - 1:
2134 pattern = list(range(x_ndim))
2135 for i in range(a0, x_ndim - 1):
2136 pattern[i] = pattern[i + 1]
2137 pattern[-1] = a0
2138 x = array_ops.transpose(x, pattern)
2140 # bring y's dimension to be reduced to axis 1.
2141 if a1 != 1:
2142 pattern = list(range(y_ndim))
2143 for i in range(a1, 1, -1):
2144 pattern[i] = pattern[i - 1]
2145 pattern[1] = a1
2146 y = array_ops.transpose(y, pattern)
2148 # normalize both inputs to rank 3.
2149 if x_ndim > 3:
2150 # squash middle dimensions of x.
2151 x_shape = shape(x)
2152 x_mid_dims = x_shape[1:-1]
2153 x_squashed_shape = array_ops_stack.stack(
2154 [x_shape[0], -1, x_shape[-1]])
2155 x = array_ops.reshape(x, x_squashed_shape)
2156 x_squashed = True
2157 else:
2158 x_squashed = False
2160 if y_ndim > 3:
2161 # squash trailing dimensions of y.
2162 y_shape = shape(y)
2163 y_trail_dims = y_shape[2:]
2164 y_squashed_shape = array_ops_stack.stack(
2165 [y_shape[0], y_shape[1], -1])
2166 y = array_ops.reshape(y, y_squashed_shape)
2167 y_squashed = True
2168 else:
2169 y_squashed = False
2171 result = math_ops.matmul(x, y)
2173 # if inputs were squashed, we have to reshape the matmul output.
2174 output_shape = array_ops.shape(result)
2175 do_reshape = False
2177 if x_squashed:
2178 output_shape = array_ops.concat(
2179 [output_shape[:1],
2180 x_mid_dims,
2181 output_shape[-1:]], 0)
2182 do_reshape = True
2184 if y_squashed:
2185 output_shape = array_ops.concat([output_shape[:-1], y_trail_dims], 0)
2186 do_reshape = True
2188 if do_reshape:
2189 result = array_ops.reshape(result, output_shape)
2191 # if the inputs were originally rank 2, we remove the added 1 dim.
2192 if orig_x_ndim == 2:
2193 result = array_ops.squeeze(result, 1)
2194 elif orig_y_ndim == 2:
2195 result = array_ops.squeeze(result, -1)
2197 return result
2200@keras_export('keras.backend.transpose')
2201@dispatch.add_dispatch_support
2202@doc_controls.do_not_generate_docs
2203def transpose(x):
2204 """Transposes a tensor and returns it.
2206 Args:
2207 x: Tensor or variable.
2209 Returns:
2210 A tensor.
2212 Examples:
2214 >>> var = tf.keras.backend.variable([[1, 2, 3], [4, 5, 6]])
2215 >>> tf.keras.backend.eval(var)
2216 array([[1., 2., 3.],
2217 [4., 5., 6.]], dtype=float32)
2218 >>> var_transposed = tf.keras.backend.transpose(var)
2219 >>> tf.keras.backend.eval(var_transposed)
2220 array([[1., 4.],
2221 [2., 5.],
2222 [3., 6.]], dtype=float32)
2223 >>> input = tf.keras.backend.placeholder((2, 3))
2224 >>> input
2225 <KerasTensor: shape=(2, 3) dtype=float32 ...>
2226 >>> input_transposed = tf.keras.backend.transpose(input)
2227 >>> input_transposed
2228 <KerasTensor: shape=(3, 2) dtype=float32 ...>
2229 """
2230 return array_ops.transpose(x)
2233@keras_export('keras.backend.gather')
2234@dispatch.add_dispatch_support
2235@doc_controls.do_not_generate_docs
2236def gather(reference, indices):
2237 """Retrieves the elements of indices `indices` in the tensor `reference`.
2239 Args:
2240 reference: A tensor.
2241 indices: An integer tensor of indices.
2243 Returns:
2244 A tensor of same type as `reference`.
2246 Examples:
2248 >>> var = tf.keras.backend.variable([[1, 2, 3], [4, 5, 6]])
2249 >>> tf.keras.backend.eval(var)
2250 array([[1., 2., 3.],
2251 [4., 5., 6.]], dtype=float32)
2252 >>> var_gathered = tf.keras.backend.gather(var, [0])
2253 >>> tf.keras.backend.eval(var_gathered)
2254 array([[1., 2., 3.]], dtype=float32)
2255 >>> var_gathered = tf.keras.backend.gather(var, [1])
2256 >>> tf.keras.backend.eval(var_gathered)
2257 array([[4., 5., 6.]], dtype=float32)
2258 >>> var_gathered = tf.keras.backend.gather(var, [0,1,0])
2259 >>> tf.keras.backend.eval(var_gathered)
2260 array([[1., 2., 3.],
2261 [4., 5., 6.],
2262 [1., 2., 3.]], dtype=float32)
2263 """
2264 return array_ops.gather(reference, indices)
2267# ELEMENT-WISE OPERATIONS
2270@keras_export('keras.backend.max')
2271@dispatch.add_dispatch_support
2272@doc_controls.do_not_generate_docs
2273def max(x, axis=None, keepdims=False):
2274 """Maximum value in a tensor.
2276 Args:
2277 x: A tensor or variable.
2278 axis: An integer, the axis to find maximum values.
2279 keepdims: A boolean, whether to keep the dimensions or not.
2280 If `keepdims` is `False`, the rank of the tensor is reduced
2281 by 1. If `keepdims` is `True`,
2282 the reduced dimension is retained with length 1.
2284 Returns:
2285 A tensor with maximum values of `x`.
2286 """
2287 return math_ops.reduce_max(x, axis, keepdims)
2290@keras_export('keras.backend.min')
2291@dispatch.add_dispatch_support
2292@doc_controls.do_not_generate_docs
2293def min(x, axis=None, keepdims=False):
2294 """Minimum value in a tensor.
2296 Args:
2297 x: A tensor or variable.
2298 axis: An integer, the axis to find minimum values.
2299 keepdims: A boolean, whether to keep the dimensions or not.
2300 If `keepdims` is `False`, the rank of the tensor is reduced
2301 by 1. If `keepdims` is `True`,
2302 the reduced dimension is retained with length 1.
2304 Returns:
2305 A tensor with minimum values of `x`.
2306 """
2307 return math_ops.reduce_min(x, axis, keepdims)
2310@keras_export('keras.backend.sum')
2311@dispatch.add_dispatch_support
2312@doc_controls.do_not_generate_docs
2313def sum(x, axis=None, keepdims=False):
2314 """Sum of the values in a tensor, alongside the specified axis.
2316 Args:
2317 x: A tensor or variable.
2318 axis: An integer, the axis to sum over.
2319 keepdims: A boolean, whether to keep the dimensions or not.
2320 If `keepdims` is `False`, the rank of the tensor is reduced
2321 by 1. If `keepdims` is `True`,
2322 the reduced dimension is retained with length 1.
2324 Returns:
2325 A tensor with sum of `x`.
2326 """
2327 return math_ops.reduce_sum(x, axis, keepdims)
2330@keras_export('keras.backend.prod')
2331@dispatch.add_dispatch_support
2332@doc_controls.do_not_generate_docs
2333def prod(x, axis=None, keepdims=False):
2334 """Multiplies the values in a tensor, alongside the specified axis.
2336 Args:
2337 x: A tensor or variable.
2338 axis: An integer, the axis to compute the product.
2339 keepdims: A boolean, whether to keep the dimensions or not.
2340 If `keepdims` is `False`, the rank of the tensor is reduced
2341 by 1. If `keepdims` is `True`,
2342 the reduced dimension is retained with length 1.
2344 Returns:
2345 A tensor with the product of elements of `x`.
2346 """
2347 return math_ops.reduce_prod(x, axis, keepdims)
2350@keras_export('keras.backend.cumsum')
2351@dispatch.add_dispatch_support
2352@doc_controls.do_not_generate_docs
2353def cumsum(x, axis=0):
2354 """Cumulative sum of the values in a tensor, alongside the specified axis.
2356 Args:
2357 x: A tensor or variable.
2358 axis: An integer, the axis to compute the sum.
2360 Returns:
2361 A tensor of the cumulative sum of values of `x` along `axis`.
2362 """
2363 return math_ops.cumsum(x, axis=axis)
2366@keras_export('keras.backend.cumprod')
2367@dispatch.add_dispatch_support
2368@doc_controls.do_not_generate_docs
2369def cumprod(x, axis=0):
2370 """Cumulative product of the values in a tensor, alongside the specified axis.
2372 Args:
2373 x: A tensor or variable.
2374 axis: An integer, the axis to compute the product.
2376 Returns:
2377 A tensor of the cumulative product of values of `x` along `axis`.
2378 """
2379 return math_ops.cumprod(x, axis=axis)
2382@keras_export('keras.backend.var')
2383@doc_controls.do_not_generate_docs
2384def var(x, axis=None, keepdims=False):
2385 """Variance of a tensor, alongside the specified axis.
2387 Args:
2388 x: A tensor or variable.
2389 axis: An integer, the axis to compute the variance.
2390 keepdims: A boolean, whether to keep the dimensions or not.
2391 If `keepdims` is `False`, the rank of the tensor is reduced
2392 by 1. If `keepdims` is `True`,
2393 the reduced dimension is retained with length 1.
2395 Returns:
2396 A tensor with the variance of elements of `x`.
2397 """
2398 if x.dtype.base_dtype == dtypes_module.bool:
2399 x = math_ops.cast(x, floatx())
2400 return math_ops.reduce_variance(x, axis=axis, keepdims=keepdims)
2403@keras_export('keras.backend.std')
2404@dispatch.add_dispatch_support
2405@doc_controls.do_not_generate_docs
2406def std(x, axis=None, keepdims=False):
2407 """Standard deviation of a tensor, alongside the specified axis.
2409 It is an alias to `tf.math.reduce_std`.
2411 Args:
2412 x: A tensor or variable. It should have numerical dtypes. Boolean type
2413 inputs will be converted to float.
2414 axis: An integer, the axis to compute the standard deviation. If `None`
2415 (the default), reduces all dimensions. Must be in the range
2416 `[-rank(x), rank(x))`.
2417 keepdims: A boolean, whether to keep the dimensions or not.
2418 If `keepdims` is `False`, the rank of the tensor is reduced
2419 by 1. If `keepdims` is `True`, the reduced dimension is retained with
2420 length 1.
2422 Returns:
2423 A tensor with the standard deviation of elements of `x` with same dtype.
2424 Boolean type input will be converted to float.
2425 """
2426 if x.dtype.base_dtype == dtypes_module.bool:
2427 x = math_ops.cast(x, floatx())
2428 return math_ops.reduce_std(x, axis=axis, keepdims=keepdims)
2431@keras_export('keras.backend.mean')
2432@dispatch.add_dispatch_support
2433@doc_controls.do_not_generate_docs
2434def mean(x, axis=None, keepdims=False):
2435 """Mean of a tensor, alongside the specified axis.
2437 Args:
2438 x: A tensor or variable.
2439 axis: A list of integer. Axes to compute the mean.
2440 keepdims: A boolean, whether to keep the dimensions or not.
2441 If `keepdims` is `False`, the rank of the tensor is reduced
2442 by 1 for each entry in `axis`. If `keepdims` is `True`,
2443 the reduced dimensions are retained with length 1.
2445 Returns:
2446 A tensor with the mean of elements of `x`.
2447 """
2448 if x.dtype.base_dtype == dtypes_module.bool:
2449 x = math_ops.cast(x, floatx())
2450 return math_ops.reduce_mean(x, axis, keepdims)
2453@keras_export('keras.backend.any')
2454@dispatch.add_dispatch_support
2455@doc_controls.do_not_generate_docs
2456def any(x, axis=None, keepdims=False):
2457 """Bitwise reduction (logical OR).
2459 Args:
2460 x: Tensor or variable.
2461 axis: axis along which to perform the reduction.
2462 keepdims: whether the drop or broadcast the reduction axes.
2464 Returns:
2465 A uint8 tensor (0s and 1s).
2466 """
2467 x = math_ops.cast(x, dtypes_module.bool)
2468 return math_ops.reduce_any(x, axis, keepdims)
2471@keras_export('keras.backend.all')
2472@dispatch.add_dispatch_support
2473@doc_controls.do_not_generate_docs
2474def all(x, axis=None, keepdims=False):
2475 """Bitwise reduction (logical AND).
2477 Args:
2478 x: Tensor or variable.
2479 axis: axis along which to perform the reduction.
2480 keepdims: whether the drop or broadcast the reduction axes.
2482 Returns:
2483 A uint8 tensor (0s and 1s).
2484 """
2485 x = math_ops.cast(x, dtypes_module.bool)
2486 return math_ops.reduce_all(x, axis, keepdims)
2489@keras_export('keras.backend.argmax')
2490@dispatch.add_dispatch_support
2491@doc_controls.do_not_generate_docs
2492def argmax(x, axis=-1):
2493 """Returns the index of the maximum value along an axis.
2495 Args:
2496 x: Tensor or variable.
2497 axis: axis along which to perform the reduction.
2499 Returns:
2500 A tensor.
2501 """
2502 return math_ops.argmax(x, axis)
2505@keras_export('keras.backend.argmin')
2506@dispatch.add_dispatch_support
2507@doc_controls.do_not_generate_docs
2508def argmin(x, axis=-1):
2509 """Returns the index of the minimum value along an axis.
2511 Args:
2512 x: Tensor or variable.
2513 axis: axis along which to perform the reduction.
2515 Returns:
2516 A tensor.
2517 """
2518 return math_ops.argmin(x, axis)
2521@keras_export('keras.backend.square')
2522@dispatch.add_dispatch_support
2523@doc_controls.do_not_generate_docs
2524def square(x):
2525 """Element-wise square.
2527 Args:
2528 x: Tensor or variable.
2530 Returns:
2531 A tensor.
2532 """
2533 return math_ops.square(x)
2536@keras_export('keras.backend.abs')
2537@dispatch.add_dispatch_support
2538@doc_controls.do_not_generate_docs
2539def abs(x):
2540 """Element-wise absolute value.
2542 Args:
2543 x: Tensor or variable.
2545 Returns:
2546 A tensor.
2547 """
2548 return math_ops.abs(x)
2551@keras_export('keras.backend.sqrt')
2552@dispatch.add_dispatch_support
2553@doc_controls.do_not_generate_docs
2554def sqrt(x):
2555 """Element-wise square root.
2557 This function clips negative tensor values to 0 before computing the
2558 square root.
2560 Args:
2561 x: Tensor or variable.
2563 Returns:
2564 A tensor.
2565 """
2566 zero = _constant_to_tensor(0., x.dtype.base_dtype)
2567 x = math_ops.maximum(x, zero)
2568 return math_ops.sqrt(x)
2571@keras_export('keras.backend.exp')
2572@dispatch.add_dispatch_support
2573@doc_controls.do_not_generate_docs
2574def exp(x):
2575 """Element-wise exponential.
2577 Args:
2578 x: Tensor or variable.
2580 Returns:
2581 A tensor.
2582 """
2583 return math_ops.exp(x)
2586@keras_export('keras.backend.log')
2587@dispatch.add_dispatch_support
2588@doc_controls.do_not_generate_docs
2589def log(x):
2590 """Element-wise log.
2592 Args:
2593 x: Tensor or variable.
2595 Returns:
2596 A tensor.
2597 """
2598 return math_ops.log(x)
2601def logsumexp(x, axis=None, keepdims=False):
2602 """Computes log(sum(exp(elements across dimensions of a tensor))).
2604 This function is more numerically stable than log(sum(exp(x))).
2605 It avoids overflows caused by taking the exp of large inputs and
2606 underflows caused by taking the log of small inputs.
2608 Args:
2609 x: A tensor or variable.
2610 axis: An integer, the axis to reduce over.
2611 keepdims: A boolean, whether to keep the dimensions or not.
2612 If `keepdims` is `False`, the rank of the tensor is reduced
2613 by 1. If `keepdims` is `True`, the reduced dimension is
2614 retained with length 1.
2616 Returns:
2617 The reduced tensor.
2618 """
2619 return math_ops.reduce_logsumexp(x, axis, keepdims)
2622@keras_export('keras.backend.round')
2623@dispatch.add_dispatch_support
2624@doc_controls.do_not_generate_docs
2625def round(x):
2626 """Element-wise rounding to the closest integer.
2628 In case of tie, the rounding mode used is "half to even".
2630 Args:
2631 x: Tensor or variable.
2633 Returns:
2634 A tensor.
2635 """
2636 return math_ops.round(x)
2639@keras_export('keras.backend.sign')
2640@dispatch.add_dispatch_support
2641@doc_controls.do_not_generate_docs
2642def sign(x):
2643 """Element-wise sign.
2645 Args:
2646 x: Tensor or variable.
2648 Returns:
2649 A tensor.
2650 """
2651 return math_ops.sign(x)
2654@keras_export('keras.backend.pow')
2655@dispatch.add_dispatch_support
2656@doc_controls.do_not_generate_docs
2657def pow(x, a):
2658 """Element-wise exponentiation.
2660 Args:
2661 x: Tensor or variable.
2662 a: Python integer.
2664 Returns:
2665 A tensor.
2666 """
2667 return math_ops.pow(x, a)
2670@keras_export('keras.backend.clip')
2671@dispatch.add_dispatch_support
2672@doc_controls.do_not_generate_docs
2673def clip(x, min_value, max_value):
2674 """Element-wise value clipping.
2676 Args:
2677 x: Tensor or variable.
2678 min_value: Python float, integer, or tensor.
2679 max_value: Python float, integer, or tensor.
2681 Returns:
2682 A tensor.
2683 """
2684 if (isinstance(min_value, (int, float)) and
2685 isinstance(max_value, (int, float))):
2686 if max_value < min_value:
2687 max_value = min_value
2688 if min_value is None:
2689 min_value = -np.inf
2690 if max_value is None:
2691 max_value = np.inf
2692 return clip_ops.clip_by_value(x, min_value, max_value)
2695@keras_export('keras.backend.equal')
2696@dispatch.add_dispatch_support
2697@doc_controls.do_not_generate_docs
2698def equal(x, y):
2699 """Element-wise equality between two tensors.
2701 Args:
2702 x: Tensor or variable.
2703 y: Tensor or variable.
2705 Returns:
2706 A bool tensor.
2707 """
2708 return math_ops.equal(x, y)
2711@keras_export('keras.backend.not_equal')
2712@dispatch.add_dispatch_support
2713@doc_controls.do_not_generate_docs
2714def not_equal(x, y):
2715 """Element-wise inequality between two tensors.
2717 Args:
2718 x: Tensor or variable.
2719 y: Tensor or variable.
2721 Returns:
2722 A bool tensor.
2723 """
2724 return math_ops.not_equal(x, y)
2727@keras_export('keras.backend.greater')
2728@dispatch.add_dispatch_support
2729@doc_controls.do_not_generate_docs
2730def greater(x, y):
2731 """Element-wise truth value of (x > y).
2733 Args:
2734 x: Tensor or variable.
2735 y: Tensor or variable.
2737 Returns:
2738 A bool tensor.
2739 """
2740 return math_ops.greater(x, y)
2743@keras_export('keras.backend.greater_equal')
2744@dispatch.add_dispatch_support
2745@doc_controls.do_not_generate_docs
2746def greater_equal(x, y):
2747 """Element-wise truth value of (x >= y).
2749 Args:
2750 x: Tensor or variable.
2751 y: Tensor or variable.
2753 Returns:
2754 A bool tensor.
2755 """
2756 return math_ops.greater_equal(x, y)
2759@keras_export('keras.backend.less')
2760@dispatch.add_dispatch_support
2761@doc_controls.do_not_generate_docs
2762def less(x, y):
2763 """Element-wise truth value of (x < y).
2765 Args:
2766 x: Tensor or variable.
2767 y: Tensor or variable.
2769 Returns:
2770 A bool tensor.
2771 """
2772 return math_ops.less(x, y)
2775@keras_export('keras.backend.less_equal')
2776@dispatch.add_dispatch_support
2777@doc_controls.do_not_generate_docs
2778def less_equal(x, y):
2779 """Element-wise truth value of (x <= y).
2781 Args:
2782 x: Tensor or variable.
2783 y: Tensor or variable.
2785 Returns:
2786 A bool tensor.
2787 """
2788 return math_ops.less_equal(x, y)
2791@keras_export('keras.backend.maximum')
2792@dispatch.add_dispatch_support
2793@doc_controls.do_not_generate_docs
2794def maximum(x, y):
2795 """Element-wise maximum of two tensors.
2797 Args:
2798 x: Tensor or variable.
2799 y: Tensor or variable.
2801 Returns:
2802 A tensor with the element wise maximum value(s) of `x` and `y`.
2804 Examples:
2806 >>> x = tf.Variable([[1, 2], [3, 4]])
2807 >>> y = tf.Variable([[2, 1], [0, -1]])
2808 >>> m = tf.keras.backend.maximum(x, y)
2809 >>> m
2810 <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
2811 array([[2, 2],
2812 [3, 4]], dtype=int32)>
2813 """
2814 return math_ops.maximum(x, y)
2817@keras_export('keras.backend.minimum')
2818@dispatch.add_dispatch_support
2819@doc_controls.do_not_generate_docs
2820def minimum(x, y):
2821 """Element-wise minimum of two tensors.
2823 Args:
2824 x: Tensor or variable.
2825 y: Tensor or variable.
2827 Returns:
2828 A tensor.
2829 """
2830 return math_ops.minimum(x, y)
2833@keras_export('keras.backend.sin')
2834@dispatch.add_dispatch_support
2835@doc_controls.do_not_generate_docs
2836def sin(x):
2837 """Computes sin of x element-wise.
2839 Args:
2840 x: Tensor or variable.
2842 Returns:
2843 A tensor.
2844 """
2845 return math_ops.sin(x)
2848@keras_export('keras.backend.cos')
2849@dispatch.add_dispatch_support
2850@doc_controls.do_not_generate_docs
2851def cos(x):
2852 """Computes cos of x element-wise.
2854 Args:
2855 x: Tensor or variable.
2857 Returns:
2858 A tensor.
2859 """
2860 return math_ops.cos(x)
2863def _regular_normalize_batch_in_training(x,
2864 gamma,
2865 beta,
2866 reduction_axes,
2867 epsilon=1e-3):
2868 """Non-fused version of `normalize_batch_in_training`.
2870 Args:
2871 x: Input tensor or variable.
2872 gamma: Tensor by which to scale the input.
2873 beta: Tensor with which to center the input.
2874 reduction_axes: iterable of integers,
2875 axes over which to normalize.
2876 epsilon: Fuzz factor.
2878 Returns:
2879 A tuple length of 3, `(normalized_tensor, mean, variance)`.
2880 """
2881 mean, var = nn.moments(x, reduction_axes, None, None, False)
2882 normed = nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
2883 return normed, mean, var
2886def _broadcast_normalize_batch_in_training(x,
2887 gamma,
2888 beta,
2889 reduction_axes,
2890 epsilon=1e-3):
2891 """Non-fused, broadcast version of `normalize_batch_in_training`.
2893 Args:
2894 x: Input tensor or variable.
2895 gamma: Tensor by which to scale the input.
2896 beta: Tensor with which to center the input.
2897 reduction_axes: iterable of integers,
2898 axes over which to normalize.
2899 epsilon: Fuzz factor.
2901 Returns:
2902 A tuple length of 3, `(normalized_tensor, mean, variance)`.
2903 """
2904 mean, var = nn.moments(x, reduction_axes, None, None, False)
2905 target_shape = []
2906 for axis in range(ndim(x)):
2907 if axis in reduction_axes:
2908 target_shape.append(1)
2909 else:
2910 target_shape.append(array_ops.shape(x)[axis])
2911 target_shape = array_ops_stack.stack(target_shape)
2913 broadcast_mean = array_ops.reshape(mean, target_shape)
2914 broadcast_var = array_ops.reshape(var, target_shape)
2915 if gamma is None:
2916 broadcast_gamma = None
2917 else:
2918 broadcast_gamma = array_ops.reshape(gamma, target_shape)
2919 if beta is None:
2920 broadcast_beta = None
2921 else:
2922 broadcast_beta = array_ops.reshape(beta, target_shape)
2924 normed = nn.batch_normalization(x, broadcast_mean, broadcast_var,
2925 broadcast_beta, broadcast_gamma, epsilon)
2926 return normed, mean, var
2929def _fused_normalize_batch_in_training(x,
2930 gamma,
2931 beta,
2932 reduction_axes,
2933 epsilon=1e-3):
2934 """Fused version of `normalize_batch_in_training`.
2936 Args:
2937 x: Input tensor or variable.
2938 gamma: Tensor by which to scale the input.
2939 beta: Tensor with which to center the input.
2940 reduction_axes: iterable of integers,
2941 axes over which to normalize.
2942 epsilon: Fuzz factor.
2944 Returns:
2945 A tuple length of 3, `(normalized_tensor, mean, variance)`.
2946 """
2947 if list(reduction_axes) == [0, 1, 2]:
2948 normalization_axis = 3
2949 tf_data_format = 'NHWC'
2950 else:
2951 normalization_axis = 1
2952 tf_data_format = 'NCHW'
2954 if gamma is None:
2955 gamma = constant_op.constant(
2956 1.0, dtype=x.dtype, shape=[x.shape[normalization_axis]])
2957 if beta is None:
2958 beta = constant_op.constant(
2959 0.0, dtype=x.dtype, shape=[x.shape[normalization_axis]])
2961 return nn.fused_batch_norm(
2962 x, gamma, beta, epsilon=epsilon, data_format=tf_data_format)
2965@keras_export('keras.backend.normalize_batch_in_training')
2966@doc_controls.do_not_generate_docs
2967def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
2968 """Computes mean and std for batch then apply batch_normalization on batch.
2970 Args:
2971 x: Input tensor or variable.
2972 gamma: Tensor by which to scale the input.
2973 beta: Tensor with which to center the input.
2974 reduction_axes: iterable of integers,
2975 axes over which to normalize.
2976 epsilon: Fuzz factor.
2978 Returns:
2979 A tuple length of 3, `(normalized_tensor, mean, variance)`.
2980 """
2981 if ndim(x) == 4 and list(reduction_axes) in [[0, 1, 2], [0, 2, 3]]:
2982 if not _has_nchw_support() and list(reduction_axes) == [0, 2, 3]:
2983 return _broadcast_normalize_batch_in_training(
2984 x, gamma, beta, reduction_axes, epsilon=epsilon)
2985 return _fused_normalize_batch_in_training(
2986 x, gamma, beta, reduction_axes, epsilon=epsilon)
2987 else:
2988 if sorted(reduction_axes) == list(range(ndim(x)))[:-1]:
2989 return _regular_normalize_batch_in_training(
2990 x, gamma, beta, reduction_axes, epsilon=epsilon)
2991 else:
2992 return _broadcast_normalize_batch_in_training(
2993 x, gamma, beta, reduction_axes, epsilon=epsilon)
2996@keras_export('keras.backend.batch_normalization')
2997@dispatch.add_dispatch_support
2998@doc_controls.do_not_generate_docs
2999def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
3000 """Applies batch normalization on x given mean, var, beta and gamma.
3002 I.e. returns:
3003 `output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta`
3005 Args:
3006 x: Input tensor or variable.
3007 mean: Mean of batch.
3008 var: Variance of batch.
3009 beta: Tensor with which to center the input.
3010 gamma: Tensor by which to scale the input.
3011 axis: Integer, the axis that should be normalized.
3012 (typically the features axis).
3013 epsilon: Fuzz factor.
3015 Returns:
3016 A tensor.
3017 """
3018 if ndim(x) == 4:
3019 # The CPU implementation of `fused_batch_norm` only supports NHWC
3020 if axis == 1 or axis == -3:
3021 tf_data_format = 'NCHW'
3022 elif axis == 3 or axis == -1:
3023 tf_data_format = 'NHWC'
3024 else:
3025 tf_data_format = None
3027 if (tf_data_format == 'NHWC' or
3028 tf_data_format == 'NCHW' and _has_nchw_support()):
3029 # The mean / var / beta / gamma tensors may be broadcasted
3030 # so they may have extra axes of size 1, which should be squeezed.
3031 if ndim(mean) > 1:
3032 mean = array_ops.reshape(mean, [-1])
3033 if ndim(var) > 1:
3034 var = array_ops.reshape(var, [-1])
3035 if beta is None:
3036 beta = zeros_like(mean)
3037 elif ndim(beta) > 1:
3038 beta = array_ops.reshape(beta, [-1])
3039 if gamma is None:
3040 gamma = ones_like(mean)
3041 elif ndim(gamma) > 1:
3042 gamma = array_ops.reshape(gamma, [-1])
3043 y, _, _ = nn.fused_batch_norm(
3044 x,
3045 gamma,
3046 beta,
3047 epsilon=epsilon,
3048 mean=mean,
3049 variance=var,
3050 data_format=tf_data_format,
3051 is_training=False
3052 )
3053 return y
3054 return nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
3057# SHAPE OPERATIONS
3060@keras_export('keras.backend.concatenate')
3061@dispatch.add_dispatch_support
3062@doc_controls.do_not_generate_docs
3063def concatenate(tensors, axis=-1):
3064 """Concatenates a list of tensors alongside the specified axis.
3066 Args:
3067 tensors: list of tensors to concatenate.
3068 axis: concatenation axis.
3070 Returns:
3071 A tensor.
3073 Example:
3075 >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
3076 >>> b = tf.constant([[10, 20, 30], [40, 50, 60], [70, 80, 90]])
3077 >>> tf.keras.backend.concatenate((a, b), axis=-1)
3078 <tf.Tensor: shape=(3, 6), dtype=int32, numpy=
3079 array([[ 1, 2, 3, 10, 20, 30],
3080 [ 4, 5, 6, 40, 50, 60],
3081 [ 7, 8, 9, 70, 80, 90]], dtype=int32)>
3083 """
3084 if axis < 0:
3085 rank = ndim(tensors[0])
3086 if rank:
3087 axis %= rank
3088 else:
3089 axis = 0
3091 if py_all(is_sparse(x) for x in tensors):
3092 return sparse_ops.sparse_concat(axis, tensors)
3093 elif py_all(isinstance(x, ragged_tensor.RaggedTensor) for x in tensors):
3094 return array_ops.concat(tensors, axis)
3095 else:
3096 return array_ops.concat([to_dense(x) for x in tensors], axis)
3099@keras_export('keras.backend.reshape')
3100@dispatch.add_dispatch_support
3101@doc_controls.do_not_generate_docs
3102def reshape(x, shape):
3103 """Reshapes a tensor to the specified shape.
3105 Args:
3106 x: Tensor or variable.
3107 shape: Target shape tuple.
3109 Returns:
3110 A tensor.
3112 Example:
3114 >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
3115 >>> a
3116 <tf.Tensor: shape=(4, 3), dtype=int32, numpy=
3117 array([[ 1, 2, 3],
3118 [ 4, 5, 6],
3119 [ 7, 8, 9],
3120 [10, 11, 12]], dtype=int32)>
3121 >>> tf.keras.backend.reshape(a, shape=(2, 6))
3122 <tf.Tensor: shape=(2, 6), dtype=int32, numpy=
3123 array([[ 1, 2, 3, 4, 5, 6],
3124 [ 7, 8, 9, 10, 11, 12]], dtype=int32)>
3126 """
3127 return array_ops.reshape(x, shape)
3130@keras_export('keras.backend.permute_dimensions')
3131@dispatch.add_dispatch_support
3132@doc_controls.do_not_generate_docs
3133def permute_dimensions(x, pattern):
3134 """Permutes axes in a tensor.
3136 Args:
3137 x: Tensor or variable.
3138 pattern: A tuple of
3139 dimension indices, e.g. `(0, 2, 1)`.
3141 Returns:
3142 A tensor.
3144 Example:
3146 >>> a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
3147 >>> a
3148 <tf.Tensor: shape=(4, 3), dtype=int32, numpy=
3149 array([[ 1, 2, 3],
3150 [ 4, 5, 6],
3151 [ 7, 8, 9],
3152 [10, 11, 12]], dtype=int32)>
3153 >>> tf.keras.backend.permute_dimensions(a, pattern=(1, 0))
3154 <tf.Tensor: shape=(3, 4), dtype=int32, numpy=
3155 array([[ 1, 4, 7, 10],
3156 [ 2, 5, 8, 11],
3157 [ 3, 6, 9, 12]], dtype=int32)>
3159 """
3160 return array_ops.transpose(x, perm=pattern)
3163@keras_export('keras.backend.resize_images')
3164@dispatch.add_dispatch_support
3165@doc_controls.do_not_generate_docs
3166def resize_images(x, height_factor, width_factor, data_format,
3167 interpolation='nearest'):
3168 """Resizes the images contained in a 4D tensor.
3170 Args:
3171 x: Tensor or variable to resize.
3172 height_factor: Positive integer.
3173 width_factor: Positive integer.
3174 data_format: One of `"channels_first"`, `"channels_last"`.
3175 interpolation: A string, one of `nearest` or `bilinear`.
3177 Returns:
3178 A tensor.
3180 Raises:
3181 ValueError: in case of incorrect value for
3182 `data_format` or `interpolation`.
3183 """
3184 if data_format == 'channels_first':
3185 rows, cols = 2, 3
3186 elif data_format == 'channels_last':
3187 rows, cols = 1, 2
3188 else:
3189 raise ValueError('Invalid `data_format` argument: %s' % (data_format,))
3191 new_shape = x.shape[rows:cols + 1]
3192 if new_shape.is_fully_defined():
3193 new_shape = constant_op.constant(new_shape.as_list(), dtype='int32')
3194 else:
3195 new_shape = array_ops.shape_v2(x)[rows:cols + 1]
3196 new_shape *= constant_op.constant(
3197 np.array([height_factor, width_factor], dtype='int32'))
3199 if data_format == 'channels_first':
3200 x = permute_dimensions(x, [0, 2, 3, 1])
3201 if interpolation == 'nearest':
3202 x = image_ops.resize_images_v2(
3203 x, new_shape, method=image_ops.ResizeMethod.NEAREST_NEIGHBOR)
3204 elif interpolation == 'bilinear':
3205 x = image_ops.resize_images_v2(x, new_shape,
3206 method=image_ops.ResizeMethod.BILINEAR)
3207 else:
3208 raise ValueError('interpolation should be one '
3209 'of "nearest" or "bilinear".')
3210 if data_format == 'channels_first':
3211 x = permute_dimensions(x, [0, 3, 1, 2])
3213 return x
3216@keras_export('keras.backend.resize_volumes')
3217@dispatch.add_dispatch_support
3218@doc_controls.do_not_generate_docs
3219def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
3220 """Resizes the volume contained in a 5D tensor.
3222 Args:
3223 x: Tensor or variable to resize.
3224 depth_factor: Positive integer.
3225 height_factor: Positive integer.
3226 width_factor: Positive integer.
3227 data_format: One of `"channels_first"`, `"channels_last"`.
3229 Returns:
3230 A tensor.
3232 Raises:
3233 ValueError: if `data_format` is neither
3234 `channels_last` or `channels_first`.
3235 """
3236 if data_format == 'channels_first':
3237 output = repeat_elements(x, depth_factor, axis=2)
3238 output = repeat_elements(output, height_factor, axis=3)
3239 output = repeat_elements(output, width_factor, axis=4)
3240 return output
3241 elif data_format == 'channels_last':
3242 output = repeat_elements(x, depth_factor, axis=1)
3243 output = repeat_elements(output, height_factor, axis=2)
3244 output = repeat_elements(output, width_factor, axis=3)
3245 return output
3246 else:
3247 raise ValueError('Invalid data_format: ' + str(data_format))
3250@keras_export('keras.backend.repeat_elements')
3251@dispatch.add_dispatch_support
3252@doc_controls.do_not_generate_docs
3253def repeat_elements(x, rep, axis):
3254 """Repeats the elements of a tensor along an axis, like `np.repeat`.
3256 If `x` has shape `(s1, s2, s3)` and `axis` is `1`, the output
3257 will have shape `(s1, s2 * rep, s3)`.
3259 Args:
3260 x: Tensor or variable.
3261 rep: Python integer, number of times to repeat.
3262 axis: Axis along which to repeat.
3264 Returns:
3265 A tensor.
3267 Example:
3269 >>> b = tf.constant([1, 2, 3])
3270 >>> tf.keras.backend.repeat_elements(b, rep=2, axis=0)
3271 <tf.Tensor: shape=(6,), dtype=int32,
3272 numpy=array([1, 1, 2, 2, 3, 3], dtype=int32)>
3274 """
3275 x_shape = x.shape.as_list()
3276 # For static axis
3277 if x_shape[axis] is not None:
3278 # slices along the repeat axis
3279 splits = array_ops.split(value=x,
3280 num_or_size_splits=x_shape[axis],
3281 axis=axis)
3282 # repeat each slice the given number of reps
3283 x_rep = [s for s in splits for _ in range(rep)]
3284 return concatenate(x_rep, axis)
3286 # Here we use tf.tile to mimic behavior of np.repeat so that
3287 # we can handle dynamic shapes (that include None).
3288 # To do that, we need an auxiliary axis to repeat elements along
3289 # it and then merge them along the desired axis.
3291 # Repeating
3292 auxiliary_axis = axis + 1
3293 x_shape = array_ops.shape(x)
3294 x_rep = array_ops.expand_dims(x, axis=auxiliary_axis)
3295 reps = np.ones(len(x.shape) + 1)
3296 reps[auxiliary_axis] = rep
3297 x_rep = array_ops.tile(x_rep, reps)
3299 # Merging
3300 reps = np.delete(reps, auxiliary_axis)
3301 reps[axis] = rep
3302 reps = array_ops.constant(reps, dtype='int32')
3303 x_shape *= reps
3304 x_rep = array_ops.reshape(x_rep, x_shape)
3306 # Fix shape representation
3307 x_shape = x.shape.as_list()
3308 x_rep.set_shape(x_shape)
3309 x_rep._keras_shape = tuple(x_shape)
3310 return x_rep
3313@keras_export('keras.backend.repeat')
3314@dispatch.add_dispatch_support
3315@doc_controls.do_not_generate_docs
3316def repeat(x, n):
3317 """Repeats a 2D tensor.
3319 if `x` has shape (samples, dim) and `n` is `2`,
3320 the output will have shape `(samples, 2, dim)`.
3322 Args:
3323 x: Tensor or variable.
3324 n: Python integer, number of times to repeat.
3326 Returns:
3327 A tensor.
3329 Example:
3331 >>> b = tf.constant([[1, 2], [3, 4]])
3332 >>> b
3333 <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
3334 array([[1, 2],
3335 [3, 4]], dtype=int32)>
3336 >>> tf.keras.backend.repeat(b, n=2)
3337 <tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy=
3338 array([[[1, 2],
3339 [1, 2]],
3340 [[3, 4],
3341 [3, 4]]], dtype=int32)>
3343 """
3344 assert ndim(x) == 2
3345 x = array_ops.expand_dims(x, 1)
3346 pattern = array_ops_stack.stack([1, n, 1])
3347 return array_ops.tile(x, pattern)
3350@keras_export('keras.backend.arange')
3351@dispatch.add_dispatch_support
3352@doc_controls.do_not_generate_docs
3353def arange(start, stop=None, step=1, dtype='int32'):
3354 """Creates a 1D tensor containing a sequence of integers.
3356 The function arguments use the same convention as
3357 Theano's arange: if only one argument is provided,
3358 it is in fact the "stop" argument and "start" is 0.
3360 The default type of the returned tensor is `'int32'` to
3361 match TensorFlow's default.
3363 Args:
3364 start: Start value.
3365 stop: Stop value.
3366 step: Difference between two successive values.
3367 dtype: Integer dtype to use.
3369 Returns:
3370 An integer tensor.
3372 Example:
3374 >>> tf.keras.backend.arange(start=0, stop=10, step=1.5)
3375 <tf.Tensor: shape=(7,), dtype=float32,
3376 numpy=array([0. , 1.5, 3. , 4.5, 6. , 7.5, 9. ], dtype=float32)>
3380 """
3381 # Match the behavior of numpy and Theano by returning an empty sequence.
3382 if stop is None and start < 0:
3383 start = 0
3384 result = math_ops.range(start, limit=stop, delta=step, name='arange')
3385 if dtype != 'int32':
3386 result = cast(result, dtype)
3387 return result
3390@keras_export('keras.backend.tile')
3391@dispatch.add_dispatch_support
3392@doc_controls.do_not_generate_docs
3393def tile(x, n):
3394 """Creates a tensor by tiling `x` by `n`.
3396 Args:
3397 x: A tensor or variable
3398 n: A list of integer. The length must be the same as the number of
3399 dimensions in `x`.
3401 Returns:
3402 A tiled tensor.
3403 """
3404 if isinstance(n, int):
3405 n = [n]
3406 return array_ops.tile(x, n)
3409@keras_export('keras.backend.flatten')
3410@dispatch.add_dispatch_support
3411@doc_controls.do_not_generate_docs
3412def flatten(x):
3413 """Flatten a tensor.
3415 Args:
3416 x: A tensor or variable.
3418 Returns:
3419 A tensor, reshaped into 1-D
3421 Example:
3423 >>> b = tf.constant([[1, 2], [3, 4]])
3424 >>> b
3425 <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
3426 array([[1, 2],
3427 [3, 4]], dtype=int32)>
3428 >>> tf.keras.backend.flatten(b)
3429 <tf.Tensor: shape=(4,), dtype=int32,
3430 numpy=array([1, 2, 3, 4], dtype=int32)>
3432 """
3433 return array_ops.reshape(x, [-1])
3436@keras_export('keras.backend.batch_flatten')
3437@dispatch.add_dispatch_support
3438@doc_controls.do_not_generate_docs
3439def batch_flatten(x):
3440 """Turn a nD tensor into a 2D tensor with same 0th dimension.
3442 In other words, it flattens each data samples of a batch.
3444 Args:
3445 x: A tensor or variable.
3447 Returns:
3448 A tensor.
3450 Examples:
3451 Flattening a 3D tensor to 2D by collapsing the last dimension.
3453 >>> x_batch = tf.keras.backend.ones(shape=(2, 3, 4, 5))
3454 >>> x_batch_flatten = batch_flatten(x_batch)
3455 >>> tf.keras.backend.int_shape(x_batch_flatten)
3456 (2, 60)
3458 """
3459 x = array_ops.reshape(x, array_ops_stack.stack([-1, prod(shape(x)[1:])]))
3460 return x
3463@keras_export('keras.backend.expand_dims')
3464@dispatch.add_dispatch_support
3465@doc_controls.do_not_generate_docs
3466def expand_dims(x, axis=-1):
3467 """Adds a 1-sized dimension at index "axis".
3469 Args:
3470 x: A tensor or variable.
3471 axis: Position where to add a new axis.
3473 Returns:
3474 A tensor with expanded dimensions.
3475 """
3476 return array_ops.expand_dims(x, axis)
3479@keras_export('keras.backend.squeeze')
3480@dispatch.add_dispatch_support
3481@doc_controls.do_not_generate_docs
3482def squeeze(x, axis):
3483 """Removes a 1-dimension from the tensor at index "axis".
3485 Args:
3486 x: A tensor or variable.
3487 axis: Axis to drop.
3489 Returns:
3490 A tensor with the same data as `x` but reduced dimensions.
3491 """
3492 return array_ops.squeeze(x, [axis])
3495@keras_export('keras.backend.temporal_padding')
3496@dispatch.add_dispatch_support
3497@doc_controls.do_not_generate_docs
3498def temporal_padding(x, padding=(1, 1)):
3499 """Pads the middle dimension of a 3D tensor.
3501 Args:
3502 x: Tensor or variable.
3503 padding: Tuple of 2 integers, how many zeros to
3504 add at the start and end of dim 1.
3506 Returns:
3507 A padded 3D tensor.
3508 """
3509 assert len(padding) == 2
3510 pattern = [[0, 0], [padding[0], padding[1]], [0, 0]]
3511 return array_ops.pad(x, pattern)
3514@keras_export('keras.backend.spatial_2d_padding')
3515@dispatch.add_dispatch_support
3516@doc_controls.do_not_generate_docs
3517def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
3518 """Pads the 2nd and 3rd dimensions of a 4D tensor.
3520 Args:
3521 x: Tensor or variable.
3522 padding: Tuple of 2 tuples, padding pattern.
3523 data_format: One of `channels_last` or `channels_first`.
3525 Returns:
3526 A padded 4D tensor.
3528 Raises:
3529 ValueError: if `data_format` is neither
3530 `channels_last` or `channels_first`.
3531 """
3532 assert len(padding) == 2
3533 assert len(padding[0]) == 2
3534 assert len(padding[1]) == 2
3535 if data_format is None:
3536 data_format = image_data_format()
3537 if data_format not in {'channels_first', 'channels_last'}:
3538 raise ValueError('Unknown data_format: ' + str(data_format))
3540 if data_format == 'channels_first':
3541 pattern = [[0, 0], [0, 0], list(padding[0]), list(padding[1])]
3542 else:
3543 pattern = [[0, 0], list(padding[0]), list(padding[1]), [0, 0]]
3544 return array_ops.pad(x, pattern)
3547@keras_export('keras.backend.spatial_3d_padding')
3548@dispatch.add_dispatch_support
3549@doc_controls.do_not_generate_docs
3550def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
3551 """Pads 5D tensor with zeros along the depth, height, width dimensions.
3553 Pads these dimensions with respectively
3554 "padding[0]", "padding[1]" and "padding[2]" zeros left and right.
3556 For 'channels_last' data_format,
3557 the 2nd, 3rd and 4th dimension will be padded.
3558 For 'channels_first' data_format,
3559 the 3rd, 4th and 5th dimension will be padded.
3561 Args:
3562 x: Tensor or variable.
3563 padding: Tuple of 3 tuples, padding pattern.
3564 data_format: One of `channels_last` or `channels_first`.
3566 Returns:
3567 A padded 5D tensor.
3569 Raises:
3570 ValueError: if `data_format` is neither
3571 `channels_last` or `channels_first`.
3573 """
3574 assert len(padding) == 3
3575 assert len(padding[0]) == 2
3576 assert len(padding[1]) == 2
3577 assert len(padding[2]) == 2
3578 if data_format is None:
3579 data_format = image_data_format()
3580 if data_format not in {'channels_first', 'channels_last'}:
3581 raise ValueError('Unknown data_format: ' + str(data_format))
3583 if data_format == 'channels_first':
3584 pattern = [[0, 0], [0, 0], [padding[0][0], padding[0][1]],
3585 [padding[1][0], padding[1][1]], [padding[2][0], padding[2][1]]]
3586 else:
3587 pattern = [[0, 0], [padding[0][0], padding[0][1]],
3588 [padding[1][0], padding[1][1]], [padding[2][0],
3589 padding[2][1]], [0, 0]]
3590 return array_ops.pad(x, pattern)
3593@keras_export('keras.backend.stack')
3594@dispatch.add_dispatch_support
3595@doc_controls.do_not_generate_docs
3596def stack(x, axis=0):
3597 """Stacks a list of rank `R` tensors into a rank `R+1` tensor.
3599 Args:
3600 x: List of tensors.
3601 axis: Axis along which to perform stacking.
3603 Returns:
3604 A tensor.
3606 Example:
3608 >>> a = tf.constant([[1, 2],[3, 4]])
3609 >>> b = tf.constant([[10, 20],[30, 40]])
3610 >>> tf.keras.backend.stack((a, b))
3611 <tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy=
3612 array([[[ 1, 2],
3613 [ 3, 4]],
3614 [[10, 20],
3615 [30, 40]]], dtype=int32)>
3617 """
3618 return array_ops_stack.stack(x, axis=axis)
3621@keras_export('keras.backend.one_hot')
3622@dispatch.add_dispatch_support
3623@doc_controls.do_not_generate_docs
3624def one_hot(indices, num_classes):
3625 """Computes the one-hot representation of an integer tensor.
3627 Args:
3628 indices: nD integer tensor of shape
3629 `(batch_size, dim1, dim2, ... dim(n-1))`
3630 num_classes: Integer, number of classes to consider.
3632 Returns:
3633 (n + 1)D one hot representation of the input
3634 with shape `(batch_size, dim1, dim2, ... dim(n-1), num_classes)`
3636 Returns:
3637 The one-hot tensor.
3638 """
3639 return array_ops.one_hot(indices, depth=num_classes, axis=-1)
3642@keras_export('keras.backend.reverse')
3643@dispatch.add_dispatch_support
3644@doc_controls.do_not_generate_docs
3645def reverse(x, axes):
3646 """Reverse a tensor along the specified axes.
3648 Args:
3649 x: Tensor to reverse.
3650 axes: Integer or iterable of integers.
3651 Axes to reverse.
3653 Returns:
3654 A tensor.
3655 """
3656 if isinstance(axes, int):
3657 axes = [axes]
3658 return array_ops.reverse(x, axes)
3661# VALUE MANIPULATION
3662_VALUE_SET_CODE_STRING = """
3663 >>> K = tf.keras.backend # Common keras convention
3664 >>> v = K.variable(1.)
3666 >>> # reassign
3667 >>> K.set_value(v, 2.)
3668 >>> print(K.get_value(v))
3669 2.0
3671 >>> # increment
3672 >>> K.set_value(v, K.get_value(v) + 1)
3673 >>> print(K.get_value(v))
3674 3.0
3676 Variable semantics in TensorFlow 2 are eager execution friendly. The above
3677 code is roughly equivalent to:
3679 >>> v = tf.Variable(1.)
3681 >>> v.assign(2.)
3682 >>> print(v.numpy())
3683 2.0
3685 >>> v.assign_add(1.)
3686 >>> print(v.numpy())
3687 3.0"""[3:] # Prune first newline and indent to match the docstring template.
3690@keras_export('keras.backend.get_value')
3691@doc_controls.do_not_generate_docs
3692def get_value(x):
3693 """Returns the value of a variable.
3695 `backend.get_value` is the complement of `backend.set_value`, and provides
3696 a generic interface for reading from variables while abstracting away the
3697 differences between TensorFlow 1.x and 2.x semantics.
3699 {snippet}
3701 Args:
3702 x: input variable.
3704 Returns:
3705 A Numpy array.
3706 """
3707 if not tensor_util.is_tf_type(x):
3708 return x
3709 if context.executing_eagerly() or isinstance(x, ops.EagerTensor):
3710 return x.numpy()
3711 if not getattr(x, '_in_graph_mode', True):
3712 # This is a variable which was created in an eager context, but is being
3713 # evaluated from a Graph.
3714 with context.eager_mode():
3715 return x.numpy()
3717 if ops.executing_eagerly_outside_functions():
3718 # This method of evaluating works inside the Keras FuncGraph.
3719 with ops.init_scope():
3720 return x.numpy()
3722 with x.graph.as_default():
3723 return x.eval(session=get_session((x,)))
3726@keras_export('keras.backend.batch_get_value')
3727@dispatch.add_dispatch_support
3728@doc_controls.do_not_generate_docs
3729def batch_get_value(tensors):
3730 """Returns the value of more than one tensor variable.
3732 Args:
3733 tensors: list of ops to run.
3735 Returns:
3736 A list of Numpy arrays.
3738 Raises:
3739 RuntimeError: If this method is called inside defun.
3740 """
3741 if context.executing_eagerly():
3742 return [x.numpy() for x in tensors]
3743 elif ops.inside_function(): # pylint: disable=protected-access
3744 raise RuntimeError('Cannot get value inside Tensorflow graph function.')
3745 if tensors:
3746 return get_session(tensors).run(tensors)
3747 else:
3748 return []
3751@keras_export('keras.backend.set_value')
3752@doc_controls.do_not_generate_docs
3753def set_value(x, value):
3754 """Sets the value of a variable, from a Numpy array.
3756 `backend.set_value` is the complement of `backend.get_value`, and provides
3757 a generic interface for assigning to variables while abstracting away the
3758 differences between TensorFlow 1.x and 2.x semantics.
3760 {snippet}
3762 Args:
3763 x: Variable to set to a new value.
3764 value: Value to set the tensor to, as a Numpy array
3765 (of the same shape).
3766 """
3767 value = np.asarray(value, dtype=dtype_numpy(x))
3768 if ops.executing_eagerly_outside_functions():
3769 x.assign(value)
3770 else:
3771 with get_graph().as_default():
3772 tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0])
3773 if hasattr(x, '_assign_placeholder'):
3774 assign_placeholder = x._assign_placeholder
3775 assign_op = x._assign_op
3776 else:
3777 # In order to support assigning weights to resizable variables in
3778 # Keras, we make a placeholder with the correct number of dimensions
3779 # but with None in each dimension. This way, we can assign weights
3780 # of any size (as long as they have the correct dimensionality).
3781 placeholder_shape = tensor_shape.TensorShape([None] * value.ndim)
3782 assign_placeholder = array_ops.placeholder(
3783 tf_dtype, shape=placeholder_shape)
3784 assign_op = x.assign(assign_placeholder)
3785 x._assign_placeholder = assign_placeholder
3786 x._assign_op = assign_op
3787 get_session().run(assign_op, feed_dict={assign_placeholder: value})
3790@keras_export('keras.backend.batch_set_value')
3791@dispatch.add_dispatch_support
3792@doc_controls.do_not_generate_docs
3793def batch_set_value(tuples):
3794 """Sets the values of many tensor variables at once.
3796 Args:
3797 tuples: a list of tuples `(tensor, value)`.
3798 `value` should be a Numpy array.
3799 """
3800 if context.executing_eagerly() or ops.inside_function():
3801 for x, value in tuples:
3802 x.assign(np.asarray(value, dtype=dtype_numpy(x)))
3803 else:
3804 with get_graph().as_default():
3805 if tuples:
3806 assign_ops = []
3807 feed_dict = {}
3808 for x, value in tuples:
3809 value = np.asarray(value, dtype=dtype_numpy(x))
3810 tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0])
3811 if hasattr(x, '_assign_placeholder'):
3812 assign_placeholder = x._assign_placeholder
3813 assign_op = x._assign_op
3814 else:
3815 # In order to support assigning weights to resizable variables in
3816 # Keras, we make a placeholder with the correct number of dimensions
3817 # but with None in each dimension. This way, we can assign weights
3818 # of any size (as long as they have the correct dimensionality).
3819 placeholder_shape = tensor_shape.TensorShape([None] * value.ndim)
3820 assign_placeholder = array_ops.placeholder(
3821 tf_dtype, shape=placeholder_shape)
3822 assign_op = x.assign(assign_placeholder)
3823 x._assign_placeholder = assign_placeholder
3824 x._assign_op = assign_op
3825 assign_ops.append(assign_op)
3826 feed_dict[assign_placeholder] = value
3827 get_session().run(assign_ops, feed_dict=feed_dict)
3830get_value.__doc__ = get_value.__doc__.format(snippet=_VALUE_SET_CODE_STRING)
3831set_value.__doc__ = set_value.__doc__.format(snippet=_VALUE_SET_CODE_STRING)
3834@keras_export('keras.backend.print_tensor')
3835@dispatch.add_dispatch_support
3836@doc_controls.do_not_generate_docs
3837def print_tensor(x, message='', summarize=3):
3838 """Prints `message` and the tensor value when evaluated.
3840 Note that `print_tensor` returns a new tensor identical to `x`
3841 which should be used in the following code. Otherwise the
3842 print operation is not taken into account during evaluation.
3844 Example:
3846 >>> x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
3847 >>> _ = tf.keras.backend.print_tensor(x)
3848 [[1 2]
3849 [3 4]]
3851 Args:
3852 x: Tensor to print.
3853 message: Message to print jointly with the tensor.
3854 summarize: The first and last `summarize` elements within each dimension
3855 are recursively printed per Tensor. If None, then the first 3 and last
3856 3 elements of each dimension are printed for each tensor. If set to
3857 -1, it will print all elements of every tensor.
3859 Returns:
3860 The same tensor `x`, unchanged.
3861 """
3862 if isinstance(x, ops.Tensor) and hasattr(x, 'graph'):
3863 with get_graph().as_default():
3864 op = logging_ops.print_v2(
3865 message, x, output_stream=sys.stdout, summarize=summarize)
3866 with ops.control_dependencies([op]):
3867 return array_ops.identity(x)
3868 else:
3869 logging_ops.print_v2(
3870 message, x, output_stream=sys.stdout, summarize=summarize)
3871 return x
3873# GRAPH MANIPULATION
3876class GraphExecutionFunction:
3877 """Runs a computation graph.
3879 It's possible to pass arguments to `tf.Session.run()` via `session_kwargs`.
3880 In particular additional operations via `fetches` argument and additional
3881 tensor substitutions via `feed_dict` arguments. Note that given
3882 substitutions are merged with substitutions from `inputs`. Even though
3883 `feed_dict` is passed once in the constructor (called in `model.compile()`)
3884 we can modify the values in the dictionary. Through this feed_dict we can
3885 provide additional substitutions besides Keras inputs.
3887 Args:
3888 inputs: Feed placeholders to the computation graph.
3889 outputs: Output tensors to fetch.
3890 updates: Additional update ops to be run at function call.
3891 name: A name to help users identify what this function does.
3892 session_kwargs: Arguments to `tf.Session.run()`:
3893 `fetches`, `feed_dict`, `options`, `run_metadata`.
3894 """
3896 def __init__(self, inputs, outputs, updates=None, name=None,
3897 **session_kwargs):
3898 updates = updates or []
3899 if not isinstance(updates, (list, tuple)):
3900 raise TypeError('`updates` in a Keras backend function '
3901 'should be a list or tuple.')
3903 self._inputs_structure = inputs
3904 self.inputs = nest.flatten(inputs, expand_composites=True)
3905 self._outputs_structure = outputs
3906 self.outputs = cast_variables_to_tensor(
3907 nest.flatten(outputs, expand_composites=True))
3908 # TODO(b/127668432): Consider using autograph to generate these
3909 # dependencies in call.
3910 # Index 0 = total loss or model output for `predict`.
3911 with ops.control_dependencies([self.outputs[0]]):
3912 updates_ops = []
3913 for update in updates:
3914 if isinstance(update, tuple):
3915 p, new_p = update
3916 updates_ops.append(state_ops.assign(p, new_p))
3917 else:
3918 # assumed already an op
3919 updates_ops.append(update)
3920 self.updates_op = control_flow_ops.group(*updates_ops)
3921 self.name = name
3922 # additional tensor substitutions
3923 self.feed_dict = session_kwargs.pop('feed_dict', None)
3924 # additional operations
3925 self.fetches = session_kwargs.pop('fetches', [])
3926 if not isinstance(self.fetches, list):
3927 self.fetches = [self.fetches]
3928 self.run_options = session_kwargs.pop('options', None)
3929 self.run_metadata = session_kwargs.pop('run_metadata', None)
3930 # The main use case of `fetches` being passed to a model is the ability
3931 # to run custom updates
3932 # This requires us to wrap fetches in `identity` ops.
3933 self.fetches = [array_ops.identity(x) for x in self.fetches]
3934 self.session_kwargs = session_kwargs
3935 # This mapping keeps track of the function that should receive the
3936 # output from a fetch in `fetches`: { fetch: function(fetch_output) }
3937 # A Callback can use this to register a function with access to the
3938 # output values for a fetch it added.
3939 self.fetch_callbacks = {}
3941 if session_kwargs:
3942 raise ValueError('Some keys in session_kwargs are not supported at this '
3943 'time: %s' % (session_kwargs.keys(),))
3945 self._callable_fn = None
3946 self._feed_arrays = None
3947 self._feed_symbols = None
3948 self._symbol_vals = None
3949 self._fetches = None
3950 self._session = None
3952 def _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session):
3953 """Generates a callable that runs the graph.
3955 Args:
3956 feed_arrays: List of input tensors to be fed Numpy arrays at runtime.
3957 feed_symbols: List of input tensors to be fed symbolic tensors at runtime.
3958 symbol_vals: List of symbolic tensors to be fed to `feed_symbols`.
3959 session: Session to use to generate the callable.
3961 Returns:
3962 Function that runs the graph according to the above options.
3963 """
3964 # Prepare callable options.
3965 callable_opts = config_pb2.CallableOptions()
3966 # Handle external-data feed.
3967 for x in feed_arrays:
3968 callable_opts.feed.append(x.name)
3969 if self.feed_dict:
3970 for key in sorted(self.feed_dict.keys()):
3971 callable_opts.feed.append(key.name)
3972 # Handle symbolic feed.
3973 for x, y in zip(feed_symbols, symbol_vals):
3974 connection = callable_opts.tensor_connection.add()
3975 if x.dtype != y.dtype:
3976 y = math_ops.cast(y, dtype=x.dtype)
3977 from_tensor = _as_graph_element(y)
3978 if from_tensor is None:
3979 from_tensor = y
3980 connection.from_tensor = from_tensor.name # Data tensor
3981 connection.to_tensor = x.name # Placeholder
3982 # Handle fetches.
3983 for x in self.outputs + self.fetches:
3984 callable_opts.fetch.append(x.name)
3985 # Handle updates.
3986 callable_opts.target.append(self.updates_op.name)
3987 # Handle run_options.
3988 if self.run_options:
3989 callable_opts.run_options.CopyFrom(self.run_options)
3990 # Create callable.
3991 callable_fn = session._make_callable_from_options(callable_opts)
3992 # Cache parameters corresponding to the generated callable, so that
3993 # we can detect future mismatches and refresh the callable.
3994 self._callable_fn = callable_fn
3995 self._feed_arrays = feed_arrays
3996 self._feed_symbols = feed_symbols
3997 self._symbol_vals = symbol_vals
3998 self._fetches = list(self.fetches)
3999 self._session = session
4001 def _call_fetch_callbacks(self, fetches_output):
4002 for fetch, output in zip(self._fetches, fetches_output):
4003 if fetch in self.fetch_callbacks:
4004 self.fetch_callbacks[fetch](output)
4006 def _eval_if_composite(self, tensor):
4007 """Helper method which evaluates any CompositeTensors passed to it."""
4008 # We need to evaluate any composite tensor objects that have been
4009 # reconstructed in 'pack_sequence_as', since otherwise they'll be output as
4010 # actual CompositeTensor objects instead of the value(s) contained in the
4011 # CompositeTensors. E.g., if output_structure contains a SparseTensor, then
4012 # this ensures that we return its value as a SparseTensorValue rather than
4013 # a SparseTensor.
4014 from tensorflow.python.keras.utils import tf_utils # pylint: disable=g-import-not-at-top
4015 if tf_utils.is_extension_type(tensor):
4016 return self._session.run(tensor)
4017 else:
4018 return tensor
4020 def __call__(self, inputs):
4021 inputs = nest.flatten(inputs, expand_composites=True)
4023 session = get_session(inputs)
4024 feed_arrays = []
4025 array_vals = []
4026 feed_symbols = []
4027 symbol_vals = []
4028 for tensor, value in zip(self.inputs, inputs):
4029 if value is None:
4030 continue
4032 if tensor_util.is_tf_type(value):
4033 # Case: feeding symbolic tensor.
4034 feed_symbols.append(tensor)
4035 symbol_vals.append(value)
4036 else:
4037 # Case: feeding Numpy array.
4038 feed_arrays.append(tensor)
4039 # We need to do array conversion and type casting at this level, since
4040 # `callable_fn` only supports exact matches.
4041 tensor_type = dtypes_module.as_dtype(tensor.dtype)
4042 array_vals.append(np.asarray(value,
4043 dtype=tensor_type.as_numpy_dtype))
4045 if self.feed_dict:
4046 for key in sorted(self.feed_dict.keys()):
4047 array_vals.append(
4048 np.asarray(self.feed_dict[key], dtype=key.dtype.as_numpy_dtype))
4050 # Refresh callable if anything has changed.
4051 if (self._callable_fn is None or feed_arrays != self._feed_arrays or
4052 symbol_vals != self._symbol_vals or
4053 feed_symbols != self._feed_symbols or self.fetches != self._fetches or
4054 session != self._session):
4055 self._make_callable(feed_arrays, feed_symbols, symbol_vals, session)
4057 fetched = self._callable_fn(*array_vals,
4058 run_metadata=self.run_metadata)
4059 self._call_fetch_callbacks(fetched[-len(self._fetches):])
4060 output_structure = nest.pack_sequence_as(
4061 self._outputs_structure,
4062 fetched[:len(self.outputs)],
4063 expand_composites=True)
4064 # We need to evaluate any composite tensor objects that have been
4065 # reconstructed in 'pack_sequence_as', since otherwise they'll be output as
4066 # actual CompositeTensor objects instead of the value(s) contained in the
4067 # CompositeTensors. E.g., if output_structure contains a SparseTensor, then
4068 # this ensures that we return its value as a SparseTensorValue rather than
4069 # a SparseTensor.
4070 return nest.map_structure(self._eval_if_composite, output_structure)
4073@keras_export('keras.backend.function')
4074@doc_controls.do_not_generate_docs
4075def function(inputs, outputs, updates=None, name=None, **kwargs):
4076 """Instantiates a Keras function.
4078 Args:
4079 inputs: List of placeholder tensors.
4080 outputs: List of output tensors.
4081 updates: List of update ops.
4082 name: String, name of function.
4083 **kwargs: Passed to `tf.Session.run`.
4085 Returns:
4086 Output values as Numpy arrays.
4088 Raises:
4089 ValueError: if invalid kwargs are passed in or if in eager execution.
4090 """
4091 if ops.executing_eagerly_outside_functions():
4092 if kwargs:
4093 raise ValueError('Session keyword arguments are not supported during '
4094 'eager execution. You passed: %s' % (kwargs,))
4095 if updates:
4096 raise ValueError('`updates` argument is not supported during '
4097 'eager execution. You passed: %s' % (updates,))
4098 from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
4099 from tensorflow.python.keras.utils import tf_utils # pylint: disable=g-import-not-at-top
4100 model = models.Model(inputs=inputs, outputs=outputs)
4102 wrap_outputs = isinstance(outputs, list) and len(outputs) == 1
4103 def func(model_inputs):
4104 outs = model(model_inputs)
4105 if wrap_outputs:
4106 outs = [outs]
4107 return tf_utils.sync_to_numpy_or_python_type(outs)
4109 return func
4111 if kwargs:
4112 for key in kwargs:
4113 if (key not in tf_inspect.getfullargspec(session_module.Session.run)[0]
4114 and key not in ['inputs', 'outputs', 'updates', 'name']):
4115 msg = ('Invalid argument "%s" passed to K.function with TensorFlow '
4116 'backend') % key
4117 raise ValueError(msg)
4118 return GraphExecutionFunction(
4119 inputs, outputs, updates=updates, name=name, **kwargs)
4122@keras_export('keras.backend.gradients')
4123@doc_controls.do_not_generate_docs
4124def gradients(loss, variables):
4125 """Returns the gradients of `loss` w.r.t. `variables`.
4127 Args:
4128 loss: Scalar tensor to minimize.
4129 variables: List of variables.
4131 Returns:
4132 A gradients tensor.
4133 """
4134 return gradients_module.gradients(
4135 loss, variables, colocate_gradients_with_ops=True)
4138@keras_export('keras.backend.stop_gradient')
4139@dispatch.add_dispatch_support
4140@doc_controls.do_not_generate_docs
4141def stop_gradient(variables):
4142 """Returns `variables` but with zero gradient w.r.t. every other variable.
4144 Args:
4145 variables: Tensor or list of tensors to consider constant with respect
4146 to any other variable.
4149 Returns:
4150 A single tensor or a list of tensors (depending on the passed argument)
4151 that has no gradient with respect to any other variable.
4152 """
4153 if isinstance(variables, (list, tuple)):
4154 return map(array_ops.stop_gradient, variables)
4155 return array_ops.stop_gradient(variables)
4158# CONTROL FLOW
4161@keras_export('keras.backend.rnn')
4162@dispatch.add_dispatch_support
4163def rnn(step_function,
4164 inputs,
4165 initial_states,
4166 go_backwards=False,
4167 mask=None,
4168 constants=None,
4169 unroll=False,
4170 input_length=None,
4171 time_major=False,
4172 zero_output_for_mask=False):
4173 """Iterates over the time dimension of a tensor.
4175 Args:
4176 step_function: RNN step function.
4177 Args;
4178 input; Tensor with shape `(samples, ...)` (no time dimension),
4179 representing input for the batch of samples at a certain
4180 time step.
4181 states; List of tensors.
4182 Returns;
4183 output; Tensor with shape `(samples, output_dim)`
4184 (no time dimension).
4185 new_states; List of tensors, same length and shapes
4186 as 'states'. The first state in the list must be the
4187 output tensor at the previous timestep.
4188 inputs: Tensor of temporal data of shape `(samples, time, ...)`
4189 (at least 3D), or nested tensors, and each of which has shape
4190 `(samples, time, ...)`.
4191 initial_states: Tensor with shape `(samples, state_size)`
4192 (no time dimension), containing the initial values for the states used
4193 in the step function. In the case that state_size is in a nested
4194 shape, the shape of initial_states will also follow the nested
4195 structure.
4196 go_backwards: Boolean. If True, do the iteration over the time
4197 dimension in reverse order and return the reversed sequence.
4198 mask: Binary tensor with shape `(samples, time, 1)`,
4199 with a zero for every element that is masked.
4200 constants: List of constant values passed at each step.
4201 unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
4202 input_length: An integer or a 1-D Tensor, depending on whether
4203 the time dimension is fixed-length or not. In case of variable length
4204 input, it is used for masking in case there's no mask specified.
4205 time_major: Boolean. If true, the inputs and outputs will be in shape
4206 `(timesteps, batch, ...)`, whereas in the False case, it will be
4207 `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
4208 efficient because it avoids transposes at the beginning and end of the
4209 RNN calculation. However, most TensorFlow data is batch-major, so by
4210 default this function accepts input and emits output in batch-major
4211 form.
4212 zero_output_for_mask: Boolean. If True, the output for masked timestep
4213 will be zeros, whereas in the False case, output from previous
4214 timestep is returned.
4216 Returns:
4217 A tuple, `(last_output, outputs, new_states)`.
4218 last_output: the latest output of the rnn, of shape `(samples, ...)`
4219 outputs: tensor with shape `(samples, time, ...)` where each
4220 entry `outputs[s, t]` is the output of the step function
4221 at time `t` for sample `s`.
4222 new_states: list of tensors, latest states returned by
4223 the step function, of shape `(samples, ...)`.
4225 Raises:
4226 ValueError: if input dimension is less than 3.
4227 ValueError: if `unroll` is `True` but input timestep is not a fixed
4228 number.
4229 ValueError: if `mask` is provided (not `None`) but states is not provided
4230 (`len(states)` == 0).
4231 """
4233 def swap_batch_timestep(input_t):
4234 # Swap the batch and timestep dim for the incoming tensor.
4235 axes = list(range(len(input_t.shape)))
4236 axes[0], axes[1] = 1, 0
4237 return array_ops.transpose(input_t, axes)
4239 if not time_major:
4240 inputs = nest.map_structure(swap_batch_timestep, inputs)
4242 flatted_inputs = nest.flatten(inputs)
4243 time_steps = flatted_inputs[0].shape[0]
4244 batch = flatted_inputs[0].shape[1]
4245 time_steps_t = array_ops.shape(flatted_inputs[0])[0]
4247 for input_ in flatted_inputs:
4248 input_.shape.with_rank_at_least(3)
4250 if mask is not None:
4251 if mask.dtype != dtypes_module.bool:
4252 mask = math_ops.cast(mask, dtypes_module.bool)
4253 if len(mask.shape) == 2:
4254 mask = expand_dims(mask)
4255 if not time_major:
4256 mask = swap_batch_timestep(mask)
4258 if constants is None:
4259 constants = []
4261 # tf.where needs its condition tensor to be the same shape as its two
4262 # result tensors, but in our case the condition (mask) tensor is
4263 # (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
4264 # So we need to broadcast the mask to match the shape of inputs.
4265 # That's what the tile call does, it just repeats the mask along its
4266 # second dimension n times.
4267 def _expand_mask(mask_t, input_t, fixed_dim=1):
4268 if nest.is_nested(mask_t):
4269 raise ValueError('mask_t is expected to be tensor, but got %s' % mask_t)
4270 if nest.is_nested(input_t):
4271 raise ValueError('input_t is expected to be tensor, but got %s' % input_t)
4272 rank_diff = len(input_t.shape) - len(mask_t.shape)
4273 for _ in range(rank_diff):
4274 mask_t = array_ops.expand_dims(mask_t, -1)
4275 multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
4276 return array_ops.tile(mask_t, multiples)
4278 if unroll:
4279 if not time_steps:
4280 raise ValueError('Unrolling requires a fixed number of timesteps.')
4281 states = tuple(initial_states)
4282 successive_states = []
4283 successive_outputs = []
4285 # Process the input tensors. The input tensor need to be split on the
4286 # time_step dim, and reverse if go_backwards is True. In the case of nested
4287 # input, the input is flattened and then transformed individually.
4288 # The result of this will be a tuple of lists, each of the item in tuple is
4289 # list of the tensor with shape (batch, feature)
4290 def _process_single_input_t(input_t):
4291 input_t = array_ops_stack.unstack(input_t) # unstack for time_step dim
4292 if go_backwards:
4293 input_t.reverse()
4294 return input_t
4296 if nest.is_nested(inputs):
4297 processed_input = nest.map_structure(_process_single_input_t, inputs)
4298 else:
4299 processed_input = (_process_single_input_t(inputs),)
4301 def _get_input_tensor(time):
4302 inp = [t_[time] for t_ in processed_input]
4303 return nest.pack_sequence_as(inputs, inp)
4305 if mask is not None:
4306 mask_list = array_ops_stack.unstack(mask)
4307 if go_backwards:
4308 mask_list.reverse()
4310 for i in range(time_steps):
4311 inp = _get_input_tensor(i)
4312 mask_t = mask_list[i]
4313 output, new_states = step_function(inp,
4314 tuple(states) + tuple(constants))
4315 tiled_mask_t = _expand_mask(mask_t, output)
4317 if not successive_outputs:
4318 prev_output = zeros_like(output)
4319 else:
4320 prev_output = successive_outputs[-1]
4322 output = array_ops.where_v2(tiled_mask_t, output, prev_output)
4324 flat_states = nest.flatten(states)
4325 flat_new_states = nest.flatten(new_states)
4326 tiled_mask_t = tuple(_expand_mask(mask_t, s) for s in flat_states)
4327 flat_final_states = tuple(
4328 array_ops.where_v2(m, s, ps)
4329 for m, s, ps in zip(tiled_mask_t, flat_new_states, flat_states))
4330 states = nest.pack_sequence_as(states, flat_final_states)
4332 successive_outputs.append(output)
4333 successive_states.append(states)
4334 last_output = successive_outputs[-1]
4335 new_states = successive_states[-1]
4336 outputs = array_ops_stack.stack(successive_outputs)
4338 if zero_output_for_mask:
4339 last_output = array_ops.where_v2(
4340 _expand_mask(mask_list[-1], last_output), last_output,
4341 zeros_like(last_output))
4342 outputs = array_ops.where_v2(
4343 _expand_mask(mask, outputs, fixed_dim=2), outputs,
4344 zeros_like(outputs))
4346 else: # mask is None
4347 for i in range(time_steps):
4348 inp = _get_input_tensor(i)
4349 output, states = step_function(inp, tuple(states) + tuple(constants))
4350 successive_outputs.append(output)
4351 successive_states.append(states)
4352 last_output = successive_outputs[-1]
4353 new_states = successive_states[-1]
4354 outputs = array_ops_stack.stack(successive_outputs)
4356 else: # Unroll == False
4357 states = tuple(initial_states)
4359 # Create input tensor array, if the inputs is nested tensors, then it will
4360 # be flattened first, and tensor array will be created one per flattened
4361 # tensor.
4362 input_ta = tuple(
4363 tensor_array_ops.TensorArray(
4364 dtype=inp.dtype,
4365 size=time_steps_t,
4366 tensor_array_name='input_ta_%s' % i)
4367 for i, inp in enumerate(flatted_inputs))
4368 input_ta = tuple(
4369 ta.unstack(input_) if not go_backwards else ta
4370 .unstack(reverse(input_, 0))
4371 for ta, input_ in zip(input_ta, flatted_inputs))
4373 # Get the time(0) input and compute the output for that, the output will be
4374 # used to determine the dtype of output tensor array. Don't read from
4375 # input_ta due to TensorArray clear_after_read default to True.
4376 input_time_zero = nest.pack_sequence_as(inputs,
4377 [inp[0] for inp in flatted_inputs])
4378 # output_time_zero is used to determine the cell output shape and its dtype.
4379 # the value is discarded.
4380 output_time_zero, _ = step_function(
4381 input_time_zero, tuple(initial_states) + tuple(constants))
4382 output_ta = tuple(
4383 tensor_array_ops.TensorArray(
4384 dtype=out.dtype,
4385 size=time_steps_t,
4386 element_shape=out.shape,
4387 tensor_array_name='output_ta_%s' % i)
4388 for i, out in enumerate(nest.flatten(output_time_zero)))
4390 time = constant_op.constant(0, dtype='int32', name='time')
4392 # We only specify the 'maximum_iterations' when building for XLA since that
4393 # causes slowdowns on GPU in TF.
4394 if (not context.executing_eagerly() and
4395 control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph())):
4396 max_iterations = math_ops.reduce_max(input_length)
4397 else:
4398 max_iterations = None
4400 while_loop_kwargs = {
4401 'cond': lambda time, *_: time < time_steps_t,
4402 'maximum_iterations': max_iterations,
4403 'parallel_iterations': 32,
4404 'swap_memory': True,
4405 }
4406 if mask is not None:
4407 if go_backwards:
4408 mask = reverse(mask, 0)
4410 mask_ta = tensor_array_ops.TensorArray(
4411 dtype=dtypes_module.bool,
4412 size=time_steps_t,
4413 tensor_array_name='mask_ta')
4414 mask_ta = mask_ta.unstack(mask)
4416 def masking_fn(time):
4417 return mask_ta.read(time)
4419 def compute_masked_output(mask_t, flat_out, flat_mask):
4420 tiled_mask_t = tuple(
4421 _expand_mask(mask_t, o, fixed_dim=len(mask_t.shape))
4422 for o in flat_out)
4423 return tuple(
4424 array_ops.where_v2(m, o, fm)
4425 for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask))
4426 elif isinstance(input_length, ops.Tensor):
4427 if go_backwards:
4428 max_len = math_ops.reduce_max(input_length, axis=0)
4429 rev_input_length = math_ops.subtract(max_len - 1, input_length)
4431 def masking_fn(time):
4432 return math_ops.less(rev_input_length, time)
4433 else:
4435 def masking_fn(time):
4436 return math_ops.greater(input_length, time)
4438 def compute_masked_output(mask_t, flat_out, flat_mask):
4439 return tuple(
4440 array_ops.where(mask_t, o, zo)
4441 for (o, zo) in zip(flat_out, flat_mask))
4442 else:
4443 masking_fn = None
4445 if masking_fn is not None:
4446 # Mask for the T output will be base on the output of T - 1. In the case
4447 # T = 0, a zero filled tensor will be used.
4448 flat_zero_output = tuple(array_ops.zeros_like(o)
4449 for o in nest.flatten(output_time_zero))
4450 def _step(time, output_ta_t, prev_output, *states):
4451 """RNN step function.
4453 Args:
4454 time: Current timestep value.
4455 output_ta_t: TensorArray.
4456 prev_output: tuple of outputs from time - 1.
4457 *states: List of states.
4459 Returns:
4460 Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`
4461 """
4462 current_input = tuple(ta.read(time) for ta in input_ta)
4463 # maybe set shape.
4464 current_input = nest.pack_sequence_as(inputs, current_input)
4465 mask_t = masking_fn(time)
4466 output, new_states = step_function(current_input,
4467 tuple(states) + tuple(constants))
4468 # mask output
4469 flat_output = nest.flatten(output)
4470 flat_mask_output = (flat_zero_output if zero_output_for_mask
4471 else nest.flatten(prev_output))
4472 flat_new_output = compute_masked_output(mask_t, flat_output,
4473 flat_mask_output)
4475 # mask states
4476 flat_state = nest.flatten(states)
4477 flat_new_state = nest.flatten(new_states)
4478 for state, new_state in zip(flat_state, flat_new_state):
4479 if isinstance(new_state, ops.Tensor):
4480 new_state.set_shape(state.shape)
4481 flat_final_state = compute_masked_output(mask_t, flat_new_state,
4482 flat_state)
4483 new_states = nest.pack_sequence_as(new_states, flat_final_state)
4485 output_ta_t = tuple(
4486 ta.write(time, out)
4487 for ta, out in zip(output_ta_t, flat_new_output))
4488 return (time + 1, output_ta_t,
4489 tuple(flat_new_output)) + tuple(new_states)
4491 final_outputs = while_loop.while_loop(
4492 body=_step,
4493 loop_vars=(time, output_ta, flat_zero_output) + states,
4494 **while_loop_kwargs)
4495 # Skip final_outputs[2] which is the output for final timestep.
4496 new_states = final_outputs[3:]
4497 else:
4498 def _step(time, output_ta_t, *states):
4499 """RNN step function.
4501 Args:
4502 time: Current timestep value.
4503 output_ta_t: TensorArray.
4504 *states: List of states.
4506 Returns:
4507 Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
4508 """
4509 current_input = tuple(ta.read(time) for ta in input_ta)
4510 current_input = nest.pack_sequence_as(inputs, current_input)
4511 output, new_states = step_function(current_input,
4512 tuple(states) + tuple(constants))
4513 flat_state = nest.flatten(states)
4514 flat_new_state = nest.flatten(new_states)
4515 for state, new_state in zip(flat_state, flat_new_state):
4516 if isinstance(new_state, ops.Tensor):
4517 new_state.set_shape(state.shape)
4519 flat_output = nest.flatten(output)
4520 output_ta_t = tuple(
4521 ta.write(time, out) for ta, out in zip(output_ta_t, flat_output))
4522 new_states = nest.pack_sequence_as(initial_states, flat_new_state)
4523 return (time + 1, output_ta_t) + tuple(new_states)
4525 final_outputs = while_loop.while_loop(
4526 body=_step, loop_vars=(time, output_ta) + states, **while_loop_kwargs)
4527 new_states = final_outputs[2:]
4529 output_ta = final_outputs[1]
4531 outputs = tuple(o.stack() for o in output_ta)
4532 last_output = tuple(o[-1] for o in outputs)
4534 outputs = nest.pack_sequence_as(output_time_zero, outputs)
4535 last_output = nest.pack_sequence_as(output_time_zero, last_output)
4537 # static shape inference
4538 def set_shape(output_):
4539 if isinstance(output_, ops.Tensor):
4540 shape = output_.shape.as_list()
4541 shape[0] = time_steps
4542 shape[1] = batch
4543 output_.set_shape(shape)
4544 return output_
4546 outputs = nest.map_structure(set_shape, outputs)
4548 if not time_major:
4549 outputs = nest.map_structure(swap_batch_timestep, outputs)
4551 return last_output, outputs, new_states
4554@keras_export('keras.backend.switch')
4555@dispatch.add_dispatch_support
4556@doc_controls.do_not_generate_docs
4557def switch(condition, then_expression, else_expression):
4558 """Switches between two operations depending on a scalar value.
4560 Note that both `then_expression` and `else_expression`
4561 should be symbolic tensors of the *same shape*.
4563 Args:
4564 condition: tensor (`int` or `bool`).
4565 then_expression: either a tensor, or a callable that returns a tensor.
4566 else_expression: either a tensor, or a callable that returns a tensor.
4568 Returns:
4569 The selected tensor.
4571 Raises:
4572 ValueError: If rank of `condition` is greater than rank of expressions.
4573 """
4574 if condition.dtype != dtypes_module.bool:
4575 condition = math_ops.cast(condition, 'bool')
4576 cond_ndim = ndim(condition)
4577 if not cond_ndim:
4578 if not callable(then_expression):
4580 def then_expression_fn():
4581 return then_expression
4582 else:
4583 then_expression_fn = then_expression
4584 if not callable(else_expression):
4586 def else_expression_fn():
4587 return else_expression
4588 else:
4589 else_expression_fn = else_expression
4590 x = cond.cond(condition, then_expression_fn, else_expression_fn)
4591 else:
4592 # tf.where needs its condition tensor
4593 # to be the same shape as its two
4594 # result tensors
4595 if callable(then_expression):
4596 then_expression = then_expression()
4597 if callable(else_expression):
4598 else_expression = else_expression()
4599 expr_ndim = ndim(then_expression)
4600 if cond_ndim > expr_ndim:
4601 raise ValueError('Rank of `condition` should be less than or'
4602 ' equal to rank of `then_expression` and '
4603 '`else_expression`. ndim(condition)=' + str(cond_ndim) +
4604 ', ndim(then_expression)'
4605 '=' + str(expr_ndim))
4606 if cond_ndim > 1:
4607 ndim_diff = expr_ndim - cond_ndim
4608 cond_shape = array_ops.concat(
4609 [array_ops.shape(condition), [1] * ndim_diff], axis=0)
4610 condition = array_ops.reshape(condition, cond_shape)
4611 expr_shape = array_ops.shape(then_expression)
4612 shape_diff = expr_shape - cond_shape
4613 tile_shape = array_ops.where_v2(shape_diff > 0, expr_shape,
4614 array_ops.ones_like(expr_shape))
4615 condition = array_ops.tile(condition, tile_shape)
4616 x = array_ops.where_v2(condition, then_expression, else_expression)
4617 return x
4620@keras_export('keras.backend.in_train_phase')
4621@doc_controls.do_not_generate_docs
4622def in_train_phase(x, alt, training=None):
4623 """Selects `x` in train phase, and `alt` otherwise.
4625 Note that `alt` should have the *same shape* as `x`.
4627 Args:
4628 x: What to return in train phase
4629 (tensor or callable that returns a tensor).
4630 alt: What to return otherwise
4631 (tensor or callable that returns a tensor).
4632 training: Optional scalar tensor
4633 (or Python boolean, or Python integer)
4634 specifying the learning phase.
4636 Returns:
4637 Either `x` or `alt` based on the `training` flag.
4638 the `training` flag defaults to `K.learning_phase()`.
4639 """
4640 from tensorflow.python.keras.engine import base_layer_utils # pylint: disable=g-import-not-at-top
4641 if training is None:
4642 training = base_layer_utils.call_context().training
4644 if training is None:
4645 training = learning_phase()
4647 # TODO(b/138862903): Handle the case when training is tensor.
4648 if not tensor_util.is_tf_type(training):
4649 if training == 1 or training is True:
4650 if callable(x):
4651 return x()
4652 else:
4653 return x
4655 elif training == 0 or training is False:
4656 if callable(alt):
4657 return alt()
4658 else:
4659 return alt
4661 # else: assume learning phase is a placeholder tensor.
4662 x = switch(training, x, alt)
4663 return x
4666@keras_export('keras.backend.in_test_phase')
4667@doc_controls.do_not_generate_docs
4668def in_test_phase(x, alt, training=None):
4669 """Selects `x` in test phase, and `alt` otherwise.
4671 Note that `alt` should have the *same shape* as `x`.
4673 Args:
4674 x: What to return in test phase
4675 (tensor or callable that returns a tensor).
4676 alt: What to return otherwise
4677 (tensor or callable that returns a tensor).
4678 training: Optional scalar tensor
4679 (or Python boolean, or Python integer)
4680 specifying the learning phase.
4682 Returns:
4683 Either `x` or `alt` based on `K.learning_phase`.
4684 """
4685 return in_train_phase(alt, x, training=training)
4688# NN OPERATIONS
4691@keras_export('keras.backend.relu')
4692@dispatch.add_dispatch_support
4693@doc_controls.do_not_generate_docs
4694def relu(x, alpha=0., max_value=None, threshold=0):
4695 """Rectified linear unit.
4697 With default values, it returns element-wise `max(x, 0)`.
4699 Otherwise, it follows:
4700 `f(x) = max_value` for `x >= max_value`,
4701 `f(x) = x` for `threshold <= x < max_value`,
4702 `f(x) = alpha * (x - threshold)` otherwise.
4704 Args:
4705 x: A tensor or variable.
4706 alpha: A scalar, slope of negative section (default=`0.`).
4707 max_value: float. Saturation threshold.
4708 threshold: float. Threshold value for thresholded activation.
4710 Returns:
4711 A tensor.
4712 """
4713 # While x can be a tensor or variable, we also see cases where
4714 # numpy arrays, lists, tuples are passed as well.
4715 # lists, tuples do not have 'dtype' attribute.
4716 dtype = getattr(x, 'dtype', floatx())
4717 if alpha != 0.:
4718 if max_value is None and threshold == 0:
4719 return nn.leaky_relu(x, alpha=alpha)
4721 if threshold != 0:
4722 negative_part = nn.relu(-x + threshold)
4723 else:
4724 negative_part = nn.relu(-x)
4726 clip_max = max_value is not None
4728 if threshold != 0:
4729 # computes x for x > threshold else 0
4730 x = x * math_ops.cast(math_ops.greater(x, threshold), dtype=dtype)
4731 elif max_value == 6:
4732 # if no threshold, then can use nn.relu6 native TF op for performance
4733 x = nn.relu6(x)
4734 clip_max = False
4735 else:
4736 x = nn.relu(x)
4738 if clip_max:
4739 max_value = _constant_to_tensor(max_value, x.dtype.base_dtype)
4740 zero = _constant_to_tensor(0, x.dtype.base_dtype)
4741 x = clip_ops.clip_by_value(x, zero, max_value)
4743 if alpha != 0.:
4744 alpha = _to_tensor(alpha, x.dtype.base_dtype)
4745 x -= alpha * negative_part
4746 return x
4749@keras_export('keras.backend.elu')
4750@dispatch.add_dispatch_support
4751@doc_controls.do_not_generate_docs
4752def elu(x, alpha=1.):
4753 """Exponential linear unit.
4755 Args:
4756 x: A tensor or variable to compute the activation function for.
4757 alpha: A scalar, slope of negative section.
4759 Returns:
4760 A tensor.
4761 """
4762 res = nn.elu(x)
4763 if alpha == 1:
4764 return res
4765 else:
4766 return array_ops.where_v2(x > 0, res, alpha * res)
4769@keras_export('keras.backend.softmax')
4770@dispatch.add_dispatch_support
4771@doc_controls.do_not_generate_docs
4772def softmax(x, axis=-1):
4773 """Softmax of a tensor.
4775 Args:
4776 x: A tensor or variable.
4777 axis: The dimension softmax would be performed on.
4778 The default is -1 which indicates the last dimension.
4780 Returns:
4781 A tensor.
4782 """
4783 return nn.softmax(x, axis=axis)
4786@keras_export('keras.backend.softplus')
4787@dispatch.add_dispatch_support
4788@doc_controls.do_not_generate_docs
4789def softplus(x):
4790 """Softplus of a tensor.
4792 Args:
4793 x: A tensor or variable.
4795 Returns:
4796 A tensor.
4797 """
4798 return math_ops.softplus(x)
4801@keras_export('keras.backend.softsign')
4802@dispatch.add_dispatch_support
4803@doc_controls.do_not_generate_docs
4804def softsign(x):
4805 """Softsign of a tensor.
4807 Args:
4808 x: A tensor or variable.
4810 Returns:
4811 A tensor.
4812 """
4813 return nn.softsign(x)
4816@keras_export('keras.backend.categorical_crossentropy')
4817@dispatch.add_dispatch_support
4818@doc_controls.do_not_generate_docs
4819def categorical_crossentropy(target, output, from_logits=False, axis=-1):
4820 """Categorical crossentropy between an output tensor and a target tensor.
4822 Args:
4823 target: A tensor of the same shape as `output`.
4824 output: A tensor resulting from a softmax
4825 (unless `from_logits` is True, in which
4826 case `output` is expected to be the logits).
4827 from_logits: Boolean, whether `output` is the
4828 result of a softmax, or is a tensor of logits.
4829 axis: Int specifying the channels axis. `axis=-1` corresponds to data
4830 format `channels_last`, and `axis=1` corresponds to data format
4831 `channels_first`.
4833 Returns:
4834 Output tensor.
4836 Raises:
4837 ValueError: if `axis` is neither -1 nor one of the axes of `output`.
4839 Example:
4841 >>> a = tf.constant([1., 0., 0., 0., 1., 0., 0., 0., 1.], shape=[3,3])
4842 >>> print(a)
4843 tf.Tensor(
4844 [[1. 0. 0.]
4845 [0. 1. 0.]
4846 [0. 0. 1.]], shape=(3, 3), dtype=float32)
4847 >>> b = tf.constant([.9, .05, .05, .05, .89, .06, .05, .01, .94], shape=[3,3])
4848 >>> print(b)
4849 tf.Tensor(
4850 [[0.9 0.05 0.05]
4851 [0.05 0.89 0.06]
4852 [0.05 0.01 0.94]], shape=(3, 3), dtype=float32)
4853 >>> loss = tf.keras.backend.categorical_crossentropy(a, b)
4854 >>> print(np.around(loss, 5))
4855 [0.10536 0.11653 0.06188]
4856 >>> loss = tf.keras.backend.categorical_crossentropy(a, a)
4857 >>> print(np.around(loss, 5))
4858 [0. 0. 0.]
4860 """
4861 target = tensor_conversion.convert_to_tensor_v2_with_dispatch(target)
4862 output = tensor_conversion.convert_to_tensor_v2_with_dispatch(output)
4863 target.shape.assert_is_compatible_with(output.shape)
4865 # Use logits whenever they are available. `softmax` and `sigmoid`
4866 # activations cache logits on the `output` Tensor.
4867 if hasattr(output, '_keras_logits'):
4868 output = output._keras_logits # pylint: disable=protected-access
4869 if from_logits:
4870 warnings.warn(
4871 '"`categorical_crossentropy` received `from_logits=True`, but '
4872 'the `output` argument was produced by a sigmoid or softmax '
4873 'activation and thus does not represent logits. Was this intended?"')
4874 from_logits = True
4876 if from_logits:
4877 return nn.softmax_cross_entropy_with_logits_v2(
4878 labels=target, logits=output, axis=axis)
4880 if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
4881 output.op.type == 'Softmax') and not hasattr(output, '_keras_history'):
4882 # When softmax activation function is used for output operation, we
4883 # use logits from the softmax function directly to compute loss in order
4884 # to prevent collapsing zero when training.
4885 # See b/117284466
4886 assert len(output.op.inputs) == 1
4887 output = output.op.inputs[0]
4888 return nn.softmax_cross_entropy_with_logits_v2(
4889 labels=target, logits=output, axis=axis)
4891 # scale preds so that the class probas of each sample sum to 1
4892 output = output / math_ops.reduce_sum(output, axis, True)
4893 # Compute cross entropy from probabilities.
4894 epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
4895 output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
4896 return -math_ops.reduce_sum(target * math_ops.log(output), axis)
4899@keras_export('keras.backend.sparse_categorical_crossentropy')
4900@dispatch.add_dispatch_support
4901@doc_controls.do_not_generate_docs
4902def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
4903 """Categorical crossentropy with integer targets.
4905 Args:
4906 target: An integer tensor.
4907 output: A tensor resulting from a softmax
4908 (unless `from_logits` is True, in which
4909 case `output` is expected to be the logits).
4910 from_logits: Boolean, whether `output` is the
4911 result of a softmax, or is a tensor of logits.
4912 axis: Int specifying the channels axis. `axis=-1` corresponds to data
4913 format `channels_last`, and `axis=1` corresponds to data format
4914 `channels_first`.
4916 Returns:
4917 Output tensor.
4919 Raises:
4920 ValueError: if `axis` is neither -1 nor one of the axes of `output`.
4921 """
4922 target = tensor_conversion.convert_to_tensor_v2_with_dispatch(target)
4923 output = tensor_conversion.convert_to_tensor_v2_with_dispatch(output)
4925 # Use logits whenever they are available. `softmax` and `sigmoid`
4926 # activations cache logits on the `output` Tensor.
4927 if hasattr(output, '_keras_logits'):
4928 output = output._keras_logits # pylint: disable=protected-access
4929 if from_logits:
4930 warnings.warn(
4931 '"`sparse_categorical_crossentropy` received `from_logits=True`, but '
4932 'the `output` argument was produced by a sigmoid or softmax '
4933 'activation and thus does not represent logits. Was this intended?"')
4934 from_logits = True
4935 elif (not from_logits and
4936 not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
4937 output.op.type == 'Softmax') and not hasattr(output, '_keras_history'):
4938 # When softmax activation function is used for output operation, we
4939 # use logits from the softmax function directly to compute loss in order
4940 # to prevent collapsing zero when training.
4941 # See b/117284466
4942 assert len(output.op.inputs) == 1
4943 output = output.op.inputs[0]
4944 from_logits = True
4945 elif not from_logits:
4946 epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
4947 output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_)
4948 output = math_ops.log(output)
4950 if isinstance(output.shape, (tuple, list)):
4951 output_rank = len(output.shape)
4952 else:
4953 output_rank = output.shape.ndims
4954 if output_rank is not None:
4955 axis %= output_rank
4956 if axis != output_rank - 1:
4957 permutation = list(
4958 itertools.chain(range(axis), range(axis + 1, output_rank), [axis]))
4959 output = array_ops.transpose(output, perm=permutation)
4960 elif axis != -1:
4961 raise ValueError(
4962 'Cannot compute sparse categorical crossentropy with `axis={}` on an '
4963 'output tensor with unknown rank'.format(axis))
4965 target = cast(target, 'int64')
4967 # Try to adjust the shape so that rank of labels = rank of logits - 1.
4968 output_shape = array_ops.shape_v2(output)
4969 target_rank = target.shape.ndims
4971 update_shape = (
4972 target_rank is not None and output_rank is not None and
4973 target_rank != output_rank - 1)
4974 if update_shape:
4975 target = flatten(target)
4976 output = array_ops.reshape(output, [-1, output_shape[-1]])
4978 if py_any(_is_symbolic_tensor(v) for v in [target, output]):
4979 with get_graph().as_default():
4980 res = nn.sparse_softmax_cross_entropy_with_logits_v2(
4981 labels=target, logits=output)
4982 else:
4983 res = nn.sparse_softmax_cross_entropy_with_logits_v2(
4984 labels=target, logits=output)
4986 if update_shape and output_rank >= 3:
4987 # If our output includes timesteps or spatial dimensions we need to reshape
4988 return array_ops.reshape(res, output_shape[:-1])
4989 else:
4990 return res
4993@keras_export('keras.backend.binary_crossentropy')
4994@dispatch.add_dispatch_support
4995@doc_controls.do_not_generate_docs
4996def binary_crossentropy(target, output, from_logits=False):
4997 """Binary crossentropy between an output tensor and a target tensor.
4999 Args:
5000 target: A tensor with the same shape as `output`.
5001 output: A tensor.
5002 from_logits: Whether `output` is expected to be a logits tensor.
5003 By default, we consider that `output`
5004 encodes a probability distribution.
5006 Returns:
5007 A tensor.
5008 """
5009 target = tensor_conversion.convert_to_tensor_v2_with_dispatch(target)
5010 output = tensor_conversion.convert_to_tensor_v2_with_dispatch(output)
5012 # Use logits whenever they are available. `softmax` and `sigmoid`
5013 # activations cache logits on the `output` Tensor.
5014 if hasattr(output, '_keras_logits'):
5015 output = output._keras_logits # pylint: disable=protected-access
5016 if from_logits:
5017 warnings.warn(
5018 '"`binary_crossentropy` received `from_logits=True`, but the `output`'
5019 ' argument was produced by a sigmoid or softmax activation and thus '
5020 'does not represent logits. Was this intended?"')
5021 from_logits = True
5023 if from_logits:
5024 return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
5026 if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
5027 output.op.type == 'Sigmoid') and not hasattr(output, '_keras_history'):
5028 # When sigmoid activation function is used for output operation, we
5029 # use logits from the sigmoid function directly to compute loss in order
5030 # to prevent collapsing zero when training.
5031 assert len(output.op.inputs) == 1
5032 output = output.op.inputs[0]
5033 return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
5035 epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
5036 output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
5038 # Compute cross entropy from probabilities.
5039 bce = target * math_ops.log(output + epsilon())
5040 bce += (1 - target) * math_ops.log(1 - output + epsilon())
5041 return -bce
5044@keras_export('keras.backend.sigmoid')
5045@dispatch.add_dispatch_support
5046@doc_controls.do_not_generate_docs
5047def sigmoid(x):
5048 """Element-wise sigmoid.
5050 Args:
5051 x: A tensor or variable.
5053 Returns:
5054 A tensor.
5055 """
5056 return nn.sigmoid(x)
5059@keras_export('keras.backend.hard_sigmoid')
5060@dispatch.add_dispatch_support
5061@doc_controls.do_not_generate_docs
5062def hard_sigmoid(x):
5063 """Segment-wise linear approximation of sigmoid.
5065 Faster than sigmoid.
5066 Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`.
5067 In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`.
5069 Args:
5070 x: A tensor or variable.
5072 Returns:
5073 A tensor.
5074 """
5075 point_two = _constant_to_tensor(0.2, x.dtype.base_dtype)
5076 point_five = _constant_to_tensor(0.5, x.dtype.base_dtype)
5077 x = math_ops.multiply(x, point_two)
5078 x = math_ops.add(x, point_five)
5079 x = clip_ops.clip_by_value(x, 0., 1.)
5080 return x
5083@keras_export('keras.backend.tanh')
5084@dispatch.add_dispatch_support
5085@doc_controls.do_not_generate_docs
5086def tanh(x):
5087 """Element-wise tanh.
5089 Args:
5090 x: A tensor or variable.
5092 Returns:
5093 A tensor.
5094 """
5095 return nn.tanh(x)
5098@keras_export('keras.backend.dropout')
5099@dispatch.add_dispatch_support
5100@doc_controls.do_not_generate_docs
5101def dropout(x, level, noise_shape=None, seed=None):
5102 """Sets entries in `x` to zero at random, while scaling the entire tensor.
5104 Args:
5105 x: tensor
5106 level: fraction of the entries in the tensor
5107 that will be set to 0.
5108 noise_shape: shape for randomly generated keep/drop flags,
5109 must be broadcastable to the shape of `x`
5110 seed: random seed to ensure determinism.
5112 Returns:
5113 A tensor.
5114 """
5115 if seed is None:
5116 seed = np.random.randint(10e6)
5117 return nn.dropout_v2(x, rate=level, noise_shape=noise_shape, seed=seed)
5120@keras_export('keras.backend.l2_normalize')
5121@dispatch.add_dispatch_support
5122@doc_controls.do_not_generate_docs
5123def l2_normalize(x, axis=None):
5124 """Normalizes a tensor wrt the L2 norm alongside the specified axis.
5126 Args:
5127 x: Tensor or variable.
5128 axis: axis along which to perform normalization.
5130 Returns:
5131 A tensor.
5132 """
5133 return nn.l2_normalize(x, axis=axis)
5136@keras_export('keras.backend.in_top_k')
5137@dispatch.add_dispatch_support
5138@doc_controls.do_not_generate_docs
5139def in_top_k(predictions, targets, k):
5140 """Returns whether the `targets` are in the top `k` `predictions`.
5142 Args:
5143 predictions: A tensor of shape `(batch_size, classes)` and type `float32`.
5144 targets: A 1D tensor of length `batch_size` and type `int32` or `int64`.
5145 k: An `int`, number of top elements to consider.
5147 Returns:
5148 A 1D tensor of length `batch_size` and type `bool`.
5149 `output[i]` is `True` if `predictions[i, targets[i]]` is within top-`k`
5150 values of `predictions[i]`.
5151 """
5152 return nn.in_top_k(predictions, targets, k)
5155# CONVOLUTIONS
5158def _preprocess_conv1d_input(x, data_format):
5159 """Transpose and cast the input before the conv1d.
5161 Args:
5162 x: input tensor.
5163 data_format: string, `"channels_last"` or `"channels_first"`.
5165 Returns:
5166 A tensor.
5167 """
5168 tf_data_format = 'NWC' # to pass TF Conv2dNative operations
5169 if data_format == 'channels_first':
5170 if not _has_nchw_support():
5171 x = array_ops.transpose(x, (0, 2, 1)) # NCW -> NWC
5172 else:
5173 tf_data_format = 'NCW'
5174 return x, tf_data_format
5177def _preprocess_conv2d_input(x, data_format, force_transpose=False):
5178 """Transpose and cast the input before the conv2d.
5180 Args:
5181 x: input tensor.
5182 data_format: string, `"channels_last"` or `"channels_first"`.
5183 force_transpose: Boolean. If True, the input will always be transposed
5184 from NCHW to NHWC if `data_format` is `"channels_first"`.
5185 If False, the transposition only occurs on CPU (GPU ops are
5186 assumed to support NCHW).
5188 Returns:
5189 A tensor.
5190 """
5191 tf_data_format = 'NHWC'
5192 if data_format == 'channels_first':
5193 if not _has_nchw_support() or force_transpose:
5194 x = array_ops.transpose(x, (0, 2, 3, 1)) # NCHW -> NHWC
5195 else:
5196 tf_data_format = 'NCHW'
5197 return x, tf_data_format
5200def _preprocess_conv3d_input(x, data_format):
5201 """Transpose and cast the input before the conv3d.
5203 Args:
5204 x: input tensor.
5205 data_format: string, `"channels_last"` or `"channels_first"`.
5207 Returns:
5208 A tensor.
5209 """
5210 tf_data_format = 'NDHWC'
5211 if data_format == 'channels_first':
5212 if not _has_nchw_support():
5213 x = array_ops.transpose(x, (0, 2, 3, 4, 1))
5214 else:
5215 tf_data_format = 'NCDHW'
5216 return x, tf_data_format
5219def _preprocess_padding(padding):
5220 """Convert keras' padding to TensorFlow's padding.
5222 Args:
5223 padding: string, one of 'same' , 'valid'
5225 Returns:
5226 a string, one of 'SAME', 'VALID'.
5228 Raises:
5229 ValueError: if invalid `padding'`
5230 """
5231 if padding == 'same':
5232 padding = 'SAME'
5233 elif padding == 'valid':
5234 padding = 'VALID'
5235 else:
5236 raise ValueError('Invalid padding: ' + str(padding))
5237 return padding
5240@keras_export('keras.backend.conv1d')
5241@dispatch.add_dispatch_support
5242@doc_controls.do_not_generate_docs
5243def conv1d(x,
5244 kernel,
5245 strides=1,
5246 padding='valid',
5247 data_format=None,
5248 dilation_rate=1):
5249 """1D convolution.
5251 Args:
5252 x: Tensor or variable.
5253 kernel: kernel tensor.
5254 strides: stride integer.
5255 padding: string, `"same"`, `"causal"` or `"valid"`.
5256 data_format: string, one of "channels_last", "channels_first".
5257 dilation_rate: integer dilate rate.
5259 Returns:
5260 A tensor, result of 1D convolution.
5262 Raises:
5263 ValueError: if `data_format` is neither `channels_last` or
5264 `channels_first`.
5265 """
5266 if data_format is None:
5267 data_format = image_data_format()
5268 if data_format not in {'channels_first', 'channels_last'}:
5269 raise ValueError('Unknown data_format: ' + str(data_format))
5271 kernel_shape = kernel.shape.as_list()
5272 if padding == 'causal':
5273 # causal (dilated) convolution:
5274 left_pad = dilation_rate * (kernel_shape[0] - 1)
5275 x = temporal_padding(x, (left_pad, 0))
5276 padding = 'valid'
5277 padding = _preprocess_padding(padding)
5279 x, tf_data_format = _preprocess_conv1d_input(x, data_format)
5280 x = nn.convolution(
5281 input=x,
5282 filter=kernel,
5283 dilation_rate=dilation_rate,
5284 strides=strides,
5285 padding=padding,
5286 data_format=tf_data_format)
5287 if data_format == 'channels_first' and tf_data_format == 'NWC':
5288 x = array_ops.transpose(x, (0, 2, 1)) # NWC -> NCW
5289 return x
5292@keras_export('keras.backend.conv2d')
5293@dispatch.add_dispatch_support
5294@doc_controls.do_not_generate_docs
5295def conv2d(x,
5296 kernel,
5297 strides=(1, 1),
5298 padding='valid',
5299 data_format=None,
5300 dilation_rate=(1, 1)):
5301 """2D convolution.
5303 Args:
5304 x: Tensor or variable.
5305 kernel: kernel tensor.
5306 strides: strides tuple.
5307 padding: string, `"same"` or `"valid"`.
5308 data_format: `"channels_last"` or `"channels_first"`.
5309 dilation_rate: tuple of 2 integers.
5311 Returns:
5312 A tensor, result of 2D convolution.
5314 Raises:
5315 ValueError: if `data_format` is neither `channels_last` or
5316 `channels_first`.
5317 """
5318 if data_format is None:
5319 data_format = image_data_format()
5320 if data_format not in {'channels_first', 'channels_last'}:
5321 raise ValueError('Unknown data_format: ' + str(data_format))
5323 x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5324 padding = _preprocess_padding(padding)
5325 x = nn.convolution(
5326 input=x,
5327 filter=kernel,
5328 dilation_rate=dilation_rate,
5329 strides=strides,
5330 padding=padding,
5331 data_format=tf_data_format)
5332 if data_format == 'channels_first' and tf_data_format == 'NHWC':
5333 x = array_ops.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
5334 return x
5337@keras_export('keras.backend.conv2d_transpose')
5338@dispatch.add_dispatch_support
5339@doc_controls.do_not_generate_docs
5340def conv2d_transpose(x,
5341 kernel,
5342 output_shape,
5343 strides=(1, 1),
5344 padding='valid',
5345 data_format=None,
5346 dilation_rate=(1, 1)):
5347 """2D deconvolution (i.e.
5349 transposed convolution).
5351 Args:
5352 x: Tensor or variable.
5353 kernel: kernel tensor.
5354 output_shape: 1D int tensor for the output shape.
5355 strides: strides tuple.
5356 padding: string, `"same"` or `"valid"`.
5357 data_format: string, `"channels_last"` or `"channels_first"`.
5358 dilation_rate: Tuple of 2 integers.
5360 Returns:
5361 A tensor, result of transposed 2D convolution.
5363 Raises:
5364 ValueError: if `data_format` is neither `channels_last` or
5365 `channels_first`.
5366 """
5367 if data_format is None:
5368 data_format = image_data_format()
5369 if data_format not in {'channels_first', 'channels_last'}:
5370 raise ValueError('Unknown data_format: ' + str(data_format))
5372 # `atrous_conv2d_transpose` only supports NHWC format, even on GPU.
5373 if data_format == 'channels_first' and dilation_rate != (1, 1):
5374 force_transpose = True
5375 else:
5376 force_transpose = False
5378 x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
5380 if data_format == 'channels_first' and tf_data_format == 'NHWC':
5381 output_shape = (output_shape[0], output_shape[2], output_shape[3],
5382 output_shape[1])
5383 if output_shape[0] is None:
5384 output_shape = (shape(x)[0],) + tuple(output_shape[1:])
5386 if isinstance(output_shape, (tuple, list)):
5387 output_shape = array_ops_stack.stack(list(output_shape))
5389 padding = _preprocess_padding(padding)
5390 if tf_data_format == 'NHWC':
5391 strides = (1,) + strides + (1,)
5392 else:
5393 strides = (1, 1) + strides
5395 if dilation_rate == (1, 1):
5396 x = nn.conv2d_transpose(x, kernel, output_shape, strides,
5397 padding=padding,
5398 data_format=tf_data_format)
5399 else:
5400 assert dilation_rate[0] == dilation_rate[1]
5401 x = nn.atrous_conv2d_transpose(
5402 x,
5403 kernel,
5404 output_shape,
5405 rate=dilation_rate[0],
5406 padding=padding)
5407 if data_format == 'channels_first' and tf_data_format == 'NHWC':
5408 x = array_ops.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
5409 return x
5412def separable_conv1d(x,
5413 depthwise_kernel,
5414 pointwise_kernel,
5415 strides=1,
5416 padding='valid',
5417 data_format=None,
5418 dilation_rate=1):
5419 """1D convolution with separable filters.
5421 Args:
5422 x: input tensor
5423 depthwise_kernel: convolution kernel for the depthwise convolution.
5424 pointwise_kernel: kernel for the 1x1 convolution.
5425 strides: stride integer.
5426 padding: string, `"same"` or `"valid"`.
5427 data_format: string, `"channels_last"` or `"channels_first"`.
5428 dilation_rate: integer dilation rate.
5430 Returns:
5431 Output tensor.
5433 Raises:
5434 ValueError: if `data_format` is neither `channels_last` or
5435 `channels_first`.
5436 """
5437 if data_format is None:
5438 data_format = image_data_format()
5439 if data_format not in {'channels_first', 'channels_last'}:
5440 raise ValueError('Unknown data_format: ' + str(data_format))
5442 if isinstance(strides, int):
5443 strides = (strides,)
5444 if isinstance(dilation_rate, int):
5445 dilation_rate = (dilation_rate,)
5447 x, tf_data_format = _preprocess_conv1d_input(x, data_format)
5448 padding = _preprocess_padding(padding)
5449 if not isinstance(strides, tuple):
5450 strides = tuple(strides)
5451 if tf_data_format == 'NWC':
5452 spatial_start_dim = 1
5453 strides = (1,) + strides * 2 + (1,)
5454 else:
5455 spatial_start_dim = 2
5456 strides = (1, 1) + strides * 2
5457 x = array_ops.expand_dims(x, spatial_start_dim)
5458 depthwise_kernel = array_ops.expand_dims(depthwise_kernel, 0)
5459 pointwise_kernel = array_ops.expand_dims(pointwise_kernel, 0)
5460 dilation_rate = (1,) + dilation_rate
5462 x = nn.separable_conv2d(
5463 x,
5464 depthwise_kernel,
5465 pointwise_kernel,
5466 strides=strides,
5467 padding=padding,
5468 rate=dilation_rate,
5469 data_format=tf_data_format)
5471 x = array_ops.squeeze(x, [spatial_start_dim])
5473 if data_format == 'channels_first' and tf_data_format == 'NWC':
5474 x = array_ops.transpose(x, (0, 2, 1)) # NWC -> NCW
5476 return x
5479@keras_export('keras.backend.separable_conv2d')
5480@dispatch.add_dispatch_support
5481@doc_controls.do_not_generate_docs
5482def separable_conv2d(x,
5483 depthwise_kernel,
5484 pointwise_kernel,
5485 strides=(1, 1),
5486 padding='valid',
5487 data_format=None,
5488 dilation_rate=(1, 1)):
5489 """2D convolution with separable filters.
5491 Args:
5492 x: input tensor
5493 depthwise_kernel: convolution kernel for the depthwise convolution.
5494 pointwise_kernel: kernel for the 1x1 convolution.
5495 strides: strides tuple (length 2).
5496 padding: string, `"same"` or `"valid"`.
5497 data_format: string, `"channels_last"` or `"channels_first"`.
5498 dilation_rate: tuple of integers,
5499 dilation rates for the separable convolution.
5501 Returns:
5502 Output tensor.
5504 Raises:
5505 ValueError: if `data_format` is neither `channels_last` or
5506 `channels_first`.
5507 ValueError: if `strides` is not a tuple of 2 integers.
5508 """
5509 if data_format is None:
5510 data_format = image_data_format()
5511 if data_format not in {'channels_first', 'channels_last'}:
5512 raise ValueError('Unknown data_format: ' + str(data_format))
5513 if len(strides) != 2:
5514 raise ValueError('`strides` must be a tuple of 2 integers.')
5516 x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5517 padding = _preprocess_padding(padding)
5518 if not isinstance(strides, tuple):
5519 strides = tuple(strides)
5520 if tf_data_format == 'NHWC':
5521 strides = (1,) + strides + (1,)
5522 else:
5523 strides = (1, 1) + strides
5525 x = nn.separable_conv2d(
5526 x,
5527 depthwise_kernel,
5528 pointwise_kernel,
5529 strides=strides,
5530 padding=padding,
5531 rate=dilation_rate,
5532 data_format=tf_data_format)
5533 if data_format == 'channels_first' and tf_data_format == 'NHWC':
5534 x = array_ops.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
5535 return x
5538@keras_export('keras.backend.depthwise_conv2d')
5539@dispatch.add_dispatch_support
5540@doc_controls.do_not_generate_docs
5541def depthwise_conv2d(x,
5542 depthwise_kernel,
5543 strides=(1, 1),
5544 padding='valid',
5545 data_format=None,
5546 dilation_rate=(1, 1)):
5547 """2D convolution with separable filters.
5549 Args:
5550 x: input tensor
5551 depthwise_kernel: convolution kernel for the depthwise convolution.
5552 strides: strides tuple (length 2).
5553 padding: string, `"same"` or `"valid"`.
5554 data_format: string, `"channels_last"` or `"channels_first"`.
5555 dilation_rate: tuple of integers,
5556 dilation rates for the separable convolution.
5558 Returns:
5559 Output tensor.
5561 Raises:
5562 ValueError: if `data_format` is neither `channels_last` or
5563 `channels_first`.
5564 """
5565 if data_format is None:
5566 data_format = image_data_format()
5567 if data_format not in {'channels_first', 'channels_last'}:
5568 raise ValueError('Unknown data_format: ' + str(data_format))
5570 x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5571 padding = _preprocess_padding(padding)
5572 if tf_data_format == 'NHWC':
5573 strides = (1,) + strides + (1,)
5574 else:
5575 strides = (1, 1) + strides
5577 x = nn.depthwise_conv2d(
5578 x,
5579 depthwise_kernel,
5580 strides=strides,
5581 padding=padding,
5582 rate=dilation_rate,
5583 data_format=tf_data_format)
5584 if data_format == 'channels_first' and tf_data_format == 'NHWC':
5585 x = array_ops.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
5586 return x
5589@keras_export('keras.backend.conv3d')
5590@dispatch.add_dispatch_support
5591@doc_controls.do_not_generate_docs
5592def conv3d(x,
5593 kernel,
5594 strides=(1, 1, 1),
5595 padding='valid',
5596 data_format=None,
5597 dilation_rate=(1, 1, 1)):
5598 """3D convolution.
5600 Args:
5601 x: Tensor or variable.
5602 kernel: kernel tensor.
5603 strides: strides tuple.
5604 padding: string, `"same"` or `"valid"`.
5605 data_format: string, `"channels_last"` or `"channels_first"`.
5606 dilation_rate: tuple of 3 integers.
5608 Returns:
5609 A tensor, result of 3D convolution.
5611 Raises:
5612 ValueError: if `data_format` is neither `channels_last` or
5613 `channels_first`.
5614 """
5615 if data_format is None:
5616 data_format = image_data_format()
5617 if data_format not in {'channels_first', 'channels_last'}:
5618 raise ValueError('Unknown data_format: ' + str(data_format))
5620 x, tf_data_format = _preprocess_conv3d_input(x, data_format)
5621 padding = _preprocess_padding(padding)
5622 x = nn.convolution(
5623 input=x,
5624 filter=kernel,
5625 dilation_rate=dilation_rate,
5626 strides=strides,
5627 padding=padding,
5628 data_format=tf_data_format)
5629 if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5630 x = array_ops.transpose(x, (0, 4, 1, 2, 3))
5631 return x
5634def conv3d_transpose(x,
5635 kernel,
5636 output_shape,
5637 strides=(1, 1, 1),
5638 padding='valid',
5639 data_format=None):
5640 """3D deconvolution (i.e.
5642 transposed convolution).
5644 Args:
5645 x: input tensor.
5646 kernel: kernel tensor.
5647 output_shape: 1D int tensor for the output shape.
5648 strides: strides tuple.
5649 padding: string, "same" or "valid".
5650 data_format: string, `"channels_last"` or `"channels_first"`.
5652 Returns:
5653 A tensor, result of transposed 3D convolution.
5655 Raises:
5656 ValueError: if `data_format` is neither `channels_last` or
5657 `channels_first`.
5658 """
5659 if data_format is None:
5660 data_format = image_data_format()
5661 if data_format not in {'channels_first', 'channels_last'}:
5662 raise ValueError('Unknown data_format: ' + str(data_format))
5663 if isinstance(output_shape, (tuple, list)):
5664 output_shape = array_ops_stack.stack(output_shape)
5666 x, tf_data_format = _preprocess_conv3d_input(x, data_format)
5668 if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5669 output_shape = (output_shape[0], output_shape[2], output_shape[3],
5670 output_shape[4], output_shape[1])
5671 if output_shape[0] is None:
5672 output_shape = (array_ops.shape(x)[0],) + tuple(output_shape[1:])
5673 output_shape = array_ops_stack.stack(list(output_shape))
5675 padding = _preprocess_padding(padding)
5676 if tf_data_format == 'NDHWC':
5677 strides = (1,) + strides + (1,)
5678 else:
5679 strides = (1, 1) + strides
5681 x = nn.conv3d_transpose(
5682 x,
5683 kernel,
5684 output_shape,
5685 strides,
5686 padding=padding,
5687 data_format=tf_data_format)
5688 if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5689 x = array_ops.transpose(x, (0, 4, 1, 2, 3))
5690 return x
5693@keras_export('keras.backend.pool2d')
5694@dispatch.add_dispatch_support
5695@doc_controls.do_not_generate_docs
5696def pool2d(x,
5697 pool_size,
5698 strides=(1, 1),
5699 padding='valid',
5700 data_format=None,
5701 pool_mode='max'):
5702 """2D Pooling.
5704 Args:
5705 x: Tensor or variable.
5706 pool_size: tuple of 2 integers.
5707 strides: tuple of 2 integers.
5708 padding: string, `"same"` or `"valid"`.
5709 data_format: string, `"channels_last"` or `"channels_first"`.
5710 pool_mode: string, `"max"` or `"avg"`.
5712 Returns:
5713 A tensor, result of 2D pooling.
5715 Raises:
5716 ValueError: if `data_format` is neither `"channels_last"` or
5717 `"channels_first"`.
5718 ValueError: if `pool_size` is not a tuple of 2 integers.
5719 ValueError: if `strides` is not a tuple of 2 integers.
5720 ValueError: if `pool_mode` is neither `"max"` or `"avg"`.
5721 """
5722 if data_format is None:
5723 data_format = image_data_format()
5724 if data_format not in {'channels_first', 'channels_last'}:
5725 raise ValueError('Unknown data_format: ' + str(data_format))
5726 if len(pool_size) != 2:
5727 raise ValueError('`pool_size` must be a tuple of 2 integers.')
5728 if len(strides) != 2:
5729 raise ValueError('`strides` must be a tuple of 2 integers.')
5731 x, tf_data_format = _preprocess_conv2d_input(x, data_format)
5732 padding = _preprocess_padding(padding)
5733 if tf_data_format == 'NHWC':
5734 strides = (1,) + strides + (1,)
5735 pool_size = (1,) + pool_size + (1,)
5736 else:
5737 strides = (1, 1) + strides
5738 pool_size = (1, 1) + pool_size
5740 if pool_mode == 'max':
5741 x = nn.max_pool(
5742 x, pool_size, strides, padding=padding, data_format=tf_data_format)
5743 elif pool_mode == 'avg':
5744 x = nn.avg_pool(
5745 x, pool_size, strides, padding=padding, data_format=tf_data_format)
5746 else:
5747 raise ValueError('Invalid pooling mode: ' + str(pool_mode))
5749 if data_format == 'channels_first' and tf_data_format == 'NHWC':
5750 x = array_ops.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
5751 return x
5754@keras_export('keras.backend.pool3d')
5755@dispatch.add_dispatch_support
5756@doc_controls.do_not_generate_docs
5757def pool3d(x,
5758 pool_size,
5759 strides=(1, 1, 1),
5760 padding='valid',
5761 data_format=None,
5762 pool_mode='max'):
5763 """3D Pooling.
5765 Args:
5766 x: Tensor or variable.
5767 pool_size: tuple of 3 integers.
5768 strides: tuple of 3 integers.
5769 padding: string, `"same"` or `"valid"`.
5770 data_format: string, `"channels_last"` or `"channels_first"`.
5771 pool_mode: string, `"max"` or `"avg"`.
5773 Returns:
5774 A tensor, result of 3D pooling.
5776 Raises:
5777 ValueError: if `data_format` is neither `"channels_last"` or
5778 `"channels_first"`.
5779 ValueError: if `pool_mode` is neither `"max"` or `"avg"`.
5780 """
5781 if data_format is None:
5782 data_format = image_data_format()
5783 if data_format not in {'channels_first', 'channels_last'}:
5784 raise ValueError('Unknown data_format: ' + str(data_format))
5786 x, tf_data_format = _preprocess_conv3d_input(x, data_format)
5787 padding = _preprocess_padding(padding)
5788 if tf_data_format == 'NDHWC':
5789 strides = (1,) + strides + (1,)
5790 pool_size = (1,) + pool_size + (1,)
5791 else:
5792 strides = (1, 1) + strides
5793 pool_size = (1, 1) + pool_size
5795 if pool_mode == 'max':
5796 x = nn.max_pool3d(
5797 x, pool_size, strides, padding=padding, data_format=tf_data_format)
5798 elif pool_mode == 'avg':
5799 x = nn.avg_pool3d(
5800 x, pool_size, strides, padding=padding, data_format=tf_data_format)
5801 else:
5802 raise ValueError('Invalid pooling mode: ' + str(pool_mode))
5804 if data_format == 'channels_first' and tf_data_format == 'NDHWC':
5805 x = array_ops.transpose(x, (0, 4, 1, 2, 3))
5806 return x
5809def local_conv(inputs,
5810 kernel,
5811 kernel_size,
5812 strides,
5813 output_shape,
5814 data_format=None):
5815 """Apply N-D convolution with un-shared weights.
5817 Args:
5818 inputs: (N+2)-D tensor with shape
5819 (batch_size, channels_in, d_in1, ..., d_inN)
5820 if data_format='channels_first', or
5821 (batch_size, d_in1, ..., d_inN, channels_in)
5822 if data_format='channels_last'.
5823 kernel: the unshared weight for N-D convolution,
5824 with shape (output_items, feature_dim, channels_out), where
5825 feature_dim = np.prod(kernel_size) * channels_in,
5826 output_items = np.prod(output_shape).
5827 kernel_size: a tuple of N integers, specifying the
5828 spatial dimensions of the N-D convolution window.
5829 strides: a tuple of N integers, specifying the strides
5830 of the convolution along the spatial dimensions.
5831 output_shape: a tuple of (d_out1, ..., d_outN) specifying the spatial
5832 dimensionality of the output.
5833 data_format: string, "channels_first" or "channels_last".
5835 Returns:
5836 An (N+2)-D tensor with shape:
5837 (batch_size, channels_out) + output_shape
5838 if data_format='channels_first', or:
5839 (batch_size,) + output_shape + (channels_out,)
5840 if data_format='channels_last'.
5842 Raises:
5843 ValueError: if `data_format` is neither
5844 `channels_last` nor `channels_first`.
5845 """
5846 if data_format is None:
5847 data_format = image_data_format()
5848 if data_format not in {'channels_first', 'channels_last'}:
5849 raise ValueError('Unknown data_format: ' + str(data_format))
5851 kernel_shape = int_shape(kernel)
5852 feature_dim = kernel_shape[1]
5853 channels_out = kernel_shape[-1]
5854 ndims = len(output_shape)
5855 spatial_dimensions = list(range(ndims))
5857 xs = []
5858 output_axes_ticks = [range(axis_max) for axis_max in output_shape]
5859 for position in itertools.product(*output_axes_ticks):
5860 slices = [slice(None)]
5862 if data_format == 'channels_first':
5863 slices.append(slice(None))
5865 slices.extend(
5866 slice(position[d] * strides[d], position[d] * strides[d] +
5867 kernel_size[d]) for d in spatial_dimensions)
5869 if data_format == 'channels_last':
5870 slices.append(slice(None))
5872 xs.append(reshape(inputs[slices], (1, -1, feature_dim)))
5874 x_aggregate = concatenate(xs, axis=0)
5875 output = batch_dot(x_aggregate, kernel)
5876 output = reshape(output, output_shape + (-1, channels_out))
5878 if data_format == 'channels_first':
5879 permutation = [ndims, ndims + 1] + spatial_dimensions
5880 else:
5881 permutation = [ndims] + spatial_dimensions + [ndims + 1]
5883 return permute_dimensions(output, permutation)
5886@keras_export('keras.backend.local_conv1d')
5887@dispatch.add_dispatch_support
5888@doc_controls.do_not_generate_docs
5889def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
5890 """Apply 1D conv with un-shared weights.
5892 Args:
5893 inputs: 3D tensor with shape:
5894 (batch_size, steps, input_dim)
5895 if data_format is "channels_last" or
5896 (batch_size, input_dim, steps)
5897 if data_format is "channels_first".
5898 kernel: the unshared weight for convolution,
5899 with shape (output_length, feature_dim, filters).
5900 kernel_size: a tuple of a single integer,
5901 specifying the length of the 1D convolution window.
5902 strides: a tuple of a single integer,
5903 specifying the stride length of the convolution.
5904 data_format: the data format, channels_first or channels_last.
5906 Returns:
5907 A 3d tensor with shape:
5908 (batch_size, output_length, filters)
5909 if data_format='channels_first'
5910 or 3D tensor with shape:
5911 (batch_size, filters, output_length)
5912 if data_format='channels_last'.
5913 """
5914 output_shape = (kernel.shape[0],)
5915 return local_conv(inputs,
5916 kernel,
5917 kernel_size,
5918 strides,
5919 output_shape,
5920 data_format)
5923@keras_export('keras.backend.local_conv2d')
5924@dispatch.add_dispatch_support
5925@doc_controls.do_not_generate_docs
5926def local_conv2d(inputs,
5927 kernel,
5928 kernel_size,
5929 strides,
5930 output_shape,
5931 data_format=None):
5932 """Apply 2D conv with un-shared weights.
5934 Args:
5935 inputs: 4D tensor with shape:
5936 (batch_size, filters, new_rows, new_cols)
5937 if data_format='channels_first'
5938 or 4D tensor with shape:
5939 (batch_size, new_rows, new_cols, filters)
5940 if data_format='channels_last'.
5941 kernel: the unshared weight for convolution,
5942 with shape (output_items, feature_dim, filters).
5943 kernel_size: a tuple of 2 integers, specifying the
5944 width and height of the 2D convolution window.
5945 strides: a tuple of 2 integers, specifying the strides
5946 of the convolution along the width and height.
5947 output_shape: a tuple with (output_row, output_col).
5948 data_format: the data format, channels_first or channels_last.
5950 Returns:
5951 A 4D tensor with shape:
5952 (batch_size, filters, new_rows, new_cols)
5953 if data_format='channels_first'
5954 or 4D tensor with shape:
5955 (batch_size, new_rows, new_cols, filters)
5956 if data_format='channels_last'.
5957 """
5958 return local_conv(inputs,
5959 kernel,
5960 kernel_size,
5961 strides,
5962 output_shape,
5963 data_format)
5966@keras_export('keras.backend.bias_add')
5967@dispatch.add_dispatch_support
5968@doc_controls.do_not_generate_docs
5969def bias_add(x, bias, data_format=None):
5970 """Adds a bias vector to a tensor.
5972 Args:
5973 x: Tensor or variable.
5974 bias: Bias tensor to add.
5975 data_format: string, `"channels_last"` or `"channels_first"`.
5977 Returns:
5978 Output tensor.
5980 Raises:
5981 ValueError: In one of the two cases below:
5982 1. invalid `data_format` argument.
5983 2. invalid bias shape.
5984 the bias should be either a vector or
5985 a tensor with ndim(x) - 1 dimension
5986 """
5987 if data_format is None:
5988 data_format = image_data_format()
5989 if data_format not in {'channels_first', 'channels_last'}:
5990 raise ValueError('Unknown data_format: ' + str(data_format))
5991 bias_shape = int_shape(bias)
5992 if len(bias_shape) != 1 and len(bias_shape) != ndim(x) - 1:
5993 raise ValueError(
5994 'Unexpected bias dimensions %d, expect to be 1 or %d dimensions' %
5995 (len(bias_shape), ndim(x) - 1))
5997 if len(bias_shape) == 1:
5998 if data_format == 'channels_first':
5999 return nn.bias_add(x, bias, data_format='NCHW')
6000 return nn.bias_add(x, bias, data_format='NHWC')
6001 if ndim(x) in (3, 4, 5):
6002 if data_format == 'channels_first':
6003 bias_reshape_axis = (1, bias_shape[-1]) + bias_shape[:-1]
6004 return x + reshape(bias, bias_reshape_axis)
6005 return x + reshape(bias, (1,) + bias_shape)
6006 return nn.bias_add(x, bias)
6009# RANDOMNESS
6012@keras_export('keras.backend.random_normal')
6013@dispatch.add_dispatch_support
6014@doc_controls.do_not_generate_docs
6015def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
6016 """Returns a tensor with normal distribution of values.
6018 It is an alias to `tf.random.normal`.
6020 Args:
6021 shape: A tuple of integers, the shape of tensor to create.
6022 mean: A float, the mean value of the normal distribution to draw samples.
6023 Default to 0.0.
6024 stddev: A float, the standard deviation of the normal distribution
6025 to draw samples. Default to 1.0.
6026 dtype: `tf.dtypes.DType`, dtype of returned tensor. Default to use Keras
6027 backend dtype which is float32.
6028 seed: Integer, random seed. Will use a random numpy integer when not
6029 specified.
6031 Returns:
6032 A tensor with normal distribution of values.
6034 Example:
6036 >>> random_normal_tensor = tf.keras.backend.random_normal(shape=(2,3),
6037 ... mean=0.0, stddev=1.0)
6038 >>> random_normal_tensor
6039 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...,
6040 dtype=float32)>
6041 """
6042 if dtype is None:
6043 dtype = floatx()
6044 if seed is None:
6045 seed = np.random.randint(10e6)
6046 return random_ops.random_normal(
6047 shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed)
6050@keras_export('keras.backend.random_uniform')
6051@dispatch.add_dispatch_support
6052@doc_controls.do_not_generate_docs
6053def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
6054 """Returns a tensor with uniform distribution of values.
6056 Args:
6057 shape: A tuple of integers, the shape of tensor to create.
6058 minval: A float, lower boundary of the uniform distribution
6059 to draw samples.
6060 maxval: A float, upper boundary of the uniform distribution
6061 to draw samples.
6062 dtype: String, dtype of returned tensor.
6063 seed: Integer, random seed.
6065 Returns:
6066 A tensor.
6068 Example:
6070 >>> random_uniform_tensor = tf.keras.backend.random_uniform(shape=(2,3),
6071 ... minval=0.0, maxval=1.0)
6072 >>> random_uniform_tensor
6073 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...,
6074 dtype=float32)>
6075 """
6076 if dtype is None:
6077 dtype = floatx()
6078 if seed is None:
6079 seed = np.random.randint(10e6)
6080 return random_ops.random_uniform(
6081 shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed)
6084@keras_export('keras.backend.random_binomial')
6085@dispatch.add_dispatch_support
6086@doc_controls.do_not_generate_docs
6087def random_binomial(shape, p=0.0, dtype=None, seed=None):
6088 """Returns a tensor with random binomial distribution of values.
6090 DEPRECATED, use `tf.keras.backend.random_bernoulli` instead.
6092 The binomial distribution with parameters `n` and `p` is the probability
6093 distribution of the number of successful Bernoulli process. Only supports
6094 `n` = 1 for now.
6096 Args:
6097 shape: A tuple of integers, the shape of tensor to create.
6098 p: A float, `0. <= p <= 1`, probability of binomial distribution.
6099 dtype: String, dtype of returned tensor.
6100 seed: Integer, random seed.
6102 Returns:
6103 A tensor.
6105 Example:
6107 >>> random_binomial_tensor = tf.keras.backend.random_binomial(shape=(2,3),
6108 ... p=0.5)
6109 >>> random_binomial_tensor
6110 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...,
6111 dtype=float32)>
6112 """
6113 warnings.warn('`tf.keras.backend.random_binomial` is deprecated, '
6114 'and will be removed in a future version.'
6115 'Please use `tf.keras.backend.random_bernoulli` instead.')
6116 return random_bernoulli(shape, p, dtype, seed)
6119@keras_export('keras.backend.random_bernoulli')
6120@dispatch.add_dispatch_support
6121@doc_controls.do_not_generate_docs
6122def random_bernoulli(shape, p=0.0, dtype=None, seed=None):
6123 """Returns a tensor with random bernoulli distribution of values.
6125 Args:
6126 shape: A tuple of integers, the shape of tensor to create.
6127 p: A float, `0. <= p <= 1`, probability of bernoulli distribution.
6128 dtype: String, dtype of returned tensor.
6129 seed: Integer, random seed.
6131 Returns:
6132 A tensor.
6133 """
6134 if dtype is None:
6135 dtype = floatx()
6136 if seed is None:
6137 seed = np.random.randint(10e6)
6138 return array_ops.where_v2(
6139 random_ops.random_uniform(shape, dtype=dtype, seed=seed) <= p,
6140 array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype))
6143@keras_export('keras.backend.truncated_normal')
6144@dispatch.add_dispatch_support
6145@doc_controls.do_not_generate_docs
6146def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
6147 """Returns a tensor with truncated random normal distribution of values.
6149 The generated values follow a normal distribution
6150 with specified mean and standard deviation,
6151 except that values whose magnitude is more than
6152 two standard deviations from the mean are dropped and re-picked.
6154 Args:
6155 shape: A tuple of integers, the shape of tensor to create.
6156 mean: Mean of the values.
6157 stddev: Standard deviation of the values.
6158 dtype: String, dtype of returned tensor.
6159 seed: Integer, random seed.
6161 Returns:
6162 A tensor.
6163 """
6164 if dtype is None:
6165 dtype = floatx()
6166 if seed is None:
6167 seed = np.random.randint(10e6)
6168 return random_ops.truncated_normal(
6169 shape, mean, stddev, dtype=dtype, seed=seed)
6172# CTC
6173# TensorFlow has a native implementation, but it uses sparse tensors
6174# and therefore requires a wrapper for Keras. The functions below convert
6175# dense to sparse tensors and also wraps up the beam search code that is
6176# in TensorFlow's CTC implementation
6179@keras_export('keras.backend.ctc_label_dense_to_sparse')
6180@dispatch.add_dispatch_support
6181@doc_controls.do_not_generate_docs
6182def ctc_label_dense_to_sparse(labels, label_lengths):
6183 """Converts CTC labels from dense to sparse.
6185 Args:
6186 labels: dense CTC labels.
6187 label_lengths: length of the labels.
6189 Returns:
6190 A sparse tensor representation of the labels.
6191 """
6192 label_shape = array_ops.shape(labels)
6193 num_batches_tns = array_ops_stack.stack([label_shape[0]])
6194 max_num_labels_tns = array_ops_stack.stack([label_shape[1]])
6196 def range_less_than(old_input, current_input):
6197 return array_ops.expand_dims(
6198 math_ops.range(array_ops.shape(old_input)[1]), 0) < array_ops.fill(
6199 max_num_labels_tns, current_input)
6201 init = math_ops.cast(
6202 array_ops.fill([1, label_shape[1]], 0), dtypes_module.bool)
6203 dense_mask = functional_ops.scan(
6204 range_less_than, label_lengths, initializer=init, parallel_iterations=1)
6205 dense_mask = dense_mask[:, 0, :]
6207 label_array = array_ops.reshape(
6208 array_ops.tile(math_ops.range(0, label_shape[1]), num_batches_tns),
6209 label_shape)
6210 label_ind = array_ops.boolean_mask(label_array, dense_mask)
6212 batch_array = array_ops.transpose(
6213 array_ops.reshape(
6214 array_ops.tile(math_ops.range(0, label_shape[0]), max_num_labels_tns),
6215 reverse(label_shape, 0)))
6216 batch_ind = array_ops.boolean_mask(batch_array, dense_mask)
6217 indices = array_ops.transpose(
6218 array_ops.reshape(concatenate([batch_ind, label_ind], axis=0), [2, -1]))
6220 vals_sparse = array_ops.gather_nd(labels, indices)
6222 return sparse_tensor.SparseTensor(
6223 math_ops.cast(indices, dtypes_module.int64), vals_sparse,
6224 math_ops.cast(label_shape, dtypes_module.int64))
6227@keras_export('keras.backend.ctc_batch_cost')
6228@dispatch.add_dispatch_support
6229@doc_controls.do_not_generate_docs
6230def ctc_batch_cost(y_true, y_pred, input_length, label_length):
6231 """Runs CTC loss algorithm on each batch element.
6233 Args:
6234 y_true: tensor `(samples, max_string_length)`
6235 containing the truth labels.
6236 y_pred: tensor `(samples, time_steps, num_categories)`
6237 containing the prediction, or output of the softmax.
6238 input_length: tensor `(samples, 1)` containing the sequence length for
6239 each batch item in `y_pred`.
6240 label_length: tensor `(samples, 1)` containing the sequence length for
6241 each batch item in `y_true`.
6243 Returns:
6244 Tensor with shape (samples,1) containing the
6245 CTC loss of each element.
6246 """
6247 label_length = math_ops.cast(
6248 array_ops.squeeze(label_length, axis=-1), dtypes_module.int32)
6249 input_length = math_ops.cast(
6250 array_ops.squeeze(input_length, axis=-1), dtypes_module.int32)
6251 sparse_labels = math_ops.cast(
6252 ctc_label_dense_to_sparse(y_true, label_length), dtypes_module.int32)
6254 y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
6256 return array_ops.expand_dims(
6257 ctc.ctc_loss(
6258 inputs=y_pred, labels=sparse_labels, sequence_length=input_length), 1)
6261@keras_export('keras.backend.ctc_decode')
6262@dispatch.add_dispatch_support
6263@doc_controls.do_not_generate_docs
6264def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
6265 """Decodes the output of a softmax.
6267 Can use either greedy search (also known as best path)
6268 or a constrained dictionary search.
6270 Args:
6271 y_pred: tensor `(samples, time_steps, num_categories)`
6272 containing the prediction, or output of the softmax.
6273 input_length: tensor `(samples, )` containing the sequence length for
6274 each batch item in `y_pred`.
6275 greedy: perform much faster best-path search if `true`.
6276 This does not use a dictionary.
6277 beam_width: if `greedy` is `false`: a beam search decoder will be used
6278 with a beam of this width.
6279 top_paths: if `greedy` is `false`,
6280 how many of the most probable paths will be returned.
6282 Returns:
6283 Tuple:
6284 List: if `greedy` is `true`, returns a list of one element that
6285 contains the decoded sequence.
6286 If `false`, returns the `top_paths` most probable
6287 decoded sequences.
6288 Each decoded sequence has shape (samples, time_steps).
6289 Important: blank labels are returned as `-1`.
6290 Tensor `(top_paths, )` that contains
6291 the log probability of each decoded sequence.
6292 """
6293 input_shape = shape(y_pred)
6294 num_samples, num_steps = input_shape[0], input_shape[1]
6295 y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
6296 input_length = math_ops.cast(input_length, dtypes_module.int32)
6298 if greedy:
6299 (decoded, log_prob) = ctc.ctc_greedy_decoder(
6300 inputs=y_pred, sequence_length=input_length)
6301 else:
6302 (decoded, log_prob) = ctc.ctc_beam_search_decoder(
6303 inputs=y_pred,
6304 sequence_length=input_length,
6305 beam_width=beam_width,
6306 top_paths=top_paths)
6307 decoded_dense = []
6308 for st in decoded:
6309 st = sparse_tensor.SparseTensor(
6310 st.indices, st.values, (num_samples, num_steps))
6311 decoded_dense.append(
6312 sparse_ops.sparse_tensor_to_dense(sp_input=st, default_value=-1))
6313 return (decoded_dense, log_prob)
6316# HIGH ORDER FUNCTIONS
6319@keras_export('keras.backend.map_fn')
6320@doc_controls.do_not_generate_docs
6321def map_fn(fn, elems, name=None, dtype=None):
6322 """Map the function fn over the elements elems and return the outputs.
6324 Args:
6325 fn: Callable that will be called upon each element in elems
6326 elems: tensor
6327 name: A string name for the map node in the graph
6328 dtype: Output data type.
6330 Returns:
6331 Tensor with dtype `dtype`.
6332 """
6333 return map_fn_lib.map_fn(fn, elems, name=name, dtype=dtype)
6336@keras_export('keras.backend.foldl')
6337@doc_controls.do_not_generate_docs
6338def foldl(fn, elems, initializer=None, name=None):
6339 """Reduce elems using fn to combine them from left to right.
6341 Args:
6342 fn: Callable that will be called upon each element in elems and an
6343 accumulator, for instance `lambda acc, x: acc + x`
6344 elems: tensor
6345 initializer: The first value used (`elems[0]` in case of None)
6346 name: A string name for the foldl node in the graph
6348 Returns:
6349 Tensor with same type and shape as `initializer`.
6350 """
6351 return functional_ops.foldl(fn, elems, initializer=initializer, name=name)
6354@keras_export('keras.backend.foldr')
6355@doc_controls.do_not_generate_docs
6356def foldr(fn, elems, initializer=None, name=None):
6357 """Reduce elems using fn to combine them from right to left.
6359 Args:
6360 fn: Callable that will be called upon each element in elems and an
6361 accumulator, for instance `lambda acc, x: acc + x`
6362 elems: tensor
6363 initializer: The first value used (`elems[-1]` in case of None)
6364 name: A string name for the foldr node in the graph
6366 Returns:
6367 Same type and shape as initializer
6368 """
6369 return functional_ops.foldr(fn, elems, initializer=initializer, name=name)
6371# Load Keras default configuration from config file if present.
6372# Set Keras base dir path given KERAS_HOME env variable, if applicable.
6373# Otherwise either ~/.keras or /tmp.
6374if 'KERAS_HOME' in os.environ:
6375 _keras_dir = os.environ.get('KERAS_HOME')
6376else:
6377 _keras_base_dir = os.path.expanduser('~')
6378 _keras_dir = os.path.join(_keras_base_dir, '.keras')
6379_config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json'))
6380if os.path.exists(_config_path):
6381 try:
6382 with open(_config_path) as fh:
6383 _config = json.load(fh)
6384 except ValueError:
6385 _config = {}
6386 _floatx = _config.get('floatx', floatx())
6387 assert _floatx in {'float16', 'float32', 'float64'}
6388 _epsilon = _config.get('epsilon', epsilon())
6389 assert isinstance(_epsilon, float)
6390 _image_data_format = _config.get('image_data_format', image_data_format())
6391 assert _image_data_format in {'channels_last', 'channels_first'}
6392 set_floatx(_floatx)
6393 set_epsilon(_epsilon)
6394 set_image_data_format(_image_data_format)
6396# Save config file.
6397if not os.path.exists(_keras_dir):
6398 try:
6399 os.makedirs(_keras_dir)
6400 except OSError:
6401 # Except permission denied and potential race conditions
6402 # in multi-threaded environments.
6403 pass
6405if not os.path.exists(_config_path):
6406 _config = {
6407 'floatx': floatx(),
6408 'epsilon': epsilon(),
6409 'backend': 'tensorflow',
6410 'image_data_format': image_data_format()
6411 }
6412 try:
6413 with open(_config_path, 'w') as f:
6414 f.write(json.dumps(_config, indent=4))
6415 except IOError:
6416 # Except permission denied.
6417 pass
6420def configure_and_create_distributed_session(distribution_strategy):
6421 """Configure session config and create a session with it."""
6423 def _create_session(distribution_strategy):
6424 """Create the Distributed Strategy session."""
6425 session_config = get_default_session_config()
6427 # If a session already exists, merge in its config; in the case there is a
6428 # conflict, take values of the existing config.
6429 global _SESSION
6430 if getattr(_SESSION, 'session', None) and _SESSION.session._config:
6431 session_config.MergeFrom(_SESSION.session._config)
6433 if is_tpu_strategy(distribution_strategy):
6434 # TODO(priyag, yuefengz): Remove this workaround when Distribute
6435 # Coordinator is integrated with keras and we can create a session from
6436 # there.
6437 distribution_strategy.configure(session_config)
6438 master = distribution_strategy.extended._tpu_cluster_resolver.master() # pylint: disable=protected-access
6439 session = session_module.Session(config=session_config, target=master)
6440 else:
6441 worker_context = dc.get_current_worker_context()
6442 if worker_context:
6443 dc_session_config = worker_context.session_config
6444 # Merge the default session config to the one from distribute
6445 # coordinator, which is fine for now since they don't have
6446 # conflicting configurations.
6447 dc_session_config.MergeFrom(session_config)
6448 session = session_module.Session(
6449 config=dc_session_config, target=worker_context.master_target)
6450 else:
6451 distribution_strategy.configure(session_config)
6452 session = session_module.Session(config=session_config)
6454 set_session(session)
6456 if distribution_strategy.extended._in_multi_worker_mode():
6457 dc.run_distribute_coordinator(
6458 _create_session,
6459 distribution_strategy)
6460 else:
6461 _create_session(distribution_strategy)
6464def _is_tpu_strategy_class(clz):
6465 is_tpu_strat = lambda k: k.__name__.startswith('TPUStrategy')
6466 if is_tpu_strat(clz):
6467 return True
6468 return py_any(map(_is_tpu_strategy_class, clz.__bases__))
6471def is_tpu_strategy(strategy):
6472 """Returns whether input is a TPUStrategy instance or subclass instance."""
6473 return _is_tpu_strategy_class(strategy.__class__)
6476def cast_variables_to_tensor(tensors):
6478 def _cast_variables_to_tensor(tensor):
6479 if isinstance(tensor, variables_module.Variable):
6480 return array_ops.identity(tensor)
6481 return tensor
6483 return nest.map_structure(_cast_variables_to_tensor, tensors)
6486def _is_symbolic_tensor(x):
6487 return tensor_util.is_tf_type(x) and not isinstance(x, ops.EagerTensor)
6490def convert_inputs_if_ragged(inputs):
6491 """Converts any ragged tensors to dense."""
6493 def _convert_ragged_input(inputs):
6494 if isinstance(inputs, ragged_tensor.RaggedTensor):
6495 return inputs.to_tensor()
6496 return inputs
6498 flat_inputs = nest.flatten(inputs)
6499 contains_ragged = py_any(
6500 isinstance(i, ragged_tensor.RaggedTensor) for i in flat_inputs)
6502 if not contains_ragged:
6503 return inputs, None
6505 inputs = nest.map_structure(_convert_ragged_input, inputs)
6506 # Multiple mask are not yet supported, so one mask is used on all inputs.
6507 # We approach this similarly when using row lengths to ignore steps.
6508 nested_row_lengths = math_ops.cast(flat_inputs[0].nested_row_lengths()[0],
6509 'int32')
6510 return inputs, nested_row_lengths
6513def maybe_convert_to_ragged(is_ragged_input, output, nested_row_lengths,
6514 go_backwards=False):
6515 """Converts any ragged input back to its initial structure."""
6516 if not is_ragged_input:
6517 return output
6519 if go_backwards:
6520 # Reverse based on the timestep dim, so that nested_row_lengths will mask
6521 # from the correct direction. Return the reverse ragged tensor.
6522 output = reverse(output, [1])
6523 ragged = ragged_tensor.RaggedTensor.from_tensor(output, nested_row_lengths)
6524 return reverse(ragged, [1])
6525 else:
6526 return ragged_tensor.RaggedTensor.from_tensor(output, nested_row_lengths)
6529class ContextValueCache(weakref.WeakKeyDictionary):
6530 """Container that caches (possibly tensor) values based on the context.
6532 This class is similar to defaultdict, where values may be produced by the
6533 default factory specified during initialization. This class also has a default
6534 value for the key (when key is `None`) -- the key is set to the current graph
6535 or eager context. The default factories for key and value are only used in
6536 `__getitem__` and `setdefault`. The `.get()` behavior remains the same.
6538 This object will return the value of the current graph or closest parent graph
6539 if the current graph is a function. This is to reflect the fact that if a
6540 tensor is created in eager/graph, child functions may capture that tensor.
6542 The default factory method may accept keyword arguments (unlike defaultdict,
6543 which only accepts callables with 0 arguments). To pass keyword arguments to
6544 `default_factory`, use the `setdefault` method instead of `__getitem__`.
6546 An example of how this class can be used in different contexts:
6548 ```
6549 cache = ContextValueCache(int)
6551 # Eager mode
6552 cache[None] += 2
6553 cache[None] += 4
6554 assert cache[None] == 6
6556 # Graph mode
6557 with tf.Graph().as_default() as g:
6558 cache[None] += 5
6559 cache[g] += 3
6560 assert cache[g] == 8
6561 ```
6563 Example of a default factory with arguments:
6565 ```
6566 cache = ContextValueCache(lambda x: x + 1)
6567 g = tf.get_default_graph()
6569 # Example with keyword argument.
6570 value = cache.setdefault(key=g, kwargs={'x': 3})
6571 assert cache[g] == 4
6572 ```
6573 """
6575 def __init__(self, default_factory):
6576 self.default_factory = default_factory
6577 weakref.WeakKeyDictionary.__init__(self)
6579 def _key(self):
6580 if context.executing_eagerly():
6581 return _DUMMY_EAGER_GRAPH.key
6582 else:
6583 return ops.get_default_graph()
6585 def _get_parent_graph(self, graph):
6586 """Returns the parent graph or dummy eager object."""
6587 # TODO(b/149317164): Currently FuncGraphs use ops.get_default_graph() as the
6588 # outer graph. This results in outer_graph always being a Graph,
6589 # even in eager mode (get_default_graph will create a new Graph if there
6590 # isn't a default graph). Because of this bug, we have to specially set the
6591 # key when eager execution is enabled.
6592 parent_graph = graph.outer_graph
6593 if (not isinstance(parent_graph, func_graph.FuncGraph) and
6594 ops.executing_eagerly_outside_functions()):
6595 return _DUMMY_EAGER_GRAPH.key
6596 return parent_graph
6598 def _get_recursive(self, key):
6599 """Gets the value at key or the closest parent graph."""
6600 value = self.get(key)
6601 if value is not None:
6602 return value
6604 # Since FuncGraphs are able to capture tensors and variables from their
6605 # parent graphs, recursively search to see if there is a value stored for
6606 # one of the parent graphs.
6607 if isinstance(key, func_graph.FuncGraph):
6608 return self._get_recursive(self._get_parent_graph(key))
6609 return None
6611 def __getitem__(self, key):
6612 """Gets the value at key (or current context), or sets default value.
6614 Args:
6615 key: May be `None` or `Graph`object. When `None`, the key is set to the
6616 current context.
6618 Returns:
6619 Either the cached or default value.
6620 """
6621 if key is None:
6622 key = self._key()
6624 value = self._get_recursive(key)
6625 if value is None:
6626 value = self[key] = self.default_factory() # pylint:disable=not-callable
6627 return value
6629 def setdefault(self, key=None, default=None, kwargs=None):
6630 """Sets the default value if key is not in dict, and returns the value."""
6631 if key is None:
6632 key = self._key()
6633 kwargs = kwargs or {}
6635 if default is None and key not in self:
6636 default = self.default_factory(**kwargs)
6637 return weakref.WeakKeyDictionary.setdefault(self, key, default)
6639# This dictionary holds a mapping {graph: learning_phase}. In eager mode, a
6640# dummy object is used.
6641# A learning phase is a bool tensor used to run Keras models in
6642# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
6643_GRAPH_LEARNING_PHASES = ContextValueCache(_default_learning_phase)
6645# This dictionary holds a mapping between a graph and variables to initialize
6646# in the graph.
6647_GRAPH_VARIABLES = ContextValueCache(object_identity.ObjectIdentityWeakSet)
6649# This dictionary holds a mapping between a graph and TF optimizers created in
6650# the graph.
6651_GRAPH_TF_OPTIMIZERS = ContextValueCache(object_identity.ObjectIdentityWeakSet)