Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/rnn.py: 12%

453 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""RNN helpers for TensorFlow models.""" 

16from tensorflow.python.eager import context 

17from tensorflow.python.framework import constant_op 

18from tensorflow.python.framework import dtypes 

19from tensorflow.python.framework import ops 

20from tensorflow.python.framework import tensor_shape 

21from tensorflow.python.framework import tensor_util 

22from tensorflow.python.ops import array_ops 

23from tensorflow.python.ops import array_ops_stack 

24from tensorflow.python.ops import cond 

25from tensorflow.python.ops import control_flow_assert 

26from tensorflow.python.ops import control_flow_util 

27from tensorflow.python.ops import control_flow_util_v2 

28from tensorflow.python.ops import math_ops 

29from tensorflow.python.ops import rnn_cell_impl 

30from tensorflow.python.ops import tensor_array_ops 

31from tensorflow.python.ops import variable_scope as vs 

32from tensorflow.python.ops import while_loop 

33from tensorflow.python.util import deprecation 

34from tensorflow.python.util import dispatch 

35from tensorflow.python.util import nest 

36from tensorflow.python.util.tf_export import tf_export 

37 

38# pylint: disable=protected-access 

39_concat = rnn_cell_impl._concat 

40# pylint: enable=protected-access 

41 

42 

43def _transpose_batch_time(x): 

44 """Transposes the batch and time dimensions of a Tensor. 

45 

46 If the input tensor has rank < 2 it returns the original tensor. Retains as 

47 much of the static shape information as possible. 

48 

49 Args: 

50 x: A Tensor. 

51 

52 Returns: 

53 x transposed along the first two dimensions. 

54 """ 

55 x_static_shape = x.get_shape() 

56 if x_static_shape.rank is not None and x_static_shape.rank < 2: 

57 return x 

58 

59 x_rank = array_ops.rank(x) 

60 x_t = array_ops.transpose( 

61 x, array_ops.concat(([1, 0], math_ops.range(2, x_rank)), axis=0)) 

62 x_t.set_shape( 

63 tensor_shape.TensorShape( 

64 [x_static_shape.dims[1].value, 

65 x_static_shape.dims[0].value]).concatenate(x_static_shape[2:])) 

66 return x_t 

67 

68 

69def _best_effort_input_batch_size(flat_input): 

70 """Get static input batch size if available, with fallback to the dynamic one. 

71 

72 Args: 

73 flat_input: An iterable of time major input Tensors of shape `[max_time, 

74 batch_size, ...]`. All inputs should have compatible batch sizes. 

75 

76 Returns: 

77 The batch size in Python integer if available, or a scalar Tensor otherwise. 

78 

79 Raises: 

80 ValueError: if there is any input with an invalid shape. 

81 """ 

82 for input_ in flat_input: 

83 shape = input_.shape 

84 if shape.rank is None: 

85 continue 

86 if shape.rank < 2: 

87 raise ValueError("Input tensor should have rank >= 2. Received input=" 

88 f"{input_} of rank {shape.rank}") 

89 batch_size = shape.dims[1].value 

90 if batch_size is not None: 

91 return batch_size 

92 # Fallback to the dynamic batch size of the first input. 

93 return array_ops.shape(flat_input[0])[1] 

94 

95 

96def _infer_state_dtype(explicit_dtype, state): 

97 """Infer the dtype of an RNN state. 

98 

99 Args: 

100 explicit_dtype: explicitly declared dtype or None. 

101 state: RNN's hidden state. Must be a Tensor or a nested iterable containing 

102 Tensors. 

103 

104 Returns: 

105 dtype: inferred dtype of hidden state. 

106 

107 Raises: 

108 ValueError: if `state` has heterogeneous dtypes or is empty. 

109 """ 

110 if explicit_dtype is not None: 

111 return explicit_dtype 

112 elif nest.is_nested(state): 

113 inferred_dtypes = [element.dtype for element in nest.flatten(state)] 

114 if not inferred_dtypes: 

115 raise ValueError(f"Unable to infer dtype from argument state={state}.") 

116 all_same = all(x == inferred_dtypes[0] for x in inferred_dtypes) 

117 if not all_same: 

118 raise ValueError( 

119 f"Argument state={state} has tensors of different inferred dtypes. " 

120 "Unable to infer a single representative dtype. Dtypes received: " 

121 f"{inferred_dtypes}") 

122 return inferred_dtypes[0] 

123 else: 

124 return state.dtype 

125 

126 

127def _maybe_tensor_shape_from_tensor(shape): 

128 if isinstance(shape, ops.Tensor): 

129 return tensor_shape.as_shape(tensor_util.constant_value(shape)) 

130 else: 

131 return shape 

132 

133 

134def _should_cache(): 

135 """Returns True if a default caching device should be set, otherwise False.""" 

136 if context.executing_eagerly(): 

137 return False 

138 # Don't set a caching device when running in a loop, since it is possible that 

139 # train steps could be wrapped in a tf.while_loop. In that scenario caching 

140 # prevents forward computations in loop iterations from re-reading the 

141 # updated weights. 

142 graph = ops.get_default_graph() 

143 ctxt = graph._get_control_flow_context() # pylint: disable=protected-access 

144 in_v1_while_loop = ( 

145 control_flow_util.GetContainingWhileContext(ctxt) is not None) 

146 in_v2_while_loop = control_flow_util_v2.in_while_loop_defun(graph) 

147 return not in_v1_while_loop and not in_v2_while_loop 

148 

149 

150# pylint: disable=unused-argument 

151def _rnn_step(time, 

152 sequence_length, 

153 min_sequence_length, 

154 max_sequence_length, 

155 zero_output, 

156 state, 

157 call_cell, 

158 state_size, 

159 skip_conditionals=False): 

160 """Calculate one step of a dynamic RNN minibatch. 

161 

162 Returns an (output, state) pair conditioned on `sequence_length`. 

163 When skip_conditionals=False, the pseudocode is something like: 

164 

165 if t >= max_sequence_length: 

166 return (zero_output, state) 

167 if t < min_sequence_length: 

168 return call_cell() 

169 

170 # Selectively output zeros or output, old state or new state depending 

171 # on whether we've finished calculating each row. 

172 new_output, new_state = call_cell() 

173 final_output = np.vstack([ 

174 zero_output if time >= sequence_length[r] else new_output_r 

175 for r, new_output_r in enumerate(new_output) 

176 ]) 

177 final_state = np.vstack([ 

178 state[r] if time >= sequence_length[r] else new_state_r 

179 for r, new_state_r in enumerate(new_state) 

180 ]) 

181 return (final_output, final_state) 

182 

183 Args: 

184 time: int32 `Tensor` scalar. 

185 sequence_length: int32 `Tensor` vector of size [batch_size]. 

186 min_sequence_length: int32 `Tensor` scalar, min of sequence_length. 

187 max_sequence_length: int32 `Tensor` scalar, max of sequence_length. 

188 zero_output: `Tensor` vector of shape [output_size]. 

