Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/numpy_ops/np_random.py: 2%

64 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-05 06:32 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Random functions.""" 

16 

17# pylint: disable=g-direct-tensorflow-import 

18 

19import numpy as onp 

20 

21from tensorflow.python.framework import random_seed 

22from tensorflow.python.ops import array_ops 

23from tensorflow.python.ops import random_ops 

24from tensorflow.python.ops.numpy_ops import np_array_ops 

25from tensorflow.python.ops.numpy_ops import np_dtypes 

26from tensorflow.python.ops.numpy_ops import np_utils 

27 

28# TODO(agarwal): deprecate this. 

29DEFAULT_RANDN_DTYPE = onp.float32 

30 

31 

32@np_utils.np_doc('random.seed') 

33def seed(s): 

34 """Sets the seed for the random number generator. 

35 

36 Uses `tf.set_random_seed`. 

37 

38 Args: 

39 s: an integer. 

40 """ 

41 try: 

42 s = int(s) 

43 except TypeError: 

44 # TODO(wangpeng): support this? 

45 raise ValueError( 

46 f'Argument `s` got an invalid value {s}. Only integers are supported.') 

47 random_seed.set_seed(s) 

48 

49 

50@np_utils.np_doc('random.randn') 

51def randn(*args): 

52 """Returns samples from a normal distribution. 

53 

54 Uses `tf.random_normal`. 

55 

56 Args: 

57 *args: The shape of the output array. 

58 

59 Returns: 

60 An ndarray with shape `args` and dtype `float64`. 

61 """ 

62 return standard_normal(size=args) 

63 

64 

65@np_utils.np_doc('random.standard_normal') 

66def standard_normal(size=None): 

67 # TODO(wangpeng): Use new stateful RNG 

68 if size is None: 

69 size = () 

70 elif np_utils.isscalar(size): 

71 size = (size,) 

72 dtype = np_dtypes.default_float_type() 

73 return random_ops.random_normal(size, dtype=dtype) 

74 

75 

76@np_utils.np_doc('random.uniform') 

77def uniform(low=0.0, high=1.0, size=None): 

78 dtype = np_dtypes.default_float_type() 

79 low = np_array_ops.asarray(low, dtype=dtype) 

80 high = np_array_ops.asarray(high, dtype=dtype) 

81 if size is None: 

82 size = array_ops.broadcast_dynamic_shape(low.shape, high.shape) 

83 return random_ops.random_uniform( 

84 shape=size, minval=low, maxval=high, dtype=dtype) 

85 

86 

87@np_utils.np_doc('random.poisson') 

88def poisson(lam=1.0, size=None): 

89 if size is None: 

90 size = () 

91 elif np_utils.isscalar(size): 

92 size = (size,) 

93 return random_ops.random_poisson(shape=size, lam=lam, dtype=np_dtypes.int_) 

94 

95 

96@np_utils.np_doc('random.random') 

97def random(size=None): 

98 return uniform(0., 1., size) 

99 

100 

101@np_utils.np_doc('random.rand') 

102def rand(*size): 

103 return uniform(0., 1., size) 

104 

105 

106@np_utils.np_doc('random.randint') 

107def randint(low, high=None, size=None, dtype=onp.int64): # pylint: disable=missing-function-docstring 

108 low = int(low) 

109 if high is None: 

110 high = low 

111 low = 0 

112 if size is None: 

113 size = () 

114 elif isinstance(size, int): 

115 size = (size,) 

116 dtype_orig = dtype 

117 dtype = np_utils.result_type(dtype) 

118 accepted_dtypes = (onp.int32, onp.int64) 

119 if dtype not in accepted_dtypes: 

120 raise ValueError( 

121 f'Argument `dtype` got an invalid value {dtype_orig}. Only those ' 

122 f'convertible to {accepted_dtypes} are supported.') 

123 return random_ops.random_uniform( 

124 shape=size, minval=low, maxval=high, dtype=dtype)