Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/distribution.py: 39%

399 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Base classes for probability distributions.""" 

16 

17import abc 

18import contextlib 

19import types 

20 

21import numpy as np 

22 

23from tensorflow.python.eager import context 

24from tensorflow.python.framework import dtypes 

25from tensorflow.python.framework import ops 

26from tensorflow.python.framework import tensor_shape 

27from tensorflow.python.framework import tensor_util 

28from tensorflow.python.ops import array_ops 

29from tensorflow.python.ops import math_ops 

30from tensorflow.python.ops.distributions import kullback_leibler 

31from tensorflow.python.ops.distributions import util 

32from tensorflow.python.util import deprecation 

33from tensorflow.python.util import tf_inspect 

34from tensorflow.python.util.tf_export import tf_export 

35 

36 

37__all__ = [ 

38 "ReparameterizationType", 

39 "FULLY_REPARAMETERIZED", 

40 "NOT_REPARAMETERIZED", 

41 "Distribution", 

42] 

43 

44_DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [ 

45 "batch_shape", 

46 "batch_shape_tensor", 

47 "cdf", 

48 "covariance", 

49 "cross_entropy", 

50 "entropy", 

51 "event_shape", 

52 "event_shape_tensor", 

53 "kl_divergence", 

54 "log_cdf", 

55 "log_prob", 

56 "log_survival_function", 

57 "mean", 

58 "mode", 

59 "prob", 

60 "sample", 

61 "stddev", 

62 "survival_function", 

63 "variance", 

64] 

65 

66 

67class _BaseDistribution(metaclass=abc.ABCMeta): 

68 """Abstract base class needed for resolving subclass hierarchy.""" 

69 pass 

70 

71 

72def _copy_fn(fn): 

73 """Create a deep copy of fn. 

74 

75 Args: 

76 fn: a callable 

77 

78 Returns: 

79 A `FunctionType`: a deep copy of fn. 

80 

81 Raises: 

82 TypeError: if `fn` is not a callable. 

83 """ 

84 if not callable(fn): 

85 raise TypeError("fn is not callable: %s" % fn) 

86 # The blessed way to copy a function. copy.deepcopy fails to create a 

87 # non-reference copy. Since: 

88 # types.FunctionType == type(lambda: None), 

89 # and the docstring for the function type states: 

90 # 

91 # function(code, globals[, name[, argdefs[, closure]]]) 

92 # 

93 # Create a function object from a code object and a dictionary. 

94 # ... 

95 # 

96 # Here we can use this to create a new function with the old function's 

97 # code, globals, closure, etc. 

98 return types.FunctionType( 

99 code=fn.__code__, globals=fn.__globals__, 

100 name=fn.__name__, argdefs=fn.__defaults__, 

101 closure=fn.__closure__) 

102 

103 

104def _update_docstring(old_str, append_str): 

105 """Update old_str by inserting append_str just before the "Args:" section.""" 

106 old_str = old_str or "" 

107 old_str_lines = old_str.split("\n") 

108 

109 # Step 0: Prepend spaces to all lines of append_str. This is 

110 # necessary for correct markdown generation. 

111 append_str = "\n".join(" %s" % line for line in append_str.split("\n")) 

112 

113 # Step 1: Find mention of "Args": 

114 has_args_ix = [ 

115 ix for ix, line in enumerate(old_str_lines) 

116 if line.strip().lower() == "args:"] 

117 if has_args_ix: 

118 final_args_ix = has_args_ix[-1] 

119 return ("\n".join(old_str_lines[:final_args_ix]) 

120 + "\n\n" + append_str + "\n\n" 

121 + "\n".join(old_str_lines[final_args_ix:])) 

122 else: 

123 return old_str + "\n\n" + append_str 

124 

125 

126def _convert_to_tensor(value, name=None, preferred_dtype=None): 

127 """Converts to tensor avoiding an eager bug that loses float precision.""" 

128 # TODO(b/116672045): Remove this function. 

129 if (context.executing_eagerly() and preferred_dtype is not None and 

130 (preferred_dtype.is_integer or preferred_dtype.is_bool)): 

131 v = ops.convert_to_tensor(value, name=name) 

132 if v.dtype.is_floating: 

133 return v 

134 return ops.convert_to_tensor( 

135 value, name=name, preferred_dtype=preferred_dtype) 

136 

137 

138class _DistributionMeta(abc.ABCMeta): 

139 

140 def __new__(mcs, classname, baseclasses, attrs): 

141 """Control the creation of subclasses of the Distribution class. 

142 

143 The main purpose of this method is to properly propagate docstrings 

144 from private Distribution methods, like `_log_prob`, into their 

145 public wrappers as inherited by the Distribution base class 

146 (e.g. `log_prob`). 

147 

148 Args: 

149 classname: The name of the subclass being created. 

150 baseclasses: A tuple of parent classes. 

151 attrs: A dict mapping new attributes to their values. 

152 

153 Returns: 

154 The class object. 

155 

156 Raises: 

157 TypeError: If `Distribution` is not a subclass of `BaseDistribution`, or 

158 the new class is derived via multiple inheritance and the first 

159 parent class is not a subclass of `BaseDistribution`. 

160 AttributeError: If `Distribution` does not implement e.g. `log_prob`. 

161 ValueError: If a `Distribution` public method lacks a docstring. 

