Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/core/function/trace_type/serialization.py: 54%

39 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-05 06:32 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utils for serializing and deserializing TraceTypes.""" 

16 

17import abc 

18from typing import Type 

19 

20from google.protobuf import message 

21from tensorflow.core.function.trace_type import serialization_pb2 

22 

23SerializedTraceType = serialization_pb2.SerializedTraceType 

24 

25PROTO_CLASS_TO_PY_CLASS = {} 

26 

27 

28class Serializable(metaclass=abc.ABCMeta): 

29 """TraceTypes implementing this additional interface are portable.""" 

30 

31 @classmethod 

32 @abc.abstractmethod 

33 def experimental_type_proto(cls) -> Type[message.Message]: 

34 """Returns the unique type of proto associated with this class.""" 

35 raise NotImplementedError 

36 

37 @classmethod 

38 @abc.abstractmethod 

39 def experimental_from_proto(cls, proto: message.Message) -> "Serializable": 

40 """Returns an instance based on a proto.""" 

41 raise NotImplementedError 

42 

43 @abc.abstractmethod 

44 def experimental_as_proto(self) -> message.Message: 

45 """Returns a proto representing this instance.""" 

46 raise NotImplementedError 

47 

48 

49def register_serializable(cls: Type[Serializable]): 

50 """Registers a Python class to support serialization. 

51 

52 Only register standard TF types. Custom types should NOT be registered. 

53 

54 Args: 

55 cls: Python class to register. 

56 """ 

57 if cls.experimental_type_proto() in PROTO_CLASS_TO_PY_CLASS: 

58 raise ValueError( 

59 "Existing Python class " + 

60 PROTO_CLASS_TO_PY_CLASS[cls.experimental_type_proto()].__name__ + 

61 " already has " + cls.experimental_type_proto().__name__ + 

62 " as its associated proto representation. Please ensure " + 

63 cls.__name__ + " has a unique proto representation.") 

64 

65 PROTO_CLASS_TO_PY_CLASS[cls.experimental_type_proto()] = cls 

66 

67 

68def serialize(to_serialize: Serializable) -> SerializedTraceType: 

69 """Converts Serializable to a proto SerializedTraceType.""" 

70 

71 if not isinstance(to_serialize, Serializable): 

72 raise ValueError("Can not serialize " + type(to_serialize).__name__ + 

73 " since it is not Serializable. For object " + 

74 str(to_serialize)) 

75 actual_proto = to_serialize.experimental_as_proto() 

76 

77 if not isinstance(actual_proto, to_serialize.experimental_type_proto()): 

78 raise ValueError( 

79 type(to_serialize).__name__ + 

80 " returned different type of proto than specified by " + 

81 "experimental_type_proto()") 

82 

83 serialized = SerializedTraceType() 

84 serialized.representation.Pack(actual_proto) 

85 return serialized 

86 

87 

88def deserialize(proto: SerializedTraceType) -> Serializable: 

89 """Converts a proto SerializedTraceType to instance of Serializable.""" 

90 for proto_class in PROTO_CLASS_TO_PY_CLASS: 

91 if proto.representation.Is(proto_class.DESCRIPTOR): 

92 actual_proto = proto_class() 

93 proto.representation.Unpack(actual_proto) 

94 return PROTO_CLASS_TO_PY_CLASS[proto_class].experimental_from_proto( 

95 actual_proto) 

96 

97 raise ValueError( 

98 "Can not deserialize proto of url: ", proto.representation.type_url, 

99 " since no matching Python class could be found. For value ", 

100 proto.representation.value)