Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/core/function/trace_type/serialization.py: 54%
39 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utils for serializing and deserializing TraceTypes."""
17import abc
18from typing import Type
20from google.protobuf import message
21from tensorflow.core.function.trace_type import serialization_pb2
23SerializedTraceType = serialization_pb2.SerializedTraceType
25PROTO_CLASS_TO_PY_CLASS = {}
28class Serializable(metaclass=abc.ABCMeta):
29 """TraceTypes implementing this additional interface are portable."""
31 @classmethod
32 @abc.abstractmethod
33 def experimental_type_proto(cls) -> Type[message.Message]:
34 """Returns the unique type of proto associated with this class."""
35 raise NotImplementedError
37 @classmethod
38 @abc.abstractmethod
39 def experimental_from_proto(cls, proto: message.Message) -> "Serializable":
40 """Returns an instance based on a proto."""
41 raise NotImplementedError
43 @abc.abstractmethod
44 def experimental_as_proto(self) -> message.Message:
45 """Returns a proto representing this instance."""
46 raise NotImplementedError
49def register_serializable(cls: Type[Serializable]):
50 """Registers a Python class to support serialization.
52 Only register standard TF types. Custom types should NOT be registered.
54 Args:
55 cls: Python class to register.
56 """
57 if cls.experimental_type_proto() in PROTO_CLASS_TO_PY_CLASS:
58 raise ValueError(
59 "Existing Python class " +
60 PROTO_CLASS_TO_PY_CLASS[cls.experimental_type_proto()].__name__ +
61 " already has " + cls.experimental_type_proto().__name__ +
62 " as its associated proto representation. Please ensure " +
63 cls.__name__ + " has a unique proto representation.")
65 PROTO_CLASS_TO_PY_CLASS[cls.experimental_type_proto()] = cls
68def serialize(to_serialize: Serializable) -> SerializedTraceType:
69 """Converts Serializable to a proto SerializedTraceType."""
71 if not isinstance(to_serialize, Serializable):
72 raise ValueError("Can not serialize " + type(to_serialize).__name__ +
73 " since it is not Serializable. For object " +
74 str(to_serialize))
75 actual_proto = to_serialize.experimental_as_proto()
77 if not isinstance(actual_proto, to_serialize.experimental_type_proto()):
78 raise ValueError(
79 type(to_serialize).__name__ +
80 " returned different type of proto than specified by " +
81 "experimental_type_proto()")
83 serialized = SerializedTraceType()
84 serialized.representation.Pack(actual_proto)
85 return serialized
88def deserialize(proto: SerializedTraceType) -> Serializable:
89 """Converts a proto SerializedTraceType to instance of Serializable."""
90 for proto_class in PROTO_CLASS_TO_PY_CLASS:
91 if proto.representation.Is(proto_class.DESCRIPTOR):
92 actual_proto = proto_class()
93 proto.representation.Unpack(actual_proto)
94 return PROTO_CLASS_TO_PY_CLASS[proto_class].experimental_from_proto(
95 actual_proto)
97 raise ValueError(
98 "Can not deserialize proto of url: ", proto.representation.type_url,
99 " since no matching Python class could be found. For value ",
100 proto.representation.value)