Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/nested_structure_coder.py: 39%
252 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module that encodes (decodes) nested structures into (from) protos.
17The intended use is to serialize everything needed to restore a `Function` that
18was saved into a SavedModel. This may include concrete function inputs and
19outputs, signatures, function specs, etc.
21Example use:
22# Encode into proto.
23signature_proto = nested_structure_coder.encode_structure(
24 function.input_signature)
25# Decode into a Python object.
26restored_signature = nested_structure_coder.decode_proto(signature_proto)
27"""
29import collections
30import functools
31import warnings
33from tensorflow.core.protobuf import struct_pb2
34from tensorflow.python.framework import dtypes
35from tensorflow.python.framework import type_spec_registry
36from tensorflow.python.types import internal
37from tensorflow.python.util import compat
38from tensorflow.python.util import nest
39from tensorflow.python.util.compat import collections_abc
40from tensorflow.python.util.tf_export import tf_export
43class NotEncodableError(Exception):
44 """Error raised when a coder cannot encode an object."""
47def register_codec(x):
48 """Registers a codec to use for encoding/decoding.
50 Args:
51 x: The codec object to register. The object must implement can_encode,
52 do_encode, can_decode, and do_decode. See the various _*Codec classes for
53 examples.
54 """
55 _codecs.append(x)
58def _get_encoders():
59 return [(c.can_encode, c.do_encode) for c in _codecs]
62def _get_decoders():
63 return [(c.can_decode, c.do_decode) for c in _codecs]
66def _map_structure(pyobj, coders):
67 # Iterate through the codecs in the reverse order they were registered in,
68 # as the most specific codec should be checked first.
69 for can, do in reversed(coders):
70 if can(pyobj):
71 recursion_fn = functools.partial(_map_structure, coders=coders)
72 return do(pyobj, recursion_fn)
73 raise NotEncodableError(
74 f"No encoder for object {str(pyobj)} of type {type(pyobj)}.")
77@tf_export("__internal__.saved_model.encode_structure", v1=[])
78def encode_structure(nested_structure):
79 """Encodes nested structures composed of encodable types into a proto.
81 Args:
82 nested_structure: Structure to encode.
84 Returns:
85 Encoded proto.
87 Raises:
88 NotEncodableError: For values for which there are no encoders.
89 """
90 return _map_structure(nested_structure, _get_encoders())
93def can_encode(nested_structure):
94 """Determines whether a nested structure can be encoded into a proto.
96 Args:
97 nested_structure: Structure to encode.
99 Returns:
100 True if the nested structured can be encoded.
101 """
102 try:
103 encode_structure(nested_structure)
104 except NotEncodableError:
105 return False
106 return True
109@tf_export("__internal__.saved_model.decode_proto", v1=[])
110def decode_proto(proto):
111 """Decodes proto representing a nested structure.
113 Args:
114 proto: Proto to decode.
116 Returns:
117 Decoded structure.
119 Raises:
120 NotEncodableError: For values for which there are no encoders.
121 """
122 return _map_structure(proto, _get_decoders())
125class _ListCodec:
126 """Codec for lists."""
128 def can_encode(self, pyobj):
129 return isinstance(pyobj, list)
131 def do_encode(self, list_value, encode_fn):
132 encoded_list = struct_pb2.StructuredValue()
133 encoded_list.list_value.CopyFrom(struct_pb2.ListValue())
134 for element in list_value:
135 encoded_list.list_value.values.add().CopyFrom(encode_fn(element))
136 return encoded_list
138 def can_decode(self, value):
139 return value.HasField("list_value")
141 def do_decode(self, value, decode_fn):
142 return [decode_fn(element) for element in value.list_value.values]
145def _is_tuple(obj):
146 return not _is_named_tuple(obj) and isinstance(obj, tuple)
149def _is_named_tuple(instance):
150 """Returns True iff `instance` is a `namedtuple`.
152 Args:
153 instance: An instance of a Python object.
155 Returns:
156 True if `instance` is a `namedtuple`.
157 """
158 if not isinstance(instance, tuple):
159 return False
160 return (hasattr(instance, "_fields") and
161 isinstance(instance._fields, collections_abc.Sequence) and
162 all(isinstance(f, str) for f in instance._fields))
165class _TupleCodec:
166 """Codec for tuples."""
168 def can_encode(self, pyobj):
169 return _is_tuple(pyobj)
171 def do_encode(self, tuple_value, encode_fn):
172 encoded_tuple = struct_pb2.StructuredValue()
173 encoded_tuple.tuple_value.CopyFrom(struct_pb2.TupleValue())
174 for element in tuple_value:
175 encoded_tuple.tuple_value.values.add().CopyFrom(encode_fn(element))
176 return encoded_tuple
178 def can_decode(self, value):
179 return value.HasField("tuple_value")
181 def do_decode(self, value, decode_fn):
182 return tuple(decode_fn(element) for element in value.tuple_value.values)
185class _DictCodec:
186 """Codec for dicts."""
188 def can_encode(self, pyobj):
189 return isinstance(pyobj, dict)
191 def do_encode(self, dict_value, encode_fn):
192 encoded_dict = struct_pb2.StructuredValue()
193 encoded_dict.dict_value.CopyFrom(struct_pb2.DictValue())
194 for key, value in dict_value.items():
195 encoded_dict.dict_value.fields[key].CopyFrom(encode_fn(value))
196 return encoded_dict
198 def can_decode(self, value):
199 return value.HasField("dict_value")
201 def do_decode(self, value, decode_fn):
202 return {key: decode_fn(val) for key, val in value.dict_value.fields.items()}
205class _NamedTupleCodec:
206 """Codec for namedtuples.
208 Encoding and decoding a namedtuple reconstructs a namedtuple with a different
209 actual Python type, but with the same `typename` and `fields`.
210 """
212 def can_encode(self, pyobj):
213 return _is_named_tuple(pyobj)
215 def do_encode(self, named_tuple_value, encode_fn):
216 encoded_named_tuple = struct_pb2.StructuredValue()
217 encoded_named_tuple.named_tuple_value.CopyFrom(struct_pb2.NamedTupleValue())
218 encoded_named_tuple.named_tuple_value.name = \
219 named_tuple_value.__class__.__name__
220 for key in named_tuple_value._fields:
221 pair = encoded_named_tuple.named_tuple_value.values.add()
222 pair.key = key
223 pair.value.CopyFrom(encode_fn(named_tuple_value._asdict()[key]))
224 return encoded_named_tuple
226 def can_decode(self, value):
227 return value.HasField("named_tuple_value")
229 def do_decode(self, value, decode_fn):
230 key_value_pairs = value.named_tuple_value.values
231 items = [(pair.key, decode_fn(pair.value)) for pair in key_value_pairs]
232 named_tuple_type = collections.namedtuple(value.named_tuple_value.name,
233 [item[0] for item in items])
234 return named_tuple_type(**dict(items))
237class _Float64Codec:
238 """Codec for floats."""
240 def can_encode(self, pyobj):
241 return isinstance(pyobj, float)
243 def do_encode(self, float64_value, encode_fn):
244 del encode_fn
245 value = struct_pb2.StructuredValue()
246 value.float64_value = float64_value
247 return value
249 def can_decode(self, value):
250 return value.HasField("float64_value")
252 def do_decode(self, value, decode_fn):
253 del decode_fn
254 return value.float64_value
257class _Int64Codec:
258 """Codec for Python integers (limited to 64 bit values)."""
260 def can_encode(self, pyobj):
261 return not isinstance(pyobj, bool) and isinstance(pyobj, int)
263 def do_encode(self, int_value, encode_fn):
264 del encode_fn
265 value = struct_pb2.StructuredValue()
266 value.int64_value = int_value
267 return value
269 def can_decode(self, value):
270 return value.HasField("int64_value")
272 def do_decode(self, value, decode_fn):
273 del decode_fn
274 return int(value.int64_value)
277class _StringCodec:
278 """Codec for strings.
280 See StructuredValue.string_value in proto/struct.proto for more detailed
281 explanation.
282 """
284 def can_encode(self, pyobj):
285 return isinstance(pyobj, str)
287 def do_encode(self, string_value, encode_fn):
288 del encode_fn
289 value = struct_pb2.StructuredValue()
290 value.string_value = string_value
291 return value
293 def can_decode(self, value):
294 return value.HasField("string_value")
296 def do_decode(self, value, decode_fn):
297 del decode_fn
298 return compat.as_str(value.string_value)
301class _NoneCodec:
302 """Codec for None."""
304 def can_encode(self, pyobj):
305 return pyobj is None
307 def do_encode(self, none_value, encode_fn):
308 del encode_fn, none_value
309 value = struct_pb2.StructuredValue()
310 value.none_value.CopyFrom(struct_pb2.NoneValue())
311 return value
313 def can_decode(self, value):
314 return value.HasField("none_value")
316 def do_decode(self, value, decode_fn):
317 del decode_fn, value
318 return None
321class _BoolCodec:
322 """Codec for booleans."""
324 def can_encode(self, pyobj):
325 return isinstance(pyobj, bool)
327 def do_encode(self, bool_value, encode_fn):
328 del encode_fn
329 value = struct_pb2.StructuredValue()
330 value.bool_value = bool_value
331 return value
333 def can_decode(self, value):
334 return value.HasField("bool_value")
336 def do_decode(self, value, decode_fn):
337 del decode_fn
338 return value.bool_value
341class _TensorTypeCodec:
342 """Codec for `TensorType`."""
344 def can_encode(self, pyobj):
345 return isinstance(pyobj, dtypes.DType)
347 def do_encode(self, tensor_dtype_value, encode_fn):
348 del encode_fn
349 encoded_tensor_type = struct_pb2.StructuredValue()
350 encoded_tensor_type.tensor_dtype_value = tensor_dtype_value.as_datatype_enum
351 return encoded_tensor_type
353 def can_decode(self, value):
354 return value.HasField("tensor_dtype_value")
356 def do_decode(self, value, decode_fn):
357 del decode_fn
358 return dtypes.DType(value.tensor_dtype_value)
361class BuiltInTypeSpecCodec:
362 """Codec for built-in `TypeSpec` classes.
364 Built-in TypeSpec's that do not require a custom codec implementation
365 register themselves by instantiating this class and passing it to
366 register_codec.
368 Attributes:
369 type_spec_class: The built-in TypeSpec class that the
370 codec is instantiated for.
371 type_spec_proto_enum: The proto enum value for the built-in TypeSpec class.
372 """
374 _BUILT_IN_TYPE_SPEC_PROTOS = []
375 _BUILT_IN_TYPE_SPECS = []
377 def __init__(self, type_spec_class, type_spec_proto_enum):
378 if not issubclass(type_spec_class, internal.TypeSpec):
379 raise ValueError(
380 f"The type '{type_spec_class}' does not subclass tf.TypeSpec.")
382 if type_spec_class in self._BUILT_IN_TYPE_SPECS:
383 raise ValueError(
384 f"The type '{type_spec_class}' already has an instantiated codec.")
386 if type_spec_proto_enum in self._BUILT_IN_TYPE_SPEC_PROTOS:
387 raise ValueError(
388 f"The proto value '{type_spec_proto_enum}' is already registered."
389 )
391 if (not isinstance(type_spec_proto_enum, int)
392 or type_spec_proto_enum <= 0
393 or type_spec_proto_enum > 10):
394 raise ValueError(f"The proto value '{type_spec_proto_enum}' is invalid.")
396 self.type_spec_class = type_spec_class
397 self.type_spec_proto_enum = type_spec_proto_enum
399 self._BUILT_IN_TYPE_SPECS.append(type_spec_class)
400 self._BUILT_IN_TYPE_SPEC_PROTOS.append(type_spec_proto_enum)
402 def can_encode(self, pyobj):
403 """Returns true if `pyobj` can be encoded as the built-in TypeSpec."""
404 return isinstance(pyobj, self.type_spec_class)
406 def do_encode(self, type_spec_value, encode_fn):
407 """Returns an encoded proto for the given built-in TypeSpec."""
408 type_state = type_spec_value._serialize() # pylint: disable=protected-access
409 num_flat_components = len(nest.flatten(
410 type_spec_value._component_specs, expand_composites=True)) # pylint: disable=protected-access
411 encoded_type_spec = struct_pb2.StructuredValue()
412 encoded_type_spec.type_spec_value.CopyFrom(
413 struct_pb2.TypeSpecProto(
414 type_spec_class=self.type_spec_proto_enum,
415 type_state=encode_fn(type_state),
416 type_spec_class_name=self.type_spec_class.__name__,
417 num_flat_components=num_flat_components))
418 return encoded_type_spec
420 def can_decode(self, value):
421 """Returns true if `value` can be decoded into its built-in TypeSpec."""
422 if value.HasField("type_spec_value"):
423 type_spec_class_enum = value.type_spec_value.type_spec_class
424 return type_spec_class_enum == self.type_spec_proto_enum
425 return False
427 def do_decode(self, value, decode_fn):
428 """Returns the built in `TypeSpec` encoded by the proto `value`."""
429 type_spec_proto = value.type_spec_value
430 # pylint: disable=protected-access
431 return self.type_spec_class._deserialize(
432 decode_fn(type_spec_proto.type_state)
433 )
436# TODO(b/238903802): Use TraceType serialization and specific protos.
437class _TypeSpecCodec:
438 """Codec for `tf.TypeSpec`."""
440 # Mapping from enum value to type (TypeSpec subclass).
441 # Must leave this for backwards-compatibility until all external usages
442 # have been removed.
443 TYPE_SPEC_CLASS_FROM_PROTO = {
444 }
446 # Mapping from type (TypeSpec subclass) to enum value.
447 TYPE_SPEC_CLASS_TO_PROTO = dict(
448 (cls, enum) for (enum, cls) in TYPE_SPEC_CLASS_FROM_PROTO.items())
450 def can_encode(self, pyobj):
451 """Returns true if `pyobj` can be encoded as a TypeSpec."""
452 if type(pyobj) in self.TYPE_SPEC_CLASS_TO_PROTO: # pylint: disable=unidiomatic-typecheck
453 return True
455 # Check if it's a registered type.
456 if isinstance(pyobj, internal.TypeSpec):
457 try:
458 type_spec_registry.get_name(type(pyobj))
459 return True
460 except ValueError:
461 return False
463 return False
465 def do_encode(self, type_spec_value, encode_fn):
466 """Returns an encoded proto for the given `tf.TypeSpec`."""
467 type_spec_class = self.TYPE_SPEC_CLASS_TO_PROTO.get(type(type_spec_value))
468 type_spec_class_name = type(type_spec_value).__name__
470 if type_spec_class is None:
471 type_spec_class_name = type_spec_registry.get_name(type(type_spec_value))
472 type_spec_class = struct_pb2.TypeSpecProto.REGISTERED_TYPE_SPEC
473 # Support for saving registered TypeSpecs is currently experimental.
474 # Issue a warning to indicate the limitations.
475 warnings.warn("Encoding a StructuredValue with type %s; loading this "
476 "StructuredValue will require that this type be "
477 "imported and registered." % type_spec_class_name)
479 type_state = type_spec_value._serialize() # pylint: disable=protected-access
480 num_flat_components = len(
481 nest.flatten(type_spec_value._component_specs, expand_composites=True)) # pylint: disable=protected-access
482 encoded_type_spec = struct_pb2.StructuredValue()
483 encoded_type_spec.type_spec_value.CopyFrom(
484 struct_pb2.TypeSpecProto(
485 type_spec_class=type_spec_class,
486 type_state=encode_fn(type_state),
487 type_spec_class_name=type_spec_class_name,
488 num_flat_components=num_flat_components))
489 return encoded_type_spec
491 def can_decode(self, value):
492 """Returns true if `value` can be decoded into a `tf.TypeSpec`."""
493 return value.HasField("type_spec_value")
495 def do_decode(self, value, decode_fn):
496 """Returns the `tf.TypeSpec` encoded by the proto `value`."""
497 type_spec_proto = value.type_spec_value
498 type_spec_class_enum = type_spec_proto.type_spec_class
499 class_name = type_spec_proto.type_spec_class_name
501 if type_spec_class_enum == struct_pb2.TypeSpecProto.REGISTERED_TYPE_SPEC:
502 try:
503 type_spec_class = type_spec_registry.lookup(class_name)
504 except ValueError as e:
505 raise ValueError(
506 f"The type '{class_name}' has not been registered. It must be "
507 "registered before you load this object (typically by importing "
508 "its module).") from e
509 else:
510 if type_spec_class_enum not in self.TYPE_SPEC_CLASS_FROM_PROTO:
511 raise ValueError(
512 f"The type '{class_name}' is not supported by this version of "
513 "TensorFlow. (The object you are loading must have been created "
514 "with a newer version of TensorFlow.)")
515 type_spec_class = self.TYPE_SPEC_CLASS_FROM_PROTO[type_spec_class_enum]
517 # pylint: disable=protected-access
518 return type_spec_class._deserialize(decode_fn(type_spec_proto.type_state))
521_codecs = [
522 _ListCodec(),
523 _TupleCodec(),
524 _NamedTupleCodec(),
525 _StringCodec(),
526 _Float64Codec(),
527 _NoneCodec(),
528 _Int64Codec(),
529 _BoolCodec(),
530 _DictCodec(),
531 _TypeSpecCodec(),
532 _TensorTypeCodec(),
533]