Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/writers.py: 65%
23 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for tf.data writers."""
16from tensorflow.python.data.ops import dataset_ops
17from tensorflow.python.data.util import convert
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_spec
21from tensorflow.python.ops import gen_experimental_dataset_ops
22from tensorflow.python.types import data as data_types
23from tensorflow.python.util import deprecation
24from tensorflow.python.util.tf_export import tf_export
27@tf_export("data.experimental.TFRecordWriter")
28@deprecation.deprecated(
29 None, "To write TFRecords to disk, use `tf.io.TFRecordWriter`. To save "
30 "and load the contents of a dataset, use `tf.data.experimental.save` "
31 "and `tf.data.experimental.load`")
32class TFRecordWriter:
33 """Writes a dataset to a TFRecord file.
35 The elements of the dataset must be scalar strings. To serialize dataset
36 elements as strings, you can use the `tf.io.serialize_tensor` function.
38 ```python
39 dataset = tf.data.Dataset.range(3)
40 dataset = dataset.map(tf.io.serialize_tensor)
41 writer = tf.data.experimental.TFRecordWriter("/path/to/file.tfrecord")
42 writer.write(dataset)
43 ```
45 To read back the elements, use `TFRecordDataset`.
47 ```python
48 dataset = tf.data.TFRecordDataset("/path/to/file.tfrecord")
49 dataset = dataset.map(lambda x: tf.io.parse_tensor(x, tf.int64))
50 ```
52 To shard a `dataset` across multiple TFRecord files:
54 ```python
55 dataset = ... # dataset to be written
57 def reduce_func(key, dataset):
58 filename = tf.strings.join([PATH_PREFIX, tf.strings.as_string(key)])
59 writer = tf.data.experimental.TFRecordWriter(filename)
60 writer.write(dataset.map(lambda _, x: x))
61 return tf.data.Dataset.from_tensors(filename)
63 dataset = dataset.enumerate()
64 dataset = dataset.apply(tf.data.experimental.group_by_window(
65 lambda i, _: i % NUM_SHARDS, reduce_func, tf.int64.max
66 ))
68 # Iterate through the dataset to trigger data writing.
69 for _ in dataset:
70 pass
71 ```
72 """
74 def __init__(self, filename, compression_type=None):
75 """Initializes a `TFRecordWriter`.
77 Args:
78 filename: a string path indicating where to write the TFRecord data.
79 compression_type: (Optional.) a string indicating what type of compression
80 to use when writing the file. See `tf.io.TFRecordCompressionType` for
81 what types of compression are available. Defaults to `None`.
82 """
83 self._filename = ops.convert_to_tensor(
84 filename, dtypes.string, name="filename")
85 self._compression_type = convert.optional_param_to_tensor(
86 "compression_type",
87 compression_type,
88 argument_default="",
89 argument_dtype=dtypes.string)
91 def write(self, dataset):
92 """Writes a dataset to a TFRecord file.
94 An operation that writes the content of the specified dataset to the file
95 specified in the constructor.
97 If the file exists, it will be overwritten.
99 Args:
100 dataset: a `tf.data.Dataset` whose elements are to be written to a file
102 Returns:
103 In graph mode, this returns an operation which when executed performs the
104 write. In eager mode, the write is performed by the method itself and
105 there is no return value.
107 Raises
108 TypeError: if `dataset` is not a `tf.data.Dataset`.
109 TypeError: if the elements produced by the dataset are not scalar strings.
110 """
111 if not isinstance(dataset, data_types.DatasetV2):
112 raise TypeError(
113 f"Invalid `dataset.` Expected a `tf.data.Dataset` object but got "
114 f"{type(dataset)}."
115 )
116 if not dataset_ops.get_structure(dataset).is_compatible_with(
117 tensor_spec.TensorSpec([], dtypes.string)):
118 raise TypeError(
119 f"Invalid `dataset`. Expected a`dataset` that produces scalar "
120 f"`tf.string` elements, but got a dataset which produces elements "
121 f"with shapes {dataset_ops.get_legacy_output_shapes(dataset)} and "
122 f"types {dataset_ops.get_legacy_output_types(dataset)}.")
123 # pylint: disable=protected-access
124 dataset = dataset._apply_debug_options()
125 return gen_experimental_dataset_ops.dataset_to_tf_record(
126 dataset._variant_tensor, self._filename, self._compression_type)