Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/get_single_element.py: 70%
10 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for Datasets and Iterators."""
16from tensorflow.python.types import data as data_types
17from tensorflow.python.util import deprecation
18from tensorflow.python.util.tf_export import tf_export
21@deprecation.deprecated(None, "Use `tf.data.Dataset.get_single_element()`.")
22@tf_export("data.experimental.get_single_element")
23def get_single_element(dataset):
24 """Returns the single element of the `dataset` as a nested structure of tensors.
26 The function enables you to use a `tf.data.Dataset` in a stateless
27 "tensor-in tensor-out" expression, without creating an iterator.
28 This facilitates the ease of data transformation on tensors using the
29 optimized `tf.data.Dataset` abstraction on top of them.
31 For example, lets consider a `preprocessing_fn` which would take as an
32 input the raw features and returns the processed feature along with
33 it's label.
35 ```python
36 def preprocessing_fn(raw_feature):
37 # ... the raw_feature is preprocessed as per the use-case
38 return feature
40 raw_features = ... # input batch of BATCH_SIZE elements.
41 dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
42 .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
43 .batch(BATCH_SIZE))
45 processed_features = tf.data.experimental.get_single_element(dataset)
46 ```
48 In the above example, the `raw_features` tensor of length=BATCH_SIZE
49 was converted to a `tf.data.Dataset`. Next, each of the `raw_feature` was
50 mapped using the `preprocessing_fn` and the processed features were
51 grouped into a single batch. The final `dataset` contains only one element
52 which is a batch of all the processed features.
54 NOTE: The `dataset` should contain only one element.
56 Now, instead of creating an iterator for the `dataset` and retrieving the
57 batch of features, the `tf.data.experimental.get_single_element()` function
58 is used to skip the iterator creation process and directly output the batch
59 of features.
61 This can be particularly useful when your tensor transformations are
62 expressed as `tf.data.Dataset` operations, and you want to use those
63 transformations while serving your model.
65 # Keras
67 ```python
69 model = ... # A pre-built or custom model
71 class PreprocessingModel(tf.keras.Model):
72 def __init__(self, model):
73 super().__init__(self)
74 self.model = model
76 @tf.function(input_signature=[...])
77 def serving_fn(self, data):
78 ds = tf.data.Dataset.from_tensor_slices(data)
79 ds = ds.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
80 ds = ds.batch(batch_size=BATCH_SIZE)
81 return tf.argmax(
82 self.model(tf.data.experimental.get_single_element(ds)),
83 axis=-1
84 )
86 preprocessing_model = PreprocessingModel(model)
87 your_exported_model_dir = ... # save the model to this path.
88 tf.saved_model.save(preprocessing_model, your_exported_model_dir,
89 signatures={'serving_default': preprocessing_model.serving_fn})
90 ```
92 # Estimator
94 In the case of estimators, you need to generally define a `serving_input_fn`
95 which would require the features to be processed by the model while
96 inferencing.
98 ```python
99 def serving_input_fn():
101 raw_feature_spec = ... # Spec for the raw_features
102 input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
103 raw_feature_spec, default_batch_size=None)
104 )
105 serving_input_receiver = input_fn()
106 raw_features = serving_input_receiver.features
108 def preprocessing_fn(raw_feature):
109 # ... the raw_feature is preprocessed as per the use-case
110 return feature
112 dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
113 .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
114 .batch(BATCH_SIZE))
116 processed_features = tf.data.experimental.get_single_element(dataset)
118 # Please note that the value of `BATCH_SIZE` should be equal to
119 # the size of the leading dimension of `raw_features`. This ensures
120 # that `dataset` has only element, which is a pre-requisite for
121 # using `tf.data.experimental.get_single_element(dataset)`.
123 return tf.estimator.export.ServingInputReceiver(
124 processed_features, serving_input_receiver.receiver_tensors)
126 estimator = ... # A pre-built or custom estimator
127 estimator.export_saved_model(your_exported_model_dir, serving_input_fn)
128 ```
130 Args:
131 dataset: A `tf.data.Dataset` object containing a single element.
133 Returns:
134 A nested structure of `tf.Tensor` objects, corresponding to the single
135 element of `dataset`.
137 Raises:
138 TypeError: if `dataset` is not a `tf.data.Dataset` object.
139 InvalidArgumentError: (at runtime) if `dataset` does not contain exactly
140 one element.
141 """
142 if not isinstance(dataset, data_types.DatasetV2):
143 raise TypeError(
144 f"Invalid `dataset`. Expected a `tf.data.Dataset` object "
145 f"but got {type(dataset)}.")
147 return dataset.get_single_element()