Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_wrapper_impl.py: 26%

222 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Module contains the implementation of RNN cell wrappers.""" 

16import hashlib 

17import numbers 

18import sys 

19import types as python_types 

20import warnings 

21 

22from tensorflow.python.framework import ops 

23from tensorflow.python.framework import tensor_conversion 

24from tensorflow.python.framework import tensor_shape 

25from tensorflow.python.framework import tensor_util 

26from tensorflow.python.keras.utils import generic_utils 

27from tensorflow.python.ops import array_ops 

28from tensorflow.python.ops import math_ops 

29from tensorflow.python.ops import nn_ops 

30from tensorflow.python.ops import random_ops 

31from tensorflow.python.ops import tensor_array_ops 

32from tensorflow.python.util import nest 

33 

34 

35class DropoutWrapperBase(object): 

36 """Operator adding dropout to inputs and outputs of the given cell.""" 

37 

38 def __init__(self, 

39 cell, 

40 input_keep_prob=1.0, 

41 output_keep_prob=1.0, 

42 state_keep_prob=1.0, 

43 variational_recurrent=False, 

44 input_size=None, 

45 dtype=None, 

46 seed=None, 

47 dropout_state_filter_visitor=None, 

48 **kwargs): 

49 """Create a cell with added input, state, and/or output dropout. 

50 

51 If `variational_recurrent` is set to `True` (**NOT** the default behavior), 

52 then the same dropout mask is applied at every step, as described in: 

