Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/optional_ops.py: 62%

82 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""A type for representing values that may or may not exist.""" 

16import abc 

17 

18from tensorflow.core.protobuf import struct_pb2 

19from tensorflow.python.data.util import structure 

20from tensorflow.python.framework import composite_tensor 

21from tensorflow.python.framework import dtypes 

22from tensorflow.python.framework import ops 

23from tensorflow.python.framework import tensor_spec 

24from tensorflow.python.framework import type_spec 

25from tensorflow.python.ops import gen_optional_ops 

26from tensorflow.python.saved_model import nested_structure_coder 

27from tensorflow.python.util import deprecation 

28from tensorflow.python.util.tf_export import tf_export 

29 

30 

31@tf_export("experimental.Optional", "data.experimental.Optional") 

32@deprecation.deprecated_endpoints("data.experimental.Optional") 

33class Optional(composite_tensor.CompositeTensor, metaclass=abc.ABCMeta): 

34 """Represents a value that may or may not be present. 

35 

36 A `tf.experimental.Optional` can represent the result of an operation that may 

37 fail as a value, rather than raising an exception and halting execution. For 

38 example, `tf.data.Iterator.get_next_as_optional()` returns a 

39 `tf.experimental.Optional` that either contains the next element of an 

40 iterator if one exists, or an "empty" value that indicates the end of the 

41 sequence has been reached. 

42 

43 `tf.experimental.Optional` can only be used with values that are convertible 

44 to `tf.Tensor` or `tf.CompositeTensor`. 

45 

46 One can create a `tf.experimental.Optional` from a value using the 

47 `from_value()` method: 

48 

49 >>> optional = tf.experimental.Optional.from_value(42) 

50 >>> print(optional.has_value()) 

51 tf.Tensor(True, shape=(), dtype=bool) 

52 >>> print(optional.get_value()) 

53 tf.Tensor(42, shape=(), dtype=int32) 

54 

55 or without a value using the `empty()` method: 

56 

57 >>> optional = tf.experimental.Optional.empty( 

58 ... tf.TensorSpec(shape=(), dtype=tf.int32, name=None)) 

59 >>> print(optional.has_value()) 

60 tf.Tensor(False, shape=(), dtype=bool) 

61 """ 

62 

63 @abc.abstractmethod 

64 def has_value(self, name=None): 

65 """Returns a tensor that evaluates to `True` if this optional has a value. 

66 

67 >>> optional = tf.experimental.Optional.from_value(42) 

68 >>> print(optional.has_value()) 

69 tf.Tensor(True, shape=(), dtype=bool) 

70 

71 Args: 

72 name: (Optional.) A name for the created operation. 

73 

74 Returns: 

75 A scalar `tf.Tensor` of type `tf.bool`. 

76 """ 

77 raise NotImplementedError("Optional.has_value()") 

78 

79 @abc.abstractmethod 

80 def get_value(self, name=None): 

81 """Returns the value wrapped by this optional. 

82 

83 If this optional does not have a value (i.e. `self.has_value()` evaluates to 

84 `False`), this operation will raise `tf.errors.InvalidArgumentError` at 

85 runtime. 

86 

87 >>> optional = tf.experimental.Optional.from_value(42) 

88 >>> print(optional.get_value()) 

89 tf.Tensor(42, shape=(), dtype=int32) 

90 

91 Args: 

92 name: (Optional.) A name for the created operation. 

93 

94 Returns: 

95 The wrapped value. 

96 """ 

97 raise NotImplementedError("Optional.get_value()") 

98 

99 @abc.abstractproperty 

100 def element_spec(self): 

101 """The type specification of an element of this optional. 

102 

103 >>> optional = tf.experimental.Optional.from_value(42) 

104 >>> print(optional.element_spec) 

105 tf.TensorSpec(shape=(), dtype=tf.int32, name=None) 

106 

107 Returns: 

108 A (nested) structure of `tf.TypeSpec` objects matching the structure of an 

109 element of this optional, specifying the type of individual components. 

110 """ 

111 raise NotImplementedError("Optional.element_spec") 

112 

113 @staticmethod 

114 def empty(element_spec): 

115 """Returns an `Optional` that has no value. 

116 

117 NOTE: This method takes an argument that defines the structure of the value 

118 that would be contained in the returned `Optional` if it had a value. 

119 

120 >>> optional = tf.experimental.Optional.empty( 

121 ... tf.TensorSpec(shape=(), dtype=tf.int32, name=None)) 

122 >>> print(optional.has_value()) 

123 tf.Tensor(False, shape=(), dtype=bool) 

124 

125 Args: 

126 element_spec: A (nested) structure of `tf.TypeSpec` objects matching the 

127 structure of an element of this optional. 

128 

129 Returns: 

130 A `tf.experimental.Optional` with no value. 

131 """ 

132 return _OptionalImpl(gen_optional_ops.optional_none(), element_spec) 

133 

134 @staticmethod 

135 def from_value(value): 

136 """Returns a `tf.experimental.Optional` that wraps the given value. 