189 state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`, 

190 or a list/tuple of such tensors. 

191 call_cell: lambda returning tuple of (new_output, new_state) where 

192 new_output is a `Tensor` matrix of shape `[batch_size, output_size]`. 

193 new_state is a `Tensor` matrix of shape `[batch_size, state_size]`. 

194 state_size: The `cell.state_size` associated with the state. 

195 skip_conditionals: Python bool, whether to skip using the conditional 

196 calculations. This is useful for `dynamic_rnn`, where the input tensor 

197 matches `max_sequence_length`, and using conditionals just slows 

198 everything down. 

199 

200 Returns: 

201 A tuple of (`final_output`, `final_state`) as given by the pseudocode above: 

202 final_output is a `Tensor` matrix of shape [batch_size, output_size] 

203 final_state is either a single `Tensor` matrix, or a tuple of such 

204 matrices (matching length and shapes of input `state`). 

205 

206 Raises: 

207 ValueError: If the cell returns a state tuple whose length does not match 

208 that returned by `state_size`. 

209 """ 

210 

211 # Convert state to a list for ease of use 

212 flat_state = nest.flatten(state) 

213 flat_zero_output = nest.flatten(zero_output) 

214 

215 # Vector describing which batch entries are finished. 

216 copy_cond = time >= sequence_length 

217 

218 def _copy_one_through(output, new_output): 

219 # TensorArray and scalar get passed through. 

220 if isinstance(output, tensor_array_ops.TensorArray): 

221 return new_output 

222 if output.shape.rank == 0: 

223 return new_output 

224 # Otherwise propagate the old or the new value. 

225 with ops.colocate_with(new_output): 

226 return array_ops.where(copy_cond, output, new_output) 

227 

228 def _copy_some_through(flat_new_output, flat_new_state): 

229 # Use broadcasting select to determine which values should get 

230 # the previous state & zero output, and which values should get 

231 # a calculated state & output. 

232 flat_new_output = [ 

233 _copy_one_through(zero_output, new_output) 

234 for zero_output, new_output in zip(flat_zero_output, flat_new_output) 

235 ] 

236 flat_new_state = [ 

237 _copy_one_through(state, new_state) 

238 for state, new_state in zip(flat_state, flat_new_state) 

239 ] 

240 return flat_new_output + flat_new_state 

241 

242 def _maybe_copy_some_through(): 

243 """Run RNN step. Pass through either no or some past state.""" 

244 new_output, new_state = call_cell() 

245 

246 nest.assert_same_structure(zero_output, new_output) 

247 nest.assert_same_structure(state, new_state) 

248 

249 flat_new_state = nest.flatten(new_state) 

250 flat_new_output = nest.flatten(new_output) 

251 return cond.cond( 

252 # if t < min_seq_len: calculate and return everything 

253 time < min_sequence_length, 

254 lambda: flat_new_output + flat_new_state, 

255 # else copy some of it through 

256 lambda: _copy_some_through(flat_new_output, flat_new_state)) 

257 

258 # TODO(ebrevdo): skipping these conditionals may cause a slowdown, 

259 # but benefits from removing cond() and its gradient. We should 

260 # profile with and without this switch here. 

261 if skip_conditionals: 

262 # Instead of using conditionals, perform the selective copy at all time 

263 # steps. This is faster when max_seq_len is equal to the number of unrolls 

264 # (which is typical for dynamic_rnn). 

265 new_output, new_state = call_cell() 

266 nest.assert_same_structure(zero_output, new_output) 

267 nest.assert_same_structure(state, new_state) 

268 new_state = nest.flatten(new_state) 

269 new_output = nest.flatten(new_output) 

270 final_output_and_state = _copy_some_through(new_output, new_state) 

271 else: 

272 empty_update = lambda: flat_zero_output + flat_state 

273 final_output_and_state = cond.cond( 

274 # if t >= max_seq_len: copy all state through, output zeros 

275 time >= max_sequence_length, 

276 empty_update, 

277 # otherwise calculation is required: copy some or all of it through 

278 _maybe_copy_some_through) 

279 

280 if len(final_output_and_state) != len(flat_zero_output) + len(flat_state): 

281 raise ValueError("Internal error: state and output were not concatenated " 

282 f"correctly. Received state length: {len(flat_state)}, " 

283 f"output length: {len(flat_zero_output)}. Expected " 

284 f"contatenated length: {len(final_output_and_state)}.") 

285 final_output = final_output_and_state[:len(flat_zero_output)] 

286 final_state = final_output_and_state[len(flat_zero_output):] 

287 

288 for output, flat_output in zip(final_output, flat_zero_output): 

289 output.set_shape(flat_output.get_shape()) 

290 for substate, flat_substate in zip(final_state, flat_state): 

291 if not isinstance(substate, tensor_array_ops.TensorArray): 

292 substate.set_shape(flat_substate.get_shape()) 

293 

294 final_output = nest.pack_sequence_as( 

295 structure=zero_output, flat_sequence=final_output) 

296 final_state = nest.pack_sequence_as( 

297 structure=state, flat_sequence=final_state) 

298 

299 return final_output, final_state 

300 

301 

302def _reverse_seq(input_seq, lengths): 

303 """Reverse a list of Tensors up to specified lengths. 

304 

305 Args: 

306 input_seq: Sequence of seq_len tensors of dimension (batch_size, n_features) 

307 or nested tuples of tensors. 

308 lengths: A `Tensor` of dimension batch_size, containing lengths for each 

309 sequence in the batch. If "None" is specified, simply reverses the list. 

310 

311 Returns: 

312 time-reversed sequence 

313 """ 

314 if lengths is None: 

315 return list(reversed(input_seq)) 

316 

317 flat_input_seq = tuple(nest.flatten(input_) for input_ in input_seq) 

318 

319 flat_results = [[] for _ in range(len(input_seq))] 

320 for sequence in zip(*flat_input_seq): 

321 input_shape = tensor_shape.unknown_shape(rank=sequence[0].get_shape().rank) 

322 for input_ in sequence: 

323 input_shape.assert_is_compatible_with(input_.get_shape()) 

324 input_.set_shape(input_shape) 

325 

326 # Join into (time, batch_size, depth) 

327 s_joined = array_ops_stack.stack(sequence) 

328 

329 # Reverse along dimension 0 

330 s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1) 

331 # Split again into list 

332 result = array_ops_stack.unstack(s_reversed) 

333 for r, flat_result in zip(result, flat_results): 

334 r.set_shape(input_shape) 

335 flat_result.append(r) 

336 

337 results = [ 

338 nest.pack_sequence_as(structure=input_, flat_sequence=flat_result) 

339 for input_, flat_result in zip(input_seq, flat_results) 

340 ] 

341 return results 

342 

343 

344@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional(" 

345 "keras.layers.RNN(cell))`, which is equivalent to " 

346 "this API") 

347@tf_export(v1=["nn.bidirectional_dynamic_rnn"]) 

348@dispatch.add_dispatch_support 

349def bidirectional_dynamic_rnn(cell_fw, 

350 cell_bw, 

351 inputs, 

352 sequence_length=None, 

353 initial_state_fw=None, 

354 initial_state_bw=None, 

355 dtype=None, 

356 parallel_iterations=None, 

357 swap_memory=False, 

358 time_major=False, 

359 scope=None): 

360 """Creates a dynamic version of bidirectional recurrent neural network. 

361 

362 Takes input and builds independent forward and backward RNNs. The input_size 

363 of forward and backward cell must match. The initial state for both directions 

364 is zero by default (but can be set optionally) and no intermediate states are 

365 ever returned -- the network is fully unrolled for the given (passed in) 

366 length(s) of the sequence(s) or completely unrolled if length(s) is not 

367 given. 

368 

369 Args: 

370 cell_fw: An instance of RNNCell, to be used for forward direction. 

371 cell_bw: An instance of RNNCell, to be used for backward direction. 

372 inputs: The RNN inputs. 

373 If time_major == False (default), this must be a tensor of shape: 

374 `[batch_size, max_time, ...]`, or a nested tuple of such elements. 

375 If time_major == True, this must be a tensor of shape: `[max_time, 

376 batch_size, ...]`, or a nested tuple of such elements. 

377 sequence_length: (optional) An int32/int64 vector, size `[batch_size]`, 

378 containing the actual lengths for each of the sequences in the batch. If 

379 not provided, all batch entries are assumed to be full sequences; and time 

380 reversal is applied from time `0` to `max_time` for each sequence. 

381 initial_state_fw: (optional) An initial state for the forward RNN. This must 

382 be a tensor of appropriate type and shape `[batch_size, 

383 cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a 

384 tuple of tensors having shapes `[batch_size, s] for s in 

385 cell_fw.state_size`. 

386 initial_state_bw: (optional) Same as for `initial_state_fw`, but using the 

387 corresponding properties of `cell_bw`. 

388 dtype: (optional) The data type for the initial states and expected output. 

389 Required if initial_states are not provided or RNN states have a 

390 heterogeneous dtype. 

391 parallel_iterations: (Default: 32). The number of iterations to run in 

392 parallel. Those operations which do not have any temporal dependency and 

393 can be run in parallel, will be. This parameter trades off time for 

394 space. Values >> 1 use more memory but take less time, while smaller 

395 values use less memory but computations take longer. 

396 swap_memory: Transparently swap the tensors produced in forward inference 

397 but needed for back prop from GPU to CPU. This allows training RNNs which 

398 would typically not fit on a single GPU, with very minimal (or no) 

399 performance penalty. 

400 time_major: The shape format of the `inputs` and `outputs` Tensors. If true, 

401 these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false, 

402 these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using 

403 `time_major = True` is a bit more efficient because it avoids transposes 

404 at the beginning and end of the RNN calculation. However, most TensorFlow 

405 data is batch-major, so by default this function accepts input and emits 

406 output in batch-major form. 

407 scope: VariableScope for the created subgraph; defaults to 

408 "bidirectional_rnn" 

409 

410 Returns: 

411 A tuple (outputs, output_states) where: 

412 outputs: A tuple (output_fw, output_bw) containing the forward and 

413 the backward rnn output `Tensor`. 

414 If time_major == False (default), 

415 output_fw will be a `Tensor` shaped: 

416 `[batch_size, max_time, cell_fw.output_size]` 

417 and output_bw will be a `Tensor` shaped: 

418 `[batch_size, max_time, cell_bw.output_size]`. 

419 If time_major == True, 

420 output_fw will be a `Tensor` shaped: 

421 `[max_time, batch_size, cell_fw.output_size]` 

422 and output_bw will be a `Tensor` shaped: 

423 `[max_time, batch_size, cell_bw.output_size]`. 

424 It returns a tuple instead of a single concatenated `Tensor`, unlike 

425 in the `bidirectional_rnn`. If the concatenated one is preferred, 

426 the forward and backward outputs can be concatenated as 

427 `tf.concat(outputs, 2)`. 

428 output_states: A tuple (output_state_fw, output_state_bw) containing 

429 the forward and the backward final states of bidirectional rnn. 

430 

431 Raises: 

432 TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. 

433 """ 

434 rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw) 

435 rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw) 

436 

437 with vs.variable_scope(scope or "bidirectional_rnn"): 

438 # Forward direction 

439 with vs.variable_scope("fw") as fw_scope: 

