Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/stateful_random_ops.py: 31%
307 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations for generating random numbers."""
17from tensorflow.python.distribute import distribute_lib
18from tensorflow.python.distribute import sharded_variable
19from tensorflow.python.distribute import values_util
20from tensorflow.python.eager import context
21from tensorflow.python.framework import config
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import array_ops_stack
26from tensorflow.python.ops import gen_stateful_random_ops
27from tensorflow.python.ops import gen_stateless_random_ops_v2
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import resource_variable_ops
30from tensorflow.python.ops import stateless_random_ops
31from tensorflow.python.ops import variables
32from tensorflow.python.ops.stateless_random_ops import Algorithm
33from tensorflow.python.trackable import autotrackable
34from tensorflow.python.util import nest
35from tensorflow.python.util.tf_export import tf_export
38# A seed for random ops (stateful and stateless) will always be 1024
39# bits, all of which will be sent to the C++ code. The actual C++
40# implementation of some algorithms may only use a lower part of the bits.
42UINT64_HALF_SPAN = 2**63
43MAX_INT64 = UINT64_HALF_SPAN - 1
44MIN_INT64 = -UINT64_HALF_SPAN
45UINT64_SPAN = UINT64_HALF_SPAN * 2
46# 'Variable' doesn't support uint32 or uint64 yet (due to reasons explained in
47# b/111604096 and cl/171681867), so I use signed int here. I choose int64
48# instead of int32 here because `VarHandleOp` doesn't support int32 on GPU.
49SEED_TYPE = "int64"
50SEED_MIN = MIN_INT64
51SEED_MAX = MAX_INT64
52SEED_UINT_SPAN = UINT64_SPAN
53SEED_TYPE_BITS = 64
54SEED_BIT_MASK = 0xFFFFFFFFFFFFFFFF
55SEED_SIZE = 16 # in units of SEED_TYPE
58STATE_TYPE = SEED_TYPE
59ALGORITHM_TYPE = STATE_TYPE
62# The following sizes are all in unit of uint64.
63PHILOX_KEY_SIZE = 1
64THREEFRY_KEY_SIZE = 1
65PHILOX_COUNTER_SIZE = 2
66THREEFRY_COUNTER_SIZE = 1
67PHILOX_STATE_SIZE = PHILOX_COUNTER_SIZE + PHILOX_KEY_SIZE
68THREEFRY_STATE_SIZE = THREEFRY_COUNTER_SIZE + THREEFRY_KEY_SIZE
71RNG_ALG_PHILOX = Algorithm.PHILOX.value
72RNG_ALG_THREEFRY = Algorithm.THREEFRY.value
75DEFAULT_ALGORITHM = RNG_ALG_PHILOX
78def non_deterministic_ints(shape, dtype=dtypes.int64):
79 """Non-deterministically generates some integers.
81 This op may use some OS-provided source of non-determinism (e.g. an RNG), so
82 each execution will give different results.
84 Args:
85 shape: the shape of the result.
86 dtype: (optional) the dtype of the result.
88 Returns:
89 a tensor whose element values are non-deterministically chosen.
90 """
91 return gen_stateful_random_ops.non_deterministic_ints(
92 shape=shape, dtype=dtype)
95def _uint_to_int(n):
96 if isinstance(n, int) and n > SEED_MAX:
97 n = n - SEED_UINT_SPAN
98 return n
101def _make_1d_state(state_size, seed):
102 """Makes a 1-D RNG state.
104 Args:
105 state_size: an integer.
106 seed: an integer or 1-D tensor.
108 Returns:
109 a 1-D tensor of shape [state_size] and dtype STATE_TYPE.
110 """
111 if isinstance(seed, int):
112 # chop the Python integer (infinite precision) into chunks of SEED_TYPE
113 ls = []
114 for _ in range(state_size):
115 ls.append(seed & SEED_BIT_MASK)
116 seed >>= SEED_TYPE_BITS
117 seed = ls
118 # to avoid overflow error from ops.convert_to_tensor
119 seed = nest.map_structure(_uint_to_int, seed)
120 seed = math_ops.cast(seed, STATE_TYPE)
121 seed = array_ops.reshape(seed, [-1])
122 seed = seed[0:state_size]
123 # Padding with zeros on the *left* if too short. Padding on the right would
124 # cause a small seed to be used as the "counter" while the "key" is always
125 # zero (for counter-based RNG algorithms), because in the current memory
126 # layout counter is stored before key. In such a situation two RNGs with
127 # two different small seeds may generate overlapping outputs.
128 seed_size = seed.shape[0]
129 if seed_size is None:
130 seed_size = array_ops.shape(seed)[0]
131 padding_size = math_ops.maximum(state_size - seed_size, 0)
132 padding = array_ops.zeros([padding_size], seed.dtype)
133 # can't use `pad` because it doesn't support integer dtypes on GPU
134 seed = array_ops.concat([padding, seed], axis=0)
135 seed.set_shape([state_size])
136 return seed
139def _get_counter_size(alg):
140 if alg == Algorithm.PHILOX.value:
141 return PHILOX_COUNTER_SIZE
142 elif alg == Algorithm.THREEFRY.value:
143 return THREEFRY_COUNTER_SIZE
144 elif alg == Algorithm.AUTO_SELECT.value:
145 # For AUTO_SELECT, we'll manage the counter as if it's for Philox.
146 return PHILOX_COUNTER_SIZE
147 else:
148 raise ValueError(stateless_random_ops.unsupported_alg_error_msg(alg))
151def _get_state_size(alg):
152 if alg == Algorithm.PHILOX.value:
153 return PHILOX_STATE_SIZE
154 elif alg == Algorithm.THREEFRY.value:
155 return THREEFRY_STATE_SIZE
156 elif alg == Algorithm.AUTO_SELECT.value:
157 # For AUTO_SELECT, we'll manage the state as if it's for Philox.
158 return PHILOX_STATE_SIZE
159 else:
160 raise ValueError(stateless_random_ops.unsupported_alg_error_msg(alg))
163def _check_state_shape(shape, alg):
164 if isinstance(alg, ops.Tensor) and not context.executing_eagerly():
165 return
166 shape.assert_is_compatible_with([_get_state_size(int(alg))])
169def _make_state_from_seed(seed, alg):
170 return _make_1d_state(_get_state_size(alg), seed)
173@tf_export("random.create_rng_state", "random.experimental.create_rng_state")
174def create_rng_state(seed, alg):
175 """Creates a RNG state from an integer or a vector.
177 Example:
179 >>> tf.random.create_rng_state(
180 ... 1234, "philox")
181 <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1234, 0, 0])>
182 >>> tf.random.create_rng_state(
183 ... [12, 34], "threefry")
184 <tf.Tensor: shape=(2,), dtype=int64, numpy=array([12, 34])>
186 Args:
187 seed: an integer or 1-D numpy array.
188 alg: the RNG algorithm. Can be a string, an `Algorithm` or an integer.
190 Returns:
191 a 1-D numpy array whose size depends on the algorithm.
192 """
193 alg = stateless_random_ops.convert_alg_to_int(alg)
194 return _make_state_from_seed(seed, alg)
197def _shape_tensor(shape):
198 """Convert to an int32 or int64 tensor, defaulting to int64 if empty."""
199 if isinstance(shape, (tuple, list)) and not shape:
200 dtype = dtypes.int64
201 else:
202 dtype = None
203 return ops.convert_to_tensor(shape, dtype=dtype, name="shape")
206def _convert_to_state_tensor(t):
207 # to avoid out-of-range error from ops.convert_to_tensor
208 t = nest.map_structure(_uint_to_int, t)
209 return math_ops.cast(t, STATE_TYPE)
212def get_replica_id():
213 rctx = distribute_lib.get_replica_context()
214 if rctx is None:
215 return None
216 return rctx.replica_id_in_sync_group
219@tf_export("random.Generator", "random.experimental.Generator")
220class Generator(autotrackable.AutoTrackable):
221 """Random-number generator.
223 Example:
225 Creating a generator from a seed:
227 >>> g = tf.random.Generator.from_seed(1234)
228 >>> g.normal(shape=(2, 3))
229 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
230 array([[ 0.9356609 , 1.0854305 , -0.93788373],
231 [-0.5061547 , 1.3169702 , 0.7137579 ]], dtype=float32)>
233 Creating a generator from a non-deterministic state:
235 >>> g = tf.random.Generator.from_non_deterministic_state()
236 >>> g.normal(shape=(2, 3))
237 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...>
239 All the constructors allow explicitly choosing an Random-Number-Generation
240 (RNG) algorithm. Supported algorithms are `"philox"` and `"threefry"`. For
241 example:
243 >>> g = tf.random.Generator.from_seed(123, alg="philox")
244 >>> g.normal(shape=(2, 3))
245 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
246 array([[ 0.8673864 , -0.29899067, -0.9310337 ],
247 [-1.5828488 , 1.2481191 , -0.6770643 ]], dtype=float32)>
249 CPU, GPU and TPU with the same algorithm and seed will generate the same
250 integer random numbers. Float-point results (such as the output of `normal`)
251 may have small numerical discrepancies between different devices.
253 This class uses a `tf.Variable` to manage its internal state. Every time
254 random numbers are generated, the state of the generator will change. For
255 example:
257 >>> g = tf.random.Generator.from_seed(1234)
258 >>> g.state
259 <tf.Variable ... numpy=array([1234, 0, 0])>
260 >>> g.normal(shape=(2, 3))
261 <...>
262 >>> g.state
263 <tf.Variable ... numpy=array([2770, 0, 0])>
265 The shape of the state is algorithm-specific.
267 There is also a global generator:
269 >>> g = tf.random.get_global_generator()
270 >>> g.normal(shape=(2, 3))
271 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...>
273 When creating a generator inside a `tf.distribute.Strategy` scope, each
274 replica will get a different stream of random numbers.
276 For example, in this code:
278 ```
279 strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
280 with strat.scope():
281 g = tf.random.Generator.from_seed(1)
282 def f():
283 return g.normal([])
284 results = strat.run(f).values
285 ```
287 `results[0]` and `results[1]` will have different values.
289 If the generator is seeded (e.g. created via `Generator.from_seed`), the
290 random numbers will be determined by the seed, even though different replicas
291 get different numbers. One can think of a random number generated on a
292 replica as a hash of the replica ID and a "master" random number that may be
293 common to all replicas. Hence, the whole system is still deterministic.
295 (Note that the random numbers on different replicas are not correlated, even
296 if they are deterministically determined by the same seed. They are not
297 correlated in the sense that no matter what statistics one calculates on them,
298 there won't be any discernable correlation.)
300 Generators can be freely saved and restored using `tf.train.Checkpoint`. The
301 checkpoint can be restored in a distribution strategy with a different number
302 of replicas than the original strategy. If a replica ID is present in both the
303 original and the new distribution strategy, its state will be properly
304 restored (i.e. the random-number stream from the restored point will be the
305 same as that from the saving point) unless the replicas have already diverged
306 in their RNG call traces before saving (e.g. one replica has made one RNG call
307 while another has made two RNG calls). We don't have such guarantee if the
308 generator is saved in a strategy scope and restored outside of any strategy
309 scope, or vice versa.
311 When a generator is created within the scope of
312 `tf.distribute.experimental.ParameterServerStrategy`, the workers
313 will share the generator's state (placed on one of the parameter
314 servers). In this way the workers will still get different
315 random-number streams, as stated above. (This is similar to replicas
316 in a `tf.distribute.MirroredStrategy` sequentially accessing a
317 generator created outside the strategy.) Each RNG call on a worker
318 will incur a round-trip to a parameter server, which may have
319 performance impacts. When creating a
320 `tf.distribute.experimental.ParameterServerStrategy`, please make
321 sure that the `variable_partitioner` argument won't shard small
322 variables of shape `[2]` or `[3]` (because generator states must not
323 be sharded). Ways to avoid sharding small variables include setting
324 `variable_partitioner` to `None` or to
325 `tf.distribute.experimental.partitioners.MinSizePartitioner` with a
326 large enough `min_shard_bytes` (see
327 `tf.distribute.experimental.ParameterServerStrategy`'s documentation
328 for more details).
329 """
331 @classmethod
332 def from_state(cls, state, alg):
333 """Creates a generator from a state.
335 See `__init__` for description of `state` and `alg`.
337 Args:
338 state: the new state.
339 alg: the RNG algorithm.
341 Returns:
342 The new generator.
343 """
344 return cls(alg=alg, state=state)
346 @classmethod
347 def from_seed(cls, seed, alg=None):
348 """Creates a generator from a seed.
350 A seed is a 1024-bit unsigned integer represented either as a Python
351 integer or a vector of integers. Seeds shorter than 1024-bit will be
352 padded. The padding, the internal structure of a seed and the way a seed
353 is converted to a state are all opaque (unspecified). The only semantics
354 specification of seeds is that two different seeds are likely to produce
355 two independent generators (but no guarantee).
357 Args:
358 seed: the seed for the RNG.
359 alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
360 `__init__` for its possible values.
362 Returns:
363 The new generator.
364 """
365 if alg is None:
366 # TODO(b/170668986): more sophisticated algorithm selection
367 alg = DEFAULT_ALGORITHM
368 alg = stateless_random_ops.convert_alg_to_int(alg)
369 state = create_rng_state(seed, alg)
370 return cls(state=state, alg=alg)
372 @classmethod
373 def from_non_deterministic_state(cls, alg=None):
374 """Creates a generator by non-deterministically initializing its state.
376 The source of the non-determinism will be platform- and time-dependent.
378 Args:
379 alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
380 `__init__` for its possible values.
382 Returns:
383 The new generator.
384 """
385 if config.is_op_determinism_enabled():
386 raise RuntimeError('"from_non_deterministic_state" cannot be called when ' # pylint: disable=g-doc-exception
387 "determinism is enabled.")
388 if alg is None:
389 # TODO(b/170668986): more sophisticated algorithm selection
390 alg = DEFAULT_ALGORITHM
391 alg = stateless_random_ops.convert_alg_to_int(alg)
392 state = non_deterministic_ints(shape=[_get_state_size(alg)],
393 dtype=SEED_TYPE)
394 return cls(state=state, alg=alg)
396 @classmethod
397 def from_key_counter(cls, key, counter, alg):
398 """Creates a generator from a key and a counter.
400 This constructor only applies if the algorithm is a counter-based algorithm.
401 See method `key` for the meaning of "key" and "counter".
403 Args:
404 key: the key for the RNG, a scalar of type STATE_TYPE.
405 counter: a vector of dtype STATE_TYPE representing the initial counter for
406 the RNG, whose length is algorithm-specific.,
407 alg: the RNG algorithm. If None, it will be auto-selected. See
408 `__init__` for its possible values.
410 Returns:
411 The new generator.
412 """
413 counter = _convert_to_state_tensor(counter)
414 key = _convert_to_state_tensor(key)
415 alg = stateless_random_ops.convert_alg_to_int(alg)
416 counter.shape.assert_is_compatible_with([_get_state_size(alg) - 1])
417 key.shape.assert_is_compatible_with([])
418 key = array_ops.reshape(key, [1])
419 state = array_ops.concat([counter, key], 0)
420 return cls(state=state, alg=alg)
422 def __init__(self, copy_from=None, state=None, alg=None):
423 """Creates a generator.
425 The new generator will be initialized by one of the following ways, with
426 decreasing precedence:
427 (1) If `copy_from` is not None, the new generator is initialized by copying
428 information from another generator.
429 (2) If `state` and `alg` are not None (they must be set together), the new
430 generator is initialized by a state.
432 Args:
433 copy_from: a generator to be copied from.
434 state: a vector of dtype STATE_TYPE representing the initial state of the
435 RNG, whose length and semantics are algorithm-specific. If it's a
436 variable, the generator will reuse it instead of creating a new
437 variable.
438 alg: the RNG algorithm. Possible values are
439 `tf.random.Algorithm.PHILOX` for the Philox algorithm and
440 `tf.random.Algorithm.THREEFRY` for the ThreeFry algorithm
441 (see paper 'Parallel Random Numbers: As Easy as 1, 2, 3'
442 [https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]).
443 The string names `"philox"` and `"threefry"` can also be used.
444 Note `PHILOX` guarantees the same numbers are produced (given
445 the same random state) across all architectures (CPU, GPU, XLA etc).
446 """
447 # TODO(b/175072242): Remove distribution-strategy dependencies in this file.
448 if distribute_lib.has_strategy():
449 self._distribution_strategy = distribute_lib.get_strategy()
450 else:
451 self._distribution_strategy = None
452 if copy_from is not None:
453 # All other arguments should be None
454 assert (alg or state) is None
455 self._state_var = self._create_variable(copy_from.state, dtype=STATE_TYPE,
456 trainable=False)
457 self._alg = copy_from.algorithm
458 else:
459 assert alg is not None and state is not None
460 alg = stateless_random_ops.convert_alg_to_int(alg)
461 if isinstance(state, variables.Variable):
462 _check_state_shape(state.shape, alg)
463 self._state_var = state
464 else:
465 state = _convert_to_state_tensor(state)
466 _check_state_shape(state.shape, alg)
467 self._state_var = self._create_variable(state, dtype=STATE_TYPE,
468 trainable=False)
469 self._alg = alg
471 def _create_variable(self, *args, **kwargs):
472 """Creates a variable.
474 Args:
475 *args: positional arguments passed along to `variables.Variable.
476 **kwargs: keyword arguments passed along to `variables.Variable.
478 Returns:
479 The created variable.
480 """
481 with ops.name_scope("random_generator"):
482 # Make sure we don't change this name since Keras was using this name
483 # to filter out the state variable.
484 kwargs["name"] = "StateVar"
485 v = variables.Variable(*args, **kwargs)
486 if isinstance(v, sharded_variable.ShardedVariable):
487 # RNG state is an atomic entity representing a 128-bit or
488 # 192-bit value, so it mustn't be sharded.
489 raise ValueError(
490 "tf.random.Generator state is sharded, which is not allowed. When "
491 "creating a tf.distribute.experimental.ParameterServerStrategy, "
492 "please make sure that the `variable_partitioner` "
493 "argument won't shard a "
494 "small variable of shape [2] or [3]. Ways to avoid sharding small "
495 "variables include setting `variable_partitioner` to None or to "
496 "tf.distribute.experimental.partitioners.MinSizePartitioner with a "
497 "large enough `min_shard_bytes`.")
498 return v
500 def reset(self, state):
501 """Resets the generator by a new state.
503 See `__init__` for the meaning of "state".
505 Args:
506 state: the new state.
507 """
508 state = _convert_to_state_tensor(state)
509 state.shape.assert_is_compatible_with([_get_state_size(self.algorithm)])
510 self._state_var.assign(state)
512 def reset_from_seed(self, seed):
513 """Resets the generator by a new seed.
515 See `from_seed` for the meaning of "seed".
517 Args:
518 seed: the new seed.
519 """
520 state = create_rng_state(seed, self.algorithm)
521 self._state_var.assign(state)
523 def reset_from_key_counter(self, key, counter):
524 """Resets the generator by a new key-counter pair.
526 See `from_key_counter` for the meaning of "key" and "counter".
528 Args:
529 key: the new key.
530 counter: the new counter.
531 """
532 counter = _convert_to_state_tensor(counter)
533 key = _convert_to_state_tensor(key)
534 counter.shape.assert_is_compatible_with(
535 [_get_state_size(self.algorithm) - 1])
536 key.shape.assert_is_compatible_with([])
537 key = array_ops.reshape(key, [1])
538 state = array_ops.concat([counter, key], 0)
539 self._state_var.assign(state)
541 @property
542 def state(self):
543 """The internal state of the RNG."""
544 return self._state_var
546 @property
547 def algorithm(self):
548 """The RNG algorithm id (a Python integer or scalar integer Tensor)."""
549 return self._alg
551 def _standard_normal(self, shape, dtype):
552 key, counter = self._prepare_key_counter(shape)
553 return gen_stateless_random_ops_v2.stateless_random_normal_v2(
554 shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm)
556 @property
557 def key(self):
558 """The 'key' part of the state of a counter-based RNG.
560 For a counter-base RNG algorithm such as Philox and ThreeFry (as
561 described in paper 'Parallel Random Numbers: As Easy as 1, 2, 3'
562 [https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]),
563 the RNG state consists of two parts: counter and key. The output is
564 generated via the formula: output=hash(key, counter), i.e. a hashing of
565 the counter parametrized by the key. Two RNGs with two different keys can
566 be thought as generating two independent random-number streams (a stream
567 is formed by increasing the counter).
569 Returns:
570 A scalar which is the 'key' part of the state, if the RNG algorithm is
571 counter-based; otherwise it raises a ValueError.
572 """
573 alg = self.algorithm
574 if alg in (a.value for a in Algorithm):
575 return self._state_var[-1]
576 else:
577 raise ValueError(stateless_random_ops.unsupported_alg_error_msg(alg))
579 def _skip_single_var(self, var, delta):
580 resource_variable_ops.variable_accessed(var)
581 # TODO(wangpeng): Cache the cast algorithm instead of casting everytime.
582 return gen_stateful_random_ops.rng_read_and_skip(
583 var.handle,
584 alg=math_ops.cast(self.algorithm, dtypes.int32),
585 delta=math_ops.cast(delta, dtypes.uint64))
587 def skip(self, delta):
588 """Advance the counter of a counter-based RNG.
590 Args:
591 delta: the amount of advancement. The state of the RNG after
592 `skip(n)` will be the same as that after `normal([n])`
593 (or any other distribution). The actual increment added to the
594 counter is an unspecified implementation detail.
596 Returns:
597 A `Tensor` of type `int64`.
598 """
600 def update_fn(v):
601 return self._skip_single_var(v, delta)
602 # TODO(b/170515001): Always call strategy.extended.update after calling it
603 # from both replica context and cross-replica context is supported.
604 if values_util.is_saving_non_distributed():
605 # Assumes replica context with replica_id=0, since we only save the first
606 # replica.
607 return update_fn(self.state)
608 if self._distribution_strategy is not None:
609 with distribute_lib.enter_or_assert_strategy(self._distribution_strategy):
610 if distribute_lib.in_cross_replica_context():
611 # Code that operates on all replicas of a variable cannot be saved
612 # without retracing.
613 values_util.mark_as_unsaveable()
614 if (distribute_lib.in_cross_replica_context() or
615 "CentralStorage" in type(self._distribution_strategy).__name__):
616 # In cross-replica context we need to use strategy.extended.update.
617 # In CentralStorageStrategy we also need to use
618 # strategy.extended.update (even for replica context),
619 # because variable updates here must be within merge_call.
620 return distribute_lib.get_strategy().extended.update(
621 self.state, update_fn)
622 return update_fn(self.state)
624 def _preprocess_key(self, key):
625 if self._distribution_strategy is None:
626 return key
627 with distribute_lib.enter_or_assert_strategy(self._distribution_strategy):
628 replica_id = get_replica_id()
629 if replica_id is not None:
630 replica_id = array_ops_stack.stack([replica_id, 0], axis=0)
631 replica_id = math_ops.cast(replica_id, dtypes.uint64)
632 # Conceptually: key = hash(key, replica_id)
633 key = gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
634 shape=[1], key=key, counter=replica_id, dtype=dtypes.uint64,
635 alg=self.algorithm)
636 return key
638 def _prepare_key_counter(self, shape):
639 delta = math_ops.reduce_prod(shape)
640 counter_key = self.skip(delta)
641 counter_size = _get_counter_size(self.algorithm)
642 counter = array_ops.bitcast(counter_key[:counter_size], dtypes.uint64)
643 key = array_ops.bitcast(counter_key[counter_size:counter_size + 1],
644 dtypes.uint64)
645 key = self._preprocess_key(key)
646 return key, counter
648 # The following functions return a tensor and as a side effect update
649 # self._state_var.
650 def normal(self, shape, mean=0.0, stddev=1.0, dtype=dtypes.float32,
651 name=None):
652 """Outputs random values from a normal distribution.
654 Args:
655 shape: A 1-D integer Tensor or Python array. The shape of the output
656 tensor.
657 mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal
658 distribution.
659 stddev: A 0-D Tensor or Python value of type `dtype`. The standard
660 deviation of the normal distribution.
661 dtype: The type of the output.
662 name: A name for the operation (optional).
664 Returns:
665 A tensor of the specified shape filled with random normal values.
666 """
667 with ops.name_scope(name, "stateful_normal", [shape, mean, stddev]) as name:
668 shape = _shape_tensor(shape)
669 mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
670 stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
671 rnd = self._standard_normal(shape, dtype=dtype)
672 return math_ops.add(rnd * stddev, mean, name=name)
674 def _truncated_normal(self, shape, dtype):
675 key, counter = self._prepare_key_counter(shape)
676 return gen_stateless_random_ops_v2.stateless_truncated_normal_v2(
677 shape=shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm)
679 def truncated_normal(self, shape,
680 mean=0.0,
681 stddev=1.0,
682 dtype=dtypes.float32,
683 name=None):
684 """Outputs random values from a truncated normal distribution.
686 The generated values follow a normal distribution with specified mean and
687 standard deviation, except that values whose magnitude is more than
688 2 standard deviations from the mean are dropped and re-picked.
690 Args:
691 shape: A 1-D integer Tensor or Python array. The shape of the output
692 tensor.
693 mean: A 0-D Tensor or Python value of type `dtype`. The mean of the
694 truncated normal distribution.
695 stddev: A 0-D Tensor or Python value of type `dtype`. The standard
696 deviation of the normal distribution, before truncation.
697 dtype: The type of the output.
698 name: A name for the operation (optional).
700 Returns:
701 A tensor of the specified shape filled with random truncated normal
702 values.
703 """
704 with ops.name_scope(
705 name, "truncated_normal", [shape, mean, stddev]) as name:
706 shape_tensor = _shape_tensor(shape)
707 mean_tensor = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
708 stddev_tensor = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
709 rnd = self._truncated_normal(shape_tensor, dtype=dtype)
710 mul = rnd * stddev_tensor
711 return math_ops.add(mul, mean_tensor, name=name)
713 def _uniform(self, shape, dtype):
714 key, counter = self._prepare_key_counter(shape)
715 return gen_stateless_random_ops_v2.stateless_random_uniform_v2(
716 shape=shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm)
718 def _uniform_full_int(self, shape, dtype, name=None):
719 key, counter = self._prepare_key_counter(shape)
720 return gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
721 shape=shape,
722 key=key,
723 counter=counter,
724 dtype=dtype,
725 alg=self.algorithm,
726 name=name)
728 def uniform(self, shape, minval=0, maxval=None,
729 dtype=dtypes.float32, name=None):
730 """Outputs random values from a uniform distribution.
732 The generated values follow a uniform distribution in the range
733 `[minval, maxval)`. The lower bound `minval` is included in the range, while
734 the upper bound `maxval` is excluded. (For float numbers especially
735 low-precision types like bfloat16, because of
736 rounding, the result may sometimes include `maxval`.)
738 For floats, the default range is `[0, 1)`. For ints, at least `maxval` must
739 be specified explicitly.
741 In the integer case, the random integers are slightly biased unless
742 `maxval - minval` is an exact power of two. The bias is small for values of
743 `maxval - minval` significantly smaller than the range of the output (either
744 `2**32` or `2**64`).
746 For full-range random integers, pass `minval=None` and `maxval=None` with an
747 integer `dtype` (for integer dtypes, `minval` and `maxval` must be both
748 `None` or both not `None`).
750 Args:
751 shape: A 1-D integer Tensor or Python array. The shape of the output
752 tensor.
753 minval: A Tensor or Python value of type `dtype`, broadcastable with
754 `shape` (for integer types, broadcasting is not supported, so it needs
755 to be a scalar). The lower bound (included) on the range of random
756 values to generate. Pass `None` for full-range integers. Defaults to 0.
757 maxval: A Tensor or Python value of type `dtype`, broadcastable with
758 `shape` (for integer types, broadcasting is not supported, so it needs
759 to be a scalar). The upper bound (excluded) on the range of random
760 values to generate. Pass `None` for full-range integers. Defaults to 1
761 if `dtype` is floating point.
762 dtype: The type of the output.
763 name: A name for the operation (optional).
765 Returns:
766 A tensor of the specified shape filled with random uniform values.
768 Raises:
769 ValueError: If `dtype` is integral and `maxval` is not specified.
770 """
771 dtype = dtypes.as_dtype(dtype)
772 if dtype.is_integer:
773 if (minval is None) != (maxval is None):
774 raise ValueError("For integer dtype {}, minval and maxval must be both "
775 "`None` or both non-`None`; got minval={} and "
776 "maxval={}".format(dtype, minval, maxval))
777 elif maxval is None:
778 maxval = 1
779 with ops.name_scope(name, "stateful_uniform",
780 [shape, minval, maxval]) as name:
781 shape = _shape_tensor(shape)
782 if dtype.is_integer and minval is None:
783 return self._uniform_full_int(shape=shape, dtype=dtype, name=name)
784 minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
785 maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
786 if dtype.is_integer:
787 key, counter = self._prepare_key_counter(shape)
788 return gen_stateless_random_ops_v2.stateless_random_uniform_int_v2(
789 shape=shape,
790 key=key,
791 counter=counter,
792 minval=minval,
793 maxval=maxval,
794 alg=self.algorithm,
795 name=name)
796 else:
797 rnd = self._uniform(shape=shape, dtype=dtype)
798 return math_ops.add(rnd * (maxval - minval), minval, name=name)
800 def uniform_full_int(self, shape, dtype=dtypes.uint64, name=None):
801 """Uniform distribution on an integer type's entire range.
803 This method is the same as setting `minval` and `maxval` to `None` in the
804 `uniform` method.
806 Args:
807 shape: the shape of the output.
808 dtype: (optional) the integer type, default to uint64.
809 name: (optional) the name of the node.
811 Returns:
812 A tensor of random numbers of the required shape.
813 """
814 dtype = dtypes.as_dtype(dtype)
815 with ops.name_scope(name, "stateful_uniform_full_int",
816 [shape]) as name:
817 shape = _shape_tensor(shape)
818 return self._uniform_full_int(shape=shape, dtype=dtype, name=name)
820 def binomial(self, shape, counts, probs, dtype=dtypes.int32, name=None):
821 """Outputs random values from a binomial distribution.
823 The generated values follow a binomial distribution with specified count and
824 probability of success parameters.
826 Example:
828 ```python
829 counts = [10., 20.]
830 # Probability of success.
831 probs = [0.8]
833 rng = tf.random.Generator.from_seed(seed=234)
834 binomial_samples = rng.binomial(shape=[2], counts=counts, probs=probs)
837 counts = ... # Shape [3, 1, 2]
838 probs = ... # Shape [1, 4, 2]
839 shape = [3, 4, 3, 4, 2]
840 rng = tf.random.Generator.from_seed(seed=1717)
841 # Sample shape will be [3, 4, 3, 4, 2]
842 binomial_samples = rng.binomial(shape=shape, counts=counts, probs=probs)
843 ```
846 Args:
847 shape: A 1-D integer Tensor or Python array. The shape of the output
848 tensor.
849 counts: Tensor. The counts of the binomial distribution. Must be
850 broadcastable with `probs`, and broadcastable with the rightmost
851 dimensions of `shape`.
852 probs: Tensor. The probability of success for the
853 binomial distribution. Must be broadcastable with `counts` and
854 broadcastable with the rightmost dimensions of `shape`.
855 dtype: The type of the output. Default: tf.int32
856 name: A name for the operation (optional).
858 Returns:
859 samples: A Tensor of the specified shape filled with random binomial
860 values. For each i, each samples[i, ...] is an independent draw from
861 the binomial distribution on counts[i] trials with probability of
862 success probs[i].
863 """
864 dtype = dtypes.as_dtype(dtype)
865 with ops.name_scope(name, "binomial", [shape, counts, probs]) as name:
866 counts = ops.convert_to_tensor(counts, name="counts")
867 probs = ops.convert_to_tensor(probs, name="probs")
868 shape_tensor = _shape_tensor(shape)
869 return gen_stateful_random_ops.stateful_random_binomial(
870 self.state.handle,
871 self.algorithm,
872 shape=shape_tensor,
873 counts=counts,
874 probs=probs,
875 dtype=dtype,
876 name=name)
878 # TODO(wangpeng): implement other distributions
880 def _make_int64_keys(self, shape=()):
881 # New independent keys are generated via
882 # `new_key[i] = hash(old_key, counter+i)`, which is exactly what
883 # `uniform_full_int(dtype=int64)` does for PhiloxRandom_64_128_128 and
884 # ThreeFry_64_64_64.
885 return self.uniform_full_int(shape=shape, dtype=dtypes.int64)
887 def make_seeds(self, count=1):
888 """Generates seeds for stateless random ops.
890 For example:
892 ```python
893 seeds = get_global_generator().make_seeds(count=10)
894 for i in range(10):
895 seed = seeds[:, i]
896 numbers = stateless_random_normal(shape=[2, 3], seed=seed)
897 ...
898 ```
900 Args:
901 count: the number of seed pairs (note that stateless random ops need a
902 pair of seeds to invoke).
904 Returns:
905 A tensor of shape [2, count] and dtype int64.
906 """
907 alg = self.algorithm
908 if alg in (a.value for a in Algorithm):
909 keys = self._make_int64_keys(shape=[count])
910 # The two seeds for stateless random ops don't have individual semantics
911 # and are scrambled together, so setting one to zero is fine.
912 zeros = array_ops.zeros_like(keys)
913 return array_ops_stack.stack([keys, zeros])
914 else:
915 raise ValueError(stateless_random_ops.unsupported_alg_error_msg(alg))
917 def split(self, count=1):
918 """Returns a list of independent `Generator` objects.
920 Two generators are independent of each other in the sense that the
921 random-number streams they generate don't have statistically detectable
922 correlations. The new generators are also independent of the old one.
923 The old generator's state will be changed (like other random-number
924 generating methods), so two calls of `split` will return different
925 new generators.
927 For example:
929 ```python
930 gens = get_global_generator().split(count=10)
931 for gen in gens:
932 numbers = gen.normal(shape=[2, 3])
933 # ...
934 gens2 = get_global_generator().split(count=10)
935 # gens2 will be different from gens
936 ```
938 The new generators will be put on the current device (possible different
939 from the old generator's), for example:
941 ```python
942 with tf.device("/device:CPU:0"):
943 gen = Generator(seed=1234) # gen is on CPU
944 with tf.device("/device:GPU:0"):
945 gens = gen.split(count=10) # gens are on GPU
946 ```
948 Args:
949 count: the number of generators to return.
951 Returns:
952 A list (length `count`) of `Generator` objects independent of each other.
953 The new generators have the same RNG algorithm as the old one.
954 """
955 def _key_to_state(alg, key):
956 # Padding with zeros on the left. The zeros will be the counter.
957 return [0] * (_get_state_size(alg) - 1) + [key]
959 alg = self.algorithm
960 if alg in (a.value for a in Algorithm):
961 keys = self._make_int64_keys(shape=[count])
962 return [Generator(state=_key_to_state(alg, key), alg=alg)
963 for key in array_ops_stack.unstack(keys, num=count)]
964 else:
965 raise ValueError(stateless_random_ops.unsupported_alg_error_msg(alg))
968# It's not safe to create TF ops before `init_google` is called, so this is
969# initialized to None and get a value the first time `get_global_generator` is
970# called.
971global_generator = None
974@tf_export("random.get_global_generator",
975 "random.experimental.get_global_generator")
976def get_global_generator():
977 """Retrieves the global generator.
979 This function will create the global generator the first time it is called,
980 and the generator will be placed at the default device at that time, so one
981 needs to be careful when this function is first called. Using a generator
982 placed on a less-ideal device will incur performance regression.
984 Returns:
985 The global `tf.random.Generator` object.
986 """
987 global global_generator
988 if global_generator is None:
989 if config.is_op_determinism_enabled():
990 raise RuntimeError('"get_global_generator" cannot be called if ' # pylint: disable=g-doc-exception
991 "determinism is enabled, unless "
992 '"set_global_generator" has already been called. '
993 'Please call "set_global_generator" first.')
994 with ops.init_scope():
995 global_generator = Generator.from_non_deterministic_state()
996 return global_generator
999@tf_export("random.set_global_generator",
1000 "random.experimental.set_global_generator")
1001def set_global_generator(generator):
1002 """Replaces the global generator with another `Generator` object.
1004 This function replaces the global generator with the provided `generator`
1005 object.
1006 A random number generator utilizes a `tf.Variable` object to store its state.
1007 The user shall be aware of caveats how `set_global_generator` interacts with
1008 `tf.function`:
1010 - tf.function puts restrictions on Variable creation thus one cannot freely
1011 create a new random generator instance inside `tf.function`.
1012 To call `set_global_generator` inside `tf.function`, the generator instance
1013 must have already been created eagerly.
1014 - tf.function captures the Variable during trace-compilation, thus a compiled
1015 f.function will not be affected `set_global_generator` as demonstrated by
1016 random_test.py/RandomTest.testResetGlobalGeneratorBadWithDefun .
1018 For most use cases, avoid calling `set_global_generator` after program
1019 initialization, and prefer to reset the state of the existing global generator
1020 instead, such as,
1022 >>> rng = tf.random.get_global_generator()
1023 >>> rng.reset_from_seed(30)
1026 Args:
1027 generator: the new `Generator` object.
1028 """
1029 global global_generator
1030 global_generator = generator