Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/layer_utils.py: 13%
437 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""Utilities related to layer/model functionality."""
18import copy
19import functools
20import re
21import weakref
23import numpy as np
24import tensorflow.compat.v2 as tf
26from keras.src import initializers
27from keras.src.utils import io_utils
29# isort: off
30from tensorflow.python.util.tf_export import keras_export
33@keras_export("keras.utils.get_source_inputs")
34def get_source_inputs(tensor, layer=None, node_index=None):
35 """Returns the list of input tensors necessary to compute `tensor`.
37 Output will always be a list of tensors
38 (potentially with 1 element).
40 Args:
41 tensor: The tensor to start from.
42 layer: Origin layer of the tensor. Will be
43 determined via tensor._keras_history if not provided.
44 node_index: Origin node index of the tensor.
46 Returns:
47 List of input tensors.
48 """
49 if not hasattr(tensor, "_keras_history"):
50 return tensor
52 if layer is None or node_index:
53 layer, node_index, _ = tensor._keras_history
54 if not layer._inbound_nodes:
55 return [tensor]
56 else:
57 node = layer._inbound_nodes[node_index]
58 if node.is_input:
59 # Reached an Input layer, stop recursion.
60 return tf.nest.flatten(node.input_tensors)
61 else:
62 source_tensors = []
63 for layer, node_index, _, tensor in node.iterate_inbound():
64 previous_sources = get_source_inputs(tensor, layer, node_index)
65 # Avoid input redundancy.
66 for x in previous_sources:
67 if all(x is not t for t in source_tensors):
68 source_tensors.append(x)
69 return source_tensors
72def validate_string_arg(
73 input_data,
74 allowable_strings,
75 layer_name,
76 arg_name,
77 allow_none=False,
78 allow_callables=False,
79):
80 """Validates the correctness of a string-based arg."""
81 if allow_none and input_data is None:
82 return
83 elif allow_callables and callable(input_data):
84 return
85 elif isinstance(input_data, str) and input_data in allowable_strings:
86 return
87 else:
88 allowed_args = "`None`, " if allow_none else ""
89 allowed_args += "a `Callable`, " if allow_callables else ""
90 allowed_args += f"or one of the following values: {allowable_strings}"
91 if allow_callables:
92 callable_note = (
93 f"If restoring a model and `{arg_name}` is a custom callable, "
94 "please ensure the callable is registered as a custom object. "
95 "See https://www.tensorflow.org/guide/keras/save_and_serialize"
96 "#registering_the_custom_object for details. "
97 )
98 else:
99 callable_note = ""
100 raise ValueError(
101 f"Unkown value for `{arg_name}` argument of layer {layer_name}. "
102 f"{callable_note}Allowed values are: {allowed_args}. Received: "
103 f"{input_data}"
104 )
107def count_params(weights):
108 """Count the total number of scalars composing the weights.
110 Args:
111 weights: An iterable containing the weights on which to compute params
113 Returns:
114 The total number of scalars composing the weights
115 """
116 unique_weights = {id(w): w for w in weights}.values()
117 # Ignore TrackableWeightHandlers, which will not have a shape defined.
118 unique_weights = [w for w in unique_weights if hasattr(w, "shape")]
119 weight_shapes = [w.shape.as_list() for w in unique_weights]
120 standardized_weight_shapes = [
121 [0 if w_i is None else w_i for w_i in w] for w in weight_shapes
122 ]
123 return int(sum(np.prod(p) for p in standardized_weight_shapes))
126def weight_memory_size(weights):
127 """Calculate the memory footprint for weights based on their dtypes.
129 Args:
130 weights: An iterable contains the weights to compute weight size.
132 Returns:
133 The total memory size (in Bytes) of the weights.
134 """
135 unique_weights = {id(w): w for w in weights}.values()
137 total_memory_size = 0
138 for w in unique_weights:
139 # Ignore TrackableWeightHandlers, which will not have a shape defined.
140 if not hasattr(w, "shape"):
141 continue
142 elif None in w.shape.as_list():
143 continue
144 weight_shape = np.prod(w.shape.as_list())
145 per_param_size = w.dtype.size
146 total_memory_size += weight_shape * per_param_size
147 return total_memory_size
150def dtensor_variable_summary(weights):
151 """Group and calculate DTensor based weights memory size.
153 Since DTensor weights can be sharded across multiple device, the result
154 will be grouped by the layout/sharding spec for the variables, so that
155 the accurate per-device memory size can be calculated.
157 Args:
158 weights: An iterable contains the weights to compute weight size.
160 Returns:
161 total_weight_count, total_memory_size and per_sharing_spec_result which
162 is a dict with normalized layout spec as key and tuple of weight count
163 and weight size as value.
164 """
165 unique_weights = {id(w): w for w in weights}.values()
166 total_weight_count = 0
167 total_memory_size = 0
168 per_sharing_spec_result = {}
169 for w in unique_weights:
170 # Ignore TrackableWeightHandlers, which will not have a shape defined.
171 if not hasattr(w, "shape"):
172 continue
173 if not isinstance(w, tf.experimental.dtensor.DVariable):
174 continue
175 layout = w.layout
176 # Remove all the duplication axis, and sort the column name.
177 # 1D replicated and 2D replicated variable will still be fully
178 # replicated, and [batch, model] sharding will have same memory
179 # footprint as the [model, batch] layout.
180 reduced_sharding_spec = list(sorted(set(layout.sharding_specs)))
181 if tf.experimental.dtensor.UNSHARDED in reduced_sharding_spec:
182 reduced_sharding_spec.remove(tf.experimental.dtensor.UNSHARDED)
183 reduced_sharding_spec = tuple(reduced_sharding_spec) # For dict key
184 weight_count, memory_size = per_sharing_spec_result.get(
185 reduced_sharding_spec, (0, 0)
186 )
187 reduced_weight_shape = np.prod(w.shape.as_list())
188 per_param_size = w.dtype.size
189 weight_count += reduced_weight_shape
190 memory_size += reduced_weight_shape * per_param_size
191 per_sharing_spec_result[reduced_sharding_spec] = (
192 weight_count,
193 memory_size,
194 )
195 total_weight_count += reduced_weight_shape
196 total_memory_size += reduced_weight_shape * per_param_size
197 return total_weight_count, total_memory_size, per_sharing_spec_result
200def print_dtensor_variable_summary(model, print_fn, line_length):
201 if getattr(model, "_layout_map", None) is not None:
202 mesh = model._layout_map.get_default_mesh()
203 elif hasattr(model, "distribute_strategy") and hasattr(
204 model.distribute_strategy, "_mesh"
205 ):
206 mesh = model.distribute_strategy._mesh
207 else:
208 # Not running with DTensor
209 mesh = None
210 if mesh:
211 (
212 total_weight_count,
213 total_memory_size,
214 per_sharing_spec_result,
215 ) = dtensor_variable_summary(model.weights)
216 total_per_device_memory_size = 0
217 for sharding_spec in sorted(per_sharing_spec_result.keys()):
218 count, memory_size = per_sharing_spec_result[sharding_spec]
219 if len(sharding_spec) == 0:
220 print_fn(
221 f"{count} / {total_weight_count} params "
222 f"({readable_memory_size(memory_size)}) "
223 "are fully replicated"
224 )
225 per_device_size = memory_size
226 else:
227 sharding_factor = np.prod(
228 [mesh.dim_size(s) for s in sharding_spec]
229 )
230 per_device_size = memory_size / sharding_factor
231 print_fn(
232 f"{count} / {total_weight_count} params "
233 f"({readable_memory_size(memory_size)}) are sharded based "
234 f"on spec '{sharding_spec}' and across {sharding_factor} "
235 f"devices."
236 )
237 total_per_device_memory_size += per_device_size
238 print_fn(
239 "Overall per device memory usage: "
240 f"{readable_memory_size(total_per_device_memory_size)}"
241 )
242 print_fn(
243 "Overall sharding factor: {:.2f}".format(
244 total_memory_size / total_per_device_memory_size
245 )
246 )
247 print_fn("_" * line_length)
250def readable_memory_size(weight_memory_size):
251 """Convert the weight memory size (Bytes) to a readable string."""
252 units = ["Byte", "KB", "MB", "GB", "TB", "PB"]
253 scale = 1024
254 for unit in units:
255 if weight_memory_size / scale < 1:
256 return "{:.2f} {}".format(weight_memory_size, unit)
257 else:
258 weight_memory_size /= scale
259 return "{:.2f} {}".format(weight_memory_size, units[-1])
262def get_layer_index_bound_by_layer_name(model, layer_range=None):
263 """Get the layer indexes from the model based on layer names.
265 The layer indexes can be used to slice the model into sub models for
266 display.
268 Args:
269 model: `tf.keras.Model` instance.
270 layer_names: a list or tuple of 2 strings, the starting layer name and
271 ending layer name (both inclusive) for the result. All layers will
272 be included when `None` is provided.
274 Returns:
275 The index value of layer based on its unique name (layer_names).
276 Output will be [first_layer_index, last_layer_index + 1].
277 """
278 if layer_range is not None:
279 if len(layer_range) != 2:
280 raise ValueError(
281 "layer_range must be a list or tuple of length 2. Received: "
282 f"layer_range = {layer_range} of length {len(layer_range)}"
283 )
284 if not isinstance(layer_range[0], str) or not isinstance(
285 layer_range[1], str
286 ):
287 raise ValueError(
288 "layer_range should contain string type only. "
289 f"Received: {layer_range}"
290 )
291 else:
292 return [0, len(model.layers)]
294 lower_index = [
295 idx
296 for idx, layer in enumerate(model.layers)
297 if re.match(layer_range[0], layer.name)
298 ]
299 upper_index = [
300 idx
301 for idx, layer in enumerate(model.layers)
302 if re.match(layer_range[1], layer.name)
303 ]
305 if not lower_index or not upper_index:
306 raise ValueError(
307 "Passed layer_names do not match the layer names in the model. "
308 f"Received: {layer_range}"
309 )
311 if min(lower_index) > max(upper_index):
312 return [min(upper_index), max(lower_index) + 1]
313 return [min(lower_index), max(upper_index) + 1]
316def print_summary(
317 model,
318 line_length=None,
319 positions=None,
320 print_fn=None,
321 expand_nested=False,
322 show_trainable=False,
323 layer_range=None,
324):
325 """Prints a summary of a model.
327 Args:
328 model: Keras model instance.
329 line_length: Total length of printed lines
330 (e.g. set this to adapt the display to different
331 terminal window sizes).
332 positions: Relative or absolute positions of log elements in each line.
333 If not provided, defaults to `[0.3, 0.6, 0.70, 1.]`.
334 print_fn: Print function to use.
335 It will be called on each line of the summary.
336 You can set it to a custom function
337 in order to capture the string summary.
338 It defaults to `print` (prints to stdout).
339 expand_nested: Whether to expand the nested models.
340 If not provided, defaults to `False`.
341 show_trainable: Whether to show if a layer is trainable.
342 If not provided, defaults to `False`.
343 layer_range: List or tuple containing two strings,
344 the starting layer name and ending layer name (both inclusive),
345 indicating the range of layers to be printed in the summary. The
346 strings could also be regexes instead of an exact name. In this
347 case, the starting layer will be the first layer that matches
348 `layer_range[0]` and the ending layer will be the last element that
349 matches `layer_range[1]`. By default (`None`) all
350 layers in the model are included in the summary.
351 """
352 if print_fn is None:
353 print_fn = io_utils.print_msg
355 if model.__class__.__name__ == "Sequential":
356 sequential_like = True
357 elif not model._is_graph_network:
358 # We treat subclassed models as a simple sequence of layers, for logging
359 # purposes.
360 sequential_like = True
361 else:
362 sequential_like = True
363 nodes_by_depth = model._nodes_by_depth.values()
364 nodes = []
365 for v in nodes_by_depth:
366 if (len(v) > 1) or (
367 len(v) == 1 and len(tf.nest.flatten(v[0].keras_inputs)) > 1
368 ):
369 # if the model has multiple nodes
370 # or if the nodes have multiple inbound_layers
371 # the model is no longer sequential
372 sequential_like = False
373 break
374 nodes += v
375 if sequential_like:
376 # search for shared layers
377 for layer in model.layers:
378 flag = False
379 for node in layer._inbound_nodes:
380 if node in nodes:
381 if flag:
382 sequential_like = False
383 break
384 else:
385 flag = True
386 if not sequential_like:
387 break
389 if sequential_like:
390 line_length = line_length or 65
391 positions = positions or [0.45, 0.85, 1.0]
392 if positions[-1] <= 1:
393 positions = [int(line_length * p) for p in positions]
394 # header names for the different log elements
395 to_display = ["Layer (type)", "Output Shape", "Param #"]
396 else:
397 line_length = line_length or 98
398 positions = positions or [0.3, 0.6, 0.70, 1.0]
399 if positions[-1] <= 1:
400 positions = [int(line_length * p) for p in positions]
401 # header names for the different log elements
402 to_display = ["Layer (type)", "Output Shape", "Param #", "Connected to"]
403 relevant_nodes = []
404 for v in model._nodes_by_depth.values():
405 relevant_nodes += v
407 if show_trainable:
408 line_length += 11
409 positions.append(line_length)
410 to_display.append("Trainable")
412 layer_range = get_layer_index_bound_by_layer_name(model, layer_range)
414 def print_row(fields, positions, nested_level=0):
415 left_to_print = [str(x) for x in fields]
416 while any(left_to_print):
417 line = ""
418 for col in range(len(left_to_print)):
419 if col > 0:
420 start_pos = positions[col - 1]
421 else:
422 start_pos = 0
423 end_pos = positions[col]
424 # Leave room for 2 spaces to delineate columns
425 # we don't need any if we are printing the last column
426 space = 2 if col != len(positions) - 1 else 0
427 cutoff = end_pos - start_pos - space
428 # Except for last col, offset by one to align the start of col
429 if col != len(positions) - 1:
430 cutoff -= 1
431 if col == 0:
432 cutoff -= nested_level
433 fit_into_line = left_to_print[col][:cutoff]
434 # For nicer formatting we line-break on seeing end of
435 # tuple/dict etc.
436 line_break_conditions = ("),", "},", "],", "',")
437 candidate_cutoffs = [
438 fit_into_line.find(x) + len(x)
439 for x in line_break_conditions
440 if fit_into_line.find(x) >= 0
441 ]
442 if candidate_cutoffs:
443 cutoff = min(candidate_cutoffs)
444 fit_into_line = fit_into_line[:cutoff]
446 if col == 0:
447 line += "|" * nested_level + " "
448 line += fit_into_line
449 line += " " * space if space else ""
450 left_to_print[col] = left_to_print[col][cutoff:]
452 # Pad out to the next position
453 # Make space for nested_level for last column
454 if nested_level and col == len(positions) - 1:
455 line += " " * (positions[col] - len(line) - nested_level)
456 else:
457 line += " " * (positions[col] - len(line))
458 line += "|" * nested_level
459 print_fn(line)
461 print_fn(f'Model: "{model.name}"')
462 print_fn("_" * line_length)
463 print_row(to_display, positions)
464 print_fn("=" * line_length)
466 def print_layer_summary(layer, nested_level=0):
467 """Prints a summary for a single layer.
469 Args:
470 layer: target layer.
471 nested_level: level of nesting of the layer inside its parent layer
472 (e.g. 0 for a top-level layer, 1 for a nested layer).
473 """
474 try:
475 output_shape = layer.output_shape
476 except AttributeError:
477 output_shape = "multiple"
478 except RuntimeError: # output_shape unknown in Eager mode.
479 output_shape = "?"
480 name = layer.name
481 cls_name = layer.__class__.__name__
482 if not layer.built and not getattr(layer, "_is_graph_network", False):
483 # If a subclassed model has a layer that is not called in
484 # Model.call, the layer will not be built and we cannot call
485 # layer.count_params().
486 params = "0 (unused)"
487 else:
488 params = layer.count_params()
489 fields = [name + " (" + cls_name + ")", output_shape, params]
491 if show_trainable:
492 fields.append("Y" if layer.trainable else "N")
494 print_row(fields, positions, nested_level)
496 def print_layer_summary_with_connections(layer, nested_level=0):
497 """Prints a summary for a single layer (including its connections).
499 Args:
500 layer: target layer.
501 nested_level: level of nesting of the layer inside its parent layer
502 (e.g. 0 for a top-level layer, 1 for a nested layer).
503 """
504 try:
505 output_shape = layer.output_shape
506 except AttributeError:
507 output_shape = "multiple"
508 connections = []
509 for node in layer._inbound_nodes:
510 if relevant_nodes and node not in relevant_nodes:
511 # node is not part of the current network
512 continue
514 for (
515 inbound_layer,
516 node_index,
517 tensor_index,
518 _,
519 ) in node.iterate_inbound():
520 connections.append(
521 f"{inbound_layer.name}[{node_index}][{tensor_index}]"
522 )
524 name = layer.name
525 cls_name = layer.__class__.__name__
526 fields = [
527 name + " (" + cls_name + ")",
528 output_shape,
529 layer.count_params(),
530 connections,
531 ]
533 if show_trainable:
534 fields.append("Y" if layer.trainable else "N")
536 print_row(fields, positions, nested_level)
538 def print_layer(layer, nested_level=0, is_nested_last=False):
539 if sequential_like:
540 print_layer_summary(layer, nested_level)
541 else:
542 print_layer_summary_with_connections(layer, nested_level)
544 if expand_nested and hasattr(layer, "layers") and layer.layers:
545 print_fn(
546 "|" * (nested_level + 1)
547 + "¯" * (line_length - 2 * nested_level - 2)
548 + "|" * (nested_level + 1)
549 )
551 nested_layer = layer.layers
552 is_nested_last = False
553 for i in range(len(nested_layer)):
554 if i == len(nested_layer) - 1:
555 is_nested_last = True
556 print_layer(nested_layer[i], nested_level + 1, is_nested_last)
558 print_fn(
559 "|" * nested_level
560 + "¯" * (line_length - 2 * nested_level)
561 + "|" * nested_level
562 )
564 if not is_nested_last:
565 print_fn(
566 "|" * nested_level
567 + " " * (line_length - 2 * nested_level)
568 + "|" * nested_level
569 )
571 for layer in model.layers[layer_range[0] : layer_range[1]]:
572 print_layer(layer)
573 print_fn("=" * line_length)
575 if hasattr(model, "_collected_trainable_weights"):
576 trainable_count = count_params(model._collected_trainable_weights)
577 trainable_memory_size = weight_memory_size(
578 model._collected_trainable_weights
579 )
580 else:
581 trainable_count = count_params(model.trainable_weights)
582 trainable_memory_size = weight_memory_size(model.trainable_weights)
584 non_trainable_count = count_params(model.non_trainable_weights)
585 non_trainable_memory_size = weight_memory_size(model.non_trainable_weights)
587 total_memory_size = trainable_memory_size + non_trainable_memory_size
589 print_fn(
590 f"Total params: {trainable_count + non_trainable_count} "
591 f"({readable_memory_size(total_memory_size)})"
592 )
593 print_fn(
594 f"Trainable params: {trainable_count} "
595 f"({readable_memory_size(trainable_memory_size)})"
596 )
597 print_fn(
598 f"Non-trainable params: {non_trainable_count} "
599 f"({readable_memory_size(non_trainable_memory_size)})"
600 )
601 print_fn("_" * line_length)
603 print_dtensor_variable_summary(model, print_fn, line_length)
606def convert_dense_weights_data_format(
607 dense, previous_feature_map_shape, target_data_format="channels_first"
608):
609 """Utility useful when changing a convnet's `data_format`.
611 When porting the weights of a convnet from one data format to the other,
612 if the convnet includes a `Flatten` layer
613 (applied to the last convolutional feature map)
614 followed by a `Dense` layer, the weights of that `Dense` layer
615 should be updated to reflect the new dimension ordering.
617 Args:
618 dense: The target `Dense` layer.
619 previous_feature_map_shape: A shape tuple of 3 integers,
620 e.g. `(512, 7, 7)`. The shape of the convolutional
621 feature map right before the `Flatten` layer that
622 came before the target `Dense` layer.
623 target_data_format: One of "channels_last", "channels_first".
624 Set it "channels_last"
625 if converting a "channels_first" model to "channels_last",
626 or reciprocally.
627 """
628 assert target_data_format in {"channels_last", "channels_first"}
629 kernel, bias = dense.get_weights()
630 for i in range(kernel.shape[1]):
631 if target_data_format == "channels_first":
632 c, h, w = previous_feature_map_shape
633 original_fm_shape = (h, w, c)
634 ki = kernel[:, i].reshape(original_fm_shape)
635 ki = np.transpose(ki, (2, 0, 1)) # last -> first
636 else:
637 h, w, c = previous_feature_map_shape
638 original_fm_shape = (c, h, w)
639 ki = kernel[:, i].reshape(original_fm_shape)
640 ki = np.transpose(ki, (1, 2, 0)) # first -> last
641 kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),))
642 dense.set_weights([kernel, bias])
645def is_builtin_layer(layer):
646 if not getattr(layer, "_keras_api_names", None):
647 return False
649 # Subclasses of `Layer` that are not exported inherit the export name
650 # of the base layer class.
651 return layer._keras_api_names != (
652 "keras.layers.Layer",
653 ) and layer._keras_api_names_v1 != ("keras.layers.Layer",)
656def cached_per_instance(f):
657 """Lightweight decorator for caching lazily constructed properties.
659 When to use:
660 This decorator provides simple caching with minimal overhead. It is designed
661 for properties which are expensive to compute and static over the life of a
662 class instance, and provides no mechanism for cache invalidation. Thus it is
663 best suited for lazily exposing derived properties of other static data.
665 For classes with custom getattr / setattr behavior (such as trackable
666 objects), storing cache results as object attributes is not performant.
667 Instead, a specialized cache can significantly reduce property lookup
668 overhead. (While still allowing the decorated property to be lazily
669 computed.) Consider the following class:
671 ```
672 class MyClass:
673 def __setattr__(self, key, value):
674 # Some expensive class specific code
675 # ...
676 # ...
678 super(MyClass, self).__setattr__(key, value)
680 @property
681 def thing(self):
682 # `thing` is expensive to compute (and may not even be requested), so we
683 # want to lazily compute it and then cache it.
684 output = getattr(self, '_thing', None)
685 if output is None:
686 self._thing = output = compute_thing(self)
687 return output
688 ```
690 It's also worth noting that ANY overriding of __setattr__, even something as
691 simple as:
692 ```
693 def __setattr__(self, key, value):
694 super(MyClass, self).__setattr__(key, value)
695 ```
697 Slows down attribute assignment by nearly 10x.
699 By contrast, replacing the definition of `thing` with the following
700 sidesteps the expensive __setattr__ altogether:
702 '''
703 @property
704 @tracking.cached_per_instance
705 def thing(self):
706 # `thing` is expensive to compute (and may not even be requested), so we
707 # want to lazily compute it and then cache it.
708 return compute_thing(self)
709 '''
711 Performance:
712 The overhead for this decorator is ~0.4 us / call. A much lower overhead
713 implementation (~0.085 us / call) can be achieved by using a custom dict
714 type:
716 ```
717 def dict_based_cache(f):
718 class Cache(dict):
719 __slots__ = ()
720 def __missing__(self, key):
721 self[key] = output = f(key)
722 return output
724 return property(Cache().__getitem__)
725 ```
727 However, that implementation holds class instances as keys, and as a result
728 blocks garbage collection. (And modifying it to use weakref's as keys raises
729 the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary
730 implementation below turns out to be more prudent.
732 Args:
733 f: The function to cache.
735 Returns:
736 f decorated with simple caching behavior.
737 """
739 cache = weakref.WeakKeyDictionary()
741 @functools.wraps(f)
742 def wrapped(item):
743 output = cache.get(item)
744 if output is None:
745 cache[item] = output = f(item)
746 return output
748 wrapped.cache = cache
749 return wrapped
752def filter_empty_layer_containers(layer_list):
753 """Filter out empty Layer-like containers and uniquify."""
754 # TODO(b/130381733): Make this an attribute in base_layer.Layer.
755 existing = set()
756 to_visit = layer_list[::-1]
757 while to_visit:
758 obj = to_visit.pop()
759 if id(obj) in existing:
760 continue
761 existing.add(id(obj))
762 if hasattr(obj, "_is_layer") and not isinstance(obj, type):
763 yield obj
764 else:
765 sub_layers = getattr(obj, "layers", None) or []
767 # Trackable data structures will not show up in ".layers" lists, but
768 # the layers they contain will.
769 to_visit.extend(sub_layers[::-1])
772class CallFunctionSpec:
773 """Caches the spec and provides utilities for handling call function
774 args."""
776 def __init__(self, full_argspec):
777 """Initialies a `CallFunctionSpec`.
779 Args:
780 full_argspec: the FullArgSpec of a call function of a layer.
781 """
782 self._full_argspec = full_argspec
784 self._arg_names = list(self._full_argspec.args)
785 # Scrub `self` that appears if a decorator was applied.
786 if self._arg_names and self._arg_names[0] == "self":
787 self._arg_names = self._arg_names[1:]
788 self._arg_names += self._full_argspec.kwonlyargs or []
790 call_accepts_kwargs = self._full_argspec.varkw is not None
791 self._expects_training_arg = (
792 "training" in self._arg_names or call_accepts_kwargs
793 )
794 self._expects_mask_arg = (
795 "mask" in self._arg_names or call_accepts_kwargs
796 )
798 call_fn_defaults = self._full_argspec.defaults or []
799 defaults = dict()
800 # The call arg defaults are an n-tuple of the last n elements of the
801 # args list. (n = # of elements that have a default argument)
802 for i in range(-1 * len(call_fn_defaults), 0):
803 defaults[self._arg_names[i]] = call_fn_defaults[i]
804 # The default training arg will be any (non-None) default specified in
805 # the method signature, or None if no value is specified.
806 defaults.update(self._full_argspec.kwonlydefaults or {})
807 self._default_training_arg = defaults.get("training")
809 @property
810 def full_argspec(self):
811 """Returns the FullArgSpec of the call function."""
812 return self._full_argspec
814 @property
815 def arg_names(self):
816 """List of names of args and kwonlyargs."""
817 # `arg_names` is not accurate if the layer has variable positional args.
818 return self._arg_names
820 @arg_names.setter
821 def arg_names(self, value):
822 self._arg_names = value
824 @property
825 @cached_per_instance
826 def arg_positions(self):
827 """Returns a dict mapping arg names to their index positions."""
828 # `arg_positions` is not accurate if the layer has variable positional
829 # args.
830 call_fn_arg_positions = dict()
831 for pos, arg in enumerate(self._arg_names):
832 call_fn_arg_positions[arg] = pos
833 return call_fn_arg_positions
835 @property
836 def expects_training_arg(self):
837 """Whether the call function uses 'training' as a parameter."""
838 return self._expects_training_arg
840 @expects_training_arg.setter
841 def expects_training_arg(self, value):
842 self._expects_training_arg = value
844 @property
845 def expects_mask_arg(self):
846 """Whether the call function uses `mask` as a parameter."""
847 return self._expects_mask_arg
849 @expects_mask_arg.setter
850 def expects_mask_arg(self, value):
851 self._expects_mask_arg = value
853 @property
854 def default_training_arg(self):
855 """The default value given to the "training" argument."""
856 return self._default_training_arg
858 def arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False):
859 """Returns true if argument is present in `args` or `kwargs`.
861 Args:
862 arg_name: String name of the argument to find.
863 args: Tuple of args passed to the call function.
864 kwargs: Dictionary of kwargs passed to the call function.
865 inputs_in_args: Whether the input argument (the first argument in the
866 call function) is included in `args`. Defaults to `False`.
868 Returns:
869 True if argument with `arg_name` is present in `args` or `kwargs`.
870 """
871 # Performance optimization: do no work in most common case.
872 if not args and not kwargs:
873 return False
875 if arg_name in kwargs:
876 return True
877 call_fn_args = self._arg_names
878 if not inputs_in_args:
879 # Ignore `inputs` arg.
880 call_fn_args = call_fn_args[1:]
881 return arg_name in dict(zip(call_fn_args, args))
883 def get_arg_value(self, arg_name, args, kwargs, inputs_in_args=False):
884 """Retrieves the value for the argument with name `arg_name`.
886 Args:
887 arg_name: String name of the argument to find.
888 args: Tuple of args passed to the call function.
889 kwargs: Dictionary of kwargs passed to the call function.
890 inputs_in_args: Whether the input argument (the first argument in the
891 call function) is included in `args`. Defaults to `False`.
893 Returns:
894 The value of the argument with name `arg_name`, extracted from `args`
895 or `kwargs`.
897 Raises:
898 KeyError if the value of `arg_name` cannot be found.
899 """
900 if arg_name in kwargs:
901 return kwargs[arg_name]
902 call_fn_args = self._arg_names
903 if not inputs_in_args:
904 # Ignore `inputs` arg.
905 call_fn_args = call_fn_args[1:]
906 args_dict = dict(zip(call_fn_args, args))
907 return args_dict[arg_name]
909 def set_arg_value(
910 self,
911 arg_name,
912 new_value,
913 args,
914 kwargs,
915 inputs_in_args=False,
916 pop_kwarg_if_none=False,
917 ):
918 """Sets the value of an argument into the given args/kwargs.
920 Args:
921 arg_name: String name of the argument to find.
922 new_value: New value to give to the argument.
923 args: Tuple of args passed to the call function.
924 kwargs: Dictionary of kwargs passed to the call function.
925 inputs_in_args: Whether the input argument (the first argument in the
926 call function) is included in `args`. Defaults to `False`.
927 pop_kwarg_if_none: If the new value is `None`, and this is `True`,
928 then the argument is deleted from `kwargs`.
930 Returns:
931 The updated `(args, kwargs)`.
932 """
933 if self.full_argspec.varargs:
934 try:
935 arg_pos = self.full_argspec.args.index(arg_name)
936 if self.full_argspec.args[0] == "self":
937 arg_pos -= 1
938 except ValueError:
939 arg_pos = None
940 else:
941 arg_pos = self.arg_positions.get(arg_name, None)
943 if arg_pos is not None:
944 if not inputs_in_args:
945 # Ignore `inputs` arg.
946 arg_pos = arg_pos - 1
947 if len(args) > arg_pos:
948 args = list(args)
949 args[arg_pos] = new_value
950 return tuple(args), kwargs
951 if new_value is None and pop_kwarg_if_none:
952 kwargs.pop(arg_name, None)
953 else:
954 kwargs[arg_name] = new_value
955 return args, kwargs
957 def split_out_first_arg(self, args, kwargs):
958 """Splits (args, kwargs) into (inputs, args, kwargs)."""
959 # Grab the argument corresponding to the first argument in the
960 # layer's `call` method spec. This will either be the first positional
961 # argument, or it will be provided as a keyword argument.
962 if args:
963 inputs = args[0]
964 args = args[1:]
965 elif self._arg_names[0] in kwargs:
966 kwargs = copy.copy(kwargs)
967 inputs = kwargs.pop(self._arg_names[0])
968 else:
969 raise ValueError(
970 "The first argument to `Layer.call` must always be passed."
971 )
972 return inputs, args, kwargs
975@keras_export("keras.utils.warmstart_embedding_matrix")
976def warmstart_embedding_matrix(
977 base_vocabulary,
978 new_vocabulary,
979 base_embeddings,
980 new_embeddings_initializer="uniform",
981):
982 """Warm start embedding matrix with changing vocab.
984 This util can be used to warmstart the embedding layer matrix when
985 vocabulary changes between previously saved checkpoint and model.
986 Vocabulary change could mean, the size of the new vocab is different or the
987 vocabulary is reshuffled or new vocabulary has been added to old vocabulary.
988 If the vocabulary size changes, size of the embedding layer matrix also
989 changes. This util remaps the old vocabulary embeddings to the new embedding
990 layer matrix.
992 Example:
993 Here is an example that demonstrates how to use the
994 `warmstart_embedding_matrix` util.
995 >>> import keras
996 >>> vocab_base = tf.convert_to_tensor(["unk", "a", "b", "c"])
997 >>> vocab_new = tf.convert_to_tensor(
998 ... ["unk", "unk", "a", "b", "c", "d", "e"])
999 >>> vectorized_vocab_base = np.random.rand(vocab_base.shape[0], 3)
1000 >>> vectorized_vocab_new = np.random.rand(vocab_new.shape[0], 3)
1001 >>> warmstarted_embedding_matrix = warmstart_embedding_matrix(
1002 ... base_vocabulary=vocab_base,
1003 ... new_vocabulary=vocab_new,
1004 ... base_embeddings=vectorized_vocab_base,
1005 ... new_embeddings_initializer=keras.initializers.Constant(
1006 ... vectorized_vocab_new))
1008 Here is an example that demonstrates how to get vocabulary and embedding
1009 weights from layers, use the `warmstart_embedding_matrix` util to remap the
1010 layer embeddings and continue with model training.
1011 ```
1012 # get old and new vocabulary by using layer.get_vocabulary()
1013 # for example assume TextVectorization layer is used
1014 base_vocabulary = old_text_vectorization_layer.get_vocabulary()
1015 new_vocabulary = new_text_vectorization_layer.get_vocabulary()
1016 # get previous embedding layer weights
1017 embedding_weights_base = model.get_layer('embedding').get_weights()[0]
1018 warmstarted_embedding = keras.utils.warmstart_embedding_matrix(
1019 base_vocabulary,
1020 new_vocabulary,
1021 base_embeddings=embedding_weights_base,
1022 new_embeddings_initializer="uniform")
1023 updated_embedding_variable = tf.Variable(warmstarted_embedding)
1025 # update embedding layer weights
1026 model.layers[1].embeddings = updated_embedding_variable
1027 model.fit(..)
1028 # continue with model training
1030 ```
1032 Args:
1033 base_vocabulary: The list of vocabulary terms that
1034 the preexisting embedding matrix `base_embeddings` represents.
1035 It can be either a 1D array/tensor or a tuple/list of vocabulary
1036 terms (strings), or a path to a vocabulary text file. If passing a
1037 file path, the file should contain one line per term in the
1038 vocabulary.
1039 new_vocabulary: The list of vocabulary terms for the new vocabulary
1040 (same format as above).
1041 base_embeddings: NumPy array or tensor representing the preexisting
1042 embedding matrix.
1043 new_embeddings_initializer: Initializer for embedding vectors for
1044 previously unseen terms to be added to the new embedding matrix (see
1045 `keras.initializers`). Defaults to "uniform". new_embedding matrix
1046 needs to be specified with "constant" initializer.
1047 matrix. Default value is None.
1049 Returns:
1050 tf.tensor of remapped embedding layer matrix
1052 """
1053 # convert vocab to list
1054 base_vocabulary = convert_vocab_to_list(base_vocabulary)
1055 new_vocabulary = convert_vocab_to_list(new_vocabulary)
1057 # Initialize the new embedding layer matrix
1058 new_embeddings_initializer = initializers.get(new_embeddings_initializer)
1059 new_embedding = new_embeddings_initializer(
1060 shape=(len(new_vocabulary), base_embeddings.shape[1]),
1061 dtype=base_embeddings.dtype,
1062 )
1064 # create mapping dict {vocab:index}
1065 base_vocabulary_dict = dict(
1066 zip(base_vocabulary, range(len(base_vocabulary)))
1067 )
1069 indices_base_vocabulary = []
1070 indices_new_vocabulary = []
1071 for index, key in enumerate(new_vocabulary):
1072 if key in base_vocabulary_dict:
1073 indices_base_vocabulary.append(base_vocabulary_dict[key])
1074 indices_new_vocabulary.append(int(index))
1076 # update embedding matrix
1077 if indices_base_vocabulary:
1078 values_to_update = tf.gather(base_embeddings, indices_base_vocabulary)
1079 new_embedding = tf.tensor_scatter_nd_update(
1080 new_embedding,
1081 tf.expand_dims(indices_new_vocabulary, axis=1),
1082 values_to_update,
1083 )
1084 return new_embedding
1087def convert_vocab_to_list(vocab):
1088 """Convert input vacabulary to list."""
1089 vocab_list = []
1090 if tf.is_tensor(vocab):
1091 vocab_list = list(vocab.numpy())
1092 elif isinstance(vocab, (np.ndarray, tuple, list)):
1093 vocab_list = list(vocab)
1094 elif isinstance(vocab, str):
1095 if not tf.io.gfile.exists(vocab):
1096 raise ValueError(f"Vocabulary file {vocab} does not exist.")
1097 with tf.io.gfile.GFile(vocab, "r") as vocabulary_file:
1098 vocab_list = vocabulary_file.read().splitlines()
1099 else:
1100 raise ValueError(
1101 "Vocabulary is expected to be either a NumPy array, "
1102 "list, 1D tensor or a vocabulary text file. Instead type "
1103 f"{type(vocab)} was received."
1104 )
1105 if len(vocab_list) == 0:
1106 raise ValueError(
1107 "Vocabulary is expected to be either a NumPy array, "
1108 "list, 1D tensor or a vocabulary text file with at least one token."
1109 " Received 0 instead."
1110 )
1111 return vocab_list