440 output_fw, output_state_fw = dynamic_rnn( 

441 cell=cell_fw, 

442 inputs=inputs, 

443 sequence_length=sequence_length, 

444 initial_state=initial_state_fw, 

445 dtype=dtype, 

446 parallel_iterations=parallel_iterations, 

447 swap_memory=swap_memory, 

448 time_major=time_major, 

449 scope=fw_scope) 

450 

451 # Backward direction 

452 if not time_major: 

453 time_axis = 1 

454 batch_axis = 0 

455 else: 

456 time_axis = 0 

457 batch_axis = 1 

458 

459 def _reverse(input_, seq_lengths, seq_axis, batch_axis): 

460 if seq_lengths is not None: 

461 return array_ops.reverse_sequence( 

462 input=input_, 

463 seq_lengths=seq_lengths, 

464 seq_axis=seq_axis, 

465 batch_axis=batch_axis) 

466 else: 

467 return array_ops.reverse(input_, axis=[seq_axis]) 

468 

469 with vs.variable_scope("bw") as bw_scope: 

470 

471 def _map_reverse(inp): 

472 return _reverse( 

473 inp, 

474 seq_lengths=sequence_length, 

475 seq_axis=time_axis, 

476 batch_axis=batch_axis) 

477 

478 inputs_reverse = nest.map_structure(_map_reverse, inputs) 

479 tmp, output_state_bw = dynamic_rnn( 

480 cell=cell_bw, 

481 inputs=inputs_reverse, 

482 sequence_length=sequence_length, 

483 initial_state=initial_state_bw, 

484 dtype=dtype, 

485 parallel_iterations=parallel_iterations, 

486 swap_memory=swap_memory, 

487 time_major=time_major, 

488 scope=bw_scope) 

489 

490 output_bw = _reverse( 

491 tmp, 

492 seq_lengths=sequence_length, 

493 seq_axis=time_axis, 

494 batch_axis=batch_axis) 

495 

496 outputs = (output_fw, output_bw) 

497 output_states = (output_state_fw, output_state_bw) 

498 

499 return (outputs, output_states) 

500 

501 

502@deprecation.deprecated( 

503 None, 

504 "Please use `keras.layers.RNN(cell)`, which is equivalent to this API") 

505@tf_export(v1=["nn.dynamic_rnn"]) 

506@dispatch.add_dispatch_support 

507def dynamic_rnn(cell, 

508 inputs, 

509 sequence_length=None, 

510 initial_state=None, 

511 dtype=None, 

512 parallel_iterations=None, 

513 swap_memory=False, 

514 time_major=False, 

515 scope=None): 

516 """Creates a recurrent neural network specified by RNNCell `cell`. 

517 

518 Performs fully dynamic unrolling of `inputs`. 

519 

520 Example: 

521 

522 ```python 

523 # create a BasicRNNCell 

524 rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size) 

525 

526 # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size] 

527 

528 # defining initial state 

529 initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32) 

530 

531 # 'state' is a tensor of shape [batch_size, cell_state_size] 

532 outputs, state = tf.compat.v1.nn.dynamic_rnn(rnn_cell, input_data, 

533 initial_state=initial_state, 

534 dtype=tf.float32) 

535 ``` 

536 

537 ```python 

538 # create 2 LSTMCells 

539 rnn_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(size) for size in [128, 256]] 

540 

541 # create a RNN cell composed sequentially of a number of RNNCells 

542 multi_rnn_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(rnn_layers) 

543 

544 # 'outputs' is a tensor of shape [batch_size, max_time, 256] 

545 # 'state' is a N-tuple where N is the number of LSTMCells containing a 

546 # tf.nn.rnn_cell.LSTMStateTuple for each cell 

547 outputs, state = tf.compat.v1.nn.dynamic_rnn(cell=multi_rnn_cell, 

548 inputs=data, 

549 dtype=tf.float32) 

550 ``` 

551 

552 

553 Args: 

554 cell: An instance of RNNCell. 

555 inputs: The RNN inputs. 

556 If `time_major == False` (default), this must be a `Tensor` of shape: 

557 `[batch_size, max_time, ...]`, or a nested tuple of such elements. 

558 If `time_major == True`, this must be a `Tensor` of shape: `[max_time, 

559 batch_size, ...]`, or a nested tuple of such elements. This may also be 

560 a (possibly nested) tuple of Tensors satisfying this property. The 

561 first two dimensions must match across all the inputs, but otherwise the 

562 ranks and other shape components may differ. In this case, input to 

563 `cell` at each time-step will replicate the structure of these tuples, 

564 except for the time dimension (from which the time is taken). The input 

565 to `cell` at each time step will be a `Tensor` or (possibly nested) 

566 tuple of Tensors each with dimensions `[batch_size, ...]`. 

567 sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. Used 

568 to copy-through state and zero-out outputs when past a batch element's 

569 sequence length. This parameter enables users to extract the last valid 

570 state and properly padded outputs, so it is provided for correctness. 

571 initial_state: (optional) An initial state for the RNN. If `cell.state_size` 

572 is an integer, this must be a `Tensor` of appropriate type and shape 

573 `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this 

574 should be a tuple of tensors having shapes `[batch_size, s] for s in 

575 cell.state_size`. 

576 dtype: (optional) The data type for the initial state and expected output. 

577 Required if initial_state is not provided or RNN state has a heterogeneous 

578 dtype. 

579 parallel_iterations: (Default: 32). The number of iterations to run in 

580 parallel. Those operations which do not have any temporal dependency and 

581 can be run in parallel, will be. This parameter trades off time for 

582 space. Values >> 1 use more memory but take less time, while smaller 

583 values use less memory but computations take longer. 

584 swap_memory: Transparently swap the tensors produced in forward inference 

585 but needed for back prop from GPU to CPU. This allows training RNNs which 

586 would typically not fit on a single GPU, with very minimal (or no) 

587 performance penalty. 

588 time_major: The shape format of the `inputs` and `outputs` Tensors. If true, 

589 these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false, 

590 these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using 

591 `time_major = True` is a bit more efficient because it avoids transposes 

592 at the beginning and end of the RNN calculation. However, most TensorFlow 

593 data is batch-major, so by default this function accepts input and emits 

594 output in batch-major form. 

595 scope: VariableScope for the created subgraph; defaults to "rnn". 

596 

597 Returns: 

598 A pair (outputs, state) where: 

599 

600 outputs: The RNN output `Tensor`. 

601 

602 If time_major == False (default), this will be a `Tensor` shaped: 

603 `[batch_size, max_time, cell.output_size]`. 

604 

605 If time_major == True, this will be a `Tensor` shaped: 

606 `[max_time, batch_size, cell.output_size]`. 

607 

608 Note, if `cell.output_size` is a (possibly nested) tuple of integers 

609 or `TensorShape` objects, then `outputs` will be a tuple having the 

610 same structure as `cell.output_size`, containing Tensors having shapes 

611 corresponding to the shape data in `cell.output_size`. 

612 

613 state: The final state. If `cell.state_size` is an int, this 

614 will be shaped `[batch_size, cell.state_size]`. If it is a 

615 `TensorShape`, this will be shaped `[batch_size] + cell.state_size`. 

616 If it is a (possibly nested) tuple of ints or `TensorShape`, this will 

617 be a tuple having the corresponding shapes. If cells are `LSTMCells` 

618 `state` will be a tuple containing a `LSTMStateTuple` for each cell. 

619 

620 Raises: 

621 TypeError: If `cell` is not an instance of RNNCell. 

622 ValueError: If inputs is None or an empty list. 

623 

624 @compatibility(TF2) 

625 `tf.compat.v1.nn.dynamic_rnn` is not compatible with eager execution and 

626 `tf.function`. Please use `tf.keras.layers.RNN` instead for TF2 migration. 

627 Take LSTM as an example, you can instantiate a `tf.keras.layers.RNN` layer 

628 with `tf.keras.layers.LSTMCell`, or directly via `tf.keras.layers.LSTM`. Once 

629 the keras layer is created, you can get the output and states by calling 

630 the layer with input and states. Please refer to [this 

631 guide](https://www.tensorflow.org/guide/keras/rnn) for more details about 

632 Keras RNN. You can also find more details about the difference and comparison 

633 between Keras RNN and TF compat v1 rnn in [this 

634 document](https://github.com/tensorflow/community/blob/master/rfcs/20180920-unify-rnn-interface.md) 

635 

636 #### Structural Mapping to Native TF2 

637 

638 Before: 

639 

640 ```python 

641 # create 2 LSTMCells 

642 rnn_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(size) for size in [128, 256]] 

643 

644 # create a RNN cell composed sequentially of a number of RNNCells 

645 multi_rnn_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(rnn_layers) 

646 

647 # 'outputs' is a tensor of shape [batch_size, max_time, 256] 

648 # 'state' is a N-tuple where N is the number of LSTMCells containing a 

649 # tf.nn.rnn_cell.LSTMStateTuple for each cell 

650 outputs, state = tf.compat.v1.nn.dynamic_rnn(cell=multi_rnn_cell, 

651 inputs=data, 

652 dtype=tf.float32) 

653 ``` 

654 

655 After: 

656 

657 ```python 

658 # RNN layer can take a list of cells, which will then stack them together. 

659 # By default, keras RNN will only return the last timestep output and will not 

660 # return states. If you need whole time sequence output as well as the states, 

661 # you can set `return_sequences` and `return_state` to True. 

662 rnn_layer = tf.keras.layers.RNN([tf.keras.layers.LSTMCell(128), 

663 tf.keras.layers.LSTMCell(256)], 

664 return_sequences=True, 

665 return_state=True) 

666 outputs, output_states = rnn_layer(inputs, states) 

667 ``` 

668 

669 #### How to Map Arguments 

670 

671 | TF1 Arg Name | TF2 Arg Name | Note | 

672 | :-------------------- | :-------------- | :------------------------------- | 

673 | `cell` | `cell` | In the RNN layer constructor | 

674 | `inputs` | `inputs` | In the RNN layer `__call__` | 

675 | `sequence_length` | Not used | Adding masking layer before RNN : 

676 : : : to achieve the same result. : 

677 | `initial_state` | `initial_state` | In the RNN layer `__call__` | 

678 | `dtype` | `dtype` | In the RNN layer constructor | 

679 | `parallel_iterations` | Not supported | | 

680 | `swap_memory` | Not supported | | 

681 | `time_major` | `time_major` | In the RNN layer constructor | 

682 | `scope` | Not supported | | 

683 @end_compatibility 

684 """ 

