Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/rnn/abstract_rnn_cell.py: 39%

28 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2023 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Base class for RNN cells. 

16 

17Adapted from legacy github.com/keras-team/tf-keras. 

18""" 

19 

20import tensorflow as tf 

21 

22 

23def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype): 

24 if inputs is not None: 

25 batch_size = tf.shape(inputs)[0] 

26 dtype = inputs.dtype 

27 return _generate_zero_filled_state(batch_size, cell.state_size, dtype) 

28 

29 

30def _generate_zero_filled_state(batch_size_tensor, state_size, dtype): 

31 """Generate a zero filled tensor with shape [batch_size, state_size].""" 

32 if batch_size_tensor is None or dtype is None: 

33 raise ValueError( 

34 "batch_size and dtype cannot be None while constructing initial state: " 

35 "batch_size={}, dtype={}".format(batch_size_tensor, dtype) 

36 ) 

37 

38 def create_zeros(unnested_state_size): 

39 flat_dims = tf.TensorShape(unnested_state_size).as_list() 

40 init_state_size = [batch_size_tensor] + flat_dims 

41 return tf.zeros(init_state_size, dtype=dtype) 

42 

43 if tf.nest.is_nested(state_size): 

44 return tf.nest.map_structure(create_zeros, state_size) 

45 else: 

46 return create_zeros(state_size) 

47 

48 

49class AbstractRNNCell(tf.keras.layers.Layer): 

50 """Abstract object representing an RNN cell. 

51 

52 This is a base class for implementing RNN cells with custom behavior. 

53 

54 Every `RNNCell` must have the properties below and implement `call` with 

55 the signature `(output, next_state) = call(input, state)`. 

56 

57 Examples: 

58 

59 ```python 

60 class MinimalRNNCell(AbstractRNNCell): 

61 

62 def __init__(self, units, **kwargs): 

63 self.units = units 

64 super(MinimalRNNCell, self).__init__(**kwargs) 

65 

66 @property 

67 def state_size(self): 

68 return self.units 

69 

70 def build(self, input_shape): 

71 self.kernel = self.add_weight(shape=(input_shape[-1], self.units), 

72 initializer='uniform', 

73 name='kernel') 

74 self.recurrent_kernel = self.add_weight( 

75 shape=(self.units, self.units), 

76 initializer='uniform', 

77 name='recurrent_kernel') 

78 self.built = True 

79 

80 def call(self, inputs, states): 

81 prev_output = states[0] 

82 h = backend.dot(inputs, self.kernel) 

83 output = h + backend.dot(prev_output, self.recurrent_kernel) 

84 return output, output 

85 ``` 

86 

87 This definition of cell differs from the definition used in the literature. 

88 In the literature, 'cell' refers to an object with a single scalar output. 

89 This definition refers to a horizontal array of such units. 

90 

91 An RNN cell, in the most abstract setting, is anything that has 

92 a state and performs some operation that takes a matrix of inputs. 

93 This operation results in an output matrix with `self.output_size` columns. 

94 If `self.state_size` is an integer, this operation also results in a new 

95 state matrix with `self.state_size` columns. If `self.state_size` is a 

96 (possibly nested tuple of) TensorShape object(s), then it should return a 

97 matching structure of Tensors having shape `[batch_size].concatenate(s)` 

98 for each `s` in `self.batch_size`. 

99 """ 

100 

101 def call(self, inputs, states): 

102 """The function that contains the logic for one RNN step calculation. 

103 

104 Args: 

105 inputs: the input tensor, which is a slide from the overall RNN input by 

106 the time dimension (usually the second dimension). 

107 states: the state tensor from previous step, which has the same shape 

108 as `(batch, state_size)`. In the case of timestep 0, it will be the 

109 initial state user specified, or zero filled tensor otherwise. 

110 

111 Returns: 

112 A tuple of two tensors: 

113 1. output tensor for the current timestep, with size `output_size`. 

114 2. state tensor for next step, which has the shape of `state_size`. 

115 """ 

116 raise NotImplementedError("Abstract method") 

117 

118 @property 

119 def state_size(self): 

120 """size(s) of state(s) used by this cell. 

121 

122 It can be represented by an Integer, a TensorShape or a tuple of Integers 

123 or TensorShapes. 

124 """ 

125 raise NotImplementedError("Abstract method") 

126 

127 @property 

128 def output_size(self): 

129 """Integer or TensorShape: size of outputs produced by this cell.""" 

130 raise NotImplementedError("Abstract method") 

131 

132 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 

133 return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)