162 """ 

163 if not baseclasses: # Nothing to be done for Distribution 

164 raise TypeError("Expected non-empty baseclass. Does Distribution " 

165 "not subclass _BaseDistribution?") 

166 which_base = [ 

167 base for base in baseclasses 

168 if base == _BaseDistribution or issubclass(base, Distribution)] 

169 base = which_base[0] 

170 if base == _BaseDistribution: # Nothing to be done for Distribution 

171 return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs) 

172 if not issubclass(base, Distribution): 

173 raise TypeError("First parent class declared for %s must be " 

174 "Distribution, but saw '%s'" % (classname, base.__name__)) 

175 for attr in _DISTRIBUTION_PUBLIC_METHOD_WRAPPERS: 

176 special_attr = "_%s" % attr 

177 class_attr_value = attrs.get(attr, None) 

178 if attr in attrs: 

179 # The method is being overridden, do not update its docstring 

180 continue 

181 base_attr_value = getattr(base, attr, None) 

182 if not base_attr_value: 

183 raise AttributeError( 

184 "Internal error: expected base class '%s' to implement method '%s'" 

185 % (base.__name__, attr)) 

186 class_special_attr_value = attrs.get(special_attr, None) 

187 if class_special_attr_value is None: 

188 # No _special method available, no need to update the docstring. 

189 continue 

190 class_special_attr_docstring = tf_inspect.getdoc(class_special_attr_value) 

191 if not class_special_attr_docstring: 

192 # No docstring to append. 

193 continue 

194 class_attr_value = _copy_fn(base_attr_value) 

195 class_attr_docstring = tf_inspect.getdoc(base_attr_value) 

196 if class_attr_docstring is None: 

197 raise ValueError( 

198 "Expected base class fn to contain a docstring: %s.%s" 

199 % (base.__name__, attr)) 

200 class_attr_value.__doc__ = _update_docstring( 

201 class_attr_value.__doc__, 

202 ("Additional documentation from `%s`:\n\n%s" 

203 % (classname, class_special_attr_docstring))) 

204 attrs[attr] = class_attr_value 

205 

206 return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs) 

207 

208 

209@tf_export(v1=["distributions.ReparameterizationType"]) 

210class ReparameterizationType: 

211 """Instances of this class represent how sampling is reparameterized. 

212 

213 Two static instances exist in the distributions library, signifying 

214 one of two possible properties for samples from a distribution: 

215 

216 `FULLY_REPARAMETERIZED`: Samples from the distribution are fully 

217 reparameterized, and straight-through gradients are supported. 

218 

219 `NOT_REPARAMETERIZED`: Samples from the distribution are not fully 

220 reparameterized, and straight-through gradients are either partially 

221 unsupported or are not supported at all. In this case, for purposes of 

222 e.g. RL or variational inference, it is generally safest to wrap the 

223 sample results in a `stop_gradients` call and use policy 

224 gradients / surrogate loss instead. 

225 """ 

226 

227 @deprecation.deprecated( 

228 "2019-01-01", 

229 "The TensorFlow Distributions library has moved to " 

230 "TensorFlow Probability " 

231 "(https://github.com/tensorflow/probability). You " 

232 "should update all references to use `tfp.distributions` " 

233 "instead of `tf.distributions`.", 

234 warn_once=True) 

235 def __init__(self, rep_type): 

236 self._rep_type = rep_type 

237 

238 def __repr__(self): 

239 return "<Reparameterization Type: %s>" % self._rep_type 

240 

241 def __eq__(self, other): 

242 """Determine if this `ReparameterizationType` is equal to another. 

243 

244 Since ReparameterizationType instances are constant static global 

245 instances, equality checks if two instances' id() values are equal. 

246 

247 Args: 

248 other: Object to compare against. 

249 

250 Returns: 

251 `self is other`. 

252 """ 

253 return self is other 

254 

255 

256# Fully reparameterized distribution: samples from a fully 

257# reparameterized distribution support straight-through gradients with 

258# respect to all parameters. 

259FULLY_REPARAMETERIZED = ReparameterizationType("FULLY_REPARAMETERIZED") 

260tf_export(v1=["distributions.FULLY_REPARAMETERIZED"]).export_constant( 

261 __name__, "FULLY_REPARAMETERIZED") 

262 

263 

264# Not reparameterized distribution: samples from a non- 

265# reparameterized distribution do not support straight-through gradients for 

266# at least some of the parameters. 

267NOT_REPARAMETERIZED = ReparameterizationType("NOT_REPARAMETERIZED") 

268tf_export(v1=["distributions.NOT_REPARAMETERIZED"]).export_constant( 

269 __name__, "NOT_REPARAMETERIZED") 

270 

271 

272@tf_export(v1=["distributions.Distribution"]) 

273class Distribution(_BaseDistribution, metaclass=_DistributionMeta): 

274 """A generic probability distribution base class. 

275 

276 `Distribution` is a base class for constructing and organizing properties 

277 (e.g., mean, variance) of random variables (e.g, Bernoulli, Gaussian). 

278 

279 #### Subclassing 

280 

281 Subclasses are expected to implement a leading-underscore version of the 

282 same-named function. The argument signature should be identical except for 

283 the omission of `name="..."`. For example, to enable `log_prob(value, 

284 name="log_prob")` a subclass should implement `_log_prob(value)`. 

285 

286 Subclasses can append to public-level docstrings by providing 

287 docstrings for their method specializations. For example: 

288 

289 ```python 

290 @util.AppendDocstring("Some other details.") 

291 def _log_prob(self, value): 

292 ... 

293 ``` 

294 

295 would add the string "Some other details." to the `log_prob` function 

296 docstring. This is implemented as a simple decorator to avoid python 

297 linter complaining about missing Args/Returns/Raises sections in the 

298 partial docstrings. 

299 

300 #### Broadcasting, batching, and shapes 

301 

302 All distributions support batches of independent distributions of that type. 

303 The batch shape is determined by broadcasting together the parameters. 

304 

305 The shape of arguments to `__init__`, `cdf`, `log_cdf`, `prob`, and 

306 `log_prob` reflect this broadcasting, as does the return value of `sample` and 

307 `sample_n`. 

308 

309 `sample_n_shape = [n] + batch_shape + event_shape`, where `sample_n_shape` is 

310 the shape of the `Tensor` returned from `sample_n`, `n` is the number of 

311 samples, `batch_shape` defines how many independent distributions there are, 

312 and `event_shape` defines the shape of samples from each of those independent 

313 distributions. Samples are independent along the `batch_shape` dimensions, but 

314 not necessarily so along the `event_shape` dimensions (depending on the 

315 particulars of the underlying distribution). 

316 

317 Using the `Uniform` distribution as an example: 

318 

