Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/stateful_random_ops.py: 31%

307 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Operations for generating random numbers.""" 

16 

17from tensorflow.python.distribute import distribute_lib 

18from tensorflow.python.distribute import sharded_variable 

19from tensorflow.python.distribute import values_util 

20from tensorflow.python.eager import context 

21from tensorflow.python.framework import config 

22from tensorflow.python.framework import dtypes 

23from tensorflow.python.framework import ops 

24from tensorflow.python.ops import array_ops 

25from tensorflow.python.ops import array_ops_stack 

26from tensorflow.python.ops import gen_stateful_random_ops 

27from tensorflow.python.ops import gen_stateless_random_ops_v2 

28from tensorflow.python.ops import math_ops 

29from tensorflow.python.ops import resource_variable_ops 

30from tensorflow.python.ops import stateless_random_ops 

31from tensorflow.python.ops import variables 

32from tensorflow.python.ops.stateless_random_ops import Algorithm 

33from tensorflow.python.trackable import autotrackable 

34from tensorflow.python.util import nest 

35from tensorflow.python.util.tf_export import tf_export 

36 

37 

38# A seed for random ops (stateful and stateless) will always be 1024 

39# bits, all of which will be sent to the C++ code. The actual C++ 

40# implementation of some algorithms may only use a lower part of the bits. 

41 

42UINT64_HALF_SPAN = 2**63 

43MAX_INT64 = UINT64_HALF_SPAN - 1 

44MIN_INT64 = -UINT64_HALF_SPAN 

45UINT64_SPAN = UINT64_HALF_SPAN * 2 

46# 'Variable' doesn't support uint32 or uint64 yet (due to reasons explained in 

47# b/111604096 and cl/171681867), so I use signed int here. I choose int64 

48# instead of int32 here because `VarHandleOp` doesn't support int32 on GPU. 

49SEED_TYPE = "int64" 

50SEED_MIN = MIN_INT64 

51SEED_MAX = MAX_INT64 

52SEED_UINT_SPAN = UINT64_SPAN 

53SEED_TYPE_BITS = 64 

54SEED_BIT_MASK = 0xFFFFFFFFFFFFFFFF 

55SEED_SIZE = 16 # in units of SEED_TYPE 

56 

57 

58STATE_TYPE = SEED_TYPE 

59ALGORITHM_TYPE = STATE_TYPE 

60 

61 

62# The following sizes are all in unit of uint64. 

63PHILOX_KEY_SIZE = 1 

64THREEFRY_KEY_SIZE = 1 

65PHILOX_COUNTER_SIZE = 2 

66THREEFRY_COUNTER_SIZE = 1 

67PHILOX_STATE_SIZE = PHILOX_COUNTER_SIZE + PHILOX_KEY_SIZE 

68THREEFRY_STATE_SIZE = THREEFRY_COUNTER_SIZE + THREEFRY_KEY_SIZE 

69 

70 

71RNG_ALG_PHILOX = Algorithm.PHILOX.value 

72RNG_ALG_THREEFRY = Algorithm.THREEFRY.value 

73 

74 

75DEFAULT_ALGORITHM = RNG_ALG_PHILOX 

76 

77 

78def non_deterministic_ints(shape, dtype=dtypes.int64): 

79 """Non-deterministically generates some integers. 

80 

81 This op may use some OS-provided source of non-determinism (e.g. an RNG), so 

82 each execution will give different results. 

83 

84 Args: 

85 shape: the shape of the result. 

86 dtype: (optional) the dtype of the result. 

87 

88 Returns: 

89 a tensor whose element values are non-deterministically chosen. 

90 """ 

91 return gen_stateful_random_ops.non_deterministic_ints( 

92 shape=shape, dtype=dtype) 

93 

94 

95def _uint_to_int(n): 

96 if isinstance(n, int) and n > SEED_MAX: 

97 n = n - SEED_UINT_SPAN 

98 return n 

99 

100 

101def _make_1d_state(state_size, seed): 

102 """Makes a 1-D RNG state. 

103 

104 Args: 

105 state_size: an integer. 

106 seed: an integer or 1-D tensor. 

107 

108 Returns: 

109 a 1-D tensor of shape [state_size] and dtype STATE_TYPE. 

