Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/bidirectional.py: 13%
255 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Bidirectional wrapper for RNNs."""
18import copy
20import tensorflow.compat.v2 as tf
22from keras.src import backend
23from keras.src.engine.base_layer import Layer
24from keras.src.engine.input_spec import InputSpec
25from keras.src.layers.rnn import rnn_utils
26from keras.src.layers.rnn.base_wrapper import Wrapper
27from keras.src.saving import serialization_lib
28from keras.src.utils import generic_utils
29from keras.src.utils import tf_inspect
30from keras.src.utils import tf_utils
32# isort: off
33from tensorflow.python.util.tf_export import keras_export
36@keras_export("keras.layers.Bidirectional")
37class Bidirectional(Wrapper):
38 """Bidirectional wrapper for RNNs.
40 Args:
41 layer: `keras.layers.RNN` instance, such as `keras.layers.LSTM` or
42 `keras.layers.GRU`. It could also be a `keras.layers.Layer` instance
43 that meets the following criteria:
44 1. Be a sequence-processing layer (accepts 3D+ inputs).
45 2. Have a `go_backwards`, `return_sequences` and `return_state`
46 attribute (with the same semantics as for the `RNN` class).
47 3. Have an `input_spec` attribute.
48 4. Implement serialization via `get_config()` and `from_config()`.
49 Note that the recommended way to create new RNN layers is to write a
50 custom RNN cell and use it with `keras.layers.RNN`, instead of
51 subclassing `keras.layers.Layer` directly.
52 - When the `returns_sequences` is true, the output of the masked
53 timestep will be zero regardless of the layer's original
54 `zero_output_for_mask` value.
55 merge_mode: Mode by which outputs of the forward and backward RNNs will be
56 combined. One of {'sum', 'mul', 'concat', 'ave', None}. If None, the
57 outputs will not be combined, they will be returned as a list. Default
58 value is 'concat'.
59 backward_layer: Optional `keras.layers.RNN`, or `keras.layers.Layer`
60 instance to be used to handle backwards input processing.
61 If `backward_layer` is not provided, the layer instance passed as the
62 `layer` argument will be used to generate the backward layer
63 automatically.
64 Note that the provided `backward_layer` layer should have properties
65 matching those of the `layer` argument, in particular it should have the
66 same values for `stateful`, `return_states`, `return_sequences`, etc.
67 In addition, `backward_layer` and `layer` should have different
68 `go_backwards` argument values.
69 A `ValueError` will be raised if these requirements are not met.
71 Call arguments:
72 The call arguments for this layer are the same as those of the wrapped RNN
73 layer.
74 Beware that when passing the `initial_state` argument during the call of
75 this layer, the first half in the list of elements in the `initial_state`
76 list will be passed to the forward RNN call and the last half in the list
77 of elements will be passed to the backward RNN call.
79 Raises:
80 ValueError:
81 1. If `layer` or `backward_layer` is not a `Layer` instance.
82 2. In case of invalid `merge_mode` argument.
83 3. If `backward_layer` has mismatched properties compared to `layer`.
85 Examples:
87 ```python
88 model = Sequential()
89 model.add(Bidirectional(LSTM(10, return_sequences=True),
90 input_shape=(5, 10)))
91 model.add(Bidirectional(LSTM(10)))
92 model.add(Dense(5))
93 model.add(Activation('softmax'))
94 model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
96 # With custom backward layer
97 model = Sequential()
98 forward_layer = LSTM(10, return_sequences=True)
99 backward_layer = LSTM(10, activation='relu', return_sequences=True,
100 go_backwards=True)
101 model.add(Bidirectional(forward_layer, backward_layer=backward_layer,
102 input_shape=(5, 10)))
103 model.add(Dense(5))
104 model.add(Activation('softmax'))
105 model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
106 ```
107 """
109 def __init__(
110 self,
111 layer,
112 merge_mode="concat",
113 weights=None,
114 backward_layer=None,
115 **kwargs,
116 ):
117 if not isinstance(layer, Layer):
118 raise ValueError(
119 "Please initialize `Bidirectional` layer with a "
120 f"`tf.keras.layers.Layer` instance. Received: {layer}"
121 )
122 if backward_layer is not None and not isinstance(backward_layer, Layer):
123 raise ValueError(
124 "`backward_layer` need to be a `tf.keras.layers.Layer` "
125 f"instance. Received: {backward_layer}"
126 )
127 if merge_mode not in ["sum", "mul", "ave", "concat", None]:
128 raise ValueError(
129 f"Invalid merge mode. Received: {merge_mode}. "
130 "Merge mode should be one of "
131 '{"sum", "mul", "ave", "concat", None}'
132 )
133 # We don't want to track `layer` since we're already tracking the two
134 # copies of it we actually run.
135 self._setattr_tracking = False
136 super().__init__(layer, **kwargs)
137 self._setattr_tracking = True
139 # Recreate the forward layer from the original layer config, so that it
140 # will not carry over any state from the layer.
141 self.forward_layer = self._recreate_layer_from_config(layer)
143 if backward_layer is None:
144 self.backward_layer = self._recreate_layer_from_config(
145 layer, go_backwards=True
146 )
147 else:
148 self.backward_layer = backward_layer
150 # Keep the custom backward layer config, so that we can save it
151 # later. The layer's name might be updated below with prefix
152 # 'backward_', and we want to preserve the original config.
153 self._backward_layer_config = (
154 serialization_lib.serialize_keras_object(backward_layer)
155 )
157 self.forward_layer._name = "forward_" + self.forward_layer.name
158 self.backward_layer._name = "backward_" + self.backward_layer.name
160 self._verify_layer_config()
162 def force_zero_output_for_mask(layer):
163 # Force the zero_output_for_mask to be True if returning sequences.
164 if getattr(layer, "zero_output_for_mask", None) is not None:
165 layer.zero_output_for_mask = layer.return_sequences
167 force_zero_output_for_mask(self.forward_layer)
168 force_zero_output_for_mask(self.backward_layer)
170 self.merge_mode = merge_mode
171 if weights:
172 nw = len(weights)
173 self.forward_layer.initial_weights = weights[: nw // 2]
174 self.backward_layer.initial_weights = weights[nw // 2 :]
175 self.stateful = layer.stateful
176 self.return_sequences = layer.return_sequences
177 self.return_state = layer.return_state
178 self.supports_masking = True
179 self._trainable = kwargs.get("trainable", layer.trainable)
180 self._num_constants = 0
181 self.input_spec = layer.input_spec
183 @property
184 def _use_input_spec_as_call_signature(self):
185 return self.layer._use_input_spec_as_call_signature
187 def _verify_layer_config(self):
188 """Ensure the forward and backward layers have valid common property."""
189 if self.forward_layer.go_backwards == self.backward_layer.go_backwards:
190 raise ValueError(
191 "Forward layer and backward layer should have different "
192 "`go_backwards` value."
193 "forward_layer.go_backwards = "
194 f"{self.forward_layer.go_backwards},"
195 "backward_layer.go_backwards = "
196 f"{self.backward_layer.go_backwards}"
197 )
199 common_attributes = ("stateful", "return_sequences", "return_state")
200 for a in common_attributes:
201 forward_value = getattr(self.forward_layer, a)
202 backward_value = getattr(self.backward_layer, a)
203 if forward_value != backward_value:
204 raise ValueError(
205 "Forward layer and backward layer are expected to have "
206 f'the same value for attribute "{a}", got '
207 f'"{forward_value}" for forward layer and '
208 f'"{backward_value}" for backward layer'
209 )
211 def _recreate_layer_from_config(self, layer, go_backwards=False):
212 # When recreating the layer from its config, it is possible that the
213 # layer is a RNN layer that contains custom cells. In this case we
214 # inspect the layer and pass the custom cell class as part of the
215 # `custom_objects` argument when calling `from_config`. See
216 # https://github.com/tensorflow/tensorflow/issues/26581 for more detail.
217 config = layer.get_config()
218 if go_backwards:
219 config["go_backwards"] = not config["go_backwards"]
220 if (
221 "custom_objects"
222 in tf_inspect.getfullargspec(layer.__class__.from_config).args
223 ):
224 custom_objects = {}
225 cell = getattr(layer, "cell", None)
226 if cell is not None:
227 custom_objects[cell.__class__.__name__] = cell.__class__
228 # For StackedRNNCells
229 stacked_cells = getattr(cell, "cells", [])
230 for c in stacked_cells:
231 custom_objects[c.__class__.__name__] = c.__class__
232 return layer.__class__.from_config(
233 config, custom_objects=custom_objects
234 )
235 else:
236 return layer.__class__.from_config(config)
238 @tf_utils.shape_type_conversion
239 def compute_output_shape(self, input_shape):
240 output_shape = self.forward_layer.compute_output_shape(input_shape)
241 if self.return_state:
242 state_shape = tf_utils.convert_shapes(
243 output_shape[1:], to_tuples=False
244 )
245 output_shape = tf_utils.convert_shapes(
246 output_shape[0], to_tuples=False
247 )
248 else:
249 output_shape = tf_utils.convert_shapes(
250 output_shape, to_tuples=False
251 )
253 if self.merge_mode == "concat":
254 output_shape = output_shape.as_list()
255 output_shape[-1] *= 2
256 output_shape = tf.TensorShape(output_shape)
257 elif self.merge_mode is None:
258 output_shape = [output_shape, copy.copy(output_shape)]
260 if self.return_state:
261 if self.merge_mode is None:
262 return output_shape + state_shape + copy.copy(state_shape)
263 return [output_shape] + state_shape + copy.copy(state_shape)
264 return output_shape
266 def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
267 """`Bidirectional.__call__` implements the same API as the wrapped
268 `RNN`."""
269 inputs, initial_state, constants = rnn_utils.standardize_args(
270 inputs, initial_state, constants, self._num_constants
271 )
273 if isinstance(inputs, list):
274 if len(inputs) > 1:
275 initial_state = inputs[1:]
276 inputs = inputs[0]
278 if initial_state is None and constants is None:
279 return super().__call__(inputs, **kwargs)
281 # Applies the same workaround as in `RNN.__call__`
282 additional_inputs = []
283 additional_specs = []
284 if initial_state is not None:
285 # Check if `initial_state` can be split into half
286 num_states = len(initial_state)
287 if num_states % 2 > 0:
288 raise ValueError(
289 "When passing `initial_state` to a Bidirectional RNN, "
290 "the state should be a list containing the states of "
291 "the underlying RNNs. "
292 f"Received: {initial_state}"
293 )
295 kwargs["initial_state"] = initial_state
296 additional_inputs += initial_state
297 state_specs = tf.nest.map_structure(
298 lambda state: InputSpec(shape=backend.int_shape(state)),
299 initial_state,
300 )
301 self.forward_layer.state_spec = state_specs[: num_states // 2]
302 self.backward_layer.state_spec = state_specs[num_states // 2 :]
303 additional_specs += state_specs
304 if constants is not None:
305 kwargs["constants"] = constants
306 additional_inputs += constants
307 constants_spec = [
308 InputSpec(shape=backend.int_shape(constant))
309 for constant in constants
310 ]
311 self.forward_layer.constants_spec = constants_spec
312 self.backward_layer.constants_spec = constants_spec
313 additional_specs += constants_spec
315 self._num_constants = len(constants)
316 self.forward_layer._num_constants = self._num_constants
317 self.backward_layer._num_constants = self._num_constants
319 is_keras_tensor = backend.is_keras_tensor(
320 tf.nest.flatten(additional_inputs)[0]
321 )
322 for tensor in tf.nest.flatten(additional_inputs):
323 if backend.is_keras_tensor(tensor) != is_keras_tensor:
324 raise ValueError(
325 "The initial state of a Bidirectional"
326 " layer cannot be specified with a mix of"
327 " Keras tensors and non-Keras tensors"
328 ' (a "Keras tensor" is a tensor that was'
329 " returned by a Keras layer, or by `Input`)"
330 )
332 if is_keras_tensor:
333 # Compute the full input spec, including state
334 full_input = [inputs] + additional_inputs
335 # The original input_spec is None since there could be a nested
336 # tensor input. Update the input_spec to match the inputs.
337 full_input_spec = [
338 None for _ in range(len(tf.nest.flatten(inputs)))
339 ] + additional_specs
340 # Removing kwargs since the value are passed with input list.
341 kwargs["initial_state"] = None
342 kwargs["constants"] = None
344 # Perform the call with temporarily replaced input_spec
345 original_input_spec = self.input_spec
346 self.input_spec = full_input_spec
347 output = super().__call__(full_input, **kwargs)
348 self.input_spec = original_input_spec
349 return output
350 else:
351 return super().__call__(inputs, **kwargs)
353 def call(
354 self,
355 inputs,
356 training=None,
357 mask=None,
358 initial_state=None,
359 constants=None,
360 ):
361 """`Bidirectional.call` implements the same API as the wrapped `RNN`."""
362 kwargs = {}
363 if generic_utils.has_arg(self.layer.call, "training"):
364 kwargs["training"] = training
365 if generic_utils.has_arg(self.layer.call, "mask"):
366 kwargs["mask"] = mask
367 if generic_utils.has_arg(self.layer.call, "constants"):
368 kwargs["constants"] = constants
370 if generic_utils.has_arg(self.layer.call, "initial_state"):
371 if isinstance(inputs, list) and len(inputs) > 1:
372 # initial_states are keras tensors, which means they are passed
373 # in together with inputs as list. The initial_states need to be
374 # split into forward and backward section, and be feed to layers
375 # accordingly.
376 forward_inputs = [inputs[0]]
377 backward_inputs = [inputs[0]]
378 pivot = (len(inputs) - self._num_constants) // 2 + 1
379 # add forward initial state
380 forward_inputs += inputs[1:pivot]
381 if not self._num_constants:
382 # add backward initial state
383 backward_inputs += inputs[pivot:]
384 else:
385 # add backward initial state
386 backward_inputs += inputs[pivot : -self._num_constants]
387 # add constants for forward and backward layers
388 forward_inputs += inputs[-self._num_constants :]
389 backward_inputs += inputs[-self._num_constants :]
390 forward_state, backward_state = None, None
391 if "constants" in kwargs:
392 kwargs["constants"] = None
393 elif initial_state is not None:
394 # initial_states are not keras tensors, eg eager tensor from np
395 # array. They are only passed in from kwarg initial_state, and
396 # should be passed to forward/backward layer via kwarg
397 # initial_state as well.
398 forward_inputs, backward_inputs = inputs, inputs
399 half = len(initial_state) // 2
400 forward_state = initial_state[:half]
401 backward_state = initial_state[half:]
402 else:
403 forward_inputs, backward_inputs = inputs, inputs
404 forward_state, backward_state = None, None
406 y = self.forward_layer(
407 forward_inputs, initial_state=forward_state, **kwargs
408 )
409 y_rev = self.backward_layer(
410 backward_inputs, initial_state=backward_state, **kwargs
411 )
412 else:
413 y = self.forward_layer(inputs, **kwargs)
414 y_rev = self.backward_layer(inputs, **kwargs)
416 if self.return_state:
417 states = y[1:] + y_rev[1:]
418 y = y[0]
419 y_rev = y_rev[0]
421 if self.return_sequences:
422 time_dim = (
423 0 if getattr(self.forward_layer, "time_major", False) else 1
424 )
425 y_rev = backend.reverse(y_rev, time_dim)
426 if self.merge_mode == "concat":
427 output = backend.concatenate([y, y_rev])
428 elif self.merge_mode == "sum":
429 output = y + y_rev
430 elif self.merge_mode == "ave":
431 output = (y + y_rev) / 2
432 elif self.merge_mode == "mul":
433 output = y * y_rev
434 elif self.merge_mode is None:
435 output = [y, y_rev]
436 else:
437 raise ValueError(
438 "Unrecognized value for `merge_mode`. "
439 f"Received: {self.merge_mode}"
440 'Expected values are ["concat", "sum", "ave", "mul"]'
441 )
443 if self.return_state:
444 if self.merge_mode is None:
445 return output + states
446 return [output] + states
447 return output
449 def reset_states(self, states=None):
450 if not self.stateful:
451 raise AttributeError("Layer must be stateful.")
453 if states is None:
454 self.forward_layer.reset_states()
455 self.backward_layer.reset_states()
456 else:
457 if not isinstance(states, (list, tuple)):
458 raise ValueError(
459 "Unrecognized value for `states`. "
460 "Expected `states` to be list or tuple. "
461 f"Received: {states}"
462 )
464 half = len(states) // 2
465 self.forward_layer.reset_states(states[:half])
466 self.backward_layer.reset_states(states[half:])
468 def build(self, input_shape):
469 with backend.name_scope(self.forward_layer.name):
470 self.forward_layer.build(input_shape)
471 with backend.name_scope(self.backward_layer.name):
472 self.backward_layer.build(input_shape)
473 self.built = True
475 def compute_mask(self, inputs, mask):
476 if isinstance(mask, list):
477 mask = mask[0]
478 if self.return_sequences:
479 if not self.merge_mode:
480 output_mask = [mask, mask]
481 else:
482 output_mask = mask
483 else:
484 output_mask = [None, None] if not self.merge_mode else None
486 if self.return_state:
487 states = self.forward_layer.states
488 state_mask = [None for _ in states]
489 if isinstance(output_mask, list):
490 return output_mask + state_mask * 2
491 return [output_mask] + state_mask * 2
492 return output_mask
494 @property
495 def constraints(self):
496 constraints = {}
497 if hasattr(self.forward_layer, "constraints"):
498 constraints.update(self.forward_layer.constraints)
499 constraints.update(self.backward_layer.constraints)
500 return constraints
502 def get_config(self):
503 config = {"merge_mode": self.merge_mode}
504 if self._num_constants:
505 config["num_constants"] = self._num_constants
507 if hasattr(self, "_backward_layer_config"):
508 config["backward_layer"] = self._backward_layer_config
509 base_config = super().get_config()
510 return dict(list(base_config.items()) + list(config.items()))
512 @classmethod
513 def from_config(cls, config, custom_objects=None):
514 # Instead of updating the input, create a copy and use that.
515 config = copy.deepcopy(config)
516 num_constants = config.pop("num_constants", 0)
517 # Handle forward layer instantiation (as would parent class).
518 from keras.src.layers import deserialize as deserialize_layer
520 config["layer"] = deserialize_layer(
521 config["layer"], custom_objects=custom_objects
522 )
523 # Handle (optional) backward layer instantiation.
524 backward_layer_config = config.pop("backward_layer", None)
525 if backward_layer_config is not None:
526 backward_layer = deserialize_layer(
527 backward_layer_config, custom_objects=custom_objects
528 )
529 config["backward_layer"] = backward_layer
530 # Instantiate the wrapper, adjust it and return it.
531 layer = cls(**config)
532 layer._num_constants = num_constants
533 return layer