Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/rnn_utils.py: 17%
59 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for RNN cells and layers."""
18import tensorflow.compat.v2 as tf
20from keras.src.utils import control_flow_util
22# isort: off
23from tensorflow.python.platform import tf_logging as logging
26def standardize_args(inputs, initial_state, constants, num_constants):
27 """Standardizes `__call__` to a single list of tensor inputs.
29 When running a model loaded from a file, the input tensors
30 `initial_state` and `constants` can be passed to `RNN.__call__()` as part
31 of `inputs` instead of by the dedicated keyword arguments. This method
32 makes sure the arguments are separated and that `initial_state` and
33 `constants` are lists of tensors (or None).
35 Args:
36 inputs: Tensor or list/tuple of tensors. which may include constants
37 and initial states. In that case `num_constant` must be specified.
38 initial_state: Tensor or list of tensors or None, initial states.
39 constants: Tensor or list of tensors or None, constant tensors.
40 num_constants: Expected number of constants (if constants are passed as
41 part of the `inputs` list.
43 Returns:
44 inputs: Single tensor or tuple of tensors.
45 initial_state: List of tensors or None.
46 constants: List of tensors or None.
47 """
48 if isinstance(inputs, list):
49 # There are several situations here:
50 # In the graph mode, __call__ will be only called once. The
51 # initial_state and constants could be in inputs (from file loading).
52 # In the eager mode, __call__ will be called twice, once during
53 # rnn_layer(inputs=input_t, constants=c_t, ...), and second time will be
54 # model.fit/train_on_batch/predict with real np data. In the second
55 # case, the inputs will contain initial_state and constants as eager
56 # tensor.
57 #
58 # For either case, the real input is the first item in the list, which
59 # could be a nested structure itself. Then followed by initial_states,
60 # which could be a list of items, or list of list if the initial_state
61 # is complex structure, and finally followed by constants which is a
62 # flat list.
63 assert initial_state is None and constants is None
64 if num_constants:
65 constants = inputs[-num_constants:]
66 inputs = inputs[:-num_constants]
67 if len(inputs) > 1:
68 initial_state = inputs[1:]
69 inputs = inputs[:1]
71 if len(inputs) > 1:
72 inputs = tuple(inputs)
73 else:
74 inputs = inputs[0]
76 def to_list_or_none(x):
77 if x is None or isinstance(x, list):
78 return x
79 if isinstance(x, tuple):
80 return list(x)
81 return [x]
83 initial_state = to_list_or_none(initial_state)
84 constants = to_list_or_none(constants)
86 return inputs, initial_state, constants
89def is_multiple_state(state_size):
90 """Check whether the state_size contains multiple states."""
91 return hasattr(state_size, "__len__") and not isinstance(
92 state_size, tf.TensorShape
93 )
96def generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype):
97 if inputs is not None:
98 batch_size = tf.shape(inputs)[0]
99 dtype = inputs.dtype
100 return generate_zero_filled_state(batch_size, cell.state_size, dtype)
103def generate_zero_filled_state(batch_size_tensor, state_size, dtype):
104 """Generate a zero filled tensor with shape [batch_size, state_size]."""
105 if batch_size_tensor is None or dtype is None:
106 raise ValueError(
107 "batch_size and dtype cannot be None while constructing initial "
108 f"state. Received: batch_size={batch_size_tensor}, dtype={dtype}"
109 )
111 def create_zeros(unnested_state_size):
112 flat_dims = tf.TensorShape(unnested_state_size).as_list()
113 init_state_size = [batch_size_tensor] + flat_dims
114 return tf.zeros(init_state_size, dtype=dtype)
116 if tf.nest.is_nested(state_size):
117 return tf.nest.map_structure(create_zeros, state_size)
118 else:
119 return create_zeros(state_size)
122def caching_device(rnn_cell):
123 """Returns the caching device for the RNN variable.
125 This is useful for distributed training, when variable is not located as
126 same device as the training worker. By enabling the device cache, this
127 allows worker to read the variable once and cache locally, rather than read
128 it every time step from remote when it is needed.
130 Note that this is assuming the variable that cell needs for each time step
131 is having the same value in the forward path, and only gets updated in the
132 backprop. It is true for all the default cells (SimpleRNN, GRU, LSTM). If
133 the cell body relies on any variable that gets updated every time step, then
134 caching device will cause it to read the stall value.
136 Args:
137 rnn_cell: the rnn cell instance.
138 """
139 if tf.executing_eagerly():
140 # caching_device is not supported in eager mode.
141 return None
142 if not getattr(rnn_cell, "_enable_caching_device", False):
143 return None
144 # Don't set a caching device when running in a loop, since it is possible
145 # that train steps could be wrapped in a tf.while_loop. In that scenario
146 # caching prevents forward computations in loop iterations from re-reading
147 # the updated weights.
148 if control_flow_util.IsInWhileLoop(tf.compat.v1.get_default_graph()):
149 logging.warning(
150 "Variable read device caching has been disabled because the "
151 "RNN is in tf.while_loop loop context, which will cause "
152 "reading stalled value in forward path. This could slow down "
153 "the training due to duplicated variable reads. Please "
154 "consider updating your code to remove tf.while_loop if possible."
155 )
156 return None
157 if (
158 rnn_cell._dtype_policy.compute_dtype
159 != rnn_cell._dtype_policy.variable_dtype
160 ):
161 logging.warning(
162 "Variable read device caching has been disabled since it "
163 "doesn't work with the mixed precision API. This is "
164 "likely to cause a slowdown for RNN training due to "
165 "duplicated read of variable for each timestep, which "
166 "will be significant in a multi remote worker setting. "
167 "Please consider disabling mixed precision API if "
168 "the performance has been affected."
169 )
170 return None
171 # Cache the value on the device that access the variable.
172 return lambda op: op.device
175def config_for_enable_caching_device(rnn_cell):
176 """Return the dict config for RNN cell wrt to enable_caching_device field.
178 Since enable_caching_device is a internal implementation detail for speed up
179 the RNN variable read when running on the multi remote worker setting, we
180 don't want this config to be serialized constantly in the JSON. We will only
181 serialize this field when a none default value is used to create the cell.
182 Args:
183 rnn_cell: the RNN cell for serialize.
185 Returns:
186 A dict which contains the JSON config for enable_caching_device value or
187 empty dict if the enable_caching_device value is same as the default
188 value.
189 """
190 default_enable_caching_device = (
191 tf.compat.v1.executing_eagerly_outside_functions()
192 )
193 if rnn_cell._enable_caching_device != default_enable_caching_device:
194 return {"enable_caching_device": rnn_cell._enable_caching_device}
195 return {}