53 [A Theoretically Grounded Application of Dropout in Recurrent 

54 Neural Networks. Y. Gal, Z. Ghahramani](https://arxiv.org/abs/1512.05287). 

55 

56 Otherwise a different dropout mask is applied at every time step. 

57 

58 Note, by default (unless a custom `dropout_state_filter` is provided), 

59 the memory state (`c` component of any `LSTMStateTuple`) passing through 

60 a `DropoutWrapper` is never modified. This behavior is described in the 

61 above article. 

62 

63 Args: 

64 cell: an RNNCell, a projection to output_size is added to it. 

65 input_keep_prob: unit Tensor or float between 0 and 1, input keep 

66 probability; if it is constant and 1, no input dropout will be added. 

67 output_keep_prob: unit Tensor or float between 0 and 1, output keep 

68 probability; if it is constant and 1, no output dropout will be added. 

69 state_keep_prob: unit Tensor or float between 0 and 1, output keep 

70 probability; if it is constant and 1, no output dropout will be added. 

71 State dropout is performed on the outgoing states of the cell. **Note** 

72 the state components to which dropout is applied when `state_keep_prob` 

73 is in `(0, 1)` are also determined by the argument 

74 `dropout_state_filter_visitor` (e.g. by default dropout is never applied 

75 to the `c` component of an `LSTMStateTuple`). 

76 variational_recurrent: Python bool. If `True`, then the same dropout 

77 pattern is applied across all time steps per run call. If this parameter 

78 is set, `input_size` **must** be provided. 

79 input_size: (optional) (possibly nested tuple of) `TensorShape` objects 

80 containing the depth(s) of the input tensors expected to be passed in to 

81 the `DropoutWrapper`. Required and used **iff** `variational_recurrent 

82 = True` and `input_keep_prob < 1`. 

83 dtype: (optional) The `dtype` of the input, state, and output tensors. 

84 Required and used **iff** `variational_recurrent = True`. 

85 seed: (optional) integer, the randomness seed. 

86 dropout_state_filter_visitor: (optional), default: (see below). Function 

87 that takes any hierarchical level of the state and returns a scalar or 

88 depth=1 structure of Python booleans describing which terms in the state 

89 should be dropped out. In addition, if the function returns `True`, 

90 dropout is applied across this sublevel. If the function returns 

91 `False`, dropout is not applied across this entire sublevel. 

92 Default behavior: perform dropout on all terms except the memory (`c`) 

93 state of `LSTMCellState` objects, and don't try to apply dropout to 

94 `TensorArray` objects: ``` 

95 def dropout_state_filter_visitor(s): 

96 if isinstance(s, LSTMCellState): # Never perform dropout on the c 

97 state. return LSTMCellState(c=False, h=True) 

98 elif isinstance(s, TensorArray): return False return True ``` 

99 **kwargs: dict of keyword arguments for base layer. 

100 

101 Raises: 

102 TypeError: if `cell` is not an `RNNCell`, or `keep_state_fn` is provided 

103 but not `callable`. 

104 ValueError: if any of the keep_probs are not between 0 and 1. 

105 """ 

106 super(DropoutWrapperBase, self).__init__(cell, dtype=dtype, **kwargs) 

107 

108 if (dropout_state_filter_visitor is not None and 

109 not callable(dropout_state_filter_visitor)): 

110 raise TypeError("dropout_state_filter_visitor must be callable") 

111 self._dropout_state_filter = ( 

112 dropout_state_filter_visitor or _default_dropout_state_filter_visitor) 

113 with ops.name_scope_v2("DropoutWrapperInit"): 

114 

115 def tensor_and_const_value(v): 

116 tensor_value = tensor_conversion.convert_to_tensor_v2_with_dispatch(v) 

117 const_value = tensor_util.constant_value(tensor_value) 

118 return (tensor_value, const_value) 

119 

120 for prob, attr in [(input_keep_prob, "input_keep_prob"), 

121 (state_keep_prob, "state_keep_prob"), 

122 (output_keep_prob, "output_keep_prob")]: 

123 tensor_prob, const_prob = tensor_and_const_value(prob) 

124 if const_prob is not None: 

125 if const_prob < 0 or const_prob > 1: 

126 raise ValueError("Parameter %s must be between 0 and 1: %d" % 

127 (attr, const_prob)) 

128 setattr(self, "_%s" % attr, float(const_prob)) 

129 else: 

130 setattr(self, "_%s" % attr, tensor_prob) 

131 

132 # Set variational_recurrent, seed before running the code below 

133 self._variational_recurrent = variational_recurrent 

134 self._input_size = input_size 

135 self._seed = seed 

136 

137 self._recurrent_input_noise = None 

138 self._recurrent_state_noise = None 

139 self._recurrent_output_noise = None 

140 

141 if variational_recurrent: 

142 if dtype is None: 

143 raise ValueError( 

144 "When variational_recurrent=True, dtype must be provided") 

145 

146 def convert_to_batch_shape(s): 

147 # Prepend a 1 for the batch dimension; for recurrent 

148 # variational dropout we use the same dropout mask for all 

149 # batch elements. 

150 return array_ops.concat(([1], tensor_shape.TensorShape(s).as_list()), 0) 

151 

152 def batch_noise(s, inner_seed): 

153 shape = convert_to_batch_shape(s) 

154 return random_ops.random_uniform(shape, seed=inner_seed, dtype=dtype) 

155 

156 if (not isinstance(self._input_keep_prob, numbers.Real) or 

157 self._input_keep_prob < 1.0): 

158 if input_size is None: 

159 raise ValueError( 

160 "When variational_recurrent=True and input_keep_prob < 1.0 or " 

161 "is unknown, input_size must be provided") 

162 self._recurrent_input_noise = _enumerated_map_structure_up_to( 

163 input_size, 

164 lambda i, s: batch_noise(s, inner_seed=self._gen_seed("input", i)), 

165 input_size) 

166 self._recurrent_state_noise = _enumerated_map_structure_up_to( 

167 cell.state_size, 

168 lambda i, s: batch_noise(s, inner_seed=self._gen_seed("state", i)), 

169 cell.state_size) 

170 self._recurrent_output_noise = _enumerated_map_structure_up_to( 

171 cell.output_size, 

172 lambda i, s: batch_noise(s, inner_seed=self._gen_seed("output", i)), 

173 cell.output_size) 

174 

175 def _gen_seed(self, salt_prefix, index): 

176 if self._seed is None: 

177 return None 

178 salt = "%s_%d" % (salt_prefix, index) 

179 string = (str(self._seed) + salt).encode("utf-8") 

180 return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF 

181 

182 @property 

183 def wrapped_cell(self): 

184 return self.cell 

185 

186 @property 

187 def state_size(self): 

188 return self.cell.state_size 

189 

190 @property 

191 def output_size(self): 

192 return self.cell.output_size 

193 

194 def build(self, inputs_shape): 

195 self.cell.build(inputs_shape) 

196 self.built = True 

197 

198 def zero_state(self, batch_size, dtype): 

199 with ops.name_scope_v2(type(self).__name__ + "ZeroState"): 

200 return self.cell.zero_state(batch_size, dtype) 

201 

202 def _variational_recurrent_dropout_value( 

203 self, unused_index, value, noise, keep_prob): 

204 """Performs dropout given the pre-calculated noise tensor.""" 

205 # uniform [keep_prob, 1.0 + keep_prob) 

206 random_tensor = keep_prob + noise 

207 

208 # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) 

209 binary_tensor = math_ops.floor(random_tensor) 

210 ret = math_ops.divide(value, keep_prob) * binary_tensor 

211 ret.set_shape(value.get_shape()) 

212 return ret 

213 

214 def _dropout(self, 

215 values, 

216 salt_prefix, 

217 recurrent_noise, 

218 keep_prob, 

219 shallow_filtered_substructure=None): 

220 """Decides whether to perform standard dropout or recurrent dropout.""" 

221 

222 if shallow_filtered_substructure is None: 

223 # Put something so we traverse the entire structure; inside the 

224 # dropout function we check to see if leafs of this are bool or not. 

225 shallow_filtered_substructure = values 

226 

227 if not self._variational_recurrent: 

228 

229 def dropout(i, do_dropout, v): 

230 if not isinstance(do_dropout, bool) or do_dropout: 

231 return nn_ops.dropout_v2( 

232 v, rate=1. - keep_prob, seed=self._gen_seed(salt_prefix, i)) 

233 else: 

234 return v 

235 

236 return _enumerated_map_structure_up_to( 

237 shallow_filtered_substructure, dropout, 

238 *[shallow_filtered_substructure, values]) 

239 else: 

240 

241 def dropout(i, do_dropout, v, n): 

242 if not isinstance(do_dropout, bool) or do_dropout: 

243 return self._variational_recurrent_dropout_value(i, v, n, keep_prob) 

244 else: 

245 return v 

246 

247 return _enumerated_map_structure_up_to( 

248 shallow_filtered_substructure, dropout, 

249 *[shallow_filtered_substructure, values, recurrent_noise]) 

250 

251 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): 

252 """Runs the wrapped cell and applies dropout. 

253 

254 Args: 

255 inputs: A tensor with wrapped cell's input. 

256 state: A tensor or tuple of tensors with wrapped cell's state. 

257 cell_call_fn: Wrapped cell's method to use for step computation (cell's 

258 `__call__` or 'call' method). 