110 """ 

111 if isinstance(seed, int): 

112 # chop the Python integer (infinite precision) into chunks of SEED_TYPE 

113 ls = [] 

114 for _ in range(state_size): 

115 ls.append(seed & SEED_BIT_MASK) 

116 seed >>= SEED_TYPE_BITS 

117 seed = ls 

118 # to avoid overflow error from ops.convert_to_tensor 

119 seed = nest.map_structure(_uint_to_int, seed) 

120 seed = math_ops.cast(seed, STATE_TYPE) 

121 seed = array_ops.reshape(seed, [-1]) 

122 seed = seed[0:state_size] 

123 # Padding with zeros on the *left* if too short. Padding on the right would 

124 # cause a small seed to be used as the "counter" while the "key" is always 

125 # zero (for counter-based RNG algorithms), because in the current memory 

126 # layout counter is stored before key. In such a situation two RNGs with 

127 # two different small seeds may generate overlapping outputs. 

128 seed_size = seed.shape[0] 

129 if seed_size is None: 

130 seed_size = array_ops.shape(seed)[0] 

131 padding_size = math_ops.maximum(state_size - seed_size, 0) 

132 padding = array_ops.zeros([padding_size], seed.dtype) 

133 # can't use `pad` because it doesn't support integer dtypes on GPU 

134 seed = array_ops.concat([padding, seed], axis=0) 

135 seed.set_shape([state_size]) 

136 return seed 

137 

138 

139def _get_counter_size(alg): 

140 if alg == Algorithm.PHILOX.value: 

141 return PHILOX_COUNTER_SIZE 

142 elif alg == Algorithm.THREEFRY.value: 

143 return THREEFRY_COUNTER_SIZE 

144 elif alg == Algorithm.AUTO_SELECT.value: 

145 # For AUTO_SELECT, we'll manage the counter as if it's for Philox. 

146 return PHILOX_COUNTER_SIZE 

147 else: 

148 raise ValueError(stateless_random_ops.unsupported_alg_error_msg(alg)) 

149 

150 

151def _get_state_size(alg): 

152 if alg == Algorithm.PHILOX.value: 

153 return PHILOX_STATE_SIZE 

154 elif alg == Algorithm.THREEFRY.value: 

155 return THREEFRY_STATE_SIZE 

156 elif alg == Algorithm.AUTO_SELECT.value: 

157 # For AUTO_SELECT, we'll manage the state as if it's for Philox. 

158 return PHILOX_STATE_SIZE 

159 else: 

160 raise ValueError(stateless_random_ops.unsupported_alg_error_msg(alg)) 

161 

162 

163def _check_state_shape(shape, alg): 

164 if isinstance(alg, ops.Tensor) and not context.executing_eagerly(): 

165 return 

166 shape.assert_is_compatible_with([_get_state_size(int(alg))]) 

167 

168 

169def _make_state_from_seed(seed, alg): 

170 return _make_1d_state(_get_state_size(alg), seed) 

171 

172 

173@tf_export("random.create_rng_state", "random.experimental.create_rng_state") 

174def create_rng_state(seed, alg): 

175 """Creates a RNG state from an integer or a vector. 

176 

177 Example: 

178 

179 >>> tf.random.create_rng_state( 

180 ... 1234, "philox") 

181 <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1234, 0, 0])> 

182 >>> tf.random.create_rng_state( 

183 ... [12, 34], "threefry") 

184 <tf.Tensor: shape=(2,), dtype=int64, numpy=array([12, 34])> 

185 

186 Args: 

187 seed: an integer or 1-D numpy array. 

188 alg: the RNG algorithm. Can be a string, an `Algorithm` or an integer. 

189 

190 Returns: 

191 a 1-D numpy array whose size depends on the algorithm. 

192 """ 

193 alg = stateless_random_ops.convert_alg_to_int(alg) 

194 return _make_state_from_seed(seed, alg) 

195 

196 

197def _shape_tensor(shape): 

198 """Convert to an int32 or int64 tensor, defaulting to int64 if empty.""" 

199 if isinstance(shape, (tuple, list)) and not shape: 

200 dtype = dtypes.int64 

201 else: 

202 dtype = None 

203 return ops.convert_to_tensor(shape, dtype=dtype, name="shape") 

204 

205 

206def _convert_to_state_tensor(t): 

207 # to avoid out-of-range error from ops.convert_to_tensor 

208 t = nest.map_structure(_uint_to_int, t) 

209 return math_ops.cast(t, STATE_TYPE) 

210 

211 

212def get_replica_id(): 

213 rctx = distribute_lib.get_replica_context() 

214 if rctx is None: 

215 return None 

216 return rctx.replica_id_in_sync_group 

217 

218 

219@tf_export("random.Generator", "random.experimental.Generator") 

220class Generator(autotrackable.AutoTrackable): 

221 """Random-number generator. 

222 

223 Example: 

224 

225 Creating a generator from a seed: 

226 

227 >>> g = tf.random.Generator.from_seed(1234) 

228 >>> g.normal(shape=(2, 3)) 

229 <tf.Tensor: shape=(2, 3), dtype=float32, numpy= 

230 array([[ 0.9356609 , 1.0854305 , -0.93788373], 

231 [-0.5061547 , 1.3169702 , 0.7137579 ]], dtype=float32)> 

232 

233 Creating a generator from a non-deterministic state: 

234 

235 >>> g = tf.random.Generator.from_non_deterministic_state() 

236 >>> g.normal(shape=(2, 3)) 

237 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...> 

238 

239 All the constructors allow explicitly choosing an Random-Number-Generation 

240 (RNG) algorithm. Supported algorithms are `"philox"` and `"threefry"`. For 

241 example: 

242 

243 >>> g = tf.random.Generator.from_seed(123, alg="philox") 

244 >>> g.normal(shape=(2, 3)) 

245 <tf.Tensor: shape=(2, 3), dtype=float32, numpy= 

246 array([[ 0.8673864 , -0.29899067, -0.9310337 ], 

247 [-1.5828488 , 1.2481191 , -0.6770643 ]], dtype=float32)> 

248 

249 CPU, GPU and TPU with the same algorithm and seed will generate the same 

250 integer random numbers. Float-point results (such as the output of `normal`) 

251 may have small numerical discrepancies between different devices. 

252 

253 This class uses a `tf.Variable` to manage its internal state. Every time 

254 random numbers are generated, the state of the generator will change. For 

255 example: 

256 

257 >>> g = tf.random.Generator.from_seed(1234) 

