Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/bidirectional.py: 13%

255 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Bidirectional wrapper for RNNs.""" 

16 

17 

18import copy 

19 

20import tensorflow.compat.v2 as tf 

21 

22from keras.src import backend 

23from keras.src.engine.base_layer import Layer 

24from keras.src.engine.input_spec import InputSpec 

25from keras.src.layers.rnn import rnn_utils 

26from keras.src.layers.rnn.base_wrapper import Wrapper 

27from keras.src.saving import serialization_lib 

28from keras.src.utils import generic_utils 

29from keras.src.utils import tf_inspect 

30from keras.src.utils import tf_utils 

31 

32# isort: off 

33from tensorflow.python.util.tf_export import keras_export 

34 

35 

36@keras_export("keras.layers.Bidirectional") 

37class Bidirectional(Wrapper): 

38 """Bidirectional wrapper for RNNs. 

39 

40 Args: 

41 layer: `keras.layers.RNN` instance, such as `keras.layers.LSTM` or 

42 `keras.layers.GRU`. It could also be a `keras.layers.Layer` instance 

43 that meets the following criteria: 

44 1. Be a sequence-processing layer (accepts 3D+ inputs). 

45 2. Have a `go_backwards`, `return_sequences` and `return_state` 

46 attribute (with the same semantics as for the `RNN` class). 

47 3. Have an `input_spec` attribute. 

48 4. Implement serialization via `get_config()` and `from_config()`. 

49 Note that the recommended way to create new RNN layers is to write a 

50 custom RNN cell and use it with `keras.layers.RNN`, instead of 

51 subclassing `keras.layers.Layer` directly. 

52 - When the `returns_sequences` is true, the output of the masked 

53 timestep will be zero regardless of the layer's original 

54 `zero_output_for_mask` value. 

55 merge_mode: Mode by which outputs of the forward and backward RNNs will be 

56 combined. One of {'sum', 'mul', 'concat', 'ave', None}. If None, the 

57 outputs will not be combined, they will be returned as a list. Default 

58 value is 'concat'. 

59 backward_layer: Optional `keras.layers.RNN`, or `keras.layers.Layer` 

60 instance to be used to handle backwards input processing. 

61 If `backward_layer` is not provided, the layer instance passed as the 

62 `layer` argument will be used to generate the backward layer 

63 automatically. 

64 Note that the provided `backward_layer` layer should have properties 

65 matching those of the `layer` argument, in particular it should have the 

66 same values for `stateful`, `return_states`, `return_sequences`, etc. 

67 In addition, `backward_layer` and `layer` should have different 

68 `go_backwards` argument values. 

69 A `ValueError` will be raised if these requirements are not met. 

70 

71 Call arguments: 

72 The call arguments for this layer are the same as those of the wrapped RNN 

73 layer. 

74 Beware that when passing the `initial_state` argument during the call of 

75 this layer, the first half in the list of elements in the `initial_state` 

76 list will be passed to the forward RNN call and the last half in the list 

77 of elements will be passed to the backward RNN call. 

78 

79 Raises: 

80 ValueError: 

81 1. If `layer` or `backward_layer` is not a `Layer` instance. 

82 2. In case of invalid `merge_mode` argument. 

83 3. If `backward_layer` has mismatched properties compared to `layer`. 

84 

85 Examples: 

86 

87 ```python 

88 model = Sequential() 

89 model.add(Bidirectional(LSTM(10, return_sequences=True), 

90 input_shape=(5, 10))) 

91 model.add(Bidirectional(LSTM(10))) 

92 model.add(Dense(5)) 

93 model.add(Activation('softmax')) 

94 model.compile(loss='categorical_crossentropy', optimizer='rmsprop') 

95 

96 # With custom backward layer 

97 model = Sequential() 

98 forward_layer = LSTM(10, return_sequences=True) 

99 backward_layer = LSTM(10, activation='relu', return_sequences=True, 

100 go_backwards=True) 

101 model.add(Bidirectional(forward_layer, backward_layer=backward_layer, 

102 input_shape=(5, 10))) 

103 model.add(Dense(5)) 

104 model.add(Activation('softmax')) 

105 model.compile(loss='categorical_crossentropy', optimizer='rmsprop') 

106 ``` 

