Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py: 29%

524 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=g-classes-have-attributes 

16"""Module implementing RNN Cells. 

17 

18This module provides a number of basic commonly used RNN cells, such as LSTM 

19(Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number of 

20operators that allow adding dropouts, projections, or embeddings for inputs. 

21Constructing multi-layer cells is supported by the class `MultiRNNCell`, or by 

22calling the `rnn` ops several times. 

23""" 

24import collections 

25import warnings 

26 

27from tensorflow.python.eager import context 

28from tensorflow.python.framework import config as tf_config 

29from tensorflow.python.framework import constant_op 

30from tensorflow.python.framework import dtypes 

31from tensorflow.python.framework import ops 

32from tensorflow.python.framework import tensor_conversion 

33from tensorflow.python.framework import tensor_shape 

34from tensorflow.python.framework import tensor_util 

35from tensorflow.python.keras import activations 

36from tensorflow.python.keras import backend 

37from tensorflow.python.keras import initializers 

38from tensorflow.python.keras.engine import base_layer_utils 

39from tensorflow.python.keras.engine import input_spec 

40from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_wrapper_impl 

41from tensorflow.python.keras.legacy_tf_layers import base as base_layer 

42from tensorflow.python.keras.utils import tf_utils 

43from tensorflow.python.ops import array_ops 

44from tensorflow.python.ops import clip_ops 

45from tensorflow.python.ops import init_ops 

46from tensorflow.python.ops import math_ops 

47from tensorflow.python.ops import nn_ops 

48from tensorflow.python.ops import partitioned_variables 

49from tensorflow.python.ops import variable_scope as vs 

50from tensorflow.python.ops import variables as tf_variables 

51from tensorflow.python.platform import tf_logging as logging 

52from tensorflow.python.trackable import base as trackable 

53from tensorflow.python.util import nest 

54from tensorflow.python.util.tf_export import keras_export 

55from tensorflow.python.util.tf_export import tf_export 

56 

57_BIAS_VARIABLE_NAME = "bias" 

58_WEIGHTS_VARIABLE_NAME = "kernel" 

59 

60# This can be used with self.assertRaisesRegexp for assert_like_rnncell. 

61ASSERT_LIKE_RNNCELL_ERROR_REGEXP = "is not an RNNCell" 

62 

63 

64def _hasattr(obj, attr_name): 

65 try: 

66 getattr(obj, attr_name) 

67 except AttributeError: 

68 return False 

69 else: 

70 return True 

71 

72 

73def assert_like_rnncell(cell_name, cell): 

74 """Raises a TypeError if cell is not like an RNNCell. 

75 

76 NOTE: Do not rely on the error message (in particular in tests) which can be 

77 subject to change to increase readability. Use 

78 ASSERT_LIKE_RNNCELL_ERROR_REGEXP. 

79 

80 Args: 

81 cell_name: A string to give a meaningful error referencing to the name of 

82 the functionargument. 

83 cell: The object which should behave like an RNNCell. 

84 

85 Raises: 

86 TypeError: A human-friendly exception. 

87 """ 

88 conditions = [ 

89 _hasattr(cell, "output_size"), 

90 _hasattr(cell, "state_size"), 

91 _hasattr(cell, "get_initial_state") or _hasattr(cell, "zero_state"), 

92 callable(cell), 

93 ] 

94 errors = [ 

95 "'output_size' property is missing", "'state_size' property is missing", 

96 "either 'zero_state' or 'get_initial_state' method is required", 

97 "is not callable" 

98 ] 

99 

100 if not all(conditions): 

101 

102 errors = [error for error, cond in zip(errors, conditions) if not cond] 

103 raise TypeError("The argument {!r} ({}) is not an RNNCell: {}.".format( 

104 cell_name, cell, ", ".join(errors))) 

105 

106 

107def _concat(prefix, suffix, static=False): 

108 """Concat that enables int, Tensor, or TensorShape values. 

109 

110 This function takes a size specification, which can be an integer, a 

111 TensorShape, or a Tensor, and converts it into a concatenated Tensor 

112 (if static = False) or a list of integers (if static = True). 

113 

114 Args: 

115 prefix: The prefix; usually the batch size (and/or time step size). 

116 (TensorShape, int, or Tensor.) 

117 suffix: TensorShape, int, or Tensor. 

118 static: If `True`, return a python list with possibly unknown dimensions. 

119 Otherwise return a `Tensor`. 

120 

121 Returns: 

122 shape: the concatenation of prefix and suffix. 

123 

124 Raises: 

125 ValueError: if `suffix` is not a scalar or vector (or TensorShape). 

126 ValueError: if prefix or suffix was `None` and asked for dynamic 

127 Tensors out. 

128 """ 

129 if isinstance(prefix, ops.Tensor): 

130 p = prefix 

131 p_static = tensor_util.constant_value(prefix) 

132 if p.shape.ndims == 0: 

133 p = array_ops.expand_dims(p, 0) 

134 elif p.shape.ndims != 1: 

135 raise ValueError("prefix tensor must be either a scalar or vector, " 

136 "but saw tensor: %s" % p) 

137 else: 

138 p = tensor_shape.TensorShape(prefix) 

139 p_static = p.as_list() if p.ndims is not None else None 

140 p = ( 

141 constant_op.constant(p.as_list(), dtype=dtypes.int32) 

142 if p.is_fully_defined() else None) 

143 if isinstance(suffix, ops.Tensor): 

144 s = suffix 

145 s_static = tensor_util.constant_value(suffix) 

146 if s.shape.ndims == 0: 

147 s = array_ops.expand_dims(s, 0) 

148 elif s.shape.ndims != 1: 

149 raise ValueError("suffix tensor must be either a scalar or vector, " 

150 "but saw tensor: %s" % s) 

151 else: 

152 s = tensor_shape.TensorShape(suffix) 

153 s_static = s.as_list() if s.ndims is not None else None 

154 s = ( 

155 constant_op.constant(s.as_list(), dtype=dtypes.int32) 

156 if s.is_fully_defined() else None) 

157 

158 if static: 

159 shape = tensor_shape.TensorShape(p_static).concatenate(s_static) 

160 shape = shape.as_list() if shape.ndims is not None else None 

161 else: 

162 if p is None or s is None: 

163 raise ValueError("Provided a prefix or suffix of None: %s and %s" % 

164 (prefix, suffix)) 

165 shape = array_ops.concat((p, s), 0) 

166 return shape 

167 

168 

169def _zero_state_tensors(state_size, batch_size, dtype): 

170 """Create tensors of zeros based on state_size, batch_size, and dtype.""" 

171 

172 def get_state_shape(s): 

173 """Combine s with batch_size to get a proper tensor shape.""" 

174 c = _concat(batch_size, s) 

175 size = array_ops.zeros(c, dtype=dtype) 

176 if not context.executing_eagerly(): 

177 c_static = _concat(batch_size, s, static=True) 

178 size.set_shape(c_static) 

179 return size 

180 

181 return nest.map_structure(get_state_shape, state_size) 

182 

183 

184@keras_export(v1=["keras.__internal__.legacy.rnn_cell.RNNCell"]) 

185@tf_export(v1=["nn.rnn_cell.RNNCell"]) 

186class RNNCell(base_layer.Layer): 