258 >>> g.state 

259 <tf.Variable ... numpy=array([1234, 0, 0])> 

260 >>> g.normal(shape=(2, 3)) 

261 <...> 

262 >>> g.state 

263 <tf.Variable ... numpy=array([2770, 0, 0])> 

264 

265 The shape of the state is algorithm-specific. 

266 

267 There is also a global generator: 

268 

269 >>> g = tf.random.get_global_generator() 

270 >>> g.normal(shape=(2, 3)) 

271 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...> 

272 

273 When creating a generator inside a `tf.distribute.Strategy` scope, each 

274 replica will get a different stream of random numbers. 

275 

276 For example, in this code: 

277 

278 ``` 

279 strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"]) 

280 with strat.scope(): 

281 g = tf.random.Generator.from_seed(1) 

282 def f(): 

283 return g.normal([]) 

284 results = strat.run(f).values 

285 ``` 

286 

287 `results[0]` and `results[1]` will have different values. 

288 

289 If the generator is seeded (e.g. created via `Generator.from_seed`), the 

290 random numbers will be determined by the seed, even though different replicas 

291 get different numbers. One can think of a random number generated on a 

292 replica as a hash of the replica ID and a "master" random number that may be 

293 common to all replicas. Hence, the whole system is still deterministic. 

294 

295 (Note that the random numbers on different replicas are not correlated, even 

296 if they are deterministically determined by the same seed. They are not 

297 correlated in the sense that no matter what statistics one calculates on them, 

298 there won't be any discernable correlation.) 

299 

300 Generators can be freely saved and restored using `tf.train.Checkpoint`. The 

301 checkpoint can be restored in a distribution strategy with a different number 

302 of replicas than the original strategy. If a replica ID is present in both the 

303 original and the new distribution strategy, its state will be properly 

304 restored (i.e. the random-number stream from the restored point will be the 

305 same as that from the saving point) unless the replicas have already diverged 

306 in their RNG call traces before saving (e.g. one replica has made one RNG call 

307 while another has made two RNG calls). We don't have such guarantee if the 

308 generator is saved in a strategy scope and restored outside of any strategy 

309 scope, or vice versa. 

310 

311 When a generator is created within the scope of 

312 `tf.distribute.experimental.ParameterServerStrategy`, the workers 

313 will share the generator's state (placed on one of the parameter 

314 servers). In this way the workers will still get different 

315 random-number streams, as stated above. (This is similar to replicas 

316 in a `tf.distribute.MirroredStrategy` sequentially accessing a 

317 generator created outside the strategy.) Each RNG call on a worker 

318 will incur a round-trip to a parameter server, which may have 

319 performance impacts. When creating a 

320 `tf.distribute.experimental.ParameterServerStrategy`, please make 

321 sure that the `variable_partitioner` argument won't shard small 

322 variables of shape `[2]` or `[3]` (because generator states must not 

323 be sharded). Ways to avoid sharding small variables include setting 

324 `variable_partitioner` to `None` or to 

325 `tf.distribute.experimental.partitioners.MinSizePartitioner` with a 

326 large enough `min_shard_bytes` (see 

327 `tf.distribute.experimental.ParameterServerStrategy`'s documentation 

328 for more details). 

329 """ 

330 

331 @classmethod 

332 def from_state(cls, state, alg): 

333 """Creates a generator from a state. 

334 

335 See `__init__` for description of `state` and `alg`. 

336 

337 Args: 

338 state: the new state. 

339 alg: the RNG algorithm. 

340 

341 Returns: 

342 The new generator. 

343 """ 

344 return cls(alg=alg, state=state) 

345 

346 @classmethod 

347 def from_seed(cls, seed, alg=None): 

348 """Creates a generator from a seed. 

349 

350 A seed is a 1024-bit unsigned integer represented either as a Python 

351 integer or a vector of integers. Seeds shorter than 1024-bit will be 

352 padded. The padding, the internal structure of a seed and the way a seed 

353 is converted to a state are all opaque (unspecified). The only semantics 

354 specification of seeds is that two different seeds are likely to produce 

355 two independent generators (but no guarantee). 

356 

357 Args: 

358 seed: the seed for the RNG. 

359 alg: (optional) the RNG algorithm. If None, it will be auto-selected. See 

360 `__init__` for its possible values. 

361 

362 Returns: 

363 The new generator. 

364 """ 

365 if alg is None: 

366 # TODO(b/170668986): more sophisticated algorithm selection 

367 alg = DEFAULT_ALGORITHM 

368 alg = stateless_random_ops.convert_alg_to_int(alg) 

369 state = create_rng_state(seed, alg) 

370 return cls(state=state, alg=alg) 

371 

372 @classmethod 

373 def from_non_deterministic_state(cls, alg=None): 

374 """Creates a generator by non-deterministically initializing its state. 

375 

376 The source of the non-determinism will be platform- and time-dependent. 

377 

378 Args: 

379 alg: (optional) the RNG algorithm. If None, it will be auto-selected. See 

380 `__init__` for its possible values. 

381 

382 Returns: 

383 The new generator. 

