Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/writers.py: 65%

23 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Python wrappers for tf.data writers.""" 

16from tensorflow.python.data.ops import dataset_ops 

17from tensorflow.python.data.util import convert 

18from tensorflow.python.framework import dtypes 

19from tensorflow.python.framework import ops 

20from tensorflow.python.framework import tensor_spec 

21from tensorflow.python.ops import gen_experimental_dataset_ops 

22from tensorflow.python.types import data as data_types 

23from tensorflow.python.util import deprecation 

24from tensorflow.python.util.tf_export import tf_export 

25 

26 

27@tf_export("data.experimental.TFRecordWriter") 

28@deprecation.deprecated( 

29 None, "To write TFRecords to disk, use `tf.io.TFRecordWriter`. To save " 

30 "and load the contents of a dataset, use `tf.data.experimental.save` " 

31 "and `tf.data.experimental.load`") 

32class TFRecordWriter: 

33 """Writes a dataset to a TFRecord file. 

34 

35 The elements of the dataset must be scalar strings. To serialize dataset 

36 elements as strings, you can use the `tf.io.serialize_tensor` function. 

37 

38 ```python 

39 dataset = tf.data.Dataset.range(3) 

40 dataset = dataset.map(tf.io.serialize_tensor) 

41 writer = tf.data.experimental.TFRecordWriter("/path/to/file.tfrecord") 

42 writer.write(dataset) 

43 ``` 

44 

45 To read back the elements, use `TFRecordDataset`. 

46 

47 ```python 

48 dataset = tf.data.TFRecordDataset("/path/to/file.tfrecord") 

49 dataset = dataset.map(lambda x: tf.io.parse_tensor(x, tf.int64)) 

50 ``` 

51 

52 To shard a `dataset` across multiple TFRecord files: 

53 

54 ```python 

55 dataset = ... # dataset to be written 

56 

57 def reduce_func(key, dataset): 

58 filename = tf.strings.join([PATH_PREFIX, tf.strings.as_string(key)]) 

59 writer = tf.data.experimental.TFRecordWriter(filename) 

60 writer.write(dataset.map(lambda _, x: x)) 

61 return tf.data.Dataset.from_tensors(filename) 

62 

63 dataset = dataset.enumerate() 

64 dataset = dataset.apply(tf.data.experimental.group_by_window( 

65 lambda i, _: i % NUM_SHARDS, reduce_func, tf.int64.max 

66 )) 

67 

68 # Iterate through the dataset to trigger data writing. 

69 for _ in dataset: 

70 pass 

71 ``` 

72 """ 

73 

74 def __init__(self, filename, compression_type=None): 

75 """Initializes a `TFRecordWriter`. 

76 

77 Args: 

78 filename: a string path indicating where to write the TFRecord data. 

79 compression_type: (Optional.) a string indicating what type of compression 

80 to use when writing the file. See `tf.io.TFRecordCompressionType` for 

81 what types of compression are available. Defaults to `None`. 

82 """ 

83 self._filename = ops.convert_to_tensor( 

84 filename, dtypes.string, name="filename") 

85 self._compression_type = convert.optional_param_to_tensor( 

86 "compression_type", 

87 compression_type, 

88 argument_default="", 

89 argument_dtype=dtypes.string) 

90 

91 def write(self, dataset): 

92 """Writes a dataset to a TFRecord file. 

93 

94 An operation that writes the content of the specified dataset to the file 

95 specified in the constructor. 

96 

97 If the file exists, it will be overwritten. 

98 

99 Args: 

100 dataset: a `tf.data.Dataset` whose elements are to be written to a file 

101 

102 Returns: 

103 In graph mode, this returns an operation which when executed performs the 

104 write. In eager mode, the write is performed by the method itself and 

105 there is no return value. 

106 

107 Raises 

108 TypeError: if `dataset` is not a `tf.data.Dataset`. 

109 TypeError: if the elements produced by the dataset are not scalar strings. 

110 """ 

111 if not isinstance(dataset, data_types.DatasetV2): 

112 raise TypeError( 

113 f"Invalid `dataset.` Expected a `tf.data.Dataset` object but got " 

114 f"{type(dataset)}." 

115 ) 

116 if not dataset_ops.get_structure(dataset).is_compatible_with( 

117 tensor_spec.TensorSpec([], dtypes.string)): 

118 raise TypeError( 

119 f"Invalid `dataset`. Expected a`dataset` that produces scalar " 

120 f"`tf.string` elements, but got a dataset which produces elements " 

121 f"with shapes {dataset_ops.get_legacy_output_shapes(dataset)} and " 

122 f"types {dataset_ops.get_legacy_output_types(dataset)}.") 

123 # pylint: disable=protected-access 

124 dataset = dataset._apply_debug_options() 

125 return gen_experimental_dataset_ops.dataset_to_tf_record( 

126 dataset._variant_tensor, self._filename, self._compression_type)