187 """Abstract object representing an RNN cell. 

188 

189 Every `RNNCell` must have the properties below and implement `call` with 

190 the signature `(output, next_state) = call(input, state)`. The optional 

191 third input argument, `scope`, is allowed for backwards compatibility 

192 purposes; but should be left off for new subclasses. 

193 

194 This definition of cell differs from the definition used in the literature. 

195 In the literature, 'cell' refers to an object with a single scalar output. 

196 This definition refers to a horizontal array of such units. 

197 

198 An RNN cell, in the most abstract setting, is anything that has 

199 a state and performs some operation that takes a matrix of inputs. 

200 This operation results in an output matrix with `self.output_size` columns. 

201 If `self.state_size` is an integer, this operation also results in a new 

202 state matrix with `self.state_size` columns. If `self.state_size` is a 

203 (possibly nested tuple of) TensorShape object(s), then it should return a 

204 matching structure of Tensors having shape `[batch_size].concatenate(s)` 

205 for each `s` in `self.batch_size`. 

206 """ 

207 

208 def __init__(self, trainable=True, name=None, dtype=None, **kwargs): 

209 super(RNNCell, self).__init__( 

210 trainable=trainable, name=name, dtype=dtype, **kwargs) 

211 # Attribute that indicates whether the cell is a TF RNN cell, due the slight 

212 # difference between TF and Keras RNN cell. Notably the state is not wrapped 

213 # in a list for TF cell where they are single tensor state, whereas keras 

214 # cell will wrap the state into a list, and call() will have to unwrap them. 

215 self._is_tf_rnn_cell = True 

216 

217 def __call__(self, inputs, state, scope=None): 

218 """Run this RNN cell on inputs, starting from the given state. 

219 

220 Args: 

221 inputs: `2-D` tensor with shape `[batch_size, input_size]`. 

222 state: if `self.state_size` is an integer, this should be a `2-D Tensor` 

223 with shape `[batch_size, self.state_size]`. Otherwise, if 

224 `self.state_size` is a tuple of integers, this should be a tuple with 

225 shapes `[batch_size, s] for s in self.state_size`. 

226 scope: VariableScope for the created subgraph; defaults to class name. 

227 

228 Returns: 

229 A pair containing: 

230 

231 - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`. 

232 - New state: Either a single `2-D` tensor, or a tuple of tensors matching 

233 the arity and shapes of `state`. 

234 """ 

235 if scope is not None: 

236 with vs.variable_scope( 

237 scope, custom_getter=self._rnn_get_variable) as scope: 

238 return super(RNNCell, self).__call__(inputs, state, scope=scope) 

239 else: 

240 scope_attrname = "rnncell_scope" 

241 scope = getattr(self, scope_attrname, None) 

242 if scope is None: 

243 scope = vs.variable_scope( 

244 vs.get_variable_scope(), custom_getter=self._rnn_get_variable) 

245 setattr(self, scope_attrname, scope) 

246 with scope: 

247 return super(RNNCell, self).__call__(inputs, state) 

248 

249 def _rnn_get_variable(self, getter, *args, **kwargs): 

250 variable = getter(*args, **kwargs) 

251 if ops.executing_eagerly_outside_functions(): 

252 trainable = variable.trainable 

253 else: 

254 trainable = ( 

255 variable in tf_variables.trainable_variables() or 

256 (base_layer_utils.is_split_variable(variable) and 

257 list(variable)[0] in tf_variables.trainable_variables())) 

258 if trainable and all(variable is not v for v in self._trainable_weights): 

259 self._trainable_weights.append(variable) 

260 elif not trainable and all( 

261 variable is not v for v in self._non_trainable_weights): 

262 self._non_trainable_weights.append(variable) 

263 return variable 

264 

265 @property 

266 def state_size(self): 

267 """size(s) of state(s) used by this cell. 

268 

269 It can be represented by an Integer, a TensorShape or a tuple of Integers 

270 or TensorShapes. 

271 """ 

272 raise NotImplementedError("Abstract method") 

273 

274 @property 

275 def output_size(self): 

276 """Integer or TensorShape: size of outputs produced by this cell.""" 

277 raise NotImplementedError("Abstract method") 

278 

279 def build(self, _): 

280 # This tells the parent Layer object that it's OK to call 

281 # self.add_variable() inside the call() method. 

282 pass 

283 

284 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 

285 if inputs is not None: 

286 # Validate the given batch_size and dtype against inputs if provided. 

287 inputs = tensor_conversion.convert_to_tensor_v2_with_dispatch( 

288 inputs, name="inputs" 

289 ) 

290 if batch_size is not None: 

291 if tensor_util.is_tf_type(batch_size): 

292 static_batch_size = tensor_util.constant_value( 

293 batch_size, partial=True) 

294 else: 

295 static_batch_size = batch_size 

296 if inputs.shape.dims[0].value != static_batch_size: 

297 raise ValueError( 

298 "batch size from input tensor is different from the " 

299 "input param. Input tensor batch: {}, batch_size: {}".format( 

300 inputs.shape.dims[0].value, batch_size)) 

301 

302 if dtype is not None and inputs.dtype != dtype: 

303 raise ValueError( 

304 "dtype from input tensor is different from the " 

305 "input param. Input tensor dtype: {}, dtype: {}".format( 

306 inputs.dtype, dtype)) 

307 

308 batch_size = inputs.shape.dims[0].value or array_ops.shape(inputs)[0] 

309 dtype = inputs.dtype 

310 if batch_size is None or dtype is None: 

311 raise ValueError( 

312 "batch_size and dtype cannot be None while constructing initial " 

313 "state: batch_size={}, dtype={}".format(batch_size, dtype)) 

314 return self.zero_state(batch_size, dtype) 

315 

316 def zero_state(self, batch_size, dtype): 

317 """Return zero-filled state tensor(s). 

318 

319 Args: 

320 batch_size: int, float, or unit Tensor representing the batch size. 

321 dtype: the data type to use for the state. 

322 

323 Returns: 

324 If `state_size` is an int or TensorShape, then the return value is a 

325 `N-D` tensor of shape `[batch_size, state_size]` filled with zeros. 

326 

327 If `state_size` is a nested list or tuple, then the return value is 

328 a nested list or tuple (of the same structure) of `2-D` tensors with 

329 the shapes `[batch_size, s]` for each s in `state_size`. 

330 """ 

331 # Try to use the last cached zero_state. This is done to avoid recreating 

332 # zeros, especially when eager execution is enabled. 

333 state_size = self.state_size 

334 is_eager = context.executing_eagerly() 

335 if is_eager and _hasattr(self, "_last_zero_state"): 

336 (last_state_size, last_batch_size, last_dtype, 

337 last_output) = getattr(self, "_last_zero_state") 

338 if (last_batch_size == batch_size and last_dtype == dtype and 

339 last_state_size == state_size): 

340 return last_output 

341 with backend.name_scope(type(self).__name__ + "ZeroState"): 

342 output = _zero_state_tensors(state_size, batch_size, dtype) 

343 if is_eager: 

344 self._last_zero_state = (state_size, batch_size, dtype, output) 

345 return output 

346 

347 # TODO(b/134773139): Remove when contrib RNN cells implement `get_config` 

348 def get_config(self): # pylint: disable=useless-super-delegation 

349 return super(RNNCell, self).get_config() 

350 

351 @property 

352 def _use_input_spec_as_call_signature(self): 

353 # We do not store the shape information for the state argument in the call 

354 # function for legacy RNN cells, so do not generate an input signature. 

355 return False 

356 

357 

358class LayerRNNCell(RNNCell): 

