1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module contains the implementation of RNN cell wrappers."""
16import hashlib
17import numbers
18import sys
19import types as python_types
20import warnings
21
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_conversion
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.keras.utils import generic_utils
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import nn_ops
30from tensorflow.python.ops import random_ops
31from tensorflow.python.ops import tensor_array_ops
32from tensorflow.python.util import nest
33
34
35class DropoutWrapperBase(object):
36 """Operator adding dropout to inputs and outputs of the given cell."""
37
38 def __init__(self,
39 cell,
40 input_keep_prob=1.0,
41 output_keep_prob=1.0,
42 state_keep_prob=1.0,
43 variational_recurrent=False,
44 input_size=None,
45 dtype=None,
46 seed=None,
47 dropout_state_filter_visitor=None,
48 **kwargs):
49 """Create a cell with added input, state, and/or output dropout.
50
51 If `variational_recurrent` is set to `True` (**NOT** the default behavior),
52 then the same dropout mask is applied at every step, as described in:
53 [A Theoretically Grounded Application of Dropout in Recurrent
54 Neural Networks. Y. Gal, Z. Ghahramani](https://arxiv.org/abs/1512.05287).
55
56 Otherwise a different dropout mask is applied at every time step.
57
58 Note, by default (unless a custom `dropout_state_filter` is provided),
59 the memory state (`c` component of any `LSTMStateTuple`) passing through
60 a `DropoutWrapper` is never modified. This behavior is described in the
61 above article.
62
63 Args:
64 cell: an RNNCell, a projection to output_size is added to it.
65 input_keep_prob: unit Tensor or float between 0 and 1, input keep
66 probability; if it is constant and 1, no input dropout will be added.
67 output_keep_prob: unit Tensor or float between 0 and 1, output keep
68 probability; if it is constant and 1, no output dropout will be added.
69 state_keep_prob: unit Tensor or float between 0 and 1, output keep
70 probability; if it is constant and 1, no output dropout will be added.
71 State dropout is performed on the outgoing states of the cell. **Note**
72 the state components to which dropout is applied when `state_keep_prob`
73 is in `(0, 1)` are also determined by the argument
74 `dropout_state_filter_visitor` (e.g. by default dropout is never applied
75 to the `c` component of an `LSTMStateTuple`).
76 variational_recurrent: Python bool. If `True`, then the same dropout
77 pattern is applied across all time steps per run call. If this parameter
78 is set, `input_size` **must** be provided.
79 input_size: (optional) (possibly nested tuple of) `TensorShape` objects
80 containing the depth(s) of the input tensors expected to be passed in to
81 the `DropoutWrapper`. Required and used **iff** `variational_recurrent
82 = True` and `input_keep_prob < 1`.
83 dtype: (optional) The `dtype` of the input, state, and output tensors.
84 Required and used **iff** `variational_recurrent = True`.
85 seed: (optional) integer, the randomness seed.
86 dropout_state_filter_visitor: (optional), default: (see below). Function
87 that takes any hierarchical level of the state and returns a scalar or
88 depth=1 structure of Python booleans describing which terms in the state
89 should be dropped out. In addition, if the function returns `True`,
90 dropout is applied across this sublevel. If the function returns
91 `False`, dropout is not applied across this entire sublevel.
92 Default behavior: perform dropout on all terms except the memory (`c`)
93 state of `LSTMCellState` objects, and don't try to apply dropout to
94 `TensorArray` objects: ```
95 def dropout_state_filter_visitor(s):
96 if isinstance(s, LSTMCellState): # Never perform dropout on the c
97 state. return LSTMCellState(c=False, h=True)
98 elif isinstance(s, TensorArray): return False return True ```
99 **kwargs: dict of keyword arguments for base layer.
100
101 Raises:
102 TypeError: if `cell` is not an `RNNCell`, or `keep_state_fn` is provided
103 but not `callable`.
104 ValueError: if any of the keep_probs are not between 0 and 1.
105 """
106 super(DropoutWrapperBase, self).__init__(cell, dtype=dtype, **kwargs)
107
108 if (dropout_state_filter_visitor is not None and
109 not callable(dropout_state_filter_visitor)):
110 raise TypeError("dropout_state_filter_visitor must be callable")
111 self._dropout_state_filter = (
112 dropout_state_filter_visitor or _default_dropout_state_filter_visitor)
113 with ops.name_scope_v2("DropoutWrapperInit"):
114
115 def tensor_and_const_value(v):
116 tensor_value = tensor_conversion.convert_to_tensor_v2_with_dispatch(v)
117 const_value = tensor_util.constant_value(tensor_value)
118 return (tensor_value, const_value)
119
120 for prob, attr in [(input_keep_prob, "input_keep_prob"),
121 (state_keep_prob, "state_keep_prob"),
122 (output_keep_prob, "output_keep_prob")]:
123 tensor_prob, const_prob = tensor_and_const_value(prob)
124 if const_prob is not None:
125 if const_prob < 0 or const_prob > 1:
126 raise ValueError("Parameter %s must be between 0 and 1: %d" %
127 (attr, const_prob))
128 setattr(self, "_%s" % attr, float(const_prob))
129 else:
130 setattr(self, "_%s" % attr, tensor_prob)
131
132 # Set variational_recurrent, seed before running the code below
133 self._variational_recurrent = variational_recurrent
134 self._input_size = input_size
135 self._seed = seed
136
137 self._recurrent_input_noise = None
138 self._recurrent_state_noise = None
139 self._recurrent_output_noise = None
140
141 if variational_recurrent:
142 if dtype is None:
143 raise ValueError(
144 "When variational_recurrent=True, dtype must be provided")
145
146 def convert_to_batch_shape(s):
147 # Prepend a 1 for the batch dimension; for recurrent
148 # variational dropout we use the same dropout mask for all
149 # batch elements.
150 return array_ops.concat(([1], tensor_shape.TensorShape(s).as_list()), 0)
151
152 def batch_noise(s, inner_seed):
153 shape = convert_to_batch_shape(s)
154 return random_ops.random_uniform(shape, seed=inner_seed, dtype=dtype)
155
156 if (not isinstance(self._input_keep_prob, numbers.Real) or
157 self._input_keep_prob < 1.0):
158 if input_size is None:
159 raise ValueError(
160 "When variational_recurrent=True and input_keep_prob < 1.0 or "
161 "is unknown, input_size must be provided")
162 self._recurrent_input_noise = _enumerated_map_structure_up_to(
163 input_size,
164 lambda i, s: batch_noise(s, inner_seed=self._gen_seed("input", i)),
165 input_size)
166 self._recurrent_state_noise = _enumerated_map_structure_up_to(
167 cell.state_size,
168 lambda i, s: batch_noise(s, inner_seed=self._gen_seed("state", i)),
169 cell.state_size)
170 self._recurrent_output_noise = _enumerated_map_structure_up_to(
171 cell.output_size,
172 lambda i, s: batch_noise(s, inner_seed=self._gen_seed("output", i)),
173 cell.output_size)
174
175 def _gen_seed(self, salt_prefix, index):
176 if self._seed is None:
177 return None
178 salt = "%s_%d" % (salt_prefix, index)
179 string = (str(self._seed) + salt).encode("utf-8")
180 return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
181
182 @property
183 def wrapped_cell(self):
184 return self.cell
185
186 @property
187 def state_size(self):
188 return self.cell.state_size
189
190 @property
191 def output_size(self):
192 return self.cell.output_size
193
194 def build(self, inputs_shape):
195 self.cell.build(inputs_shape)
196 self.built = True
197
198 def zero_state(self, batch_size, dtype):
199 with ops.name_scope_v2(type(self).__name__ + "ZeroState"):
200 return self.cell.zero_state(batch_size, dtype)
201
202 def _variational_recurrent_dropout_value(
203 self, unused_index, value, noise, keep_prob):
204 """Performs dropout given the pre-calculated noise tensor."""
205 # uniform [keep_prob, 1.0 + keep_prob)
206 random_tensor = keep_prob + noise
207
208 # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
209 binary_tensor = math_ops.floor(random_tensor)
210 ret = math_ops.divide(value, keep_prob) * binary_tensor
211 ret.set_shape(value.get_shape())
212 return ret
213
214 def _dropout(self,
215 values,
216 salt_prefix,
217 recurrent_noise,
218 keep_prob,
219 shallow_filtered_substructure=None):
220 """Decides whether to perform standard dropout or recurrent dropout."""
221
222 if shallow_filtered_substructure is None:
223 # Put something so we traverse the entire structure; inside the
224 # dropout function we check to see if leafs of this are bool or not.
225 shallow_filtered_substructure = values
226
227 if not self._variational_recurrent:
228
229 def dropout(i, do_dropout, v):
230 if not isinstance(do_dropout, bool) or do_dropout:
231 return nn_ops.dropout_v2(
232 v, rate=1. - keep_prob, seed=self._gen_seed(salt_prefix, i))
233 else:
234 return v
235
236 return _enumerated_map_structure_up_to(
237 shallow_filtered_substructure, dropout,
238 *[shallow_filtered_substructure, values])
239 else:
240
241 def dropout(i, do_dropout, v, n):
242 if not isinstance(do_dropout, bool) or do_dropout:
243 return self._variational_recurrent_dropout_value(i, v, n, keep_prob)
244 else:
245 return v
246
247 return _enumerated_map_structure_up_to(
248 shallow_filtered_substructure, dropout,
249 *[shallow_filtered_substructure, values, recurrent_noise])
250
251 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
252 """Runs the wrapped cell and applies dropout.
253
254 Args:
255 inputs: A tensor with wrapped cell's input.
256 state: A tensor or tuple of tensors with wrapped cell's state.
257 cell_call_fn: Wrapped cell's method to use for step computation (cell's
258 `__call__` or 'call' method).
259 **kwargs: Additional arguments.
260
261 Returns:
262 A pair containing:
263
264 - Output: A tensor with cell's output.
265 - New state: A tensor or tuple of tensors with new wrapped cell's state.
266 """
267
268 def _should_dropout(p):
269 return (not isinstance(p, float)) or p < 1
270
271 if _should_dropout(self._input_keep_prob):
272 inputs = self._dropout(inputs, "input", self._recurrent_input_noise,
273 self._input_keep_prob)
274 output, new_state = cell_call_fn(inputs, state, **kwargs)
275 if _should_dropout(self._state_keep_prob):
276 # Identify which subsets of the state to perform dropout on and
277 # which ones to keep.
278 shallow_filtered_substructure = nest.get_traverse_shallow_structure(
279 self._dropout_state_filter, new_state)
280 new_state = self._dropout(new_state, "state", self._recurrent_state_noise,
281 self._state_keep_prob,
282 shallow_filtered_substructure)
283 if _should_dropout(self._output_keep_prob):
284 output = self._dropout(output, "output", self._recurrent_output_noise,
285 self._output_keep_prob)
286 return output, new_state
287
288 def get_config(self):
289 """Returns the config of the dropout wrapper."""
290 config = {
291 "input_keep_prob": self._input_keep_prob,
292 "output_keep_prob": self._output_keep_prob,
293 "state_keep_prob": self._state_keep_prob,
294 "variational_recurrent": self._variational_recurrent,
295 "input_size": self._input_size,
296 "seed": self._seed,
297 }
298 if self._dropout_state_filter != _default_dropout_state_filter_visitor:
299 function, function_type, function_module = _serialize_function_to_config(
300 self._dropout_state_filter)
301 config.update({"dropout_fn": function,
302 "dropout_fn_type": function_type,
303 "dropout_fn_module": function_module})
304 base_config = super(DropoutWrapperBase, self).get_config()
305 return dict(list(base_config.items()) + list(config.items()))
306
307 @classmethod
308 def from_config(cls, config, custom_objects=None):
309 if "dropout_fn" in config:
310 config = config.copy()
311 dropout_state_filter = _parse_config_to_function(
312 config, custom_objects, "dropout_fn", "dropout_fn_type",
313 "dropout_fn_module")
314 config.pop("dropout_fn")
315 config["dropout_state_filter_visitor"] = dropout_state_filter
316 return super(DropoutWrapperBase, cls).from_config(
317 config, custom_objects=custom_objects)
318
319
320class ResidualWrapperBase(object):
321 """RNNCell wrapper that ensures cell inputs are added to the outputs."""
322
323 def __init__(self, cell, residual_fn=None, **kwargs):
324 """Constructs a `ResidualWrapper` for `cell`.
325
326 Args:
327 cell: An instance of `RNNCell`.
328 residual_fn: (Optional) The function to map raw cell inputs and raw cell
329 outputs to the actual cell outputs of the residual network.
330 Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs
331 and outputs.
332 **kwargs: dict of keyword arguments for base layer.
333 """
334 super(ResidualWrapperBase, self).__init__(cell, **kwargs)
335 self._residual_fn = residual_fn
336
337 @property
338 def state_size(self):
339 return self.cell.state_size
340
341 @property
342 def output_size(self):
343 return self.cell.output_size
344
345 def zero_state(self, batch_size, dtype):
346 with ops.name_scope_v2(type(self).__name__ + "ZeroState"):
347 return self.cell.zero_state(batch_size, dtype)
348
349 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
350 """Run the cell and then apply the residual_fn on its inputs to its outputs.
351
352 Args:
353 inputs: cell inputs.
354 state: cell state.
355 cell_call_fn: Wrapped cell's method to use for step computation (cell's
356 `__call__` or 'call' method).
357 **kwargs: Additional arguments passed to the wrapped cell's `call`.
358
359 Returns:
360 Tuple of cell outputs and new state.
361
362 Raises:
363 TypeError: If cell inputs and outputs have different structure (type).
364 ValueError: If cell inputs and outputs have different structure (value).
365 """
366 outputs, new_state = cell_call_fn(inputs, state, **kwargs)
367
368 # Ensure shapes match
369 def assert_shape_match(inp, out):
370 inp.get_shape().assert_is_compatible_with(out.get_shape())
371
372 def default_residual_fn(inputs, outputs):
373 nest.assert_same_structure(inputs, outputs)
374 nest.map_structure(assert_shape_match, inputs, outputs)
375 return nest.map_structure(lambda inp, out: inp + out, inputs, outputs)
376
377 res_outputs = (self._residual_fn or default_residual_fn)(inputs, outputs)
378 return (res_outputs, new_state)
379
380 def get_config(self):
381 """Returns the config of the residual wrapper."""
382 if self._residual_fn is not None:
383 function, function_type, function_module = _serialize_function_to_config(
384 self._residual_fn)
385 config = {
386 "residual_fn": function,
387 "residual_fn_type": function_type,
388 "residual_fn_module": function_module
389 }
390 else:
391 config = {}
392 base_config = super(ResidualWrapperBase, self).get_config()
393 return dict(list(base_config.items()) + list(config.items()))
394
395 @classmethod
396 def from_config(cls, config, custom_objects=None):
397 if "residual_fn" in config:
398 config = config.copy()
399 residual_function = _parse_config_to_function(config, custom_objects,
400 "residual_fn",
401 "residual_fn_type",
402 "residual_fn_module")
403 config["residual_fn"] = residual_function
404 return super(ResidualWrapperBase, cls).from_config(
405 config, custom_objects=custom_objects)
406
407
408class DeviceWrapperBase(object):
409 """Operator that ensures an RNNCell runs on a particular device."""
410
411 def __init__(self, cell, device, **kwargs):
412 """Construct a `DeviceWrapper` for `cell` with device `device`.
413
414 Ensures the wrapped `cell` is called with `tf.device(device)`.
415
416 Args:
417 cell: An instance of `RNNCell`.
418 device: A device string or function, for passing to `tf.device`.
419 **kwargs: dict of keyword arguments for base layer.
420 """
421 super(DeviceWrapperBase, self).__init__(cell, **kwargs)
422 self._device = device
423
424 @property
425 def state_size(self):
426 return self.cell.state_size
427
428 @property
429 def output_size(self):
430 return self.cell.output_size
431
432 def zero_state(self, batch_size, dtype):
433 with ops.name_scope_v2(type(self).__name__ + "ZeroState"):
434 with ops.device(self._device):
435 return self.cell.zero_state(batch_size, dtype)
436
437 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
438 """Run the cell on specified device."""
439 with ops.device(self._device):
440 return cell_call_fn(inputs, state, **kwargs)
441
442 def get_config(self):
443 config = {"device": self._device}
444 base_config = super(DeviceWrapperBase, self).get_config()
445 return dict(list(base_config.items()) + list(config.items()))
446
447
448def _serialize_function_to_config(function):
449 """Serialize the function for get_config()."""
450 if isinstance(function, python_types.LambdaType):
451 output = generic_utils.func_dump(function)
452 output_type = "lambda"
453 module = function.__module__
454 elif callable(function):
455 output = function.__name__
456 output_type = "function"
457 module = function.__module__
458 else:
459 raise ValueError("Unrecognized function type for input: {}".format(
460 type(function)))
461
462 return output, output_type, module
463
464
465def _parse_config_to_function(config, custom_objects, func_attr_name,
466 func_type_attr_name, module_attr_name):
467 """Reconstruct the function from the config."""
468 globs = globals()
469 module = config.pop(module_attr_name, None)
470 if module in sys.modules:
471 globs.update(sys.modules[module].__dict__)
472 elif module is not None:
473 # Note: we don't know the name of the function if it's a lambda.
474 warnings.warn("{} is not loaded, but a layer uses it. "
475 "It may cause errors.".format(module), UserWarning)
476 if custom_objects:
477 globs.update(custom_objects)
478 function_type = config.pop(func_type_attr_name)
479 if function_type == "function":
480 # Simple lookup in custom objects
481 function = generic_utils.deserialize_keras_object(
482 config[func_attr_name],
483 custom_objects=custom_objects,
484 printable_module_name="function in wrapper")
485 elif function_type == "lambda":
486 # Unsafe deserialization from bytecode
487 function = generic_utils.func_load(
488 config[func_attr_name], globs=globs)
489 else:
490 raise TypeError("Unknown function type:", function_type)
491 return function
492
493
494def _default_dropout_state_filter_visitor(substate):
495 from tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl import LSTMStateTuple # pylint: disable=g-import-not-at-top
496 if isinstance(substate, LSTMStateTuple):
497 # Do not perform dropout on the memory state.
498 return LSTMStateTuple(c=False, h=True)
499 elif isinstance(substate, tensor_array_ops.TensorArray):
500 return False
501 return True
502
503
504def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs):
505 ix = [0]
506
507 def enumerated_fn(*inner_args, **inner_kwargs):
508 r = map_fn(ix[0], *inner_args, **inner_kwargs)
509 ix[0] += 1
510 return r
511
512 return nest.map_structure_up_to(shallow_structure, enumerated_fn, *args,
513 **kwargs)