Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/utils/layer_utils.py: 11%
186 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Utilities related to layer/model functionality."""
18import functools
19import weakref
21import numpy as np
23from tensorflow.python.util import nest
24from tensorflow.python.util.tf_export import keras_export
27@keras_export('keras.utils.get_source_inputs')
28def get_source_inputs(tensor, layer=None, node_index=None):
29 """Returns the list of input tensors necessary to compute `tensor`.
31 Output will always be a list of tensors
32 (potentially with 1 element).
34 Args:
35 tensor: The tensor to start from.
36 layer: Origin layer of the tensor. Will be
37 determined via tensor._keras_history if not provided.
38 node_index: Origin node index of the tensor.
40 Returns:
41 List of input tensors.
42 """
43 if not hasattr(tensor, '_keras_history'):
44 return tensor
46 if layer is None or node_index:
47 layer, node_index, _ = tensor._keras_history
48 if not layer._inbound_nodes:
49 return [tensor]
50 else:
51 node = layer._inbound_nodes[node_index]
52 if node.is_input:
53 # Reached an Input layer, stop recursion.
54 return nest.flatten(node.input_tensors)
55 else:
56 source_tensors = []
57 for layer, node_index, _, tensor in node.iterate_inbound():
58 previous_sources = get_source_inputs(tensor, layer, node_index)
59 # Avoid input redundancy.
60 for x in previous_sources:
61 if all(x is not t for t in source_tensors):
62 source_tensors.append(x)
63 return source_tensors
66def validate_string_arg(input_data,
67 allowable_strings,
68 layer_name,
69 arg_name,
70 allow_none=False,
71 allow_callables=False):
72 """Validates the correctness of a string-based arg."""
73 if allow_none and input_data is None:
74 return
75 elif allow_callables and callable(input_data):
76 return
77 elif isinstance(input_data, str) and input_data in allowable_strings:
78 return
79 else:
80 allowed_args = '`None`, ' if allow_none else ''
81 allowed_args += 'a `Callable`, ' if allow_callables else ''
82 allowed_args += 'or one of the following values: %s' % (allowable_strings,)
83 raise ValueError(('The %s argument of layer %s received an invalid '
84 'value %s. Allowed values are: %s.') %
85 (arg_name, layer_name, input_data, allowed_args))
88def count_params(weights):
89 """Count the total number of scalars composing the weights.
91 Args:
92 weights: An iterable containing the weights on which to compute params
94 Returns:
95 The total number of scalars composing the weights
96 """
97 unique_weights = {id(w): w for w in weights}.values()
98 weight_shapes = [w.shape.as_list() for w in unique_weights]
99 standardized_weight_shapes = [
100 [0 if w_i is None else w_i for w_i in w] for w in weight_shapes
101 ]
102 return int(sum(np.prod(p) for p in standardized_weight_shapes))
105def print_summary(model, line_length=None, positions=None, print_fn=None):
106 """Prints a summary of a model.
108 Args:
109 model: Keras model instance.
110 line_length: Total length of printed lines
111 (e.g. set this to adapt the display to different
112 terminal window sizes).
113 positions: Relative or absolute positions of log elements in each line.
114 If not provided, defaults to `[.33, .55, .67, 1.]`.
115 print_fn: Print function to use.
116 It will be called on each line of the summary.
117 You can set it to a custom function
118 in order to capture the string summary.
119 It defaults to `print` (prints to stdout).
120 """
121 if print_fn is None:
122 print_fn = print
124 if model.__class__.__name__ == 'Sequential':
125 sequential_like = True
126 elif not model._is_graph_network:
127 # We treat subclassed models as a simple sequence of layers, for logging
128 # purposes.
129 sequential_like = True
130 else:
131 sequential_like = True
132 nodes_by_depth = model._nodes_by_depth.values()
133 nodes = []
134 for v in nodes_by_depth:
135 if (len(v) > 1) or (len(v) == 1 and
136 len(nest.flatten(v[0].keras_inputs)) > 1):
137 # if the model has multiple nodes
138 # or if the nodes have multiple inbound_layers
139 # the model is no longer sequential
140 sequential_like = False
141 break
142 nodes += v
143 if sequential_like:
144 # search for shared layers
145 for layer in model.layers:
146 flag = False
147 for node in layer._inbound_nodes:
148 if node in nodes:
149 if flag:
150 sequential_like = False
151 break
152 else:
153 flag = True
154 if not sequential_like:
155 break
157 if sequential_like:
158 line_length = line_length or 65
159 positions = positions or [.45, .85, 1.]
160 if positions[-1] <= 1:
161 positions = [int(line_length * p) for p in positions]
162 # header names for the different log elements
163 to_display = ['Layer (type)', 'Output Shape', 'Param #']
164 else:
165 line_length = line_length or 98
166 positions = positions or [.33, .55, .67, 1.]
167 if positions[-1] <= 1:
168 positions = [int(line_length * p) for p in positions]
169 # header names for the different log elements
170 to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to']
171 relevant_nodes = []
172 for v in model._nodes_by_depth.values():
173 relevant_nodes += v
175 def print_row(fields, positions):
176 line = ''
177 for i in range(len(fields)):
178 if i > 0:
179 line = line[:-1] + ' '
180 line += str(fields[i])
181 line = line[:positions[i]]
182 line += ' ' * (positions[i] - len(line))
183 print_fn(line)
185 print_fn('Model: "{}"'.format(model.name))
186 print_fn('_' * line_length)
187 print_row(to_display, positions)
188 print_fn('=' * line_length)
190 def print_layer_summary(layer):
191 """Prints a summary for a single layer.
193 Args:
194 layer: target layer.
195 """
196 try:
197 output_shape = layer.output_shape
198 except AttributeError:
199 output_shape = 'multiple'
200 except RuntimeError: # output_shape unknown in Eager mode.
201 output_shape = '?'
202 name = layer.name
203 cls_name = layer.__class__.__name__
204 if not layer.built and not getattr(layer, '_is_graph_network', False):
205 # If a subclassed model has a layer that is not called in Model.call, the
206 # layer will not be built and we cannot call layer.count_params().
207 params = '0 (unused)'
208 else:
209 params = layer.count_params()
210 fields = [name + ' (' + cls_name + ')', output_shape, params]
211 print_row(fields, positions)
213 def print_layer_summary_with_connections(layer):
214 """Prints a summary for a single layer (including topological connections).
216 Args:
217 layer: target layer.
218 """
219 try:
220 output_shape = layer.output_shape
221 except AttributeError:
222 output_shape = 'multiple'
223 connections = []
224 for node in layer._inbound_nodes:
225 if relevant_nodes and node not in relevant_nodes:
226 # node is not part of the current network
227 continue
229 for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound():
230 connections.append('{}[{}][{}]'.format(inbound_layer.name, node_index,
231 tensor_index))
233 name = layer.name
234 cls_name = layer.__class__.__name__
235 if not connections:
236 first_connection = ''
237 else:
238 first_connection = connections[0]
239 fields = [
240 name + ' (' + cls_name + ')', output_shape,
241 layer.count_params(), first_connection
242 ]
243 print_row(fields, positions)
244 if len(connections) > 1:
245 for i in range(1, len(connections)):
246 fields = ['', '', '', connections[i]]
247 print_row(fields, positions)
249 layers = model.layers
250 for i in range(len(layers)):
251 if sequential_like:
252 print_layer_summary(layers[i])
253 else:
254 print_layer_summary_with_connections(layers[i])
255 if i == len(layers) - 1:
256 print_fn('=' * line_length)
257 else:
258 print_fn('_' * line_length)
260 if hasattr(model, '_collected_trainable_weights'):
261 trainable_count = count_params(model._collected_trainable_weights)
262 else:
263 trainable_count = count_params(model.trainable_weights)
265 non_trainable_count = count_params(model.non_trainable_weights)
267 print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count))
268 print_fn('Trainable params: {:,}'.format(trainable_count))
269 print_fn('Non-trainable params: {:,}'.format(non_trainable_count))
270 print_fn('_' * line_length)
273def convert_dense_weights_data_format(dense,
274 previous_feature_map_shape,
275 target_data_format='channels_first'):
276 """Utility useful when changing a convnet's `data_format`.
278 When porting the weights of a convnet from one data format to the other,
279 if the convnet includes a `Flatten` layer
280 (applied to the last convolutional feature map)
281 followed by a `Dense` layer, the weights of that `Dense` layer
282 should be updated to reflect the new dimension ordering.
284 Args:
285 dense: The target `Dense` layer.
286 previous_feature_map_shape: A shape tuple of 3 integers,
287 e.g. `(512, 7, 7)`. The shape of the convolutional
288 feature map right before the `Flatten` layer that
289 came before the target `Dense` layer.
290 target_data_format: One of "channels_last", "channels_first".
291 Set it "channels_last"
292 if converting a "channels_first" model to "channels_last",
293 or reciprocally.
294 """
295 assert target_data_format in {'channels_last', 'channels_first'}
296 kernel, bias = dense.get_weights()
297 for i in range(kernel.shape[1]):
298 if target_data_format == 'channels_first':
299 c, h, w = previous_feature_map_shape
300 original_fm_shape = (h, w, c)
301 ki = kernel[:, i].reshape(original_fm_shape)
302 ki = np.transpose(ki, (2, 0, 1)) # last -> first
303 else:
304 h, w, c = previous_feature_map_shape
305 original_fm_shape = (c, h, w)
306 ki = kernel[:, i].reshape(original_fm_shape)
307 ki = np.transpose(ki, (1, 2, 0)) # first -> last
308 kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),))
309 dense.set_weights([kernel, bias])
312def is_builtin_layer(layer):
313 if not getattr(layer, '_keras_api_names', None):
314 return False
316 # Subclasses of `Layer` that are not exported inherit the export name
317 # of the base layer class.
318 return (layer._keras_api_names != ('keras.layers.Layer',) and
319 layer._keras_api_names_v1 != ('keras.layers.Layer',))
322def cached_per_instance(f):
323 """Lightweight decorator for caching lazily constructed properties.
325 When to use:
326 This decorator provides simple caching with minimal overhead. It is designed
327 for properties which are expensive to compute and static over the life of a
328 class instance, and provides no mechanism for cache invalidation. Thus it is
329 best suited for lazily exposing derived properties of other static data.
331 For classes with custom getattr / setattr behavior (such as trackable
332 objects), storing cache results as object attributes is not performant.
333 Instead, a specialized cache can significantly reduce property lookup
334 overhead. (While still allowing the decorated property to be lazily computed.)
335 Consider the following class:
337 ```
338 class MyClass(object):
339 def __setattr__(self, key, value):
340 # Some expensive class specific code
341 # ...
342 # ...
344 super(MyClass, self).__setattr__(key, value)
346 @property
347 def thing(self):
348 # `thing` is expensive to compute (and may not even be requested), so we
349 # want to lazily compute it and then cache it.
350 output = getattr(self, '_thing', None)
351 if output is None:
352 self._thing = output = compute_thing(self)
353 return output
354 ```
356 It's also worth noting that ANY overriding of __setattr__, even something as
357 simple as:
358 ```
359 def __setattr__(self, key, value):
360 super(MyClass, self).__setattr__(key, value)
361 ```
363 Slows down attribute assignment by nearly 10x.
365 By contrast, replacing the definition of `thing` with the following sidesteps
366 the expensive __setattr__ altogether:
368 '''
369 @property
370 @tracking.cached_per_instance
371 def thing(self):
372 # `thing` is expensive to compute (and may not even be requested), so we
373 # want to lazily compute it and then cache it.
374 return compute_thing(self)
375 '''
377 Performance:
378 The overhead for this decorator is ~0.4 us / call. A much lower overhead
379 implementation (~0.085 us / call) can be achieved by using a custom dict type:
381 ```
382 def dict_based_cache(f):
383 class Cache(dict):
384 __slots__ = ()
385 def __missing__(self, key):
386 self[key] = output = f(key)
387 return output
389 return property(Cache().__getitem__)
390 ```
392 However, that implementation holds class instances as keys, and as a result
393 blocks garbage collection. (And modifying it to use weakref's as keys raises
394 the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary
395 implementation below turns out to be more prudent.
397 Args:
398 f: The function to cache.
400 Returns:
401 f decorated with simple caching behavior.
402 """
404 cache = weakref.WeakKeyDictionary()
406 @functools.wraps(f)
407 def wrapped(item):
408 output = cache.get(item)
409 if output is None:
410 cache[item] = output = f(item)
411 return output
413 wrapped.cache = cache
414 return wrapped
417def filter_empty_layer_containers(layer_list):
418 """Filter out empty Layer-like containers and uniquify."""
419 # TODO(b/130381733): Make this an attribute in base_layer.Layer.
420 existing = set()
421 to_visit = layer_list[::-1]
422 while to_visit:
423 obj = to_visit.pop()
424 if id(obj) in existing:
425 continue
426 existing.add(id(obj))
427 if hasattr(obj, '_is_layer') and not isinstance(obj, type):
428 yield obj
429 else:
430 sub_layers = getattr(obj, 'layers', None) or []
432 # Trackable data structures will not show up in ".layers" lists, but
433 # the layers they contain will.
434 to_visit.extend(sub_layers[::-1])