259 **kwargs: Additional arguments. 

260 

261 Returns: 

262 A pair containing: 

263 

264 - Output: A tensor with cell's output. 

265 - New state: A tensor or tuple of tensors with new wrapped cell's state. 

266 """ 

267 

268 def _should_dropout(p): 

269 return (not isinstance(p, float)) or p < 1 

270 

271 if _should_dropout(self._input_keep_prob): 

272 inputs = self._dropout(inputs, "input", self._recurrent_input_noise, 

273 self._input_keep_prob) 

274 output, new_state = cell_call_fn(inputs, state, **kwargs) 

275 if _should_dropout(self._state_keep_prob): 

276 # Identify which subsets of the state to perform dropout on and 

277 # which ones to keep. 

278 shallow_filtered_substructure = nest.get_traverse_shallow_structure( 

279 self._dropout_state_filter, new_state) 

280 new_state = self._dropout(new_state, "state", self._recurrent_state_noise, 

281 self._state_keep_prob, 

282 shallow_filtered_substructure) 

283 if _should_dropout(self._output_keep_prob): 

284 output = self._dropout(output, "output", self._recurrent_output_noise, 

285 self._output_keep_prob) 

286 return output, new_state 

287 

288 def get_config(self): 

289 """Returns the config of the dropout wrapper.""" 

290 config = { 

291 "input_keep_prob": self._input_keep_prob, 

292 "output_keep_prob": self._output_keep_prob, 

293 "state_keep_prob": self._state_keep_prob, 

294 "variational_recurrent": self._variational_recurrent, 

295 "input_size": self._input_size, 

296 "seed": self._seed, 

297 } 

298 if self._dropout_state_filter != _default_dropout_state_filter_visitor: 

299 function, function_type, function_module = _serialize_function_to_config( 

300 self._dropout_state_filter) 

301 config.update({"dropout_fn": function, 

302 "dropout_fn_type": function_type, 

303 "dropout_fn_module": function_module}) 

304 base_config = super(DropoutWrapperBase, self).get_config() 

305 return dict(list(base_config.items()) + list(config.items())) 

306 

307 @classmethod 

308 def from_config(cls, config, custom_objects=None): 

309 if "dropout_fn" in config: 

310 config = config.copy() 

311 dropout_state_filter = _parse_config_to_function( 

312 config, custom_objects, "dropout_fn", "dropout_fn_type", 

313 "dropout_fn_module") 

314 config.pop("dropout_fn") 

315 config["dropout_state_filter_visitor"] = dropout_state_filter 

316 return super(DropoutWrapperBase, cls).from_config( 

317 config, custom_objects=custom_objects) 

318 

319 

320class ResidualWrapperBase(object): 

321 """RNNCell wrapper that ensures cell inputs are added to the outputs.""" 

322 

323 def __init__(self, cell, residual_fn=None, **kwargs): 

324 """Constructs a `ResidualWrapper` for `cell`. 

