Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/nn_impl.py: 29%
431 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Implementation of Neural Net (NN) functions."""
17import math
19from tensorflow.python.distribute import distribute_lib
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import array_ops_stack
25from tensorflow.python.ops import candidate_sampling_ops
26from tensorflow.python.ops import check_ops
27from tensorflow.python.ops import cond as tf_cond
28from tensorflow.python.ops import custom_gradient
29from tensorflow.python.ops import embedding_ops
30from tensorflow.python.ops import gen_array_ops # pylint: disable=unused-import
31from tensorflow.python.ops import gen_nn_ops
32from tensorflow.python.ops import gen_sparse_ops
33from tensorflow.python.ops import linalg_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import nn_ops
36from tensorflow.python.ops import variables
37from tensorflow.python.ops.losses import util as losses_util
38from tensorflow.python.platform import device_context
39from tensorflow.python.util import dispatch
40from tensorflow.python.util.deprecation import deprecated_args
41from tensorflow.python.util.deprecation import deprecated_argument_lookup
42from tensorflow.python.util.tf_export import tf_export
45@tf_export("nn.log_poisson_loss")
46@dispatch.add_dispatch_support
47def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None):
48 """Computes log Poisson loss given `log_input`.
50 Gives the log-likelihood loss between the prediction and the target under the
51 assumption that the target has a Poisson distribution.
52 Caveat: By default, this is not the exact loss, but the loss minus a
53 constant term [log(z!)]. That has no effect for optimization, but
54 does not play well with relative loss comparisons. To compute an
55 approximation of the log factorial term, specify
56 compute_full_loss=True to enable Stirling's Approximation.
58 For brevity, let `c = log(x) = log_input`, `z = targets`. The log Poisson
59 loss is
61 -log(exp(-x) * (x^z) / z!)
62 = -log(exp(-x) * (x^z)) + log(z!)
63 ~ -log(exp(-x)) - log(x^z) [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
64 [ Note the second term is the Stirling's Approximation for log(z!).
65 It is invariant to x and does not affect optimization, though
66 important for correct relative loss comparisons. It is only
67 computed when compute_full_loss == True. ]
68 = x - z * log(x) [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
69 = exp(c) - z * c [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
71 Args:
72 targets: A `Tensor` of the same type and shape as `log_input`.
73 log_input: A `Tensor` of type `float32` or `float64`.
74 compute_full_loss: whether to compute the full loss. If false, a constant
75 term is dropped in favor of more efficient optimization.
76 name: A name for the operation (optional).
78 Returns:
79 A `Tensor` of the same shape as `log_input` with the componentwise
80 logistic losses.
82 Raises:
83 ValueError: If `log_input` and `targets` do not have the same shape.
84 """
85 with ops.name_scope(name, "log_poisson_loss", [log_input, targets]) as name:
86 log_input = ops.convert_to_tensor(log_input, name="log_input")
87 targets = ops.convert_to_tensor(targets, name="targets")
88 try:
89 targets.get_shape().assert_is_compatible_with(log_input.get_shape())
90 except ValueError:
91 raise ValueError(
92 "`log_input` and `targets` must have the same shape, received "
93 f"({log_input.get_shape()} vs {targets.get_shape()}).")
95 result = math_ops.exp(log_input) - log_input * targets
96 if compute_full_loss:
97 # need to create constant tensors here so that their dtypes can be matched
98 # to that of the targets.
99 point_five = constant_op.constant(0.5, dtype=targets.dtype)
100 two_pi = constant_op.constant(2 * math.pi, dtype=targets.dtype)
102 stirling_approx = (targets * math_ops.log(targets)) - targets + (
103 point_five * math_ops.log(two_pi * targets))
104 zeros = array_ops.zeros_like(targets, dtype=targets.dtype)
105 ones = array_ops.ones_like(targets, dtype=targets.dtype)
106 cond = math_ops.logical_and(targets >= zeros, targets <= ones)
107 result += array_ops.where(cond, zeros, stirling_approx)
108 return result
111@tf_export(v1=["nn.sigmoid_cross_entropy_with_logits"])
112@dispatch.add_dispatch_support
113def sigmoid_cross_entropy_with_logits(
114 labels=None,
115 logits=None,
116 name=None):
117 """See sigmoid_cross_entropy_with_logits_v2."""
118 # pylint: disable=protected-access
119 nn_ops._ensure_xent_args("sigmoid_cross_entropy_with_logits", labels, logits)
120 # pylint: enable=protected-access
122 with ops.name_scope(name, "logistic_loss", [logits, labels]) as name:
123 logits = ops.convert_to_tensor(logits, name="logits")
124 labels = ops.convert_to_tensor(labels, name="labels")
125 try:
126 labels.get_shape().assert_is_compatible_with(logits.get_shape())
127 except ValueError:
128 raise ValueError("`logits` and `labels` must have the same shape, "
129 f"received ({logits.get_shape()} vs "
130 f"{labels.get_shape()}).")
132 # The logistic loss formula from above is
133 # x - x * z + log(1 + exp(-x))
134 # For x < 0, a more numerically stable formula is
135 # -x * z + log(1 + exp(x))
136 # Note that these two expressions can be combined into the following:
137 # max(x, 0) - x * z + log(1 + exp(-abs(x)))
138 # To allow computing gradients at zero, we define custom versions of max and
139 # abs functions.
140 zeros = array_ops.zeros_like(logits, dtype=logits.dtype)
141 cond = (logits >= zeros)
142 relu_logits = array_ops.where(cond, logits, zeros)
143 neg_abs_logits = array_ops.where(cond, -logits, logits) # pylint: disable=invalid-unary-operand-type
144 return math_ops.add(
145 relu_logits - logits * labels,
146 math_ops.log1p(math_ops.exp(neg_abs_logits)),
147 name=name)
150# Note: intentionally calling this v2 to not allow existing code with indirect
151# imports to ignore the sentinel behavior.
152@tf_export("nn.sigmoid_cross_entropy_with_logits", v1=[])
153@dispatch.register_binary_elementwise_api
154@dispatch.add_dispatch_support
155def sigmoid_cross_entropy_with_logits_v2( # pylint: disable=invalid-name
156 labels=None,
157 logits=None,
158 name=None):
159 r"""Computes sigmoid cross entropy given `logits`.
161 Measures the probability error in tasks with two outcomes in which each
162 outcome is independent and need not have a fully certain label. For instance,
163 one could perform a regression where the probability of an event happening is
164 known and used as a label. This loss may also be used for binary
165 classification, where labels are either zero or one.
167 For brevity, let `x = logits`, `z = labels`. The logistic loss is
169 z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
170 = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
171 = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
172 = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
173 = (1 - z) * x + log(1 + exp(-x))
174 = x - x * z + log(1 + exp(-x))
176 For x < 0, to avoid overflow in exp(-x), we reformulate the above
178 x - x * z + log(1 + exp(-x))
179 = log(exp(x)) - x * z + log(1 + exp(-x))
180 = - x * z + log(1 + exp(x))
182 Hence, to ensure stability and avoid overflow, the implementation uses this
183 equivalent formulation
185 max(x, 0) - x * z + log(1 + exp(-abs(x)))
187 `logits` and `labels` must have the same type and shape.
189 >>> logits = tf.constant([1., -1., 0., 1., -1., 0., 0.])
190 >>> labels = tf.constant([0., 0., 0., 1., 1., 1., 0.5])
191 >>> tf.nn.sigmoid_cross_entropy_with_logits(
192 ... labels=labels, logits=logits).numpy()
193 array([1.3132617, 0.3132617, 0.6931472, 0.3132617, 1.3132617, 0.6931472,
194 0.6931472], dtype=float32)
196 Compared to the losses which handle multiple outcomes,
197 `tf.nn.softmax_cross_entropy_with_logits` for general multi-class
198 classification and `tf.nn.sparse_softmax_cross_entropy_with_logits` for more
199 efficient multi-class classification with hard labels,
200 `sigmoid_cross_entropy_with_logits` is a slight simplification for binary
201 classification:
203 sigmoid(x) = softmax([x, 0])[0]
205 $$\frac{1}{1 + e^{-x}} = \frac{e^x}{e^x + e^0}$$
207 While `sigmoid_cross_entropy_with_logits` works for soft binary labels
208 (probabilities between 0 and 1), it can also be used for binary classification
209 where the labels are hard. There is an equivalence between all three symbols
210 in this case, with a probability 0 indicating the second class or 1 indicating
211 the first class:
213 >>> sigmoid_logits = tf.constant([1., -1., 0.])
214 >>> softmax_logits = tf.stack([sigmoid_logits, tf.zeros_like(sigmoid_logits)],
215 ... axis=-1)
216 >>> soft_binary_labels = tf.constant([1., 1., 0.])
217 >>> soft_multiclass_labels = tf.stack(
218 ... [soft_binary_labels, 1. - soft_binary_labels], axis=-1)
219 >>> hard_labels = tf.constant([0, 0, 1])
220 >>> tf.nn.sparse_softmax_cross_entropy_with_logits(
221 ... labels=hard_labels, logits=softmax_logits).numpy()
222 array([0.31326166, 1.3132616 , 0.6931472 ], dtype=float32)
223 >>> tf.nn.softmax_cross_entropy_with_logits(
224 ... labels=soft_multiclass_labels, logits=softmax_logits).numpy()
225 array([0.31326166, 1.3132616, 0.6931472], dtype=float32)
226 >>> tf.nn.sigmoid_cross_entropy_with_logits(
227 ... labels=soft_binary_labels, logits=sigmoid_logits).numpy()
228 array([0.31326166, 1.3132616, 0.6931472], dtype=float32)
230 Args:
231 labels: A `Tensor` of the same type and shape as `logits`. Between 0 and 1,
232 inclusive.
233 logits: A `Tensor` of type `float32` or `float64`. Any real number.
234 name: A name for the operation (optional).
236 Returns:
237 A `Tensor` of the same shape as `logits` with the componentwise
238 logistic losses.
240 Raises:
241 ValueError: If `logits` and `labels` do not have the same shape.
242 """
243 return sigmoid_cross_entropy_with_logits(
244 logits=logits, labels=labels, name=name)
247sigmoid_cross_entropy_with_logits.__doc__ = (
248 sigmoid_cross_entropy_with_logits_v2.__doc__)
251@tf_export("nn.weighted_cross_entropy_with_logits", v1=[])
252@dispatch.add_dispatch_support
253def weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight,
254 name=None):
255 """Computes a weighted cross entropy.
257 This is like `sigmoid_cross_entropy_with_logits()` except that `pos_weight`,
258 allows one to trade off recall and precision by up- or down-weighting the
259 cost of a positive error relative to a negative error.
261 The usual cross-entropy cost is defined as:
263 labels * -log(sigmoid(logits)) +
264 (1 - labels) * -log(1 - sigmoid(logits))
266 A value `pos_weight > 1` decreases the false negative count, hence increasing
267 the recall.
268 Conversely setting `pos_weight < 1` decreases the false positive count and
269 increases the precision.
270 This can be seen from the fact that `pos_weight` is introduced as a
271 multiplicative coefficient for the positive labels term
272 in the loss expression:
274 labels * -log(sigmoid(logits)) * pos_weight +
275 (1 - labels) * -log(1 - sigmoid(logits))
277 For brevity, let `x = logits`, `z = labels`, `q = pos_weight`.
278 The loss is:
280 qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
281 = qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
282 = qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
283 = qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
284 = (1 - z) * x + (qz + 1 - z) * log(1 + exp(-x))
285 = (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
287 Setting `l = (1 + (q - 1) * z)`, to ensure stability and avoid overflow,
288 the implementation uses
290 (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
292 `logits` and `labels` must have the same type and shape.
294 >>> labels = tf.constant([1., 0.5, 0.])
295 >>> logits = tf.constant([1.5, -0.1, -10.])
296 >>> tf.nn.weighted_cross_entropy_with_logits(
297 ... labels=labels, logits=logits, pos_weight=tf.constant(1.5)).numpy()
298 array([3.0211994e-01, 8.8049585e-01, 4.5776367e-05], dtype=float32)
299 >>> tf.nn.weighted_cross_entropy_with_logits(
300 ... labels=labels, logits=logits, pos_weight=tf.constant(0.5)).numpy()
301 array([1.00706644e-01, 5.08297503e-01, 4.57763672e-05], dtype=float32)
303 Args:
304 labels: A `Tensor` of the same type and shape as `logits`, with values
305 between 0 and 1 inclusive.
306 logits: A `Tensor` of type `float32` or `float64`, any real numbers.
307 pos_weight: A coefficient to use on the positive examples, typically a
308 scalar but otherwise broadcastable to the shape of `logits`. Its value
309 should be non-negative.
310 name: A name for the operation (optional).
312 Returns:
313 A `Tensor` of the same shape as `logits` with the componentwise
314 weighted logistic losses.
316 Raises:
317 ValueError: If `logits` and `labels` do not have the same shape.
318 """
319 with ops.name_scope(name, "logistic_loss", [logits, labels]) as name:
320 logits = ops.convert_to_tensor(logits, name="logits")
321 labels = ops.convert_to_tensor(labels, name="labels")
322 try:
323 labels.get_shape().assert_is_compatible_with(logits.get_shape())
324 except ValueError:
325 raise ValueError("`logits` and `labels` must have the same shape, "
326 f"received ({logits.get_shape()} vs "
327 f"{labels.get_shape()}).")
329 # The logistic loss formula from above is
330 # (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
331 # For x < 0, a more numerically stable formula is
332 # (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(x)) - l * x
333 # To avoid branching, we use the combined version
334 # (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
335 log_weight = 1 + (pos_weight - 1) * labels
336 return math_ops.add(
337 (1 - labels) * logits,
338 log_weight * (math_ops.log1p(math_ops.exp(-math_ops.abs(logits))) +
339 nn_ops.relu(-logits)), # pylint: disable=invalid-unary-operand-type
340 name=name)
343@tf_export(v1=["nn.weighted_cross_entropy_with_logits"])
344@dispatch.add_dispatch_support
345@deprecated_args(None, "targets is deprecated, use labels instead", "targets")
346def weighted_cross_entropy_with_logits(labels=None,
347 logits=None,
348 pos_weight=None,
349 name=None,
350 targets=None):
351 """Computes a weighted cross entropy.
353 This is like `sigmoid_cross_entropy_with_logits()` except that `pos_weight`,
354 allows one to trade off recall and precision by up- or down-weighting the
355 cost of a positive error relative to a negative error.
357 The usual cross-entropy cost is defined as:
359 labels * -log(sigmoid(logits)) +
360 (1 - labels) * -log(1 - sigmoid(logits))
362 A value `pos_weight > 1` decreases the false negative count, hence increasing
363 the recall.
364 Conversely setting `pos_weight < 1` decreases the false positive count and
365 increases the precision.
366 This can be seen from the fact that `pos_weight` is introduced as a
367 multiplicative coefficient for the positive labels term
368 in the loss expression:
370 labels * -log(sigmoid(logits)) * pos_weight +
371 (1 - labels) * -log(1 - sigmoid(logits))
373 For brevity, let `x = logits`, `z = labels`, `q = pos_weight`.
374 The loss is:
376 qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
377 = qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
378 = qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
379 = qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
380 = (1 - z) * x + (qz + 1 - z) * log(1 + exp(-x))
381 = (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
383 Setting `l = (1 + (q - 1) * z)`, to ensure stability and avoid overflow,
384 the implementation uses
386 (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
388 `logits` and `labels` must have the same type and shape.
390 Args:
391 labels: A `Tensor` of the same type and shape as `logits`.
392 logits: A `Tensor` of type `float32` or `float64`.
393 pos_weight: A coefficient to use on the positive examples.
394 name: A name for the operation (optional).
395 targets: Deprecated alias for labels.
397 Returns:
398 A `Tensor` of the same shape as `logits` with the componentwise
399 weighted logistic losses.
401 Raises:
402 ValueError: If `logits` and `labels` do not have the same shape.
403 """
404 labels = deprecated_argument_lookup("labels", labels, "targets", targets)
405 return weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight, name)
408@tf_export("nn.compute_average_loss")
409@dispatch.add_dispatch_support
410def compute_average_loss(per_example_loss,
411 sample_weight=None,
412 global_batch_size=None):
413 """Scales per-example losses with sample_weights and computes their average.
415 Usage with distribution strategy and custom training loop:
417 ```python
418 with strategy.scope():
419 def compute_loss(labels, predictions, sample_weight=None):
421 # If you are using a `Loss` class instead, set reduction to `NONE` so that
422 # we can do the reduction afterwards and divide by global batch size.
423 per_example_loss = tf.keras.losses.sparse_categorical_crossentropy(
424 labels, predictions)
426 # Compute loss that is scaled by sample_weight and by global batch size.
427 return tf.nn.compute_average_loss(
428 per_example_loss,
429 sample_weight=sample_weight,
430 global_batch_size=GLOBAL_BATCH_SIZE)
431 ```
433 Args:
434 per_example_loss: Per-example loss.
435 sample_weight: Optional weighting for each example.
436 global_batch_size: Optional global batch size value. Defaults to (size of
437 first dimension of `losses`) * (number of replicas).
439 Returns:
440 Scalar loss value, obtained by summing the `per_example_loss` and dividing
441 by `global_batch_size`. If `global_batch_size` is zero, the result is zero.
442 """ # pylint: disable=g-doc-exception
443 per_example_loss = ops.convert_to_tensor(per_example_loss)
444 input_dtype = per_example_loss.dtype
446 with losses_util.check_per_example_loss_rank(per_example_loss):
447 if sample_weight is not None:
448 sample_weight = ops.convert_to_tensor(sample_weight)
449 per_example_loss = losses_util.scale_losses_by_sample_weight(
450 per_example_loss, sample_weight)
451 per_example_loss = math_ops.cast(per_example_loss, input_dtype)
453 if global_batch_size is None:
454 if (distribute_lib.has_strategy()
455 and distribute_lib.in_cross_replica_context()):
456 raise RuntimeError(
457 "You are calling `compute_average_loss` in cross replica context, "
458 "while it was expected to be called in replica context.")
460 num_replicas = distribute_lib.get_strategy().num_replicas_in_sync
461 per_replica_batch_size = array_ops.shape_v2(per_example_loss)[0]
462 global_batch_size = per_replica_batch_size * num_replicas
464 check_ops.assert_scalar_v2(
465 global_batch_size, message="global_batch_size must be scalar.")
466 check_ops.assert_integer_v2(
467 global_batch_size,
468 message="global_batch_size must be an integer.")
469 check_ops.assert_non_negative_v2(
470 global_batch_size, message="global_batch_size must be non-negative.")
472 loss = math_ops.reduce_sum(per_example_loss)
473 global_batch_size = math_ops.cast(global_batch_size, input_dtype)
474 return math_ops.div_no_nan(loss, global_batch_size)
477@tf_export("nn.scale_regularization_loss")
478@dispatch.add_dispatch_support
479def scale_regularization_loss(regularization_loss):
480 """Scales the sum of the given regularization losses by number of replicas.
482 Usage with distribution strategy and custom training loop:
484 ```python
485 with strategy.scope():
486 def compute_loss(self, label, predictions):
487 per_example_loss = tf.keras.losses.sparse_categorical_crossentropy(
488 labels, predictions)
490 # Compute loss that is scaled by sample_weight and by global batch size.
491 loss = tf.nn.compute_average_loss(
492 per_example_loss,
493 sample_weight=sample_weight,
494 global_batch_size=GLOBAL_BATCH_SIZE)
496 # Add scaled regularization losses.
497 loss += tf.nn.scale_regularization_loss(tf.nn.l2_loss(weights))
498 return loss
499 ```
501 Args:
502 regularization_loss: Regularization loss.
504 Returns:
505 Scalar loss value.
506 """ # pylint: disable=g-doc-exception
507 if (distribute_lib.has_strategy()
508 and distribute_lib.in_cross_replica_context()):
509 raise RuntimeError(
510 "You are calling `scale_regularization_loss` in cross replica context, "
511 "while it was expected to be called in replica context.")
513 num_replicas = distribute_lib.get_strategy().num_replicas_in_sync
514 return math_ops.reduce_sum(regularization_loss) / num_replicas
517@tf_export(v1=["nn.relu_layer"])
518@dispatch.add_dispatch_support
519def relu_layer(x, weights, biases, name=None):
520 """Computes Relu(x * weight + biases).
522 Args:
523 x: a 2D tensor. Dimensions typically: batch, in_units
524 weights: a 2D tensor. Dimensions typically: in_units, out_units
525 biases: a 1D tensor. Dimensions: out_units
526 name: A name for the operation (optional). If not specified
527 "nn_relu_layer" is used.
529 Returns:
530 A 2-D Tensor computing relu(matmul(x, weights) + biases).
531 Dimensions typically: batch, out_units.
532 """
533 with ops.name_scope(name, "relu_layer", [x, weights, biases]) as name:
534 x = ops.convert_to_tensor(x, name="x")
535 weights = ops.convert_to_tensor(weights, name="weights")
536 biases = ops.convert_to_tensor(biases, name="biases")
537 xw_plus_b = nn_ops.bias_add(math_ops.matmul(x, weights), biases)
538 return nn_ops.relu(xw_plus_b, name=name)
541@tf_export("nn.silu", "nn.swish")
542@dispatch.register_unary_elementwise_api
543@dispatch.add_dispatch_support
544def swish(features, beta=1.0):
545 # pylint: disable=g-doc-args
546 """Computes the SiLU or Swish activation function: `x * sigmoid(beta * x)`.
548 beta : Hyperparameter for Swish activation function. Default value 1.0.
550 The SiLU activation function was introduced in "Gaussian Error Linear Units
551 (GELUs)" [Hendrycks et al. 2016](https://arxiv.org/abs/1606.08415) and
552 "Sigmoid-Weighted Linear Units for Neural Network Function Approximation in
553 Reinforcement Learning"
554 [Elfwing et al. 2017](https://arxiv.org/abs/1702.03118) and was independently
555 discovered (and called swish) in "Searching for Activation Functions"
556 [Ramachandran et al. 2017](https://arxiv.org/abs/1710.05941)
558 Args:
559 features: A `Tensor` representing preactivation values.
560 beta: A 'Tensor' representing value of beta hyperparameter.
562 Returns:
563 The activation value.
564 """
565 # pylint: enable=g-doc-args
566 features = ops.convert_to_tensor(features, name="features")
567 beta = ops.convert_to_tensor(beta, name="beta")
568 beta = math_ops.cast(beta, features.dtype)
570 @custom_gradient.custom_gradient
571 def swish_impl(features, beta):
573 def grad(dy):
574 """Gradient for the Swish activation function."""
575 # Naively, x * tf.nn.sigmoid(x) requires keeping both x and sigmoid(x)
576 # around for backprop, effectively doubling the tensor's memory
577 # consumption. We use a control dependency here so that sigmoid(features)
578 # is re-computed during backprop (the control dep prevents it being
579 # de-duped with the forward pass) and we can free the sigmoid(features)
580 # expression immediately after use during the forward pass.
581 with ops.control_dependencies([dy]):
582 sigmoid_features = math_ops.sigmoid(beta * features)
584 activation_grad = (
585 sigmoid_features * (1.0 + (beta * features) *
586 (1.0 - sigmoid_features)))
587 beta_grad = math_ops.reduce_sum(
588 dy * math_ops.square(features) * sigmoid_features *
589 (1.0 - sigmoid_features))
590 return (dy * activation_grad, beta_grad)
592 return features * math_ops.sigmoid(beta * features), grad
594 return swish_impl(features, beta)
597# pylint: disable=redefined-builtin
598@tf_export("linalg.normalize")
599@dispatch.add_dispatch_support
600def normalize(tensor, ord="euclidean", axis=None, name=None):
601 """Normalizes `tensor` along dimension `axis` using specified norm.
603 This uses `tf.linalg.norm` to compute the norm along `axis`.
605 This function can compute several different vector norms (the 1-norm, the
606 Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0) and
607 matrix norms (Frobenius, 1-norm, 2-norm and inf-norm).
609 Args:
610 tensor: `Tensor` of types `float32`, `float64`, `complex64`, `complex128`
611 ord: Order of the norm. Supported values are `'fro'`, `'euclidean'`, `1`,
612 `2`, `np.inf` and any positive real number yielding the corresponding
613 p-norm. Default is `'euclidean'` which is equivalent to Frobenius norm if
614 `tensor` is a matrix and equivalent to 2-norm for vectors.
615 Some restrictions apply: a) The Frobenius norm `'fro'` is not defined for
616 vectors, b) If axis is a 2-tuple (matrix norm), only `'euclidean'`,
617 '`fro'`, `1`, `2`, `np.inf` are supported. See the description of `axis`
618 on how to compute norms for a batch of vectors or matrices stored in a
619 tensor.
620 axis: If `axis` is `None` (the default), the input is considered a vector
621 and a single vector norm is computed over the entire set of values in the
622 tensor, i.e. `norm(tensor, ord=ord)` is equivalent to
623 `norm(reshape(tensor, [-1]), ord=ord)`. If `axis` is a Python integer, the
624 input is considered a batch of vectors, and `axis` determines the axis in
625 `tensor` over which to compute vector norms. If `axis` is a 2-tuple of
626 Python integers it is considered a batch of matrices and `axis` determines
627 the axes in `tensor` over which to compute a matrix norm.
628 Negative indices are supported. Example: If you are passing a tensor that
629 can be either a matrix or a batch of matrices at runtime, pass
630 `axis=[-2,-1]` instead of `axis=None` to make sure that matrix norms are
631 computed.
632 name: The name of the op.
634 Returns:
635 normalized: A normalized `Tensor` with the same shape as `tensor`.
636 norm: The computed norms with the same shape and dtype `tensor` but the
637 final axis is 1 instead. Same as running
638 `tf.cast(tf.linalg.norm(tensor, ord, axis keepdims=True), tensor.dtype)`.
640 Raises:
641 ValueError: If `ord` or `axis` is invalid.
642 """
643 with ops.name_scope(name, "normalize", [tensor]) as name:
644 tensor = ops.convert_to_tensor(tensor)
645 norm = linalg_ops.norm(tensor, ord, axis, keepdims=True)
646 norm = math_ops.cast(norm, tensor.dtype)
647 normalized = tensor / norm
648 return normalized, norm
651@tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize",
652 v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"])
653@dispatch.add_dispatch_support
654@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
655def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
656 """Normalizes along dimension `axis` using an L2 norm.
658 For a 1-D tensor with `axis = 0`, computes
660 output = x / sqrt(max(sum(x**2), epsilon))
662 For `x` with more dimensions, independently normalizes each 1-D slice along
663 dimension `axis`.
665 1-D tensor example:
666 >>> x = tf.constant([3.0, 4.0])
667 >>> tf.math.l2_normalize(x).numpy()
668 array([0.6, 0.8], dtype=float32)
670 2-D tensor example:
671 >>> x = tf.constant([[3.0], [4.0]])
672 >>> tf.math.l2_normalize(x, 0).numpy()
673 array([[0.6],
674 [0.8]], dtype=float32)
676 >>> x = tf.constant([[3.0], [4.0]])
677 >>> tf.math.l2_normalize(x, 1).numpy()
678 array([[1.],
679 [1.]], dtype=float32)
681 Args:
682 x: A `Tensor`.
683 axis: Dimension along which to normalize. A scalar or a vector of
684 integers.
685 epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the
686 divisor if `norm < sqrt(epsilon)`.
687 name: A name for this operation (optional).
688 dim: Deprecated, do not use.
690 Returns:
691 A `Tensor` with the same shape as `x`.
692 """
693 axis = deprecated_argument_lookup("axis", axis, "dim", dim)
694 with ops.name_scope(name, "l2_normalize", [x]) as name:
695 x = ops.convert_to_tensor(x, name="x")
696 if x.dtype.is_complex:
697 square_real = math_ops.square(math_ops.real(x))
698 square_imag = math_ops.square(math_ops.imag(x))
699 square_sum = math_ops.real(
700 math_ops.reduce_sum(square_real + square_imag, axis, keepdims=True))
701 x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon))
702 norm_real = math_ops.multiply(math_ops.real(x), x_inv_norm)
703 norm_imag = math_ops.multiply(math_ops.imag(x), x_inv_norm)
704 return math_ops.complex(norm_real, norm_imag, name=name)
705 square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True)
706 x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon))
707 return math_ops.multiply(x, x_inv_norm, name=name)
710def _count_nonzero(input_tensor, dtype=dtypes.int64):
711 """Same as math_ops.count_nonzero.
713 The reduction is done in dtype, which can be faster for 32-bit dtypes.
715 Args:
716 input_tensor: numeric tensor
717 dtype: reduction dtype
719 Returns:
720 number of nonzero values with type dtype
721 """
722 with ops.name_scope("count_nonzero", values=[input_tensor]):
723 zero = array_ops.zeros([], dtype=input_tensor.dtype)
724 nonzero_count = math_ops.reduce_sum(
725 math_ops.cast(
726 math_ops.not_equal(input_tensor, zero),
727 dtype=dtype), name="nonzero_count")
728 return nonzero_count
731@tf_export("math.zero_fraction", "nn.zero_fraction")
732@dispatch.add_dispatch_support
733def zero_fraction(value, name=None):
734 """Returns the fraction of zeros in `value`.
736 If `value` is empty, the result is `nan`.
738 This is useful in summaries to measure and report sparsity. For example,
740 ```python
741 z = tf.nn.relu(...)
742 summ = tf.compat.v1.summary.scalar('sparsity', tf.nn.zero_fraction(z))
743 ```
745 Args:
746 value: A tensor of numeric type.
747 name: A name for the operation (optional).
749 Returns:
750 The fraction of zeros in `value`, with type `float32`.
751 """
752 with ops.name_scope(name, "zero_fraction", [value]):
753 value = ops.convert_to_tensor(value, name="value")
754 size = array_ops.size(value, out_type=dtypes.int64)
755 # If the count is small, we can save memory/CPU with an int32 reduction.
756 num_nonzero = tf_cond.cond(
757 size <= dtypes.int32.max,
758 # pylint: disable=g-long-lambda
759 true_fn=lambda: math_ops.cast(
760 _count_nonzero(value, dtype=dtypes.int32),
761 dtype=dtypes.int64),
762 false_fn=lambda: _count_nonzero(value, dtype=dtypes.int64))
764 with ops.name_scope("counts_to_fraction"):
765 num_zero = size - num_nonzero
766 num_zero_float32 = math_ops.cast(num_zero, dtype=dtypes.float32)
767 size_float32 = math_ops.cast(size, dtype=dtypes.float32)
768 zero_fraction_float32 = num_zero_float32 / size_float32
770 return array_ops.identity(zero_fraction_float32, "fraction")
773# pylint: disable=redefined-builtin
774@tf_export(v1=["nn.depthwise_conv2d"])
775@dispatch.add_dispatch_support
776def depthwise_conv2d(input,
777 filter,
778 strides,
779 padding,
780 rate=None,
781 name=None,
782 data_format=None,
783 dilations=None):
784 """Depthwise 2-D convolution.
786 Given a 4D input tensor ('NHWC' or 'NCHW' data formats)
787 and a filter tensor of shape
788 `[filter_height, filter_width, in_channels, channel_multiplier]`
789 containing `in_channels` convolutional filters of depth 1, `depthwise_conv2d`
790 applies a different filter to each input channel (expanding from 1 channel
791 to `channel_multiplier` channels for each), then concatenates the results
792 together. The output has `in_channels * channel_multiplier` channels.
794 In detail, with the default NHWC format,
796 output[b, i, j, k * channel_multiplier + q] = sum_{di, dj}
797 filter[di, dj, k, q] * input[b, strides[1] * i + rate[0] * di,
798 strides[2] * j + rate[1] * dj, k]
800 Must have `strides[0] = strides[3] = 1`. For the most common case of the
801 same horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
802 If any value in `rate` is greater than 1, we perform atrous depthwise
803 convolution, in which case all values in the `strides` tensor must be equal
804 to 1.
806 Usage Example:
808 >>> x = np.array([
809 ... [1., 2.],
810 ... [3., 4.],
811 ... [5., 6.]
812 ... ], dtype=np.float32).reshape((1, 3, 2, 1))
813 >>> kernel = np.array([
814 ... [1., 2.],
815 ... [3., 4]
816 ... ], dtype=np.float32).reshape((2, 1, 1, 2))
817 >>> tf.compat.v1.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
818 ... padding='VALID').numpy()
819 array([[[[10., 14.],
820 [14., 20.]],
821 [[18., 26.],
822 [22., 32.]]]], dtype=float32)
824 >>> tf.compat.v1.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
825 ... padding=[[0, 0], [1, 0], [1, 0], [0, 0]]
826 ... ).numpy()
827 array([[[[ 0., 0.],
828 [ 3., 4.],
829 [ 6., 8.]],
830 [[ 0., 0.],
831 [10., 14.],
832 [14., 20.]],
833 [[ 0., 0.],
834 [18., 26.],
835 [22., 32.]]]], dtype=float32)
837 Args:
838 input: 4-D with shape according to `data_format`.
839 filter: 4-D with shape
840 `[filter_height, filter_width, in_channels, channel_multiplier]`.
841 strides: 1-D of size 4. The stride of the sliding window for each
842 dimension of `input`.
843 padding: Controls how to pad the image before applying the convolution. Can
844 be the string `"SAME"` or `"VALID"` indicating the type of padding
845 algorithm to use, or a list indicating the explicit paddings at the start
846 and end of each dimension. When explicit padding is used and data_format
847 is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
848 [pad_left, pad_right], [0, 0]]`. When explicit padding used and
849 data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
850 [pad_top, pad_bottom], [pad_left, pad_right]]`.
851 rate: 1-D of size 2. The dilation rate in which we sample input values
852 across the `height` and `width` dimensions in atrous convolution. If it is
853 greater than 1, then all values of strides must be 1.
854 name: A name for this operation (optional).
855 data_format: The data format for input. Either "NHWC" (default) or "NCHW".
856 dilations: Alias of rate.
858 Returns:
859 A 4-D `Tensor` with shape according to `data_format`. E.g., for
860 "NHWC" format, shape is
861 `[batch, out_height, out_width, in_channels * channel_multiplier].`
862 """
863 rate = deprecated_argument_lookup("dilations", dilations, "rate", rate)
864 with ops.name_scope(name, "depthwise", [input, filter]) as name:
865 input = ops.convert_to_tensor(input, name="tensor_in")
866 filter = ops.convert_to_tensor(filter, name="filter_in")
867 if rate is None:
868 rate = [1, 1]
870 # Use depthwise_conv2d_native if executing on TPU.
871 if device_context.enclosing_tpu_context() is not None:
872 if data_format == "NCHW":
873 dilations = [1, 1, rate[0], rate[1]]
874 else:
875 dilations = [1, rate[0], rate[1], 1]
876 return nn_ops.depthwise_conv2d_native(
877 input=input,
878 filter=filter,
879 strides=strides,
880 padding=padding,
881 data_format=data_format,
882 dilations=dilations,
883 name=name)
885 def op(input_converted, _, padding):
886 return nn_ops.depthwise_conv2d_native(
887 input=input_converted,
888 filter=filter,
889 strides=strides,
890 padding=padding,
891 data_format=data_format,
892 name=name)
894 return nn_ops.with_space_to_batch(
895 input=input,
896 filter_shape=array_ops.shape(filter),
897 dilation_rate=rate,
898 padding=padding,
899 data_format=data_format,
900 op=op)
903@tf_export("nn.depthwise_conv2d", v1=[])
904@dispatch.add_dispatch_support
905def depthwise_conv2d_v2(input,
906 filter,
907 strides,
908 padding,
909 data_format=None,
910 dilations=None,
911 name=None):
912 """Depthwise 2-D convolution.
914 Given a 4D input tensor ('NHWC' or 'NCHW' data formats)
915 and a filter tensor of shape
916 `[filter_height, filter_width, in_channels, channel_multiplier]`
917 containing `in_channels` convolutional filters of depth 1, `depthwise_conv2d`
918 applies a different filter to each input channel (expanding from 1 channel
919 to `channel_multiplier` channels for each), then concatenates the results
920 together. The output has `in_channels * channel_multiplier` channels.
922 In detail, with the default NHWC format,
924 output[b, i, j, k * channel_multiplier + q] =
925 sum_{di, dj} filter[di, dj, k, q] *
926 input[b, strides[1] * i + dilations[0] * di,
927 strides[2] * j + dilations[1] * dj, k]
929 Must have `strides[0] = strides[3] = 1`. For the most common case of the
930 same horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
931 If any value in `dilations` is greater than 1, we perform atrous depthwise
932 convolution, in which case all values in the `strides` tensor must be equal
933 to 1.
935 Usage Example:
937 >>> x = np.array([
938 ... [1., 2.],
939 ... [3., 4.],
940 ... [5., 6.]
941 ... ], dtype=np.float32).reshape((1, 3, 2, 1))
942 >>> kernel = np.array([
943 ... [1., 2.],
944 ... [3., 4]
945 ... ], dtype=np.float32).reshape((2, 1, 1, 2))
946 >>> tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
947 ... padding='VALID').numpy()
948 array([[[[10., 14.],
949 [14., 20.]],
950 [[18., 26.],
951 [22., 32.]]]], dtype=float32)
953 >>> tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
954 ... padding=[[0, 0], [1, 0], [1, 0], [0, 0]]).numpy()
955 array([[[[ 0., 0.],
956 [ 3., 4.],
957 [ 6., 8.]],
958 [[ 0., 0.],
959 [10., 14.],
960 [14., 20.]],
961 [[ 0., 0.],
962 [18., 26.],
963 [22., 32.]]]], dtype=float32)
965 Args:
966 input: 4-D with shape according to `data_format`.
967 filter: 4-D with shape
968 `[filter_height, filter_width, in_channels, channel_multiplier]`.
969 strides: 1-D of size 4. The stride of the sliding window for each
970 dimension of `input`.
971 padding: Controls how to pad the image before applying the convolution. Can
972 be the string `"SAME"` or `"VALID"` indicating the type of padding
973 algorithm to use, or a list indicating the explicit paddings at the start
974 and end of each dimension. See
975 [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
976 for more information. When explicit padding is used and data_format
977 is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
978 [pad_left, pad_right], [0, 0]]`. When explicit padding used and
979 data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
980 [pad_top, pad_bottom], [pad_left, pad_right]]`.
981 data_format: The data format for input. Either "NHWC" (default) or "NCHW".
982 dilations: 1-D of size 2. The dilation rate in which we sample input values
983 across the `height` and `width` dimensions in atrous convolution. If it is
984 greater than 1, then all values of strides must be 1.
985 name: A name for this operation (optional).
987 Returns:
988 A 4-D `Tensor` with shape according to `data_format`. E.g., for
989 "NHWC" format, shape is
990 `[batch, out_height, out_width, in_channels * channel_multiplier].`
991 """
992 return depthwise_conv2d(input=input,
993 filter=filter,
994 strides=strides,
995 padding=padding,
996 rate=dilations,
997 name=name,
998 data_format=data_format)
1000# pylint: enable=redefined-builtin
1003# pylint: disable=redefined-builtin,line-too-long
1004@tf_export(v1=["nn.separable_conv2d"])
1005@dispatch.add_dispatch_support
1006def separable_conv2d(input,
1007 depthwise_filter,
1008 pointwise_filter,
1009 strides,
1010 padding,
1011 rate=None,
1012 name=None,
1013 data_format=None,
1014 dilations=None):
1015 """2-D convolution with separable filters.
1017 Performs a depthwise convolution that acts separately on channels followed by
1018 a pointwise convolution that mixes channels. Note that this is separability
1019 between dimensions `[1, 2]` and `3`, not spatial separability between
1020 dimensions `1` and `2`.
1022 In detail, with the default NHWC format,
1024 output[b, i, j, k] = sum_{di, dj, q, r}
1025 input[b, strides[1] * i + di, strides[2] * j + dj, q] *
1026 depthwise_filter[di, dj, q, r] *
1027 pointwise_filter[0, 0, q * channel_multiplier + r, k]
1029 `strides` controls the strides for the depthwise convolution only, since
1030 the pointwise convolution has implicit strides of `[1, 1, 1, 1]`. Must have
1031 `strides[0] = strides[3] = 1`. For the most common case of the same
1032 horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
1033 If any value in `rate` is greater than 1, we perform atrous depthwise
1034 convolution, in which case all values in the `strides` tensor must be equal
1035 to 1.
1037 Args:
1038 input: 4-D `Tensor` with shape according to `data_format`.
1039 depthwise_filter: 4-D `Tensor` with shape
1040 `[filter_height, filter_width, in_channels, channel_multiplier]`.
1041 Contains `in_channels` convolutional filters of depth 1.
1042 pointwise_filter: 4-D `Tensor` with shape
1043 `[1, 1, channel_multiplier * in_channels, out_channels]`. Pointwise
1044 filter to mix channels after `depthwise_filter` has convolved spatially.
1045 strides: 1-D of size 4. The strides for the depthwise convolution for
1046 each dimension of `input`.
1047 padding: Controls how to pad the image before applying the depthwise
1048 convolution. Can be the string `"SAME"` or `"VALID"` indicating the type
1049 of padding algorithm to use, or a Python list indicating the explicit
1050 paddings at the start and end of each dimension. When explicit padding is
1051 used and data_format is `"NHWC"`, this should be in the form `[[0, 0],
1052 [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit
1053 padding used and data_format is `"NCHW"`, this should be in the form
1054 `[[0, 0], [0, 0], [pad_top, pad_bottom], [pad_left, pad_right]]`.
1055 rate: 1-D of size 2. The dilation rate in which we sample input values
1056 across the `height` and `width` dimensions in atrous convolution. If it is
1057 greater than 1, then all values of strides must be 1.
1058 name: A name for this operation (optional).
1059 data_format: The data format for input. Either "NHWC" (default) or "NCHW".
1060 dilations: Alias of rate.
1062 Returns:
1063 A 4-D `Tensor` with shape according to 'data_format'. For
1064 example, with data_format="NHWC", shape is [batch, out_height,
1065 out_width, out_channels].
1066 """
1067 rate = deprecated_argument_lookup("dilations", dilations, "rate", rate)
1068 with ops.name_scope(name, "separable_conv2d",
1069 [input, depthwise_filter, pointwise_filter]) as name:
1070 input = ops.convert_to_tensor(input, name="tensor_in")
1071 depthwise_filter = ops.convert_to_tensor(
1072 depthwise_filter, name="depthwise_filter")
1073 pointwise_filter = ops.convert_to_tensor(
1074 pointwise_filter, name="pointwise_filter")
1076 pointwise_filter_shape = pointwise_filter.get_shape().with_rank(4)
1077 pointwise_filter_shape.dims[0].assert_is_compatible_with(1)
1078 pointwise_filter_shape.dims[1].assert_is_compatible_with(1)
1080 if rate is None:
1081 rate = [1, 1]
1083 # The layout of the ops in the graph are expected to be as follows:
1084 # depthwise_conv2d // Conv2D op corresponding to native depthwise conv.
1085 # separable_conv2d // Conv2D op corresponding to the pointwise conv.
1087 def op(input_converted, _, padding):
1088 return nn_ops.depthwise_conv2d_native(
1089 input=input_converted,
1090 filter=depthwise_filter,
1091 strides=strides,
1092 padding=padding,
1093 data_format=data_format,
1094 name="depthwise")
1096 depthwise = nn_ops.with_space_to_batch(
1097 input=input,
1098 filter_shape=array_ops.shape(depthwise_filter),
1099 dilation_rate=rate,
1100 padding=padding,
1101 data_format=data_format,
1102 op=op)
1104 return nn_ops.conv2d(
1105 depthwise,
1106 pointwise_filter, [1, 1, 1, 1],
1107 padding="VALID",
1108 data_format=data_format,
1109 name=name)
1112@tf_export("nn.separable_conv2d", v1=[])
1113@dispatch.add_dispatch_support
1114def separable_conv2d_v2(
1115 input,
1116 depthwise_filter,
1117 pointwise_filter,
1118 strides,
1119 padding,
1120 data_format=None,
1121 dilations=None,
1122 name=None,
1123):
1124 """2-D convolution with separable filters.
1126 Performs a depthwise convolution that acts separately on channels followed by
1127 a pointwise convolution that mixes channels. Note that this is separability
1128 between dimensions `[1, 2]` and `3`, not spatial separability between
1129 dimensions `1` and `2`.
1131 In detail, with the default NHWC format,
1133 output[b, i, j, k] = sum_{di, dj, q, r}
1134 input[b, strides[1] * i + di, strides[2] * j + dj, q] *
1135 depthwise_filter[di, dj, q, r] *
1136 pointwise_filter[0, 0, q * channel_multiplier + r, k]
1138 `strides` controls the strides for the depthwise convolution only, since
1139 the pointwise convolution has implicit strides of `[1, 1, 1, 1]`. Must have
1140 `strides[0] = strides[3] = 1`. For the most common case of the same
1141 horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
1142 If any value in `rate` is greater than 1, we perform atrous depthwise
1143 convolution, in which case all values in the `strides` tensor must be equal
1144 to 1.
1146 Args:
1147 input: 4-D `Tensor` with shape according to `data_format`.
1148 depthwise_filter: 4-D `Tensor` with shape `[filter_height, filter_width,
1149 in_channels, channel_multiplier]`. Contains `in_channels` convolutional
1150 filters of depth 1.
1151 pointwise_filter: 4-D `Tensor` with shape `[1, 1, channel_multiplier *
1152 in_channels, out_channels]`. Pointwise filter to mix channels after
1153 `depthwise_filter` has convolved spatially.
1154 strides: 1-D of size 4. The strides for the depthwise convolution for each
1155 dimension of `input`.
1156 padding: Controls how to pad the image before applying the depthwise
1157 convolution. Can be the string `"SAME"` or `"VALID"` indicating the type
1158 of padding algorithm to use, or a Python list indicating the explicit
1159 paddings at the start and end of each dimension. When explicit padding is
1160 used and data_format is `"NHWC"`, this should be in the form `[[0, 0],
1161 [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit
1162 padding used and data_format is `"NCHW"`, this should be in the form
1163 `[[0, 0], [0, 0], [pad_top, pad_bottom], [pad_left, pad_right]]`.
1164 data_format: The data format for input. Either "NHWC" (default) or "NCHW".
1165 dilations: 1-D of size 2. The dilation rate in which we sample input values
1166 across the `height` and `width` dimensions in atrous convolution. If it is
1167 greater than 1, then all values of strides must be 1.
1168 name: A name for this operation (optional).
1170 Returns:
1171 A 4-D `Tensor` with shape according to 'data_format'. For
1172 example, with data_format="NHWC", shape is [batch, out_height,
1173 out_width, out_channels].
1174 """
1175 return separable_conv2d(
1176 input,
1177 depthwise_filter,
1178 pointwise_filter,
1179 strides,
1180 padding,
1181 rate=dilations,
1182 name=name,
1183 data_format=data_format)
1185# pylint: enable=redefined-builtin,line-too-long
1188@tf_export(v1=["nn.sufficient_statistics"])
1189@dispatch.add_dispatch_support
1190def sufficient_statistics(x, axes, shift=None, keep_dims=None, name=None,
1191 keepdims=None):
1192 """Calculate the sufficient statistics for the mean and variance of `x`.
1194 These sufficient statistics are computed using the one pass algorithm on
1195 an input that's optionally shifted. See:
1196 https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
1198 For example:
1199 >>> t = [[1, 2, 3], [4, 5, 6]]
1200 >>> sufficient_statistics(t, [1])
1201 (<tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(2,),
1202 dtype=int32, numpy=array([ 6, 15], dtype=int32)>, <tf.Tensor: shape=(2,),
1203 dtype=int32, numpy=array([14, 77], dtype=int32)>, None)
1204 >>> sufficient_statistics(t, [-1])
1205 (<tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(2,),
1206 dtype=int32, numpy=array([ 6, 15], dtype=int32)>, <tf.Tensor: shape=(2,),
1207 dtype=int32, numpy=array([14, 77], dtype=int32)>, None)
1209 Args:
1210 x: A `Tensor`.
1211 axes: Array of ints. Axes along which to compute mean and variance. As in
1212 Python, the axes can also be negative numbers. A negative axis is
1213 interpreted as counting from the end of the rank, i.e., axis +
1214 rank(values)-th dimension.
1215 shift: A `Tensor` containing the value by which to shift the data for
1216 numerical stability, or `None` if no shift is to be performed. A shift
1217 close to the true mean provides the most numerically stable results.
1218 keep_dims: produce statistics with the same dimensionality as the input.
1219 name: Name used to scope the operations that compute the sufficient stats.
1220 keepdims: Alias for keep_dims.
1222 Returns:
1223 Four `Tensor` objects of the same type as `x`:
1225 * the count (number of elements to average over).
1226 * the (possibly shifted) sum of the elements in the array.
1227 * the (possibly shifted) sum of squares of the elements in the array.
1228 * the shift by which the mean must be corrected or None if `shift` is None.
1229 """
1230 axes = list(set(axes))
1231 keep_dims = deprecated_argument_lookup(
1232 "keepdims", keepdims, "keep_dims", keep_dims)
1233 if keep_dims is None:
1234 keep_dims = False
1235 with ops.name_scope(name, "sufficient_statistics", [x, shift]):
1236 x = ops.convert_to_tensor(x, name="x")
1237 x_shape = x.get_shape()
1238 if x_shape.rank is not None and all(
1239 x_shape.dims[d].value is not None for d in axes):
1240 counts = 1
1241 for d in axes:
1242 counts *= x_shape.dims[d].value
1243 counts = constant_op.constant(counts, dtype=x.dtype)
1244 else: # shape needs to be inferred at runtime.
1245 # Normalize axes to be positive. Required for gather.
1246 rank = array_ops.rank(x)
1247 positive_axes = [axis + rank if axis < 0 else axis for axis in axes]
1248 x_dims = array_ops.gather(
1249 math_ops.cast(array_ops.shape(x), x.dtype), positive_axes)
1250 counts = math_ops.reduce_prod(x_dims, name="count")
1251 if shift is not None:
1252 shift = ops.convert_to_tensor(shift, name="shift")
1253 m_ss = math_ops.subtract(x, shift)
1254 v_ss = math_ops.squared_difference(x, shift)
1255 else: # no shift.
1256 m_ss = x
1257 v_ss = math_ops.square(x)
1258 m_ss = math_ops.reduce_sum(m_ss, axes, keepdims=keep_dims, name="mean_ss")
1259 v_ss = math_ops.reduce_sum(v_ss, axes, keepdims=keep_dims, name="var_ss")
1260 return counts, m_ss, v_ss, shift
1263@tf_export("nn.sufficient_statistics", v1=[])
1264@dispatch.add_dispatch_support
1265def sufficient_statistics_v2(x, axes, shift=None, keepdims=False, name=None):
1266 """Calculate the sufficient statistics for the mean and variance of `x`.
1268 These sufficient statistics are computed using the one pass algorithm on
1269 an input that's optionally shifted. See:
1270 https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
1272 Args:
1273 x: A `Tensor`.
1274 axes: Array of ints. Axes along which to compute mean and variance.
1275 shift: A `Tensor` containing the value by which to shift the data for
1276 numerical stability, or `None` if no shift is to be performed. A shift
1277 close to the true mean provides the most numerically stable results.
1278 keepdims: produce statistics with the same dimensionality as the input.
1279 name: Name used to scope the operations that compute the sufficient stats.
1281 Returns:
1282 Four `Tensor` objects of the same type as `x`:
1284 * the count (number of elements to average over).
1285 * the (possibly shifted) sum of the elements in the array.
1286 * the (possibly shifted) sum of squares of the elements in the array.
1287 * the shift by which the mean must be corrected or None if `shift` is None.
1288 """
1289 return sufficient_statistics(
1290 x=x, axes=axes, shift=shift, keep_dims=keepdims, name=name)
1293@tf_export("nn.normalize_moments")
1294@dispatch.add_dispatch_support
1295def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
1296 """Calculate the mean and variance of based on the sufficient statistics.
1298 Args:
1299 counts: A `Tensor` containing the total count of the data (one value).
1300 mean_ss: A `Tensor` containing the mean sufficient statistics: the (possibly
1301 shifted) sum of the elements to average over.
1302 variance_ss: A `Tensor` containing the variance sufficient statistics: the
1303 (possibly shifted) squared sum of the data to compute the variance over.
1304 shift: A `Tensor` containing the value by which the data is shifted for
1305 numerical stability, or `None` if no shift was performed.
1306 name: Name used to scope the operations that compute the moments.
1308 Returns:
1309 Two `Tensor` objects: `mean` and `variance`.
1310 """
1311 with ops.name_scope(name, "normalize", [counts, mean_ss, variance_ss, shift]):
1312 divisor = math_ops.reciprocal(counts, name="divisor")
1313 if shift is not None:
1314 shifted_mean = math_ops.multiply(mean_ss, divisor, name="shifted_mean")
1315 mean = math_ops.add(shifted_mean, shift, name="mean")
1316 else: # no shift.
1317 shifted_mean = math_ops.multiply(mean_ss, divisor, name="mean")
1318 mean = shifted_mean
1319 variance = math_ops.subtract(
1320 math_ops.multiply(variance_ss, divisor),
1321 math_ops.square(shifted_mean),
1322 name="variance")
1323 return (mean, variance)
1326@tf_export(v1=["nn.moments"])
1327@dispatch.add_dispatch_support
1328def moments(
1329 x,
1330 axes,
1331 shift=None, # pylint: disable=unused-argument
1332 name=None,
1333 keep_dims=None,
1334 keepdims=None):
1335 """Calculate the mean and variance of `x`.
1337 The mean and variance are calculated by aggregating the contents of `x`
1338 across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean
1339 and variance of a vector.
1341 Note: shift is currently not used; the true mean is computed and used.
1343 When using these moments for batch normalization (see
1344 `tf.nn.batch_normalization`):
1346 * for so-called "global normalization", used with convolutional filters with
1347 shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`.
1348 * for simple batch normalization pass `axes=[0]` (batch only).
1350 Args:
1351 x: A `Tensor`.
1352 axes: Array of ints. Axes along which to compute mean and
1353 variance.
1354 shift: Not used in the current implementation
1355 name: Name used to scope the operations that compute the moments.
1356 keep_dims: produce moments with the same dimensionality as the input.
1357 keepdims: Alias to keep_dims.
1359 Returns:
1360 Two `Tensor` objects: `mean` and `variance`.
1361 """
1362 keep_dims = deprecated_argument_lookup(
1363 "keepdims", keepdims, "keep_dims", keep_dims)
1364 if keep_dims is None:
1365 keep_dims = False
1366 with ops.name_scope(name, "moments", [x, axes]):
1367 # The dynamic range of fp16 is too limited to support the collection of
1368 # sufficient statistics. As a workaround we simply perform the operations
1369 # on 32-bit floats before converting the mean and variance back to fp16
1370 y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
1371 # Compute true mean while keeping the dims for proper broadcasting.
1372 mean = math_ops.reduce_mean(y, axes, keepdims=True, name="mean")
1373 # sample variance, not unbiased variance
1374 # Note: stop_gradient does not change the gradient that gets
1375 # backpropagated to the mean from the variance calculation,
1376 # because that gradient is zero
1377 variance = math_ops.reduce_mean(
1378 math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
1379 axes,
1380 keepdims=True,
1381 name="variance")
1382 if not keep_dims:
1383 mean = array_ops.squeeze(mean, axes)
1384 variance = array_ops.squeeze(variance, axes)
1385 if x.dtype == dtypes.float16:
1386 return (math_ops.cast(mean, dtypes.float16),
1387 math_ops.cast(variance, dtypes.float16))
1388 else:
1389 return (mean, variance)
1392@tf_export("nn.moments", v1=[])
1393@dispatch.add_dispatch_support
1394def moments_v2(
1395 x,
1396 axes,
1397 shift=None,
1398 keepdims=False,
1399 name=None):
1400 """Calculates the mean and variance of `x`.
1402 The mean and variance are calculated by aggregating the contents of `x`
1403 across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean
1404 and variance of a vector.
1406 Note: shift is currently not used; the true mean is computed and used.
1408 When using these moments for batch normalization (see
1409 `tf.nn.batch_normalization`):
1411 * for so-called "global normalization", used with convolutional filters with
1412 shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`.
1413 * for simple batch normalization pass `axes=[0]` (batch only).
1415 Args:
1416 x: A `Tensor`.
1417 axes: Array of ints. Axes along which to compute mean and
1418 variance.
1419 shift: Not used in the current implementation.
1420 keepdims: produce moments with the same dimensionality as the input.
1421 name: Name used to scope the operations that compute the moments.
1423 Returns:
1424 Two `Tensor` objects: `mean` and `variance`.
1425 """
1426 return moments(x=x, axes=axes, shift=shift, name=name, keep_dims=keepdims)
1429@tf_export(v1=["nn.weighted_moments"])
1430@dispatch.add_dispatch_support
1431def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=None,
1432 keepdims=None):
1433 """Returns the frequency-weighted mean and variance of `x`.
1435 Args:
1436 x: A tensor.
1437 axes: 1-d tensor of int32 values; these are the axes along which
1438 to compute mean and variance.
1439 frequency_weights: A tensor of positive weights which can be
1440 broadcast with x.
1441 name: Name used to scope the operation.
1442 keep_dims: Produce moments with the same dimensionality as the input.
1443 keepdims: Alias of keep_dims.
1445 Returns:
1446 Two tensors: `weighted_mean` and `weighted_variance`.
1447 """
1448 keep_dims = deprecated_argument_lookup(
1449 "keepdims", keepdims, "keep_dims", keep_dims)
1450 if keep_dims is None:
1451 keep_dims = False
1452 with ops.name_scope(name, "weighted_moments", [x, frequency_weights, axes]):
1453 x = ops.convert_to_tensor(x, name="x")
1454 frequency_weights = ops.convert_to_tensor(
1455 frequency_weights, name="frequency_weights")
1457 # Unlike moments(), this just uses a simpler two-pass method.
1459 # See comment in moments() WRT precision; it applies here too.
1460 needs_cast = x.dtype == dtypes.float16
1461 if needs_cast:
1462 x = math_ops.cast(x, dtypes.float32)
1464 if frequency_weights.dtype != x.dtype:
1465 frequency_weights = math_ops.cast(frequency_weights, x.dtype)
1467 # Note that we use keep_dims=True for our reductions regardless of the arg;
1468 # this is so that the results remain broadcast-compatible with the inputs.
1469 weighted_input_sum = math_ops.reduce_sum(
1470 frequency_weights * x, axes, name="weighted_input_sum", keepdims=True)
1472 # The shape of the weights isn't necessarily the same as x's
1473 # shape, just broadcast-compatible with it -- so this expression
1474 # performs broadcasting to give a per-item weight, with the same
1475 # shape as (frequency_weights * x). This avoids having to reason
1476 # through all the broadcast logic to compute a correct
1477 # sum_of_weights.
1478 broadcasted_weights = frequency_weights + array_ops.zeros_like(x)
1480 sum_of_weights = math_ops.reduce_sum(
1481 broadcasted_weights, axes, name="sum_of_weights", keepdims=True)
1483 weighted_mean = math_ops.div_no_nan(weighted_input_sum, sum_of_weights)
1485 # Have the weighted mean; now on to variance:
1486 weighted_distsq = math_ops.reduce_sum(
1487 frequency_weights * math_ops.squared_difference(x, weighted_mean),
1488 axes,
1489 name="weighted_distsq",
1490 keepdims=True)
1492 weighted_variance = math_ops.div_no_nan(weighted_distsq, sum_of_weights)
1494 if not keep_dims:
1495 weighted_mean = array_ops.squeeze(weighted_mean, axis=axes)
1496 weighted_variance = array_ops.squeeze(
1497 weighted_variance, axis=axes)
1499 if needs_cast:
1500 weighted_mean = math_ops.cast(weighted_mean, dtypes.float16)
1501 weighted_variance = math_ops.cast(weighted_variance, dtypes.float16)
1503 return weighted_mean, weighted_variance
1506@tf_export("nn.weighted_moments", v1=[])
1507@dispatch.add_dispatch_support
1508def weighted_moments_v2(x, axes, frequency_weights, keepdims=False, name=None):
1509 """Returns the frequency-weighted mean and variance of `x`.
1511 Args:
1512 x: A tensor.
1513 axes: 1-d tensor of int32 values; these are the axes along which
1514 to compute mean and variance.
1515 frequency_weights: A tensor of positive weights which can be
1516 broadcast with x.
1517 keepdims: Produce moments with the same dimensionality as the input.
1518 name: Name used to scope the operation.
1520 Returns:
1521 Two tensors: `weighted_mean` and `weighted_variance`.
1522 """
1523 return weighted_moments(
1524 x=x,
1525 axes=axes,
1526 frequency_weights=frequency_weights,
1527 name=name,
1528 keep_dims=keepdims)
1531@tf_export("nn.batch_normalization")
1532@dispatch.add_dispatch_support
1533def batch_normalization(x,
1534 mean,
1535 variance,
1536 offset,
1537 scale,
1538 variance_epsilon,
1539 name=None):
1540 r"""Batch normalization.
1542 Normalizes a tensor by `mean` and `variance`, and applies (optionally) a
1543 `scale` \\(\gamma\\) to it, as well as an `offset` \\(\beta\\):
1545 \\(\frac{\gamma(x-\mu)}{\sigma}+\beta\\)
1547 `mean`, `variance`, `offset` and `scale` are all expected to be of one of two
1548 shapes:
1550 * In all generality, they can have the same number of dimensions as the
1551 input `x`, with identical sizes as `x` for the dimensions that are not
1552 normalized over (the 'depth' dimension(s)), and dimension 1 for the
1553 others which are being normalized over.
1554 `mean` and `variance` in this case would typically be the outputs of
1555 `tf.nn.moments(..., keepdims=True)` during training, or running averages
1556 thereof during inference.
1557 * In the common case where the 'depth' dimension is the last dimension in
1558 the input tensor `x`, they may be one dimensional tensors of the same
1559 size as the 'depth' dimension.
1560 This is the case for example for the common `[batch, depth]` layout of
1561 fully-connected layers, and `[batch, height, width, depth]` for
1562 convolutions.
1563 `mean` and `variance` in this case would typically be the outputs of
1564 `tf.nn.moments(..., keepdims=False)` during training, or running averages
1565 thereof during inference.
1567 See equation 11 in Algorithm 2 of source:
1568 [Batch Normalization: Accelerating Deep Network Training by
1569 Reducing Internal Covariate Shift; S. Ioffe, C. Szegedy]
1570 (http://arxiv.org/abs/1502.03167).
1572 Args:
1573 x: Input `Tensor` of arbitrary dimensionality.
1574 mean: A mean `Tensor`.
1575 variance: A variance `Tensor`.
1576 offset: An offset `Tensor`, often denoted \\(\beta\\) in equations, or
1577 None. If present, will be added to the normalized tensor.
1578 scale: A scale `Tensor`, often denoted \\(\gamma\\) in equations, or
1579 `None`. If present, the scale is applied to the normalized tensor.
1580 variance_epsilon: A small float number to avoid dividing by 0.
1581 name: A name for this operation (optional).
1583 Returns:
1584 the normalized, scaled, offset tensor.
1586 References:
1587 Batch Normalization - Accelerating Deep Network Training by Reducing
1588 Internal Covariate Shift:
1589 [Ioffe et al., 2015](http://arxiv.org/abs/1502.03167)
1590 ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
1591 """
1592 with ops.name_scope(name, "batchnorm", [x, mean, variance, scale, offset]):
1593 inv = math_ops.rsqrt(variance + variance_epsilon)
1594 if scale is not None:
1595 inv *= scale
1596 # Note: tensorflow/contrib/quantize/python/fold_batch_norms.py depends on
1597 # the precise order of ops that are generated by the expression below.
1598 return x * math_ops.cast(inv, x.dtype) + math_ops.cast(
1599 offset - mean * inv if offset is not None else -mean * inv, x.dtype)
1602@tf_export(v1=["nn.fused_batch_norm"])
1603@dispatch.add_dispatch_support
1604def fused_batch_norm(
1605 x,
1606 scale,
1607 offset, # pylint: disable=invalid-name
1608 mean=None,
1609 variance=None,
1610 epsilon=0.001,
1611 data_format="NHWC",
1612 is_training=True,
1613 name=None,
1614 exponential_avg_factor=1.0):
1615 r"""Batch normalization.
1618 See Source: [Batch Normalization: Accelerating Deep Network Training by
1619 Reducing Internal Covariate Shift; S. Ioffe, C. Szegedy]
1620 (http://arxiv.org/abs/1502.03167).
1622 Args:
1623 x: Input `Tensor` of 4 or 5 dimensions.
1624 scale: A `Tensor` of 1 dimension for scaling.
1625 offset: A `Tensor` of 1 dimension for bias.
1626 mean: A `Tensor` of 1 dimension for population mean. The shape and meaning
1627 of this argument depends on the value of is_training and
1628 exponential_avg_factor as follows:
1629 is_training==False (inference):
1630 Mean must be a `Tensor` of the same shape as scale containing the
1631 estimated population mean computed during training.
1632 is_training==True and exponential_avg_factor == 1.0:
1633 Mean must be None.
1634 is_training==True and exponential_avg_factor != 1.0:
1635 Mean must be a `Tensor` of the same shape as scale containing the
1636 exponential running mean.
1637 variance: A `Tensor` of 1 dimension for population variance. The shape and
1638 meaning of this argument depends on the value of is_training and
1639 exponential_avg_factor as follows:
1640 is_training==False (inference):
1641 Variance must be a `Tensor` of the same shape as scale containing
1642 the estimated population variance computed during training.
1643 is_training==True and exponential_avg_factor == 1.0:
1644 Variance must be None.
1645 is_training==True and exponential_avg_factor != 1.0:
1646 Variance must be a `Tensor` of the same shape as scale containing
1647 the exponential running variance.
1648 epsilon: A small float number added to the variance of x.
1649 data_format: The data format for x. Support "NHWC" (default) or "NCHW" for
1650 4D tenors and "NDHWC" or "NCDHW" for 5D tensors.
1651 is_training: A bool value to specify if the operation is used for
1652 training or inference.
1653 name: A name for this operation (optional).
1654 exponential_avg_factor: A float number (usually between 0 and 1) used
1655 for controlling the decay of the running
1656 population average of mean and variance.
1657 If set to 1.0, the current batch average is
1658 returned.
1660 Returns:
1661 y: A 4D or 5D Tensor for the normalized, scaled, offsetted x.
1662 running_mean: A 1D Tensor for the exponential running mean of x.
1663 The output value is (1 - exponential_avg_factor) * mean +
1664 exponential_avg_factor * batch_mean), where batch_mean
1665 is the mean of the current batch in x.
1666 running_var: A 1D Tensor for the exponential running variance
1667 The output value is (1 - exponential_avg_factor) * variance +
1668 exponential_avg_factor * batch_variance), where batch_variance
1669 is the variance of the current batch in x.
1671 References:
1672 Batch Normalization - Accelerating Deep Network Training by Reducing
1673 Internal Covariate Shift:
1674 [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
1675 ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
1676 """
1677 if (not is_training or exponential_avg_factor != 1.0) and (
1678 (mean is None) or (variance is None)):
1679 raise ValueError("Both `mean` and `variance` must be a 1D tensor when "
1680 "`is_training` is False or `exponential_avg_factor` != "
1681 f"1.0. Received: `mean` {mean!r} and `variance` "
1682 f"{variance!r}")
1683 x = ops.convert_to_tensor(x, name="input")
1684 scale = ops.convert_to_tensor(scale, name="scale")
1685 offset = ops.convert_to_tensor(offset, name="offset")
1686 if mean is None:
1687 mean = constant_op.constant([])
1688 if variance is None:
1689 variance = constant_op.constant([])
1691 y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
1692 x,
1693 scale,
1694 offset,
1695 mean,
1696 variance,
1697 epsilon=epsilon,
1698 exponential_avg_factor=exponential_avg_factor,
1699 data_format=data_format,
1700 is_training=is_training,
1701 name=name)
1702 return y, running_mean, running_var
1705@tf_export(v1=["nn.batch_norm_with_global_normalization"])
1706@dispatch.add_dispatch_support
1707def batch_norm_with_global_normalization(t=None,
1708 m=None,
1709 v=None,
1710 beta=None,
1711 gamma=None,
1712 variance_epsilon=None,
1713 scale_after_normalization=None,
1714 name=None,
1715 input=None, # pylint: disable=redefined-builtin
1716 mean=None,
1717 variance=None):
1718 """Batch normalization.
1720 This op is deprecated. See `tf.nn.batch_normalization`.
1722 Args:
1723 t: A 4D input Tensor.
1724 m: A 1D mean Tensor with size matching the last dimension of t.
1725 This is the first output from tf.nn.moments,
1726 or a saved moving average thereof.
1727 v: A 1D variance Tensor with size matching the last dimension of t.
1728 This is the second output from tf.nn.moments,
1729 or a saved moving average thereof.
1730 beta: A 1D beta Tensor with size matching the last dimension of t.
1731 An offset to be added to the normalized tensor.
1732 gamma: A 1D gamma Tensor with size matching the last dimension of t.
1733 If "scale_after_normalization" is true, this tensor will be multiplied
1734 with the normalized tensor.
1735 variance_epsilon: A small float number to avoid dividing by 0.
1736 scale_after_normalization: A bool indicating whether the resulted tensor
1737 needs to be multiplied with gamma.
1738 name: A name for this operation (optional).
1739 input: Alias for t.
1740 mean: Alias for m.
1741 variance: Alias for v.
1743 Returns:
1744 A batch-normalized `t`.
1746 References:
1747 Batch Normalization - Accelerating Deep Network Training by Reducing
1748 Internal Covariate Shift:
1749 [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
1750 ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
1751 """
1752 t = deprecated_argument_lookup("input", input, "t", t)
1753 m = deprecated_argument_lookup("mean", mean, "m", m)
1754 v = deprecated_argument_lookup("variance", variance, "v", v)
1755 return batch_normalization(t, m, v, beta, gamma if scale_after_normalization
1756 else None, variance_epsilon, name)
1759# pylint: disable=redefined-builtin,line-too-long
1760@tf_export("nn.batch_norm_with_global_normalization", v1=[])
1761@dispatch.add_dispatch_support
1762def batch_norm_with_global_normalization_v2(input,
1763 mean,
1764 variance,
1765 beta,
1766 gamma,
1767 variance_epsilon,
1768 scale_after_normalization,
1769 name=None):
1770 """Batch normalization.
1772 This op is deprecated. See `tf.nn.batch_normalization`.
1774 Args:
1775 input: A 4D input Tensor.
1776 mean: A 1D mean Tensor with size matching the last dimension of t.
1777 This is the first output from tf.nn.moments,
1778 or a saved moving average thereof.
1779 variance: A 1D variance Tensor with size matching the last dimension of t.
1780 This is the second output from tf.nn.moments,
1781 or a saved moving average thereof.
1782 beta: A 1D beta Tensor with size matching the last dimension of t.
1783 An offset to be added to the normalized tensor.
1784 gamma: A 1D gamma Tensor with size matching the last dimension of t.
1785 If "scale_after_normalization" is true, this tensor will be multiplied
1786 with the normalized tensor.
1787 variance_epsilon: A small float number to avoid dividing by 0.
1788 scale_after_normalization: A bool indicating whether the resulted tensor
1789 needs to be multiplied with gamma.
1790 name: A name for this operation (optional).
1792 Returns:
1793 A batch-normalized `t`.
1795 References:
1796 Batch Normalization - Accelerating Deep Network Training by Reducing Internal Covariate Shift:
1797 [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
1798 ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
1799 """
1800 return batch_norm_with_global_normalization(t=input,
1801 m=mean,
1802 v=variance,
1803 beta=beta,
1804 gamma=gamma,
1805 variance_epsilon=variance_epsilon,
1806 scale_after_normalization=scale_after_normalization,
1807 name=name)
1809# pylint: enable=redefined-builtin,line-too-long
1812def _sum_rows(x):
1813 """Returns a vector summing up each row of the matrix x."""
1814 # _sum_rows(x) is equivalent to math_ops.reduce_sum(x, 1) when x is
1815 # a matrix. The gradient of _sum_rows(x) is more efficient than
1816 # reduce_sum(x, 1)'s gradient in today's implementation. Therefore,
1817 # we use _sum_rows(x) in the nce_loss() computation since the loss
1818 # is mostly used for training.
1819 cols = array_ops.shape(x)[1]
1820 ones_shape = array_ops_stack.stack([cols, 1])
1821 ones = array_ops.ones(ones_shape, x.dtype)
1822 return array_ops.reshape(math_ops.matmul(x, ones), [-1])
1825def _compute_sampled_logits(weights,
1826 biases,
1827 labels,
1828 inputs,
1829 num_sampled,
1830 num_classes,
1831 num_true=1,
1832 sampled_values=None,
1833 subtract_log_q=True,
1834 remove_accidental_hits=False,
1835 partition_strategy="mod",
1836 name=None,
1837 seed=None):
1838 """Helper function for nce_loss and sampled_softmax_loss functions.
1840 Computes sampled output training logits and labels suitable for implementing
1841 e.g. noise-contrastive estimation (see nce_loss) or sampled softmax (see
1842 sampled_softmax_loss).
1844 Note: In the case where num_true > 1, we assign to each target class
1845 the target probability 1 / num_true so that the target probabilities
1846 sum to 1 per-example.
1848 Args:
1849 weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
1850 objects whose concatenation along dimension 0 has shape
1851 `[num_classes, dim]`. The (possibly-partitioned) class embeddings.
1852 biases: A `Tensor` of shape `[num_classes]`. The (possibly-partitioned)
1853 class biases.
1854 labels: A `Tensor` of type `int64` and shape `[batch_size,
1855 num_true]`. The target classes. Note that this format differs from
1856 the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
1857 inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
1858 activations of the input network.
1859 num_sampled: An `int`. The number of classes to randomly sample per batch.
1860 num_classes: An `int`. The number of possible classes.
1861 num_true: An `int`. The number of target classes per training example.
1862 sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
1863 `sampled_expected_count`) returned by a `*_candidate_sampler` function.
1864 (if None, we default to `log_uniform_candidate_sampler`)
1865 subtract_log_q: A `bool`. whether to subtract the log expected count of
1866 the labels in the sample to get the logits of the true labels.
1867 Default is True. Turn off for Negative Sampling.
1868 remove_accidental_hits: A `bool`. whether to remove "accidental hits"
1869 where a sampled class equals one of the target classes. Default is
1870 False.
1871 partition_strategy: A string specifying the partitioning strategy, relevant
1872 if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
1873 Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
1874 name: A name for the operation (optional).
1875 seed: random seed for candidate sampling. Default to None, which doesn't set
1876 the op-level random seed for candidate sampling.
1877 Returns:
1878 out_logits: `Tensor` object with shape
1879 `[batch_size, num_true + num_sampled]`, for passing to either
1880 `nn.sigmoid_cross_entropy_with_logits` (NCE) or
1881 `nn.softmax_cross_entropy_with_logits` (sampled softmax).
1882 out_labels: A Tensor object with the same shape as `out_logits`.
1883 """
1885 if isinstance(weights, variables.PartitionedVariable):
1886 weights = list(weights)
1887 if not isinstance(weights, list):
1888 weights = [weights]
1890 with ops.name_scope(name, "compute_sampled_logits",
1891 weights + [biases, inputs, labels]):
1892 if labels.dtype != dtypes.int64:
1893 labels = math_ops.cast(labels, dtypes.int64)
1894 labels_flat = array_ops.reshape(labels, [-1])
1896 # Sample the negative labels.
1897 # sampled shape: [num_sampled] tensor
1898 # true_expected_count shape = [batch_size, 1] tensor
1899 # sampled_expected_count shape = [num_sampled] tensor
1900 if sampled_values is None:
1901 sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
1902 true_classes=labels,
1903 num_true=num_true,
1904 num_sampled=num_sampled,
1905 unique=True,
1906 range_max=num_classes,
1907 seed=seed)
1908 # NOTE: pylint cannot tell that 'sampled_values' is a sequence
1909 # pylint: disable=unpacking-non-sequence
1910 sampled, true_expected_count, sampled_expected_count = (
1911 array_ops.stop_gradient(s) for s in sampled_values)
1912 # pylint: enable=unpacking-non-sequence
1913 sampled = math_ops.cast(sampled, dtypes.int64)
1915 # labels_flat is a [batch_size * num_true] tensor
1916 # sampled is a [num_sampled] int tensor
1917 all_ids = array_ops.concat([labels_flat, sampled], 0)
1919 # Retrieve the true weights and the logits of the sampled weights.
1921 # weights shape is [num_classes, dim]
1922 all_w = embedding_ops.embedding_lookup(
1923 weights, all_ids, partition_strategy=partition_strategy)
1924 if all_w.dtype != inputs.dtype:
1925 all_w = math_ops.cast(all_w, inputs.dtype)
1927 # true_w shape is [batch_size * num_true, dim]
1928 true_w = array_ops.slice(all_w, [0, 0],
1929 array_ops_stack.stack(
1930 [array_ops.shape(labels_flat)[0], -1]))
1932 sampled_w = array_ops.slice(
1933 all_w,
1934 array_ops_stack.stack([array_ops.shape(labels_flat)[0], 0]), [-1, -1])
1935 # inputs has shape [batch_size, dim]
1936 # sampled_w has shape [num_sampled, dim]
1937 # Apply X*W', which yields [batch_size, num_sampled]
1938 sampled_logits = math_ops.matmul(inputs, sampled_w, transpose_b=True)
1940 # Retrieve the true and sampled biases, compute the true logits, and
1941 # add the biases to the true and sampled logits.
1942 all_b = embedding_ops.embedding_lookup(
1943 biases, all_ids, partition_strategy=partition_strategy)
1944 if all_b.dtype != inputs.dtype:
1945 all_b = math_ops.cast(all_b, inputs.dtype)
1946 # true_b is a [batch_size * num_true] tensor
1947 # sampled_b is a [num_sampled] float tensor
1948 true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat))
1949 sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1])
1951 # inputs shape is [batch_size, dim]
1952 # true_w shape is [batch_size * num_true, dim]
1953 # row_wise_dots is [batch_size, num_true, dim]
1954 dim = array_ops.shape(true_w)[1:2]
1955 new_true_w_shape = array_ops.concat([[-1, num_true], dim], 0)
1956 row_wise_dots = math_ops.multiply(
1957 array_ops.expand_dims(inputs, 1),
1958 array_ops.reshape(true_w, new_true_w_shape))
1959 # We want the row-wise dot plus biases which yields a
1960 # [batch_size, num_true] tensor of true_logits.
1961 dots_as_matrix = array_ops.reshape(row_wise_dots,
1962 array_ops.concat([[-1], dim], 0))
1963 true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true])
1964 true_b = array_ops.reshape(true_b, [-1, num_true])
1965 true_logits += true_b
1966 sampled_logits += sampled_b
1968 if remove_accidental_hits:
1969 acc_hits = candidate_sampling_ops.compute_accidental_hits(
1970 labels, sampled, num_true=num_true)
1971 acc_indices, acc_ids, acc_weights = acc_hits
1973 # This is how SparseToDense expects the indices.
1974 acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
1975 acc_ids_2d_int32 = array_ops.reshape(
1976 math_ops.cast(acc_ids, dtypes.int32), [-1, 1])
1977 sparse_indices = array_ops.concat([acc_indices_2d, acc_ids_2d_int32], 1,
1978 "sparse_indices")
1979 # Create sampled_logits_shape = [batch_size, num_sampled]
1980 sampled_logits_shape = array_ops.concat(
1981 [array_ops.shape(labels)[:1],
1982 array_ops.expand_dims(num_sampled, 0)], 0)
1983 if sampled_logits.dtype != acc_weights.dtype:
1984 acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype)
1985 sampled_logits += gen_sparse_ops.sparse_to_dense(
1986 sparse_indices,
1987 sampled_logits_shape,
1988 acc_weights,
1989 default_value=0.0,
1990 validate_indices=False)
1992 if subtract_log_q:
1993 # Subtract log of Q(l), prior probability that l appears in sampled.
1994 true_logits -= math_ops.log(true_expected_count)
1995 sampled_logits -= math_ops.log(sampled_expected_count)
1997 # Construct output logits and labels. The true labels/logits start at col 0.
1998 out_logits = array_ops.concat([true_logits, sampled_logits], 1)
2000 # true_logits is a float tensor, ones_like(true_logits) is a float
2001 # tensor of ones. We then divide by num_true to ensure the per-example
2002 # labels sum to 1.0, i.e. form a proper probability distribution.
2003 out_labels = array_ops.concat([
2004 array_ops.ones_like(true_logits) / num_true,
2005 array_ops.zeros_like(sampled_logits)
2006 ], 1)
2008 return out_logits, out_labels
2011@tf_export("nn.nce_loss", v1=[])
2012@dispatch.add_dispatch_support
2013def nce_loss_v2(weights,
2014 biases,
2015 labels,
2016 inputs,
2017 num_sampled,
2018 num_classes,
2019 num_true=1,
2020 sampled_values=None,
2021 remove_accidental_hits=False,
2022 name="nce_loss"):
2023 """Computes and returns the noise-contrastive estimation training loss.
2025 See [Noise-contrastive estimation: A new estimation principle for
2026 unnormalized statistical
2027 models](https://arxiv.org/abs/1806.03664).
2028 Also see our [Candidate Sampling Algorithms
2029 Reference](https://www.tensorflow.org/extras/candidate_sampling.pdf)
2031 A common use case is to use this method for training, and calculate the full
2032 sigmoid loss for evaluation or inference as in the following example:
2034 ```python
2035 if mode == "train":
2036 loss = tf.nn.nce_loss(
2037 weights=weights,
2038 biases=biases,
2039 labels=labels,
2040 inputs=inputs,
2041 ...)
2042 elif mode == "eval":
2043 logits = tf.matmul(inputs, tf.transpose(weights))
2044 logits = tf.nn.bias_add(logits, biases)
2045 labels_one_hot = tf.one_hot(labels, n_classes)
2046 loss = tf.nn.sigmoid_cross_entropy_with_logits(
2047 labels=labels_one_hot,
2048 logits=logits)
2049 loss = tf.reduce_sum(loss, axis=1)
2050 ```
2052 Note: when doing embedding lookup on `weights` and `bias`, "div" partition
2053 strategy will be used. Support for other partition strategy will be added
2054 later.
2056 Note: By default this uses a log-uniform (Zipfian) distribution for sampling,
2057 so your labels must be sorted in order of decreasing frequency to achieve
2058 good results. For more details, see
2059 `tf.random.log_uniform_candidate_sampler`.
2061 Note: In the case where `num_true` > 1, we assign to each target class
2062 the target probability 1 / `num_true` so that the target probabilities
2063 sum to 1 per-example.
2065 Note: It would be useful to allow a variable number of target classes per
2066 example. We hope to provide this functionality in a future release.
2067 For now, if you have a variable number of target classes, you can pad them
2068 out to a constant number by either repeating them or by padding
2069 with an otherwise unused class.
2071 Args:
2072 weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
2073 objects whose concatenation along dimension 0 has shape [num_classes,
2074 dim]. The (possibly-partitioned) class embeddings.
2075 biases: A `Tensor` of shape `[num_classes]`. The class biases.
2076 labels: A `Tensor` of type `int64` and shape `[batch_size, num_true]`. The
2077 target classes.
2078 inputs: A `Tensor` of shape `[batch_size, dim]`. The forward activations of
2079 the input network.
2080 num_sampled: An `int`. The number of negative classes to randomly sample
2081 per batch. This single sample of negative classes is evaluated for each
2082 element in the batch.
2083 num_classes: An `int`. The number of possible classes.
2084 num_true: An `int`. The number of target classes per training example.
2085 sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
2086 `sampled_expected_count`) returned by a `*_candidate_sampler` function.
2087 (if None, we default to `log_uniform_candidate_sampler`)
2088 remove_accidental_hits: A `bool`. Whether to remove "accidental hits"
2089 where a sampled class equals one of the target classes. If set to `True`,
2090 this is a "Sampled Logistic" loss instead of NCE, and we are learning to
2091 generate log-odds instead of log probabilities. See our [Candidate
2092 Sampling Algorithms Reference]
2093 (https://www.tensorflow.org/extras/candidate_sampling.pdf). Default is
2094 False.
2095 name: A name for the operation (optional).
2097 Returns:
2098 A `batch_size` 1-D tensor of per-example NCE losses.
2099 """
2100 # TODO(yuefengz): get partition_strategy from either variables or distribution
2101 # strategies.
2102 return nce_loss(
2103 weights,
2104 biases,
2105 labels,
2106 inputs,
2107 num_sampled,
2108 num_classes,
2109 num_true=num_true,
2110 sampled_values=sampled_values,
2111 remove_accidental_hits=remove_accidental_hits,
2112 partition_strategy="div",
2113 name=name)
2116@tf_export(v1=["nn.nce_loss"])
2117@dispatch.add_dispatch_support
2118def nce_loss(weights,
2119 biases,
2120 labels,
2121 inputs,
2122 num_sampled,
2123 num_classes,
2124 num_true=1,
2125 sampled_values=None,
2126 remove_accidental_hits=False,
2127 partition_strategy="mod",
2128 name="nce_loss"):
2129 """Computes and returns the noise-contrastive estimation training loss.
2131 A common use case is to use this method for training, and calculate the full
2132 sigmoid loss for evaluation or inference. In this case, you must set
2133 `partition_strategy="div"` for the two losses to be consistent, as in the
2134 following example:
2136 ```python
2137 if mode == "train":
2138 loss = tf.nn.nce_loss(
2139 weights=weights,
2140 biases=biases,
2141 labels=labels,
2142 inputs=inputs,
2143 ...,
2144 partition_strategy="div")
2145 elif mode == "eval":
2146 logits = tf.matmul(inputs, tf.transpose(weights))
2147 logits = tf.nn.bias_add(logits, biases)
2148 labels_one_hot = tf.one_hot(labels, n_classes)
2149 loss = tf.nn.sigmoid_cross_entropy_with_logits(
2150 labels=labels_one_hot,
2151 logits=logits)
2152 loss = tf.reduce_sum(loss, axis=1)
2153 ```
2155 Note: By default this uses a log-uniform (Zipfian) distribution for sampling,
2156 so your labels must be sorted in order of decreasing frequency to achieve
2157 good results. For more details, see
2158 `tf.random.log_uniform_candidate_sampler`.
2160 Note: In the case where `num_true` > 1, we assign to each target class
2161 the target probability 1 / `num_true` so that the target probabilities
2162 sum to 1 per-example.
2164 Note: It would be useful to allow a variable number of target classes per
2165 example. We hope to provide this functionality in a future release.
2166 For now, if you have a variable number of target classes, you can pad them
2167 out to a constant number by either repeating them or by padding
2168 with an otherwise unused class.
2170 Args:
2171 weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
2172 objects whose concatenation along dimension 0 has shape
2173 [num_classes, dim]. The (possibly-partitioned) class embeddings.
2174 biases: A `Tensor` of shape `[num_classes]`. The class biases.
2175 labels: A `Tensor` of type `int64` and shape `[batch_size,
2176 num_true]`. The target classes.
2177 inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
2178 activations of the input network.
2179 num_sampled: An `int`. The number of negative classes to randomly sample
2180 per batch. This single sample of negative classes is evaluated for each
2181 element in the batch.
2182 num_classes: An `int`. The number of possible classes.
2183 num_true: An `int`. The number of target classes per training example.
2184 sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
2185 `sampled_expected_count`) returned by a `*_candidate_sampler` function.
2186 (if None, we default to `log_uniform_candidate_sampler`)
2187 remove_accidental_hits: A `bool`. Whether to remove "accidental hits"
2188 where a sampled class equals one of the target classes. If set to
2189 `True`, this is a "Sampled Logistic" loss instead of NCE, and we are
2190 learning to generate log-odds instead of log probabilities. See
2191 our Candidate Sampling Algorithms Reference
2192 ([pdf](https://www.tensorflow.org/extras/candidate_sampling.pdf)).
2193 Default is False.
2194 partition_strategy: A string specifying the partitioning strategy, relevant
2195 if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
2196 Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
2197 name: A name for the operation (optional).
2199 Returns:
2200 A `batch_size` 1-D tensor of per-example NCE losses.
2202 References:
2203 Noise-contrastive estimation - A new estimation principle for unnormalized
2204 statistical models:
2205 [Gutmann et al., 2010](http://proceedings.mlr.press/v9/gutmann10a)
2206 ([pdf](http://proceedings.mlr.press/v9/gutmann10a/gutmann10a.pdf))
2207 """
2208 logits, labels = _compute_sampled_logits(
2209 weights=weights,
2210 biases=biases,
2211 labels=labels,
2212 inputs=inputs,
2213 num_sampled=num_sampled,
2214 num_classes=num_classes,
2215 num_true=num_true,
2216 sampled_values=sampled_values,
2217 subtract_log_q=True,
2218 remove_accidental_hits=remove_accidental_hits,
2219 partition_strategy=partition_strategy,
2220 name=name)
2221 sampled_losses = sigmoid_cross_entropy_with_logits(
2222 labels=labels, logits=logits, name="sampled_losses")
2223 # sampled_losses is batch_size x {true_loss, sampled_losses...}
2224 # We sum out true and sampled losses.
2225 return _sum_rows(sampled_losses)
2228@tf_export("nn.sampled_softmax_loss", v1=[])
2229@dispatch.add_dispatch_support
2230def sampled_softmax_loss_v2(weights,
2231 biases,
2232 labels,
2233 inputs,
2234 num_sampled,
2235 num_classes,
2236 num_true=1,
2237 sampled_values=None,
2238 remove_accidental_hits=True,
2239 seed=None,
2240 name="sampled_softmax_loss"):
2241 """Computes and returns the sampled softmax training loss.
2243 This is a faster way to train a softmax classifier over a huge number of
2244 classes.
2246 This operation is for training only. It is generally an underestimate of
2247 the full softmax loss.
2249 A common use case is to use this method for training, and calculate the full
2250 softmax loss for evaluation or inference as in the following example:
2252 ```python
2253 if mode == "train":
2254 loss = tf.nn.sampled_softmax_loss(
2255 weights=weights,
2256 biases=biases,
2257 labels=labels,
2258 inputs=inputs,
2259 ...)
2260 elif mode == "eval":
2261 logits = tf.matmul(inputs, tf.transpose(weights))
2262 logits = tf.nn.bias_add(logits, biases)
2263 labels_one_hot = tf.one_hot(labels, n_classes)
2264 loss = tf.nn.softmax_cross_entropy_with_logits(
2265 labels=labels_one_hot,
2266 logits=logits)
2267 ```
2269 See our [Candidate Sampling Algorithms Reference]
2270 (https://www.tensorflow.org/extras/candidate_sampling.pdf)
2272 Also see Section 3 of [Jean et al., 2014](http://arxiv.org/abs/1412.2007)
2273 ([pdf](http://arxiv.org/pdf/1412.2007.pdf)) for the math.
2275 Note: when doing embedding lookup on `weights` and `bias`, "div" partition
2276 strategy will be used. Support for other partition strategy will be added
2277 later.
2279 Args:
2280 weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
2281 objects whose concatenation along dimension 0 has shape [num_classes,
2282 dim]. The (possibly-sharded) class embeddings.
2283 biases: A `Tensor` of shape `[num_classes]`. The class biases.
2284 labels: A `Tensor` of type `int64` and shape `[batch_size, num_true]`. The
2285 target classes. Note that this format differs from the `labels` argument
2286 of `nn.softmax_cross_entropy_with_logits`.
2287 inputs: A `Tensor` of shape `[batch_size, dim]`. The forward activations of
2288 the input network.
2289 num_sampled: An `int`. The number of classes to randomly sample per batch.
2290 num_classes: An `int`. The number of possible classes.
2291 num_true: An `int`. The number of target classes per training example.
2292 sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
2293 `sampled_expected_count`) returned by a `*_candidate_sampler` function.
2294 (if None, we default to `log_uniform_candidate_sampler`)
2295 remove_accidental_hits: A `bool`. whether to remove "accidental hits"
2296 where a sampled class equals one of the target classes. Default is True.
2297 seed: random seed for candidate sampling. Default to None, which doesn't set
2298 the op-level random seed for candidate sampling.
2299 name: A name for the operation (optional).
2301 Returns:
2302 A `batch_size` 1-D tensor of per-example sampled softmax losses.
2304 """
2305 return sampled_softmax_loss(
2306 weights,
2307 biases,
2308 labels,
2309 inputs,
2310 num_sampled,
2311 num_classes,
2312 num_true=num_true,
2313 sampled_values=sampled_values,
2314 remove_accidental_hits=remove_accidental_hits,
2315 partition_strategy="div",
2316 name=name,
2317 seed=seed)
2320@tf_export(v1=["nn.sampled_softmax_loss"])
2321@dispatch.add_dispatch_support
2322def sampled_softmax_loss(weights,
2323 biases,
2324 labels,
2325 inputs,
2326 num_sampled,
2327 num_classes,
2328 num_true=1,
2329 sampled_values=None,
2330 remove_accidental_hits=True,
2331 partition_strategy="mod",
2332 name="sampled_softmax_loss",
2333 seed=None):
2334 """Computes and returns the sampled softmax training loss.
2336 This is a faster way to train a softmax classifier over a huge number of
2337 classes.
2339 This operation is for training only. It is generally an underestimate of
2340 the full softmax loss.
2342 A common use case is to use this method for training, and calculate the full
2343 softmax loss for evaluation or inference. In this case, you must set
2344 `partition_strategy="div"` for the two losses to be consistent, as in the
2345 following example:
2347 ```python
2348 if mode == "train":
2349 loss = tf.nn.sampled_softmax_loss(
2350 weights=weights,
2351 biases=biases,
2352 labels=labels,
2353 inputs=inputs,
2354 ...,
2355 partition_strategy="div")
2356 elif mode == "eval":
2357 logits = tf.matmul(inputs, tf.transpose(weights))
2358 logits = tf.nn.bias_add(logits, biases)
2359 labels_one_hot = tf.one_hot(labels, n_classes)
2360 loss = tf.nn.softmax_cross_entropy_with_logits(
2361 labels=labels_one_hot,
2362 logits=logits)
2363 ```
2365 See our Candidate Sampling Algorithms Reference
2366 ([pdf](https://www.tensorflow.org/extras/candidate_sampling.pdf)).
2367 Also see Section 3 of (Jean et al., 2014) for the math.
2369 Args:
2370 weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
2371 objects whose concatenation along dimension 0 has shape
2372 [num_classes, dim]. The (possibly-sharded) class embeddings.
2373 biases: A `Tensor` of shape `[num_classes]`. The class biases.
2374 labels: A `Tensor` of type `int64` and shape `[batch_size,
2375 num_true]`. The target classes. Note that this format differs from
2376 the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
2377 inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
2378 activations of the input network.
2379 num_sampled: An `int`. The number of classes to randomly sample per batch.
2380 num_classes: An `int`. The number of possible classes.
2381 num_true: An `int`. The number of target classes per training example.
2382 sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
2383 `sampled_expected_count`) returned by a `*_candidate_sampler` function.
2384 (if None, we default to `log_uniform_candidate_sampler`)
2385 remove_accidental_hits: A `bool`. whether to remove "accidental hits"
2386 where a sampled class equals one of the target classes. Default is
2387 True.
2388 partition_strategy: A string specifying the partitioning strategy, relevant
2389 if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
2390 Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
2391 name: A name for the operation (optional).
2392 seed: random seed for candidate sampling. Default to None, which doesn't set
2393 the op-level random seed for candidate sampling.
2395 Returns:
2396 A `batch_size` 1-D tensor of per-example sampled softmax losses.
2398 References:
2399 On Using Very Large Target Vocabulary for Neural Machine Translation:
2400 [Jean et al., 2014]
2401 (https://aclanthology.coli.uni-saarland.de/papers/P15-1001/p15-1001)
2402 ([pdf](http://aclweb.org/anthology/P15-1001))
2403 """
2404 logits, labels = _compute_sampled_logits(
2405 weights=weights,
2406 biases=biases,
2407 labels=labels,
2408 inputs=inputs,
2409 num_sampled=num_sampled,
2410 num_classes=num_classes,
2411 num_true=num_true,
2412 sampled_values=sampled_values,
2413 subtract_log_q=True,
2414 remove_accidental_hits=remove_accidental_hits,
2415 partition_strategy=partition_strategy,
2416 name=name,
2417 seed=seed)
2418 labels = array_ops.stop_gradient(labels, name="labels_stop_gradient")
2419 sampled_losses = nn_ops.softmax_cross_entropy_with_logits_v2(
2420 labels=labels, logits=logits)
2421 # sampled_losses is a [batch_size] tensor.
2422 return sampled_losses