359 """Subclass of RNNCells that act like proper `tf.Layer` objects. 

360 

361 For backwards compatibility purposes, most `RNNCell` instances allow their 

362 `call` methods to instantiate variables via `tf.compat.v1.get_variable`. The 

363 underlying 

364 variable scope thus keeps track of any variables, and returning cached 

365 versions. This is atypical of `tf.layer` objects, which separate this 

366 part of layer building into a `build` method that is only called once. 

367 

368 Here we provide a subclass for `RNNCell` objects that act exactly as 

369 `Layer` objects do. They must provide a `build` method and their 

370 `call` methods do not access Variables `tf.compat.v1.get_variable`. 

371 """ 

372 

373 def __call__(self, inputs, state, scope=None, *args, **kwargs): 

374 """Run this RNN cell on inputs, starting from the given state. 

375 

376 Args: 

377 inputs: `2-D` tensor with shape `[batch_size, input_size]`. 

378 state: if `self.state_size` is an integer, this should be a `2-D Tensor` 

379 with shape `[batch_size, self.state_size]`. Otherwise, if 

380 `self.state_size` is a tuple of integers, this should be a tuple with 

381 shapes `[batch_size, s] for s in self.state_size`. 

382 scope: optional cell scope. 

383 *args: Additional positional arguments. 

384 **kwargs: Additional keyword arguments. 

385 

386 Returns: 

387 A pair containing: 

388 

389 - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`. 

390 - New state: Either a single `2-D` tensor, or a tuple of tensors matching 

391 the arity and shapes of `state`. 

392 """ 

393 # Bypass RNNCell's variable capturing semantics for LayerRNNCell. 

394 # Instead, it is up to subclasses to provide a proper build 

395 # method. See the class docstring for more details. 

396 return base_layer.Layer.__call__( 

397 self, inputs, state, scope=scope, *args, **kwargs) 

398 

399 

400@keras_export(v1=["keras.__internal__.legacy.rnn_cell.BasicRNNCell"]) 

401@tf_export(v1=["nn.rnn_cell.BasicRNNCell"]) 

402class BasicRNNCell(LayerRNNCell): 

403 """The most basic RNN cell. 

404 

405 Note that this cell is not optimized for performance. Please use 

406 `tf.contrib.cudnn_rnn.CudnnRNNTanh` for better performance on GPU. 

407 

408 Args: 

409 num_units: int, The number of units in the RNN cell. 

410 activation: Nonlinearity to use. Default: `tanh`. It could also be string 

411 that is within Keras activation function names. 

412 reuse: (optional) Python boolean describing whether to reuse variables in an 

413 existing scope. If not `True`, and the existing scope already has the 

414 given variables, an error is raised. 

415 name: String, the name of the layer. Layers with the same name will share 

416 weights, but to avoid mistakes we require reuse=True in such cases. 

417 dtype: Default dtype of the layer (default of `None` means use the type of 

418 the first input). Required when `build` is called before `call`. 

419 **kwargs: Dict, keyword named properties for common layer attributes, like 

420 `trainable` etc when constructing the cell from configs of get_config(). 

421 """ 

422 

423 def __init__(self, 

424 num_units, 

425 activation=None, 

426 reuse=None, 

427 name=None, 

428 dtype=None, 

429 **kwargs): 

430 warnings.warn("`tf.nn.rnn_cell.BasicRNNCell` is deprecated and will be " 

431 "removed in a future version. This class " 

432 "is equivalent as `tf.keras.layers.SimpleRNNCell`, " 

433 "and will be replaced by that in Tensorflow 2.0.") 

434 super(BasicRNNCell, self).__init__( 

435 _reuse=reuse, name=name, dtype=dtype, **kwargs) 

436 _check_supported_dtypes(self.dtype) 

437 if context.executing_eagerly() and tf_config.list_logical_devices("GPU"): 

438 logging.warning( 

439 "%s: Note that this cell is not optimized for performance. " 

440 "Please use tf.contrib.cudnn_rnn.CudnnRNNTanh for better " 

441 "performance on GPU.", self) 

442 

443 # Inputs must be 2-dimensional. 

444 self.input_spec = input_spec.InputSpec(ndim=2) 

445 

446 self._num_units = num_units 

447 if activation: 

448 self._activation = activations.get(activation) 

449 else: 

450 self._activation = math_ops.tanh 

451 

452 @property 

453 def state_size(self): 

454 return self._num_units 

455 

456 @property 

457 def output_size(self): 

458 return self._num_units 

459 

460 @tf_utils.shape_type_conversion 

461 def build(self, inputs_shape): 

462 if inputs_shape[-1] is None: 

463 raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % 

464 str(inputs_shape)) 

465 _check_supported_dtypes(self.dtype) 

466 

467 input_depth = inputs_shape[-1] 

468 self._kernel = self.add_variable( 

469 _WEIGHTS_VARIABLE_NAME, 

470 shape=[input_depth + self._num_units, self._num_units]) 

471 self._bias = self.add_variable( 

472 _BIAS_VARIABLE_NAME, 

473 shape=[self._num_units], 

474 initializer=init_ops.zeros_initializer(dtype=self.dtype)) 

475 

476 self.built = True 

477 

478 def call(self, inputs, state): 

479 """Most basic RNN: output = new_state = act(W * input + U * state + B).""" 

480 _check_rnn_cell_input_dtypes([inputs, state]) 

481 gate_inputs = math_ops.matmul( 

482 array_ops.concat([inputs, state], 1), self._kernel) 

483 gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) 

484 output = self._activation(gate_inputs) 

485 return output, output 

486 

487 def get_config(self): 

488 config = { 

489 "num_units": self._num_units, 

490 "activation": activations.serialize(self._activation), 

491 "reuse": self._reuse, 

492 } 

493 base_config = super(BasicRNNCell, self).get_config() 

494 return dict(list(base_config.items()) + list(config.items())) 

495 

496 

497@keras_export(v1=["keras.__internal__.legacy.rnn_cell.GRUCell"]) 

498@tf_export(v1=["nn.rnn_cell.GRUCell"]) 

499class GRUCell(LayerRNNCell): 

500 """Gated Recurrent Unit cell. 

501 

502 Note that this cell is not optimized for performance. Please use 

503 `tf.contrib.cudnn_rnn.CudnnGRU` for better performance on GPU, or 

504 `tf.contrib.rnn.GRUBlockCellV2` for better performance on CPU. 

505 

506 Args: 

507 num_units: int, The number of units in the GRU cell. 

508 activation: Nonlinearity to use. Default: `tanh`. 

509 reuse: (optional) Python boolean describing whether to reuse variables in an 

510 existing scope. If not `True`, and the existing scope already has the 

511 given variables, an error is raised. 

512 kernel_initializer: (optional) The initializer to use for the weight and 

513 projection matrices. 

514 bias_initializer: (optional) The initializer to use for the bias. 

515 name: String, the name of the layer. Layers with the same name will share 

516 weights, but to avoid mistakes we require reuse=True in such cases. 

517 dtype: Default dtype of the layer (default of `None` means use the type of 

518 the first input). Required when `build` is called before `call`. 

519 **kwargs: Dict, keyword named properties for common layer attributes, like 

520 `trainable` etc when constructing the cell from configs of get_config(). 

521 

522 References: 

523 Learning Phrase Representations using RNN Encoder Decoder for Statistical 

524 Machine Translation: 

525 [Cho et al., 2014] 

526 (https://aclanthology.coli.uni-saarland.de/papers/D14-1179/d14-1179) 

527 ([pdf](http://emnlp2014.org/papers/pdf/EMNLP2014179.pdf)) 

528 """ 

