Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/optional_ops.py: 62%
82 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A type for representing values that may or may not exist."""
16import abc
18from tensorflow.core.protobuf import struct_pb2
19from tensorflow.python.data.util import structure
20from tensorflow.python.framework import composite_tensor
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_spec
24from tensorflow.python.framework import type_spec
25from tensorflow.python.ops import gen_optional_ops
26from tensorflow.python.saved_model import nested_structure_coder
27from tensorflow.python.util import deprecation
28from tensorflow.python.util.tf_export import tf_export
31@tf_export("experimental.Optional", "data.experimental.Optional")
32@deprecation.deprecated_endpoints("data.experimental.Optional")
33class Optional(composite_tensor.CompositeTensor, metaclass=abc.ABCMeta):
34 """Represents a value that may or may not be present.
36 A `tf.experimental.Optional` can represent the result of an operation that may
37 fail as a value, rather than raising an exception and halting execution. For
38 example, `tf.data.Iterator.get_next_as_optional()` returns a
39 `tf.experimental.Optional` that either contains the next element of an
40 iterator if one exists, or an "empty" value that indicates the end of the
41 sequence has been reached.
43 `tf.experimental.Optional` can only be used with values that are convertible
44 to `tf.Tensor` or `tf.CompositeTensor`.
46 One can create a `tf.experimental.Optional` from a value using the
47 `from_value()` method:
49 >>> optional = tf.experimental.Optional.from_value(42)
50 >>> print(optional.has_value())
51 tf.Tensor(True, shape=(), dtype=bool)
52 >>> print(optional.get_value())
53 tf.Tensor(42, shape=(), dtype=int32)
55 or without a value using the `empty()` method:
57 >>> optional = tf.experimental.Optional.empty(
58 ... tf.TensorSpec(shape=(), dtype=tf.int32, name=None))
59 >>> print(optional.has_value())
60 tf.Tensor(False, shape=(), dtype=bool)
61 """
63 @abc.abstractmethod
64 def has_value(self, name=None):
65 """Returns a tensor that evaluates to `True` if this optional has a value.
67 >>> optional = tf.experimental.Optional.from_value(42)
68 >>> print(optional.has_value())
69 tf.Tensor(True, shape=(), dtype=bool)
71 Args:
72 name: (Optional.) A name for the created operation.
74 Returns:
75 A scalar `tf.Tensor` of type `tf.bool`.
76 """
77 raise NotImplementedError("Optional.has_value()")
79 @abc.abstractmethod
80 def get_value(self, name=None):
81 """Returns the value wrapped by this optional.
83 If this optional does not have a value (i.e. `self.has_value()` evaluates to
84 `False`), this operation will raise `tf.errors.InvalidArgumentError` at
85 runtime.
87 >>> optional = tf.experimental.Optional.from_value(42)
88 >>> print(optional.get_value())
89 tf.Tensor(42, shape=(), dtype=int32)
91 Args:
92 name: (Optional.) A name for the created operation.
94 Returns:
95 The wrapped value.
96 """
97 raise NotImplementedError("Optional.get_value()")
99 @abc.abstractproperty
100 def element_spec(self):
101 """The type specification of an element of this optional.
103 >>> optional = tf.experimental.Optional.from_value(42)
104 >>> print(optional.element_spec)
105 tf.TensorSpec(shape=(), dtype=tf.int32, name=None)
107 Returns:
108 A (nested) structure of `tf.TypeSpec` objects matching the structure of an
109 element of this optional, specifying the type of individual components.
110 """
111 raise NotImplementedError("Optional.element_spec")
113 @staticmethod
114 def empty(element_spec):
115 """Returns an `Optional` that has no value.
117 NOTE: This method takes an argument that defines the structure of the value
118 that would be contained in the returned `Optional` if it had a value.
120 >>> optional = tf.experimental.Optional.empty(
121 ... tf.TensorSpec(shape=(), dtype=tf.int32, name=None))
122 >>> print(optional.has_value())
123 tf.Tensor(False, shape=(), dtype=bool)
125 Args:
126 element_spec: A (nested) structure of `tf.TypeSpec` objects matching the
127 structure of an element of this optional.
129 Returns:
130 A `tf.experimental.Optional` with no value.
131 """
132 return _OptionalImpl(gen_optional_ops.optional_none(), element_spec)
134 @staticmethod
135 def from_value(value):
136 """Returns a `tf.experimental.Optional` that wraps the given value.
138 >>> optional = tf.experimental.Optional.from_value(42)
139 >>> print(optional.has_value())
140 tf.Tensor(True, shape=(), dtype=bool)
141 >>> print(optional.get_value())
142 tf.Tensor(42, shape=(), dtype=int32)
144 Args:
145 value: A value to wrap. The value must be convertible to `tf.Tensor` or
146 `tf.CompositeTensor`.
148 Returns:
149 A `tf.experimental.Optional` that wraps `value`.
150 """
151 with ops.name_scope("optional") as scope:
152 with ops.name_scope("value"):
153 element_spec = structure.type_spec_from_value(value)
154 encoded_value = structure.to_tensor_list(element_spec, value)
156 return _OptionalImpl(
157 gen_optional_ops.optional_from_value(encoded_value, name=scope),
158 element_spec,
159 )
162class _OptionalImpl(Optional):
163 """Concrete implementation of `tf.experimental.Optional`.
165 NOTE(mrry): This implementation is kept private, to avoid defining
166 `Optional.__init__()` in the public API.
167 """
169 def __init__(self, variant_tensor, element_spec):
170 super().__init__()
171 self._variant_tensor = variant_tensor
172 self._element_spec = element_spec
174 def has_value(self, name=None):
175 with ops.colocate_with(self._variant_tensor):
176 return gen_optional_ops.optional_has_value(
177 self._variant_tensor, name=name
178 )
180 def get_value(self, name=None):
181 # TODO(b/110122868): Consolidate the restructuring logic with similar logic
182 # in `Iterator.get_next()` and `StructuredFunctionWrapper`.
183 with ops.name_scope(name, "OptionalGetValue",
184 [self._variant_tensor]) as scope:
185 with ops.colocate_with(self._variant_tensor):
186 result = gen_optional_ops.optional_get_value(
187 self._variant_tensor,
188 name=scope,
189 output_types=structure.get_flat_tensor_types(self._element_spec),
190 output_shapes=structure.get_flat_tensor_shapes(self._element_spec),
191 )
192 # NOTE: We do not colocate the deserialization of composite tensors
193 # because not all ops are guaranteed to have non-GPU kernels.
194 return structure.from_tensor_list(self._element_spec, result)
196 @property
197 def element_spec(self):
198 return self._element_spec
200 @property
201 def _type_spec(self):
202 return OptionalSpec.from_value(self)
205@tf_export(
206 "OptionalSpec", v1=["OptionalSpec", "data.experimental.OptionalStructure"])
207class OptionalSpec(type_spec.TypeSpec):
208 """Type specification for `tf.experimental.Optional`.
210 For instance, `tf.OptionalSpec` can be used to define a tf.function that takes
211 `tf.experimental.Optional` as an input argument:
213 >>> @tf.function(input_signature=[tf.OptionalSpec(
214 ... tf.TensorSpec(shape=(), dtype=tf.int32, name=None))])
215 ... def maybe_square(optional):
216 ... if optional.has_value():
217 ... x = optional.get_value()
218 ... return x * x
219 ... return -1
220 >>> optional = tf.experimental.Optional.from_value(5)
221 >>> print(maybe_square(optional))
222 tf.Tensor(25, shape=(), dtype=int32)
224 Attributes:
225 element_spec: A (nested) structure of `TypeSpec` objects that represents the
226 type specification of the optional element.
227 """
229 __slots__ = ["_element_spec"]
231 def __init__(self, element_spec):
232 super().__init__()
233 self._element_spec = element_spec
235 @property
236 def value_type(self):
237 return _OptionalImpl
239 def _serialize(self):
240 return (self._element_spec,)
242 @property
243 def _component_specs(self):
244 return [tensor_spec.TensorSpec((), dtypes.variant)]
246 def _to_components(self, value):
247 return [value._variant_tensor] # pylint: disable=protected-access
249 def _from_components(self, flat_value):
250 # pylint: disable=protected-access
251 return _OptionalImpl(flat_value[0], self._element_spec)
253 @staticmethod
254 def from_value(value):
255 return OptionalSpec(value.element_spec)
257 def _to_legacy_output_types(self):
258 return self
260 def _to_legacy_output_shapes(self):
261 return self
263 def _to_legacy_output_classes(self):
264 return self
267nested_structure_coder.register_codec(
268 nested_structure_coder.BuiltInTypeSpecCodec(
269 OptionalSpec, struct_pb2.TypeSpecProto.OPTIONAL_SPEC
270 )
271)