Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/prefetch_op.py: 45%
20 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The implementation of `tf.data.Dataset.prefetch`."""
17from tensorflow.python.data.ops import dataset_ops
18from tensorflow.python.data.ops import debug_mode
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.ops import gen_dataset_ops
24def _prefetch(input_dataset, buffer_size, name=None): # pylint: disable=unused-private-name
25 """See `Dataset.prefetch()` for details."""
26 if debug_mode.DEBUG_MODE:
27 return input_dataset
28 return _PrefetchDataset(input_dataset, buffer_size, name=name)
31class _PrefetchDataset(dataset_ops.UnaryUnchangedStructureDataset):
32 """A `Dataset` that asynchronously prefetches its input."""
34 def __init__(self, input_dataset, buffer_size, slack_period=None, name=None):
35 """See `Dataset.prefetch()` for details."""
36 self._input_dataset = input_dataset
37 if buffer_size is None:
38 buffer_size = dataset_ops.AUTOTUNE
39 self._buffer_size = ops.convert_to_tensor(
40 buffer_size, dtype=dtypes.int64, name="buffer_size")
41 self._name = name
42 # pylint: disable=protected-access
43 # We colocate the prefetch dataset with its input as this collocation only
44 # happens automatically in graph mode.
45 with ops.colocate_with(input_dataset._variant_tensor):
46 variant_tensor = gen_dataset_ops.prefetch_dataset(
47 input_dataset._variant_tensor,
48 buffer_size=self._buffer_size,
49 slack_period=slack_period,
50 **self._common_args)
51 super().__init__(input_dataset, variant_tensor)