Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/range_op.py: 35%
37 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The implementation of `tf.data.Dataset.range`."""
17from tensorflow.python.data.ops import dataset_ops
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_spec
21from tensorflow.python.ops import gen_dataset_ops
24def _range(*args, **kwargs): # pylint: disable=unused-private-name
25 return _RangeDataset(*args, **kwargs)
28class _RangeDataset(dataset_ops.DatasetSource):
29 """A `Dataset` of a step separated range of values."""
31 def __init__(self, *args, **kwargs):
32 """See `Dataset.range()` for details."""
33 self._parse_args(*args, **kwargs)
34 self._structure = tensor_spec.TensorSpec([], self._output_type)
35 variant_tensor = gen_dataset_ops.range_dataset(
36 start=self._start,
37 stop=self._stop,
38 step=self._step,
39 **self._common_args)
40 super().__init__(variant_tensor)
42 def _parse_args(self, *args, **kwargs):
43 """Parses arguments according to the same rules as the `range()` builtin."""
44 if len(args) == 1:
45 self._start = self._build_tensor(0, "start")
46 self._stop = self._build_tensor(args[0], "stop")
47 self._step = self._build_tensor(1, "step")
48 elif len(args) == 2:
49 self._start = self._build_tensor(args[0], "start")
50 self._stop = self._build_tensor(args[1], "stop")
51 self._step = self._build_tensor(1, "step")
52 elif len(args) == 3:
53 self._start = self._build_tensor(args[0], "start")
54 self._stop = self._build_tensor(args[1], "stop")
55 self._step = self._build_tensor(args[2], "step")
56 else:
57 raise ValueError(f"Invalid `args`. The length of `args` should be "
58 f"between 1 and 3 but was {len(args)}.")
59 if "output_type" in kwargs:
60 self._output_type = kwargs["output_type"]
61 else:
62 self._output_type = dtypes.int64
63 self._name = kwargs["name"] if "name" in kwargs else None
65 def _build_tensor(self, int64_value, name):
66 return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name)
68 @property
69 def element_spec(self):
70 return self._structure