Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/get_single_element.py: 70%

10 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Python wrappers for Datasets and Iterators.""" 

16from tensorflow.python.types import data as data_types 

17from tensorflow.python.util import deprecation 

18from tensorflow.python.util.tf_export import tf_export 

19 

20 

21@deprecation.deprecated(None, "Use `tf.data.Dataset.get_single_element()`.") 

22@tf_export("data.experimental.get_single_element") 

23def get_single_element(dataset): 

24 """Returns the single element of the `dataset` as a nested structure of tensors. 

25 

26 The function enables you to use a `tf.data.Dataset` in a stateless 

27 "tensor-in tensor-out" expression, without creating an iterator. 

28 This facilitates the ease of data transformation on tensors using the 

29 optimized `tf.data.Dataset` abstraction on top of them. 

30 

31 For example, lets consider a `preprocessing_fn` which would take as an 

32 input the raw features and returns the processed feature along with 

33 it's label. 

34 

35 ```python 

36 def preprocessing_fn(raw_feature): 

37 # ... the raw_feature is preprocessed as per the use-case 

38 return feature 

39 

40 raw_features = ... # input batch of BATCH_SIZE elements. 

41 dataset = (tf.data.Dataset.from_tensor_slices(raw_features) 

42 .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) 

43 .batch(BATCH_SIZE)) 

44 

45 processed_features = tf.data.experimental.get_single_element(dataset) 

46 ``` 

47 

48 In the above example, the `raw_features` tensor of length=BATCH_SIZE 

49 was converted to a `tf.data.Dataset`. Next, each of the `raw_feature` was 

50 mapped using the `preprocessing_fn` and the processed features were 

51 grouped into a single batch. The final `dataset` contains only one element 

52 which is a batch of all the processed features. 

53 

54 NOTE: The `dataset` should contain only one element. 

55 

56 Now, instead of creating an iterator for the `dataset` and retrieving the 

57 batch of features, the `tf.data.experimental.get_single_element()` function 

58 is used to skip the iterator creation process and directly output the batch 

59 of features. 

60 

61 This can be particularly useful when your tensor transformations are 

62 expressed as `tf.data.Dataset` operations, and you want to use those 

63 transformations while serving your model. 

64 

65 # Keras 

66 

67 ```python 

68 

69 model = ... # A pre-built or custom model 

70 

71 class PreprocessingModel(tf.keras.Model): 

72 def __init__(self, model): 

73 super().__init__(self) 

74 self.model = model 

75 

76 @tf.function(input_signature=[...]) 

77 def serving_fn(self, data): 

78 ds = tf.data.Dataset.from_tensor_slices(data) 

79 ds = ds.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) 

80 ds = ds.batch(batch_size=BATCH_SIZE) 

81 return tf.argmax( 

82 self.model(tf.data.experimental.get_single_element(ds)), 

83 axis=-1 

84 ) 

85 

86 preprocessing_model = PreprocessingModel(model) 

87 your_exported_model_dir = ... # save the model to this path. 

88 tf.saved_model.save(preprocessing_model, your_exported_model_dir, 

89 signatures={'serving_default': preprocessing_model.serving_fn}) 

90 ``` 

91 

92 # Estimator 

93 

94 In the case of estimators, you need to generally define a `serving_input_fn` 

95 which would require the features to be processed by the model while 

96 inferencing. 

97 

98 ```python 

99 def serving_input_fn(): 

100 

101 raw_feature_spec = ... # Spec for the raw_features 

102 input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn( 

103 raw_feature_spec, default_batch_size=None) 

104 ) 

105 serving_input_receiver = input_fn() 

106 raw_features = serving_input_receiver.features 

107 

108 def preprocessing_fn(raw_feature): 

109 # ... the raw_feature is preprocessed as per the use-case 

110 return feature 

111 

112 dataset = (tf.data.Dataset.from_tensor_slices(raw_features) 

113 .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) 

114 .batch(BATCH_SIZE)) 

115 

116 processed_features = tf.data.experimental.get_single_element(dataset) 

117 

118 # Please note that the value of `BATCH_SIZE` should be equal to 

119 # the size of the leading dimension of `raw_features`. This ensures 

120 # that `dataset` has only element, which is a pre-requisite for 

121 # using `tf.data.experimental.get_single_element(dataset)`. 

122 

123 return tf.estimator.export.ServingInputReceiver( 

124 processed_features, serving_input_receiver.receiver_tensors) 

125 

126 estimator = ... # A pre-built or custom estimator 

127 estimator.export_saved_model(your_exported_model_dir, serving_input_fn) 

128 ``` 

129 

130 Args: 

131 dataset: A `tf.data.Dataset` object containing a single element. 

132 

133 Returns: 

134 A nested structure of `tf.Tensor` objects, corresponding to the single 

135 element of `dataset`. 

136 

137 Raises: 

138 TypeError: if `dataset` is not a `tf.data.Dataset` object. 

139 InvalidArgumentError: (at runtime) if `dataset` does not contain exactly 

140 one element. 

141 """ 

142 if not isinstance(dataset, data_types.DatasetV2): 

143 raise TypeError( 

144 f"Invalid `dataset`. Expected a `tf.data.Dataset` object " 

145 f"but got {type(dataset)}.") 

146 

147 return dataset.get_single_element()