319 ```python 

320 minval = 3.0 

321 maxval = [[4.0, 6.0], 

322 [10.0, 12.0]] 

323 

324 # Broadcasting: 

325 # This instance represents 4 Uniform distributions. Each has a lower bound at 

326 # 3.0 as the `minval` parameter was broadcasted to match `maxval`'s shape. 

327 u = Uniform(minval, maxval) 

328 

329 # `event_shape` is `TensorShape([])`. 

330 event_shape = u.event_shape 

331 # `event_shape_t` is a `Tensor` which will evaluate to []. 

332 event_shape_t = u.event_shape_tensor() 

333 

334 # Sampling returns a sample per distribution. `samples` has shape 

335 # [5, 2, 2], which is [n] + batch_shape + event_shape, where n=5, 

336 # batch_shape=[2, 2], and event_shape=[]. 

337 samples = u.sample_n(5) 

338 

339 # The broadcasting holds across methods. Here we use `cdf` as an example. The 

340 # same holds for `log_cdf` and the likelihood functions. 

341 

342 # `cum_prob` has shape [2, 2] as the `value` argument was broadcasted to the 

343 # shape of the `Uniform` instance. 

344 cum_prob_broadcast = u.cdf(4.0) 

345 

346 # `cum_prob`'s shape is [2, 2], one per distribution. No broadcasting 

347 # occurred. 

348 cum_prob_per_dist = u.cdf([[4.0, 5.0], 

349 [6.0, 7.0]]) 

350 

351 # INVALID as the `value` argument is not broadcastable to the distribution's 

352 # shape. 

353 cum_prob_invalid = u.cdf([4.0, 5.0, 6.0]) 

354 ``` 

355 

356 #### Shapes 

357 

358 There are three important concepts associated with TensorFlow Distributions 

359 shapes: 

360 - Event shape describes the shape of a single draw from the distribution; 

361 it may be dependent across dimensions. For scalar distributions, the event 

362 shape is `[]`. For a 5-dimensional MultivariateNormal, the event shape is 

363 `[5]`. 

364 - Batch shape describes independent, not identically distributed draws, aka a 

365 "collection" or "bunch" of distributions. 

366 - Sample shape describes independent, identically distributed draws of batches 

367 from the distribution family. 

368 

369 The event shape and the batch shape are properties of a Distribution object, 

370 whereas the sample shape is associated with a specific call to `sample` or 

371 `log_prob`. 

372 

373 For detailed usage examples of TensorFlow Distributions shapes, see 

374 [this tutorial]( 

375 https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Understanding_TensorFlow_Distributions_Shapes.ipynb) 

376 

377 #### Parameter values leading to undefined statistics or distributions. 

378 

379 Some distributions do not have well-defined statistics for all initialization 

380 parameter values. For example, the beta distribution is parameterized by 

381 positive real numbers `concentration1` and `concentration0`, and does not have 

382 well-defined mode if `concentration1 < 1` or `concentration0 < 1`. 

383 

384 The user is given the option of raising an exception or returning `NaN`. 

385 

386 ```python 

387 a = tf.exp(tf.matmul(logits, weights_a)) 

388 b = tf.exp(tf.matmul(logits, weights_b)) 

389 

390 # Will raise exception if ANY batch member has a < 1 or b < 1. 

391 dist = distributions.beta(a, b, allow_nan_stats=False) 

392 mode = dist.mode().eval() 

393 

394 # Will return NaN for batch members with either a < 1 or b < 1. 

395 dist = distributions.beta(a, b, allow_nan_stats=True) # Default behavior 

396 mode = dist.mode().eval() 

397 ``` 

398 

399 In all cases, an exception is raised if *invalid* parameters are passed, e.g. 

400 

401 ```python 

402 # Will raise an exception if any Op is run. 

403 negative_a = -1.0 * a # beta distribution by definition has a > 0. 

404 dist = distributions.beta(negative_a, b, allow_nan_stats=True) 

405 dist.mean().eval() 

406 ``` 

407 

408 """ 

409 

410 @deprecation.deprecated( 

411 "2019-01-01", 

412 "The TensorFlow Distributions library has moved to " 

413 "TensorFlow Probability " 

414 "(https://github.com/tensorflow/probability). You " 

415 "should update all references to use `tfp.distributions` " 

416 "instead of `tf.distributions`.", 

417 warn_once=True) 

418 def __init__(self, 

419 dtype, 

420 reparameterization_type, 

421 validate_args, 

422 allow_nan_stats, 

423 parameters=None, 

424 graph_parents=None, 

425 name=None): 

426 """Constructs the `Distribution`. 

427 

428 **This is a private method for subclass use.** 

429 

430 Args: 

431 dtype: The type of the event samples. `None` implies no type-enforcement. 

432 reparameterization_type: Instance of `ReparameterizationType`. 

433 If `distributions.FULLY_REPARAMETERIZED`, this 

434 `Distribution` can be reparameterized in terms of some standard 

435 distribution with a function whose Jacobian is constant for the support 

436 of the standard distribution. If `distributions.NOT_REPARAMETERIZED`, 

437 then no such reparameterization is available. 

438 validate_args: Python `bool`, default `False`. When `True` distribution 

439 parameters are checked for validity despite possibly degrading runtime 

440 performance. When `False` invalid inputs may silently render incorrect 

441 outputs. 

442 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 

443 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 

444 result is undefined. When `False`, an exception is raised if one or 

445 more of the statistic's batch members are undefined. 

446 parameters: Python `dict` of parameters used to instantiate this 

447 `Distribution`. 

448 graph_parents: Python `list` of graph prerequisites of this 

449 `Distribution`. 

450 name: Python `str` name prefixed to Ops created by this class. Default: 

451 subclass name. 

452 

453 Raises: 

454 ValueError: if any member of graph_parents is `None` or not a `Tensor`. 

455 """ 

456 graph_parents = [] if graph_parents is None else graph_parents 

457 for i, t in enumerate(graph_parents): 

458 if t is None or not tensor_util.is_tf_type(t): 

459 raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t)) 

460 if not name or name[-1] != "/": # `name` is not a name scope 

461 non_unique_name = name or type(self).__name__ 

462 with ops.name_scope(non_unique_name) as name: 

463 pass 

464 self._dtype = dtype 

465 self._reparameterization_type = reparameterization_type 

466 self._allow_nan_stats = allow_nan_stats 

467 self._validate_args = validate_args 

468 self._parameters = parameters or {} 

469 self._graph_parents = graph_parents 

470 self._name = name 

471 

472 @property 

473 def _parameters(self): 

474 return self._parameter_dict 

475 

476 @_parameters.setter 

477 def _parameters(self, value): 

478 """Intercept assignments to self._parameters to avoid reference cycles. 

479 

480 Parameters are often created using locals(), so we need to clean out any 

481 references to `self` before assigning it to an attribute. 

482 

483 Args: 

484 value: A dictionary of parameters to assign to the `_parameters` property. 

