Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/kullback_leibler.py: 50%

52 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Registration and usage mechanisms for KL-divergences.""" 

16 

17from tensorflow.python.framework import ops 

18from tensorflow.python.ops import array_ops 

19from tensorflow.python.ops import control_flow_assert 

20from tensorflow.python.ops import math_ops 

21from tensorflow.python.util import deprecation 

22from tensorflow.python.util import tf_inspect 

23from tensorflow.python.util.tf_export import tf_export 

24 

25 

26_DIVERGENCES = {} 

27 

28 

29__all__ = [ 

30 "RegisterKL", 

31 "kl_divergence", 

32] 

33 

34 

35def _registered_kl(type_a, type_b): 

36 """Get the KL function registered for classes a and b.""" 

37 hierarchy_a = tf_inspect.getmro(type_a) 

38 hierarchy_b = tf_inspect.getmro(type_b) 

39 dist_to_children = None 

40 kl_fn = None 

41 for mro_to_a, parent_a in enumerate(hierarchy_a): 

42 for mro_to_b, parent_b in enumerate(hierarchy_b): 

43 candidate_dist = mro_to_a + mro_to_b 

44 candidate_kl_fn = _DIVERGENCES.get((parent_a, parent_b), None) 

45 if not kl_fn or (candidate_kl_fn and candidate_dist < dist_to_children): 

46 dist_to_children = candidate_dist 

47 kl_fn = candidate_kl_fn 

48 return kl_fn 

49 

50 

51@deprecation.deprecated( 

52 "2019-01-01", 

53 "The TensorFlow Distributions library has moved to " 

54 "TensorFlow Probability " 

55 "(https://github.com/tensorflow/probability). You " 

56 "should update all references to use `tfp.distributions` " 

57 "instead of `tf.distributions`.", 

58 warn_once=True) 

59@tf_export(v1=["distributions.kl_divergence"]) 

60def kl_divergence(distribution_a, distribution_b, 

61 allow_nan_stats=True, name=None): 

62 """Get the KL-divergence KL(distribution_a || distribution_b). 

63 

64 If there is no KL method registered specifically for `type(distribution_a)` 

65 and `type(distribution_b)`, then the class hierarchies of these types are 

66 searched. 

67 

68 If one KL method is registered between any pairs of classes in these two 

69 parent hierarchies, it is used. 

70 

71 If more than one such registered method exists, the method whose registered 

72 classes have the shortest sum MRO paths to the input types is used. 

73 

74 If more than one such shortest path exists, the first method 

75 identified in the search is used (favoring a shorter MRO distance to 

76 `type(distribution_a)`). 

77 

78 Args: 

79 distribution_a: The first distribution. 

80 distribution_b: The second distribution. 

81 allow_nan_stats: Python `bool`, default `True`. When `True`, 

82 statistics (e.g., mean, mode, variance) use the value "`NaN`" to 

83 indicate the result is undefined. When `False`, an exception is raised 

84 if one or more of the statistic's batch members are undefined. 

85 name: Python `str` name prefixed to Ops created by this class. 

86 

87 Returns: 

88 A Tensor with the batchwise KL-divergence between `distribution_a` 

89 and `distribution_b`. 

90 

91 Raises: 

92 NotImplementedError: If no KL method is defined for distribution types 

93 of `distribution_a` and `distribution_b`. 

