Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/random_op.py: 56%

25 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""The implementation of `tf.data.Dataset.random`.""" 

16 

17import warnings 

18 

19from tensorflow.python import tf2 

20from tensorflow.python.data.ops import dataset_ops 

21from tensorflow.python.data.util import random_seed 

22from tensorflow.python.framework import dtypes 

23from tensorflow.python.framework import tensor_spec 

24from tensorflow.python.ops import gen_dataset_ops 

25from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 

26 

27 

28def _random( # pylint: disable=unused-private-name 

29 seed=None, 

30 rerandomize_each_iteration=None, 

31 name=None): 

32 """See `Dataset.random()` for details.""" 

33 return _RandomDataset( 

34 seed=seed, 

35 rerandomize_each_iteration=rerandomize_each_iteration, 

36 name=name) 

37 

38 

39class _RandomDataset(dataset_ops.DatasetSource): 

40 """A `Dataset` of pseudorandom values.""" 

41 

42 def __init__(self, seed=None, rerandomize_each_iteration=None, name=None): 

43 """A `Dataset` of pseudorandom values.""" 

44 self._seed, self._seed2 = random_seed.get_seed(seed) 

45 self._rerandomize = rerandomize_each_iteration 

46 self._name = name 

47 if rerandomize_each_iteration: 

48 if not tf2.enabled(): 

49 warnings.warn("In TF 1, the `rerandomize_each_iteration=True` option " 

50 "is only supported for repeat-based epochs.") 

51 variant_tensor = ged_ops.random_dataset_v2( 

52 seed=self._seed, 

53 seed2=self._seed2, 

54 seed_generator=gen_dataset_ops.dummy_seed_generator(), 

55 rerandomize_each_iteration=self._rerandomize, 

56 **self._common_args) 

57 else: 

58 variant_tensor = ged_ops.random_dataset( 

59 seed=self._seed, seed2=self._seed2, **self._common_args) 

60 super().__init__(variant_tensor) 

61 

62 @property 

63 def element_spec(self): 

64 return tensor_spec.TensorSpec([], dtypes.int64)