485 """ 

486 if "self" in value: 

487 del value["self"] 

488 self._parameter_dict = value 

489 

490 @classmethod 

491 def param_shapes(cls, sample_shape, name="DistributionParamShapes"): 

492 """Shapes of parameters given the desired shape of a call to `sample()`. 

493 

494 This is a class method that describes what key/value arguments are required 

495 to instantiate the given `Distribution` so that a particular shape is 

496 returned for that instance's call to `sample()`. 

497 

498 Subclasses should override class method `_param_shapes`. 

499 

500 Args: 

501 sample_shape: `Tensor` or python list/tuple. Desired shape of a call to 

502 `sample()`. 

503 name: name to prepend ops with. 

504 

505 Returns: 

506 `dict` of parameter name to `Tensor` shapes. 

507 """ 

508 with ops.name_scope(name, values=[sample_shape]): 

509 return cls._param_shapes(sample_shape) 

510 

511 @classmethod 

512 def param_static_shapes(cls, sample_shape): 

513 """param_shapes with static (i.e. `TensorShape`) shapes. 

514 

515 This is a class method that describes what key/value arguments are required 

516 to instantiate the given `Distribution` so that a particular shape is 

517 returned for that instance's call to `sample()`. Assumes that the sample's 

518 shape is known statically. 

519 

520 Subclasses should override class method `_param_shapes` to return 

521 constant-valued tensors when constant values are fed. 

522 

523 Args: 

524 sample_shape: `TensorShape` or python list/tuple. Desired shape of a call 

525 to `sample()`. 

526 

527 Returns: 

528 `dict` of parameter name to `TensorShape`. 

529 

530 Raises: 

531 ValueError: if `sample_shape` is a `TensorShape` and is not fully defined. 

532 """ 

533 if isinstance(sample_shape, tensor_shape.TensorShape): 

534 if not sample_shape.is_fully_defined(): 

535 raise ValueError("TensorShape sample_shape must be fully defined") 

536 sample_shape = sample_shape.as_list() 

537 

538 params = cls.param_shapes(sample_shape) 

539 

540 static_params = {} 

541 for name, shape in params.items(): 

542 static_shape = tensor_util.constant_value(shape) 

543 if static_shape is None: 

544 raise ValueError( 

545 "sample_shape must be a fully-defined TensorShape or list/tuple") 

546 static_params[name] = tensor_shape.TensorShape(static_shape) 

547 

548 return static_params 

549 

550 @staticmethod 

551 def _param_shapes(sample_shape): 

552 raise NotImplementedError("_param_shapes not implemented") 

553 

554 @property 

555 def name(self): 

556 """Name prepended to all ops created by this `Distribution`.""" 

557 return self._name 

558 

559 @property 

560 def dtype(self): 

561 """The `DType` of `Tensor`s handled by this `Distribution`.""" 

562 return self._dtype 

563 

564 @property 

565 def parameters(self): 

566 """Dictionary of parameters used to instantiate this `Distribution`.""" 

567 # Remove "self", "__class__", or other special variables. These can appear 

568 # if the subclass used: 

569 # `parameters = dict(locals())`. 

570 return {k: v for k, v in self._parameters.items() 

571 if not k.startswith("__") and k != "self"} 

572 

573 @property 

574 def reparameterization_type(self): 

575 """Describes how samples from the distribution are reparameterized. 

576 

577 Currently this is one of the static instances 

578 `distributions.FULLY_REPARAMETERIZED` 

579 or `distributions.NOT_REPARAMETERIZED`. 

580 

581 Returns: 

582 An instance of `ReparameterizationType`. 

583 """ 

584 return self._reparameterization_type 

585 

586 @property 

587 def allow_nan_stats(self): 

588 """Python `bool` describing behavior when a stat is undefined. 

589 

590 Stats return +/- infinity when it makes sense. E.g., the variance of a 

591 Cauchy distribution is infinity. However, sometimes the statistic is 

592 undefined, e.g., if a distribution's pdf does not achieve a maximum within 

593 the support of the distribution, the mode is undefined. If the mean is 

594 undefined, then by definition the variance is undefined. E.g. the mean for 

595 Student's T for df = 1 is undefined (no clear way to say it is either + or - 

596 infinity), so the variance = E[(X - mean)**2] is also undefined. 

597 

598 Returns: 

599 allow_nan_stats: Python `bool`. 

600 """ 

601 return self._allow_nan_stats 

602 

603 @property 

604 def validate_args(self): 

605 """Python `bool` indicating possibly expensive checks are enabled.""" 

606 return self._validate_args 

607 

608 def copy(self, **override_parameters_kwargs): 

609 """Creates a deep copy of the distribution. 

610 

611 Note: the copy distribution may continue to depend on the original 

612 initialization arguments. 

613 

614 Args: 

615 **override_parameters_kwargs: String/value dictionary of initialization 

616 arguments to override with new values. 

617 

618 Returns: 

619 distribution: A new instance of `type(self)` initialized from the union 

620 of self.parameters and override_parameters_kwargs, i.e., 

621 `dict(self.parameters, **override_parameters_kwargs)`. 

622 """ 

623 parameters = dict(self.parameters, **override_parameters_kwargs) 

624 return type(self)(**parameters) 

625 

626 def _batch_shape_tensor(self): 

627 raise NotImplementedError( 

628 "batch_shape_tensor is not implemented: {}".format(type(self).__name__)) 

629 

630 def batch_shape_tensor(self, name="batch_shape_tensor"): 

631 """Shape of a single sample from a single event index as a 1-D `Tensor`. 

632 

633 The batch dimensions are indexes into independent, non-identical 

634 parameterizations of this distribution. 

635 

636 Args: 

637 name: name to give to the op 

638 

639 Returns: 

640 batch_shape: `Tensor`. 

641 """ 

642 with self._name_scope(name): 

643 if self.batch_shape.is_fully_defined(): 

644 return ops.convert_to_tensor(self.batch_shape.as_list(), 

645 dtype=dtypes.int32, 

646 name="batch_shape") 

647 return self._batch_shape_tensor() 

648 

649 def _batch_shape(self): 

650 return tensor_shape.TensorShape(None) 

651 

652 @property 

653 def batch_shape(self): 

654 """Shape of a single sample from a single event index as a `TensorShape`. 

655 

656 May be partially defined or unknown. 

657 

658 The batch dimensions are indexes into independent, non-identical 