325 

326 Args: 

327 cell: An instance of `RNNCell`. 

328 residual_fn: (Optional) The function to map raw cell inputs and raw cell 

329 outputs to the actual cell outputs of the residual network. 

330 Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs 

331 and outputs. 

332 **kwargs: dict of keyword arguments for base layer. 

333 """ 

334 super(ResidualWrapperBase, self).__init__(cell, **kwargs) 

335 self._residual_fn = residual_fn 

336 

337 @property 

338 def state_size(self): 

339 return self.cell.state_size 

340 

341 @property 

342 def output_size(self): 

343 return self.cell.output_size 

344 

345 def zero_state(self, batch_size, dtype): 

346 with ops.name_scope_v2(type(self).__name__ + "ZeroState"): 

347 return self.cell.zero_state(batch_size, dtype) 

348 

349 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): 

350 """Run the cell and then apply the residual_fn on its inputs to its outputs. 

351 

352 Args: 

353 inputs: cell inputs. 

354 state: cell state. 

355 cell_call_fn: Wrapped cell's method to use for step computation (cell's 

356 `__call__` or 'call' method). 

357 **kwargs: Additional arguments passed to the wrapped cell's `call`. 

358 

359 Returns: 

360 Tuple of cell outputs and new state. 

361 

362 Raises: 

363 TypeError: If cell inputs and outputs have different structure (type). 

364 ValueError: If cell inputs and outputs have different structure (value). 

365 """ 

366 outputs, new_state = cell_call_fn(inputs, state, **kwargs) 

367 

368 # Ensure shapes match 

369 def assert_shape_match(inp, out): 

370 inp.get_shape().assert_is_compatible_with(out.get_shape()) 

371 

372 def default_residual_fn(inputs, outputs): 

373 nest.assert_same_structure(inputs, outputs) 

374 nest.map_structure(assert_shape_match, inputs, outputs) 

375 return nest.map_structure(lambda inp, out: inp + out, inputs, outputs) 

376 

377 res_outputs = (self._residual_fn or default_residual_fn)(inputs, outputs) 

378 return (res_outputs, new_state) 

379 

380 def get_config(self): 

381 """Returns the config of the residual wrapper.""" 

382 if self._residual_fn is not None: 

383 function, function_type, function_module = _serialize_function_to_config( 

384 self._residual_fn) 

385 config = { 

386 "residual_fn": function, 

387 "residual_fn_type": function_type, 

388 "residual_fn_module": function_module 

389 } 

390 else: 

391 config = {} 

392 base_config = super(ResidualWrapperBase, self).get_config() 

393 return dict(list(base_config.items()) + list(config.items())) 

394 

395 @classmethod 

396 def from_config(cls, config, custom_objects=None): 

397 if "residual_fn" in config: 

398 config = config.copy() 

399 residual_function = _parse_config_to_function(config, custom_objects, 

400 "residual_fn", 

401 "residual_fn_type", 

402 "residual_fn_module") 

403 config["residual_fn"] = residual_function 

404 return super(ResidualWrapperBase, cls).from_config( 

405 config, custom_objects=custom_objects) 

406 

407 

408class DeviceWrapperBase(object): 

409 """Operator that ensures an RNNCell runs on a particular device.""" 

410 

411 def __init__(self, cell, device, **kwargs): 

412 """Construct a `DeviceWrapper` for `cell` with device `device`. 