94 """ 

95 kl_fn = _registered_kl(type(distribution_a), type(distribution_b)) 

96 if kl_fn is None: 

97 raise NotImplementedError( 

98 "No KL(distribution_a || distribution_b) registered for distribution_a " 

99 "type %s and distribution_b type %s" 

100 % (type(distribution_a).__name__, type(distribution_b).__name__)) 

101 

102 with ops.name_scope("KullbackLeibler"): 

103 kl_t = kl_fn(distribution_a, distribution_b, name=name) 

104 if allow_nan_stats: 

105 return kl_t 

106 

107 # Check KL for NaNs 

108 kl_t = array_ops.identity(kl_t, name="kl") 

109 

110 with ops.control_dependencies([ 

111 control_flow_assert.Assert( 

112 math_ops.logical_not(math_ops.reduce_any(math_ops.is_nan(kl_t))), [ 

113 "KL calculation between %s and %s returned NaN values " 

114 "(and was called with allow_nan_stats=False). Values:" % 

115 (distribution_a.name, distribution_b.name), kl_t 

116 ]) 

117 ]): 

118 return array_ops.identity(kl_t, name="checked_kl") 

119 

120 

121@deprecation.deprecated( 

122 "2019-01-01", 

123 "The TensorFlow Distributions library has moved to " 

124 "TensorFlow Probability " 

125 "(https://github.com/tensorflow/probability). You " 

126 "should update all references to use `tfp.distributions` " 

127 "instead of `tf.distributions`.", 

128 warn_once=True) 

129def cross_entropy(ref, other, 

130 allow_nan_stats=True, name=None): 

131 """Computes the (Shannon) cross entropy. 

132 

133 Denote two distributions by `P` (`ref`) and `Q` (`other`). Assuming `P, Q` 

134 are absolutely continuous with respect to one another and permit densities 

135 `p(x) dr(x)` and `q(x) dr(x)`, (Shanon) cross entropy is defined as: 

136 

137 ```none 

138 H[P, Q] = E_p[-log q(X)] = -int_F p(x) log q(x) dr(x) 

139 ``` 

140 

141 where `F` denotes the support of the random variable `X ~ P`. 

142 

143 Args: 

144 ref: `tfd.Distribution` instance. 

145 other: `tfd.Distribution` instance. 

146 allow_nan_stats: Python `bool`, default `True`. When `True`, 

147 statistics (e.g., mean, mode, variance) use the value "`NaN`" to 

148 indicate the result is undefined. When `False`, an exception is raised 

149 if one or more of the statistic's batch members are undefined. 

150 name: Python `str` prepended to names of ops created by this function. 

151 

152 Returns: 

153 cross_entropy: `ref.dtype` `Tensor` with shape `[B1, ..., Bn]` 

154 representing `n` different calculations of (Shanon) cross entropy. 

155 """ 

156 with ops.name_scope(name, "cross_entropy"): 

157 return ref.entropy() + kl_divergence( 

158 ref, other, allow_nan_stats=allow_nan_stats) 

159 

160 

161@tf_export(v1=["distributions.RegisterKL"]) 

162class RegisterKL: 

163 """Decorator to register a KL divergence implementation function. 

164 

165 Usage: 

166 

167 @distributions.RegisterKL(distributions.Normal, distributions.Normal) 

168 def _kl_normal_mvn(norm_a, norm_b): 

169 # Return KL(norm_a || norm_b) 

170 """ 

171 

172 @deprecation.deprecated( 

173 "2019-01-01", 

174 "The TensorFlow Distributions library has moved to " 

175 "TensorFlow Probability " 

176 "(https://github.com/tensorflow/probability). You " 

177 "should update all references to use `tfp.distributions` " 

178 "instead of `tf.distributions`.", 

179 warn_once=True) 

180 def __init__(self, dist_cls_a, dist_cls_b): 

181 """Initialize the KL registrar. 

182 

183 Args: 

184 dist_cls_a: the class of the first argument of the KL divergence. 

185 dist_cls_b: the class of the second argument of the KL divergence. 

186 """ 

187 self._key = (dist_cls_a, dist_cls_b) 

188 

189 def __call__(self, kl_fn): 

190 """Perform the KL registration. 

191 

192 Args: 

193 kl_fn: The function to use for the KL divergence. 

194 

195 Returns: 

196 kl_fn 

197 

198 Raises: 

199 TypeError: if kl_fn is not a callable. 

200 ValueError: if a KL divergence function has already been registered for 

201 the given argument classes. 

202 """ 

203 if not callable(kl_fn): 

204 raise TypeError("kl_fn must be callable, received: %s" % kl_fn) 

205 if self._key in _DIVERGENCES: 

206 raise ValueError("KL(%s || %s) has already been registered to: %s" 

207 % (self._key[0].__name__, self._key[1].__name__, 

208 _DIVERGENCES[self._key])) 

209 _DIVERGENCES[self._key] = kl_fn 

210 return kl_fn