685 rnn_cell_impl.assert_like_rnncell("cell", cell) 

686 

687 with vs.variable_scope(scope or "rnn") as varscope: 

688 # Create a new scope in which the caching device is either 

689 # determined by the parent scope, or is set to place the cached 

690 # Variable using the same placement as for the rest of the RNN. 

691 if _should_cache(): 

692 if varscope.caching_device is None: 

693 varscope.set_caching_device(lambda op: op.device) 

694 

695 # By default, time_major==False and inputs are batch-major: shaped 

696 # [batch, time, depth] 

697 # For internal calculations, we transpose to [time, batch, depth] 

698 flat_input = nest.flatten(inputs) 

699 

700 if not time_major: 

701 # (B,T,D) => (T,B,D) 

702 flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input] 

703 flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input) 

704 

705 parallel_iterations = parallel_iterations or 32 

706 if sequence_length is not None: 

707 sequence_length = math_ops.cast(sequence_length, dtypes.int32) 

708 if sequence_length.get_shape().rank not in (None, 1): 

709 raise ValueError( 

710 f"Argument sequence_length must be a vector of length batch_size." 

711 f" Received sequence_length={sequence_length} of shape: " 

712 f"{sequence_length.get_shape()}") 

713 sequence_length = array_ops.identity( # Just to find it in the graph. 

714 sequence_length, 

715 name="sequence_length") 

716 

717 batch_size = _best_effort_input_batch_size(flat_input) 

718 

719 if initial_state is not None: 

720 state = initial_state 

721 else: 

722 if not dtype: 

723 raise ValueError("If no initial_state is provided, argument `dtype` " 

724 "must be specified") 

725 if getattr(cell, "get_initial_state", None) is not None: 

726 state = cell.get_initial_state( 

727 inputs=None, batch_size=batch_size, dtype=dtype) 

728 else: 

729 state = cell.zero_state(batch_size, dtype) 

730 

731 def _assert_has_shape(x, shape): 

732 x_shape = array_ops.shape(x) 

733 packed_shape = array_ops_stack.stack(shape) 

734 return control_flow_assert.Assert( 

735 math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [ 

736 "Expected shape for Tensor %s is " % x.name, packed_shape, 

737 " but saw shape: ", x_shape 

738 ]) 

739 

740 if not context.executing_eagerly() and sequence_length is not None: 

741 # Perform some shape validation 

742 with ops.control_dependencies( 

743 [_assert_has_shape(sequence_length, [batch_size])]): 

744 sequence_length = array_ops.identity( 

745 sequence_length, name="CheckSeqLen") 

746 

747 inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input) 

748 

749 (outputs, final_state) = _dynamic_rnn_loop( 

750 cell, 

751 inputs, 

752 state, 

753 parallel_iterations=parallel_iterations, 

754 swap_memory=swap_memory, 

755 sequence_length=sequence_length, 

756 dtype=dtype) 

757 

758 # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth]. 

759 # If we are performing batch-major calculations, transpose output back 

760 # to shape [batch, time, depth] 

761 if not time_major: 

762 # (T,B,D) => (B,T,D) 

763 outputs = nest.map_structure(_transpose_batch_time, outputs) 

764 

765 return (outputs, final_state) 

766 

767 

768def _dynamic_rnn_loop(cell, 

769 inputs, 

770 initial_state, 

771 parallel_iterations, 

772 swap_memory, 

773 sequence_length=None, 

774 dtype=None): 

775 """Internal implementation of Dynamic RNN. 

776 

777 Args: 

778 cell: An instance of RNNCell. 

779 inputs: A `Tensor` of shape [time, batch_size, input_size], or a nested 

780 tuple of such elements. 

781 initial_state: A `Tensor` of shape `[batch_size, state_size]`, or if 

782 `cell.state_size` is a tuple, then this should be a tuple of tensors 

783 having shapes `[batch_size, s] for s in cell.state_size`. 

784 parallel_iterations: Positive Python int. 

785 swap_memory: A Python boolean 

786 sequence_length: (optional) An `int32` `Tensor` of shape [batch_size]. 

787 dtype: (optional) Expected dtype of output. If not specified, inferred from 

788 initial_state. 

789 

790 Returns: 

791 Tuple `(final_outputs, final_state)`. 

792 final_outputs: 

793 A `Tensor` of shape `[time, batch_size, cell.output_size]`. If 

794 `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape` 

795 objects, then this returns a (possibly nested) tuple of Tensors matching 

796 the corresponding shapes. 

797 final_state: 

798 A `Tensor`, or possibly nested tuple of Tensors, matching in length 

799 and shapes to `initial_state`. 

800 

801 Raises: 

802 ValueError: If the input depth cannot be inferred via shape inference 

803 from the inputs. 

804 ValueError: If time_step is not the same for all the elements in the 

805 inputs. 

806 ValueError: If batch_size is not the same for all the elements in the 

807 inputs. 

808 """ 

809 state = initial_state 

810 assert isinstance(parallel_iterations, int), "parallel_iterations must be int" 

811 

812 state_size = cell.state_size 

813 

814 flat_input = nest.flatten(inputs) 

815 flat_output_size = nest.flatten(cell.output_size) 

816 

817 # Construct an initial output 

818 input_shape = array_ops.shape(flat_input[0]) 

819 time_steps = input_shape[0] 

820 batch_size = _best_effort_input_batch_size(flat_input) 

821 

822 inputs_got_shape = tuple( 

823 input_.get_shape().with_rank_at_least(3) for input_ in flat_input) 

824 

825 const_time_steps, const_batch_size = inputs_got_shape[0].as_list()[:2] 

826 

827 for i, shape in enumerate(inputs_got_shape): 

828 if not shape[2:].is_fully_defined(): 

829 raise ValueError( 

830 "Input size (depth of inputs) must be accessible via shape inference," 

831 f" but saw value None for input={flat_input[i]}.") 

832 got_time_steps = shape.dims[0].value 

833 got_batch_size = shape.dims[1].value 

834 if const_time_steps != got_time_steps: 

835 raise ValueError( 

836 "Time steps is not the same for all the elements in the input in a " 

837 f"batch. Received time steps={got_time_steps} for input=" 

838 f"{flat_input[i]}.") 

839 if const_batch_size != got_batch_size: 

840 raise ValueError( 

841 "Batch_size is not the same for all the elements in the input. " 

842 f"Received batch size={got_batch_size} for input={flat_input[i]}.") 

843 

844 # Prepare dynamic conditional copying of state & output 

845 def _create_zero_arrays(size): 

846 size = _concat(batch_size, size) 

847 return array_ops.zeros( 

848 array_ops_stack.stack(size), _infer_state_dtype(dtype, state)) 

849 

850 flat_zero_output = tuple( 

851 _create_zero_arrays(output) for output in flat_output_size) 

852 zero_output = nest.pack_sequence_as( 

853 structure=cell.output_size, flat_sequence=flat_zero_output) 

854 

855 if sequence_length is not None: 

856 min_sequence_length = math_ops.reduce_min(sequence_length) 

857 max_sequence_length = math_ops.reduce_max(sequence_length) 

858 else: 

859 max_sequence_length = time_steps 

860 

861 time = array_ops.constant(0, dtype=dtypes.int32, name="time") 

862 

863 with ops.name_scope("dynamic_rnn") as scope: 

864 base_name = scope 

865 

866 def _create_ta(name, element_shape, dtype): 

867 return tensor_array_ops.TensorArray( 

868 dtype=dtype, 

869 size=time_steps, 

870 element_shape=element_shape, 

871 tensor_array_name=base_name + name) 

872 

873 in_graph_mode = not context.executing_eagerly() 

874 if in_graph_mode: 