529 

530 def __init__(self, 

531 num_units, 

532 activation=None, 

533 reuse=None, 

534 kernel_initializer=None, 

535 bias_initializer=None, 

536 name=None, 

537 dtype=None, 

538 **kwargs): 

539 warnings.warn("`tf.nn.rnn_cell.GRUCell` is deprecated and will be removed " 

540 "in a future version. This class " 

541 "is equivalent as `tf.keras.layers.GRUCell`, " 

542 "and will be replaced by that in Tensorflow 2.0.") 

543 super(GRUCell, self).__init__( 

544 _reuse=reuse, name=name, dtype=dtype, **kwargs) 

545 _check_supported_dtypes(self.dtype) 

546 

547 if context.executing_eagerly() and tf_config.list_logical_devices("GPU"): 

548 logging.warning( 

549 "%s: Note that this cell is not optimized for performance. " 

550 "Please use tf.contrib.cudnn_rnn.CudnnGRU for better " 

551 "performance on GPU.", self) 

552 # Inputs must be 2-dimensional. 

553 self.input_spec = input_spec.InputSpec(ndim=2) 

554 

555 self._num_units = num_units 

556 if activation: 

557 self._activation = activations.get(activation) 

558 else: 

559 self._activation = math_ops.tanh 

560 self._kernel_initializer = initializers.get(kernel_initializer) 

561 self._bias_initializer = initializers.get(bias_initializer) 

562 

563 @property 

564 def state_size(self): 

565 return self._num_units 

566 

567 @property 

568 def output_size(self): 

569 return self._num_units 

570 

571 @tf_utils.shape_type_conversion 

572 def build(self, inputs_shape): 

573 if inputs_shape[-1] is None: 

574 raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % 

575 str(inputs_shape)) 

576 _check_supported_dtypes(self.dtype) 

577 input_depth = inputs_shape[-1] 

578 self._gate_kernel = self.add_variable( 

579 "gates/%s" % _WEIGHTS_VARIABLE_NAME, 

580 shape=[input_depth + self._num_units, 2 * self._num_units], 

581 initializer=self._kernel_initializer) 

582 self._gate_bias = self.add_variable( 

583 "gates/%s" % _BIAS_VARIABLE_NAME, 

584 shape=[2 * self._num_units], 

585 initializer=(self._bias_initializer 

586 if self._bias_initializer is not None else 

587 init_ops.constant_initializer(1.0, dtype=self.dtype))) 

588 self._candidate_kernel = self.add_variable( 

589 "candidate/%s" % _WEIGHTS_VARIABLE_NAME, 

590 shape=[input_depth + self._num_units, self._num_units], 

591 initializer=self._kernel_initializer) 

592 self._candidate_bias = self.add_variable( 

593 "candidate/%s" % _BIAS_VARIABLE_NAME, 

594 shape=[self._num_units], 

595 initializer=(self._bias_initializer 

596 if self._bias_initializer is not None else 

597 init_ops.zeros_initializer(dtype=self.dtype))) 

598 

599 self.built = True 

600 

601 def call(self, inputs, state): 

602 """Gated recurrent unit (GRU) with nunits cells.""" 

603 _check_rnn_cell_input_dtypes([inputs, state]) 

604 

605 gate_inputs = math_ops.matmul( 

606 array_ops.concat([inputs, state], 1), self._gate_kernel) 

607 gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias) 

608 

609 value = math_ops.sigmoid(gate_inputs) 

610 r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) 

611 

612 r_state = r * state 

613 

614 candidate = math_ops.matmul( 

615 array_ops.concat([inputs, r_state], 1), self._candidate_kernel) 

616 candidate = nn_ops.bias_add(candidate, self._candidate_bias) 

617 

618 c = self._activation(candidate) 

619 new_h = u * state + (1 - u) * c 

620 return new_h, new_h 

621 

622 def get_config(self): 

623 config = { 

624 "num_units": self._num_units, 

625 "kernel_initializer": initializers.serialize(self._kernel_initializer), 

626 "bias_initializer": initializers.serialize(self._bias_initializer), 

627 "activation": activations.serialize(self._activation), 

628 "reuse": self._reuse, 

629 } 

630 base_config = super(GRUCell, self).get_config() 

631 return dict(list(base_config.items()) + list(config.items())) 

632 

633 

634_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h")) 

635 

636 

637@keras_export(v1=["keras.__internal__.legacy.rnn_cell.LSTMStateTuple"]) 

638@tf_export(v1=["nn.rnn_cell.LSTMStateTuple"]) 

639class LSTMStateTuple(_LSTMStateTuple): 

640 """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state. 

641 

642 Stores two elements: `(c, h)`, in that order. Where `c` is the hidden state 

643 and `h` is the output. 

644 

645 Only used when `state_is_tuple=True`. 

646 """ 

647 __slots__ = () 

648 

649 @property 

650 def dtype(self): 

651 (c, h) = self 

652 if c.dtype != h.dtype: 

653 raise TypeError("Inconsistent internal state: %s vs %s" % 

654 (str(c.dtype), str(h.dtype))) 

655 return c.dtype 

656 

657 

658@keras_export(v1=["keras.__internal__.legacy.rnn_cell.BasicLSTMCell"]) 

659@tf_export(v1=["nn.rnn_cell.BasicLSTMCell"]) 

660class BasicLSTMCell(LayerRNNCell): 

661 """DEPRECATED: Please use `tf.compat.v1.nn.rnn_cell.LSTMCell` instead. 

662 

663 Basic LSTM recurrent network cell. 

664 

665 The implementation is based on 

666 

667 We add forget_bias (default: 1) to the biases of the forget gate in order to 

668 reduce the scale of forgetting in the beginning of the training. 

669 

670 It does not allow cell clipping, a projection layer, and does not 

671 use peep-hole connections: it is the basic baseline. 

672 

673 For advanced models, please use the full `tf.compat.v1.nn.rnn_cell.LSTMCell` 

674 that follows. 

675 

676 Note that this cell is not optimized for performance. Please use 

677 `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or 

678 `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for 

679 better performance on CPU. 

680 """ 

681 

682 def __init__(self, 

683 num_units, 

684 forget_bias=1.0, 

685 state_is_tuple=True, 

686 activation=None, 

687 reuse=None, 

688 name=None, 

689 dtype=None, 

690 **kwargs): 

691 """Initialize the basic LSTM cell. 

692 

693 Args: 

694 num_units: int, The number of units in the LSTM cell. 

695 forget_bias: float, The bias added to forget gates (see above). Must set 

696 to `0.0` manually when restoring from CudnnLSTM-trained checkpoints. 

697 state_is_tuple: If True, accepted and returned states are 2-tuples of the 

698 `c_state` and `m_state`. If False, they are concatenated along the 

699 column axis. The latter behavior will soon be deprecated. 

700 activation: Activation function of the inner states. Default: `tanh`. It 

701 could also be string that is within Keras activation function names. 

702 reuse: (optional) Python boolean describing whether to reuse variables in 

703 an existing scope. If not `True`, and the existing scope already has 

704 the given variables, an error is raised. 

705 name: String, the name of the layer. Layers with the same name will share 

706 weights, but to avoid mistakes we require reuse=True in such cases. 

707 dtype: Default dtype of the layer (default of `None` means use the type of 

708 the first input). Required when `build` is called before `call`. 

709 **kwargs: Dict, keyword named properties for common layer attributes, like 

710 `trainable` etc when constructing the cell from configs of get_config(). 

711 When restoring from CudnnLSTM-trained checkpoints, must use 

712 `CudnnCompatibleLSTMCell` instead. 

713 """ 