107 """ 

108 

109 def __init__( 

110 self, 

111 layer, 

112 merge_mode="concat", 

113 weights=None, 

114 backward_layer=None, 

115 **kwargs, 

116 ): 

117 if not isinstance(layer, Layer): 

118 raise ValueError( 

119 "Please initialize `Bidirectional` layer with a " 

120 f"`tf.keras.layers.Layer` instance. Received: {layer}" 

121 ) 

122 if backward_layer is not None and not isinstance(backward_layer, Layer): 

123 raise ValueError( 

124 "`backward_layer` need to be a `tf.keras.layers.Layer` " 

125 f"instance. Received: {backward_layer}" 

126 ) 

127 if merge_mode not in ["sum", "mul", "ave", "concat", None]: 

128 raise ValueError( 

129 f"Invalid merge mode. Received: {merge_mode}. " 

130 "Merge mode should be one of " 

131 '{"sum", "mul", "ave", "concat", None}' 

132 ) 

133 # We don't want to track `layer` since we're already tracking the two 

134 # copies of it we actually run. 

135 self._setattr_tracking = False 

136 super().__init__(layer, **kwargs) 

137 self._setattr_tracking = True 

138 

139 # Recreate the forward layer from the original layer config, so that it 

140 # will not carry over any state from the layer. 

141 self.forward_layer = self._recreate_layer_from_config(layer) 

142 

143 if backward_layer is None: 

144 self.backward_layer = self._recreate_layer_from_config( 

145 layer, go_backwards=True 

146 ) 

147 else: 

148 self.backward_layer = backward_layer 

149 

150 # Keep the custom backward layer config, so that we can save it 

151 # later. The layer's name might be updated below with prefix 

152 # 'backward_', and we want to preserve the original config. 

153 self._backward_layer_config = ( 

154 serialization_lib.serialize_keras_object(backward_layer) 

155 ) 

156 

157 self.forward_layer._name = "forward_" + self.forward_layer.name 

158 self.backward_layer._name = "backward_" + self.backward_layer.name 

159 

160 self._verify_layer_config() 

161 

162 def force_zero_output_for_mask(layer): 

163 # Force the zero_output_for_mask to be True if returning sequences. 

164 if getattr(layer, "zero_output_for_mask", None) is not None: 

165 layer.zero_output_for_mask = layer.return_sequences 

166 

167 force_zero_output_for_mask(self.forward_layer) 

168 force_zero_output_for_mask(self.backward_layer) 

169 

170 self.merge_mode = merge_mode 

171 if weights: 

172 nw = len(weights) 

173 self.forward_layer.initial_weights = weights[: nw // 2] 

174 self.backward_layer.initial_weights = weights[nw // 2 :] 

175 self.stateful = layer.stateful 

176 self.return_sequences = layer.return_sequences 

177 self.return_state = layer.return_state 

178 self.supports_masking = True 

179 self._trainable = kwargs.get("trainable", layer.trainable) 

180 self._num_constants = 0 

181 self.input_spec = layer.input_spec 

182 

183 @property 

184 def _use_input_spec_as_call_signature(self): 

185 return self.layer._use_input_spec_as_call_signature 

186 

187 def _verify_layer_config(self): 

188 """Ensure the forward and backward layers have valid common property.""" 

189 if self.forward_layer.go_backwards == self.backward_layer.go_backwards: 

190 raise ValueError( 

191 "Forward layer and backward layer should have different " 

192 "`go_backwards` value." 

193 "forward_layer.go_backwards = " 

194 f"{self.forward_layer.go_backwards}," 

195 "backward_layer.go_backwards = " 

196 f"{self.backward_layer.go_backwards}" 

197 ) 

198 

199 common_attributes = ("stateful", "return_sequences", "return_state") 

200 for a in common_attributes: 

201 forward_value = getattr(self.forward_layer, a) 

202 backward_value = getattr(self.backward_layer, a) 

203 if forward_value != backward_value: 

204 raise ValueError( 

205 "Forward layer and backward layer are expected to have " 

206 f'the same value for attribute "{a}", got ' 

207 f'"{forward_value}" for forward layer and ' 

208 f'"{backward_value}" for backward layer' 

209 ) 

210 

211 def _recreate_layer_from_config(self, layer, go_backwards=False): 

212 # When recreating the layer from its config, it is possible that the 

213 # layer is a RNN layer that contains custom cells. In this case we 

214 # inspect the layer and pass the custom cell class as part of the 

215 # `custom_objects` argument when calling `from_config`. See 

216 # https://github.com/tensorflow/tensorflow/issues/26581 for more detail. 

217 config = layer.get_config() 

218 if go_backwards: 

219 config["go_backwards"] = not config["go_backwards"] 

220 if ( 

221 "custom_objects" 

222 in tf_inspect.getfullargspec(layer.__class__.from_config).args 

223 ): 

224 custom_objects = {} 

225 cell = getattr(layer, "cell", None) 

226 if cell is not None: 

227 custom_objects[cell.__class__.__name__] = cell.__class__ 

228 # For StackedRNNCells 

229 stacked_cells = getattr(cell, "cells", []) 

230 for c in stacked_cells: 

231 custom_objects[c.__class__.__name__] = c.__class__ 

232 return layer.__class__.from_config( 

233 config, custom_objects=custom_objects 

234 ) 

235 else: 

236 return layer.__class__.from_config(config) 

237 

238 @tf_utils.shape_type_conversion 

239 def compute_output_shape(self, input_shape): 

240 output_shape = self.forward_layer.compute_output_shape(input_shape) 

241 if self.return_state: 

242 state_shape = tf_utils.convert_shapes( 

243 output_shape[1:], to_tuples=False 

244 ) 

245 output_shape = tf_utils.convert_shapes( 

246 output_shape[0], to_tuples=False 

247 ) 

248 else: 

249 output_shape = tf_utils.convert_shapes( 

250 output_shape, to_tuples=False 

251 ) 

252 

253 if self.merge_mode == "concat": 

254 output_shape = output_shape.as_list() 

255 output_shape[-1] *= 2 

256 output_shape = tf.TensorShape(output_shape) 

257 elif self.merge_mode is None: 

258 output_shape = [output_shape, copy.copy(output_shape)] 

259 

260 if self.return_state: 

261 if self.merge_mode is None: 

262 return output_shape + state_shape + copy.copy(state_shape) 

263 return [output_shape] + state_shape + copy.copy(state_shape) 

264 return output_shape 

265 

266 def __call__(self, inputs, initial_state=None, constants=None, **kwargs): 

267 """`Bidirectional.__call__` implements the same API as the wrapped 