875 output_ta = tuple( 

876 _create_ta( 

877 "output_%d" % i, 

878 element_shape=( 

879 tensor_shape.TensorShape([const_batch_size]).concatenate( 

880 _maybe_tensor_shape_from_tensor(out_size))), 

881 dtype=_infer_state_dtype(dtype, state)) 

882 for i, out_size in enumerate(flat_output_size)) 

883 input_ta = tuple( 

884 _create_ta( 

885 "input_%d" % i, 

886 element_shape=flat_input_i.shape[1:], 

887 dtype=flat_input_i.dtype) 

888 for i, flat_input_i in enumerate(flat_input)) 

889 input_ta = tuple( 

890 ta.unstack(input_) for ta, input_ in zip(input_ta, flat_input)) 

891 else: 

892 output_ta = tuple([0 for _ in range(time_steps.numpy())] 

893 for i in range(len(flat_output_size))) 

894 input_ta = flat_input 

895 

896 def _time_step(time, output_ta_t, state): 

897 """Take a time step of the dynamic RNN. 

898 

899 Args: 

900 time: int32 scalar Tensor. 

901 output_ta_t: List of `TensorArray`s that represent the output. 

902 state: nested tuple of vector tensors that represent the state. 

903 

904 Returns: 

905 The tuple (time + 1, output_ta_t with updated flow, new_state). 

906 """ 

907 

908 if in_graph_mode: 

909 input_t = tuple(ta.read(time) for ta in input_ta) 

910 # Restore some shape information 

911 for input_, shape in zip(input_t, inputs_got_shape): 

912 input_.set_shape(shape[1:]) 

913 else: 

914 input_t = tuple(ta[time.numpy()] for ta in input_ta) 

915 

916 input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t) 

917 # Keras RNN cells only accept state as list, even if it's a single tensor. 

918 call_cell = lambda: cell(input_t, state) 

919 

920 if sequence_length is not None: 

921 (output, new_state) = _rnn_step( 

922 time=time, 

923 sequence_length=sequence_length, 

924 min_sequence_length=min_sequence_length, 

925 max_sequence_length=max_sequence_length, 

926 zero_output=zero_output, 

927 state=state, 

928 call_cell=call_cell, 

929 state_size=state_size, 

930 skip_conditionals=True) 

931 else: 

932 (output, new_state) = call_cell() 

933 

934 # Pack state if using state tuples 

935 output = nest.flatten(output) 

936 

937 if in_graph_mode: 

938 output_ta_t = tuple( 

939 ta.write(time, out) for ta, out in zip(output_ta_t, output)) 

940 else: 

941 for ta, out in zip(output_ta_t, output): 

942 ta[time.numpy()] = out 

943 

944 return (time + 1, output_ta_t, new_state) 

945 

946 if in_graph_mode: 

947 # Make sure that we run at least 1 step, if necessary, to ensure 

948 # the TensorArrays pick up the dynamic shape. 

949 loop_bound = math_ops.minimum(time_steps, 

950 math_ops.maximum(1, max_sequence_length)) 

951 else: 

952 # Using max_sequence_length isn't currently supported in the Eager branch. 

953 loop_bound = time_steps 

954 

955 _, output_final_ta, final_state = while_loop.while_loop( 

956 cond=lambda time, *_: time < loop_bound, 

957 body=_time_step, 

958 loop_vars=(time, output_ta, state), 

959 parallel_iterations=parallel_iterations, 

960 maximum_iterations=time_steps, 

961 swap_memory=swap_memory) 

962 

963 # Unpack final output if not using output tuples. 

964 if in_graph_mode: 

965 final_outputs = tuple(ta.stack() for ta in output_final_ta) 

966 # Restore some shape information 

967 for output, output_size in zip(final_outputs, flat_output_size): 

968 shape = _concat([const_time_steps, const_batch_size], 

969 output_size, 

970 static=True) 

971 output.set_shape(shape) 

972 else: 

973 final_outputs = output_final_ta 

974 

975 final_outputs = nest.pack_sequence_as( 

976 structure=cell.output_size, flat_sequence=final_outputs) 

977 if not in_graph_mode: 

978 final_outputs = nest.map_structure_up_to( 

979 cell.output_size, 

980 lambda x: array_ops_stack.stack(x, axis=0), final_outputs) 

981 

982 return (final_outputs, final_state) 

983 

984 

985@tf_export(v1=["nn.raw_rnn"]) 

986@dispatch.add_dispatch_support 

987def raw_rnn(cell, 

988 loop_fn, 

989 parallel_iterations=None, 

990 swap_memory=False, 

991 scope=None): 

992 """Creates an `RNN` specified by RNNCell `cell` and loop function `loop_fn`. 

993 

994 **NOTE: This method is still in testing, and the API may change.** 

995 

996 This function is a more primitive version of `dynamic_rnn` that provides 

997 more direct access to the inputs each iteration. It also provides more 

998 control over when to start and finish reading the sequence, and 

999 what to emit for the output. 

1000 

1001 For example, it can be used to implement the dynamic decoder of a seq2seq 

1002 model. 

1003 

1004 Instead of working with `Tensor` objects, most operations work with 

1005 `TensorArray` objects directly. 

1006 

1007 The operation of `raw_rnn`, in pseudo-code, is basically the following: 

1008 

1009 ```python 

1010 time = tf.constant(0, dtype=tf.int32) 

1011 (finished, next_input, initial_state, emit_structure, loop_state) = loop_fn( 

1012 time=time, cell_output=None, cell_state=None, loop_state=None) 

1013 emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype) 

1014 state = initial_state 

1015 while not all(finished): 

1016 (output, cell_state) = cell(next_input, state) 

1017 (next_finished, next_input, next_state, emit, loop_state) = loop_fn( 

1018 time=time + 1, cell_output=output, cell_state=cell_state, 

1019 loop_state=loop_state) 

1020 # Emit zeros and copy forward state for minibatch entries that are finished. 

1021 state = tf.where(finished, state, next_state) 

1022 emit = tf.where(finished, tf.zeros_like(emit_structure), emit) 

1023 emit_ta = emit_ta.write(time, emit) 

1024 # If any new minibatch entries are marked as finished, mark these. 

1025 finished = tf.logical_or(finished, next_finished) 

1026 time += 1 

1027 return (emit_ta, state, loop_state) 

1028 ``` 

1029 

1030 with the additional properties that output and state may be (possibly nested) 

1031 tuples, as determined by `cell.output_size` and `cell.state_size`, and 

1032 as a result the final `state` and `emit_ta` may themselves be tuples. 

1033 

1034 A simple implementation of `dynamic_rnn` via `raw_rnn` looks like this: 

1035 

1036 ```python 

1037 inputs = tf.compat.v1.placeholder(shape=(max_time, batch_size, input_depth), 

1038 dtype=tf.float32) 

1039 sequence_length = tf.compat.v1.placeholder(shape=(batch_size,), 

1040 dtype=tf.int32) 

1041 inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time) 

1042 inputs_ta = inputs_ta.unstack(inputs) 

1043 

1044 cell = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units) 

1045 

1046 def loop_fn(time, cell_output, cell_state, loop_state): 

1047 emit_output = cell_output # == None for time == 0 

1048 if cell_output is None: # time == 0 

1049 next_cell_state = cell.zero_state(batch_size, tf.float32) 

1050 else: 

1051 next_cell_state = cell_state 

1052 elements_finished = (time >= sequence_length) 

1053 finished = tf.reduce_all(elements_finished) 

1054 next_input = tf.cond( 

1055 finished, 

1056 lambda: tf.zeros([batch_size, input_depth], dtype=tf.float32), 

1057 lambda: inputs_ta.read(time)) 

1058 next_loop_state = None 

1059 return (elements_finished, next_input, next_cell_state, 

1060 emit_output, next_loop_state) 

1061 

1062 outputs_ta, final_state, _ = raw_rnn(cell, loop_fn) 

1063 outputs = outputs_ta.stack() 

1064 ``` 

1065 

1066 Args: 

1067 cell: An instance of RNNCell. 

1068 loop_fn: A callable that takes inputs `(time, cell_output, cell_state, 

1069 loop_state)` and returns the tuple `(finished, next_input, 

1070 next_cell_state, emit_output, next_loop_state)`. Here `time` is an int32 

1071 scalar `Tensor`, `cell_output` is a `Tensor` or (possibly nested) tuple of 

1072 tensors as determined by `cell.output_size`, and `cell_state` is a 

1073 `Tensor` or (possibly nested) tuple of tensors, as determined by the 

1074 `loop_fn` on its first call (and should match `cell.state_size`). 

1075 The outputs are: `finished`, a boolean `Tensor` of 

1076 shape `[batch_size]`, `next_input`: the next input to feed to `cell`, 

1077 `next_cell_state`: the next state to feed to `cell`, 

1078 and `emit_output`: the output to store for this iteration. Note that 

1079 `emit_output` should be a `Tensor` or (possibly nested) tuple of tensors 

1080 which is aggregated in the `emit_ta` inside the `while_loop`. For the 

1081 first call to `loop_fn`, the `emit_output` corresponds to the 

1082 `emit_structure` which is then used to determine the size of the 

1083 `zero_tensor` for the `emit_ta` (defaults to `cell.output_size`). For 

1084 the subsequent calls to the `loop_fn`, the `emit_output` corresponds to 

1085 the actual output tensor that is to be aggregated in the `emit_ta`. The 

1086 parameter `cell_state` and output `next_cell_state` may be either a 

1087 single or (possibly nested) tuple of tensors. The parameter 

1088 `loop_state` and output `next_loop_state` may be either a single or 

1089 (possibly nested) tuple of `Tensor` and `TensorArray` objects. This 

1090 last parameter may be ignored by `loop_fn` and the return value may be 

1091 `None`. If it is not `None`, then the `loop_state` will be propagated 

1092 through the RNN loop, for use purely by `loop_fn` to keep track of its 

1093 own state. The `next_loop_state` parameter returned may be `None`. The 

1094 first call to `loop_fn` will be `time = 0`, `cell_output = None`, 

1095 `cell_state = None`, and `loop_state = None`. For this call: The 

1096 `next_cell_state` value should be the value with which to initialize the 

1097 cell's state. It may be a final state from a previous RNN or it may be 

1098 the output of `cell.zero_state()`. It should be a (possibly nested) 

1099 tuple structure of tensors. If `cell.state_size` is an integer, this 

1100 must be a `Tensor` of appropriate type and shape `[batch_size, 

1101 cell.state_size]`. If `cell.state_size` is a `TensorShape`, this must be 

1102 a `Tensor` of appropriate type and shape `[batch_size] + 

1103 cell.state_size`. If `cell.state_size` is a (possibly nested) tuple of 

1104 ints or `TensorShape`, this will be a tuple having the corresponding 

1105 shapes. The `emit_output` value may be either `None` or a (possibly 

1106 nested) tuple structure of tensors, e.g., `(tf.zeros(shape_0, 

1107 dtype=dtype_0), tf.zeros(shape_1, dtype=dtype_1))`. If this first 

1108 `emit_output` return value is `None`, then the `emit_ta` result of 

1109 `raw_rnn` will have the same structure and dtypes as `cell.output_size`. 

1110 Otherwise `emit_ta` will have the same structure, shapes (prepended with 

1111 a `batch_size` dimension), and dtypes as `emit_output`. The actual 

1112 values returned for `emit_output` at this initializing call are ignored. 

1113 Note, this emit structure must be consistent across all time steps. 

1114 parallel_iterations: (Default: 32). The number of iterations to run in 

1115 parallel. Those operations which do not have any temporal dependency and 

1116 can be run in parallel, will be. This parameter trades off time for 

1117 space. Values >> 1 use more memory but take less time, while smaller 

1118 values use less memory but computations take longer. 

1119 swap_memory: Transparently swap the tensors produced in forward inference 

1120 but needed for back prop from GPU to CPU. This allows training RNNs which 

1121 would typically not fit on a single GPU, with very minimal (or no) 

1122 performance penalty. 

1123 scope: VariableScope for the created subgraph; defaults to "rnn". 

1124 

1125 Returns: 

1126 A tuple `(emit_ta, final_state, final_loop_state)` where: 

1127 

1128 `emit_ta`: The RNN output `TensorArray`. 

1129 If `loop_fn` returns a (possibly nested) set of Tensors for 

1130 `emit_output` during initialization, (inputs `time = 0`, 

1131 `cell_output = None`, and `loop_state = None`), then `emit_ta` will 

1132 have the same structure, dtypes, and shapes as `emit_output` instead. 

1133 If `loop_fn` returns `emit_output = None` during this call, 

1134 the structure of `cell.output_size` is used: 

1135 If `cell.output_size` is a (possibly nested) tuple of integers 

1136 or `TensorShape` objects, then `emit_ta` will be a tuple having the 

1137 same structure as `cell.output_size`, containing TensorArrays whose 

1138 elements' shapes correspond to the shape data in `cell.output_size`. 

1139 

1140 `final_state`: The final cell state. If `cell.state_size` is an int, this 

1141 will be shaped `[batch_size, cell.state_size]`. If it is a 

1142 `TensorShape`, this will be shaped `[batch_size] + cell.state_size`. 

1143 If it is a (possibly nested) tuple of ints or `TensorShape`, this will 

1144 be a tuple having the corresponding shapes. 

1145 

1146 `final_loop_state`: The final loop state as returned by `loop_fn`. 

1147 

1148 Raises: 

1149 TypeError: If `cell` is not an instance of RNNCell, or `loop_fn` is not 

1150 a `callable`. 

1151 """ 