413 

414 Ensures the wrapped `cell` is called with `tf.device(device)`. 

415 

416 Args: 

417 cell: An instance of `RNNCell`. 

418 device: A device string or function, for passing to `tf.device`. 

419 **kwargs: dict of keyword arguments for base layer. 

420 """ 

421 super(DeviceWrapperBase, self).__init__(cell, **kwargs) 

422 self._device = device 

423 

424 @property 

425 def state_size(self): 

426 return self.cell.state_size 

427 

428 @property 

429 def output_size(self): 

430 return self.cell.output_size 

431 

432 def zero_state(self, batch_size, dtype): 

433 with ops.name_scope_v2(type(self).__name__ + "ZeroState"): 

434 with ops.device(self._device): 

435 return self.cell.zero_state(batch_size, dtype) 

436 

437 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): 

438 """Run the cell on specified device.""" 

439 with ops.device(self._device): 

440 return cell_call_fn(inputs, state, **kwargs) 

441 

442 def get_config(self): 

443 config = {"device": self._device} 

444 base_config = super(DeviceWrapperBase, self).get_config() 

445 return dict(list(base_config.items()) + list(config.items())) 

446 

447 

448def _serialize_function_to_config(function): 

449 """Serialize the function for get_config().""" 

450 if isinstance(function, python_types.LambdaType): 

451 output = generic_utils.func_dump(function) 

452 output_type = "lambda" 

453 module = function.__module__ 

454 elif callable(function): 

455 output = function.__name__ 

456 output_type = "function" 

457 module = function.__module__ 

458 else: 

459 raise ValueError("Unrecognized function type for input: {}".format( 

460 type(function))) 

461 

462 return output, output_type, module 

463 

464 

465def _parse_config_to_function(config, custom_objects, func_attr_name, 

466 func_type_attr_name, module_attr_name): 

467 """Reconstruct the function from the config.""" 

468 globs = globals() 

469 module = config.pop(module_attr_name, None) 

470 if module in sys.modules: 

471 globs.update(sys.modules[module].__dict__) 

472 elif module is not None: 

473 # Note: we don't know the name of the function if it's a lambda. 

474 warnings.warn("{} is not loaded, but a layer uses it. " 

475 "It may cause errors.".format(module), UserWarning) 

476 if custom_objects: 

477 globs.update(custom_objects) 

478 function_type = config.pop(func_type_attr_name) 

479 if function_type == "function": 

480 # Simple lookup in custom objects 

481 function = generic_utils.deserialize_keras_object( 

482 config[func_attr_name], 

483 custom_objects=custom_objects, 

484 printable_module_name="function in wrapper") 

485 elif function_type == "lambda": 

486 # Unsafe deserialization from bytecode 

487 function = generic_utils.func_load( 

488 config[func_attr_name], globs=globs) 

489 else: 

490 raise TypeError("Unknown function type:", function_type) 

491 return function 

492 

493 

494def _default_dropout_state_filter_visitor(substate): 

495 from tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl import LSTMStateTuple # pylint: disable=g-import-not-at-top 

496 if isinstance(substate, LSTMStateTuple): 

497 # Do not perform dropout on the memory state. 

498 return LSTMStateTuple(c=False, h=True) 

499 elif isinstance(substate, tensor_array_ops.TensorArray): 

500 return False 

501 return True 

502 

503 

504def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs): 

505 ix = [0] 

506 

507 def enumerated_fn(*inner_args, **inner_kwargs): 

508 r = map_fn(ix[0], *inner_args, **inner_kwargs) 

509 ix[0] += 1 

510 return r 

511 

512 return nest.map_structure_up_to(shallow_structure, enumerated_fn, *args, 

513 **kwargs)