Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/range_op.py: 35%

37 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""The implementation of `tf.data.Dataset.range`.""" 

16 

17from tensorflow.python.data.ops import dataset_ops 

18from tensorflow.python.framework import dtypes 

19from tensorflow.python.framework import ops 

20from tensorflow.python.framework import tensor_spec 

21from tensorflow.python.ops import gen_dataset_ops 

22 

23 

24def _range(*args, **kwargs): # pylint: disable=unused-private-name 

25 return _RangeDataset(*args, **kwargs) 

26 

27 

28class _RangeDataset(dataset_ops.DatasetSource): 

29 """A `Dataset` of a step separated range of values.""" 

30 

31 def __init__(self, *args, **kwargs): 

32 """See `Dataset.range()` for details.""" 

33 self._parse_args(*args, **kwargs) 

34 self._structure = tensor_spec.TensorSpec([], self._output_type) 

35 variant_tensor = gen_dataset_ops.range_dataset( 

36 start=self._start, 

37 stop=self._stop, 

38 step=self._step, 

39 **self._common_args) 

40 super().__init__(variant_tensor) 

41 

42 def _parse_args(self, *args, **kwargs): 

43 """Parses arguments according to the same rules as the `range()` builtin.""" 

44 if len(args) == 1: 

45 self._start = self._build_tensor(0, "start") 

46 self._stop = self._build_tensor(args[0], "stop") 

47 self._step = self._build_tensor(1, "step") 

48 elif len(args) == 2: 

49 self._start = self._build_tensor(args[0], "start") 

50 self._stop = self._build_tensor(args[1], "stop") 

51 self._step = self._build_tensor(1, "step") 

52 elif len(args) == 3: 

53 self._start = self._build_tensor(args[0], "start") 

54 self._stop = self._build_tensor(args[1], "stop") 

55 self._step = self._build_tensor(args[2], "step") 

56 else: 

57 raise ValueError(f"Invalid `args`. The length of `args` should be " 

58 f"between 1 and 3 but was {len(args)}.") 

59 if "output_type" in kwargs: 

60 self._output_type = kwargs["output_type"] 

61 else: 

62 self._output_type = dtypes.int64 

63 self._name = kwargs["name"] if "name" in kwargs else None 

64 

65 def _build_tensor(self, int64_value, name): 

66 return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name) 

67 

68 @property 

69 def element_spec(self): 

70 return self._structure