659 parameterizations of this distribution. 

660 

661 Returns: 

662 batch_shape: `TensorShape`, possibly unknown. 

663 """ 

664 return tensor_shape.as_shape(self._batch_shape()) 

665 

666 def _event_shape_tensor(self): 

667 raise NotImplementedError( 

668 "event_shape_tensor is not implemented: {}".format(type(self).__name__)) 

669 

670 def event_shape_tensor(self, name="event_shape_tensor"): 

671 """Shape of a single sample from a single batch as a 1-D int32 `Tensor`. 

672 

673 Args: 

674 name: name to give to the op 

675 

676 Returns: 

677 event_shape: `Tensor`. 

678 """ 

679 with self._name_scope(name): 

680 if self.event_shape.is_fully_defined(): 

681 return ops.convert_to_tensor(self.event_shape.as_list(), 

682 dtype=dtypes.int32, 

683 name="event_shape") 

684 return self._event_shape_tensor() 

685 

686 def _event_shape(self): 

687 return tensor_shape.TensorShape(None) 

688 

689 @property 

690 def event_shape(self): 

691 """Shape of a single sample from a single batch as a `TensorShape`. 

692 

693 May be partially defined or unknown. 

694 

695 Returns: 

696 event_shape: `TensorShape`, possibly unknown. 

697 """ 

698 return tensor_shape.as_shape(self._event_shape()) 

699 

700 def is_scalar_event(self, name="is_scalar_event"): 

701 """Indicates that `event_shape == []`. 

702 

703 Args: 

704 name: Python `str` prepended to names of ops created by this function. 

705 

706 Returns: 

707 is_scalar_event: `bool` scalar `Tensor`. 

708 """ 

709 with self._name_scope(name): 

710 return ops.convert_to_tensor( 

711 self._is_scalar_helper(self.event_shape, self.event_shape_tensor), 

712 name="is_scalar_event") 

713 

714 def is_scalar_batch(self, name="is_scalar_batch"): 

715 """Indicates that `batch_shape == []`. 

716 

717 Args: 

718 name: Python `str` prepended to names of ops created by this function. 

719 

720 Returns: 

721 is_scalar_batch: `bool` scalar `Tensor`. 

722 """ 

723 with self._name_scope(name): 

724 return ops.convert_to_tensor( 

725 self._is_scalar_helper(self.batch_shape, self.batch_shape_tensor), 

726 name="is_scalar_batch") 

727 

728 def _sample_n(self, n, seed=None): 

729 raise NotImplementedError("sample_n is not implemented: {}".format( 

730 type(self).__name__)) 

731 

732 def _call_sample_n(self, sample_shape, seed, name, **kwargs): 

733 with self._name_scope(name, values=[sample_shape]): 

734 sample_shape = ops.convert_to_tensor( 

735 sample_shape, dtype=dtypes.int32, name="sample_shape") 

736 sample_shape, n = self._expand_sample_shape_to_vector( 

737 sample_shape, "sample_shape") 

738 samples = self._sample_n(n, seed, **kwargs) 

739 batch_event_shape = array_ops.shape(samples)[1:] 

740 final_shape = array_ops.concat([sample_shape, batch_event_shape], 0) 

741 samples = array_ops.reshape(samples, final_shape) 

742 samples = self._set_sample_static_shape(samples, sample_shape) 

743 return samples 

744 

745 def sample(self, sample_shape=(), seed=None, name="sample"): 

746 """Generate samples of the specified shape. 

747 

748 Note that a call to `sample()` without arguments will generate a single 

749 sample. 

750 

751 Args: 

752 sample_shape: 0D or 1D `int32` `Tensor`. Shape of the generated samples. 

753 seed: Python integer seed for RNG 

754 name: name to give to the op. 

755 

756 Returns: 

757 samples: a `Tensor` with prepended dimensions `sample_shape`. 

758 """ 

759 return self._call_sample_n(sample_shape, seed, name) 

760 

761 def _log_prob(self, value): 

762 raise NotImplementedError("log_prob is not implemented: {}".format( 

763 type(self).__name__)) 

764 

765 def _call_log_prob(self, value, name, **kwargs): 

766 with self._name_scope(name, values=[value]): 

767 value = _convert_to_tensor( 

768 value, name="value", preferred_dtype=self.dtype) 

769 try: 

770 return self._log_prob(value, **kwargs) 

771 except NotImplementedError as original_exception: 

772 try: 

773 return math_ops.log(self._prob(value, **kwargs)) 

774 except NotImplementedError: 

775 raise original_exception 

776 

777 def log_prob(self, value, name="log_prob"): 

778 """Log probability density/mass function. 

779 

780 Args: 

781 value: `float` or `double` `Tensor`. 

782 name: Python `str` prepended to names of ops created by this function. 

783 

784 Returns: 

785 log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with 

786 values of type `self.dtype`. 

787 """ 

788 return self._call_log_prob(value, name) 

789 

790 def _prob(self, value): 

791 raise NotImplementedError("prob is not implemented: {}".format( 

792 type(self).__name__)) 

793 

794 def _call_prob(self, value, name, **kwargs): 

795 with self._name_scope(name, values=[value]): 

796 value = _convert_to_tensor( 

797 value, name="value", preferred_dtype=self.dtype) 

798 try: 

799 return self._prob(value, **kwargs) 

800 except NotImplementedError as original_exception: 

801 try: 

802 return math_ops.exp(self._log_prob(value, **kwargs)) 

803 except NotImplementedError: 

804 raise original_exception 

805 

806 def prob(self, value, name="prob"): 

807 """Probability density/mass function. 

808 

809 Args: 

810 value: `float` or `double` `Tensor`. 

811 name: Python `str` prepended to names of ops created by this function. 

812 

813 Returns: 

814 prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with 

815 values of type `self.dtype`. 

816 """ 

817 return self._call_prob(value, name) 

818 

819 def _log_cdf(self, value): 

820 raise NotImplementedError("log_cdf is not implemented: {}".format( 

821 type(self).__name__)) 

822 

823 def _call_log_cdf(self, value, name, **kwargs): 

824 with self._name_scope(name, values=[value]): 

825 value = _convert_to_tensor( 

826 value, name="value", preferred_dtype=self.dtype) 

827 try: 

828 return self._log_cdf(value, **kwargs) 

829 except NotImplementedError as original_exception: 

830 try: 

831 return math_ops.log(self._cdf(value, **kwargs)) 

832 except NotImplementedError: 

833 raise original_exception 

834 

835 def log_cdf(self, value, name="log_cdf"): 

836 """Log cumulative distribution function. 