714 warnings.warn("`tf.nn.rnn_cell.BasicLSTMCell` is deprecated and will be " 

715 "removed in a future version. This class " 

716 "is equivalent as `tf.keras.layers.LSTMCell`, " 

717 "and will be replaced by that in Tensorflow 2.0.") 

718 super(BasicLSTMCell, self).__init__( 

719 _reuse=reuse, name=name, dtype=dtype, **kwargs) 

720 _check_supported_dtypes(self.dtype) 

721 if not state_is_tuple: 

722 logging.warning( 

723 "%s: Using a concatenated state is slower and will soon be " 

724 "deprecated. Use state_is_tuple=True.", self) 

725 if context.executing_eagerly() and tf_config.list_logical_devices("GPU"): 

726 logging.warning( 

727 "%s: Note that this cell is not optimized for performance. " 

728 "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better " 

729 "performance on GPU.", self) 

730 

731 # Inputs must be 2-dimensional. 

732 self.input_spec = input_spec.InputSpec(ndim=2) 

733 

734 self._num_units = num_units 

735 self._forget_bias = forget_bias 

736 self._state_is_tuple = state_is_tuple 

737 if activation: 

738 self._activation = activations.get(activation) 

739 else: 

740 self._activation = math_ops.tanh 

741 

742 @property 

743 def state_size(self): 

744 return (LSTMStateTuple(self._num_units, self._num_units) 

745 if self._state_is_tuple else 2 * self._num_units) 

746 

747 @property 

748 def output_size(self): 

749 return self._num_units 

750 

751 @tf_utils.shape_type_conversion 

752 def build(self, inputs_shape): 

753 if inputs_shape[-1] is None: 

754 raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % 

755 str(inputs_shape)) 

756 _check_supported_dtypes(self.dtype) 

757 input_depth = inputs_shape[-1] 

758 h_depth = self._num_units 

759 self._kernel = self.add_variable( 

760 _WEIGHTS_VARIABLE_NAME, 

761 shape=[input_depth + h_depth, 4 * self._num_units]) 

762 self._bias = self.add_variable( 

763 _BIAS_VARIABLE_NAME, 

764 shape=[4 * self._num_units], 

765 initializer=init_ops.zeros_initializer(dtype=self.dtype)) 

766 

767 self.built = True 

768 

769 def call(self, inputs, state): 

770 """Long short-term memory cell (LSTM). 

771 

772 Args: 

773 inputs: `2-D` tensor with shape `[batch_size, input_size]`. 

774 state: An `LSTMStateTuple` of state tensors, each shaped `[batch_size, 

775 num_units]`, if `state_is_tuple` has been set to `True`. Otherwise, a 

776 `Tensor` shaped `[batch_size, 2 * num_units]`. 

777 

778 Returns: 

779 A pair containing the new hidden state, and the new state (either a 

780 `LSTMStateTuple` or a concatenated state, depending on 

781 `state_is_tuple`). 

782 """ 

783 _check_rnn_cell_input_dtypes([inputs, state]) 

784 

785 sigmoid = math_ops.sigmoid 

786 one = constant_op.constant(1, dtype=dtypes.int32) 

787 # Parameters of gates are concatenated into one multiply for efficiency. 

788 if self._state_is_tuple: 

789 c, h = state 

790 else: 

791 c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one) 

792 

793 gate_inputs = math_ops.matmul( 

794 array_ops.concat([inputs, h], 1), self._kernel) 

795 gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) 

796 

797 # i = input_gate, j = new_input, f = forget_gate, o = output_gate 

798 i, j, f, o = array_ops.split( 

799 value=gate_inputs, num_or_size_splits=4, axis=one) 

800 

801 forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype) 

802 # Note that using `add` and `multiply` instead of `+` and `*` gives a 

803 # performance improvement. So using those at the cost of readability. 

804 add = math_ops.add 

805 multiply = math_ops.multiply 

806 new_c = add( 

807 multiply(c, sigmoid(add(f, forget_bias_tensor))), 

808 multiply(sigmoid(i), self._activation(j))) 

809 new_h = multiply(self._activation(new_c), sigmoid(o)) 

810 

811 if self._state_is_tuple: 

812 new_state = LSTMStateTuple(new_c, new_h) 

813 else: 

814 new_state = array_ops.concat([new_c, new_h], 1) 

815 return new_h, new_state 

816 

817 def get_config(self): 

818 config = { 

819 "num_units": self._num_units, 

820 "forget_bias": self._forget_bias, 

821 "state_is_tuple": self._state_is_tuple, 

822 "activation": activations.serialize(self._activation), 

823 "reuse": self._reuse, 

824 } 

825 base_config = super(BasicLSTMCell, self).get_config() 

826 return dict(list(base_config.items()) + list(config.items())) 

827 

828 

829@keras_export(v1=["keras.__internal__.legacy.rnn_cell.LSTMCell"]) 

830@tf_export(v1=["nn.rnn_cell.LSTMCell"]) 

831class LSTMCell(LayerRNNCell): 

832 """Long short-term memory unit (LSTM) recurrent network cell. 

833 

834 The default non-peephole implementation is based on (Gers et al., 1999). 

835 The peephole implementation is based on (Sak et al., 2014). 

836 

837 The class uses optional peep-hole connections, optional cell clipping, and 

838 an optional projection layer. 

839 

840 Note that this cell is not optimized for performance. Please use 

841 `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or 

842 `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for 

843 better performance on CPU. 

844 References: 

845 Long short-term memory recurrent neural network architectures for large 

846 scale acoustic modeling: 

847 [Sak et al., 2014] 

848 (https://www.isca-speech.org/archive/interspeech_2014/i14_0338.html) 

849 ([pdf] 

850 (https://www.isca-speech.org/archive/archive_papers/interspeech_2014/i14_0338.pdf)) 

851 Learning to forget: 

852 [Gers et al., 1999] 

853 (http://digital-library.theiet.org/content/conferences/10.1049/cp_19991218) 

854 ([pdf](https://arxiv.org/pdf/1409.2329.pdf)) 

855 Long Short-Term Memory: 

856 [Hochreiter et al., 1997] 

857 (https://www.mitpressjournals.org/doi/abs/10.1162/neco.1997.9.8.1735) 

858 ([pdf](http://ml.jku.at/publications/older/3504.pdf)) 

859 """ 

860 

861 def __init__(self, 

862 num_units, 

863 use_peepholes=False, 

864 cell_clip=None, 

865 initializer=None, 

866 num_proj=None, 

867 proj_clip=None, 

868 num_unit_shards=None, 

869 num_proj_shards=None, 

870 forget_bias=1.0, 

871 state_is_tuple=True, 

872 activation=None, 

873 reuse=None, 

874 name=None, 

875 dtype=None, 

876 **kwargs): 

