Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/kullback_leibler.py: 50%
52 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Registration and usage mechanisms for KL-divergences."""
17from tensorflow.python.framework import ops
18from tensorflow.python.ops import array_ops
19from tensorflow.python.ops import control_flow_assert
20from tensorflow.python.ops import math_ops
21from tensorflow.python.util import deprecation
22from tensorflow.python.util import tf_inspect
23from tensorflow.python.util.tf_export import tf_export
26_DIVERGENCES = {}
29__all__ = [
30 "RegisterKL",
31 "kl_divergence",
32]
35def _registered_kl(type_a, type_b):
36 """Get the KL function registered for classes a and b."""
37 hierarchy_a = tf_inspect.getmro(type_a)
38 hierarchy_b = tf_inspect.getmro(type_b)
39 dist_to_children = None
40 kl_fn = None
41 for mro_to_a, parent_a in enumerate(hierarchy_a):
42 for mro_to_b, parent_b in enumerate(hierarchy_b):
43 candidate_dist = mro_to_a + mro_to_b
44 candidate_kl_fn = _DIVERGENCES.get((parent_a, parent_b), None)
45 if not kl_fn or (candidate_kl_fn and candidate_dist < dist_to_children):
46 dist_to_children = candidate_dist
47 kl_fn = candidate_kl_fn
48 return kl_fn
51@deprecation.deprecated(
52 "2019-01-01",
53 "The TensorFlow Distributions library has moved to "
54 "TensorFlow Probability "
55 "(https://github.com/tensorflow/probability). You "
56 "should update all references to use `tfp.distributions` "
57 "instead of `tf.distributions`.",
58 warn_once=True)
59@tf_export(v1=["distributions.kl_divergence"])
60def kl_divergence(distribution_a, distribution_b,
61 allow_nan_stats=True, name=None):
62 """Get the KL-divergence KL(distribution_a || distribution_b).
64 If there is no KL method registered specifically for `type(distribution_a)`
65 and `type(distribution_b)`, then the class hierarchies of these types are
66 searched.
68 If one KL method is registered between any pairs of classes in these two
69 parent hierarchies, it is used.
71 If more than one such registered method exists, the method whose registered
72 classes have the shortest sum MRO paths to the input types is used.
74 If more than one such shortest path exists, the first method
75 identified in the search is used (favoring a shorter MRO distance to
76 `type(distribution_a)`).
78 Args:
79 distribution_a: The first distribution.
80 distribution_b: The second distribution.
81 allow_nan_stats: Python `bool`, default `True`. When `True`,
82 statistics (e.g., mean, mode, variance) use the value "`NaN`" to
83 indicate the result is undefined. When `False`, an exception is raised
84 if one or more of the statistic's batch members are undefined.
85 name: Python `str` name prefixed to Ops created by this class.
87 Returns:
88 A Tensor with the batchwise KL-divergence between `distribution_a`
89 and `distribution_b`.
91 Raises:
92 NotImplementedError: If no KL method is defined for distribution types
93 of `distribution_a` and `distribution_b`.
94 """
95 kl_fn = _registered_kl(type(distribution_a), type(distribution_b))
96 if kl_fn is None:
97 raise NotImplementedError(
98 "No KL(distribution_a || distribution_b) registered for distribution_a "
99 "type %s and distribution_b type %s"
100 % (type(distribution_a).__name__, type(distribution_b).__name__))
102 with ops.name_scope("KullbackLeibler"):
103 kl_t = kl_fn(distribution_a, distribution_b, name=name)
104 if allow_nan_stats:
105 return kl_t
107 # Check KL for NaNs
108 kl_t = array_ops.identity(kl_t, name="kl")
110 with ops.control_dependencies([
111 control_flow_assert.Assert(
112 math_ops.logical_not(math_ops.reduce_any(math_ops.is_nan(kl_t))), [
113 "KL calculation between %s and %s returned NaN values "
114 "(and was called with allow_nan_stats=False). Values:" %
115 (distribution_a.name, distribution_b.name), kl_t
116 ])
117 ]):
118 return array_ops.identity(kl_t, name="checked_kl")
121@deprecation.deprecated(
122 "2019-01-01",
123 "The TensorFlow Distributions library has moved to "
124 "TensorFlow Probability "
125 "(https://github.com/tensorflow/probability). You "
126 "should update all references to use `tfp.distributions` "
127 "instead of `tf.distributions`.",
128 warn_once=True)
129def cross_entropy(ref, other,
130 allow_nan_stats=True, name=None):
131 """Computes the (Shannon) cross entropy.
133 Denote two distributions by `P` (`ref`) and `Q` (`other`). Assuming `P, Q`
134 are absolutely continuous with respect to one another and permit densities
135 `p(x) dr(x)` and `q(x) dr(x)`, (Shanon) cross entropy is defined as:
137 ```none
138 H[P, Q] = E_p[-log q(X)] = -int_F p(x) log q(x) dr(x)
139 ```
141 where `F` denotes the support of the random variable `X ~ P`.
143 Args:
144 ref: `tfd.Distribution` instance.
145 other: `tfd.Distribution` instance.
146 allow_nan_stats: Python `bool`, default `True`. When `True`,
147 statistics (e.g., mean, mode, variance) use the value "`NaN`" to
148 indicate the result is undefined. When `False`, an exception is raised
149 if one or more of the statistic's batch members are undefined.
150 name: Python `str` prepended to names of ops created by this function.
152 Returns:
153 cross_entropy: `ref.dtype` `Tensor` with shape `[B1, ..., Bn]`
154 representing `n` different calculations of (Shanon) cross entropy.
155 """
156 with ops.name_scope(name, "cross_entropy"):
157 return ref.entropy() + kl_divergence(
158 ref, other, allow_nan_stats=allow_nan_stats)
161@tf_export(v1=["distributions.RegisterKL"])
162class RegisterKL:
163 """Decorator to register a KL divergence implementation function.
165 Usage:
167 @distributions.RegisterKL(distributions.Normal, distributions.Normal)
168 def _kl_normal_mvn(norm_a, norm_b):
169 # Return KL(norm_a || norm_b)
170 """
172 @deprecation.deprecated(
173 "2019-01-01",
174 "The TensorFlow Distributions library has moved to "
175 "TensorFlow Probability "
176 "(https://github.com/tensorflow/probability). You "
177 "should update all references to use `tfp.distributions` "
178 "instead of `tf.distributions`.",
179 warn_once=True)
180 def __init__(self, dist_cls_a, dist_cls_b):
181 """Initialize the KL registrar.
183 Args:
184 dist_cls_a: the class of the first argument of the KL divergence.
185 dist_cls_b: the class of the second argument of the KL divergence.
186 """
187 self._key = (dist_cls_a, dist_cls_b)
189 def __call__(self, kl_fn):
190 """Perform the KL registration.
192 Args:
193 kl_fn: The function to use for the KL divergence.
195 Returns:
196 kl_fn
198 Raises:
199 TypeError: if kl_fn is not a callable.
200 ValueError: if a KL divergence function has already been registered for
201 the given argument classes.
202 """
203 if not callable(kl_fn):
204 raise TypeError("kl_fn must be callable, received: %s" % kl_fn)
205 if self._key in _DIVERGENCES:
206 raise ValueError("KL(%s || %s) has already been registered to: %s"
207 % (self._key[0].__name__, self._key[1].__name__,
208 _DIVERGENCES[self._key]))
209 _DIVERGENCES[self._key] = kl_fn
210 return kl_fn