384 """ 

385 if config.is_op_determinism_enabled(): 

386 raise RuntimeError('"from_non_deterministic_state" cannot be called when ' # pylint: disable=g-doc-exception 

387 "determinism is enabled.") 

388 if alg is None: 

389 # TODO(b/170668986): more sophisticated algorithm selection 

390 alg = DEFAULT_ALGORITHM 

391 alg = stateless_random_ops.convert_alg_to_int(alg) 

392 state = non_deterministic_ints(shape=[_get_state_size(alg)], 

393 dtype=SEED_TYPE) 

394 return cls(state=state, alg=alg) 

395 

396 @classmethod 

397 def from_key_counter(cls, key, counter, alg): 

398 """Creates a generator from a key and a counter. 

399 

400 This constructor only applies if the algorithm is a counter-based algorithm. 

401 See method `key` for the meaning of "key" and "counter". 

402 

403 Args: 

404 key: the key for the RNG, a scalar of type STATE_TYPE. 

405 counter: a vector of dtype STATE_TYPE representing the initial counter for 

406 the RNG, whose length is algorithm-specific., 

407 alg: the RNG algorithm. If None, it will be auto-selected. See 

408 `__init__` for its possible values. 

409 

410 Returns: 

411 The new generator. 

412 """ 

413 counter = _convert_to_state_tensor(counter) 

414 key = _convert_to_state_tensor(key) 

415 alg = stateless_random_ops.convert_alg_to_int(alg) 

416 counter.shape.assert_is_compatible_with([_get_state_size(alg) - 1]) 

417 key.shape.assert_is_compatible_with([]) 

418 key = array_ops.reshape(key, [1]) 

419 state = array_ops.concat([counter, key], 0) 

420 return cls(state=state, alg=alg) 

421 

422 def __init__(self, copy_from=None, state=None, alg=None): 

423 """Creates a generator. 

424 

425 The new generator will be initialized by one of the following ways, with 

426 decreasing precedence: 

427 (1) If `copy_from` is not None, the new generator is initialized by copying 

428 information from another generator. 

429 (2) If `state` and `alg` are not None (they must be set together), the new 

430 generator is initialized by a state. 

431 

432 Args: 

433 copy_from: a generator to be copied from. 

434 state: a vector of dtype STATE_TYPE representing the initial state of the 

435 RNG, whose length and semantics are algorithm-specific. If it's a 

436 variable, the generator will reuse it instead of creating a new 

437 variable. 

438 alg: the RNG algorithm. Possible values are 

439 `tf.random.Algorithm.PHILOX` for the Philox algorithm and 

440 `tf.random.Algorithm.THREEFRY` for the ThreeFry algorithm 

441 (see paper 'Parallel Random Numbers: As Easy as 1, 2, 3' 

442 [https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]). 

443 The string names `"philox"` and `"threefry"` can also be used. 

444 Note `PHILOX` guarantees the same numbers are produced (given 

445 the same random state) across all architectures (CPU, GPU, XLA etc). 

446 """ 

447 # TODO(b/175072242): Remove distribution-strategy dependencies in this file. 

448 if distribute_lib.has_strategy(): 

449 self._distribution_strategy = distribute_lib.get_strategy() 

450 else: 

451 self._distribution_strategy = None 

452 if copy_from is not None: 

453 # All other arguments should be None 

454 assert (alg or state) is None 

455 self._state_var = self._create_variable(copy_from.state, dtype=STATE_TYPE, 

456 trainable=False) 

457 self._alg = copy_from.algorithm 

458 else: 

459 assert alg is not None and state is not None 

460 alg = stateless_random_ops.convert_alg_to_int(alg) 

461 if isinstance(state, variables.Variable): 

462 _check_state_shape(state.shape, alg) 

463 self._state_var = state 

464 else: 

465 state = _convert_to_state_tensor(state) 

466 _check_state_shape(state.shape, alg) 

467 self._state_var = self._create_variable(state, dtype=STATE_TYPE, 

468 trainable=False) 

469 self._alg = alg 

470 

471 def _create_variable(self, *args, **kwargs): 

472 """Creates a variable. 

473 

474 Args: 

475 *args: positional arguments passed along to `variables.Variable. 

476 **kwargs: keyword arguments passed along to `variables.Variable. 

477 

478 Returns: 

479 The created variable. 

480 """ 

481 with ops.name_scope("random_generator"): 

482 # Make sure we don't change this name since Keras was using this name 

483 # to filter out the state variable. 

484 kwargs["name"] = "StateVar" 

485 v = variables.Variable(*args, **kwargs) 

486 if isinstance(v, sharded_variable.ShardedVariable): 

487 # RNG state is an atomic entity representing a 128-bit or 

488 # 192-bit value, so it mustn't be sharded. 

489 raise ValueError( 

490 "tf.random.Generator state is sharded, which is not allowed. When " 

491 "creating a tf.distribute.experimental.ParameterServerStrategy, " 

492 "please make sure that the `variable_partitioner` " 

493 "argument won't shard a " 

494 "small variable of shape [2] or [3]. Ways to avoid sharding small " 

495 "variables include setting `variable_partitioner` to None or to " 

496 "tf.distribute.experimental.partitioners.MinSizePartitioner with a " 

497 "large enough `min_shard_bytes`.") 

498 return v 

499 

500 def reset(self, state): 

501 """Resets the generator by a new state. 

502 

503 See `__init__` for the meaning of "state". 

504 

505 Args: 

506 state: the new state. 

507 """ 

508 state = _convert_to_state_tensor(state) 

509 state.shape.assert_is_compatible_with([_get_state_size(self.algorithm)]) 

510 self._state_var.assign(state) 

511 

512 def reset_from_seed(self, seed): 

513 """Resets the generator by a new seed. 

