Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/parsing_ops.py: 32%

57 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Experimental `dataset` API for parsing example.""" 

16from tensorflow.python.data.ops import dataset_ops 

17from tensorflow.python.data.util import structure 

18from tensorflow.python.framework import dtypes 

19from tensorflow.python.framework import sparse_tensor 

20from tensorflow.python.framework import tensor_spec 

21from tensorflow.python.ops import gen_experimental_dataset_ops 

22from tensorflow.python.ops import parsing_ops 

23from tensorflow.python.ops.ragged import ragged_tensor 

24from tensorflow.python.util import deprecation 

25from tensorflow.python.util.tf_export import tf_export 

26 

27 

28class _ParseExampleDataset(dataset_ops.UnaryDataset): 

29 """A `Dataset` that parses `example` dataset into a `dict` dataset.""" 

30 

31 def __init__(self, input_dataset, features, num_parallel_calls, 

32 deterministic): 

33 self._input_dataset = input_dataset 

34 if not structure.are_compatible( 

35 input_dataset.element_spec, 

36 tensor_spec.TensorSpec([None], dtypes.string)): 

37 raise TypeError("Input dataset should be a dataset of vectors of " 

38 f"strings. Instead it is `{input_dataset.element_spec}`.") 

39 self._num_parallel_calls = num_parallel_calls 

40 if deterministic is None: 

41 self._deterministic = "default" 

42 elif deterministic: 

43 self._deterministic = "true" 

44 else: 

45 self._deterministic = "false" 

46 # pylint: disable=protected-access 

47 self._features = parsing_ops._prepend_none_dimension(features) 

48 params = parsing_ops._ParseOpParams.from_features(self._features, [ 

49 parsing_ops.VarLenFeature, parsing_ops.SparseFeature, 

50 parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature, 

51 parsing_ops.RaggedFeature 

52 ]) 

53 # pylint: enable=protected-access 

54 self._sparse_keys = params.sparse_keys 

55 self._sparse_types = params.sparse_types 

56 self._ragged_keys = params.ragged_keys 

57 self._ragged_value_types = params.ragged_value_types 

58 self._ragged_split_types = params.ragged_split_types 

59 self._dense_keys = params.dense_keys 

60 self._dense_defaults = params.dense_defaults_vec 

61 self._dense_shapes = params.dense_shapes_as_proto 

62 self._dense_types = params.dense_types 

63 input_dataset_shape = dataset_ops.get_legacy_output_shapes( 

64 self._input_dataset) 

65 

66 self._element_spec = {} 

67 

68 for (key, value_type) in zip(params.sparse_keys, params.sparse_types): 

69 self._element_spec[key] = sparse_tensor.SparseTensorSpec( 

70 input_dataset_shape.concatenate([None]), value_type) 

71 

72 for (key, value_type, dense_shape) in zip(params.dense_keys, 

73 params.dense_types, 

74 params.dense_shapes): 

75 self._element_spec[key] = tensor_spec.TensorSpec( 

76 input_dataset_shape.concatenate(dense_shape), value_type) 

77 

78 for (key, value_type, splits_type) in zip(params.ragged_keys, 

79 params.ragged_value_types, 

80 params.ragged_split_types): 

81 self._element_spec[key] = ragged_tensor.RaggedTensorSpec( 

82 input_dataset_shape.concatenate([None]), value_type, 1, splits_type) 

83 

84 variant_tensor = ( 

85 gen_experimental_dataset_ops.parse_example_dataset_v2( 

86 self._input_dataset._variant_tensor, # pylint: disable=protected-access 

87 self._num_parallel_calls, 

88 self._dense_defaults, 

89 self._sparse_keys, 

90 self._dense_keys, 

91 self._sparse_types, 

92 self._dense_shapes, 

93 deterministic=self._deterministic, 

94 ragged_keys=self._ragged_keys, 

95 ragged_value_types=self._ragged_value_types, 

96 ragged_split_types=self._ragged_split_types, 

97 **self._flat_structure)) 

98 super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor) 

99 

100 @property 

101 def element_spec(self): 

102 return self._element_spec 

103 

104 

105@tf_export("data.experimental.parse_example_dataset") 

106@deprecation.deprecated( 

107 None, "Use `tf.data.Dataset.map(tf.io.parse_example(...))` instead.") 

108def parse_example_dataset(features, num_parallel_calls=1, deterministic=None): 

109 """A transformation that parses `Example` protos into a `dict` of tensors. 

110 

111 Parses a number of serialized `Example` protos given in `serialized`. We refer 

112 to `serialized` as a batch with `batch_size` many entries of individual 

113 `Example` protos. 

114 

115 This op parses serialized examples into a dictionary mapping keys to `Tensor`, 

116 `SparseTensor`, and `RaggedTensor` objects. `features` is a dict from keys to 

117 `VarLenFeature`, `RaggedFeature`, `SparseFeature`, and `FixedLenFeature` 

118 objects. Each `VarLenFeature` and `SparseFeature` is mapped to a 

119 `SparseTensor`; each `RaggedFeature` is mapped to a `RaggedTensor`; and each 

120 `FixedLenFeature` is mapped to a `Tensor`. See `tf.io.parse_example` for more 

121 details about feature dictionaries. 

122 

123 Args: 

124 features: A `dict` mapping feature keys to `FixedLenFeature`, 

125 `VarLenFeature`, `RaggedFeature`, and `SparseFeature` values. 

126 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 

127 representing the number of parsing processes to call in parallel. 

128 deterministic: (Optional.) A boolean controlling whether determinism 

129 should be traded for performance by allowing elements to be produced out 

130 of order if some parsing calls complete faster than others. If 

131 `deterministic` is `None`, the 

132 `tf.data.Options.deterministic` dataset option (`True` by default) is used 

133 to decide whether to produce elements deterministically. 

134 

135 Returns: 

136 A dataset transformation function, which can be passed to 

137 `tf.data.Dataset.apply`. 

138 

139 Raises: 

140 ValueError: if features argument is None. 

141 """ 

142 if features is None: 

143 raise ValueError("Argument `features` is required, but not specified.") 

144 

145 def _apply_fn(dataset): 

146 """Function from `Dataset` to `Dataset` that applies the transformation.""" 

147 out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls, 

148 deterministic) 

149 if any( 

150 isinstance(feature, parsing_ops.SparseFeature) or 

151 isinstance(feature, parsing_ops.RaggedFeature) 

152 for feature in features.values()): 

153 # pylint: disable=protected-access 

154 # pylint: disable=g-long-lambda 

155 out_dataset = out_dataset.map( 

156 lambda x: parsing_ops._construct_tensors_for_composite_features( 

157 features, x), 

158 num_parallel_calls=num_parallel_calls) 

159 return out_dataset 

160 

161 return _apply_fn