1152 rnn_cell_impl.assert_like_rnncell("cell", cell) 

1153 

1154 if not callable(loop_fn): 

1155 raise TypeError("Argument `loop_fn` must be a callable. Received: " 

1156 f"{loop_fn}.") 

1157 

1158 parallel_iterations = parallel_iterations or 32 

1159 

1160 # Create a new scope in which the caching device is either 

1161 # determined by the parent scope, or is set to place the cached 

1162 # Variable using the same placement as for the rest of the RNN. 

1163 with vs.variable_scope(scope or "rnn") as varscope: 

1164 if _should_cache(): 

1165 if varscope.caching_device is None: 

1166 varscope.set_caching_device(lambda op: op.device) 

1167 

1168 time = constant_op.constant(0, dtype=dtypes.int32) 

1169 (elements_finished, next_input, 

1170 initial_state, emit_structure, init_loop_state) = loop_fn( 

1171 time, None, None, None) # time, cell_output, cell_state, loop_state 

1172 flat_input = nest.flatten(next_input) 

1173 

1174 # Need a surrogate loop state for the while_loop if none is available. 

1175 loop_state = ( 

1176 init_loop_state if init_loop_state is not None else 

1177 constant_op.constant(0, dtype=dtypes.int32)) 

1178 

1179 input_shape = [input_.get_shape() for input_ in flat_input] 

1180 static_batch_size = tensor_shape.dimension_at_index(input_shape[0], 0) 

1181 

1182 for input_shape_i in input_shape: 

1183 # Static verification that batch sizes all match 

1184 static_batch_size.assert_is_compatible_with( 

1185 tensor_shape.dimension_at_index(input_shape_i, 0)) 

1186 

1187 batch_size = tensor_shape.dimension_value(static_batch_size) 

1188 const_batch_size = batch_size 

1189 if batch_size is None: 

1190 batch_size = array_ops.shape(flat_input[0])[0] 

1191 

1192 nest.assert_same_structure(initial_state, cell.state_size) 

1193 state = initial_state 

1194 flat_state = nest.flatten(state) 

1195 flat_state = [ops.convert_to_tensor(s) for s in flat_state] 

1196 state = nest.pack_sequence_as(structure=state, flat_sequence=flat_state) 

1197 

1198 if emit_structure is not None: 

1199 flat_emit_structure = nest.flatten(emit_structure) 

1200 flat_emit_size = [ 

1201 emit.shape if emit.shape.is_fully_defined() else array_ops.shape(emit) 

1202 for emit in flat_emit_structure 

1203 ] 

1204 flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure] 

1205 else: 

1206 emit_structure = cell.output_size 

1207 flat_emit_size = nest.flatten(emit_structure) 

1208 flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size) 

1209 

1210 flat_emit_ta = [ 

1211 tensor_array_ops.TensorArray( 

1212 dtype=dtype_i, 

1213 dynamic_size=True, 

1214 element_shape=(tensor_shape.TensorShape([ 

1215 const_batch_size 

1216 ]).concatenate(_maybe_tensor_shape_from_tensor(size_i))), 

1217 size=0, 

1218 name="rnn_output_%d" % i) 

1219 for i, (dtype_i, 

1220 size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size)) 

1221 ] 

1222 emit_ta = nest.pack_sequence_as( 

1223 structure=emit_structure, flat_sequence=flat_emit_ta) 

1224 flat_zero_emit = [ 

1225 array_ops.zeros(_concat(batch_size, size_i), dtype_i) 

1226 for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes) 

1227 ] 

1228 zero_emit = nest.pack_sequence_as( 

1229 structure=emit_structure, flat_sequence=flat_zero_emit) 

1230 

1231 def condition(unused_time, elements_finished, *_): 

1232 return math_ops.logical_not(math_ops.reduce_all(elements_finished)) 

1233 

1234 def body(time, elements_finished, current_input, emit_ta, state, 

1235 loop_state): 

1236 """Internal while loop body for raw_rnn. 

1237 

1238 Args: 

1239 time: time scalar. 

1240 elements_finished: batch-size vector. 

1241 current_input: possibly nested tuple of input tensors. 

1242 emit_ta: possibly nested tuple of output TensorArrays. 

1243 state: possibly nested tuple of state tensors. 

1244 loop_state: possibly nested tuple of loop state tensors. 

1245 

1246 Returns: 

1247 Tuple having the same size as Args but with updated values. 

1248 """ 

1249 (next_output, cell_state) = cell(current_input, state) 

1250 

1251 nest.assert_same_structure(state, cell_state) 

1252 nest.assert_same_structure(cell.output_size, next_output) 

1253 

1254 next_time = time + 1 

1255 (next_finished, next_input, next_state, emit_output, 

1256 next_loop_state) = loop_fn(next_time, next_output, cell_state, 

1257 loop_state) 

1258 