514 

515 See `from_seed` for the meaning of "seed". 

516 

517 Args: 

518 seed: the new seed. 

519 """ 

520 state = create_rng_state(seed, self.algorithm) 

521 self._state_var.assign(state) 

522 

523 def reset_from_key_counter(self, key, counter): 

524 """Resets the generator by a new key-counter pair. 

525 

526 See `from_key_counter` for the meaning of "key" and "counter". 

527 

528 Args: 

529 key: the new key. 

530 counter: the new counter. 

531 """ 

532 counter = _convert_to_state_tensor(counter) 

533 key = _convert_to_state_tensor(key) 

534 counter.shape.assert_is_compatible_with( 

535 [_get_state_size(self.algorithm) - 1]) 

536 key.shape.assert_is_compatible_with([]) 

537 key = array_ops.reshape(key, [1]) 

538 state = array_ops.concat([counter, key], 0) 

539 self._state_var.assign(state) 

540 

541 @property 

542 def state(self): 

543 """The internal state of the RNG.""" 

544 return self._state_var 

545 

546 @property 

547 def algorithm(self): 

548 """The RNG algorithm id (a Python integer or scalar integer Tensor).""" 

549 return self._alg 

550 

551 def _standard_normal(self, shape, dtype): 

552 key, counter = self._prepare_key_counter(shape) 

553 return gen_stateless_random_ops_v2.stateless_random_normal_v2( 

554 shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm) 

555 

556 @property 

557 def key(self): 

558 """The 'key' part of the state of a counter-based RNG. 

559 

560 For a counter-base RNG algorithm such as Philox and ThreeFry (as 

561 described in paper 'Parallel Random Numbers: As Easy as 1, 2, 3' 

562 [https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]), 

563 the RNG state consists of two parts: counter and key. The output is 

564 generated via the formula: output=hash(key, counter), i.e. a hashing of 

565 the counter parametrized by the key. Two RNGs with two different keys can 

566 be thought as generating two independent random-number streams (a stream 

567 is formed by increasing the counter). 

568 

569 Returns: 

570 A scalar which is the 'key' part of the state, if the RNG algorithm is 

571 counter-based; otherwise it raises a ValueError. 

572 """ 

573 alg = self.algorithm 

574 if alg in (a.value for a in Algorithm): 

575 return self._state_var[-1] 

576 else: 

577 raise ValueError(stateless_random_ops.unsupported_alg_error_msg(alg)) 

578 

579 def _skip_single_var(self, var, delta): 

580 resource_variable_ops.variable_accessed(var) 

581 # TODO(wangpeng): Cache the cast algorithm instead of casting everytime. 

582 return gen_stateful_random_ops.rng_read_and_skip( 

583 var.handle, 

584 alg=math_ops.cast(self.algorithm, dtypes.int32), 

585 delta=math_ops.cast(delta, dtypes.uint64)) 

586 

587 def skip(self, delta): 

588 """Advance the counter of a counter-based RNG. 

589 

590 Args: 

591 delta: the amount of advancement. The state of the RNG after 

592 `skip(n)` will be the same as that after `normal([n])` 

593 (or any other distribution). The actual increment added to the 

594 counter is an unspecified implementation detail. 

595 

596 Returns: 

597 A `Tensor` of type `int64`. 

598 """ 

599 

600 def update_fn(v): 

601 return self._skip_single_var(v, delta) 

602 # TODO(b/170515001): Always call strategy.extended.update after calling it 

603 # from both replica context and cross-replica context is supported. 

604 if values_util.is_saving_non_distributed(): 

605 # Assumes replica context with replica_id=0, since we only save the first 

606 # replica. 

607 return update_fn(self.state) 

608 if self._distribution_strategy is not None: 

609 with distribute_lib.enter_or_assert_strategy(self._distribution_strategy): 

610 if distribute_lib.in_cross_replica_context(): 

611 # Code that operates on all replicas of a variable cannot be saved 

612 # without retracing. 

613 values_util.mark_as_unsaveable() 

614 if (distribute_lib.in_cross_replica_context() or 

615 "CentralStorage" in type(self._distribution_strategy).__name__): 

616 # In cross-replica context we need to use strategy.extended.update. 

617 # In CentralStorageStrategy we also need to use 

618 # strategy.extended.update (even for replica context), 

619 # because variable updates here must be within merge_call. 

620 return distribute_lib.get_strategy().extended.update( 

621 self.state, update_fn) 

622 return update_fn(self.state) 

623 

624 def _preprocess_key(self, key): 

625 if self._distribution_strategy is None: 

626 return key 

627 with distribute_lib.enter_or_assert_strategy(self._distribution_strategy): 

628 replica_id = get_replica_id() 

629 if replica_id is not None: 

630 replica_id = array_ops_stack.stack([replica_id, 0], axis=0) 

631 replica_id = math_ops.cast(replica_id, dtypes.uint64) 

632 # Conceptually: key = hash(key, replica_id) 

633 key = gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2( 

634 shape=[1], key=key, counter=replica_id, dtype=dtypes.uint64, 

635 alg=self.algorithm) 

636 return key 

637 

638 def _prepare_key_counter(self, shape): 

639 delta = math_ops.reduce_prod(shape) 

640 counter_key = self.skip(delta) 

641 counter_size = _get_counter_size(self.algorithm) 

642 counter = array_ops.bitcast(counter_key[:counter_size], dtypes.uint64) 

643 key = array_ops.bitcast(counter_key[counter_size:counter_size + 1], 

644 dtypes.uint64) 

645 key = self._preprocess_key(key) 

646 return key, counter 

647 

648 # The following functions return a tensor and as a side effect update 

649 # self._state_var. 

650 def normal(self, shape, mean=0.0, stddev=1.0, dtype=dtypes.float32, 

651 name=None): 

652 """Outputs random values from a normal distribution. 