137 

138 >>> optional = tf.experimental.Optional.from_value(42) 

139 >>> print(optional.has_value()) 

140 tf.Tensor(True, shape=(), dtype=bool) 

141 >>> print(optional.get_value()) 

142 tf.Tensor(42, shape=(), dtype=int32) 

143 

144 Args: 

145 value: A value to wrap. The value must be convertible to `tf.Tensor` or 

146 `tf.CompositeTensor`. 

147 

148 Returns: 

149 A `tf.experimental.Optional` that wraps `value`. 

150 """ 

151 with ops.name_scope("optional") as scope: 

152 with ops.name_scope("value"): 

153 element_spec = structure.type_spec_from_value(value) 

154 encoded_value = structure.to_tensor_list(element_spec, value) 

155 

156 return _OptionalImpl( 

157 gen_optional_ops.optional_from_value(encoded_value, name=scope), 

158 element_spec, 

159 ) 

160 

161 

162class _OptionalImpl(Optional): 

163 """Concrete implementation of `tf.experimental.Optional`. 

164 

165 NOTE(mrry): This implementation is kept private, to avoid defining 

166 `Optional.__init__()` in the public API. 

167 """ 

168 

169 def __init__(self, variant_tensor, element_spec): 

170 super().__init__() 

171 self._variant_tensor = variant_tensor 

172 self._element_spec = element_spec 

173 

174 def has_value(self, name=None): 

175 with ops.colocate_with(self._variant_tensor): 

176 return gen_optional_ops.optional_has_value( 

177 self._variant_tensor, name=name 

178 ) 

179 

180 def get_value(self, name=None): 

181 # TODO(b/110122868): Consolidate the restructuring logic with similar logic 

182 # in `Iterator.get_next()` and `StructuredFunctionWrapper`. 

183 with ops.name_scope(name, "OptionalGetValue", 

184 [self._variant_tensor]) as scope: 

185 with ops.colocate_with(self._variant_tensor): 

186 result = gen_optional_ops.optional_get_value( 

187 self._variant_tensor, 

188 name=scope, 

189 output_types=structure.get_flat_tensor_types(self._element_spec), 

190 output_shapes=structure.get_flat_tensor_shapes(self._element_spec), 

191 ) 

192 # NOTE: We do not colocate the deserialization of composite tensors 

193 # because not all ops are guaranteed to have non-GPU kernels. 

194 return structure.from_tensor_list(self._element_spec, result) 

195 

196 @property 

197 def element_spec(self): 

198 return self._element_spec 

199 

200 @property 

201 def _type_spec(self): 

202 return OptionalSpec.from_value(self) 

203 

204 

205@tf_export( 

206 "OptionalSpec", v1=["OptionalSpec", "data.experimental.OptionalStructure"]) 

207class OptionalSpec(type_spec.TypeSpec): 

208 """Type specification for `tf.experimental.Optional`. 

209 

210 For instance, `tf.OptionalSpec` can be used to define a tf.function that takes 

211 `tf.experimental.Optional` as an input argument: 

212 

213 >>> @tf.function(input_signature=[tf.OptionalSpec( 

214 ... tf.TensorSpec(shape=(), dtype=tf.int32, name=None))]) 

215 ... def maybe_square(optional): 

216 ... if optional.has_value(): 

217 ... x = optional.get_value() 

218 ... return x * x 

219 ... return -1 

220 >>> optional = tf.experimental.Optional.from_value(5) 

221 >>> print(maybe_square(optional)) 

222 tf.Tensor(25, shape=(), dtype=int32) 

223 

224 Attributes: 

225 element_spec: A (nested) structure of `TypeSpec` objects that represents the 

226 type specification of the optional element. 

227 """ 

228 

229 __slots__ = ["_element_spec"] 

230 

231 def __init__(self, element_spec): 

232 super().__init__() 

233 self._element_spec = element_spec 

234 

235 @property 

236 def value_type(self): 

237 return _OptionalImpl 

238 

239 def _serialize(self): 

240 return (self._element_spec,) 

241 

242 @property 

243 def _component_specs(self): 

244 return [tensor_spec.TensorSpec((), dtypes.variant)] 

245 

246 def _to_components(self, value): 

247 return [value._variant_tensor] # pylint: disable=protected-access 

248 

249 def _from_components(self, flat_value): 

250 # pylint: disable=protected-access 

251 return _OptionalImpl(flat_value[0], self._element_spec) 

252 

253 @staticmethod 

254 def from_value(value): 

255 return OptionalSpec(value.element_spec) 

256 

257 def _to_legacy_output_types(self): 

258 return self 

259 

260 def _to_legacy_output_shapes(self): 

261 return self 

262 

263 def _to_legacy_output_classes(self): 

264 return self 

265 

266 

267nested_structure_coder.register_codec( 

268 nested_structure_coder.BuiltInTypeSpecCodec( 

269 OptionalSpec, struct_pb2.TypeSpecProto.OPTIONAL_SPEC 

270 ) 

271)