Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/saving/hdf5_format.py: 8%
362 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Functions for saving and loading a Keras Model from HDF5 format."""
18import json
19import os
21import numpy as np
23from tensorflow.python.keras import backend
24from tensorflow.python.keras import optimizer_v1
25from tensorflow.python.keras.saving import model_config as model_config_lib
26from tensorflow.python.keras.saving import saving_utils
27from tensorflow.python.keras.saving.saved_model import json_utils
28from tensorflow.python.keras.utils.generic_utils import LazyLoader
29from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
30from tensorflow.python.ops import variables as variables_module
31from tensorflow.python.platform import gfile
32from tensorflow.python.platform import tf_logging as logging
35# pylint: disable=g-import-not-at-top
36try:
37 import h5py
38 HDF5_OBJECT_HEADER_LIMIT = 64512
39except ImportError:
40 h5py = None
41# pylint: enable=g-import-not-at-top
43# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
44# once the issue with copybara is fixed.
45# pylint:disable=g-inconsistent-quotes
46sequential_lib = LazyLoader(
47 "sequential_lib", globals(),
48 "tensorflow.python.keras.engine.sequential")
49# pylint:enable=g-inconsistent-quotes
52def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
53 """Saves a model to a HDF5 file.
55 The saved model contains:
56 - the model's configuration (topology)
57 - the model's weights
58 - the model's optimizer's state (if any)
60 Thus the saved model can be reinstantiated in
61 the exact same state, without any of the code
62 used for model definition or training.
64 Args:
65 model: Keras model instance to be saved.
66 filepath: One of the following:
67 - String, path where to save the model
68 - `h5py.File` object where to save the model
69 overwrite: Whether we should overwrite any existing
70 model at the target location, or instead
71 ask the user with a manual prompt.
72 include_optimizer: If True, save optimizer's state together.
74 Raises:
75 ImportError: if h5py is not available.
76 """
78 if h5py is None:
79 raise ImportError('`save_model` requires h5py.')
81 # TODO(psv) Add warning when we save models that contain non-serializable
82 # entities like metrics added using `add_metric` and losses added using
83 # `add_loss.`
84 if len(model.weights) != len(model._undeduplicated_weights):
85 logging.warning('Found duplicated `Variable`s in Model\'s `weights`. '
86 'This is usually caused by `Variable`s being shared by '
87 'Layers in the Model. These `Variable`s will be treated '
88 'as separate `Variable`s when the Model is restored. To '
89 'avoid this, please save with `save_format="tf"`.')
91 if not isinstance(filepath, h5py.File):
92 # If file exists and should not be overwritten.
93 if not overwrite and os.path.isfile(filepath):
94 proceed = ask_to_proceed_with_overwrite(filepath)
95 if not proceed:
96 return
98 # Try creating dir if not exist
99 dirpath = os.path.dirname(filepath)
100 if not os.path.exists(dirpath):
101 gfile.MakeDirs(dirpath)
103 f = h5py.File(filepath, mode='w')
104 opened_new_file = True
105 else:
106 f = filepath
107 opened_new_file = False
109 try:
110 model_metadata = saving_utils.model_metadata(model, include_optimizer)
111 for k, v in model_metadata.items():
112 if isinstance(v, (dict, list, tuple)):
113 f.attrs[k] = json.dumps(
114 v, default=json_utils.get_json_type).encode('utf8')
115 else:
116 f.attrs[k] = v
118 model_weights_group = f.create_group('model_weights')
119 model_layers = model.layers
120 save_weights_to_hdf5_group(model_weights_group, model_layers)
122 # TODO(b/128683857): Add integration tests between tf.keras and external
123 # Keras, to avoid breaking TF.js users.
124 if (include_optimizer and model.optimizer and
125 not isinstance(model.optimizer, optimizer_v1.TFOptimizer)):
126 save_optimizer_weights_to_hdf5_group(f, model.optimizer)
128 f.flush()
129 finally:
130 if opened_new_file:
131 f.close()
134def load_model_from_hdf5(filepath, custom_objects=None, compile=True): # pylint: disable=redefined-builtin
135 """Loads a model saved via `save_model_to_hdf5`.
137 Args:
138 filepath: One of the following:
139 - String, path to the saved model
140 - `h5py.File` object from which to load the model
141 custom_objects: Optional dictionary mapping names
142 (strings) to custom classes or functions to be
143 considered during deserialization.
144 compile: Boolean, whether to compile the model
145 after loading.
147 Returns:
148 A Keras model instance. If an optimizer was found
149 as part of the saved model, the model is already
150 compiled. Otherwise, the model is uncompiled and
151 a warning will be displayed. When `compile` is set
152 to False, the compilation is omitted without any
153 warning.
155 Raises:
156 ImportError: if h5py is not available.
157 ValueError: In case of an invalid savefile.
158 """
159 if h5py is None:
160 raise ImportError('`load_model` requires h5py.')
162 if not custom_objects:
163 custom_objects = {}
165 opened_new_file = not isinstance(filepath, h5py.File)
166 if opened_new_file:
167 f = h5py.File(filepath, mode='r')
168 else:
169 f = filepath
171 model = None
172 try:
173 # instantiate model
174 model_config = f.attrs.get('model_config')
175 if model_config is None:
176 raise ValueError('No model found in config file.')
177 if hasattr(model_config, 'decode'):
178 model_config = model_config.decode('utf-8')
179 model_config = json_utils.decode(model_config)
180 model = model_config_lib.model_from_config(model_config,
181 custom_objects=custom_objects)
183 # set weights
184 load_weights_from_hdf5_group(f['model_weights'], model.layers)
186 if compile:
187 # instantiate optimizer
188 training_config = f.attrs.get('training_config')
189 if hasattr(training_config, 'decode'):
190 training_config = training_config.decode('utf-8')
191 if training_config is None:
192 logging.warning('No training configuration found in the save file, so '
193 'the model was *not* compiled. Compile it manually.')
194 return model
195 training_config = json_utils.decode(training_config)
197 # Compile model.
198 model.compile(**saving_utils.compile_args_from_training_config(
199 training_config, custom_objects), from_serialized=True)
200 saving_utils.try_build_compiled_arguments(model)
202 # Set optimizer weights.
203 if 'optimizer_weights' in f:
204 try:
205 model.optimizer._create_all_weights(model.trainable_variables)
206 except (NotImplementedError, AttributeError):
207 logging.warning(
208 'Error when creating the weights of optimizer {}, making it '
209 'impossible to restore the saved optimizer state. As a result, '
210 'your model is starting with a freshly initialized optimizer.')
212 optimizer_weight_values = load_optimizer_weights_from_hdf5_group(f)
213 try:
214 model.optimizer.set_weights(optimizer_weight_values)
215 except ValueError:
216 logging.warning('Error in loading the saved optimizer '
217 'state. As a result, your model is '
218 'starting with a freshly initialized '
219 'optimizer.')
220 finally:
221 if opened_new_file:
222 f.close()
223 return model
226def preprocess_weights_for_loading(layer,
227 weights,
228 original_keras_version=None,
229 original_backend=None):
230 """Preprocess layer weights between different Keras formats.
232 Converts layers weights from Keras 1 format to Keras 2 and also weights of
233 CuDNN layers in Keras 2.
235 Args:
236 layer: Layer instance.
237 weights: List of weights values (Numpy arrays).
238 original_keras_version: Keras version for the weights, as a string.
239 original_backend: Keras backend the weights were trained with,
240 as a string.
242 Returns:
243 A list of weights values (Numpy arrays).
244 """
245 def convert_nested_bidirectional(weights):
246 """Converts layers nested in `Bidirectional` wrapper.
248 This function uses `preprocess_weights_for_loading()` for converting
249 layers.
251 Args:
252 weights: List of weights values (Numpy arrays).
254 Returns:
255 A list of weights values (Numpy arrays).
256 """
257 num_weights_per_layer = len(weights) // 2
258 forward_weights = preprocess_weights_for_loading(
259 layer.forward_layer, weights[:num_weights_per_layer],
260 original_keras_version, original_backend)
261 backward_weights = preprocess_weights_for_loading(
262 layer.backward_layer, weights[num_weights_per_layer:],
263 original_keras_version, original_backend)
264 return forward_weights + backward_weights
266 def convert_nested_time_distributed(weights):
267 """Converts layers nested in `TimeDistributed` wrapper.
269 This function uses `preprocess_weights_for_loading()` for converting nested
270 layers.
272 Args:
273 weights: List of weights values (Numpy arrays).
275 Returns:
276 A list of weights values (Numpy arrays).
277 """
278 return preprocess_weights_for_loading(
279 layer.layer, weights, original_keras_version, original_backend)
281 def convert_nested_model(weights):
282 """Converts layers nested in `Model` or `Sequential`.
284 This function uses `preprocess_weights_for_loading()` for converting nested
285 layers.
287 Args:
288 weights: List of weights values (Numpy arrays).
290 Returns:
291 A list of weights values (Numpy arrays).
292 """
293 trainable_weights = weights[:len(layer.trainable_weights)]
294 non_trainable_weights = weights[len(layer.trainable_weights):]
296 new_trainable_weights = []
297 new_non_trainable_weights = []
299 for sublayer in layer.layers:
300 num_trainable_weights = len(sublayer.trainable_weights)
301 num_non_trainable_weights = len(sublayer.non_trainable_weights)
302 if sublayer.weights:
303 preprocessed = preprocess_weights_for_loading(
304 layer=sublayer,
305 weights=(trainable_weights[:num_trainable_weights] +
306 non_trainable_weights[:num_non_trainable_weights]),
307 original_keras_version=original_keras_version,
308 original_backend=original_backend)
309 new_trainable_weights.extend(preprocessed[:num_trainable_weights])
310 new_non_trainable_weights.extend(preprocessed[num_trainable_weights:])
312 trainable_weights = trainable_weights[num_trainable_weights:]
313 non_trainable_weights = non_trainable_weights[
314 num_non_trainable_weights:]
316 return new_trainable_weights + new_non_trainable_weights
318 # Convert layers nested in Bidirectional/Model/Sequential.
319 # Both transformation should be ran for both Keras 1->2 conversion
320 # and for conversion of CuDNN layers.
321 if layer.__class__.__name__ == 'Bidirectional':
322 weights = convert_nested_bidirectional(weights)
323 if layer.__class__.__name__ == 'TimeDistributed':
324 weights = convert_nested_time_distributed(weights)
325 elif layer.__class__.__name__ in ['Model', 'Sequential', 'Functional']:
326 weights = convert_nested_model(weights)
328 if original_keras_version == '1':
329 if layer.__class__.__name__ == 'TimeDistributed':
330 weights = preprocess_weights_for_loading(
331 layer.layer, weights, original_keras_version, original_backend)
333 if layer.__class__.__name__ == 'Conv1D':
334 shape = weights[0].shape
335 # Handle Keras 1.1 format
336 if shape[:2] != (layer.kernel_size[0], 1) or shape[3] != layer.filters:
337 # Legacy shape:
338 # (filters, input_dim, filter_length, 1)
339 assert shape[0] == layer.filters and shape[2:] == (layer.kernel_size[0],
340 1)
341 weights[0] = np.transpose(weights[0], (2, 3, 1, 0))
342 weights[0] = weights[0][:, 0, :, :]
344 if layer.__class__.__name__ == 'Conv2D':
345 if layer.data_format == 'channels_first':
346 # old: (filters, stack_size, kernel_rows, kernel_cols)
347 # new: (kernel_rows, kernel_cols, stack_size, filters)
348 weights[0] = np.transpose(weights[0], (2, 3, 1, 0))
350 if layer.__class__.__name__ == 'Conv2DTranspose':
351 if layer.data_format == 'channels_last':
352 # old: (kernel_rows, kernel_cols, stack_size, filters)
353 # new: (kernel_rows, kernel_cols, filters, stack_size)
354 weights[0] = np.transpose(weights[0], (0, 1, 3, 2))
355 if layer.data_format == 'channels_first':
356 # old: (filters, stack_size, kernel_rows, kernel_cols)
357 # new: (kernel_rows, kernel_cols, filters, stack_size)
358 weights[0] = np.transpose(weights[0], (2, 3, 0, 1))
360 if layer.__class__.__name__ == 'Conv3D':
361 if layer.data_format == 'channels_first':
362 # old: (filters, stack_size, ...)
363 # new: (..., stack_size, filters)
364 weights[0] = np.transpose(weights[0], (2, 3, 4, 1, 0))
366 if layer.__class__.__name__ == 'GRU':
367 if len(weights) == 9:
368 kernel = np.concatenate([weights[0], weights[3], weights[6]], axis=-1)
369 recurrent_kernel = np.concatenate(
370 [weights[1], weights[4], weights[7]], axis=-1)
371 bias = np.concatenate([weights[2], weights[5], weights[8]], axis=-1)
372 weights = [kernel, recurrent_kernel, bias]
374 if layer.__class__.__name__ == 'LSTM':
375 if len(weights) == 12:
376 # old: i, c, f, o
377 # new: i, f, c, o
378 kernel = np.concatenate(
379 [weights[0], weights[6], weights[3], weights[9]], axis=-1)
380 recurrent_kernel = np.concatenate(
381 [weights[1], weights[7], weights[4], weights[10]], axis=-1)
382 bias = np.concatenate(
383 [weights[2], weights[8], weights[5], weights[11]], axis=-1)
384 weights = [kernel, recurrent_kernel, bias]
386 if layer.__class__.__name__ == 'ConvLSTM2D':
387 if len(weights) == 12:
388 kernel = np.concatenate(
389 [weights[0], weights[6], weights[3], weights[9]], axis=-1)
390 recurrent_kernel = np.concatenate(
391 [weights[1], weights[7], weights[4], weights[10]], axis=-1)
392 bias = np.concatenate(
393 [weights[2], weights[8], weights[5], weights[11]], axis=-1)
394 if layer.data_format == 'channels_first':
395 # old: (filters, stack_size, kernel_rows, kernel_cols)
396 # new: (kernel_rows, kernel_cols, stack_size, filters)
397 kernel = np.transpose(kernel, (2, 3, 1, 0))
398 recurrent_kernel = np.transpose(recurrent_kernel, (2, 3, 1, 0))
399 weights = [kernel, recurrent_kernel, bias]
401 conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D']
402 if layer.__class__.__name__ in conv_layers:
403 if backend.int_shape(layer.weights[0]) != weights[0].shape:
404 weights[0] = np.transpose(weights[0], (3, 2, 0, 1))
405 if layer.__class__.__name__ == 'ConvLSTM2D':
406 weights[1] = np.transpose(weights[1], (3, 2, 0, 1))
408 # convert CuDNN layers
409 return _convert_rnn_weights(layer, weights)
412def _convert_rnn_weights(layer, weights):
413 """Converts weights for RNN layers between native and CuDNN format.
415 Input kernels for each gate are transposed and converted between Fortran
416 and C layout, recurrent kernels are transposed. For LSTM biases are summed/
417 split in half, for GRU biases are reshaped.
419 Weights can be converted in both directions between `LSTM` and`CuDNNSLTM`
420 and between `CuDNNGRU` and `GRU(reset_after=True)`. Default `GRU` is not
421 compatible with `CuDNNGRU`.
423 For missing biases in `LSTM`/`GRU` (`use_bias=False`) no conversion is made.
425 Args:
426 layer: Target layer instance.
427 weights: List of source weights values (input kernels, recurrent
428 kernels, [biases]) (Numpy arrays).
430 Returns:
431 A list of converted weights values (Numpy arrays).
433 Raises:
434 ValueError: for incompatible GRU layer/weights or incompatible biases
435 """
437 def transform_kernels(kernels, func, n_gates):
438 """Transforms kernel for each gate separately using given function.
440 Args:
441 kernels: Stacked array of kernels for individual gates.
442 func: Function applied to kernel of each gate.
443 n_gates: Number of gates (4 for LSTM, 3 for GRU).
445 Returns:
446 Stacked array of transformed kernels.
447 """
448 return np.hstack([func(k) for k in np.hsplit(kernels, n_gates)])
450 def transpose_input(from_cudnn):
451 """Makes a function that transforms input kernels from/to CuDNN format.
453 It keeps the shape, but changes between the layout (Fortran/C). Eg.:
455 ```
456 Keras CuDNN
457 [[0, 1, 2], <---> [[0, 2, 4],
458 [3, 4, 5]] [1, 3, 5]]
459 ```
461 It can be passed to `transform_kernels()`.
463 Args:
464 from_cudnn: `True` if source weights are in CuDNN format, `False`
465 if they're in plain Keras format.
467 Returns:
468 Function that converts input kernel to the other format.
469 """
470 order = 'F' if from_cudnn else 'C'
472 def transform(kernel):
473 return kernel.T.reshape(kernel.shape, order=order)
475 return transform
477 target_class = layer.__class__.__name__
479 # convert the weights between CuDNNLSTM and LSTM
480 if target_class in ['LSTM', 'CuDNNLSTM'] and len(weights) == 3:
481 # determine if we're loading a CuDNNLSTM layer
482 # from the number of bias weights:
483 # CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4)
484 # if there's no bias weight in the file, skip this conversion
485 units = weights[1].shape[0]
486 bias_shape = weights[2].shape
487 n_gates = 4
489 if bias_shape == (2 * units * n_gates,):
490 source = 'CuDNNLSTM'
491 elif bias_shape == (units * n_gates,):
492 source = 'LSTM'
493 else:
494 raise ValueError('Invalid bias shape: ' + str(bias_shape))
496 def convert_lstm_weights(weights, from_cudnn=True):
497 """Converts the weights between CuDNNLSTM and LSTM.
499 Args:
500 weights: Original weights.
501 from_cudnn: Indicates whether original weights are from CuDNN layer.
503 Returns:
504 Updated weights compatible with LSTM.
505 """
507 # Transpose (and reshape) input and recurrent kernels
508 kernels = transform_kernels(weights[0], transpose_input(from_cudnn),
509 n_gates)
510 recurrent_kernels = transform_kernels(weights[1], lambda k: k.T, n_gates)
511 if from_cudnn:
512 # merge input and recurrent biases into a single set
513 biases = np.sum(np.split(weights[2], 2, axis=0), axis=0)
514 else:
515 # Split single set of biases evenly to two sets. The way of
516 # splitting doesn't matter as long as the two sets sum is kept.
517 biases = np.tile(0.5 * weights[2], 2)
518 return [kernels, recurrent_kernels, biases]
520 if source != target_class:
521 weights = convert_lstm_weights(weights, from_cudnn=source == 'CuDNNLSTM')
523 # convert the weights between CuDNNGRU and GRU(reset_after=True)
524 if target_class in ['GRU', 'CuDNNGRU'] and len(weights) == 3:
525 # We can determine the source of the weights from the shape of the bias.
526 # If there is no bias we skip the conversion since
527 # CuDNNGRU always has biases.
529 units = weights[1].shape[0]
530 bias_shape = weights[2].shape
531 n_gates = 3
533 def convert_gru_weights(weights, from_cudnn=True):
534 """Converts the weights between CuDNNGRU and GRU.
536 Args:
537 weights: Original weights.
538 from_cudnn: Indicates whether original weights are from CuDNN layer.
540 Returns:
541 Updated weights compatible with GRU.
542 """
544 kernels = transform_kernels(weights[0], transpose_input(from_cudnn),
545 n_gates)
546 recurrent_kernels = transform_kernels(weights[1], lambda k: k.T, n_gates)
547 biases = np.array(weights[2]).reshape((2, -1) if from_cudnn else -1)
548 return [kernels, recurrent_kernels, biases]
550 if bias_shape == (2 * units * n_gates,):
551 source = 'CuDNNGRU'
552 elif bias_shape == (2, units * n_gates):
553 source = 'GRU(reset_after=True)'
554 elif bias_shape == (units * n_gates,):
555 source = 'GRU(reset_after=False)'
556 else:
557 raise ValueError('Invalid bias shape: ' + str(bias_shape))
559 if target_class == 'CuDNNGRU':
560 target = 'CuDNNGRU'
561 elif layer.reset_after:
562 target = 'GRU(reset_after=True)'
563 else:
564 target = 'GRU(reset_after=False)'
566 # only convert between different types
567 if source != target:
568 types = (source, target)
569 if 'GRU(reset_after=False)' in types:
570 raise ValueError('%s is not compatible with %s' % types)
571 if source == 'CuDNNGRU':
572 weights = convert_gru_weights(weights, from_cudnn=True)
573 elif source == 'GRU(reset_after=True)':
574 weights = convert_gru_weights(weights, from_cudnn=False)
576 return weights
579def save_optimizer_weights_to_hdf5_group(hdf5_group, optimizer):
580 """Saves optimizer weights of a optimizer to a HDF5 group.
582 Args:
583 hdf5_group: HDF5 group.
584 optimizer: optimizer instance.
585 """
587 symbolic_weights = getattr(optimizer, 'weights')
588 if symbolic_weights:
589 weights_group = hdf5_group.create_group('optimizer_weights')
590 weight_names = [str(w.name).encode('utf8') for w in symbolic_weights]
591 save_attributes_to_hdf5_group(weights_group, 'weight_names', weight_names)
592 weight_values = backend.batch_get_value(symbolic_weights)
593 for name, val in zip(weight_names, weight_values):
594 param_dset = weights_group.create_dataset(
595 name, val.shape, dtype=val.dtype)
596 if not val.shape:
597 # scalar
598 param_dset[()] = val
599 else:
600 param_dset[:] = val
603def load_optimizer_weights_from_hdf5_group(hdf5_group):
604 """Load optimizer weights from a HDF5 group.
606 Args:
607 hdf5_group: A pointer to a HDF5 group.
609 Returns:
610 data: List of optimizer weight names.
611 """
612 weights_group = hdf5_group['optimizer_weights']
613 optimizer_weight_names = load_attributes_from_hdf5_group(
614 weights_group, 'weight_names')
615 return [weights_group[weight_name] for weight_name in optimizer_weight_names]
618def save_weights_to_hdf5_group(f, layers):
619 """Saves the weights of a list of layers to a HDF5 group.
621 Args:
622 f: HDF5 group.
623 layers: List of layer instances.
624 """
625 from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
627 save_attributes_to_hdf5_group(
628 f, 'layer_names', [layer.name.encode('utf8') for layer in layers])
629 f.attrs['backend'] = backend.backend().encode('utf8')
630 f.attrs['keras_version'] = str(keras_version).encode('utf8')
632 # Sort model layers by layer name to ensure that group names are strictly
633 # growing to avoid prefix issues.
634 for layer in sorted(layers, key=lambda x: x.name):
635 g = f.create_group(layer.name)
636 weights = _legacy_weights(layer)
637 weight_values = backend.batch_get_value(weights)
638 weight_names = [w.name.encode('utf8') for w in weights]
639 save_attributes_to_hdf5_group(g, 'weight_names', weight_names)
640 for name, val in zip(weight_names, weight_values):
641 param_dset = g.create_dataset(name, val.shape, dtype=val.dtype)
642 if not val.shape:
643 # scalar
644 param_dset[()] = val
645 else:
646 param_dset[:] = val
649def load_weights_from_hdf5_group(f, layers):
650 """Implements topological (order-based) weight loading.
652 Args:
653 f: A pointer to a HDF5 group.
654 layers: a list of target layers.
656 Raises:
657 ValueError: in case of mismatch between provided layers
658 and weights file.
659 """
660 if 'keras_version' in f.attrs:
661 original_keras_version = f.attrs['keras_version']
662 if hasattr(original_keras_version, 'decode'):
663 original_keras_version = original_keras_version.decode('utf8')
664 else:
665 original_keras_version = '1'
666 if 'backend' in f.attrs:
667 original_backend = f.attrs['backend']
668 if hasattr(original_backend, 'decode'):
669 original_backend = original_backend.decode('utf8')
670 else:
671 original_backend = None
673 filtered_layers = []
674 for layer in layers:
675 weights = _legacy_weights(layer)
676 if weights:
677 filtered_layers.append(layer)
679 layer_names = load_attributes_from_hdf5_group(f, 'layer_names')
680 filtered_layer_names = []
681 for name in layer_names:
682 g = f[name]
683 weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
684 if weight_names:
685 filtered_layer_names.append(name)
686 layer_names = filtered_layer_names
687 if len(layer_names) != len(filtered_layers):
688 raise ValueError('You are trying to load a weight file '
689 'containing ' + str(len(layer_names)) +
690 ' layers into a model with ' + str(len(filtered_layers)) +
691 ' layers.')
693 # We batch weight value assignments in a single backend call
694 # which provides a speedup in TensorFlow.
695 weight_value_tuples = []
696 for k, name in enumerate(layer_names):
697 g = f[name]
698 weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
699 weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names]
700 layer = filtered_layers[k]
701 symbolic_weights = _legacy_weights(layer)
702 weight_values = preprocess_weights_for_loading(
703 layer, weight_values, original_keras_version, original_backend)
704 if len(weight_values) != len(symbolic_weights):
705 raise ValueError('Layer #' + str(k) + ' (named "' + layer.name +
706 '" in the current model) was found to '
707 'correspond to layer ' + name + ' in the save file. '
708 'However the new layer ' + layer.name + ' expects ' +
709 str(len(symbolic_weights)) +
710 ' weights, but the saved weights have ' +
711 str(len(weight_values)) + ' elements.')
712 weight_value_tuples += zip(symbolic_weights, weight_values)
713 backend.batch_set_value(weight_value_tuples)
716def load_weights_from_hdf5_group_by_name(
717 f, layers, skip_mismatch=False):
718 """Implements name-based weight loading.
720 (instead of topological weight loading).
722 Layers that have no matching name are skipped.
724 Args:
725 f: A pointer to a HDF5 group.
726 layers: a list of target layers.
727 skip_mismatch: Boolean, whether to skip loading of layers
728 where there is a mismatch in the number of weights,
729 or a mismatch in the shape of the weights.
731 Raises:
732 ValueError: in case of mismatch between provided layers
733 and weights file and skip_match=False.
734 """
735 if 'keras_version' in f.attrs:
736 original_keras_version = f.attrs['keras_version']
737 if hasattr(original_keras_version, 'decode'):
738 original_keras_version = original_keras_version.decode('utf8')
739 else:
740 original_keras_version = '1'
741 if 'backend' in f.attrs:
742 original_backend = f.attrs['backend']
743 if hasattr(original_backend, 'decode'):
744 original_backend = original_backend.decode('utf8')
745 else:
746 original_backend = None
748 # New file format.
749 layer_names = load_attributes_from_hdf5_group(f, 'layer_names')
751 # Reverse index of layer name to list of layers with name.
752 index = {}
753 for layer in layers:
754 if layer.name:
755 index.setdefault(layer.name, []).append(layer)
757 # We batch weight value assignments in a single backend call
758 # which provides a speedup in TensorFlow.
759 weight_value_tuples = []
760 for k, name in enumerate(layer_names):
761 g = f[name]
762 weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
763 weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names]
765 for layer in index.get(name, []):
766 symbolic_weights = _legacy_weights(layer)
767 weight_values = preprocess_weights_for_loading(
768 layer, weight_values, original_keras_version, original_backend)
769 if len(weight_values) != len(symbolic_weights):
770 if skip_mismatch:
771 logging.warning('Skipping loading of weights for '
772 'layer {}'.format(layer.name) + ' due to mismatch '
773 'in number of weights ({} vs {}).'.format(
774 len(symbolic_weights), len(weight_values)))
775 continue
776 raise ValueError('Layer #' + str(k) + ' (named "' + layer.name +
777 '") expects ' + str(len(symbolic_weights)) +
778 ' weight(s), but the saved weights' + ' have ' +
779 str(len(weight_values)) + ' element(s).')
780 # Set values.
781 for i in range(len(weight_values)):
782 if backend.int_shape(symbolic_weights[i]) != weight_values[i].shape:
783 if skip_mismatch:
784 logging.warning('Skipping loading of weights for '
785 'layer {}'.format(layer.name) + ' due to '
786 'mismatch in shape ({} vs {}).'.format(
787 symbolic_weights[i].shape,
788 weight_values[i].shape))
789 continue
790 raise ValueError('Layer #' + str(k) +' (named "' + layer.name +
791 '"), weight ' + str(symbolic_weights[i]) +
792 ' has shape {}'.format(backend.int_shape(
793 symbolic_weights[i])) +
794 ', but the saved weight has shape ' +
795 str(weight_values[i].shape) + '.')
797 else:
798 weight_value_tuples.append((symbolic_weights[i], weight_values[i]))
799 backend.batch_set_value(weight_value_tuples)
802def save_attributes_to_hdf5_group(group, name, data):
803 """Saves attributes (data) of the specified name into the HDF5 group.
805 This method deals with an inherent problem of HDF5 file which is not
806 able to store data larger than HDF5_OBJECT_HEADER_LIMIT bytes.
808 Args:
809 group: A pointer to a HDF5 group.
810 name: A name of the attributes to save.
811 data: Attributes data to store.
813 Raises:
814 RuntimeError: If any single attribute is too large to be saved.
815 """
816 # Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT`
817 # because in that case even chunking the array would not make the saving
818 # possible.
819 bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT]
821 # Expecting this to never be true.
822 if bad_attributes:
823 raise RuntimeError('The following attributes cannot be saved to HDF5 '
824 'file because they are larger than %d bytes: %s' %
825 (HDF5_OBJECT_HEADER_LIMIT, ', '.join(bad_attributes)))
827 data_npy = np.asarray(data)
829 num_chunks = 1
830 chunked_data = np.array_split(data_npy, num_chunks)
832 # This will never loop forever thanks to the test above.
833 while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data):
834 num_chunks += 1
835 chunked_data = np.array_split(data_npy, num_chunks)
837 if num_chunks > 1:
838 for chunk_id, chunk_data in enumerate(chunked_data):
839 group.attrs['%s%d' % (name, chunk_id)] = chunk_data
840 else:
841 group.attrs[name] = data
844def load_attributes_from_hdf5_group(group, name):
845 """Loads attributes of the specified name from the HDF5 group.
847 This method deals with an inherent problem
848 of HDF5 file which is not able to store
849 data larger than HDF5_OBJECT_HEADER_LIMIT bytes.
851 Args:
852 group: A pointer to a HDF5 group.
853 name: A name of the attributes to load.
855 Returns:
856 data: Attributes data.
857 """
858 if name in group.attrs:
859 data = [
860 n.decode('utf8') if hasattr(n, 'decode') else n
861 for n in group.attrs[name]
862 ]
863 else:
864 data = []
865 chunk_id = 0
866 while '%s%d' % (name, chunk_id) in group.attrs:
867 data.extend([
868 n.decode('utf8') if hasattr(n, 'decode') else n
869 for n in group.attrs['%s%d' % (name, chunk_id)]
870 ])
871 chunk_id += 1
872 return data
875def _legacy_weights(layer):
876 """DO NOT USE.
878 For legacy reason, the layer.weights was in the order of
879 [self.trainable_weights + self.non_trainable_weights], and this order was
880 used for preserving the weights in h5 format. The new order of layer.weights
881 are the same as layer.get_weights() which is more intuitive for user. To
882 keep supporting the existing saved h5 file, this method should be used to
883 save/load weights. In future version, we will delete this method and
884 introduce a breaking change for h5 and stay with the new order for weights.
886 Args:
887 layer: a `tf.keras.Model` or `tf.keras.layers.Layer` instance.
889 Returns:
890 A list of variables with the order of trainable_weights, followed by
891 non_trainable_weights.
892 """
893 weights = layer.trainable_weights + layer.non_trainable_weights
894 if any(not isinstance(w, variables_module.Variable) for w in weights):
895 raise NotImplementedError(
896 'Save or restore weights that is not an instance of `tf.Variable` is '
897 'not supported in h5, use `save_format=\'tf\'` instead. Got a model '
898 'or layer {} with weights {}'.format(layer.__class__.__name__, weights))
899 return weights