653 

654 Args: 

655 shape: A 1-D integer Tensor or Python array. The shape of the output 

656 tensor. 

657 mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal 

658 distribution. 

659 stddev: A 0-D Tensor or Python value of type `dtype`. The standard 

660 deviation of the normal distribution. 

661 dtype: The type of the output. 

662 name: A name for the operation (optional). 

663 

664 Returns: 

665 A tensor of the specified shape filled with random normal values. 

666 """ 

667 with ops.name_scope(name, "stateful_normal", [shape, mean, stddev]) as name: 

668 shape = _shape_tensor(shape) 

669 mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean") 

670 stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") 

671 rnd = self._standard_normal(shape, dtype=dtype) 

672 return math_ops.add(rnd * stddev, mean, name=name) 

673 

674 def _truncated_normal(self, shape, dtype): 

675 key, counter = self._prepare_key_counter(shape) 

676 return gen_stateless_random_ops_v2.stateless_truncated_normal_v2( 

677 shape=shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm) 

678 

679 def truncated_normal(self, shape, 

680 mean=0.0, 

681 stddev=1.0, 

682 dtype=dtypes.float32, 

683 name=None): 

684 """Outputs random values from a truncated normal distribution. 

685 

686 The generated values follow a normal distribution with specified mean and 

687 standard deviation, except that values whose magnitude is more than 

688 2 standard deviations from the mean are dropped and re-picked. 

689 

690 Args: 

691 shape: A 1-D integer Tensor or Python array. The shape of the output 

692 tensor. 

693 mean: A 0-D Tensor or Python value of type `dtype`. The mean of the 

694 truncated normal distribution. 

695 stddev: A 0-D Tensor or Python value of type `dtype`. The standard 

696 deviation of the normal distribution, before truncation. 

697 dtype: The type of the output. 

698 name: A name for the operation (optional). 

699 

700 Returns: 

701 A tensor of the specified shape filled with random truncated normal 

702 values. 

703 """ 

704 with ops.name_scope( 

705 name, "truncated_normal", [shape, mean, stddev]) as name: 

706 shape_tensor = _shape_tensor(shape) 

707 mean_tensor = ops.convert_to_tensor(mean, dtype=dtype, name="mean") 

708 stddev_tensor = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") 

709 rnd = self._truncated_normal(shape_tensor, dtype=dtype) 

710 mul = rnd * stddev_tensor 

711 return math_ops.add(mul, mean_tensor, name=name) 

712 

713 def _uniform(self, shape, dtype): 

714 key, counter = self._prepare_key_counter(shape) 

715 return gen_stateless_random_ops_v2.stateless_random_uniform_v2( 

716 shape=shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm) 

717 

718 def _uniform_full_int(self, shape, dtype, name=None): 

719 key, counter = self._prepare_key_counter(shape) 

720 return gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2( 

721 shape=shape, 

722 key=key, 

723 counter=counter, 

724 dtype=dtype, 

725 alg=self.algorithm, 

726 name=name) 

727 

728 def uniform(self, shape, minval=0, maxval=None, 

729 dtype=dtypes.float32, name=None): 

730 """Outputs random values from a uniform distribution. 

731 

732 The generated values follow a uniform distribution in the range 

733 `[minval, maxval)`. The lower bound `minval` is included in the range, while 

734 the upper bound `maxval` is excluded. (For float numbers especially 

735 low-precision types like bfloat16, because of 

736 rounding, the result may sometimes include `maxval`.) 

737 

738 For floats, the default range is `[0, 1)`. For ints, at least `maxval` must 

739 be specified explicitly. 

740 

741 In the integer case, the random integers are slightly biased unless 

742 `maxval - minval` is an exact power of two. The bias is small for values of 

743 `maxval - minval` significantly smaller than the range of the output (either 

744 `2**32` or `2**64`). 

745 

746 For full-range random integers, pass `minval=None` and `maxval=None` with an 

747 integer `dtype` (for integer dtypes, `minval` and `maxval` must be both 

748 `None` or both not `None`). 

749 

750 Args: 

751 shape: A 1-D integer Tensor or Python array. The shape of the output 

752 tensor. 

753 minval: A Tensor or Python value of type `dtype`, broadcastable with 

754 `shape` (for integer types, broadcasting is not supported, so it needs 

755 to be a scalar). The lower bound (included) on the range of random 

756 values to generate. Pass `None` for full-range integers. Defaults to 0. 

757 maxval: A Tensor or Python value of type `dtype`, broadcastable with 

758 `shape` (for integer types, broadcasting is not supported, so it needs 

759 to be a scalar). The upper bound (excluded) on the range of random 

760 values to generate. Pass `None` for full-range integers. Defaults to 1 

761 if `dtype` is floating point. 