877 """Initialize the parameters for an LSTM cell. 

878 

879 Args: 

880 num_units: int, The number of units in the LSTM cell. 

881 use_peepholes: bool, set True to enable diagonal/peephole connections. 

882 cell_clip: (optional) A float value, if provided the cell state is clipped 

883 by this value prior to the cell output activation. 

884 initializer: (optional) The initializer to use for the weight and 

885 projection matrices. 

886 num_proj: (optional) int, The output dimensionality for the projection 

887 matrices. If None, no projection is performed. 

888 proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is 

889 provided, then the projected values are clipped elementwise to within 

890 `[-proj_clip, proj_clip]`. 

891 num_unit_shards: Deprecated, will be removed by Jan. 2017. Use a 

892 variable_scope partitioner instead. 

893 num_proj_shards: Deprecated, will be removed by Jan. 2017. Use a 

894 variable_scope partitioner instead. 

895 forget_bias: Biases of the forget gate are initialized by default to 1 in 

896 order to reduce the scale of forgetting at the beginning of the 

897 training. Must set it manually to `0.0` when restoring from CudnnLSTM 

898 trained checkpoints. 

899 state_is_tuple: If True, accepted and returned states are 2-tuples of the 

900 `c_state` and `m_state`. If False, they are concatenated along the 

901 column axis. This latter behavior will soon be deprecated. 

902 activation: Activation function of the inner states. Default: `tanh`. It 

903 could also be string that is within Keras activation function names. 

904 reuse: (optional) Python boolean describing whether to reuse variables in 

905 an existing scope. If not `True`, and the existing scope already has 

906 the given variables, an error is raised. 

907 name: String, the name of the layer. Layers with the same name will share 

908 weights, but to avoid mistakes we require reuse=True in such cases. 

909 dtype: Default dtype of the layer (default of `None` means use the type of 

910 the first input). Required when `build` is called before `call`. 

911 **kwargs: Dict, keyword named properties for common layer attributes, like 

912 `trainable` etc when constructing the cell from configs of get_config(). 

913 When restoring from CudnnLSTM-trained checkpoints, use 

914 `CudnnCompatibleLSTMCell` instead. 

915 """ 

916 warnings.warn("`tf.nn.rnn_cell.LSTMCell` is deprecated and will be " 

917 "removed in a future version. This class " 

918 "is equivalent as `tf.keras.layers.LSTMCell`, " 

919 "and will be replaced by that in Tensorflow 2.0.") 

920 super(LSTMCell, self).__init__( 

921 _reuse=reuse, name=name, dtype=dtype, **kwargs) 

922 _check_supported_dtypes(self.dtype) 

923 if not state_is_tuple: 

924 logging.warning( 

925 "%s: Using a concatenated state is slower and will soon be " 

926 "deprecated. Use state_is_tuple=True.", self) 

927 if num_unit_shards is not None or num_proj_shards is not None: 

928 logging.warning( 

929 "%s: The num_unit_shards and proj_unit_shards parameters are " 

930 "deprecated and will be removed in Jan 2017. " 

931 "Use a variable scope with a partitioner instead.", self) 

932 if context.executing_eagerly() and tf_config.list_logical_devices("GPU"): 

933 logging.warning( 

934 "%s: Note that this cell is not optimized for performance. " 

935 "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better " 

936 "performance on GPU.", self) 

937 

938 # Inputs must be 2-dimensional. 

939 self.input_spec = input_spec.InputSpec(ndim=2) 

940 

941 self._num_units = num_units 

942 self._use_peepholes = use_peepholes 

943 self._cell_clip = cell_clip 

944 self._initializer = initializers.get(initializer) 

945 self._num_proj = num_proj 

946 self._proj_clip = proj_clip 

947 self._num_unit_shards = num_unit_shards 

948 self._num_proj_shards = num_proj_shards 

949 self._forget_bias = forget_bias 

950 self._state_is_tuple = state_is_tuple 

951 if activation: 

952 self._activation = activations.get(activation) 

953 else: 

954 self._activation = math_ops.tanh 

955 

956 if num_proj: 

957 self._state_size = ( 

958 LSTMStateTuple(num_units, num_proj) if state_is_tuple else num_units + 

959 num_proj) 

960 self._output_size = num_proj 

961 else: 

962 self._state_size = ( 

963 LSTMStateTuple(num_units, num_units) if state_is_tuple else 2 * 

964 num_units) 

965 self._output_size = num_units 

966 

967 @property 

968 def state_size(self): 

969 return self._state_size 

970 

971 @property 

972 def output_size(self): 

973 return self._output_size 

974 

975 @tf_utils.shape_type_conversion 

976 def build(self, inputs_shape): 

977 if inputs_shape[-1] is None: 

978 raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % 

979 str(inputs_shape)) 

980 _check_supported_dtypes(self.dtype) 

981 input_depth = inputs_shape[-1] 

982 h_depth = self._num_units if self._num_proj is None else self._num_proj 

983 maybe_partitioner = ( 

984 partitioned_variables.fixed_size_partitioner(self._num_unit_shards) 

985 if self._num_unit_shards is not None else None) 

986 self._kernel = self.add_variable( 

987 _WEIGHTS_VARIABLE_NAME, 

988 shape=[input_depth + h_depth, 4 * self._num_units], 

989 initializer=self._initializer, 

990 partitioner=maybe_partitioner) 

991 if self.dtype is None: 

992 initializer = init_ops.zeros_initializer 

993 else: 

994 initializer = init_ops.zeros_initializer(dtype=self.dtype) 

995 self._bias = self.add_variable( 

996 _BIAS_VARIABLE_NAME, 

997 shape=[4 * self._num_units], 

998 initializer=initializer) 

999 if self._use_peepholes: 

1000 self._w_f_diag = self.add_variable( 

1001 "w_f_diag", shape=[self._num_units], initializer=self._initializer) 

1002 self._w_i_diag = self.add_variable( 

1003 "w_i_diag", shape=[self._num_units], initializer=self._initializer) 

1004 self._w_o_diag = self.add_variable( 

1005 "w_o_diag", shape=[self._num_units], initializer=self._initializer) 

1006 

1007 if self._num_proj is not None: 

1008 maybe_proj_partitioner = ( 

1009 partitioned_variables.fixed_size_partitioner(self._num_proj_shards) 

1010 if self._num_proj_shards is not None else None) 

1011 self._proj_kernel = self.add_variable( 

1012 "projection/%s" % _WEIGHTS_VARIABLE_NAME, 

1013 shape=[self._num_units, self._num_proj], 

1014 initializer=self._initializer, 

1015 partitioner=maybe_proj_partitioner) 

1016 

1017 self.built = True 

1018 

1019 def call(self, inputs, state): 

1020 """Run one step of LSTM. 

1021 

1022 Args: 

1023 inputs: input Tensor, must be 2-D, `[batch, input_size]`. 

1024 state: if `state_is_tuple` is False, this must be a state Tensor, `2-D, 

1025 [batch, state_size]`. If `state_is_tuple` is True, this must be a tuple 

1026 of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. 

1027 

1028 Returns: 

1029 A tuple containing: 

1030 

1031 - A `2-D, [batch, output_dim]`, Tensor representing the output of the 

1032 LSTM after reading `inputs` when previous state was `state`. 

1033 Here output_dim is: 

1034 num_proj if num_proj was set, 

1035 num_units otherwise. 

1036 - Tensor(s) representing the new state of LSTM after reading `inputs` when 

1037 the previous state was `state`. Same type and shape(s) as `state`. 

1038 

1039 Raises: 

1040 ValueError: If input size cannot be inferred from inputs via 

1041 static shape inference. 

1042 """ 

1043 _check_rnn_cell_input_dtypes([inputs, state]) 

