Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/legacy/model_config.py: 52%
21 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""Functions that save the model's config into different formats."""
18# isort: off
20import threading
21from tensorflow.python.util.tf_export import keras_export
22from keras.src.saving.legacy import serialization
24MODULE_OBJECTS = threading.local()
27@keras_export("keras.models.model_from_config")
28def model_from_config(config, custom_objects=None):
29 """Instantiates a Keras model from its config.
31 Usage:
32 ```
33 # for a Functional API model
34 tf.keras.Model().from_config(model.get_config())
36 # for a Sequential model
37 tf.keras.Sequential().from_config(model.get_config())
38 ```
40 Args:
41 config: Configuration dictionary.
42 custom_objects: Optional dictionary mapping names
43 (strings) to custom classes or functions to be
44 considered during deserialization.
46 Returns:
47 A Keras model instance (uncompiled).
49 Raises:
50 TypeError: if `config` is not a dictionary.
51 """
52 if isinstance(config, list):
53 raise TypeError(
54 "`model_from_config` expects a dictionary, not a list. "
55 f"Received: config={config}. Did you meant to use "
56 "`Sequential.from_config(config)`?"
57 )
58 from keras.src import layers
60 global MODULE_OBJECTS
62 if not hasattr(MODULE_OBJECTS, "ALL_OBJECTS"):
63 layers.serialization.populate_deserializable_objects()
64 MODULE_OBJECTS.ALL_OBJECTS = layers.serialization.LOCAL.ALL_OBJECTS
66 return serialization.deserialize_keras_object(
67 config,
68 module_objects=MODULE_OBJECTS.ALL_OBJECTS,
69 custom_objects=custom_objects,
70 printable_module_name="layer",
71 )
74@keras_export("keras.models.model_from_yaml")
75def model_from_yaml(yaml_string, custom_objects=None):
76 """Parses a yaml model configuration file and returns a model instance.
78 Note: Since TF 2.6, this method is no longer supported and will raise a
79 RuntimeError.
81 Args:
82 yaml_string: YAML string or open file encoding a model configuration.
83 custom_objects: Optional dictionary mapping names
84 (strings) to custom classes or functions to be
85 considered during deserialization.
87 Returns:
88 A Keras model instance (uncompiled).
90 Raises:
91 RuntimeError: announces that the method poses a security risk
92 """
93 raise RuntimeError(
94 "Method `model_from_yaml()` has been removed due to security risk of "
95 "arbitrary code execution. Please use `Model.to_json()` and "
96 "`model_from_json()` instead."
97 )
100@keras_export("keras.models.model_from_json")
101def model_from_json(json_string, custom_objects=None):
102 """Parses a JSON model configuration string and returns a model instance.
104 Usage:
106 >>> model = tf.keras.Sequential([
107 ... tf.keras.layers.Dense(5, input_shape=(3,)),
108 ... tf.keras.layers.Softmax()])
109 >>> config = model.to_json()
110 >>> loaded_model = tf.keras.models.model_from_json(config)
112 Args:
113 json_string: JSON string encoding a model configuration.
114 custom_objects: Optional dictionary mapping names
115 (strings) to custom classes or functions to be
116 considered during deserialization.
118 Returns:
119 A Keras model instance (uncompiled).
120 """
121 from keras.src.layers import (
122 deserialize_from_json,
123 )
125 return deserialize_from_json(json_string, custom_objects=custom_objects)