268 `RNN`.""" 

269 inputs, initial_state, constants = rnn_utils.standardize_args( 

270 inputs, initial_state, constants, self._num_constants 

271 ) 

272 

273 if isinstance(inputs, list): 

274 if len(inputs) > 1: 

275 initial_state = inputs[1:] 

276 inputs = inputs[0] 

277 

278 if initial_state is None and constants is None: 

279 return super().__call__(inputs, **kwargs) 

280 

281 # Applies the same workaround as in `RNN.__call__` 

282 additional_inputs = [] 

283 additional_specs = [] 

284 if initial_state is not None: 

285 # Check if `initial_state` can be split into half 

286 num_states = len(initial_state) 

287 if num_states % 2 > 0: 

288 raise ValueError( 

289 "When passing `initial_state` to a Bidirectional RNN, " 

290 "the state should be a list containing the states of " 

291 "the underlying RNNs. " 

292 f"Received: {initial_state}" 

293 ) 

294 

295 kwargs["initial_state"] = initial_state 

296 additional_inputs += initial_state 

297 state_specs = tf.nest.map_structure( 

298 lambda state: InputSpec(shape=backend.int_shape(state)), 

299 initial_state, 

300 ) 

301 self.forward_layer.state_spec = state_specs[: num_states // 2] 

302 self.backward_layer.state_spec = state_specs[num_states // 2 :] 

303 additional_specs += state_specs 

304 if constants is not None: 

305 kwargs["constants"] = constants 

306 additional_inputs += constants 

307 constants_spec = [ 

308 InputSpec(shape=backend.int_shape(constant)) 

309 for constant in constants 

310 ] 

311 self.forward_layer.constants_spec = constants_spec 

312 self.backward_layer.constants_spec = constants_spec 

313 additional_specs += constants_spec 

314 

315 self._num_constants = len(constants) 

316 self.forward_layer._num_constants = self._num_constants 

317 self.backward_layer._num_constants = self._num_constants 

318 

319 is_keras_tensor = backend.is_keras_tensor( 

320 tf.nest.flatten(additional_inputs)[0] 

321 ) 

322 for tensor in tf.nest.flatten(additional_inputs): 

323 if backend.is_keras_tensor(tensor) != is_keras_tensor: 

324 raise ValueError( 

325 "The initial state of a Bidirectional" 

326 " layer cannot be specified with a mix of" 

327 " Keras tensors and non-Keras tensors" 

328 ' (a "Keras tensor" is a tensor that was' 

329 " returned by a Keras layer, or by `Input`)" 

330 ) 

331 

332 if is_keras_tensor: 

333 # Compute the full input spec, including state 

334 full_input = [inputs] + additional_inputs 

335 # The original input_spec is None since there could be a nested 

336 # tensor input. Update the input_spec to match the inputs. 

337 full_input_spec = [ 

338 None for _ in range(len(tf.nest.flatten(inputs))) 

339 ] + additional_specs 

340 # Removing kwargs since the value are passed with input list. 

341 kwargs["initial_state"] = None 

342 kwargs["constants"] = None 

343 

344 # Perform the call with temporarily replaced input_spec 

345 original_input_spec = self.input_spec 

346 self.input_spec = full_input_spec 

347 output = super().__call__(full_input, **kwargs) 

348 self.input_spec = original_input_spec 

349 return output 

350 else: 

351 return super().__call__(inputs, **kwargs) 

352 

353 def call( 

354 self, 

355 inputs, 

356 training=None, 

357 mask=None, 

358 initial_state=None, 

359 constants=None, 

360 ): 

361 """`Bidirectional.call` implements the same API as the wrapped `RNN`.""" 

362 kwargs = {} 

363 if generic_utils.has_arg(self.layer.call, "training"): 

364 kwargs["training"] = training 

365 if generic_utils.has_arg(self.layer.call, "mask"): 

366 kwargs["mask"] = mask 

367 if generic_utils.has_arg(self.layer.call, "constants"): 

368 kwargs["constants"] = constants 

369 

370 if generic_utils.has_arg(self.layer.call, "initial_state"): 

371 if isinstance(inputs, list) and len(inputs) > 1: 

372 # initial_states are keras tensors, which means they are passed 

373 # in together with inputs as list. The initial_states need to be 

374 # split into forward and backward section, and be feed to layers 

375 # accordingly. 

376 forward_inputs = [inputs[0]] 

377 backward_inputs = [inputs[0]] 

378 pivot = (len(inputs) - self._num_constants) // 2 + 1 

379 # add forward initial state 

380 forward_inputs += inputs[1:pivot] 

381 if not self._num_constants: 

382 # add backward initial state 

383 backward_inputs += inputs[pivot:] 

384 else: 

385 # add backward initial state 

386 backward_inputs += inputs[pivot : -self._num_constants] 

387 # add constants for forward and backward layers 

388 forward_inputs += inputs[-self._num_constants :] 

389 backward_inputs += inputs[-self._num_constants :] 

390 forward_state, backward_state = None, None 

391 if "constants" in kwargs: 

392 kwargs["constants"] = None 

393 elif initial_state is not None: 

394 # initial_states are not keras tensors, eg eager tensor from np 

395 # array. They are only passed in from kwarg initial_state, and 

396 # should be passed to forward/backward layer via kwarg 

397 # initial_state as well. 

398 forward_inputs, backward_inputs = inputs, inputs 

399 half = len(initial_state) // 2 

400 forward_state = initial_state[:half] 

401 backward_state = initial_state[half:] 

402 else: 

403 forward_inputs, backward_inputs = inputs, inputs 

404 forward_state, backward_state = None, None 

405 

406 y = self.forward_layer( 

407 forward_inputs, initial_state=forward_state, **kwargs 

408 ) 

409 y_rev = self.backward_layer( 

410 backward_inputs, initial_state=backward_state, **kwargs 

411 ) 

412 else: 

413 y = self.forward_layer(inputs, **kwargs) 

414 y_rev = self.backward_layer(inputs, **kwargs) 

415 

416 if self.return_state: 

417 states = y[1:] + y_rev[1:] 

418 y = y[0] 

419 y_rev = y_rev[0] 

420 

421 if self.return_sequences: 

422 time_dim = ( 

423 0 if getattr(self.forward_layer, "time_major", False) else 1 

424 ) 

425 y_rev = backend.reverse(y_rev, time_dim) 

426 if self.merge_mode == "concat": 

427 output = backend.concatenate([y, y_rev]) 

428 elif self.merge_mode == "sum": 

429 output = y + y_rev 

430 elif self.merge_mode == "ave": 

431 output = (y + y_rev) / 2 

432 elif self.merge_mode == "mul": 

433 output = y * y_rev 

434 elif self.merge_mode is None: 

435 output = [y, y_rev] 

436 else: 

437 raise ValueError( 

438 "Unrecognized value for `merge_mode`. " 

439 f"Received: {self.merge_mode}" 

440 'Expected values are ["concat", "sum", "ave", "mul"]' 

441 ) 

442 

443 if self.return_state: 

444 if self.merge_mode is None: 

445 return output + states 

446 return [output] + states 

447 return output 

448 

449 def reset_states(self, states=None): 

450 if not self.stateful: 

451 raise AttributeError("Layer must be stateful.") 

452 

453 if states is None: 

454 self.forward_layer.reset_states() 

455 self.backward_layer.reset_states() 

456 else: 

457 if not isinstance(states, (list, tuple)): 

458 raise ValueError( 

459 "Unrecognized value for `states`. " 

460 "Expected `states` to be list or tuple. " 

461 f"Received: {states}" 

462 ) 

463 

464 half = len(states) // 2 

465 self.forward_layer.reset_states(states[:half]) 

466 self.backward_layer.reset_states(states[half:]) 

467 

468 def build(self, input_shape): 

469 with backend.name_scope(self.forward_layer.name): 

470 self.forward_layer.build(input_shape) 

471 with backend.name_scope(self.backward_layer.name): 

472 self.backward_layer.build(input_shape) 

473 self.built = True 

474 

475 def compute_mask(self, inputs, mask): 

476 if isinstance(mask, list): 

477 mask = mask[0] 

478 if self.return_sequences: 

479 if not self.merge_mode: 

480 output_mask = [mask, mask] 

481 else: 

482 output_mask = mask 

483 else: 

484 output_mask = [None, None] if not self.merge_mode else None 

485 

486 if self.return_state: 

487 states = self.forward_layer.states 

488 state_mask = [None for _ in states] 

489 if isinstance(output_mask, list): 

490 return output_mask + state_mask * 2 

491 return [output_mask] + state_mask * 2 

492 return output_mask 

493 

494 @property 

495 def constraints(self): 

496 constraints = {} 

497 if hasattr(self.forward_layer, "constraints"): 

498 constraints.update(self.forward_layer.constraints) 

499 constraints.update(self.backward_layer.constraints) 

500 return constraints 

501 

502 def get_config(self): 

503 config = {"merge_mode": self.merge_mode} 

504 if self._num_constants: 

505 config["num_constants"] = self._num_constants 

506 

507 if hasattr(self, "_backward_layer_config"): 

508 config["backward_layer"] = self._backward_layer_config 

509 base_config = super().get_config() 

510 return dict(list(base_config.items()) + list(config.items())) 

511 

512 @classmethod 

513 def from_config(cls, config, custom_objects=None): 

514 # Instead of updating the input, create a copy and use that. 

515 config = copy.deepcopy(config) 

516 num_constants = config.pop("num_constants", 0) 

517 # Handle forward layer instantiation (as would parent class). 

518 from keras.src.layers import deserialize as deserialize_layer 

519 

520 config["layer"] = deserialize_layer( 

521 config["layer"], custom_objects=custom_objects 

522 ) 

523 # Handle (optional) backward layer instantiation. 

524 backward_layer_config = config.pop("backward_layer", None) 

525 if backward_layer_config is not None: 

526 backward_layer = deserialize_layer( 

527 backward_layer_config, custom_objects=custom_objects 

528 ) 

529 config["backward_layer"] = backward_layer 

530 # Instantiate the wrapper, adjust it and return it. 

531 layer = cls(**config) 

532 layer._num_constants = num_constants 

533 return layer 

534