1044 

1045 num_proj = self._num_units if self._num_proj is None else self._num_proj 

1046 sigmoid = math_ops.sigmoid 

1047 

1048 if self._state_is_tuple: 

1049 (c_prev, m_prev) = state 

1050 else: 

1051 c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) 

1052 m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) 

1053 

1054 input_size = inputs.get_shape().with_rank(2).dims[1].value 

1055 if input_size is None: 

1056 raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 

1057 

1058 # i = input_gate, j = new_input, f = forget_gate, o = output_gate 

1059 lstm_matrix = math_ops.matmul( 

1060 array_ops.concat([inputs, m_prev], 1), self._kernel) 

1061 lstm_matrix = nn_ops.bias_add(lstm_matrix, self._bias) 

1062 

1063 i, j, f, o = array_ops.split( 

1064 value=lstm_matrix, num_or_size_splits=4, axis=1) 

1065 # Diagonal connections 

1066 if self._use_peepholes: 

1067 c = ( 

1068 sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev + 

1069 sigmoid(i + self._w_i_diag * c_prev) * self._activation(j)) 

1070 else: 

1071 c = ( 

1072 sigmoid(f + self._forget_bias) * c_prev + 

1073 sigmoid(i) * self._activation(j)) 

1074 

1075 if self._cell_clip is not None: 

1076 # pylint: disable=invalid-unary-operand-type 

1077 c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) 

1078 # pylint: enable=invalid-unary-operand-type 

1079 if self._use_peepholes: 

1080 m = sigmoid(o + self._w_o_diag * c) * self._activation(c) 

1081 else: 

1082 m = sigmoid(o) * self._activation(c) 

1083 

1084 if self._num_proj is not None: 

1085 m = math_ops.matmul(m, self._proj_kernel) 

1086 

1087 if self._proj_clip is not None: 

1088 # pylint: disable=invalid-unary-operand-type 

1089 m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) 

1090 # pylint: enable=invalid-unary-operand-type 

1091 

1092 new_state = ( 

1093 LSTMStateTuple(c, m) 

1094 if self._state_is_tuple else array_ops.concat([c, m], 1)) 

1095 return m, new_state 

1096 

1097 def get_config(self): 

1098 config = { 

1099 "num_units": self._num_units, 

1100 "use_peepholes": self._use_peepholes, 

1101 "cell_clip": self._cell_clip, 

1102 "initializer": initializers.serialize(self._initializer), 

1103 "num_proj": self._num_proj, 

1104 "proj_clip": self._proj_clip, 

1105 "num_unit_shards": self._num_unit_shards, 

1106 "num_proj_shards": self._num_proj_shards, 

1107 "forget_bias": self._forget_bias, 

1108 "state_is_tuple": self._state_is_tuple, 

1109 "activation": activations.serialize(self._activation), 

1110 "reuse": self._reuse, 

1111 } 

1112 base_config = super(LSTMCell, self).get_config() 

1113 return dict(list(base_config.items()) + list(config.items())) 

1114 

1115 

1116class _RNNCellWrapperV1(RNNCell): 

1117 """Base class for cells wrappers V1 compatibility. 

1118 

1119 This class along with `_RNNCellWrapperV2` allows to define cells wrappers that 

1120 are compatible with V1 and V2, and defines helper methods for this purpose. 

1121 """ 

1122 

1123 def __init__(self, cell, *args, **kwargs): 

1124 super(_RNNCellWrapperV1, self).__init__(*args, **kwargs) 

1125 assert_like_rnncell("cell", cell) 

1126 self.cell = cell 

1127 if isinstance(cell, trackable.Trackable): 

1128 self._track_trackable(self.cell, name="cell") 

1129 

1130 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): 

1131 """Calls the wrapped cell and performs the wrapping logic. 

1132 

1133 This method is called from the wrapper's `call` or `__call__` methods. 

1134 

1135 Args: 

1136 inputs: A tensor with wrapped cell's input. 

1137 state: A tensor or tuple of tensors with wrapped cell's state. 

1138 cell_call_fn: Wrapped cell's method to use for step computation (cell's 

1139 `__call__` or 'call' method). 

1140 **kwargs: Additional arguments. 

1141 

1142 Returns: 

1143 A pair containing: 

1144 - Output: A tensor with cell's output. 

1145 - New state: A tensor or tuple of tensors with new wrapped cell's state. 

1146 """ 

1147 raise NotImplementedError 

1148 

1149 def __call__(self, inputs, state, scope=None): 

1150 """Runs the RNN cell step computation. 

1151 

1152 We assume that the wrapped RNNCell is being built within its `__call__` 

1153 method. We directly use the wrapped cell's `__call__` in the overridden 

1154 wrapper `__call__` method. 

1155 

1156 This allows to use the wrapped cell and the non-wrapped cell equivalently 

1157 when using `__call__`. 

1158 

1159 Args: 

1160 inputs: A tensor with wrapped cell's input. 

1161 state: A tensor or tuple of tensors with wrapped cell's state. 

1162 scope: VariableScope for the subgraph created in the wrapped cells' 

1163 `__call__`. 

1164 

1165 Returns: 

1166 A pair containing: 

1167 

1168 - Output: A tensor with cell's output. 

1169 - New state: A tensor or tuple of tensors with new wrapped cell's state. 

1170 """ 

1171 return self._call_wrapped_cell( 

1172 inputs, state, cell_call_fn=self.cell.__call__, scope=scope) 

1173 

1174 def get_config(self): 

1175 config = { 

1176 "cell": { 

1177 "class_name": self.cell.__class__.__name__, 

1178 "config": self.cell.get_config() 

1179 }, 

1180 } 

1181 base_config = super(_RNNCellWrapperV1, self).get_config() 

1182 return dict(list(base_config.items()) + list(config.items())) 

1183 

1184 @classmethod 

1185 def from_config(cls, config, custom_objects=None): 

1186 config = config.copy() 

1187 cell = config.pop("cell") 

1188 try: 

1189 assert_like_rnncell("cell", cell) 

1190 return cls(cell, **config) 

1191 except TypeError: 

1192 raise ValueError("RNNCellWrapper cannot reconstruct the wrapped cell. " 

1193 "Please overwrite the cell in the config with a RNNCell " 

1194 "instance.") 

1195 

1196 

1197@keras_export(v1=["keras.__internal__.legacy.rnn_cell.DropoutWrapper"]) 

1198@tf_export(v1=["nn.rnn_cell.DropoutWrapper"]) 

1199class DropoutWrapper(rnn_cell_wrapper_impl.DropoutWrapperBase, 

1200 _RNNCellWrapperV1): 

1201 """Operator adding dropout to inputs and outputs of the given cell.""" 

1202 

1203 def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation 

1204 super(DropoutWrapper, self).__init__(*args, **kwargs) 

1205 

1206 __init__.__doc__ = rnn_cell_wrapper_impl.DropoutWrapperBase.__init__.__doc__ 

1207 

1208 

1209@keras_export(v1=["keras.__internal__.legacy.rnn_cell.ResidualWrapper"]) 

1210@tf_export(v1=["nn.rnn_cell.ResidualWrapper"]) 

1211class ResidualWrapper(rnn_cell_wrapper_impl.ResidualWrapperBase, 

1212 _RNNCellWrapperV1): 

1213 """RNNCell wrapper that ensures cell inputs are added to the outputs.""" 

1214 

1215 def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation 