1259 nest.assert_same_structure(state, next_state) 

1260 nest.assert_same_structure(current_input, next_input) 

1261 nest.assert_same_structure(emit_ta, emit_output) 

1262 

1263 # If loop_fn returns None for next_loop_state, just reuse the 

1264 # previous one. 

1265 loop_state = loop_state if next_loop_state is None else next_loop_state 

1266 

1267 def _copy_some_through(current, candidate): 

1268 """Copy some tensors through via array_ops.where.""" 

1269 

1270 def copy_fn(cur_i, cand_i): 

1271 # TensorArray and scalar get passed through. 

1272 if isinstance(cur_i, tensor_array_ops.TensorArray): 

1273 return cand_i 

1274 if cur_i.shape.rank == 0: 

1275 return cand_i 

1276 # Otherwise propagate the old or the new value. 

1277 with ops.colocate_with(cand_i): 

1278 return array_ops.where(elements_finished, cur_i, cand_i) 

1279 

1280 return nest.map_structure(copy_fn, current, candidate) 

1281 

1282 emit_output = _copy_some_through(zero_emit, emit_output) 

1283 next_state = _copy_some_through(state, next_state) 

1284 

1285 emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit), 

1286 emit_ta, emit_output) 

1287 

1288 elements_finished = math_ops.logical_or(elements_finished, next_finished) 

1289 

1290 return (next_time, elements_finished, next_input, emit_ta, next_state, 

1291 loop_state) 

1292 

1293 returned = while_loop.while_loop( 

1294 condition, 

1295 body, 

1296 loop_vars=[ 

1297 time, elements_finished, next_input, emit_ta, state, loop_state 

1298 ], 

1299 parallel_iterations=parallel_iterations, 

1300 swap_memory=swap_memory) 

1301 

1302 (emit_ta, final_state, final_loop_state) = returned[-3:] 

1303 

1304 if init_loop_state is None: 

1305 final_loop_state = None 

1306 

1307 return (emit_ta, final_state, final_loop_state) 

1308 

1309 

1310@deprecation.deprecated(None, 

1311 "Please use `keras.layers.RNN(cell, unroll=True)`, " 

1312 "which is equivalent to this API") 

1313@tf_export(v1=["nn.static_rnn"]) 

1314@dispatch.add_dispatch_support 

1315def static_rnn(cell, 

1316 inputs, 

1317 initial_state=None, 

1318 dtype=None, 

1319 sequence_length=None, 

1320 scope=None): 

1321 """Creates a recurrent neural network specified by RNNCell `cell`. 

1322 

1323 The simplest form of RNN network generated is: 

1324 

1325 ```python 

1326 state = cell.zero_state(...) 

1327 outputs = [] 

1328 for input_ in inputs: 

1329 output, state = cell(input_, state) 

1330 outputs.append(output) 

1331 return (outputs, state) 

1332 ``` 

1333 However, a few other options are available: 

1334 

1335 An initial state can be provided. 

1336 If the sequence_length vector is provided, dynamic calculation is performed. 

1337 This method of calculation does not compute the RNN steps past the maximum 

1338 sequence length of the minibatch (thus saving computational time), 

1339 and properly propagates the state at an example's sequence length 

1340 to the final state output. 

1341 

1342 The dynamic calculation performed is, at time `t` for batch row `b`, 

1343 

1344 ```python 

1345 (output, state)(b, t) = 

1346 (t >= sequence_length(b)) 

1347 ? (zeros(cell.output_size), states(b, sequence_length(b) - 1)) 

1348 : cell(input(b, t), state(b, t - 1)) 

1349 ``` 

1350 

1351 Args: 

1352 cell: An instance of RNNCell. 

1353 inputs: A length T list of inputs, each a `Tensor` of shape `[batch_size, 

1354 input_size]`, or a nested tuple of such elements. 

1355 initial_state: (optional) An initial state for the RNN. If `cell.state_size` 

1356 is an integer, this must be a `Tensor` of appropriate type and shape 

1357 `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this 

1358 should be a tuple of tensors having shapes `[batch_size, s] for s in 

1359 cell.state_size`. 

1360 dtype: (optional) The data type for the initial state and expected output. 

1361 Required if initial_state is not provided or RNN state has a heterogeneous 

1362 dtype. 

1363 sequence_length: Specifies the length of each sequence in inputs. An int32 

1364 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`. 

1365 scope: VariableScope for the created subgraph; defaults to "rnn". 

1366 

1367 Returns: 

1368 A pair (outputs, state) where: 

1369 

1370 - outputs is a length T list of outputs (one for each input), or a nested 

1371 tuple of such elements. 

1372 - state is the final state 

1373 

1374 Raises: 

1375 TypeError: If `cell` is not an instance of RNNCell. 

1376 ValueError: If `inputs` is `None` or an empty list, or if the input depth 

1377 (column size) cannot be inferred from inputs via shape inference. 

1378 """ 

1379 rnn_cell_impl.assert_like_rnncell("cell", cell) 

1380 if not nest.is_nested(inputs): 

1381 raise TypeError(f"Argument `inputs` must be a sequence. Received: {inputs}") 

1382 if not inputs: 

1383 raise ValueError("Argument `inputs` must not be empty.") 

1384 

1385 outputs = [] 

1386 # Create a new scope in which the caching device is either 

1387 # determined by the parent scope, or is set to place the cached 

1388 # Variable using the same placement as for the rest of the RNN. 

1389 with vs.variable_scope(scope or "rnn") as varscope: 

1390 if _should_cache(): 

1391 if varscope.caching_device is None: 

1392 varscope.set_caching_device(lambda op: op.device) 

1393 

1394 # Obtain the first sequence of the input 

1395 first_input = inputs 

1396 while nest.is_nested(first_input): 

1397 first_input = first_input[0] 

1398 

1399 # Temporarily avoid EmbeddingWrapper and seq2seq badness 

1400 # TODO(lukaszkaiser): remove EmbeddingWrapper 

1401 if first_input.get_shape().rank != 1: 

1402 

1403 input_shape = first_input.get_shape().with_rank_at_least(2) 

1404 fixed_batch_size = input_shape.dims[0] 

1405 

1406 flat_inputs = nest.flatten(inputs) 

1407 for flat_input in flat_inputs: 

1408 input_shape = flat_input.get_shape().with_rank_at_least(2) 

1409 batch_size, input_size = tensor_shape.dimension_at_index( 

1410 input_shape, 0), input_shape[1:] 

1411 fixed_batch_size.assert_is_compatible_with(batch_size) 

1412 for i, size in enumerate(input_size.dims): 

1413 if tensor_shape.dimension_value(size) is None: 

1414 raise ValueError( 

1415 f"Input size (dimension {i} of input {flat_input}) must be " 

1416 "accessible via shape inference, but saw value None.") 

1417 else: 

1418 fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0] 

1419 

1420 if tensor_shape.dimension_value(fixed_batch_size): 

1421 batch_size = tensor_shape.dimension_value(fixed_batch_size) 

1422 else: 

1423 batch_size = array_ops.shape(first_input)[0] 

1424 if initial_state is not None: 

1425 state = initial_state 

1426 else: 

1427 if not dtype: 

1428 raise ValueError("If no initial_state is provided, argument `dtype` " 

1429 "must be specified") 

1430 if getattr(cell, "get_initial_state", None) is not None: 

1431 state = cell.get_initial_state( 

1432 inputs=None, batch_size=batch_size, dtype=dtype) 

1433 else: 

1434 state = cell.zero_state(batch_size, dtype) 

1435 

1436 if sequence_length is not None: # Prepare variables 

1437 sequence_length = ops.convert_to_tensor( 

1438 sequence_length, name="sequence_length") 

1439 if sequence_length.get_shape().rank not in (None, 1): 

1440 raise ValueError( 

1441 "Argument `sequence_length` must be a vector of length " 

1442 f"{batch_size}. Received sequence_length={sequence_length}.") 

1443 

1444 def _create_zero_output(output_size): 

1445 # convert int to TensorShape if necessary 

1446 size = _concat(batch_size, output_size) 

1447 output = array_ops.zeros( 

1448 array_ops_stack.stack(size), _infer_state_dtype(dtype, state)) 

1449 shape = _concat( 

1450 tensor_shape.dimension_value(fixed_batch_size), 

1451 output_size, 

1452 static=True) 

1453 output.set_shape(tensor_shape.TensorShape(shape)) 

1454 return output 

1455 

1456 output_size = cell.output_size 

1457 flat_output_size = nest.flatten(output_size) 

1458 flat_zero_output = tuple( 

1459 _create_zero_output(size) for size in flat_output_size) 

1460 zero_output = nest.pack_sequence_as( 

1461 structure=output_size, flat_sequence=flat_zero_output) 

1462 

1463 sequence_length = math_ops.cast(sequence_length, dtypes.int32) 

1464 min_sequence_length = math_ops.reduce_min(sequence_length) 

1465 max_sequence_length = math_ops.reduce_max(sequence_length) 

1466 

1467 for time, input_ in enumerate(inputs): 

1468 if time > 0: 

1469 varscope.reuse_variables() 

1470 # pylint: disable=cell-var-from-loop 

