Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/shuffle_op.py: 48%

23 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""The implementation of `tf.data.Dataset.shuffle`.""" 

16from tensorflow.python import tf2 

17from tensorflow.python.data.ops import dataset_ops 

18from tensorflow.python.data.util import random_seed 

19from tensorflow.python.eager import context 

20from tensorflow.python.framework import dtypes 

21from tensorflow.python.framework import ops 

22from tensorflow.python.ops import gen_dataset_ops 

23 

24 

25def _shuffle( # pylint: disable=unused-private-name 

26 input_dataset, 

27 buffer_size, 

28 seed=None, 

29 reshuffle_each_iteration=None, 

30 name=None): 

31 return _ShuffleDataset( 

32 input_dataset, buffer_size, seed, reshuffle_each_iteration, name=name) 

33 

34 

35class _ShuffleDataset(dataset_ops.UnaryUnchangedStructureDataset): 

36 """A `Dataset` that randomly shuffles the elements of its input.""" 

37 

38 def __init__(self, 

39 input_dataset, 

40 buffer_size, 

41 seed=None, 

42 reshuffle_each_iteration=None, 

43 name=None): 

44 """See `Dataset.shuffle()` for details.""" 

45 self._input_dataset = input_dataset 

46 self._buffer_size = ops.convert_to_tensor( 

47 buffer_size, dtype=dtypes.int64, name="buffer_size") 

48 self._seed, self._seed2 = random_seed.get_seed(seed) 

49 if reshuffle_each_iteration is None: 

50 reshuffle_each_iteration = True 

51 self._reshuffle_each_iteration = reshuffle_each_iteration 

52 self._name = name 

53 

54 if (tf2.enabled() and 

55 (context.executing_eagerly() or ops.inside_function())): 

56 variant_tensor = gen_dataset_ops.shuffle_dataset_v3( 

57 input_dataset._variant_tensor, # pylint: disable=protected-access 

58 buffer_size=self._buffer_size, 

59 seed=self._seed, 

60 seed2=self._seed2, 

61 seed_generator=gen_dataset_ops.dummy_seed_generator(), 

62 reshuffle_each_iteration=self._reshuffle_each_iteration, 

63 **self._common_args) 

64 else: 

65 variant_tensor = gen_dataset_ops.shuffle_dataset( 

66 input_dataset._variant_tensor, # pylint: disable=protected-access 

67 buffer_size=self._buffer_size, 

68 seed=self._seed, 

69 seed2=self._seed2, 

70 reshuffle_each_iteration=self._reshuffle_each_iteration, 

71 **self._common_args) 

72 super().__init__(input_dataset, variant_tensor)