Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/saving/model_config.py: 53%

17 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=protected-access 

16"""Functions that save the model's config into different formats.""" 

17 

18from tensorflow.python.keras.saving.saved_model import json_utils 

19from tensorflow.python.util.tf_export import keras_export 

20 

21 

22@keras_export('keras.models.model_from_config') 

23def model_from_config(config, custom_objects=None): 

24 """Instantiates a Keras model from its config. 

25 

26 Usage: 

27 ``` 

28 # for a Functional API model 

29 tf.keras.Model().from_config(model.get_config()) 

30 

31 # for a Sequential model 

32 tf.keras.Sequential().from_config(model.get_config()) 

33 ``` 

34 

35 Args: 

36 config: Configuration dictionary. 

37 custom_objects: Optional dictionary mapping names 

38 (strings) to custom classes or functions to be 

39 considered during deserialization. 

40 

41 Returns: 

42 A Keras model instance (uncompiled). 

43 

44 Raises: 

45 TypeError: if `config` is not a dictionary. 

46 """ 

47 if isinstance(config, list): 

48 raise TypeError('`model_from_config` expects a dictionary, not a list. ' 

49 'Maybe you meant to use ' 

50 '`Sequential.from_config(config)`?') 

51 from tensorflow.python.keras.layers import deserialize # pylint: disable=g-import-not-at-top 

52 return deserialize(config, custom_objects=custom_objects) 

53 

54 

55@keras_export('keras.models.model_from_yaml') 

56def model_from_yaml(yaml_string, custom_objects=None): 

57 """Parses a yaml model configuration file and returns a model instance. 

58 

59 Note: Since TF 2.6, this method is no longer supported and will raise a 

60 RuntimeError. 

61 

62 Args: 

63 yaml_string: YAML string or open file encoding a model configuration. 

64 custom_objects: Optional dictionary mapping names 

65 (strings) to custom classes or functions to be 

66 considered during deserialization. 

67 

68 Returns: 

69 A Keras model instance (uncompiled). 

70 

71 Raises: 

72 RuntimeError: announces that the method poses a security risk 

73 """ 

74 raise RuntimeError( 

75 'Method `model_from_yaml()` has been removed due to security risk of ' 

76 'arbitrary code execution. Please use `Model.to_json()` and ' 

77 '`model_from_json()` instead.' 

78 ) 

79 

80 

81@keras_export('keras.models.model_from_json') 

82def model_from_json(json_string, custom_objects=None): 

83 """Parses a JSON model configuration string and returns a model instance. 

84 

85 Usage: 

86 

87 >>> model = tf.keras.Sequential([ 

88 ... tf.keras.layers.Dense(5, input_shape=(3,)), 

89 ... tf.keras.layers.Softmax()]) 

90 >>> config = model.to_json() 

91 >>> loaded_model = tf.keras.models.model_from_json(config) 

92 

93 Args: 

94 json_string: JSON string encoding a model configuration. 

95 custom_objects: Optional dictionary mapping names 

96 (strings) to custom classes or functions to be 

97 considered during deserialization. 

98 

99 Returns: 

100 A Keras model instance (uncompiled). 

101 """ 

102 config = json_utils.decode(json_string) 

103 from tensorflow.python.keras.layers import deserialize # pylint: disable=g-import-not-at-top 

104 return deserialize(config, custom_objects=custom_objects)