Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/layers/rnn_cell_wrapper_v2.py: 60%

47 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Module implementing for RNN wrappers for TF v2.""" 

16 

17# Note that all the APIs under this module are exported as tf.nn.*. This is due 

18# to the fact that those APIs were from tf.nn.rnn_cell_impl. They are ported 

19# here to avoid the cyclic dependency issue for serialization. These APIs will 

20# probably be deprecated and removed in future since similar API is available in 

21# existing Keras RNN API. 

22 

23 

24from tensorflow.python.keras.layers import recurrent 

25from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_wrapper_impl 

26from tensorflow.python.keras.utils import tf_inspect 

27from tensorflow.python.util.deprecation import deprecated 

28from tensorflow.python.util.tf_export import tf_export 

29 

30 

31class _RNNCellWrapperV2(recurrent.AbstractRNNCell): 

32 """Base class for cells wrappers V2 compatibility. 

33 

34 This class along with `rnn_cell_impl._RNNCellWrapperV1` allows to define 

35 wrappers that are compatible with V1 and V2, and defines helper methods for 

36 this purpose. 

37 """ 

38 

39 def __init__(self, cell, *args, **kwargs): 

40 super(_RNNCellWrapperV2, self).__init__(*args, **kwargs) 

41 self.cell = cell 

42 cell_call_spec = tf_inspect.getfullargspec(cell.call) 

43 self._expects_training_arg = ("training" in cell_call_spec.args) or ( 

44 cell_call_spec.varkw is not None 

45 ) 

46 

47 def call(self, inputs, state, **kwargs): 

48 """Runs the RNN cell step computation. 

49 

50 When `call` is being used, we assume that the wrapper object has been built, 

51 and therefore the wrapped cells has been built via its `build` method and 

52 its `call` method can be used directly. 

53 

54 This allows to use the wrapped cell and the non-wrapped cell equivalently 

55 when using `call` and `build`. 

56 

57 Args: 

58 inputs: A tensor with wrapped cell's input. 

59 state: A tensor or tuple of tensors with wrapped cell's state. 

60 **kwargs: Additional arguments passed to the wrapped cell's `call`. 

61 

62 Returns: 

63 A pair containing: 

64 

65 - Output: A tensor with cell's output. 

66 - New state: A tensor or tuple of tensors with new wrapped cell's state. 

67 """ 

68 return self._call_wrapped_cell( 

69 inputs, state, cell_call_fn=self.cell.call, **kwargs) 

70 

71 def build(self, inputs_shape): 

72 """Builds the wrapped cell.""" 

73 self.cell.build(inputs_shape) 

74 self.built = True 

75 

76 def get_config(self): 

77 config = { 

78 "cell": { 

79 "class_name": self.cell.__class__.__name__, 

80 "config": self.cell.get_config() 

81 }, 

82 } 

83 base_config = super(_RNNCellWrapperV2, self).get_config() 

84 return dict(list(base_config.items()) + list(config.items())) 

85 

86 @classmethod 

87 def from_config(cls, config, custom_objects=None): 

88 config = config.copy() 

89 from tensorflow.python.keras.layers.serialization import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 

90 cell = deserialize_layer(config.pop("cell"), custom_objects=custom_objects) 

91 return cls(cell, **config) 

92 

93 

94@deprecated(None, "Please use tf.keras.layers.RNN instead.") 

95@tf_export("nn.RNNCellDropoutWrapper", v1=[]) 

96class DropoutWrapper(rnn_cell_wrapper_impl.DropoutWrapperBase, 

97 _RNNCellWrapperV2): 

98 """Operator adding dropout to inputs and outputs of the given cell.""" 

99 

100 def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation 

101 super(DropoutWrapper, self).__init__(*args, **kwargs) 

102 if isinstance(self.cell, recurrent.LSTMCell): 

103 raise ValueError("keras LSTM cell does not work with DropoutWrapper. " 

104 "Please use LSTMCell(dropout=x, recurrent_dropout=y) " 

105 "instead.") 

106 

107 __init__.__doc__ = rnn_cell_wrapper_impl.DropoutWrapperBase.__init__.__doc__ 

108 

109 

110@deprecated(None, "Please use tf.keras.layers.RNN instead.") 

111@tf_export("nn.RNNCellResidualWrapper", v1=[]) 

112class ResidualWrapper(rnn_cell_wrapper_impl.ResidualWrapperBase, 

113 _RNNCellWrapperV2): 

114 """RNNCell wrapper that ensures cell inputs are added to the outputs.""" 

115 

116 def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation 

117 super(ResidualWrapper, self).__init__(*args, **kwargs) 

118 

119 __init__.__doc__ = rnn_cell_wrapper_impl.ResidualWrapperBase.__init__.__doc__ 

120 

121 

122@deprecated(None, "Please use tf.keras.layers.RNN instead.") 

123@tf_export("nn.RNNCellDeviceWrapper", v1=[]) 

124class DeviceWrapper(rnn_cell_wrapper_impl.DeviceWrapperBase, 

125 _RNNCellWrapperV2): 

126 """Operator that ensures an RNNCell runs on a particular device.""" 

127 

128 def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation 

129 super(DeviceWrapper, self).__init__(*args, **kwargs) 

130 

131 __init__.__doc__ = rnn_cell_wrapper_impl.DeviceWrapperBase.__init__.__doc__