762 dtype: The type of the output. 

763 name: A name for the operation (optional). 

764 

765 Returns: 

766 A tensor of the specified shape filled with random uniform values. 

767 

768 Raises: 

769 ValueError: If `dtype` is integral and `maxval` is not specified. 

770 """ 

771 dtype = dtypes.as_dtype(dtype) 

772 if dtype.is_integer: 

773 if (minval is None) != (maxval is None): 

774 raise ValueError("For integer dtype {}, minval and maxval must be both " 

775 "`None` or both non-`None`; got minval={} and " 

776 "maxval={}".format(dtype, minval, maxval)) 

777 elif maxval is None: 

778 maxval = 1 

779 with ops.name_scope(name, "stateful_uniform", 

780 [shape, minval, maxval]) as name: 

781 shape = _shape_tensor(shape) 

782 if dtype.is_integer and minval is None: 

783 return self._uniform_full_int(shape=shape, dtype=dtype, name=name) 

784 minval = ops.convert_to_tensor(minval, dtype=dtype, name="min") 

785 maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max") 

786 if dtype.is_integer: 

787 key, counter = self._prepare_key_counter(shape) 

788 return gen_stateless_random_ops_v2.stateless_random_uniform_int_v2( 

789 shape=shape, 

790 key=key, 

791 counter=counter, 

792 minval=minval, 

793 maxval=maxval, 

794 alg=self.algorithm, 

795 name=name) 

796 else: 

797 rnd = self._uniform(shape=shape, dtype=dtype) 

798 return math_ops.add(rnd * (maxval - minval), minval, name=name) 

799 

800 def uniform_full_int(self, shape, dtype=dtypes.uint64, name=None): 

801 """Uniform distribution on an integer type's entire range. 

802 

803 This method is the same as setting `minval` and `maxval` to `None` in the 

804 `uniform` method. 

805 

806 Args: 

807 shape: the shape of the output. 

808 dtype: (optional) the integer type, default to uint64. 

809 name: (optional) the name of the node. 

810 

811 Returns: 

812 A tensor of random numbers of the required shape. 

813 """ 

814 dtype = dtypes.as_dtype(dtype) 

815 with ops.name_scope(name, "stateful_uniform_full_int", 

816 [shape]) as name: 

817 shape = _shape_tensor(shape) 

818 return self._uniform_full_int(shape=shape, dtype=dtype, name=name) 

819 

820 def binomial(self, shape, counts, probs, dtype=dtypes.int32, name=None): 

821 """Outputs random values from a binomial distribution. 

822 

823 The generated values follow a binomial distribution with specified count and 

824 probability of success parameters. 

825 

826 Example: 

827 

828 ```python 

829 counts = [10., 20.] 

830 # Probability of success. 

831 probs = [0.8] 

832 

833 rng = tf.random.Generator.from_seed(seed=234) 

834 binomial_samples = rng.binomial(shape=[2], counts=counts, probs=probs) 

835 

836 

837 counts = ... # Shape [3, 1, 2] 

838 probs = ... # Shape [1, 4, 2] 

839 shape = [3, 4, 3, 4, 2] 

840 rng = tf.random.Generator.from_seed(seed=1717) 

841 # Sample shape will be [3, 4, 3, 4, 2] 

842 binomial_samples = rng.binomial(shape=shape, counts=counts, probs=probs) 

843 ``` 

844 

845 

846 Args: 

847 shape: A 1-D integer Tensor or Python array. The shape of the output 

848 tensor. 

849 counts: Tensor. The counts of the binomial distribution. Must be 

850 broadcastable with `probs`, and broadcastable with the rightmost 

851 dimensions of `shape`. 

852 probs: Tensor. The probability of success for the 

853 binomial distribution. Must be broadcastable with `counts` and 

854 broadcastable with the rightmost dimensions of `shape`. 

855 dtype: The type of the output. Default: tf.int32 

856 name: A name for the operation (optional). 

857 

858 Returns: 

859 samples: A Tensor of the specified shape filled with random binomial 

860 values. For each i, each samples[i, ...] is an independent draw from 

861 the binomial distribution on counts[i] trials with probability of 

862 success probs[i]. 

863 """ 

864 dtype = dtypes.as_dtype(dtype) 

865 with ops.name_scope(name, "binomial", [shape, counts, probs]) as name: 

866 counts = ops.convert_to_tensor(counts, name="counts") 

867 probs = ops.convert_to_tensor(probs, name="probs") 

868 shape_tensor = _shape_tensor(shape) 

869 return gen_stateful_random_ops.stateful_random_binomial( 

870 self.state.handle, 

871 self.algorithm, 

872 shape=shape_tensor, 

873 counts=counts, 

874 probs=probs, 

875 dtype=dtype, 

876 name=name) 

877 

878 # TODO(wangpeng): implement other distributions 

879 

880 def _make_int64_keys(self, shape=()): 

881 # New independent keys are generated via 

882 # `new_key[i] = hash(old_key, counter+i)`, which is exactly what 

883 # `uniform_full_int(dtype=int64)` does for PhiloxRandom_64_128_128 and 

884 # ThreeFry_64_64_64. 

885 return self.uniform_full_int(shape=shape, dtype=dtypes.int64) 

886 

887 def make_seeds(self, count=1): 

888 """Generates seeds for stateless random ops. 

889 

890 For example: 

