Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/abstract_rnn_cell.py: 75%

16 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Base class for RNN cells.""" 

16 

17 

18from keras.src.engine import base_layer 

19from keras.src.layers.rnn import rnn_utils 

20 

21# isort: off 

22from tensorflow.python.util.tf_export import keras_export 

23 

24 

25@keras_export("keras.layers.AbstractRNNCell") 

26class AbstractRNNCell(base_layer.Layer): 

27 """Abstract object representing an RNN cell. 

28 

29 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) 

30 for details about the usage of RNN API. 

31 

32 This is the base class for implementing RNN cells with custom behavior. 

33 

34 Every `RNNCell` must have the properties below and implement `call` with 

35 the signature `(output, next_state) = call(input, state)`. 

36 

37 Examples: 

38 

39 ```python 

40 class MinimalRNNCell(AbstractRNNCell): 

41 

42 def __init__(self, units, **kwargs): 

43 self.units = units 

44 super(MinimalRNNCell, self).__init__(**kwargs) 

45 

46 @property 

47 def state_size(self): 

48 return self.units 

49 

50 def build(self, input_shape): 

51 self.kernel = self.add_weight(shape=(input_shape[-1], self.units), 

52 initializer='uniform', 

53 name='kernel') 

54 self.recurrent_kernel = self.add_weight( 

55 shape=(self.units, self.units), 

56 initializer='uniform', 

57 name='recurrent_kernel') 

58 self.built = True 

59 

60 def call(self, inputs, states): 

61 prev_output = states[0] 

62 h = backend.dot(inputs, self.kernel) 

63 output = h + backend.dot(prev_output, self.recurrent_kernel) 

64 return output, output 

65 ``` 

66 

67 This definition of cell differs from the definition used in the literature. 

68 In the literature, 'cell' refers to an object with a single scalar output. 

69 This definition refers to a horizontal array of such units. 

70 

71 An RNN cell, in the most abstract setting, is anything that has 

72 a state and performs some operation that takes a matrix of inputs. 

73 This operation results in an output matrix with `self.output_size` columns. 

74 If `self.state_size` is an integer, this operation also results in a new 

75 state matrix with `self.state_size` columns. If `self.state_size` is a 

76 (possibly nested tuple of) TensorShape object(s), then it should return a 

77 matching structure of Tensors having shape `[batch_size].concatenate(s)` 

78 for each `s` in `self.batch_size`. 

79 """ 

80 

81 def call(self, inputs, states): 

82 """The function that contains the logic for one RNN step calculation. 

83 

84 Args: 

85 inputs: the input tensor, which is a slide from the overall RNN input 

86 by the time dimension (usually the second dimension). 

87 states: the state tensor from previous step, which has the same shape 

88 as `(batch, state_size)`. In the case of timestep 0, it will be the 

89 initial state user specified, or zero filled tensor otherwise. 

90 

91 Returns: 

92 A tuple of two tensors: 

93 1. output tensor for the current timestep, with size `output_size`. 

94 2. state tensor for next step, which has the shape of `state_size`. 

95 """ 

96 raise NotImplementedError 

97 

98 @property 

99 def state_size(self): 

100 """size(s) of state(s) used by this cell. 

101 

102 It can be represented by an Integer, a TensorShape or a tuple of 

103 Integers or TensorShapes. 

104 """ 

105 raise NotImplementedError 

106 

107 @property 

108 def output_size(self): 

109 """Integer or TensorShape: size of outputs produced by this cell.""" 

110 raise NotImplementedError 

111 

112 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 

113 return rnn_utils.generate_zero_filled_state_for_cell( 

114 self, inputs, batch_size, dtype 

115 ) 

116