1216 super(ResidualWrapper, self).__init__(*args, **kwargs) 

1217 

1218 __init__.__doc__ = rnn_cell_wrapper_impl.ResidualWrapperBase.__init__.__doc__ 

1219 

1220 

1221@keras_export(v1=["keras.__internal__.legacy.rnn_cell.DeviceWrapper"]) 

1222@tf_export(v1=["nn.rnn_cell.DeviceWrapper"]) 

1223class DeviceWrapper(rnn_cell_wrapper_impl.DeviceWrapperBase, 

1224 _RNNCellWrapperV1): 

1225 

1226 def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation 

1227 super(DeviceWrapper, self).__init__(*args, **kwargs) 

1228 

1229 __init__.__doc__ = rnn_cell_wrapper_impl.DeviceWrapperBase.__init__.__doc__ 

1230 

1231 

1232@keras_export(v1=["keras.__internal__.legacy.rnn_cell.MultiRNNCell"]) 

1233@tf_export(v1=["nn.rnn_cell.MultiRNNCell"]) 

1234class MultiRNNCell(RNNCell): 

1235 """RNN cell composed sequentially of multiple simple cells. 

1236 

1237 Example: 

1238 

1239 ```python 

1240 num_units = [128, 64] 

1241 cells = [BasicLSTMCell(num_units=n) for n in num_units] 

1242 stacked_rnn_cell = MultiRNNCell(cells) 

1243 ``` 

1244 """ 

1245 

1246 def __init__(self, cells, state_is_tuple=True): 

1247 """Create a RNN cell composed sequentially of a number of RNNCells. 

1248 

1249 Args: 

1250 cells: list of RNNCells that will be composed in this order. 

1251 state_is_tuple: If True, accepted and returned states are n-tuples, where 

1252 `n = len(cells)`. If False, the states are all concatenated along the 

1253 column axis. This latter behavior will soon be deprecated. 

1254 

1255 Raises: 

1256 ValueError: if cells is empty (not allowed), or at least one of the cells 

1257 returns a state tuple but the flag `state_is_tuple` is `False`. 

1258 """ 

1259 logging.warning("`tf.nn.rnn_cell.MultiRNNCell` is deprecated. This class " 

1260 "is equivalent as `tf.keras.layers.StackedRNNCells`, " 

1261 "and will be replaced by that in Tensorflow 2.0.") 

1262 super(MultiRNNCell, self).__init__() 

1263 if not cells: 

1264 raise ValueError("Must specify at least one cell for MultiRNNCell.") 

1265 if not nest.is_nested(cells): 

1266 raise TypeError("cells must be a list or tuple, but saw: %s." % cells) 

1267 

1268 if len(set(id(cell) for cell in cells)) < len(cells): 

1269 logging.log_first_n( 

1270 logging.WARN, "At least two cells provided to MultiRNNCell " 

1271 "are the same object and will share weights.", 1) 

1272 

1273 self._cells = cells 

1274 for cell_number, cell in enumerate(self._cells): 

1275 # Add Trackable dependencies on these cells so their variables get 

1276 # saved with this object when using object-based saving. 

1277 if isinstance(cell, trackable.Trackable): 

1278 # TODO(allenl): Track down non-Trackable callers. 

1279 self._track_trackable(cell, name="cell-%d" % (cell_number,)) 

1280 self._state_is_tuple = state_is_tuple 

1281 if not state_is_tuple: 

1282 if any(nest.is_nested(c.state_size) for c in self._cells): 

1283 raise ValueError("Some cells return tuples of states, but the flag " 

1284 "state_is_tuple is not set. State sizes are: %s" % 

1285 str([c.state_size for c in self._cells])) 

1286 

1287 @property 

1288 def state_size(self): 

1289 if self._state_is_tuple: 

1290 return tuple(cell.state_size for cell in self._cells) 

1291 else: 

1292 return sum(cell.state_size for cell in self._cells) 

1293 

1294 @property 

1295 def output_size(self): 

1296 return self._cells[-1].output_size 

1297 

1298 def zero_state(self, batch_size, dtype): 

1299 with backend.name_scope(type(self).__name__ + "ZeroState"): 

1300 if self._state_is_tuple: 

1301 return tuple(cell.zero_state(batch_size, dtype) for cell in self._cells) 

1302 else: 

1303 # We know here that state_size of each cell is not a tuple and 

1304 # presumably does not contain TensorArrays or anything else fancy 

1305 return super(MultiRNNCell, self).zero_state(batch_size, dtype) 

1306 

1307 @property 

1308 def trainable_weights(self): 

1309 if not self.trainable: 

1310 return [] 

1311 weights = [] 

1312 for cell in self._cells: 

1313 if isinstance(cell, base_layer.Layer): 

1314 weights += cell.trainable_weights 

1315 return weights 

1316 

1317 @property 

1318 def non_trainable_weights(self): 

1319 weights = [] 

1320 for cell in self._cells: 

1321 if isinstance(cell, base_layer.Layer): 

1322 weights += cell.non_trainable_weights 

1323 if not self.trainable: 

1324 trainable_weights = [] 

1325 for cell in self._cells: 

1326 if isinstance(cell, base_layer.Layer): 

1327 trainable_weights += cell.trainable_weights 

1328 return trainable_weights + weights 

1329 return weights 

1330 

1331 def call(self, inputs, state): 

1332 """Run this multi-layer cell on inputs, starting from state.""" 

1333 cur_state_pos = 0 

1334 cur_inp = inputs 

1335 new_states = [] 

1336 for i, cell in enumerate(self._cells): 

1337 with vs.variable_scope("cell_%d" % i): 

1338 if self._state_is_tuple: 

1339 if not nest.is_nested(state): 

1340 raise ValueError( 

1341 "Expected state to be a tuple of length %d, but received: %s" % 

1342 (len(self.state_size), state)) 

1343 cur_state = state[i] 

1344 else: 

1345 cur_state = array_ops.slice(state, [0, cur_state_pos], 

1346 [-1, cell.state_size]) 

1347 cur_state_pos += cell.state_size 

1348 cur_inp, new_state = cell(cur_inp, cur_state) 

1349 new_states.append(new_state) 

1350 

1351 new_states = ( 

1352 tuple(new_states) if self._state_is_tuple else array_ops.concat( 

1353 new_states, 1)) 

1354 

1355 return cur_inp, new_states 

1356 

1357 

1358def _check_rnn_cell_input_dtypes(inputs): 

1359 """Check whether the input tensors are with supported dtypes. 

1360 

1361 Default RNN cells only support floats and complex as its dtypes since the 

1362 activation function (tanh and sigmoid) only allow those types. This function 

1363 will throw a proper error message if the inputs is not in a supported type. 

1364 

1365 Args: 

1366 inputs: tensor or nested structure of tensors that are feed to RNN cell as 

1367 input or state. 

1368 

1369 Raises: 

1370 ValueError: if any of the input tensor are not having dtypes of float or 

1371 complex. 

1372 """ 

1373 for t in nest.flatten(inputs): 

1374 _check_supported_dtypes(t.dtype) 

1375 

1376 

1377def _check_supported_dtypes(dtype): 

1378 if dtype is None: 

1379 return 

1380 dtype = dtypes.as_dtype(dtype) 

1381 if not (dtype.is_floating or dtype.is_complex): 

1382 raise ValueError("RNN cell only supports floating point inputs, " 

1383 "but saw dtype: %s" % dtype)