Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/cell_wrappers.py: 26%
229 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module implementing RNN wrappers."""
18# Note that all the APIs under this module are exported as tf.nn.*. This is due
19# to the fact that those APIs were from tf.nn.rnn_cell_impl. They are ported
20# here to avoid the cyclic dependency issue for serialization. These APIs will
21# probably be deprecated and removed in future since similar API is available in
22# existing Keras RNN API.
24import hashlib
25import numbers
26import sys
27import types as python_types
28import warnings
30import tensorflow.compat.v2 as tf
32from keras.src.layers.rnn import lstm
33from keras.src.layers.rnn.abstract_rnn_cell import AbstractRNNCell
34from keras.src.saving import serialization_lib
35from keras.src.utils import generic_utils
36from keras.src.utils import tf_inspect
38# isort: off
39from tensorflow.python.util.tf_export import tf_export
40from tensorflow.python.util.deprecation import deprecated
43class _RNNCellWrapper(AbstractRNNCell):
44 """Base class for cells wrappers V2 compatibility.
46 This class along with `rnn_cell_impl._RNNCellWrapperV1` allows to define
47 wrappers that are compatible with V1 and V2, and defines helper methods for
48 this purpose.
49 """
51 def __init__(self, cell, *args, **kwargs):
52 super().__init__(*args, **kwargs)
53 self.cell = cell
54 cell_call_spec = tf_inspect.getfullargspec(cell.call)
55 self._call_spec.expects_training_arg = (
56 "training" in cell_call_spec.args
57 ) or (cell_call_spec.varkw is not None)
59 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
60 """Calls the wrapped cell and performs the wrapping logic.
62 This method is called from the wrapper's `call` or `__call__` methods.
64 Args:
65 inputs: A tensor with wrapped cell's input.
66 state: A tensor or tuple of tensors with wrapped cell's state.
67 cell_call_fn: Wrapped cell's method to use for step computation
68 (cell's `__call__` or 'call' method).
69 **kwargs: Additional arguments.
71 Returns:
72 A pair containing:
73 - Output: A tensor with cell's output.
74 - New state: A tensor or tuple of tensors with new wrapped cell's
75 state.
76 """
77 raise NotImplementedError
79 def call(self, inputs, state, **kwargs):
80 """Runs the RNN cell step computation.
82 When `call` is being used, we assume that the wrapper object has been
83 built, and therefore the wrapped cells has been built via its `build`
84 method and its `call` method can be used directly.
86 This allows to use the wrapped cell and the non-wrapped cell
87 equivalently when using `call` and `build`.
89 Args:
90 inputs: A tensor with wrapped cell's input.
91 state: A tensor or tuple of tensors with wrapped cell's state.
92 **kwargs: Additional arguments passed to the wrapped cell's `call`.
94 Returns:
95 A pair containing:
97 - Output: A tensor with cell's output.
98 - New state: A tensor or tuple of tensors with new wrapped cell's
99 state.
100 """
101 return self._call_wrapped_cell(
102 inputs, state, cell_call_fn=self.cell.call, **kwargs
103 )
105 def build(self, inputs_shape):
106 """Builds the wrapped cell."""
107 self.cell.build(inputs_shape)
108 self.built = True
110 @property
111 def wrapped_cell(self):
112 return self.cell
114 @property
115 def state_size(self):
116 return self.cell.state_size
118 @property
119 def output_size(self):
120 return self.cell.output_size
122 def zero_state(self, batch_size, dtype):
123 with tf.name_scope(type(self).__name__ + "ZeroState"):
124 return self.cell.zero_state(batch_size, dtype)
126 def get_config(self):
127 config = {
128 "cell": {
129 "class_name": self.cell.__class__.__name__,
130 "config": self.cell.get_config(),
131 },
132 }
133 base_config = super().get_config()
134 return dict(list(base_config.items()) + list(config.items()))
136 @classmethod
137 def from_config(cls, config, custom_objects=None):
138 config = config.copy()
139 from keras.src.layers.serialization import deserialize as deserialize_layer
141 cell = deserialize_layer(
142 config.pop("cell"), custom_objects=custom_objects
143 )
144 return cls(cell, **config)
147@deprecated(None, "Please use tf.keras.layers.RNN instead.")
148@tf_export("nn.RNNCellDropoutWrapper", v1=[])
149class DropoutWrapper(_RNNCellWrapper):
150 """Operator adding dropout to inputs and outputs of the given cell."""
152 def __init__(
153 self,
154 cell,
155 input_keep_prob=1.0,
156 output_keep_prob=1.0,
157 state_keep_prob=1.0,
158 variational_recurrent=False,
159 input_size=None,
160 dtype=None,
161 seed=None,
162 dropout_state_filter_visitor=None,
163 **kwargs,
164 ):
165 """Create a cell with added input, state, and/or output dropout.
167 If `variational_recurrent` is set to `True` (**NOT** the default
168 behavior), then the same dropout mask is applied at every step, as
169 described in: [A Theoretically Grounded Application of Dropout in
170 Recurrent Neural Networks. Y. Gal, Z.
171 Ghahramani](https://arxiv.org/abs/1512.05287).
173 Otherwise a different dropout mask is applied at every time step.
175 Note, by default (unless a custom `dropout_state_filter` is provided),
176 the memory state (`c` component of any `LSTMStateTuple`) passing through
177 a `DropoutWrapper` is never modified. This behavior is described in the
178 above article.
180 Args:
181 cell: an RNNCell, a projection to output_size is added to it.
182 input_keep_prob: unit Tensor or float between 0 and 1, input keep
183 probability; if it is constant and 1, no input dropout will be
184 added.
185 output_keep_prob: unit Tensor or float between 0 and 1, output keep
186 probability; if it is constant and 1, no output dropout will be
187 added.
188 state_keep_prob: unit Tensor or float between 0 and 1, output keep
189 probability; if it is constant and 1, no output dropout will be
190 added. State dropout is performed on the outgoing states of the
191 cell. **Note** the state components to which dropout is applied when
192 `state_keep_prob` is in `(0, 1)` are also determined by the argument
193 `dropout_state_filter_visitor` (e.g. by default dropout is never
194 applied to the `c` component of an `LSTMStateTuple`).
195 variational_recurrent: Python bool. If `True`, then the same dropout
196 pattern is applied across all time steps per run call. If this
197 parameter is set, `input_size` **must** be provided.
198 input_size: (optional) (possibly nested tuple of) `TensorShape`
199 objects containing the depth(s) of the input tensors expected to be
200 passed in to the `DropoutWrapper`. Required and used **iff**
201 `variational_recurrent = True` and `input_keep_prob < 1`.
202 dtype: (optional) The `dtype` of the input, state, and output tensors.
203 Required and used **iff** `variational_recurrent = True`.
204 seed: (optional) integer, the randomness seed.
205 dropout_state_filter_visitor: (optional), default: (see below).
206 Function that takes any hierarchical level of the state and returns
207 a scalar or depth=1 structure of Python booleans describing which
208 terms in the state should be dropped out. In addition, if the
209 function returns `True`, dropout is applied across this sublevel.
210 If the function returns `False`, dropout is not applied across this
211 entire sublevel. Default behavior: perform dropout on all terms
212 except the memory (`c`) state of `LSTMCellState` objects, and don't
213 try to apply dropout to
214 `TensorArray` objects:
215 ```
216 def dropout_state_filter_visitor(s):
217 # Never perform dropout on the c state.
218 if isinstance(s, LSTMCellState):
219 return LSTMCellState(c=False, h=True)
220 elif isinstance(s, TensorArray):
221 return False
222 return True
223 ```
224 **kwargs: dict of keyword arguments for base layer.
226 Raises:
227 TypeError: if `cell` is not an `RNNCell`, or `keep_state_fn` is
228 provided but not `callable`.
229 ValueError: if any of the keep_probs are not between 0 and 1.
230 """
231 if isinstance(cell, lstm.LSTMCell):
232 raise ValueError(
233 "keras LSTM cell does not work with DropoutWrapper. "
234 "Please use LSTMCell(dropout=x, recurrent_dropout=y) "
235 "instead."
236 )
237 super().__init__(cell, dtype=dtype, **kwargs)
239 if dropout_state_filter_visitor is not None and not callable(
240 dropout_state_filter_visitor
241 ):
242 raise TypeError(
243 "dropout_state_filter_visitor must be callable. "
244 f"Received: {dropout_state_filter_visitor}"
245 )
246 self._dropout_state_filter = (
247 dropout_state_filter_visitor
248 or _default_dropout_state_filter_visitor
249 )
250 with tf.name_scope("DropoutWrapperInit"):
252 def tensor_and_const_value(v):
253 tensor_value = tf.convert_to_tensor(v)
254 const_value = tf.get_static_value(tensor_value)
255 return (tensor_value, const_value)
257 for prob, attr in [
258 (input_keep_prob, "input_keep_prob"),
259 (state_keep_prob, "state_keep_prob"),
260 (output_keep_prob, "output_keep_prob"),
261 ]:
262 tensor_prob, const_prob = tensor_and_const_value(prob)
263 if const_prob is not None:
264 if const_prob < 0 or const_prob > 1:
265 raise ValueError(
266 f"Parameter {attr} must be between 0 and 1. "
267 f"Received {const_prob}"
268 )
269 setattr(self, f"_{attr}", float(const_prob))
270 else:
271 setattr(self, f"_{attr}", tensor_prob)
273 # Set variational_recurrent, seed before running the code below
274 self._variational_recurrent = variational_recurrent
275 self._input_size = input_size
276 self._seed = seed
278 self._recurrent_input_noise = None
279 self._recurrent_state_noise = None
280 self._recurrent_output_noise = None
282 if variational_recurrent:
283 if dtype is None:
284 raise ValueError(
285 "When variational_recurrent=True, dtype must be provided"
286 )
288 def convert_to_batch_shape(s):
289 # Prepend a 1 for the batch dimension; for recurrent
290 # variational dropout we use the same dropout mask for all
291 # batch elements.
292 return tf.concat(([1], tf.TensorShape(s).as_list()), 0)
294 def batch_noise(s, inner_seed):
295 shape = convert_to_batch_shape(s)
296 return tf.random.uniform(shape, seed=inner_seed, dtype=dtype)
298 if (
299 not isinstance(self._input_keep_prob, numbers.Real)
300 or self._input_keep_prob < 1.0
301 ):
302 if input_size is None:
303 raise ValueError(
304 "When variational_recurrent=True and input_keep_prob < "
305 "1.0 or is unknown, input_size must be provided"
306 )
307 self._recurrent_input_noise = _enumerated_map_structure_up_to(
308 input_size,
309 lambda i, s: batch_noise(
310 s, inner_seed=self._gen_seed("input", i)
311 ),
312 input_size,
313 )
314 self._recurrent_state_noise = _enumerated_map_structure_up_to(
315 cell.state_size,
316 lambda i, s: batch_noise(
317 s, inner_seed=self._gen_seed("state", i)
318 ),
319 cell.state_size,
320 )
321 self._recurrent_output_noise = _enumerated_map_structure_up_to(
322 cell.output_size,
323 lambda i, s: batch_noise(
324 s, inner_seed=self._gen_seed("output", i)
325 ),
326 cell.output_size,
327 )
329 def _gen_seed(self, salt_prefix, index):
330 if self._seed is None:
331 return None
332 salt = "%s_%d" % (salt_prefix, index)
333 string = (str(self._seed) + salt).encode("utf-8")
334 return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
336 def _variational_recurrent_dropout_value(
337 self, unused_index, value, noise, keep_prob
338 ):
339 """Performs dropout given the pre-calculated noise tensor."""
340 # uniform [keep_prob, 1.0 + keep_prob)
341 random_tensor = keep_prob + noise
343 # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
344 binary_tensor = tf.floor(random_tensor)
345 ret = tf.divide(value, keep_prob) * binary_tensor
346 ret.set_shape(value.get_shape())
347 return ret
349 def _dropout(
350 self,
351 values,
352 salt_prefix,
353 recurrent_noise,
354 keep_prob,
355 shallow_filtered_substructure=None,
356 ):
357 """Decides whether to perform standard dropout or recurrent dropout."""
359 if shallow_filtered_substructure is None:
360 # Put something so we traverse the entire structure; inside the
361 # dropout function we check to see if leafs of this are bool or not.
362 shallow_filtered_substructure = values
364 if not self._variational_recurrent:
366 def dropout(i, do_dropout, v):
367 if not isinstance(do_dropout, bool) or do_dropout:
368 return tf.nn.dropout(
369 v,
370 rate=1.0 - keep_prob,
371 seed=self._gen_seed(salt_prefix, i),
372 )
373 else:
374 return v
376 return _enumerated_map_structure_up_to(
377 shallow_filtered_substructure,
378 dropout,
379 *[shallow_filtered_substructure, values],
380 )
381 else:
383 def dropout(i, do_dropout, v, n):
384 if not isinstance(do_dropout, bool) or do_dropout:
385 return self._variational_recurrent_dropout_value(
386 i, v, n, keep_prob
387 )
388 else:
389 return v
391 return _enumerated_map_structure_up_to(
392 shallow_filtered_substructure,
393 dropout,
394 *[shallow_filtered_substructure, values, recurrent_noise],
395 )
397 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
398 """Runs the wrapped cell and applies dropout.
400 Args:
401 inputs: A tensor with wrapped cell's input.
402 state: A tensor or tuple of tensors with wrapped cell's state.
403 cell_call_fn: Wrapped cell's method to use for step computation
404 (cell's `__call__` or 'call' method).
405 **kwargs: Additional arguments.
407 Returns:
408 A pair containing:
410 - Output: A tensor with cell's output.
411 - New state: A tensor or tuple of tensors with new wrapped cell's
412 state.
413 """
415 def _should_dropout(p):
416 return (not isinstance(p, float)) or p < 1
418 if _should_dropout(self._input_keep_prob):
419 inputs = self._dropout(
420 inputs,
421 "input",
422 self._recurrent_input_noise,
423 self._input_keep_prob,
424 )
425 output, new_state = cell_call_fn(inputs, state, **kwargs)
426 if _should_dropout(self._state_keep_prob):
427 # Identify which subsets of the state to perform dropout on and
428 # which ones to keep.
429 shallow_filtered_substructure = (
430 tf.__internal__.nest.get_traverse_shallow_structure(
431 self._dropout_state_filter, new_state
432 )
433 )
434 new_state = self._dropout(
435 new_state,
436 "state",
437 self._recurrent_state_noise,
438 self._state_keep_prob,
439 shallow_filtered_substructure,
440 )
441 if _should_dropout(self._output_keep_prob):
442 output = self._dropout(
443 output,
444 "output",
445 self._recurrent_output_noise,
446 self._output_keep_prob,
447 )
448 return output, new_state
450 def get_config(self):
451 """Returns the config of the dropout wrapper."""
452 config = {
453 "input_keep_prob": self._input_keep_prob,
454 "output_keep_prob": self._output_keep_prob,
455 "state_keep_prob": self._state_keep_prob,
456 "variational_recurrent": self._variational_recurrent,
457 "input_size": self._input_size,
458 "seed": self._seed,
459 }
460 if self._dropout_state_filter != _default_dropout_state_filter_visitor:
461 (
462 function,
463 function_type,
464 function_module,
465 ) = _serialize_function_to_config(self._dropout_state_filter)
466 config.update(
467 {
468 "dropout_fn": function,
469 "dropout_fn_type": function_type,
470 "dropout_fn_module": function_module,
471 }
472 )
473 base_config = super().get_config()
474 return dict(list(base_config.items()) + list(config.items()))
476 @classmethod
477 def from_config(cls, config, custom_objects=None):
478 if "dropout_fn" in config:
479 config = config.copy()
480 dropout_state_filter = _parse_config_to_function(
481 config,
482 custom_objects,
483 "dropout_fn",
484 "dropout_fn_type",
485 "dropout_fn_module",
486 )
487 config.pop("dropout_fn")
488 config["dropout_state_filter_visitor"] = dropout_state_filter
489 return super(DropoutWrapper, cls).from_config(
490 config, custom_objects=custom_objects
491 )
494@deprecated(None, "Please use tf.keras.layers.RNN instead.")
495@tf_export("nn.RNNCellResidualWrapper", v1=[])
496class ResidualWrapper(_RNNCellWrapper):
497 """RNNCell wrapper that ensures cell inputs are added to the outputs."""
499 def __init__(self, cell, residual_fn=None, **kwargs):
500 """Constructs a `ResidualWrapper` for `cell`.
502 Args:
503 cell: An instance of `RNNCell`.
504 residual_fn: (Optional) The function to map raw cell inputs and raw
505 cell outputs to the actual cell outputs of the residual network.
506 Defaults to calling nest.map_structure on (lambda i, o: i + o),
507 inputs and outputs.
508 **kwargs: dict of keyword arguments for base layer.
509 """
510 super().__init__(cell, **kwargs)
511 self._residual_fn = residual_fn
513 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
514 """Run the cell and apply the residual_fn.
516 Args:
517 inputs: cell inputs.
518 state: cell state.
519 cell_call_fn: Wrapped cell's method to use for step computation
520 (cell's `__call__` or 'call' method).
521 **kwargs: Additional arguments passed to the wrapped cell's `call`.
523 Returns:
524 Tuple of cell outputs and new state.
526 Raises:
527 TypeError: If cell inputs and outputs have different structure (type).
528 ValueError: If cell inputs and outputs have different structure
529 (value).
530 """
531 outputs, new_state = cell_call_fn(inputs, state, **kwargs)
533 # Ensure shapes match
534 def assert_shape_match(inp, out):
535 inp.get_shape().assert_is_compatible_with(out.get_shape())
537 def default_residual_fn(inputs, outputs):
538 tf.nest.assert_same_structure(inputs, outputs)
539 tf.nest.map_structure(assert_shape_match, inputs, outputs)
540 return tf.nest.map_structure(
541 lambda inp, out: inp + out, inputs, outputs
542 )
544 res_outputs = (self._residual_fn or default_residual_fn)(
545 inputs, outputs
546 )
547 return (res_outputs, new_state)
549 def get_config(self):
550 """Returns the config of the residual wrapper."""
551 if self._residual_fn is not None:
552 (
553 function,
554 function_type,
555 function_module,
556 ) = _serialize_function_to_config(self._residual_fn)
557 config = {
558 "residual_fn": function,
559 "residual_fn_type": function_type,
560 "residual_fn_module": function_module,
561 }
562 else:
563 config = {}
564 base_config = super().get_config()
565 return dict(list(base_config.items()) + list(config.items()))
567 @classmethod
568 def from_config(cls, config, custom_objects=None):
569 if "residual_fn" in config:
570 config = config.copy()
571 residual_function = _parse_config_to_function(
572 config,
573 custom_objects,
574 "residual_fn",
575 "residual_fn_type",
576 "residual_fn_module",
577 )
578 config["residual_fn"] = residual_function
579 return super(ResidualWrapper, cls).from_config(
580 config, custom_objects=custom_objects
581 )
584@deprecated(None, "Please use tf.keras.layers.RNN instead.")
585@tf_export("nn.RNNCellDeviceWrapper", v1=[])
586class DeviceWrapper(_RNNCellWrapper):
587 """Operator that ensures an RNNCell runs on a particular device."""
589 def __init__(self, cell, device, **kwargs):
590 """Construct a `DeviceWrapper` for `cell` with device `device`.
592 Ensures the wrapped `cell` is called with `tf.device(device)`.
594 Args:
595 cell: An instance of `RNNCell`.
596 device: A device string or function, for passing to `tf.device`.
597 **kwargs: dict of keyword arguments for base layer.
598 """
599 super().__init__(cell, **kwargs)
600 self._device = device
602 def zero_state(self, batch_size, dtype):
603 with tf.name_scope(type(self).__name__ + "ZeroState"):
604 with tf.compat.v1.device(self._device):
605 return self.cell.zero_state(batch_size, dtype)
607 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
608 """Run the cell on specified device."""
609 with tf.compat.v1.device(self._device):
610 return cell_call_fn(inputs, state, **kwargs)
612 def get_config(self):
613 config = {"device": self._device}
614 base_config = super().get_config()
615 return dict(list(base_config.items()) + list(config.items()))
618def _serialize_function_to_config(function):
619 """Serialize the function for get_config()."""
620 if isinstance(function, python_types.LambdaType):
621 output = generic_utils.func_dump(function)
622 output_type = "lambda"
623 module = function.__module__
624 elif callable(function):
625 output = function.__name__
626 output_type = "function"
627 module = function.__module__
628 else:
629 raise ValueError(
630 f"Unrecognized function type for input: {type(function)}"
631 )
633 return output, output_type, module
636def _parse_config_to_function(
637 config,
638 custom_objects,
639 func_attr_name,
640 func_type_attr_name,
641 module_attr_name,
642):
643 """Reconstruct the function from the config."""
644 globs = globals()
645 module = config.pop(module_attr_name, None)
646 if module in sys.modules:
647 globs.update(sys.modules[module].__dict__)
648 elif module is not None:
649 # Note: we don't know the name of the function if it's a lambda.
650 warnings.warn(
651 "{} is not loaded, but a layer uses it. "
652 "It may cause errors.".format(module),
653 UserWarning,
654 stacklevel=2,
655 )
656 if custom_objects:
657 globs.update(custom_objects)
658 function_type = config.pop(func_type_attr_name)
659 if function_type == "function":
660 # Simple lookup in custom objects
661 function = serialization_lib.deserialize_keras_object(
662 config[func_attr_name],
663 custom_objects=custom_objects,
664 printable_module_name="function in wrapper",
665 )
666 elif function_type == "lambda":
667 if serialization_lib.in_safe_mode():
668 raise ValueError(
669 "Requested the deserialization of a layer with a "
670 "Python `lambda` inside it. "
671 "This carries a potential risk of arbitrary code execution "
672 "and thus it is disallowed by default. If you trust the "
673 "source of the saved model, you can pass `safe_mode=False` to "
674 "the loading function in order to allow "
675 "`lambda` loading."
676 )
677 # Unsafe deserialization from bytecode
678 function = generic_utils.func_load(config[func_attr_name], globs=globs)
679 else:
680 raise TypeError(
681 f"Unknown function type received: {function_type}. "
682 "Expected types are ['function', 'lambda']"
683 )
684 return function
687def _default_dropout_state_filter_visitor(substate):
688 return not isinstance(substate, tf.TensorArray)
691def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs):
692 ix = [0]
694 def enumerated_fn(*inner_args, **inner_kwargs):
695 r = map_fn(ix[0], *inner_args, **inner_kwargs)
696 ix[0] += 1
697 return r
699 return tf.__internal__.nest.map_structure_up_to(
700 shallow_structure, enumerated_fn, *args, **kwargs
701 )