1471 call_cell = lambda: cell(input_, state) 

1472 # pylint: enable=cell-var-from-loop 

1473 if sequence_length is not None: 

1474 (output, state) = _rnn_step( 

1475 time=time, 

1476 sequence_length=sequence_length, 

1477 min_sequence_length=min_sequence_length, 

1478 max_sequence_length=max_sequence_length, 

1479 zero_output=zero_output, 

1480 state=state, 

1481 call_cell=call_cell, 

1482 state_size=cell.state_size) 

1483 else: 

1484 (output, state) = call_cell() 

1485 outputs.append(output) 

1486 

1487 return (outputs, state) 

1488 

1489 

1490@deprecation.deprecated(None, 

1491 "Please use `keras.layers.RNN(cell, stateful=True)`, " 

1492 "which is equivalent to this API") 

1493@tf_export(v1=["nn.static_state_saving_rnn"]) 

1494@dispatch.add_dispatch_support 

1495def static_state_saving_rnn(cell, 

1496 inputs, 

1497 state_saver, 

1498 state_name, 

1499 sequence_length=None, 

1500 scope=None): 

1501 """RNN that accepts a state saver for time-truncated RNN calculation. 

1502 

1503 Args: 

1504 cell: An instance of `RNNCell`. 

1505 inputs: A length T list of inputs, each a `Tensor` of shape `[batch_size, 

1506 input_size]`. 

1507 state_saver: A state saver object with methods `state` and `save_state`. 

1508 state_name: Python string or tuple of strings. The name to use with the 

1509 state_saver. If the cell returns tuples of states (i.e., `cell.state_size` 

1510 is a tuple) then `state_name` should be a tuple of strings having the same 

1511 length as `cell.state_size`. Otherwise it should be a single string. 

1512 sequence_length: (optional) An int32/int64 vector size [batch_size]. See the 

1513 documentation for rnn() for more details about sequence_length. 

1514 scope: VariableScope for the created subgraph; defaults to "rnn". 

1515 

1516 Returns: 

1517 A pair (outputs, state) where: 

1518 outputs is a length T list of outputs (one for each input) 

1519 states is the final state 

1520 

1521 Raises: 

1522 TypeError: If `cell` is not an instance of RNNCell. 

1523 ValueError: If `inputs` is `None` or an empty list, or if the arity and 

1524 type of `state_name` does not match that of `cell.state_size`. 

1525 """ 

1526 state_size = cell.state_size 

1527 state_is_tuple = nest.is_nested(state_size) 

1528 state_name_tuple = nest.is_nested(state_name) 

1529 

1530 if state_is_tuple != state_name_tuple: 

1531 raise ValueError("Argument `state_name` should be the same type as " 

1532 f"`cell.state_size`. Received: state_name={state_name!s}, " 

1533 f"cell.state_size={state_size!s}.") 

1534 

1535 if state_is_tuple: 

1536 state_name_flat = nest.flatten(state_name) 

1537 state_size_flat = nest.flatten(state_size) 

1538 

1539 if len(state_name_flat) != len(state_size_flat): 

1540 raise ValueError("Number of elements in argument `state_name` and " 

1541 "`cell.state_size` are mismatched. Received " 

1542 f"state_name={state_name} with {len(state_name_flat)} " 

1543 f"elements and cell.state_size={cell.state_size} with " 

1544 f"{len(state_size_flat)} elements.") 

1545 

1546 initial_state = nest.pack_sequence_as( 

1547 structure=state_size, 

1548 flat_sequence=[state_saver.state(s) for s in state_name_flat]) 

1549 else: 

1550 initial_state = state_saver.state(state_name) 

1551 

1552 (outputs, state) = static_rnn( 

1553 cell, 

1554 inputs, 

1555 initial_state=initial_state, 

1556 sequence_length=sequence_length, 

1557 scope=scope) 

1558 

1559 if state_is_tuple: 

1560 flat_state = nest.flatten(state) 

1561 state_name = nest.flatten(state_name) 

1562 save_state = [ 

1563 state_saver.save_state(name, substate) 

1564 for name, substate in zip(state_name, flat_state) 

1565 ] 

1566 else: 

1567 save_state = [state_saver.save_state(state_name, state)] 

1568 

1569 with ops.control_dependencies(save_state): 

1570 last_output = outputs[-1] 

1571 flat_last_output = nest.flatten(last_output) 

1572 flat_last_output = [ 

1573 array_ops.identity(output) for output in flat_last_output 

1574 ] 

1575 outputs[-1] = nest.pack_sequence_as( 

1576 structure=last_output, flat_sequence=flat_last_output) 

1577 

1578 if state_is_tuple: 

1579 state = nest.pack_sequence_as( 

1580 structure=state, 

1581 flat_sequence=[array_ops.identity(s) for s in flat_state]) 

1582 else: 

1583 state = array_ops.identity(state) 

1584 

1585 return (outputs, state) 

1586 

1587 

1588@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional(" 

1589 "keras.layers.RNN(cell, unroll=True))`, which is " 

1590 "equivalent to this API") 

1591@tf_export(v1=["nn.static_bidirectional_rnn"]) 

1592@dispatch.add_dispatch_support 

1593def static_bidirectional_rnn(cell_fw, 

1594 cell_bw, 

1595 inputs, 

1596 initial_state_fw=None, 

1597 initial_state_bw=None, 

1598 dtype=None, 

1599 sequence_length=None, 

1600 scope=None): 

1601 """Creates a bidirectional recurrent neural network. 

1602 

1603 Similar to the unidirectional case above (rnn) but takes input and builds 

1604 independent forward and backward RNNs with the final forward and backward 

1605 outputs depth-concatenated, such that the output will have the format 

1606 [time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of 

1607 forward and backward cell must match. The initial state for both directions 

1608 is zero by default (but can be set optionally) and no intermediate states are 

1609 ever returned -- the network is fully unrolled for the given (passed in) 

1610 length(s) of the sequence(s) or completely unrolled if length(s) is not given. 

1611 

1612 Args: 

1613 cell_fw: An instance of RNNCell, to be used for forward direction. 

1614 cell_bw: An instance of RNNCell, to be used for backward direction. 

1615 inputs: A length T list of inputs, each a tensor of shape [batch_size, 

1616 input_size], or a nested tuple of such elements. 

1617 initial_state_fw: (optional) An initial state for the forward RNN. This must 

1618 be a tensor of appropriate type and shape `[batch_size, 

1619 cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a 

1620 tuple of tensors having shapes `[batch_size, s] for s in 

1621 cell_fw.state_size`. 

1622 initial_state_bw: (optional) Same as for `initial_state_fw`, but using the 

1623 corresponding properties of `cell_bw`. 

1624 dtype: (optional) The data type for the initial state. Required if either 

1625 of the initial states are not provided. 

1626 sequence_length: (optional) An int32/int64 vector, size `[batch_size]`, 

1627 containing the actual lengths for each of the sequences. 

1628 scope: VariableScope for the created subgraph; defaults to 

1629 "bidirectional_rnn" 

1630 

1631 Returns: 

1632 A tuple (outputs, output_state_fw, output_state_bw) where: 

1633 outputs is a length `T` list of outputs (one for each input), which 

1634 are depth-concatenated forward and backward outputs. 

1635 output_state_fw is the final state of the forward rnn. 

1636 output_state_bw is the final state of the backward rnn. 

1637 

1638 Raises: 

1639 TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. 

1640 ValueError: If inputs is None or an empty list. 

1641 """ 

1642 rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw) 

1643 rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw) 

1644 if not nest.is_nested(inputs): 

1645 raise TypeError(f"Argument `inputs` must be a sequence. Received: {inputs}") 

1646 if not inputs: 

1647 raise ValueError("Argument `inputs` must not be empty.") 

1648 

1649 with vs.variable_scope(scope or "bidirectional_rnn"): 

1650 # Forward direction 

1651 with vs.variable_scope("fw") as fw_scope: 

1652 output_fw, output_state_fw = static_rnn( 

1653 cell_fw, 

1654 inputs, 

1655 initial_state_fw, 

1656 dtype, 

1657 sequence_length, 

1658 scope=fw_scope) 

1659 

1660 # Backward direction 

1661 with vs.variable_scope("bw") as bw_scope: 

1662 reversed_inputs = _reverse_seq(inputs, sequence_length) 

1663 tmp, output_state_bw = static_rnn( 

1664 cell_bw, 

1665 reversed_inputs, 

1666 initial_state_bw, 

1667 dtype, 

1668 sequence_length, 

1669 scope=bw_scope) 

1670 

1671 output_bw = _reverse_seq(tmp, sequence_length) 

1672 # Concat each of the forward/backward outputs 

1673 flat_output_fw = nest.flatten(output_fw) 

1674 flat_output_bw = nest.flatten(output_bw) 

1675 

1676 flat_outputs = tuple( 

1677 array_ops.concat([fw, bw], 1) 

1678 for fw, bw in zip(flat_output_fw, flat_output_bw)) 

1679 

1680 outputs = nest.pack_sequence_as( 

1681 structure=output_fw, flat_sequence=flat_outputs) 

1682 

1683 return (outputs, output_state_fw, output_state_bw)