Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/backend_config.py: 94%
32 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras backend config API."""
17import tensorflow.compat.v2 as tf
19# isort: off
20from tensorflow.python.util.tf_export import keras_export
22# The type of float to use throughout a session.
23_FLOATX = "float32"
25# Epsilon fuzz factor used throughout the codebase.
26_EPSILON = 1e-7
28# Default image data format, one of "channels_last", "channels_first".
29_IMAGE_DATA_FORMAT = "channels_last"
32@keras_export("keras.backend.epsilon")
33@tf.__internal__.dispatch.add_dispatch_support
34def epsilon():
35 """Returns the value of the fuzz factor used in numeric expressions.
37 Returns:
38 A float.
40 Example:
41 >>> tf.keras.backend.epsilon()
42 1e-07
43 """
44 return _EPSILON
47@keras_export("keras.backend.set_epsilon")
48def set_epsilon(value):
49 """Sets the value of the fuzz factor used in numeric expressions.
51 Args:
52 value: float. New value of epsilon.
54 Example:
55 >>> tf.keras.backend.epsilon()
56 1e-07
57 >>> tf.keras.backend.set_epsilon(1e-5)
58 >>> tf.keras.backend.epsilon()
59 1e-05
60 >>> tf.keras.backend.set_epsilon(1e-7)
61 """
62 global _EPSILON
63 _EPSILON = value
66@keras_export("keras.backend.floatx")
67def floatx():
68 """Returns the default float type, as a string.
70 E.g. `'float16'`, `'float32'`, `'float64'`.
72 Returns:
73 String, the current default float type.
75 Example:
76 >>> tf.keras.backend.floatx()
77 'float32'
78 """
79 return _FLOATX
82@keras_export("keras.backend.set_floatx")
83def set_floatx(value):
84 """Sets the default float type.
86 Note: It is not recommended to set this to float16 for training, as this
87 will likely cause numeric stability issues. Instead, mixed precision, which
88 is using a mix of float16 and float32, can be used by calling
89 `tf.keras.mixed_precision.set_global_policy('mixed_float16')`. See the
90 [mixed precision guide](
91 https://www.tensorflow.org/guide/keras/mixed_precision) for details.
93 Args:
94 value: String; `'float16'`, `'float32'`, or `'float64'`.
96 Example:
97 >>> tf.keras.backend.floatx()
98 'float32'
99 >>> tf.keras.backend.set_floatx('float64')
100 >>> tf.keras.backend.floatx()
101 'float64'
102 >>> tf.keras.backend.set_floatx('float32')
104 Raises:
105 ValueError: In case of invalid value.
106 """
107 global _FLOATX
108 accepted_dtypes = {"float16", "float32", "float64"}
109 if value not in accepted_dtypes:
110 raise ValueError(
111 f"Unknown `floatx` value: {value}. "
112 f"Expected one of {accepted_dtypes}"
113 )
114 _FLOATX = str(value)
117@keras_export("keras.backend.image_data_format")
118@tf.__internal__.dispatch.add_dispatch_support
119def image_data_format():
120 """Returns the default image data format convention.
122 Returns:
123 A string, either `'channels_first'` or `'channels_last'`
125 Example:
126 >>> tf.keras.backend.image_data_format()
127 'channels_last'
128 """
129 return _IMAGE_DATA_FORMAT
132@keras_export("keras.backend.set_image_data_format")
133def set_image_data_format(data_format):
134 """Sets the value of the image data format convention.
136 Args:
137 data_format: string. `'channels_first'` or `'channels_last'`.
139 Example:
140 >>> tf.keras.backend.image_data_format()
141 'channels_last'
142 >>> tf.keras.backend.set_image_data_format('channels_first')
143 >>> tf.keras.backend.image_data_format()
144 'channels_first'
145 >>> tf.keras.backend.set_image_data_format('channels_last')
147 Raises:
148 ValueError: In case of invalid `data_format` value.
149 """
150 global _IMAGE_DATA_FORMAT
151 accepted_formats = {"channels_last", "channels_first"}
152 if data_format not in accepted_formats:
153 raise ValueError(
154 f"Unknown `data_format`: {data_format}. "
155 f"Expected one of {accepted_formats}"
156 )
157 _IMAGE_DATA_FORMAT = str(data_format)