Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/layers/serialization.py: 49%

55 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Layer serialization/deserialization functions. 

16""" 

17# pylint: disable=wildcard-import 

18# pylint: disable=unused-import 

19 

20import threading 

21 

22from tensorflow.python import tf2 

23from tensorflow.python.keras.engine import base_layer 

24from tensorflow.python.keras.engine import input_layer 

25from tensorflow.python.keras.engine import input_spec 

26from tensorflow.python.keras.layers import advanced_activations 

27from tensorflow.python.keras.layers import convolutional 

28from tensorflow.python.keras.layers import convolutional_recurrent 

29from tensorflow.python.keras.layers import core 

30from tensorflow.python.keras.layers import dense_attention 

31from tensorflow.python.keras.layers import embeddings 

32from tensorflow.python.keras.layers import merge 

33from tensorflow.python.keras.layers import pooling 

34from tensorflow.python.keras.layers import recurrent 

35from tensorflow.python.keras.layers import rnn_cell_wrapper_v2 

36from tensorflow.python.keras.utils import generic_utils 

37from tensorflow.python.keras.utils import tf_inspect as inspect 

38from tensorflow.python.util.tf_export import keras_export 

39 

40ALL_MODULES = (base_layer, input_layer, advanced_activations, convolutional, 

41 convolutional_recurrent, core, dense_attention, 

42 embeddings, merge, pooling, recurrent) 

43ALL_V2_MODULES = (rnn_cell_wrapper_v2,) 

44# ALL_OBJECTS is meant to be a global mutable. Hence we need to make it 

45# thread-local to avoid concurrent mutations. 

46LOCAL = threading.local() 

47 

48 

49def populate_deserializable_objects(): 

50 """Populates dict ALL_OBJECTS with every built-in layer. 

51 """ 

52 global LOCAL 

53 if not hasattr(LOCAL, 'ALL_OBJECTS'): 

54 LOCAL.ALL_OBJECTS = {} 

55 LOCAL.GENERATED_WITH_V2 = None 

56 

57 if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled(): 

58 # Objects dict is already generated for the proper TF version: 

59 # do nothing. 

60 return 

61 

62 LOCAL.ALL_OBJECTS = {} 

63 LOCAL.GENERATED_WITH_V2 = tf2.enabled() 

64 

65 base_cls = base_layer.Layer 

66 generic_utils.populate_dict_with_module_objects( 

67 LOCAL.ALL_OBJECTS, 

68 ALL_MODULES, 

69 obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls)) 

70 

71 # Overwrite certain V1 objects with V2 versions 

72 if tf2.enabled(): 

73 generic_utils.populate_dict_with_module_objects( 

74 LOCAL.ALL_OBJECTS, 

75 ALL_V2_MODULES, 

76 obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls)) 

77 

78 # Prevent circular dependencies. 

79 from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top 

80 

81 LOCAL.ALL_OBJECTS['Input'] = input_layer.Input 

82 LOCAL.ALL_OBJECTS['InputSpec'] = input_spec.InputSpec 

83 LOCAL.ALL_OBJECTS['Functional'] = models.Functional 

84 LOCAL.ALL_OBJECTS['Model'] = models.Model 

85 LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential 

86 

87 # Merge layers, function versions. 

88 LOCAL.ALL_OBJECTS['add'] = merge.add 

89 LOCAL.ALL_OBJECTS['subtract'] = merge.subtract 

90 LOCAL.ALL_OBJECTS['multiply'] = merge.multiply 

91 LOCAL.ALL_OBJECTS['average'] = merge.average 

92 LOCAL.ALL_OBJECTS['maximum'] = merge.maximum 

93 LOCAL.ALL_OBJECTS['minimum'] = merge.minimum 

94 LOCAL.ALL_OBJECTS['concatenate'] = merge.concatenate 

95 LOCAL.ALL_OBJECTS['dot'] = merge.dot 

96 

97 

98@keras_export('keras.layers.serialize') 

99def serialize(layer): 

100 return generic_utils.serialize_keras_object(layer) 

101 

102 

103@keras_export('keras.layers.deserialize') 

104def deserialize(config, custom_objects=None): 

105 """Instantiates a layer from a config dictionary. 

106 

107 Args: 

108 config: dict of the form {'class_name': str, 'config': dict} 

109 custom_objects: dict mapping class names (or function names) 

110 of custom (non-Keras) objects to class/functions 

111 

112 Returns: 

113 Layer instance (may be Model, Sequential, Network, Layer...) 

114 """ 

115 populate_deserializable_objects() 

116 return generic_utils.deserialize_keras_object( 

117 config, 

118 module_objects=LOCAL.ALL_OBJECTS, 

119 custom_objects=custom_objects, 

120 printable_module_name='layer')