837 

838 Given random variable `X`, the cumulative distribution function `cdf` is: 

839 

840 ```none 

841 log_cdf(x) := Log[ P[X <= x] ] 

842 ``` 

843 

844 Often, a numerical approximation can be used for `log_cdf(x)` that yields 

845 a more accurate answer than simply taking the logarithm of the `cdf` when 

846 `x << -1`. 

847 

848 Args: 

849 value: `float` or `double` `Tensor`. 

850 name: Python `str` prepended to names of ops created by this function. 

851 

852 Returns: 

853 logcdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with 

854 values of type `self.dtype`. 

855 """ 

856 return self._call_log_cdf(value, name) 

857 

858 def _cdf(self, value): 

859 raise NotImplementedError("cdf is not implemented: {}".format( 

860 type(self).__name__)) 

861 

862 def _call_cdf(self, value, name, **kwargs): 

863 with self._name_scope(name, values=[value]): 

864 value = _convert_to_tensor( 

865 value, name="value", preferred_dtype=self.dtype) 

866 try: 

867 return self._cdf(value, **kwargs) 

868 except NotImplementedError as original_exception: 

869 try: 

870 return math_ops.exp(self._log_cdf(value, **kwargs)) 

871 except NotImplementedError: 

872 raise original_exception 

873 

874 def cdf(self, value, name="cdf"): 

875 """Cumulative distribution function. 

876 

877 Given random variable `X`, the cumulative distribution function `cdf` is: 

878 

879 ```none 

880 cdf(x) := P[X <= x] 

881 ``` 

882 

883 Args: 

884 value: `float` or `double` `Tensor`. 

885 name: Python `str` prepended to names of ops created by this function. 

886 

887 Returns: 

888 cdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with 

889 values of type `self.dtype`. 

890 """ 

891 return self._call_cdf(value, name) 

892 

893 def _log_survival_function(self, value): 

894 raise NotImplementedError( 

895 "log_survival_function is not implemented: {}".format( 

896 type(self).__name__)) 

897 

898 def _call_log_survival_function(self, value, name, **kwargs): 

899 with self._name_scope(name, values=[value]): 

900 value = _convert_to_tensor( 

901 value, name="value", preferred_dtype=self.dtype) 

902 try: 

903 return self._log_survival_function(value, **kwargs) 

904 except NotImplementedError as original_exception: 

905 try: 

906 return math_ops.log1p(-self.cdf(value, **kwargs)) 

907 except NotImplementedError: 

908 raise original_exception 

909 

910 def log_survival_function(self, value, name="log_survival_function"): 

911 """Log survival function. 

912 

913 Given random variable `X`, the survival function is defined: 

914 

915 ```none 

916 log_survival_function(x) = Log[ P[X > x] ] 

917 = Log[ 1 - P[X <= x] ] 

918 = Log[ 1 - cdf(x) ] 

919 ``` 

920 

921 Typically, different numerical approximations can be used for the log 

922 survival function, which are more accurate than `1 - cdf(x)` when `x >> 1`. 

923 

924 Args: 

925 value: `float` or `double` `Tensor`. 

926 name: Python `str` prepended to names of ops created by this function. 

927 

928 Returns: 

929 `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type 

930 `self.dtype`. 

931 """ 

932 return self._call_log_survival_function(value, name) 

933 

934 def _survival_function(self, value): 

935 raise NotImplementedError("survival_function is not implemented: {}".format( 

936 type(self).__name__)) 

937 

938 def _call_survival_function(self, value, name, **kwargs): 

939 with self._name_scope(name, values=[value]): 

940 value = _convert_to_tensor( 

941 value, name="value", preferred_dtype=self.dtype) 

942 try: 

943 return self._survival_function(value, **kwargs) 

944 except NotImplementedError as original_exception: 

945 try: 

946 return 1. - self.cdf(value, **kwargs) 

947 except NotImplementedError: 

948 raise original_exception 

949 

950 def survival_function(self, value, name="survival_function"): 

951 """Survival function. 

952 

953 Given random variable `X`, the survival function is defined: 

954 

955 ```none 

956 survival_function(x) = P[X > x] 

957 = 1 - P[X <= x] 

958 = 1 - cdf(x). 

959 ``` 

960 

961 Args: 

962 value: `float` or `double` `Tensor`. 

963 name: Python `str` prepended to names of ops created by this function. 

964 

965 Returns: 

966 `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type 

967 `self.dtype`. 

968 """ 

969 return self._call_survival_function(value, name) 

970 

971 def _entropy(self): 

972 raise NotImplementedError("entropy is not implemented: {}".format( 

973 type(self).__name__)) 

974 

975 def entropy(self, name="entropy"): 

976 """Shannon entropy in nats.""" 

977 with self._name_scope(name): 

978 return self._entropy() 

979 

980 def _mean(self): 

981 raise NotImplementedError("mean is not implemented: {}".format( 

982 type(self).__name__)) 

983 

984 def mean(self, name="mean"): 

985 """Mean.""" 

986 with self._name_scope(name): 

987 return self._mean() 

988 

989 def _quantile(self, value): 

990 raise NotImplementedError("quantile is not implemented: {}".format( 

991 type(self).__name__)) 

992 

993 def _call_quantile(self, value, name, **kwargs): 

994 with self._name_scope(name, values=[value]): 

995 value = _convert_to_tensor( 

996 value, name="value", preferred_dtype=self.dtype) 

997 return self._quantile(value, **kwargs) 

998 

999 def quantile(self, value, name="quantile"): 

1000 """Quantile function. Aka "inverse cdf" or "percent point function". 

1001 

1002 Given random variable `X` and `p in [0, 1]`, the `quantile` is: 

1003 

1004 ```none 

1005 quantile(p) := x such that P[X <= x] == p 

1006 ``` 

1007 

1008 Args: 

1009 value: `float` or `double` `Tensor`. 

1010 name: Python `str` prepended to names of ops created by this function. 

1011 

1012 Returns: 

1013 quantile: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with 

1014 values of type `self.dtype`. 