891 

892 ```python 

893 seeds = get_global_generator().make_seeds(count=10) 

894 for i in range(10): 

895 seed = seeds[:, i] 

896 numbers = stateless_random_normal(shape=[2, 3], seed=seed) 

897 ... 

898 ``` 

899 

900 Args: 

901 count: the number of seed pairs (note that stateless random ops need a 

902 pair of seeds to invoke). 

903 

904 Returns: 

905 A tensor of shape [2, count] and dtype int64. 

906 """ 

907 alg = self.algorithm 

908 if alg in (a.value for a in Algorithm): 

909 keys = self._make_int64_keys(shape=[count]) 

910 # The two seeds for stateless random ops don't have individual semantics 

911 # and are scrambled together, so setting one to zero is fine. 

912 zeros = array_ops.zeros_like(keys) 

913 return array_ops_stack.stack([keys, zeros]) 

914 else: 

915 raise ValueError(stateless_random_ops.unsupported_alg_error_msg(alg)) 

916 

917 def split(self, count=1): 

918 """Returns a list of independent `Generator` objects. 

919 

920 Two generators are independent of each other in the sense that the 

921 random-number streams they generate don't have statistically detectable 

922 correlations. The new generators are also independent of the old one. 

923 The old generator's state will be changed (like other random-number 

924 generating methods), so two calls of `split` will return different 

925 new generators. 

926 

927 For example: 

928 

929 ```python 

930 gens = get_global_generator().split(count=10) 

931 for gen in gens: 

932 numbers = gen.normal(shape=[2, 3]) 

933 # ... 

934 gens2 = get_global_generator().split(count=10) 

935 # gens2 will be different from gens 

936 ``` 

937 

938 The new generators will be put on the current device (possible different 

939 from the old generator's), for example: 

940 

941 ```python 

942 with tf.device("/device:CPU:0"): 

943 gen = Generator(seed=1234) # gen is on CPU 

944 with tf.device("/device:GPU:0"): 

945 gens = gen.split(count=10) # gens are on GPU 

946 ``` 

947 

948 Args: 

949 count: the number of generators to return. 

950 

951 Returns: 

952 A list (length `count`) of `Generator` objects independent of each other. 

953 The new generators have the same RNG algorithm as the old one. 

954 """ 

955 def _key_to_state(alg, key): 

956 # Padding with zeros on the left. The zeros will be the counter. 

957 return [0] * (_get_state_size(alg) - 1) + [key] 

958 

959 alg = self.algorithm 

960 if alg in (a.value for a in Algorithm): 

961 keys = self._make_int64_keys(shape=[count]) 

962 return [Generator(state=_key_to_state(alg, key), alg=alg) 

963 for key in array_ops_stack.unstack(keys, num=count)] 

964 else: 

965 raise ValueError(stateless_random_ops.unsupported_alg_error_msg(alg)) 

966 

967 

968# It's not safe to create TF ops before `init_google` is called, so this is 

969# initialized to None and get a value the first time `get_global_generator` is 

970# called. 

971global_generator = None 

972 

973 

974@tf_export("random.get_global_generator", 

975 "random.experimental.get_global_generator") 

976def get_global_generator(): 

977 """Retrieves the global generator. 

978 

979 This function will create the global generator the first time it is called, 

980 and the generator will be placed at the default device at that time, so one 

981 needs to be careful when this function is first called. Using a generator 

982 placed on a less-ideal device will incur performance regression. 

983 

984 Returns: 

985 The global `tf.random.Generator` object. 

986 """ 

987 global global_generator 

988 if global_generator is None: 

989 if config.is_op_determinism_enabled(): 

990 raise RuntimeError('"get_global_generator" cannot be called if ' # pylint: disable=g-doc-exception 

991 "determinism is enabled, unless " 

992 '"set_global_generator" has already been called. ' 

993 'Please call "set_global_generator" first.') 

994 with ops.init_scope(): 

995 global_generator = Generator.from_non_deterministic_state() 

996 return global_generator 

997 

998 

999@tf_export("random.set_global_generator", 

1000 "random.experimental.set_global_generator") 

1001def set_global_generator(generator): 

1002 """Replaces the global generator with another `Generator` object. 

1003 

1004 This function replaces the global generator with the provided `generator` 

1005 object. 

1006 A random number generator utilizes a `tf.Variable` object to store its state. 

1007 The user shall be aware of caveats how `set_global_generator` interacts with 

1008 `tf.function`: 

1009 

1010 - tf.function puts restrictions on Variable creation thus one cannot freely 

1011 create a new random generator instance inside `tf.function`. 

1012 To call `set_global_generator` inside `tf.function`, the generator instance 

1013 must have already been created eagerly. 

1014 - tf.function captures the Variable during trace-compilation, thus a compiled 

1015 f.function will not be affected `set_global_generator` as demonstrated by 

1016 random_test.py/RandomTest.testResetGlobalGeneratorBadWithDefun . 

1017 

1018 For most use cases, avoid calling `set_global_generator` after program 

1019 initialization, and prefer to reset the state of the existing global generator 

1020 instead, such as, 

1021 

1022 >>> rng = tf.random.get_global_generator() 

1023 >>> rng.reset_from_seed(30) 

1024 

1025 

1026 Args: 

1027 generator: the new `Generator` object. 

1028 """ 

1029 global global_generator 

1030 global_generator = generator