Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/rnn_utils.py: 17%

59 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities for RNN cells and layers.""" 

16 

17 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src.utils import control_flow_util 

21 

22# isort: off 

23from tensorflow.python.platform import tf_logging as logging 

24 

25 

26def standardize_args(inputs, initial_state, constants, num_constants): 

27 """Standardizes `__call__` to a single list of tensor inputs. 

28 

29 When running a model loaded from a file, the input tensors 

30 `initial_state` and `constants` can be passed to `RNN.__call__()` as part 

31 of `inputs` instead of by the dedicated keyword arguments. This method 

32 makes sure the arguments are separated and that `initial_state` and 

33 `constants` are lists of tensors (or None). 

34 

35 Args: 

36 inputs: Tensor or list/tuple of tensors. which may include constants 

37 and initial states. In that case `num_constant` must be specified. 

38 initial_state: Tensor or list of tensors or None, initial states. 

39 constants: Tensor or list of tensors or None, constant tensors. 

40 num_constants: Expected number of constants (if constants are passed as 

41 part of the `inputs` list. 

42 

43 Returns: 

44 inputs: Single tensor or tuple of tensors. 

45 initial_state: List of tensors or None. 

46 constants: List of tensors or None. 

47 """ 

48 if isinstance(inputs, list): 

49 # There are several situations here: 

50 # In the graph mode, __call__ will be only called once. The 

51 # initial_state and constants could be in inputs (from file loading). 

52 # In the eager mode, __call__ will be called twice, once during 

53 # rnn_layer(inputs=input_t, constants=c_t, ...), and second time will be 

54 # model.fit/train_on_batch/predict with real np data. In the second 

55 # case, the inputs will contain initial_state and constants as eager 

56 # tensor. 

57 # 

58 # For either case, the real input is the first item in the list, which 

59 # could be a nested structure itself. Then followed by initial_states, 

60 # which could be a list of items, or list of list if the initial_state 

61 # is complex structure, and finally followed by constants which is a 

62 # flat list. 

63 assert initial_state is None and constants is None 

64 if num_constants: 

65 constants = inputs[-num_constants:] 

66 inputs = inputs[:-num_constants] 

67 if len(inputs) > 1: 

68 initial_state = inputs[1:] 

69 inputs = inputs[:1] 

70 

71 if len(inputs) > 1: 

72 inputs = tuple(inputs) 

73 else: 

74 inputs = inputs[0] 

75 

76 def to_list_or_none(x): 

77 if x is None or isinstance(x, list): 

78 return x 

79 if isinstance(x, tuple): 

80 return list(x) 

81 return [x] 

82 

83 initial_state = to_list_or_none(initial_state) 

84 constants = to_list_or_none(constants) 

85 

86 return inputs, initial_state, constants 

87 

88 

89def is_multiple_state(state_size): 

90 """Check whether the state_size contains multiple states.""" 

91 return hasattr(state_size, "__len__") and not isinstance( 

92 state_size, tf.TensorShape 

93 ) 

94 

95 

96def generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype): 

97 if inputs is not None: 

98 batch_size = tf.shape(inputs)[0] 

99 dtype = inputs.dtype 

100 return generate_zero_filled_state(batch_size, cell.state_size, dtype) 

101 

102 

103def generate_zero_filled_state(batch_size_tensor, state_size, dtype): 

104 """Generate a zero filled tensor with shape [batch_size, state_size].""" 

105 if batch_size_tensor is None or dtype is None: 

106 raise ValueError( 

107 "batch_size and dtype cannot be None while constructing initial " 

108 f"state. Received: batch_size={batch_size_tensor}, dtype={dtype}" 

109 ) 

110 

111 def create_zeros(unnested_state_size): 

112 flat_dims = tf.TensorShape(unnested_state_size).as_list() 

113 init_state_size = [batch_size_tensor] + flat_dims 

114 return tf.zeros(init_state_size, dtype=dtype) 

115 

116 if tf.nest.is_nested(state_size): 

117 return tf.nest.map_structure(create_zeros, state_size) 

118 else: 

119 return create_zeros(state_size) 

120 

121 

122def caching_device(rnn_cell): 

123 """Returns the caching device for the RNN variable. 

124 

125 This is useful for distributed training, when variable is not located as 

126 same device as the training worker. By enabling the device cache, this 

127 allows worker to read the variable once and cache locally, rather than read 

128 it every time step from remote when it is needed. 

129 

130 Note that this is assuming the variable that cell needs for each time step 

131 is having the same value in the forward path, and only gets updated in the 

132 backprop. It is true for all the default cells (SimpleRNN, GRU, LSTM). If 

133 the cell body relies on any variable that gets updated every time step, then 

134 caching device will cause it to read the stall value. 

135 

136 Args: 

137 rnn_cell: the rnn cell instance. 

138 """ 

139 if tf.executing_eagerly(): 

140 # caching_device is not supported in eager mode. 

141 return None 

142 if not getattr(rnn_cell, "_enable_caching_device", False): 

143 return None 

144 # Don't set a caching device when running in a loop, since it is possible 

145 # that train steps could be wrapped in a tf.while_loop. In that scenario 

146 # caching prevents forward computations in loop iterations from re-reading 

147 # the updated weights. 

148 if control_flow_util.IsInWhileLoop(tf.compat.v1.get_default_graph()): 

149 logging.warning( 

150 "Variable read device caching has been disabled because the " 

151 "RNN is in tf.while_loop loop context, which will cause " 

152 "reading stalled value in forward path. This could slow down " 

153 "the training due to duplicated variable reads. Please " 

154 "consider updating your code to remove tf.while_loop if possible." 

155 ) 

156 return None 

157 if ( 

158 rnn_cell._dtype_policy.compute_dtype 

159 != rnn_cell._dtype_policy.variable_dtype 

160 ): 

161 logging.warning( 

162 "Variable read device caching has been disabled since it " 

163 "doesn't work with the mixed precision API. This is " 

164 "likely to cause a slowdown for RNN training due to " 

165 "duplicated read of variable for each timestep, which " 

166 "will be significant in a multi remote worker setting. " 

167 "Please consider disabling mixed precision API if " 

168 "the performance has been affected." 

169 ) 

170 return None 

171 # Cache the value on the device that access the variable. 

172 return lambda op: op.device 

173 

174 

175def config_for_enable_caching_device(rnn_cell): 

176 """Return the dict config for RNN cell wrt to enable_caching_device field. 

177 

178 Since enable_caching_device is a internal implementation detail for speed up 

179 the RNN variable read when running on the multi remote worker setting, we 

180 don't want this config to be serialized constantly in the JSON. We will only 

181 serialize this field when a none default value is used to create the cell. 

182 Args: 

183 rnn_cell: the RNN cell for serialize. 

184 

185 Returns: 

186 A dict which contains the JSON config for enable_caching_device value or 

187 empty dict if the enable_caching_device value is same as the default 

188 value. 

189 """ 

190 default_enable_caching_device = ( 

191 tf.compat.v1.executing_eagerly_outside_functions() 

192 ) 

193 if rnn_cell._enable_caching_device != default_enable_caching_device: 

194 return {"enable_caching_device": rnn_cell._enable_caching_device} 

195 return {} 

196