1015 """ 

1016 return self._call_quantile(value, name) 

1017 

1018 def _variance(self): 

1019 raise NotImplementedError("variance is not implemented: {}".format( 

1020 type(self).__name__)) 

1021 

1022 def variance(self, name="variance"): 

1023 """Variance. 

1024 

1025 Variance is defined as, 

1026 

1027 ```none 

1028 Var = E[(X - E[X])**2] 

1029 ``` 

1030 

1031 where `X` is the random variable associated with this distribution, `E` 

1032 denotes expectation, and `Var.shape = batch_shape + event_shape`. 

1033 

1034 Args: 

1035 name: Python `str` prepended to names of ops created by this function. 

1036 

1037 Returns: 

1038 variance: Floating-point `Tensor` with shape identical to 

1039 `batch_shape + event_shape`, i.e., the same shape as `self.mean()`. 

1040 """ 

1041 with self._name_scope(name): 

1042 try: 

1043 return self._variance() 

1044 except NotImplementedError as original_exception: 

1045 try: 

1046 return math_ops.square(self._stddev()) 

1047 except NotImplementedError: 

1048 raise original_exception 

1049 

1050 def _stddev(self): 

1051 raise NotImplementedError("stddev is not implemented: {}".format( 

1052 type(self).__name__)) 

1053 

1054 def stddev(self, name="stddev"): 

1055 """Standard deviation. 

1056 

1057 Standard deviation is defined as, 

1058 

1059 ```none 

1060 stddev = E[(X - E[X])**2]**0.5 

1061 ``` 

1062 

1063 where `X` is the random variable associated with this distribution, `E` 

1064 denotes expectation, and `stddev.shape = batch_shape + event_shape`. 

1065 

1066 Args: 

1067 name: Python `str` prepended to names of ops created by this function. 

1068 

1069 Returns: 

1070 stddev: Floating-point `Tensor` with shape identical to 

1071 `batch_shape + event_shape`, i.e., the same shape as `self.mean()`. 

1072 """ 

1073 

1074 with self._name_scope(name): 

1075 try: 

1076 return self._stddev() 

1077 except NotImplementedError as original_exception: 

1078 try: 

1079 return math_ops.sqrt(self._variance()) 

1080 except NotImplementedError: 

1081 raise original_exception 

1082 

1083 def _covariance(self): 

1084 raise NotImplementedError("covariance is not implemented: {}".format( 

1085 type(self).__name__)) 

1086 

1087 def covariance(self, name="covariance"): 

1088 """Covariance. 

1089 

1090 Covariance is (possibly) defined only for non-scalar-event distributions. 

1091 

1092 For example, for a length-`k`, vector-valued distribution, it is calculated 

1093 as, 

1094 

1095 ```none 

1096 Cov[i, j] = Covariance(X_i, X_j) = E[(X_i - E[X_i]) (X_j - E[X_j])] 

1097 ``` 

1098 

1099 where `Cov` is a (batch of) `k x k` matrix, `0 <= (i, j) < k`, and `E` 

1100 denotes expectation. 

1101 

1102 Alternatively, for non-vector, multivariate distributions (e.g., 

1103 matrix-valued, Wishart), `Covariance` shall return a (batch of) matrices 

1104 under some vectorization of the events, i.e., 

1105 

1106 ```none 

1107 Cov[i, j] = Covariance(Vec(X)_i, Vec(X)_j) = [as above] 

1108 ``` 

1109 

1110 where `Cov` is a (batch of) `k' x k'` matrices, 

1111 `0 <= (i, j) < k' = reduce_prod(event_shape)`, and `Vec` is some function 

1112 mapping indices of this distribution's event dimensions to indices of a 

1113 length-`k'` vector. 

1114 

1115 Args: 

1116 name: Python `str` prepended to names of ops created by this function. 

1117 

1118 Returns: 

1119 covariance: Floating-point `Tensor` with shape `[B1, ..., Bn, k', k']` 

1120 where the first `n` dimensions are batch coordinates and 

1121 `k' = reduce_prod(self.event_shape)`. 

1122 """ 

1123 with self._name_scope(name): 

1124 return self._covariance() 

1125 

1126 def _mode(self): 

1127 raise NotImplementedError("mode is not implemented: {}".format( 

1128 type(self).__name__)) 

1129 

1130 def mode(self, name="mode"): 

1131 """Mode.""" 

1132 with self._name_scope(name): 

1133 return self._mode() 

1134 

1135 def _cross_entropy(self, other): 

1136 return kullback_leibler.cross_entropy( 

1137 self, other, allow_nan_stats=self.allow_nan_stats) 

1138 

1139 def cross_entropy(self, other, name="cross_entropy"): 

1140 """Computes the (Shannon) cross entropy. 

1141 

1142 Denote this distribution (`self`) by `P` and the `other` distribution by 

1143 `Q`. Assuming `P, Q` are absolutely continuous with respect to 

1144 one another and permit densities `p(x) dr(x)` and `q(x) dr(x)`, (Shanon) 

1145 cross entropy is defined as: 

1146 

1147 ```none 

1148 H[P, Q] = E_p[-log q(X)] = -int_F p(x) log q(x) dr(x) 

1149 ``` 

1150 

1151 where `F` denotes the support of the random variable `X ~ P`. 

1152 

1153 Args: 

1154 other: `tfp.distributions.Distribution` instance. 

1155 name: Python `str` prepended to names of ops created by this function. 

1156 

1157 Returns: 

1158 cross_entropy: `self.dtype` `Tensor` with shape `[B1, ..., Bn]` 

1159 representing `n` different calculations of (Shanon) cross entropy. 

1160 """ 

1161 with self._name_scope(name): 

1162 return self._cross_entropy(other) 

1163 

1164 def _kl_divergence(self, other): 

1165 return kullback_leibler.kl_divergence( 

1166 self, other, allow_nan_stats=self.allow_nan_stats) 

1167 

1168 def kl_divergence(self, other, name="kl_divergence"): 

1169 """Computes the Kullback--Leibler divergence. 

1170 

1171 Denote this distribution (`self`) by `p` and the `other` distribution by 

1172 `q`. Assuming `p, q` are absolutely continuous with respect to reference 

1173 measure `r`, the KL divergence is defined as: 

1174 

1175 ```none 

1176 KL[p, q] = E_p[log(p(X)/q(X))] 

1177 = -int_F p(x) log q(x) dr(x) + int_F p(x) log p(x) dr(x) 

1178 = H[p, q] - H[p] 

1179 ``` 

