Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/utils/keras_utils.py: 21%
53 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for tf.keras."""
17import tensorflow as tf
20def is_tensor_or_variable(x):
21 return tf.is_tensor(x) or isinstance(x, tf.Variable)
24class LossFunctionWrapper(tf.keras.losses.Loss):
25 """Wraps a loss function in the `Loss` class."""
27 def __init__(
28 self, fn, reduction=tf.keras.losses.Reduction.AUTO, name=None, **kwargs
29 ):
30 """Initializes `LossFunctionWrapper` class.
32 Args:
33 fn: The loss function to wrap, with signature `fn(y_true, y_pred,
34 **kwargs)`.
35 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
36 loss. Default value is `AUTO`. `AUTO` indicates that the reduction
37 option will be determined by the usage context. For almost all cases
38 this defaults to `SUM_OVER_BATCH_SIZE`. When used with
39 `tf.distribute.Strategy`, outside of built-in training loops such as
40 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
41 will raise an error. Please see this custom training [tutorial](
42 https://www.tensorflow.org/tutorials/distribute/custom_training)
43 for more details.
44 name: (Optional) name for the loss.
45 **kwargs: The keyword arguments that are passed on to `fn`.
46 """
47 super().__init__(reduction=reduction, name=name)
48 self.fn = fn
49 self._fn_kwargs = kwargs
51 def call(self, y_true, y_pred):
52 """Invokes the `LossFunctionWrapper` instance.
54 Args:
55 y_true: Ground truth values.
56 y_pred: The predicted values.
58 Returns:
59 Loss values per sample.
60 """
61 return self.fn(y_true, y_pred, **self._fn_kwargs)
63 def get_config(self):
64 config = {}
65 for k, v in iter(self._fn_kwargs.items()):
66 config[k] = tf.keras.backend.eval(v) if is_tensor_or_variable(v) else v
67 base_config = super().get_config()
68 return {**base_config, **config}
71def normalize_data_format(value):
72 if value is None:
73 value = tf.keras.backend.image_data_format()
74 data_format = value.lower()
75 if data_format not in {"channels_first", "channels_last"}:
76 raise ValueError(
77 "The `data_format` argument must be one of "
78 '"channels_first", "channels_last". Received: ' + str(value)
79 )
80 return data_format
83def normalize_tuple(value, n, name):
84 """Transforms an integer or iterable of integers into an integer tuple.
86 A copy of tensorflow.python.keras.util.
88 Args:
89 value: The value to validate and convert. Could an int, or any iterable
90 of ints.
91 n: The size of the tuple to be returned.
92 name: The name of the argument being validated, e.g. "strides" or
93 "kernel_size". This is only used to format error messages.
95 Returns:
96 A tuple of n integers.
98 Raises:
99 ValueError: If something else than an int/long or iterable thereof was
100 passed.
101 """
102 if isinstance(value, int):
103 return (value,) * n
104 else:
105 try:
106 value_tuple = tuple(value)
107 except TypeError:
108 raise TypeError(
109 "The `"
110 + name
111 + "` argument must be a tuple of "
112 + str(n)
113 + " integers. Received: "
114 + str(value)
115 )
116 if len(value_tuple) != n:
117 raise ValueError(
118 "The `"
119 + name
120 + "` argument must be a tuple of "
121 + str(n)
122 + " integers. Received: "
123 + str(value)
124 )
125 for single_value in value_tuple:
126 try:
127 int(single_value)
128 except (ValueError, TypeError):
129 raise ValueError(
130 "The `"
131 + name
132 + "` argument must be a tuple of "
133 + str(n)
134 + " integers. Received: "
135 + str(value)
136 + " "
137 "including element "
138 + str(single_value)
139 + " of type"
140 + " "
141 + str(type(single_value))
142 )
143 return value_tuple
146def _hasattr(obj, attr_name):
147 # If possible, avoid retrieving the attribute as the object might run some
148 # lazy computation in it.
149 if attr_name in dir(obj):
150 return True
151 try:
152 getattr(obj, attr_name)
153 except AttributeError:
154 return False
155 else:
156 return True
159def assert_like_rnncell(cell_name, cell):
160 """Raises a TypeError if cell is not like a
161 tf.keras.layers.AbstractRNNCell.
163 Args:
164 cell_name: A string to give a meaningful error referencing to the name
165 of the function argument.
166 cell: The object which should behave like a
167 tf.keras.layers.AbstractRNNCell.
169 Raises:
170 TypeError: A human-friendly exception.
171 """
172 conditions = [
173 _hasattr(cell, "output_size"),
174 _hasattr(cell, "state_size"),
175 _hasattr(cell, "get_initial_state"),
176 callable(cell),
177 ]
179 errors = [
180 "'output_size' property is missing",
181 "'state_size' property is missing",
182 "'get_initial_state' method is required",
183 "is not callable",
184 ]
186 if not all(conditions):
187 errors = [error for error, cond in zip(errors, conditions) if not cond]
188 raise TypeError(
189 "The argument {!r} ({}) is not an RNNCell: {}.".format(
190 cell_name, cell, ", ".join(errors)
191 )
192 )