1180 

1181 where `F` denotes the support of the random variable `X ~ p`, `H[., .]` 

1182 denotes (Shanon) cross entropy, and `H[.]` denotes (Shanon) entropy. 

1183 

1184 Args: 

1185 other: `tfp.distributions.Distribution` instance. 

1186 name: Python `str` prepended to names of ops created by this function. 

1187 

1188 Returns: 

1189 kl_divergence: `self.dtype` `Tensor` with shape `[B1, ..., Bn]` 

1190 representing `n` different calculations of the Kullback-Leibler 

1191 divergence. 

1192 """ 

1193 with self._name_scope(name): 

1194 return self._kl_divergence(other) 

1195 

1196 def __str__(self): 

1197 return ("tfp.distributions.{type_name}(" 

1198 "\"{self_name}\"" 

1199 "{maybe_batch_shape}" 

1200 "{maybe_event_shape}" 

1201 ", dtype={dtype})".format( 

1202 type_name=type(self).__name__, 

1203 self_name=self.name, 

1204 maybe_batch_shape=(", batch_shape={}".format(self.batch_shape) 

1205 if self.batch_shape.ndims is not None 

1206 else ""), 

1207 maybe_event_shape=(", event_shape={}".format(self.event_shape) 

1208 if self.event_shape.ndims is not None 

1209 else ""), 

1210 dtype=self.dtype.name)) 

1211 

1212 def __repr__(self): 

1213 return ("<tfp.distributions.{type_name} " 

1214 "'{self_name}'" 

1215 " batch_shape={batch_shape}" 

1216 " event_shape={event_shape}" 

1217 " dtype={dtype}>".format( 

1218 type_name=type(self).__name__, 

1219 self_name=self.name, 

1220 batch_shape=self.batch_shape, 

1221 event_shape=self.event_shape, 

1222 dtype=self.dtype.name)) 

1223 

1224 @contextlib.contextmanager 

1225 def _name_scope(self, name=None, values=None): 

1226 """Helper function to standardize op scope.""" 

1227 with ops.name_scope(self.name): 

1228 with ops.name_scope(name, values=( 

1229 ([] if values is None else values) + self._graph_parents)) as scope: 

1230 yield scope 

1231 

1232 def _expand_sample_shape_to_vector(self, x, name): 

1233 """Helper to `sample` which ensures input is 1D.""" 

1234 x_static_val = tensor_util.constant_value(x) 

1235 if x_static_val is None: 

1236 prod = math_ops.reduce_prod(x) 

1237 else: 

1238 prod = np.prod(x_static_val, dtype=x.dtype.as_numpy_dtype()) 

1239 

1240 ndims = x.get_shape().ndims # != sample_ndims 

1241 if ndims is None: 

1242 # Maybe expand_dims. 

1243 ndims = array_ops.rank(x) 

1244 expanded_shape = util.pick_vector( 

1245 math_ops.equal(ndims, 0), 

1246 np.array([1], dtype=np.int32), array_ops.shape(x)) 

1247 x = array_ops.reshape(x, expanded_shape) 

1248 elif ndims == 0: 

1249 # Definitely expand_dims. 

1250 if x_static_val is not None: 

1251 x = ops.convert_to_tensor( 

1252 np.array([x_static_val], dtype=x.dtype.as_numpy_dtype()), 

1253 name=name) 

1254 else: 

1255 x = array_ops.reshape(x, [1]) 

1256 elif ndims != 1: 

1257 raise ValueError("Input is neither scalar nor vector.") 

1258 

1259 return x, prod 

1260 

1261 def _set_sample_static_shape(self, x, sample_shape): 

1262 """Helper to `sample`; sets static shape info.""" 

1263 # Set shape hints. 

1264 sample_shape = tensor_shape.TensorShape( 

1265 tensor_util.constant_value(sample_shape)) 

1266 

1267 ndims = x.get_shape().ndims 

1268 sample_ndims = sample_shape.ndims 

1269 batch_ndims = self.batch_shape.ndims 

1270 event_ndims = self.event_shape.ndims 

1271 

1272 # Infer rank(x). 

1273 if (ndims is None and 

1274 sample_ndims is not None and 

1275 batch_ndims is not None and 

1276 event_ndims is not None): 

1277 ndims = sample_ndims + batch_ndims + event_ndims 

1278 x.set_shape([None] * ndims) 

1279 

1280 # Infer sample shape. 

1281 if ndims is not None and sample_ndims is not None: 

1282 shape = sample_shape.concatenate([None]*(ndims - sample_ndims)) 

1283 x.set_shape(x.get_shape().merge_with(shape)) 

1284 

1285 # Infer event shape. 

1286 if ndims is not None and event_ndims is not None: 

1287 shape = tensor_shape.TensorShape( 

1288 [None]*(ndims - event_ndims)).concatenate(self.event_shape) 

1289 x.set_shape(x.get_shape().merge_with(shape)) 

1290 

1291 # Infer batch shape. 

1292 if batch_ndims is not None: 

1293 if ndims is not None: 

1294 if sample_ndims is None and event_ndims is not None: 

1295 sample_ndims = ndims - batch_ndims - event_ndims 

1296 elif event_ndims is None and sample_ndims is not None: 

1297 event_ndims = ndims - batch_ndims - sample_ndims 

1298 if sample_ndims is not None and event_ndims is not None: 

1299 shape = tensor_shape.TensorShape([None]*sample_ndims).concatenate( 

1300 self.batch_shape).concatenate([None]*event_ndims) 

1301 x.set_shape(x.get_shape().merge_with(shape)) 

1302 

1303 return x 

1304 

1305 def _is_scalar_helper(self, static_shape, dynamic_shape_fn): 

1306 """Implementation for `is_scalar_batch` and `is_scalar_event`.""" 

1307 if static_shape.ndims is not None: 

1308 return static_shape.ndims == 0 

1309 shape = dynamic_shape_fn() 

1310 if (shape.get_shape().ndims is not None and 

1311 shape.get_shape().dims[0].value is not None): 

1312 # If the static_shape is correctly written then we should never execute 

1313 # this branch. We keep it just in case there's some unimagined corner 

1314 # case. 

1315 return shape.get_shape().as_list() == [0] 

1316 return math_ops.equal(array_ops.shape(shape)[0], 0)